Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _train_rerendezvous(self, _, run_id, train_step, hooks, state_override=None):
"""
Alternate sub-process trainer entry point used by tests that want to
force a re-rendezvous after every iteration.
"""
class RerendezvousCoordinatorP2P(CoordinatorP2P):
def should_rendezvous(self, state):
return True
elastic_coordinator = RerendezvousCoordinatorP2P(
c10d_backend="gloo",
init_method=self.get_rdzv_url(run_id, self.min_size, self.max_size),
max_num_trainers=self.max_size,
process_group_timeout=10000,
)
state = self._train_common(
_, elastic_coordinator, train_step, hooks, state_override
)
return state
def test_normal_flow_with_worker_stats(self):
"""
Test a very simple 4 trainer case, where elastic_train_step
also returns a non-None WorkerStats instance.
"""
run_id = self._generate_run_id()
nprocs = 4
qouts = []
qerrs = []
prog_rates = [100, 95, 42, None]
CoordinatorP2P.MONITOR_PROGRESS_FREQ = 1
original_monitor_progress = CoordinatorP2P.monitor_progress
def patched_monitor_progress(self, state, worker_stats):
original_monitor_progress(self, state, worker_stats)
# Save into state for retrieval in `_get_or_raise` below.
if hasattr(self, "last_relative_prog_rate"):
state._test_relative_prog_rate = self.last_relative_prog_rate
if hasattr(self, "is_worker_straggler"):
state._test_is_straggler = self.is_worker_straggler
with patch.object(CoordinatorP2P, "monitor_progress", patched_monitor_progress):
for i in range(0, nprocs):
_, qout, qerr = self._spawn(
self._train_with_worker_stats,
run_id,
def _train(self, _, run_id, train_step, hooks, state_override=None):
"""
Common sub-process trainer entry point used by most tests.
"""
elastic_coordinator = CoordinatorP2P(
c10d_backend="gloo",
init_method=self.get_rdzv_url(run_id, self.min_size, self.max_size),
max_num_trainers=self.max_size,
process_group_timeout=10000,
)
return self._train_common(
_, elastic_coordinator, train_step, hooks, state_override
)
def _train_with_worker_stats(
self,
_,
run_id,
train_step,
hooks,
state_override=None,
worker_stats_progress_rate=None,
):
"""
Similar to `_train`, but uses a coordinator that validates WorkerStats object
"""
fixed_worker_stats = TestWorkerStats(progress_rate=worker_stats_progress_rate)
elastic_coordinator = CoordinatorP2P(
c10d_backend="gloo",
init_method=self.get_rdzv_url(run_id, self.min_size, self.max_size),
max_num_trainers=self.max_size,
process_group_timeout=10000,
)
return self._train_common(
_,
elastic_coordinator,
train_step,
hooks,
state_override,
fixed_worker_stats,
)
# Trick 2: linear layers are initialized by
# drawing weights from a zero-mean Gaussian with
# standard deviation of 0.01. In the paper it was only
# fc layer, but in practice we found this better for
# accuracy.
if isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
model.train()
torch.cuda.set_device(local_rank)
device = torch.cuda.current_device()
model.cuda()
log.info(f"Rank [{local_rank}] running on GPU [{device}]")
coordinator = CoordinatorP2P(
c10d_backend=c10d_backend,
init_method=rdzv_init_url,
max_num_trainers=max_world_size,
process_group_timeout=60000,
)
state = ImagenetState(
model=model,
params=training_params,
dataset=train_dataset,
num_epochs=training_params.num_epochs,
)
log.info(f"Entering torchelastic train_loop")
torchelastic.train(coordinator, train_step, state)