Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
for decoding. Must match the one used during training.
"""
# TODO(sharannarang) : It would be nice to have a function like
# load_checkpoint that loads the model once and then call decode_from_file
# multiple times without having to restore the checkpoint weights again.
# This would be particularly useful in colab demo.
if checkpoint_steps == -1:
checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
with gin.unlock_config():
gin.parse_config_file(_operative_config_path(self._model_dir))
gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
gin.bind_parameter("Bitransformer.decode.temperature", temperature)
vocabulary = t5.data.SentencePieceVocabulary(sentencepiece_model_path)
utils.infer_model(self.estimator(vocabulary), vocabulary,
self._sequence_length, self.batch_size,
self._model_type, self._model_dir, checkpoint_steps,
input_file, output_file)
the Mesh TF transformer standalone.
Args:
mixture_or_task_name: string, an identifier for a Mixture or Task in the
appropriate registry. Must be specified via gin.
sequence_length: dict mapping feature key to the int length for that feature
the max sequence length.
vocabulary: a SentencePieceVocabulary.
dataset_split: string, which split of the dataset to load. In most cases
this should be "train".
use_cached: bool, whether to load the cached version of this dataset.
Returns:
A tf.data.Dataset of preprocessed, tokenized, and batched examples.
"""
if not isinstance(vocabulary, t5.data.SentencePieceVocabulary):
raise ValueError("vocabulary must be a SentencePieceVocabulary")
mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name)
ds = mixture_or_task.get_dataset(
sequence_length, split=dataset_split, use_cached=use_cached, shuffle=True)
ds = transformer_dataset.pack_or_pad(
ds, sequence_length, pack=True,
feature_keys=tuple(mixture_or_task.output_features), ensure_eos=True)
return ds