Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
elif epoch < 80:
lr = world_size * params.base_learning_rate * (0.1 ** (epoch // 30))
else:
lr = world_size * params.base_learning_rate * (0.1 ** 3)
for param_group in optimizer.param_groups:
lr_old = param_group["lr"]
param_group["lr"] = lr
# Trick: apply momentum correction when lr is updated
if lr > lr_old:
param_group["momentum"] = lr / lr_old * 0.9 # momentum
else:
param_group["momentum"] = 0.9 # default momentum
return
class ImagenetState(torchelastic.State):
"""
Client-provided State object; it is serializable and captures the entire
state needed for executing one iteration of training
"""
def __init__(self, model, params, dataset, num_epochs, epoch=0):
self.model = model
self.params = params
self.dataset = dataset
self.total_batch_size = params.batch_per_device
self.num_epochs = num_epochs
self.epoch = epoch
self.iteration = 0
self.data_start_index = 0
def train(elastic_coordinator, train_step, state):
"""
This is the main elastic data parallel loop. It starts from an initial 'state'.
Each iteration calls 'train_step' and returns a new state. 'train_step'
has the following interface:
state, worker_stats = train_step(state)
When 'train_step' exhausts all the data, a StopIteration exception should be
thrown.
"""
assert isinstance(state, torchelastic.State)
failure_count = 0
rank = 0
checkpoint_util = CheckpointUtil(elastic_coordinator)
while not elastic_coordinator.should_stop_training():
# See: https://github.com/pytorch/elastic/issues/7
if failure_count >= MAX_FAILURES:
e = RuntimeError(
"Exceeded max number of recoverable failures: {}".format(failure_count)
)
elastic_coordinator.on_error(e)
raise e
start_time = time.time()