How to use the torchelastic.p2p.coordinator_p2p.CoordinatorP2P function in torchelastic

To help you get started, we’ve selected a few torchelastic examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pytorch / elastic / test / p2p / elastic_trainer_test_base.py View on Github external
def _train_rerendezvous(self, _, run_id, train_step, hooks, state_override=None):
        """
        Alternate sub-process trainer entry point used by tests that want to
        force a re-rendezvous after every iteration.
        """

        class RerendezvousCoordinatorP2P(CoordinatorP2P):
            def should_rendezvous(self, state):
                return True

        elastic_coordinator = RerendezvousCoordinatorP2P(
            c10d_backend="gloo",
            init_method=self.get_rdzv_url(run_id, self.min_size, self.max_size),
            max_num_trainers=self.max_size,
            process_group_timeout=10000,
        )
        state = self._train_common(
            _, elastic_coordinator, train_step, hooks, state_override
        )
        return state
github pytorch / elastic / test / p2p / elastic_trainer_test_base.py View on Github external
def test_normal_flow_with_worker_stats(self):
        """
        Test a very simple 4 trainer case, where elastic_train_step
        also returns a non-None WorkerStats instance.
        """
        run_id = self._generate_run_id()

        nprocs = 4
        qouts = []
        qerrs = []
        prog_rates = [100, 95, 42, None]

        CoordinatorP2P.MONITOR_PROGRESS_FREQ = 1
        original_monitor_progress = CoordinatorP2P.monitor_progress

        def patched_monitor_progress(self, state, worker_stats):
            original_monitor_progress(self, state, worker_stats)

            # Save into state for retrieval in `_get_or_raise` below.
            if hasattr(self, "last_relative_prog_rate"):
                state._test_relative_prog_rate = self.last_relative_prog_rate
            if hasattr(self, "is_worker_straggler"):
                state._test_is_straggler = self.is_worker_straggler

        with patch.object(CoordinatorP2P, "monitor_progress", patched_monitor_progress):
            for i in range(0, nprocs):
                _, qout, qerr = self._spawn(
                    self._train_with_worker_stats,
                    run_id,
github pytorch / elastic / test / p2p / elastic_trainer_test_base.py View on Github external
def _train(self, _, run_id, train_step, hooks, state_override=None):
        """
        Common sub-process trainer entry point used by most tests.
        """
        elastic_coordinator = CoordinatorP2P(
            c10d_backend="gloo",
            init_method=self.get_rdzv_url(run_id, self.min_size, self.max_size),
            max_num_trainers=self.max_size,
            process_group_timeout=10000,
        )
        return self._train_common(
            _, elastic_coordinator, train_step, hooks, state_override
        )
github pytorch / elastic / test / p2p / elastic_trainer_test_base.py View on Github external
def _train_with_worker_stats(
        self,
        _,
        run_id,
        train_step,
        hooks,
        state_override=None,
        worker_stats_progress_rate=None,
    ):
        """
        Similar to `_train`, but uses a coordinator that validates WorkerStats object
        """
        fixed_worker_stats = TestWorkerStats(progress_rate=worker_stats_progress_rate)

        elastic_coordinator = CoordinatorP2P(
            c10d_backend="gloo",
            init_method=self.get_rdzv_url(run_id, self.min_size, self.max_size),
            max_num_trainers=self.max_size,
            process_group_timeout=10000,
        )
        return self._train_common(
            _,
            elastic_coordinator,
            train_step,
            hooks,
            state_override,
            fixed_worker_stats,
        )
github pytorch / elastic / examples / imagenet / main.py View on Github external
# Trick 2: linear layers are initialized by
            # drawing weights from a zero-mean Gaussian with
            # standard deviation of 0.01. In the paper it was only
            # fc layer, but in practice we found this better for
            # accuracy.
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)

    model.train()

    torch.cuda.set_device(local_rank)
    device = torch.cuda.current_device()
    model.cuda()
    log.info(f"Rank [{local_rank}] running on GPU [{device}]")

    coordinator = CoordinatorP2P(
        c10d_backend=c10d_backend,
        init_method=rdzv_init_url,
        max_num_trainers=max_world_size,
        process_group_timeout=60000,
    )

    state = ImagenetState(
        model=model,
        params=training_params,
        dataset=train_dataset,
        num_epochs=training_params.num_epochs,
    )

    log.info(f"Entering torchelastic train_loop")
    torchelastic.train(coordinator, train_step, state)