Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _train_common(
self,
_,
elastic_coordinator,
train_step,
hooks,
state_override=None,
worker_stats=None,
):
state = TestState() if state_override is None else state_override
elastic_train_step = _make_elastic_train_step(train_step, hooks, worker_stats)
state = elastic_train_loop.train(elastic_coordinator, elastic_train_step, state)
return state
def process_retryable_exception():
# Raise exception repeatedly
raise RuntimeError("train_step throws RuntimeError (retryable exception)")
hooks = {"process_retryable_exception": process_retryable_exception}
nprocs = 4
qouts = []
qerrs = []
for _ in range(0, nprocs - 2):
_, qout, qerr = self._spawn(self._train, run_id, _train_step, None)
qouts.append(qout)
qerrs.append(qerr)
with patch.object(elastic_train_loop, "MAX_FAILURES", 5):
for _ in range(nprocs - 2, nprocs):
_, qout, qerr = self._spawn(self._train, run_id, _train_step, hooks)
qouts.append(qout)
qerrs.append(qerr)
# Gather all "trained" values from all trainers, and ensure
# that the bad trainers raise the expected exception.
sums = []
for i in range(0, nprocs):
if i <= 1:
state = _get_or_raise(qouts[i], qerrs[i])
sums.append(state.total_sum)
# Initially, 4 trainers consume 2 samples each, then the
# surviving 2 trainers divide the remaining 20-8=12 samples, so
# the surviving trainers each successfully process 2+6=8 samples.
# nums keeps track of the samples "seen" so the surviving trainers