Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@metrics.profile("torchelastic")
def on_error(self, e):
log.error(
"Rank: {0}\n"
"Error: {1}\n"
"ErrorType: {2}\n"
"StackTrace: {3}".format(self.rank, str(e), type(e), traceback.format_exc())
)
@metrics.profile("torchelastic")
def signal_training_done(self):
# Close the rendezvous to indicate termination of the overall execution.
# This also propagates the stop signal to other trainers.
self.rendezvous.set_closed()
self._destroy_process_group()
self.stop_training = True
@metrics.profile("torchelastic")
def should_save_checkpoint(self):
"""
Whether the PET training loop need to do checkpoint.
This normally happens when the job was explicitly ask for checkpoint.
eg: executor got a preemption from scheduler
"""
return False
@metrics.profile("torchelastic")
def monitor_progress(self, state, worker_stats):
self.monitor_progress_step += 1
if (self.monitor_progress_step % self.MONITOR_PROGRESS_FREQ) != 0:
return
# In P2P, workers exchange progress rate info, and everyone compares itself
# to the best worker in the group.
# Logging only, no enforcement (T42935591)
# All workers must participate in the collective communication, even if some
# of them don't have a non-null WorkerStats or progress rate.
if worker_stats is not None and worker_stats.get_progress_rate() is not None:
prog_rate = worker_stats.get_progress_rate()
prog_rate_known = True
else:
prog_rate = 0.0
@metrics.profile("torchelastic")
def report_progress(self, state):
pass
@metrics.profile("torchelastic")
def should_rendezvous(self, state):
if dist.get_world_size() == self.max_num_trainers:
return False
# Check if there are any new workers waiting at the rendezvous barrier
num_new_nodes = torch.LongTensor([self.rendezvous.num_nodes_waiting()])
# Use the GLOO based coordinator_process_group to perform the
# collective op as we don't want to transfer these back-and-forth
# between GPU and CPU (when GPUs are available).
dist.all_reduce(
num_new_nodes, op=dist.ReduceOp.MAX, group=self.coordinator_process_group
)
if num_new_nodes > 0:
log.info(