Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
**task_kwargs)
@property
def splits(self):
"""Override since we can't call `info.splits` until after init."""
return self._splits or self._tfds_dataset.info.splits
@property
def tfds_dataset(self):
return self._tfds_dataset
def num_input_examples(self, split):
return self.tfds_dataset.size(split)
class TextLineTask(Task):
"""A `Task` that reads text lines as input.
Requires a text_processor to be passed that takes a tf.data.Dataset of
strings and returns a tf.data.Dataset of feature dictionaries.
e.g. preprocessors.preprocess_tsv()
"""
def __init__(
self,
name,
split_to_filepattern,
text_preprocessor,
sentencepiece_model_path,
metric_fns,
skip_header_lines=0,
**task_kwargs):
def get_subtasks(task_or_mixture):
"""Returns all the Tasks in a Mixture as a list or the Task itself."""
if isinstance(task_or_mixture, Task):
return [task_or_mixture]
else:
return task_or_mixture.tasks
"%s-*-of-*%d" % (
get_tfrecord_prefix(self.cache_dir, split),
split_info["num_shards"]),
shuffle=shuffle)
ds = ds.interleave(
tf.data.TFRecordDataset,
cycle_length=16, block_length=16,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.map(lambda ex: tf.parse_single_example(ex, feature_desc),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if self.get_cached_stats(split)["examples"] <= _MAX_EXAMPLES_TO_MEM_CACHE:
ds = ds.cache()
return ds
class TfdsTask(Task):
"""A `Task` that uses TensorFlow Datasets to provide the input dataset."""
def __init__(
self,
name,
tfds_name,
text_preprocessor,
sentencepiece_model_path,
metric_fns,
tfds_data_dir=None,
splits=None,
**task_kwargs):
"""TfdsTask constructor.
Args:
name: string, a unique name for the Task. A ValueError will be raised if