Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _sync_state(self, rank):
# broadcast from the max rank with the biggest start index
max_rank, _ = edist.all_gather_return_max_long(self.data_start_index)
# Broadcast the state from max_rank
buffer = io.BytesIO()
self.save(buffer)
state_tensor = torch.ByteTensor(list(buffer.getvalue()))
state_size = torch.LongTensor([state_tensor.size()])
dist.broadcast(state_size, src=max_rank)
if rank != max_rank:
state_tensor = torch.ByteTensor([0 for _ in range(state_size[0])])
dist.broadcast(state_tensor, src=max_rank)
buffer = io.BytesIO(state_tensor.numpy().tobytes())
self.load(buffer)
def load_checkpoint(self, state, rank: int):
"""
Loads checkpoint if the checkpoint manager has been configured and
at least one worker has already loaded the checkpoint
"""
if not self.checkpoint_manager:
# checkpoint not enabled
return state
# all gather `checkpoint_loaded` from all trainers, return true
# if any trainer have ever loaded checkpoint
any_checkpoint_loaded = (
edist.all_gather_return_max_long(1 if self.checkpoint_loaded else 0) == 1
)
if any_checkpoint_loaded:
# checkpoint already loaded by one of the existing trainer
return state
# we load checkpoint only if all trainers start from scratch. it is
# not necessary to load checkpoint if there is a good trainer as new
# trainer can sync state from it.
# Start with simple scenario, we always ask one single trainer to
# load checkpoint and other trainer sync from it
if rank == 0:
state = self._do_load_checkpoint(state)
return state