Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""
debug = logger.debug if logger is not None else print
# Load model and args
state = torch.load(path, map_location=lambda storage, loc: storage)
args, loaded_state_dict = state['args'], state['state_dict']
if current_args is not None:
args = current_args
args.cuda = cuda
load_encoder_only = current_args.load_encoder_only if current_args is not None else False
# Build model
model = build_model(args)
model_state_dict = model.state_dict()
# Skip missing parameters and parameters of mismatched size
pretrained_state_dict = {}
for param_name in loaded_state_dict.keys():
if load_encoder_only and 'encoder' not in param_name:
continue
if param_name not in model_state_dict:
debug(f'Pretrained parameter "{param_name}" cannot be found in model parameters.')
elif model_state_dict[param_name].shape != loaded_state_dict[param_name].shape:
debug(f'Pretrained parameter "{param_name}" '
f'of shape {loaded_state_dict[param_name].shape} does not match corresponding '
f'model parameter of shape {model_state_dict[param_name].shape}.')
else:
debug(f'Loading pretrained parameter "{param_name}".')
test_targets['vocab'] = [target if mask == 0 else None for target, mask in zip(test_targets['vocab'], test_data.mask())]
# Train ensemble of models
for model_idx in range(args.ensemble_size):
# Tensorboard writer
save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
os.makedirs(save_dir, exist_ok=True)
writer = SummaryWriter(log_dir=save_dir)
# Load/build model
if args.checkpoint_paths is not None:
debug(f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}')
model = load_checkpoint(args.checkpoint_paths[model_idx], current_args=args, logger=logger)
else:
debug(f'Building model {model_idx}')
model = build_model(args)
debug(model)
debug(f'Number of parameters = {param_count(model):,}')
if args.cuda:
debug('Moving model to cuda')
model = model.cuda()
# Ensure that model is saved in correct location for evaluation if 0 epochs
save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)
if args.adjust_weight_decay:
args.pnorm_target = compute_pnorm(model)
# Optimizers
optimizer = build_optimizer(model, args)
batch = mol2graph(smiles_batch, args)
batch.bert_mask(mol_batch.mask())
else:
batch = smiles_batch
if args.maml: # TODO refactor with train loop
model.zero_grad()
intermediate_preds = model(batch, features_batch)
loss = get_loss_func(args)(intermediate_preds, targets)
loss = loss.sum() / len(batch)
grad = torch.autograd.grad(loss, [p for p in model.parameters() if p.requires_grad])
theta = [p for p in model.named_parameters() if p[1].requires_grad] # comes in same order as grad
theta_prime = {p[0]: p[1] - args.maml_lr * grad[i] for i, p in enumerate(theta)}
for name, nongrad_param in [p for p in model.named_parameters() if not p[1].requires_grad]:
theta_prime[name] = nongrad_param + torch.zeros(nongrad_param.size()).to(nongrad_param)
model_prime = build_model(args=args, params=theta_prime)
smiles_batch, features_batch, targets_batch = task_test_data.smiles(), task_test_data.features(), task_test_data.targets(task_idx)
# no mask since we only picked data points that have the desired target
with torch.no_grad():
batch_preds = model_prime(smiles_batch, features_batch)
full_targets.extend([[t] for t in targets_batch])
else:
with torch.no_grad():
if args.parallel_featurization:
previous_graph_input_mode = model.encoder.graph_input
model.encoder.graph_input = True # force model to accept already processed input
batch_preds = model(featurized_mol_batch, features_batch)
model.encoder.graph_input = previous_graph_input_mode
else:
batch_preds = model(batch, features_batch)
if args.dataset_type == 'bert_pretraining':
scheduler = build_lr_scheduler(optimizer, args)
# Run training
best_score = float('inf') if args.minimize_score else -float('inf')
best_epoch, n_iter = 0, 0
for epoch in trange(args.epochs):
debug(f'Epoch {epoch}')
if args.prespecified_chunk_dir is not None:
# load some different random chunks each epoch
train_data, val_data = load_prespecified_chunks(args, logger)
debug('Loaded prespecified chunks for epoch')
if args.dataset_type == 'unsupervised': # won't work with moe
full_data = MoleculeDataset(train_data.data + val_data.data)
generate_unsupervised_cluster_labels(build_model(args), full_data, args) # cluster with a new random init
model.create_ffn(args) # reset the ffn since we're changing targets-- we're just pretraining the encoder.
optimizer.param_groups.pop() # remove ffn parameters
optimizer.add_param_group({'params': model.ffn.parameters(), 'lr': args.init_lr[1], 'weight_decay': args.weight_decay[1]})
if args.cuda:
model.ffn.cuda()
if args.gradual_unfreezing:
if epoch % args.epochs_per_unfreeze == 0:
unfroze_layer = model.unfreeze_next() # consider just stopping early after we have nothing left to unfreeze?
if unfroze_layer:
debug('Unfroze last frozen layer')
n_iter = train(
model=model,
data=train_data,
loss_func=loss_func,
loss = loss_func(preds, targets) * class_weights * mask
if args.predict_features_and_task:
loss = (loss.sum() + loss[:, :-args.features_size].sum() * (args.task_weight-1)) \
/ (mask.sum() + mask[:, :-args.features_size].sum() * (args.task_weight-1))
else:
loss = loss.sum() / mask.sum()
if args.dataset_type == 'bert_pretraining' and features_targets is not None:
loss += features_loss(features_preds, features_targets)
loss_sum += loss.item()
iter_count += len(mol_batch)
if args.maml:
model_prime = build_model(args=args, params=theta_prime)
smiles_batch, features_batch, target_batch = task_test_data.smiles(), task_test_data.features(), [t[task_idx] for t in task_test_data.targets()]
# no mask since we only picked data points that have the desired target
targets = torch.Tensor([[t] for t in target_batch])
if next(model_prime.parameters()).is_cuda:
targets = targets.cuda()
model_prime.zero_grad()
preds = model_prime(smiles_batch, features_batch)
loss = loss_func(preds, targets)
loss = loss.sum() / len(smiles_batch)
loss_sum += loss.item()
iter_count += len(smiles_batch) # TODO check that this makes sense, but it's just for display
maml_sum_loss += loss
if i % args.maml_batch_size == args.maml_batch_size - 1:
maml_sum_loss.backward()
optimizer.step()
model.zero_grad()
args.features_path = [os.path.join(features_dir, dataset_name + '.pckl')]
modify_train_args(args)
# Set up logging for training
os.makedirs(args.save_dir, exist_ok=True)
fh = logging.FileHandler(os.path.join(args.save_dir, args.log_name))
fh.setLevel(logging.DEBUG)
# Cross validate
TRAIN_LOGGER.addHandler(fh)
mean_score, std_score = cross_validate(args, TRAIN_LOGGER)
TRAIN_LOGGER.removeHandler(fh)
# Record results
logger.info(f'{mean_score} +/- {std_score} {metric}')
temp_model = build_model(args)
logger.info(f'num params: {param_count(temp_model):,}')
# Update args with hyperparams
hyper_args = deepcopy(args)
if args.save_dir is not None:
folder_name = '_'.join([f'{key}_{value}' if key in INT_KEYS else f'{key}_{value}' for key, value in hyperparams.items()])
hyper_args.save_dir = os.path.join(hyper_args.save_dir, folder_name)
for key, value in hyperparams.items():
setattr(hyper_args, key, value)
# Record hyperparameters
logger.info(hyperparams)
# Cross validate
mean_score, std_score = cross_validate(hyper_args, train_logger)
# Record results
temp_model = build_model(hyper_args)
num_params = param_count(temp_model)
logger.info(f'num params: {num_params:,}')
logger.info(f'{mean_score} +/- {std_score} {hyper_args.metric}')
results.append({
'mean_score': mean_score,
'std_score': std_score,
'hyperparams': hyperparams,
'num_params': num_params
})
# Deal with nan
if np.isnan(mean_score):
if hyper_args.dataset_type == 'classification':
mean_score = 0
else: