Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
for model_idx in range(args.ensemble_size):
# Tensorboard writer
save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
os.makedirs(save_dir, exist_ok=True)
writer = SummaryWriter(log_dir=save_dir)
# Load/build model
if args.checkpoint_paths is not None:
debug(f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}')
model = load_checkpoint(args.checkpoint_paths[model_idx], current_args=args, logger=logger)
else:
debug(f'Building model {model_idx}')
model = build_model(args)
debug(model)
debug(f'Number of parameters = {param_count(model):,}')
if args.cuda:
debug('Moving model to cuda')
model = model.cuda()
# Ensure that model is saved in correct location for evaluation if 0 epochs
save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)
if args.adjust_weight_decay:
args.pnorm_target = compute_pnorm(model)
# Optimizers
optimizer = build_optimizer(model, args)
# Learning rate schedulers
scheduler = build_lr_scheduler(optimizer, args)
hyper_args = deepcopy(args)
if args.save_dir is not None:
folder_name = '_'.join([f'{key}_{value}' if key in INT_KEYS else f'{key}_{value}' for key, value in hyperparams.items()])
hyper_args.save_dir = os.path.join(hyper_args.save_dir, folder_name)
for key, value in hyperparams.items():
setattr(hyper_args, key, value)
# Record hyperparameters
logger.info(hyperparams)
# Cross validate
mean_score, std_score = cross_validate(hyper_args, train_logger)
# Record results
temp_model = build_model(hyper_args)
num_params = param_count(temp_model)
logger.info(f'num params: {num_params:,}')
logger.info(f'{mean_score} +/- {std_score} {hyper_args.metric}')
results.append({
'mean_score': mean_score,
'std_score': std_score,
'hyperparams': hyperparams,
'num_params': num_params
})
# Deal with nan
if np.isnan(mean_score):
if hyper_args.dataset_type == 'classification':
mean_score = 0
else:
raise ValueError('Can\'t handle nan score for non-classification dataset.')
modify_train_args(args)
# Set up logging for training
os.makedirs(args.save_dir, exist_ok=True)
fh = logging.FileHandler(os.path.join(args.save_dir, args.log_name))
fh.setLevel(logging.DEBUG)
# Cross validate
TRAIN_LOGGER.addHandler(fh)
mean_score, std_score = cross_validate(args, TRAIN_LOGGER)
TRAIN_LOGGER.removeHandler(fh)
# Record results
logger.info(f'{mean_score} +/- {std_score} {metric}')
temp_model = build_model(args)
logger.info(f'num params: {param_count(temp_model):,}')