How to use the emmental.task.EmmentalTask function in emmental

To help you get started, we’ve selected a few emmental examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github SenWu / emmental / tests / test_model.py View on Github external
caplog.set_level(logging.INFO)

    dirpath = "temp_test_model"

    Meta.reset()
    emmental.init(dirpath)

    def ce_loss(module_name, immediate_output_dict, Y, active):
        return F.cross_entropy(
            immediate_output_dict[module_name][0][active], (Y.view(-1))[active]
        )

    def output(module_name, immediate_output_dict):
        return F.softmax(immediate_output_dict[module_name][0], dim=1)

    task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict(
            {"m1": nn.Linear(10, 10, bias=False), "m2": nn.Linear(10, 2, bias=False)}
        ),
        task_flow=[
            {"name": "m1", "module": "m1", "inputs": [("_input_", "data")]},
            {"name": "m2", "module": "m2", "inputs": [("m1", 0)]},
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    new_task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict(
github SenWu / emmental / tests / test_model.py View on Github external
new_task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict(
            {"m1": nn.Linear(10, 5, bias=False), "m2": nn.Linear(5, 2, bias=False)}
        ),
        task_flow=[
            {"name": "m1", "module": "m1", "inputs": [("_input_", "data")]},
            {"name": "m2", "module": "m2", "inputs": [("m1", 0)]},
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    task2 = EmmentalTask(
        name="task_2",
        module_pool=nn.ModuleDict(
            {"m1": nn.Linear(10, 5, bias=False), "m2": nn.Linear(5, 2, bias=False)}
        ),
        task_flow=[
            {"name": "m1", "module": "m1", "inputs": [("_input_", "data")]},
            {"name": "m2", "module": "m2", "inputs": [("m1", 0)]},
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    # Test w/ dataparallel
    model = EmmentalModel(name="test", tasks=task1)
github SenWu / emmental / tests / test_e2e.py View on Github external
# Create task
    def ce_loss(task_name, immediate_ouput_dict, Y, active):
        module_name = f"{task_name}_pred_head"
        return F.cross_entropy(
            immediate_ouput_dict[module_name][0][active], (Y.view(-1))[active]
        )

    def output(task_name, immediate_ouput_dict):
        module_name = f"{task_name}_pred_head"
        return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

    task_metrics = {"task1": ["accuracy"], "task2": ["accuracy", "roc_auc"]}

    tasks = [
        EmmentalTask(
            name=task_name,
            module_pool=nn.ModuleDict(
                {
                    "input_module": nn.Linear(2, 8),
                    f"{task_name}_pred_head": nn.Linear(8, 2),
                }
            ),
            task_flow=[
                {
                    "name": "input",
                    "module": "input_module",
                    "inputs": [("_input_", "data")],
                },
                {
                    "name": f"{task_name}_pred_head",
                    "module": f"{task_name}_pred_head",
github SenWu / emmental / tests / test_model.py View on Github external
task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict(
            {"m1": nn.Linear(10, 10, bias=False), "m2": nn.Linear(10, 2, bias=False)}
        ),
        task_flow=[
            {"name": "m1", "module": "m1", "inputs": [("_input_", "data")]},
            {"name": "m2", "module": "m2", "inputs": [("m1", 0)]},
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    new_task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict(
            {"m1": nn.Linear(10, 5, bias=False), "m2": nn.Linear(5, 2, bias=False)}
        ),
        task_flow=[
            {"name": "m1", "module": "m1", "inputs": [("_input_", "data")]},
            {"name": "m2", "module": "m2", "inputs": [("m1", 0)]},
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    task2 = EmmentalTask(
        name="task_2",
        module_pool=nn.ModuleDict(
github SenWu / emmental / src / emmental / model.py View on Github external
r"""Build the MTL network using all tasks.

        Args:
          tasks(EmmentalTask or List[EmmentalTask]): A task or a list of tasks.

        """

        if not isinstance(tasks, Iterable):
            tasks = [tasks]
        for task in tasks:
            if task.name in self.task_names:
                raise ValueError(
                    f"Found duplicate task {task.name}, different task should use "
                    f"different task name."
                )
            if not isinstance(task, EmmentalTask):
                raise ValueError(f"Unrecognized task type {task}.")
            self.add_task(task)
github SenWu / emmental / src / emmental / contrib / slicing / task.py View on Github external
# Loss function
        if pred_task_name in slice_distribution:
            loss = partial(
                utils.ce_loss,
                pred_head_module_name,
                weight=move_to_device(
                    slice_distribution[pred_task_name],
                    Meta.config["model_config"]["device"],
                ),
            )
        else:
            loss = partial(utils.ce_loss, pred_head_module_name)

        tasks.append(
            EmmentalTask(
                name=pred_task_name,
                module_pool=pred_module_pool,
                task_flow=pred_task_flow,
                loss_func=loss,
                output_func=partial(utils.output, pred_head_module_name),
                scorer=task.scorer,
            )
        )

    # Create master task

    # Create task name
    master_task_name = task.name

    # Create attention module
    master_attention_module_name = f"{master_task_name}_attention"
github SenWu / emmental / src / emmental / model.py View on Github external
def add_task(self, task: EmmentalTask) -> None:
        """Add a single task into MTL network.

        Args:
          task: A task to add.
        """
        if not isinstance(task, EmmentalTask):
            raise ValueError(f"Unrecognized task type {task}.")

        if task.name in self.task_names:
            raise ValueError(
                f"Found duplicate task {task.name}, different task should use "
                f"different task name."
            )

        # Combine module_pool from all tasks
        for key in task.module_pool.keys():
            if key in self.module_pool.keys():
                if Meta.config["model_config"]["dataparallel"]:
                    task.module_pool[key] = nn.DataParallel(self.module_pool[key])
                else:
                    task.module_pool[key] = self.module_pool[key]
            else:
github HazyResearch / fonduer / src / fonduer / learning / task.py View on Github external
"inputs": [
                        ("_input_", "feature_index"),
                        ("_input_", "feature_weight"),
                    ],
                },
                {
                    "name": f"{task_name}_pred_head",
                    "module": f"{task_name}_pred_head",
                    "inputs": None,
                },
            ]
        else:
            raise ValueError(f"Unrecognized model {model}.")

        tasks.append(
            EmmentalTask(
                name=task_name,
                module_pool=module_pool,
                task_flow=task_flow,
                loss_func=partial(loss, f"{task_name}_pred_head"),
                output_func=partial(output, f"{task_name}_pred_head"),
                scorer=Scorer(metrics=["accuracy", "precision", "recall", "f1"]),
            )
        )

    return tasks