How to use the emmental.meta.Meta function in emmental

To help you get started, we’ve selected a few emmental examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github SenWu / emmental / src / emmental / model.py View on Github external
def add_task(self, task: EmmentalTask) -> None:
        r"""Add a single task into MTL network.

        Args:
          task(EmmentalTask): A task to add.

        """

        # Combine module_pool from all tasks
        for key in task.module_pool.keys():
            if key in self.module_pool.keys():
                if Meta.config["model_config"]["dataparallel"]:
                    task.module_pool[key] = nn.DataParallel(self.module_pool[key])
                else:
                    task.module_pool[key] = self.module_pool[key]
            else:
                if Meta.config["model_config"]["dataparallel"]:
                    self.module_pool[key] = nn.DataParallel(task.module_pool[key])
                else:
                    self.module_pool[key] = task.module_pool[key]
        # Collect task name
        self.task_names.add(task.name)
        # Collect task flow
        self.task_flows[task.name] = task.task_flow
        # Collect loss function
        self.loss_funcs[task.name] = task.loss_func
        # Collect output function
        self.output_funcs[task.name] = task.output_func
github SenWu / emmental / src / emmental / logging / tensorboard_writer.py View on Github external
def __init__(self) -> None:
        """Initialize TensorBoardWriter."""
        super().__init__()

        # Set up tensorboard summary writer
        self.writer = SummaryWriter(Meta.log_path)
github SenWu / emmental / src / emmental / logging / checkpointer.py View on Github external
def __init__(self) -> None:
        """Initialize the checkpointer."""
        # Set up checkpoint directory
        self.checkpoint_path = Meta.config["logging_config"]["checkpointer_config"][
            "checkpoint_path"
        ]
        if self.checkpoint_path is None:
            self.checkpoint_path = Meta.log_path

        # Create checkpoint directory if necessary
        if not os.path.exists(self.checkpoint_path):
            os.makedirs(self.checkpoint_path)

        # Set up checkpoint frequency
        self.checkpoint_freq = (
            Meta.config["logging_config"]["evaluation_freq"]
            * Meta.config["logging_config"]["checkpointer_config"]["checkpoint_freq"]
        )

        if self.checkpoint_freq <= 0:
github SenWu / emmental / src / emmental / logging / log_writer.py View on Github external
def write_log(self, log_filename: str = "log.json") -> None:
        """Dump the log to file.

        Args:
          log_filename: The log filename, defaults to "log.json".
        """
        log_path = os.path.join(Meta.log_path, log_filename)
        with open(log_path, "w") as f:
            json.dump(self.run_log, f)
github SenWu / emmental / src / emmental / data.py View on Github external
X_batch[field_name].append(value)
        for label_name, value in y_dict.items():
            if isinstance(value, list):
                Y_batch[label_name] += value
            else:
                Y_batch[label_name].append(value)

    field_names = copy.deepcopy(list(X_batch.keys()))

    for field_name in field_names:
        values = X_batch[field_name]
        # Only merge list of tensors
        if isinstance(values[0], Tensor):
            item_tensor, item_mask_tensor = list_to_tensor(
                values,
                min_len=Meta.config["data_config"]["min_data_len"],
                max_len=Meta.config["data_config"]["max_data_len"],
            )
            X_batch[field_name] = item_tensor
            if item_mask_tensor is not None:
                X_batch[f"{field_name}_mask"] = item_mask_tensor

    for label_name, values in Y_batch.items():
        Y_batch[label_name] = list_to_tensor(
            values,
            min_len=Meta.config["data_config"]["min_data_len"],
            max_len=Meta.config["data_config"]["max_data_len"],
        )[0]

    return dict(X_batch), dict(Y_batch)
github SenWu / emmental / src / emmental / model.py View on Github external
gold_dict: Dict[str, ndarray] = defaultdict(list)
        prob_dict: Dict[str, ndarray] = defaultdict(list)

        output_dict = self.flow(X_dict, list(task_to_label_dict.keys()))

        # Calculate loss for each task
        for task_name, label_name in task_to_label_dict.items():
            Y = Y_dict[label_name]

            # Select the active samples
            if Meta.config["learner_config"]["ignore_index"] is not None:
                if len(Y.size()) == 1:
                    active = Y.detach() != Meta.config["learner_config"]["ignore_index"]
                else:
                    active = torch.any(
                        Y.detach() != Meta.config["learner_config"]["ignore_index"],
                        dim=1,
                    )
            else:
                active = torch.BoolTensor([True] * Y.size()[0])  # type: ignore

            # Only calculate the loss when active example exists
            if active.any():
                uid_dict[task_name] = [*itertools.compress(uids, active.numpy())]

                loss_dict[task_name] = self.loss_funcs[task_name](
                    output_dict,
                    move_to_device(
                        Y_dict[label_name], Meta.config["model_config"]["device"]
                    ),
                    move_to_device(active, Meta.config["model_config"]["device"]),
                )