Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
configure_logging(config)
print_args(config)
if config.seed is not None:
manual_seed(config.seed)
cudnn.deterministic = True
cudnn.benchmark = False
# create model
model_name = config['model']
weights = config.get('weights')
model = load_model(model_name,
pretrained=config.get('pretrained', True) if weights is None else False,
num_classes=config.get('num_classes', 1000),
model_params=config.get('model_params'))
compression_algo, model = create_compressed_model(model, config)
if weights:
load_state(model, torch.load(weights, map_location='cpu'))
model, _ = prepare_model_for_execution(model, config)
if config.distributed:
compression_algo.distributed()
is_inception = 'inception' in model_name
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(config.device)
params_to_optimize = get_parameter_groups(model, config)
optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)
resuming_checkpoint = config.resuming_checkpoint
def load_torch_model(config, cuda=False):
weights = config.get('weights')
model = load_model(config.model,
pretrained=config.get('pretrained', True) if weights is None else False,
num_classes=config.get('num_classes', 1000),
model_params=config.get('model_params', {}))
compression_algo, model = create_compressed_model(model, config)
if weights:
sd = torch.load(weights, map_location='cpu')
load_state(model, sd)
if cuda:
model = model.cuda()
model = torch.nn.DataParallel(model)
print_statistics(compression_algo.statistics())
return model
configure_logging(config)
print_args(config)
print(config)
config.device = get_device(config)
dataset = get_dataset(config.dataset)
color_encoding = dataset.color_encoding
num_classes = len(color_encoding)
weights = config.get('weights')
model = load_model(config.model,
pretrained=config.get('pretrained', True) if weights is None else False,
num_classes=num_classes,
model_params=config.get('model_params', {}))
compression_algo, model = create_compressed_model(model, config)
if weights:
sd = torch.load(weights, map_location='cpu')
load_state(model, sd)
model, model_without_dp = prepare_model_for_execution(model, config)
if config.distributed:
compression_algo.distributed()
resuming_checkpoint = config.resuming_checkpoint
if resuming_checkpoint is not None:
if not config.pretrained:
# Load the previously saved model state
model, _, _, _, _ = \
load_checkpoint(model, resuming_checkpoint, config.device,