Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def train(self, mixture_or_task_name, steps, init_checkpoint=None):
"""Train the model on the given Mixture or Task.
Args:
mixture_or_task_name: str, the name of the Mixture or Task to train on.
Must be pre-registered in the global `TaskRegistry` or
`MixtureRegistry.`
steps: int, the total number of steps to train for.
init_checkpoint: a string, if not None then read in variables from this
checkpoint path when initializing variables. Will only initialize
variables that appear both in the current graph and the checkpoint.
"""
vocabulary = t5.data.get_mixture_or_task(
mixture_or_task_name).get_vocabulary()
dataset_fn = functools.partial(
mesh_train_dataset_fn, mixture_or_task_name=mixture_or_task_name)
utils.train_model(self.estimator(vocabulary, init_checkpoint), vocabulary,
self._sequence_length, self.batch_size, dataset_fn,
steps, self._ensemble_inputs)
appropriate registry. Must be specified via gin.
sequence_length: dict mapping feature key to the int length for that feature
the max sequence length.
vocabulary: a SentencePieceVocabulary.
dataset_split: string, which split of the dataset to load.
num_eval_examples: maximum number of examples per task to use for continuous
eval. If None, use all examples.
use_cached: bool, whether to load the cached version of this dataset.
Returns:
A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples.
"""
if not isinstance(vocabulary, t5.data.SentencePieceVocabulary):
raise ValueError("vocabulary must be a SentencePieceVocabulary")
mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name)
def _get_dataset_for_single_task(task):
"""Get a tensorflow.data.Dataset for the provided task."""
ds = task.get_dataset(
sequence_length, split=dataset_split,
use_cached=use_cached, shuffle=False
)
ds = transformer_dataset.pack_or_pad(
ds, sequence_length, pack=False, feature_keys=task.output_features,
ensure_eos=True)
if num_eval_examples is not None:
ds = ds.take(num_eval_examples)
return ds
outputs = []
Args:
mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
Must be pre-registered in the global `TaskRegistry` or
`MixtureRegistry.`
checkpoint_steps: int, list of ints, or None. If an int or list of ints,
evaluation will be run on the checkpoint files in `model_dir` whose
global steps are closest to the global steps provided. If None, run eval
continuously waiting for new checkpoints. If -1, get the latest
checkpoint from the model directory.
summary_dir: str, path to write TensorBoard events file summaries for
eval. If None, use model_dir/eval_{split}.
split: str, the split to evaluate on.
"""
if checkpoint_steps == -1:
checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
vocabulary = t5.data.get_mixture_or_task(
mixture_or_task_name).get_vocabulary()
dataset_fn = functools.partial(
mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name)
with gin.unlock_config():
gin.parse_config_file(_operative_config_path(self._model_dir))
utils.eval_model(self.estimator(vocabulary), vocabulary,
self._sequence_length, self.batch_size, split,
self._model_dir, dataset_fn, summary_dir, checkpoint_steps)
def get_sentencepiece_model_path(mixture_or_task_name):
return t5.data.get_mixture_or_task(
mixture_or_task_name).sentencepiece_model_path
def main(_):
if FLAGS.module_import:
for module in FLAGS.module_import:
importlib.import_module(module)
if FLAGS.t5_tfds_data_dir:
t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)
t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)
# Add search path for gin files stored in package.
gin.add_config_file_search_path(
pkg_resources.resource_filename(__name__, "gin"))
tf.io.gfile.makedirs(FLAGS.model_dir)
suffix = 0
command_filename = os.path.join(FLAGS.model_dir, "command")
while tf.io.gfile.exists(command_filename):
suffix += 1
command_filename = os.path.join(
FLAGS.model_dir, "command.{}".format(suffix))
with tf.io.gfile.GFile(command_filename, "w") as f:
f.write(" ".join(sys.argv))
def main(_):
if FLAGS.module_import:
for module in FLAGS.module_import:
importlib.import_module(module)
if FLAGS.t5_tfds_data_dir:
t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)
t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)
# Add search path for gin files stored in package.
gin.add_config_file_search_path(
pkg_resources.resource_filename(__name__, "gin"))
tf.io.gfile.makedirs(FLAGS.model_dir)
suffix = 0
command_filename = os.path.join(FLAGS.model_dir, "command")
while tf.io.gfile.exists(command_filename):
suffix += 1
command_filename = os.path.join(
FLAGS.model_dir, "command.{}".format(suffix))
with tf.io.gfile.GFile(command_filename, "w") as f:
f.write(" ".join(sys.argv))
utils.parse_gin_defaults_and_flags()