Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
pair whose first element is the task name and whose second element
is either a float (rate) or a function from Task to float.
default_rate: a float or a function from Task to float. This specifies the
default rate if rates are not provided in the `tasks` argument.
"""
self._task_to_rate = {}
self._tasks = []
for t in tasks:
if isinstance(t, str):
task_name = t
rate = default_rate
if default_rate is None:
raise ValueError("need a rate for each task")
else:
task_name, rate = t
self._tasks.append(TaskRegistry.get(task_name))
self._task_to_rate[task_name] = rate
if len(set(tuple(t.output_features) for t in self._tasks)) != 1:
raise ValueError(
"All Tasks in a Mixture must have the same output features."
)
if len(set(t.sentencepiece_model_path for t in self._tasks)) != 1:
raise ValueError(
"All Tasks in a Mixture must have the same sentencepiece_model_path."
)
def get_mixture_or_task(task_or_mixture_name):
"""Return the Task or Mixture from the appropriate registry."""
mixtures = MixtureRegistry.names()
tasks = TaskRegistry.names()
if task_or_mixture_name in mixtures:
if task_or_mixture_name in tasks:
logging.warning("%s is both a Task and a Mixture, returning Mixture",
task_or_mixture_name)
return MixtureRegistry.get(task_or_mixture_name)
if task_or_mixture_name in tasks:
return TaskRegistry.get(task_or_mixture_name)
else:
raise ValueError("No Task or Mixture found with name: %s" %
task_or_mixture_name)