Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
("cola", [metrics.matthews_corrcoef]),
("sst2", [metrics.accuracy]),
("mrpc", [metrics.f1_score_with_invalid, metrics.accuracy]),
("stsb", [metrics.pearson_corrcoef, metrics.spearman_corrcoef]),
("qqp", [metrics.f1_score_with_invalid, metrics.accuracy]),
("mnli", [metrics.accuracy]),
("mnli_matched", [metrics.accuracy]),
("mnli_mismatched", [metrics.accuracy]),
("qnli", [metrics.accuracy]),
("rte", [metrics.accuracy]),
("wnli", [metrics.accuracy]),
("ax", []), # Only test set available.
])
for b in tfds.text.glue.Glue.builder_configs.values():
TaskRegistry.add(
"glue_%s_v002" % b.name,
TfdsTask,
tfds_name="glue/%s:%s" % (b.name, "1.0.0" if b.name == "ax" else "0.0.2"),
text_preprocessor=_get_glue_text_preprocessor(b),
metric_fns=GLUE_METRICS[b.name],
sentencepiece_model_path=DEFAULT_SPM_PATH,
postprocess_fn=_get_glue_postprocess_fn(b),
splits=["test"] if b.name == "ax" else None,
)
# =============================== CNN DailyMail ================================
TaskRegistry.add(
"cnn_dailymail_v002",
TfdsTask,
tfds_name="cnn_dailymail/plain_text:0.0.2",
text_preprocessor=functools.partial(preprocessors.summarize,
def get_mixture_or_task(task_or_mixture_name):
"""Return the Task or Mixture from the appropriate registry."""
mixtures = MixtureRegistry.names()
tasks = TaskRegistry.names()
if task_or_mixture_name in mixtures:
if task_or_mixture_name in tasks:
logging.warning("%s is both a Task and a Mixture, returning Mixture",
task_or_mixture_name)
return MixtureRegistry.get(task_or_mixture_name)
if task_or_mixture_name in tasks:
return TaskRegistry.get(task_or_mixture_name)
else:
raise ValueError("No Task or Mixture found with name: %s" %
task_or_mixture_name)
def add(cls, name, task_cls=Task, **kwargs):
super(TaskRegistry, cls).add(name, task_cls, name, **kwargs)
from t5.data import postprocessors
from t5.data import preprocessors
from t5.data.utils import DEFAULT_SPM_PATH
from t5.data.utils import set_global_cache_dirs
from t5.data.utils import TaskRegistry
from t5.data.utils import TfdsTask
from t5.evaluation import metrics
import tensorflow_datasets as tfds
# ==================================== C4 ======================================
_c4_config_suffixes = ["", ".noclean", ".realnewslike", ".webtextlike"]
for config_suffix in _c4_config_suffixes:
TaskRegistry.add(
"c4{name}_v020_unsupervised".format(
name=config_suffix.replace(".", "_")),
TfdsTask,
tfds_name="c4/en{config}:1.0.0".format(config=config_suffix),
text_preprocessor=functools.partial(
preprocessors.rekey, key_map={"inputs": None, "targets": "text"}),
token_preprocessor=preprocessors.unsupervised,
sentencepiece_model_path=DEFAULT_SPM_PATH,
metric_fns=[])
# ================================ Wikipedia ===================================
TaskRegistry.add(
"wikipedia_20190301.en_v003_unsupervised",
TfdsTask,
# 0.0.4 is identical to 0.0.3 except empty records removed.
tfds_name="wikipedia/20190301.en:0.0.4",