Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
mixture_or_task_name: string, an identifier for a Mixture or Task in the
appropriate registry. Must be specified via gin.
sequence_length: dict mapping feature key to the int length for that feature
the max sequence length.
vocabulary: a SentencePieceVocabulary.
dataset_split: string, which split of the dataset to load. In most cases
this should be "train".
use_cached: bool, whether to load the cached version of this dataset.
Returns:
A tf.data.Dataset of preprocessed, tokenized, and batched examples.
"""
if not isinstance(vocabulary, t5.data.SentencePieceVocabulary):
raise ValueError("vocabulary must be a SentencePieceVocabulary")
mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name)
ds = mixture_or_task.get_dataset(
sequence_length, split=dataset_split, use_cached=use_cached, shuffle=True)
ds = transformer_dataset.pack_or_pad(
ds, sequence_length, pack=True,
feature_keys=tuple(mixture_or_task.output_features), ensure_eos=True)
return ds