How to use the t5.data function in t5

To help you get started, we’ve selected a few t5 examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google-research / text-to-text-transfer-transformer / t5 / models / mtf_model.py View on Github external
def train(self, mixture_or_task_name, steps, init_checkpoint=None):
    """Train the model on the given Mixture or Task.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to train on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      steps: int, the total number of steps to train for.
      init_checkpoint: a string, if not None then read in variables from this
        checkpoint path when initializing variables. Will only initialize
        variables that appear both in the current graph and the checkpoint.
    """
    vocabulary = t5.data.get_mixture_or_task(
        mixture_or_task_name).get_vocabulary()
    dataset_fn = functools.partial(
        mesh_train_dataset_fn, mixture_or_task_name=mixture_or_task_name)
    utils.train_model(self.estimator(vocabulary, init_checkpoint), vocabulary,
                      self._sequence_length, self.batch_size, dataset_fn,
                      steps, self._ensemble_inputs)
github google-research / text-to-text-transfer-transformer / t5 / models / mesh_transformer.py View on Github external
appropriate registry. Must be specified via gin.
    sequence_length: dict mapping feature key to the int length for that feature
      the max sequence length.
    vocabulary: a SentencePieceVocabulary.
    dataset_split: string, which split of the dataset to load.
    num_eval_examples: maximum number of examples per task to use for continuous
      eval. If None, use all examples.
    use_cached: bool, whether to load the cached version of this dataset.

  Returns:
    A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples.
  """
  if not isinstance(vocabulary, t5.data.SentencePieceVocabulary):
    raise ValueError("vocabulary must be a SentencePieceVocabulary")

  mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name)

  def _get_dataset_for_single_task(task):
    """Get a tensorflow.data.Dataset for the provided task."""
    ds = task.get_dataset(
        sequence_length, split=dataset_split,
        use_cached=use_cached, shuffle=False
    )
    ds = transformer_dataset.pack_or_pad(
        ds, sequence_length, pack=False, feature_keys=task.output_features,
        ensure_eos=True)
    if num_eval_examples is not None:
      ds = ds.take(num_eval_examples)
    return ds

  outputs = []
github google-research / text-to-text-transfer-transformer / t5 / models / mtf_model.py View on Github external
Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        evaluation will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run eval
        continuously waiting for new checkpoints. If -1, get the latest
        checkpoint from the model directory.
      summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/eval_{split}.
      split: str, the split to evaluate on.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
    vocabulary = t5.data.get_mixture_or_task(
        mixture_or_task_name).get_vocabulary()
    dataset_fn = functools.partial(
        mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name)
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
    utils.eval_model(self.estimator(vocabulary), vocabulary,
                     self._sequence_length, self.batch_size, split,
                     self._model_dir, dataset_fn, summary_dir, checkpoint_steps)
github google-research / text-to-text-transfer-transformer / t5 / models / mesh_transformer.py View on Github external
def get_sentencepiece_model_path(mixture_or_task_name):
  return t5.data.get_mixture_or_task(
      mixture_or_task_name).sentencepiece_model_path
github google-research / text-to-text-transfer-transformer / t5 / models / mesh_transformer_main.py View on Github external
def main(_):
  if FLAGS.module_import:
    for module in FLAGS.module_import:
      importlib.import_module(module)

  if FLAGS.t5_tfds_data_dir:
    t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)
  t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

  # Add search path for gin files stored in package.
  gin.add_config_file_search_path(
      pkg_resources.resource_filename(__name__, "gin"))

  tf.io.gfile.makedirs(FLAGS.model_dir)
  suffix = 0
  command_filename = os.path.join(FLAGS.model_dir, "command")
  while tf.io.gfile.exists(command_filename):
    suffix += 1
    command_filename = os.path.join(
        FLAGS.model_dir, "command.{}".format(suffix))
  with tf.io.gfile.GFile(command_filename, "w") as f:
    f.write(" ".join(sys.argv))
github google-research / text-to-text-transfer-transformer / t5 / models / mesh_transformer_main.py View on Github external
def main(_):
  if FLAGS.module_import:
    for module in FLAGS.module_import:
      importlib.import_module(module)

  if FLAGS.t5_tfds_data_dir:
    t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)
  t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

  # Add search path for gin files stored in package.
  gin.add_config_file_search_path(
      pkg_resources.resource_filename(__name__, "gin"))

  tf.io.gfile.makedirs(FLAGS.model_dir)
  suffix = 0
  command_filename = os.path.join(FLAGS.model_dir, "command")
  while tf.io.gfile.exists(command_filename):
    suffix += 1
    command_filename = os.path.join(
        FLAGS.model_dir, "command.{}".format(suffix))
  with tf.io.gfile.GFile(command_filename, "w") as f:
    f.write(" ".join(sys.argv))

  utils.parse_gin_defaults_and_flags()