Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
metric_fns=[metrics.qa],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# Maximized evaluation metrics over all answers.
TaskRegistry.add(
"squad_v010_context_free",
TfdsTask,
tfds_name="squad/plain_text:0.1.0",
text_preprocessor=functools.partial(
preprocessors.squad, include_context=False),
postprocess_fn=postprocessors.qa,
metric_fns=[metrics.qa],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# Squad span prediction task instead of text.
TaskRegistry.add(
"squad_v010_allanswers_span",
TfdsTask,
tfds_name="squad/plain_text:0.1.0",
text_preprocessor=preprocessors.squad_span_space_tokenized,
postprocess_fn=postprocessors.span_qa,
metric_fns=[metrics.span_qa],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# Deprecated: Use `squad_v010_allanswers` instead.
TaskRegistry.add(
"squad_v010",
TfdsTask,
tfds_name="squad/plain_text:0.1.0",
text_preprocessor=preprocessors.squad,
metric_fns=[metrics.qa],
sentencepiece_model_path=DEFAULT_SPM_PATH)
splits=["validation", "test"])
# =================================== WNLI =====================================
TaskRegistry.add(
"glue_wnli_v002_simple_eval",
TfdsTask,
tfds_name="glue/wnli:0.0.2",
text_preprocessor=preprocessors.wnli_simple,
postprocess_fn=postprocessors.wsc_simple,
metric_fns=[metrics.accuracy],
sentencepiece_model_path=DEFAULT_SPM_PATH,
splits=["validation", "test"])
# =================================== Squad ====================================
# Maximized evaluation metrics over all answers.
TaskRegistry.add(
"squad_v010_allanswers",
TfdsTask,
tfds_name="squad/plain_text:0.1.0",
text_preprocessor=preprocessors.squad,
postprocess_fn=postprocessors.qa,
metric_fns=[metrics.qa],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# Maximized evaluation metrics over all answers.
TaskRegistry.add(
"squad_v010_context_free",
TfdsTask,
tfds_name="squad/plain_text:0.1.0",
text_preprocessor=functools.partial(
preprocessors.squad, include_context=False),
postprocess_fn=postprocessors.qa,
tfds_name="definite_pronoun_resolution/plain_text:0.0.1",
text_preprocessor=preprocessors.definite_pronoun_resolution_simple,
metric_fns=[metrics.accuracy],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# =================================== WSC ======================================
TaskRegistry.add(
"super_glue_wsc_v102_simple_train",
TfdsTask,
tfds_name="super_glue/wsc.fixed:1.0.2",
text_preprocessor=functools.partial(
preprocessors.wsc_simple, correct_referent_only=True),
metric_fns=[],
sentencepiece_model_path=DEFAULT_SPM_PATH,
splits=["train"])
TaskRegistry.add(
"super_glue_wsc_v102_simple_eval",
TfdsTask,
tfds_name="super_glue/wsc.fixed:1.0.2",
text_preprocessor=functools.partial(
preprocessors.wsc_simple, correct_referent_only=False),
postprocess_fn=postprocessors.wsc_simple,
metric_fns=[metrics.accuracy],
sentencepiece_model_path=DEFAULT_SPM_PATH,
splits=["validation", "test"])
# =================================== WNLI =====================================
TaskRegistry.add(
"glue_wnli_v002_simple_eval",
TfdsTask,
tfds_name="glue/wnli:0.0.2",
text_preprocessor=preprocessors.wnli_simple,
_get_glue_text_preprocessor(b)
]
else:
text_preprocessor = _get_glue_text_preprocessor(b)
TaskRegistry.add(
"super_glue_%s_v102" % b.name,
TfdsTask,
tfds_name="super_glue/%s:1.0.2" % b.name,
text_preprocessor=text_preprocessor,
metric_fns=SUPERGLUE_METRICS[b.name],
sentencepiece_model_path=DEFAULT_SPM_PATH,
postprocess_fn=_get_glue_postprocess_fn(b),
splits=["test"] if b.name in ["axb", "axg"] else None)
# ======================== Definite Pronoun Resolution =========================
TaskRegistry.add(
"dpr_v001_simple",
TfdsTask,
tfds_name="definite_pronoun_resolution/plain_text:0.0.1",
text_preprocessor=preprocessors.definite_pronoun_resolution_simple,
metric_fns=[metrics.accuracy],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# =================================== WSC ======================================
TaskRegistry.add(
"super_glue_wsc_v102_simple_train",
TfdsTask,
tfds_name="super_glue/wsc.fixed:1.0.2",
text_preprocessor=functools.partial(
preprocessors.wsc_simple, correct_referent_only=True),
metric_fns=[],
sentencepiece_model_path=DEFAULT_SPM_PATH,
continue
if b.name == "axb":
text_preprocessor = [
functools.partial(
preprocessors.rekey,
key_map={
"premise": "sentence1",
"hypothesis": "sentence2",
"label": "label",
"idx": "idx",
}),
_get_glue_text_preprocessor(b)
]
else:
text_preprocessor = _get_glue_text_preprocessor(b)
TaskRegistry.add(
"super_glue_%s_v102" % b.name,
TfdsTask,
tfds_name="super_glue/%s:1.0.2" % b.name,
text_preprocessor=text_preprocessor,
metric_fns=SUPERGLUE_METRICS[b.name],
sentencepiece_model_path=DEFAULT_SPM_PATH,
postprocess_fn=_get_glue_postprocess_fn(b),
splits=["test"] if b.name in ["axb", "axg"] else None)
# ======================== Definite Pronoun Resolution =========================
TaskRegistry.add(
"dpr_v001_simple",
TfdsTask,
tfds_name="definite_pronoun_resolution/plain_text:0.0.1",
text_preprocessor=preprocessors.definite_pronoun_resolution_simple,
metric_fns=[metrics.accuracy],
metric_fns=[],
sentencepiece_model_path=DEFAULT_SPM_PATH,
splits=["train"])
TaskRegistry.add(
"super_glue_wsc_v102_simple_eval",
TfdsTask,
tfds_name="super_glue/wsc.fixed:1.0.2",
text_preprocessor=functools.partial(
preprocessors.wsc_simple, correct_referent_only=False),
postprocess_fn=postprocessors.wsc_simple,
metric_fns=[metrics.accuracy],
sentencepiece_model_path=DEFAULT_SPM_PATH,
splits=["validation", "test"])
# =================================== WNLI =====================================
TaskRegistry.add(
"glue_wnli_v002_simple_eval",
TfdsTask,
tfds_name="glue/wnli:0.0.2",
text_preprocessor=preprocessors.wnli_simple,
postprocess_fn=postprocessors.wsc_simple,
metric_fns=[metrics.accuracy],
sentencepiece_model_path=DEFAULT_SPM_PATH,
splits=["validation", "test"])
# =================================== Squad ====================================
# Maximized evaluation metrics over all answers.
TaskRegistry.add(
"squad_v010_allanswers",
TfdsTask,
tfds_name="squad/plain_text:0.1.0",
text_preprocessor=preprocessors.squad,
# ==================================== C4 ======================================
_c4_config_suffixes = ["", ".noclean", ".realnewslike", ".webtextlike"]
for config_suffix in _c4_config_suffixes:
TaskRegistry.add(
"c4{name}_v020_unsupervised".format(
name=config_suffix.replace(".", "_")),
TfdsTask,
tfds_name="c4/en{config}:1.0.0".format(config=config_suffix),
text_preprocessor=functools.partial(
preprocessors.rekey, key_map={"inputs": None, "targets": "text"}),
token_preprocessor=preprocessors.unsupervised,
sentencepiece_model_path=DEFAULT_SPM_PATH,
metric_fns=[])
# ================================ Wikipedia ===================================
TaskRegistry.add(
"wikipedia_20190301.en_v003_unsupervised",
TfdsTask,
# 0.0.4 is identical to 0.0.3 except empty records removed.
tfds_name="wikipedia/20190301.en:0.0.4",
text_preprocessor=functools.partial(
preprocessors.rekey, key_map={"inputs": None, "targets": "text"}),
token_preprocessor=preprocessors.unsupervised,
sentencepiece_model_path=DEFAULT_SPM_PATH,
metric_fns=[])
# =================================== GLUE =====================================
def _get_glue_text_preprocessor(builder_config):
"""Return the glue preprocessor.
Args:
for prefix, b, tfds_version in b_configs:
TaskRegistry.add(
"wmt%s_%s%s_v003" % (prefix, b.language_pair[1], b.language_pair[0]),
TfdsTask,
tfds_name="wmt%s_translate/%s:%s" % (prefix, b.name, tfds_version),
text_preprocessor=functools.partial(
preprocessors.translate,
source_language=b.language_pair[1],
target_language=b.language_pair[0],
),
metric_fns=[metrics.bleu],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# Special case for t2t ende.
b = tfds.translate.wmt_t2t.WmtT2tTranslate.builder_configs["de-en"]
TaskRegistry.add(
"wmt_t2t_ende_v003",
TfdsTask,
tfds_name="wmt_t2t_translate/de-en:0.0.1",
text_preprocessor=functools.partial(
preprocessors.translate,
source_language=b.language_pair[1],
target_language=b.language_pair[0],
),
metric_fns=[metrics.bleu],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# ================================= SuperGlue ==================================
SUPERGLUE_METRICS = collections.OrderedDict([
("boolq", [metrics.accuracy]),
("cb", [
metrics.mean_multiclass_f1(num_classes=3),
postprocess_fn=postprocessors.qa,
metric_fns=[metrics.qa],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# Squad span prediction task instead of text.
TaskRegistry.add(
"squad_v010_allanswers_span",
TfdsTask,
tfds_name="squad/plain_text:0.1.0",
text_preprocessor=preprocessors.squad_span_space_tokenized,
postprocess_fn=postprocessors.span_qa,
metric_fns=[metrics.span_qa],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# Deprecated: Use `squad_v010_allanswers` instead.
TaskRegistry.add(
"squad_v010",
TfdsTask,
tfds_name="squad/plain_text:0.1.0",
text_preprocessor=preprocessors.squad,
metric_fns=[metrics.qa],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# ================================= TriviaQA ===================================
TaskRegistry.add(
"trivia_qa_v010",
TfdsTask,
tfds_name="trivia_qa:0.1.0",
text_preprocessor=preprocessors.trivia_qa,
metric_fns=[],
token_preprocessor=preprocessors.trivia_qa_truncate_inputs,
sentencepiece_model_path=DEFAULT_SPM_PATH)
# Format: year, tfds builder config, tfds version
b_configs = [
("14", tfds.translate.wmt14.Wmt14Translate.builder_configs["de-en"], "0.0.3"
),
("14", tfds.translate.wmt14.Wmt14Translate.builder_configs["fr-en"], "0.0.3"
),
("16", tfds.translate.wmt16.Wmt16Translate.builder_configs["ro-en"], "0.0.3"
),
("15", tfds.translate.wmt15.Wmt15Translate.builder_configs["fr-en"], "0.0.4"
),
("19", tfds.translate.wmt19.Wmt19Translate.builder_configs["de-en"], "0.0.3"
),
]
for prefix, b, tfds_version in b_configs:
TaskRegistry.add(
"wmt%s_%s%s_v003" % (prefix, b.language_pair[1], b.language_pair[0]),
TfdsTask,
tfds_name="wmt%s_translate/%s:%s" % (prefix, b.name, tfds_version),
text_preprocessor=functools.partial(
preprocessors.translate,
source_language=b.language_pair[1],
target_language=b.language_pair[0],
),
metric_fns=[metrics.bleu],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# Special case for t2t ende.
b = tfds.translate.wmt_t2t.WmtT2tTranslate.builder_configs["de-en"]
TaskRegistry.add(
"wmt_t2t_ende_v003",
TfdsTask,