Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def add_task(self, task: EmmentalTask) -> None:
r"""Add a single task into MTL network.
Args:
task(EmmentalTask): A task to add.
"""
# Combine module_pool from all tasks
for key in task.module_pool.keys():
if key in self.module_pool.keys():
if Meta.config["model_config"]["dataparallel"]:
task.module_pool[key] = nn.DataParallel(self.module_pool[key])
else:
task.module_pool[key] = self.module_pool[key]
else:
if Meta.config["model_config"]["dataparallel"]:
self.module_pool[key] = nn.DataParallel(task.module_pool[key])
else:
self.module_pool[key] = task.module_pool[key]
# Collect task name
self.task_names.add(task.name)
# Collect task flow
self.task_flows[task.name] = task.task_flow
# Collect loss function
self.loss_funcs[task.name] = task.loss_func
# Collect output function
self.output_funcs[task.name] = task.output_func
def __init__(self) -> None:
"""Initialize TensorBoardWriter."""
super().__init__()
# Set up tensorboard summary writer
self.writer = SummaryWriter(Meta.log_path)
def __init__(self) -> None:
"""Initialize the checkpointer."""
# Set up checkpoint directory
self.checkpoint_path = Meta.config["logging_config"]["checkpointer_config"][
"checkpoint_path"
]
if self.checkpoint_path is None:
self.checkpoint_path = Meta.log_path
# Create checkpoint directory if necessary
if not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
# Set up checkpoint frequency
self.checkpoint_freq = (
Meta.config["logging_config"]["evaluation_freq"]
* Meta.config["logging_config"]["checkpointer_config"]["checkpoint_freq"]
)
if self.checkpoint_freq <= 0:
def write_log(self, log_filename: str = "log.json") -> None:
"""Dump the log to file.
Args:
log_filename: The log filename, defaults to "log.json".
"""
log_path = os.path.join(Meta.log_path, log_filename)
with open(log_path, "w") as f:
json.dump(self.run_log, f)
X_batch[field_name].append(value)
for label_name, value in y_dict.items():
if isinstance(value, list):
Y_batch[label_name] += value
else:
Y_batch[label_name].append(value)
field_names = copy.deepcopy(list(X_batch.keys()))
for field_name in field_names:
values = X_batch[field_name]
# Only merge list of tensors
if isinstance(values[0], Tensor):
item_tensor, item_mask_tensor = list_to_tensor(
values,
min_len=Meta.config["data_config"]["min_data_len"],
max_len=Meta.config["data_config"]["max_data_len"],
)
X_batch[field_name] = item_tensor
if item_mask_tensor is not None:
X_batch[f"{field_name}_mask"] = item_mask_tensor
for label_name, values in Y_batch.items():
Y_batch[label_name] = list_to_tensor(
values,
min_len=Meta.config["data_config"]["min_data_len"],
max_len=Meta.config["data_config"]["max_data_len"],
)[0]
return dict(X_batch), dict(Y_batch)
gold_dict: Dict[str, ndarray] = defaultdict(list)
prob_dict: Dict[str, ndarray] = defaultdict(list)
output_dict = self.flow(X_dict, list(task_to_label_dict.keys()))
# Calculate loss for each task
for task_name, label_name in task_to_label_dict.items():
Y = Y_dict[label_name]
# Select the active samples
if Meta.config["learner_config"]["ignore_index"] is not None:
if len(Y.size()) == 1:
active = Y.detach() != Meta.config["learner_config"]["ignore_index"]
else:
active = torch.any(
Y.detach() != Meta.config["learner_config"]["ignore_index"],
dim=1,
)
else:
active = torch.BoolTensor([True] * Y.size()[0]) # type: ignore
# Only calculate the loss when active example exists
if active.any():
uid_dict[task_name] = [*itertools.compress(uids, active.numpy())]
loss_dict[task_name] = self.loss_funcs[task_name](
output_dict,
move_to_device(
Y_dict[label_name], Meta.config["model_config"]["device"]
),
move_to_device(active, Meta.config["model_config"]["device"]),
)