How to use the torchelastic.train_loop function in torchelastic

To help you get started, we’ve selected a few torchelastic examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pytorch / elastic / test / p2p / elastic_trainer_test_base.py View on Github external
def _train_common(
        self,
        _,
        elastic_coordinator,
        train_step,
        hooks,
        state_override=None,
        worker_stats=None,
    ):
        state = TestState() if state_override is None else state_override

        elastic_train_step = _make_elastic_train_step(train_step, hooks, worker_stats)
        state = elastic_train_loop.train(elastic_coordinator, elastic_train_step, state)
        return state
github pytorch / elastic / test / p2p / elastic_trainer_test_base.py View on Github external
def process_retryable_exception():
            # Raise exception repeatedly
            raise RuntimeError("train_step throws RuntimeError (retryable exception)")

        hooks = {"process_retryable_exception": process_retryable_exception}

        nprocs = 4
        qouts = []
        qerrs = []

        for _ in range(0, nprocs - 2):
            _, qout, qerr = self._spawn(self._train, run_id, _train_step, None)
            qouts.append(qout)
            qerrs.append(qerr)

        with patch.object(elastic_train_loop, "MAX_FAILURES", 5):
            for _ in range(nprocs - 2, nprocs):
                _, qout, qerr = self._spawn(self._train, run_id, _train_step, hooks)
                qouts.append(qout)
                qerrs.append(qerr)

        # Gather all "trained" values from all trainers, and ensure
        # that the bad trainers raise the expected exception.
        sums = []
        for i in range(0, nprocs):
            if i <= 1:
                state = _get_or_raise(qouts[i], qerrs[i])
                sums.append(state.total_sum)
                # Initially, 4 trainers consume 2 samples each, then the
                # surviving 2 trainers divide the remaining 20-8=12 samples, so
                # the surviving trainers each successfully process 2+6=8 samples.
                # nums keeps track of the samples "seen" so the surviving trainers