Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def train_epoch_end(config, compression_algo, net, epoch, iteration, epoch_size, lr_scheduler, optimizer,
test_data_loader):
test_freq_in_epochs = max(config.test_interval // epoch_size, 1)
compression_algo.scheduler.epoch_step(epoch)
if not isinstance(lr_scheduler, ReduceLROnPlateau):
lr_scheduler.step(epoch)
if epoch % test_freq_in_epochs == 0 and iteration != 0:
if is_on_first_rank(config):
print_statistics(compression_algo.statistics())
with torch.no_grad():
net.eval()
mAP = test_net(net, config.device, test_data_loader, distributed=config.multiprocessing_distributed)
if isinstance(lr_scheduler, ReduceLROnPlateau):
lr_scheduler.step(mAP)
net.train()
if epoch > 0 and epoch % config.save_freq == 0 and is_on_first_rank(config):
print('Saving state, iter:', iteration)
checkpoint_file_path = osp.join(config.intermediate_checkpoints_path,
"{}_{}.pth".format(config.model, iteration))
torch.save({
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'iter': iteration,
'scheduler': compression_algo.scheduler.state_dict()
}, str(checkpoint_file_path))
if config.to_onnx is not None:
compression_algo.export_model(config.to_onnx)
print("Saved to", config.to_onnx)
return
if config.mode.lower() == 'test':
print(model)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("Trainable argument count:{params}".format(params=params))
model = model.to(config.device)
loaders, w_class = load_dataset(dataset, config)
_, val_loader = loaders
test(model, val_loader, w_class, color_encoding, config)
print_statistics(compression_algo.statistics())
elif config.mode.lower() == 'train':
loaders, w_class = load_dataset(dataset, config)
train_loader, val_loader = loaders
if not resuming_checkpoint:
compression_algo.initialize(train_loader)
model = \
train(model, model_without_dp, compression_algo, train_loader, val_loader, w_class, color_encoding, config)
else:
# Should never happen...but just in case it does
raise RuntimeError(
"\"{0}\" is not a valid choice for execution mode.".format(
config.mode))
if isinstance(lr_scheduler, ReduceLROnPlateau):
lr_scheduler.step(best_miou)
# Print per class IoU on last epoch or if best iou
if epoch + 1 == config.epochs or is_best:
for key, class_iou in zip(class_encoding.keys(), iou):
print("{0}: {1:.4f}".format(key, class_iou))
# Save the model if it's the best thus far
if is_main_process():
checkpoint_path = save_checkpoint(model,
optimizer, epoch + 1, best_miou,
compression_algo.scheduler, config)
make_additional_checkpoints(checkpoint_path, is_best, epoch + 1, config)
print_statistics(compression_algo.statistics())
return model
compression_algo.scheduler.epoch_step()
# compute compression algo statistics
stats = compression_algo.statistics()
acc1 = best_acc1
if epoch % config.test_every_n_epochs == 0:
# evaluate on validation set
acc1, _ = validate(val_loader, model, criterion, config)
# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
if is_main_process():
print_statistics(stats)
checkpoint_path = osp.join(config.checkpoint_save_dir, get_name(config) + '_last.pth')
checkpoint = {
'epoch': epoch + 1,
'arch': model_name,
'state_dict': model.state_dict(),
'best_acc1': best_acc1,
'optimizer': optimizer.state_dict(),
'scheduler': compression_algo.scheduler.state_dict()
}
torch.save(checkpoint, checkpoint_path)
make_additional_checkpoints(checkpoint_path, is_best, epoch + 1, config)
for key, value in stats.items():
if isinstance(value, (int, float)):
def load_torch_model(config, cuda=False):
weights = config.get('weights')
model = load_model(config.model,
pretrained=config.get('pretrained', True) if weights is None else False,
num_classes=config.get('num_classes', 1000),
model_params=config.get('model_params', {}))
compression_algo, model = create_compressed_model(model, config)
if weights:
sd = torch.load(weights, map_location='cpu')
load_state(model, sd)
if cuda:
model = model.cuda()
model = torch.nn.DataParallel(model)
print_statistics(compression_algo.statistics())
return model
resume_from_checkpoint(resuming_checkpoint, model,
config, optimizer, compression_algo)
if config.to_onnx is not None:
compression_algo.export_model(config.to_onnx)
print("Saved to", config.to_onnx)
return
if config.execution_mode != ExecutionMode.CPU_ONLY:
cudnn.benchmark = True
# Data loading code
train_loader, train_sampler, val_loader = create_dataloaders(config)
if config.mode.lower() == 'test':
print_statistics(compression_algo.statistics())
validate(val_loader, model, criterion, config)
if config.mode.lower() == 'train':
if not resuming_checkpoint:
compression_algo.initialize(train_loader)
train(config, compression_algo, model, criterion, is_inception, lr_scheduler, model_name, optimizer,
train_loader, train_sampler, val_loader, best_acc1)