Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
store, rank, world_size = elastic_coordinator.rendezvous_barrier()
elastic_coordinator.init_process_group()
# load checkpoint if necessary
state = checkpoint_util.load_checkpoint(state, rank)
state_sync_start_time = time.time()
state.sync(world_size, rank)
publish_metric(
"torchelastic",
"state_sync.duration.ms",
get_elapsed_time_ms(state_sync_start_time),
)
checkpoint_util.set_checkpoint_loaded()
log.info("Rank {0} synced state with other nodes".format(rank))
except StopException:
log.info("Rank {0} received stopped signal. Exiting training.".format(rank))
break
except RuntimeError as e:
# See: https://github.com/pytorch/elastic/issues/7
elastic_coordinator.on_error(e)
state.apply_snapshot(snapshot)
failure_count += 1
continue
except (NonRetryableException, Exception) as e:
elastic_coordinator.on_error(e)
raise
finally:
publish_metric(
"torch_elastic",
"outer_train_loop.duration.ms",
get_elapsed_time_ms(start_time),
def rendezvous_barrier(self):
self._destroy_process_group()
try:
self.store, self.rank, self.world_size = self.rendezvous.next_rendezvous()
except RendezvousClosedException:
# Sets the local variable to True
self.stop_training = True
raise StopException(
"Rank {0} received RendezvousClosedException."
" Raising a StopException".format(self.rank)
)
except (RuntimeError, Exception) as e:
raise NonRetryableException(
"Rank {0} received an Exception."
" Detailed message: {1}".format(self.rank, str(e))
)
log.info(
"Got next rendezvous: rank {0}, world size {1}".format(
self.rank, self.world_size
)
)
# Assume straggler state is unreliable after rendezvous
self.is_worker_straggler = False