How to use torchelastic - 10 common examples

To help you get started, we’ve selected a few torchelastic examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pytorch / elastic / examples / imagenet / main.py View on Github external
elif epoch < 80:
        lr = world_size * params.base_learning_rate * (0.1 ** (epoch // 30))
    else:
        lr = world_size * params.base_learning_rate * (0.1 ** 3)
    for param_group in optimizer.param_groups:
        lr_old = param_group["lr"]
        param_group["lr"] = lr
        # Trick: apply momentum correction when lr is updated
        if lr > lr_old:
            param_group["momentum"] = lr / lr_old * 0.9  # momentum
        else:
            param_group["momentum"] = 0.9  # default momentum
    return


class ImagenetState(torchelastic.State):
    """
    Client-provided State object; it is serializable and captures the entire
    state needed for executing one iteration of training
    """

    def __init__(self, model, params, dataset, num_epochs, epoch=0):
        self.model = model
        self.params = params
        self.dataset = dataset
        self.total_batch_size = params.batch_per_device

        self.num_epochs = num_epochs
        self.epoch = epoch

        self.iteration = 0
        self.data_start_index = 0
github pytorch / elastic / test / p2p / elastic_trainer_test_base.py View on Github external
def _train_rerendezvous(self, _, run_id, train_step, hooks, state_override=None):
        """
        Alternate sub-process trainer entry point used by tests that want to
        force a re-rendezvous after every iteration.
        """

        class RerendezvousCoordinatorP2P(CoordinatorP2P):
            def should_rendezvous(self, state):
                return True

        elastic_coordinator = RerendezvousCoordinatorP2P(
            c10d_backend="gloo",
            init_method=self.get_rdzv_url(run_id, self.min_size, self.max_size),
            max_num_trainers=self.max_size,
            process_group_timeout=10000,
        )
        state = self._train_common(
            _, elastic_coordinator, train_step, hooks, state_override
        )
        return state
github pytorch / elastic / test / p2p / elastic_trainer_test_base.py View on Github external
def test_normal_flow_with_worker_stats(self):
        """
        Test a very simple 4 trainer case, where elastic_train_step
        also returns a non-None WorkerStats instance.
        """
        run_id = self._generate_run_id()

        nprocs = 4
        qouts = []
        qerrs = []
        prog_rates = [100, 95, 42, None]

        CoordinatorP2P.MONITOR_PROGRESS_FREQ = 1
        original_monitor_progress = CoordinatorP2P.monitor_progress

        def patched_monitor_progress(self, state, worker_stats):
            original_monitor_progress(self, state, worker_stats)

            # Save into state for retrieval in `_get_or_raise` below.
            if hasattr(self, "last_relative_prog_rate"):
                state._test_relative_prog_rate = self.last_relative_prog_rate
            if hasattr(self, "is_worker_straggler"):
                state._test_is_straggler = self.is_worker_straggler

        with patch.object(CoordinatorP2P, "monitor_progress", patched_monitor_progress):
            for i in range(0, nprocs):
                _, qout, qerr = self._spawn(
                    self._train_with_worker_stats,
                    run_id,
github pytorch / elastic / test / p2p / elastic_trainer_test_base.py View on Github external
def _train(self, _, run_id, train_step, hooks, state_override=None):
        """
        Common sub-process trainer entry point used by most tests.
        """
        elastic_coordinator = CoordinatorP2P(
            c10d_backend="gloo",
            init_method=self.get_rdzv_url(run_id, self.min_size, self.max_size),
            max_num_trainers=self.max_size,
            process_group_timeout=10000,
        )
        return self._train_common(
            _, elastic_coordinator, train_step, hooks, state_override
        )
github pytorch / elastic / test / p2p / elastic_trainer_test_base.py View on Github external
def _train_with_worker_stats(
        self,
        _,
        run_id,
        train_step,
        hooks,
        state_override=None,
        worker_stats_progress_rate=None,
    ):
        """
        Similar to `_train`, but uses a coordinator that validates WorkerStats object
        """
        fixed_worker_stats = TestWorkerStats(progress_rate=worker_stats_progress_rate)

        elastic_coordinator = CoordinatorP2P(
            c10d_backend="gloo",
            init_method=self.get_rdzv_url(run_id, self.min_size, self.max_size),
            max_num_trainers=self.max_size,
            process_group_timeout=10000,
        )
        return self._train_common(
            _,
            elastic_coordinator,
            train_step,
            hooks,
            state_override,
            fixed_worker_stats,
        )
github pytorch / elastic / test / p2p / elastic_trainer_test_base.py View on Github external
def _train_common(
        self,
        _,
        elastic_coordinator,
        train_step,
        hooks,
        state_override=None,
        worker_stats=None,
    ):
        state = TestState() if state_override is None else state_override

        elastic_train_step = _make_elastic_train_step(train_step, hooks, worker_stats)
        state = elastic_train_loop.train(elastic_coordinator, elastic_train_step, state)
        return state
github pytorch / elastic / test / p2p / elastic_trainer_test_base.py View on Github external
def process_retryable_exception():
            # Raise exception repeatedly
            raise RuntimeError("train_step throws RuntimeError (retryable exception)")

        hooks = {"process_retryable_exception": process_retryable_exception}

        nprocs = 4
        qouts = []
        qerrs = []

        for _ in range(0, nprocs - 2):
            _, qout, qerr = self._spawn(self._train, run_id, _train_step, None)
            qouts.append(qout)
            qerrs.append(qerr)

        with patch.object(elastic_train_loop, "MAX_FAILURES", 5):
            for _ in range(nprocs - 2, nprocs):
                _, qout, qerr = self._spawn(self._train, run_id, _train_step, hooks)
                qouts.append(qout)
                qerrs.append(qerr)

        # Gather all "trained" values from all trainers, and ensure
        # that the bad trainers raise the expected exception.
        sums = []
        for i in range(0, nprocs):
            if i <= 1:
                state = _get_or_raise(qouts[i], qerrs[i])
                sums.append(state.total_sum)
                # Initially, 4 trainers consume 2 samples each, then the
                # surviving 2 trainers divide the remaining 20-8=12 samples, so
                # the surviving trainers each successfully process 2+6=8 samples.
                # nums keeps track of the samples "seen" so the surviving trainers
github pytorch / elastic / torchelastic / rendezvous / etcd_rendezvous.py View on Github external
def join_rendezvous(self, expected_version):
        # Use compare-and-swap to add self to rendezvous state:
        while True:
            cas_delay()
            active_version, state = self.get_rdzv_state()

            if state["status"] != "joinable":
                raise EtcdRendezvousRetryableFailure(
                    "Rendezvous state became non-joinable before we could join. "
                    "Must join next one."
                )

            if state["version"] != expected_version:
                raise EtcdRendezvousRetryImmediately(
                    "Rendezvous version changed. Must try join the new one."
                )

            assert (
                len(state["participants"]) < self._num_max_workers
            ), "Logic error: joinable rendezvous should always have space left"

            this_rank = len(state["participants"])
            state["participants"].append(this_rank)

            # When reaching min workers, or changing state to frozen, we'll set
            # the active_version node to be ephemeral.
            if len(state["participants"]) == self._num_max_workers:
                state["status"] = "frozen"
                state["keep_alives"] = []
                set_ttl = CONST_ETCD_FROZEN_TTL
github pytorch / elastic / torchelastic / rendezvous / etcd_rendezvous.py View on Github external
def announce_self_waiting(self, expected_version):
        while True:
            cas_delay()
            active_version, state = self.get_rdzv_state()

            if state["status"] != "final" or state["version"] != expected_version:
                raise EtcdRendezvousRetryImmediately()

            # Increment counter to signal an additional waiting worker.
            state["num_workers_waiting"] += 1

            try:
                active_version = self.client.test_and_set(
                    key=self.get_path("/rdzv/active_version"),
                    value=json.dumps(state),
                    prev_value=active_version.value,
                )
                return active_version

            except etcd.EtcdCompareFailed:
                log.info("Announce self as waiting CAS unsuccessful, retrying")
github pytorch / elastic / torchelastic / rendezvous / etcd_rendezvous.py View on Github external
def confirm_membership(self, expected_version, this_rank):
        # Compare-and-swap loop
        while True:
            cas_delay()
            active_version, state = self.get_rdzv_state()

            if state["status"] != "frozen":
                raise EtcdRendezvousRetryImmediately(
                    "Rendezvous no longer frozen, before we confirmed. "
                    "Must join next one"
                )
            if state["version"] != expected_version:
                raise EtcdRendezvousRetryImmediately(
                    "Rendezvous version changed. Must try join the new one."
                )

            this_lease_key = self.get_path(
                "/rdzv/v_{}/rank_{}".format(expected_version, this_rank)
            )
            self.client.set(this_lease_key, value=None, ttl=CONST_WORKER_KEEPALIVE_TTL)

            state["keep_alives"].append(this_lease_key)
            if len(state["keep_alives"]) == len(state["participants"]):
                # Everyone confirmed (this rank is last to do so)