Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from __future__ import print_function
import collections
import os
import shutil
from absl import logging
from absl.testing import absltest
import numpy as np
import six
from t5.data import sentencepiece_vocabulary
from t5.data import utils as dataset_utils
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
TaskRegistry = dataset_utils.TaskRegistry
MixtureRegistry = dataset_utils.MixtureRegistry
mock = absltest.mock
TEST_DATA_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_data")
# _ProxyTest is required because py2 does not allow instantiating
# absltest.TestCase directly.
class _ProxyTest(absltest.TestCase):
"""Instance of TestCase to reuse methods for testing."""
maxDiff = None
def runTest(self):
pass
_FAKE_CACHED_DATASET["train"], [2, 1], _dump_examples_to_tfrecord)
_dump_fake_dataset(
os.path.join(cached_task_dir, "validation.tfrecord"),
_FAKE_CACHED_DATASET["validation"], [2], _dump_examples_to_tfrecord)
# Prepare uncached TfdsTask.
add_tfds_task("uncached_task")
self.uncached_task = TaskRegistry.get("uncached_task")
# Prepare uncached TextLineTask.
_dump_fake_dataset(
os.path.join(self.test_data_dir, "train.tsv"),
_FAKE_DATASET["train"], [2, 1], _dump_examples_to_tsv)
TaskRegistry.add(
"text_line_task",
dataset_utils.TextLineTask,
split_to_filepattern={
"train": os.path.join(self.test_data_dir, "train.tsv*"),
},
skip_header_lines=1,
text_preprocessor=[_split_tsv_preprocessor, test_text_preprocessor],
sentencepiece_model_path=os.path.join(
TEST_DATA_DIR, "sentencepiece", "sentencepiece.model"),
metric_fns=[])
self.text_line_task = TaskRegistry.get("text_line_task")
# Auto-verify any split by just retuning the split name
dataset_utils.verify_tfds_split = absltest.mock.Mock(
side_effect=lambda x, y: y
)
def add_fake_tfds(fake_tfds):
dataset_utils.LazyTfdsLoader._MEMOIZED_INSTANCES[ # pylint:disable=protected-access
(fake_tfds.name, None)] = fake_tfds