Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def flow(self, X_dict: Dict[str, Any], task_names: List[str]) -> Dict[str, Any]:
"""Forward based on input and task flow.
Note:
We assume that all shared modules from all tasks are based on the
same input.
Args:
X_dict: The input data
task_names: The task names that needs to forward.
Returns:
The output of all forwarded modules
"""
X_dict = move_to_device(X_dict, Meta.config["model_config"]["device"])
output_dict = dict(_input_=X_dict)
# Call forward for each task
for task_name in task_names:
for action in self.task_flows[task_name]:
if action["name"] not in output_dict:
if action["inputs"]:
try:
input = [
output_dict[action_name][output_index]
for action_name, output_index in action["inputs"]
]
except Exception:
raise ValueError(f"Unrecognized action {action}.")
output = self.module_pool[action["module"]].forward(*input)
def write_config(self, config_filename: str = "config.yaml") -> None:
"""Write the config to tensorboard and dump it to file.
Args:
config_filename: The config filename, defaults to "config.yaml".
"""
config = json.dumps(Meta.config)
self.writer.add_text(tag="config", text_string=config)
super().write_config(config_filename)
{
"name": pred_head_module_name,
"module": shared_pred_head_module_name,
"inputs": [(pred_transform_module_name, 0)],
},
]
)
# Loss function
if pred_task_name in slice_distribution:
loss = partial(
utils.ce_loss,
pred_head_module_name,
weight=move_to_device(
slice_distribution[pred_task_name],
Meta.config["model_config"]["device"],
),
)
else:
loss = partial(utils.ce_loss, pred_head_module_name)
tasks.append(
EmmentalTask(
name=pred_task_name,
module_pool=pred_module_pool,
task_flow=pred_task_flow,
loss_func=loss,
output_func=partial(utils.output, pred_head_module_name),
scorer=task.scorer,
)
)
loss_func: Callable,
output_func: Callable,
scorer: Scorer,
weight: Union[float, int] = 1.0,
) -> None:
"""Initialize EmmentalTask."""
self.name = name
assert isinstance(module_pool, nn.ModuleDict) is True
self.module_pool = module_pool
self.task_flow = task_flow
self.loss_func = loss_func
self.output_func = output_func
self.scorer = scorer
self.weight = weight
if Meta.config["meta_config"]["verbose"]:
logger.info(f"Created task: {self.name}")
# Only merge list of tensors
if isinstance(values[0], Tensor):
item_tensor, item_mask_tensor = list_to_tensor(
values,
min_len=Meta.config["data_config"]["min_data_len"],
max_len=Meta.config["data_config"]["max_data_len"],
)
X_batch[field_name] = item_tensor
if item_mask_tensor is not None:
X_batch[f"{field_name}_mask"] = item_mask_tensor
for label_name, values in Y_batch.items():
Y_batch[label_name] = list_to_tensor(
values,
min_len=Meta.config["data_config"]["min_data_len"],
max_len=Meta.config["data_config"]["max_data_len"],
)[0]
return dict(X_batch), dict(Y_batch)
"""
uid_dict: Dict[str, List[str]] = defaultdict(list)
loss_dict: Dict[str, ndarray] = defaultdict(float)
gold_dict: Dict[str, ndarray] = defaultdict(list)
prob_dict: Dict[str, ndarray] = defaultdict(list)
output_dict = self.flow(X_dict, list(task_to_label_dict.keys()))
# Calculate loss for each task
for task_name, label_name in task_to_label_dict.items():
Y = Y_dict[label_name]
# Select the active samples
if Meta.config["learner_config"]["ignore_index"] is not None:
if len(Y.size()) == 1:
active = Y.detach() != Meta.config["learner_config"]["ignore_index"]
else:
active = torch.any(
Y.detach() != Meta.config["learner_config"]["ignore_index"],
dim=1,
)
else:
active = torch.ByteTensor([True] * Y.size()[0])
# Only calculate the loss when active example exists
if active.any():
uid_dict[task_name] = [*itertools.compress(uids, active.numpy())]
loss_dict[task_name] = self.loss_funcs[task_name](
output_dict,
)
# Set up checkpoint unit
self.checkpoint_unit = Meta.config["logging_config"]["counter_unit"]
logger.info(
f"Save checkpoints at {self.checkpoint_path} every "
f"{self.checkpoint_freq} {self.checkpoint_unit}"
)
# Set up checkpoint metric
self.checkpoint_metric = Meta.config["logging_config"]["checkpointer_config"][
"checkpoint_metric"
]
self.checkpoint_all_metrics = Meta.config["logging_config"][
"checkpointer_config"
]["checkpoint_task_metrics"]
# Collect all metrics to checkpoint
if self.checkpoint_all_metrics is None:
self.checkpoint_all_metrics = dict()
if self.checkpoint_metric:
self.checkpoint_all_metrics.update(self.checkpoint_metric)
# Check evaluation metric mode
for metric, mode in self.checkpoint_all_metrics.items():
if mode not in ["min", "max"]:
raise ValueError(
f"Unrecognized checkpoint metric mode {mode} for metric {metric}, "
f"must be 'min' or 'max'."