Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
for name, nongrad_param in [p for p in model.named_parameters() if not p[1].requires_grad]:
theta_prime[name] = nongrad_param + torch.zeros(nongrad_param.size()).to(nongrad_param)
else:
# Prepare batch
if args.parallel_featurization:
if len(currently_loaded_batches) == 0:
currently_loaded_batches = batch_queue.get()
mol_batch, featurized_mol_batch = currently_loaded_batches.pop()
else:
if not args.last_batch and i + args.batch_size > len(data):
break
mol_batch = MoleculeDataset(data[i:i + args.batch_size])
smiles_batch, features_batch, target_batch = mol_batch.smiles(), mol_batch.features(), mol_batch.targets()
if args.dataset_type == 'bert_pretraining':
batch = mol2graph(smiles_batch, args)
mask = mol_batch.mask()
batch.bert_mask(mask)
mask = 1 - torch.FloatTensor(mask) # num_atoms
features_targets = torch.FloatTensor(target_batch['features']) if target_batch['features'] is not None else None # num_molecules x features_size
targets = torch.FloatTensor(target_batch['vocab']) # num_atoms
if args.bert_vocab_func == 'feature_vector':
mask = mask.reshape(-1, 1)
else:
targets = targets.long()
else:
batch = smiles_batch
mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])
targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch])
if next(model.parameters()).is_cuda:
mask, targets = mask.cuda(), targets.cuda()
targets = torch.Tensor(targets_batch).unsqueeze(1)
if args.cuda:
targets = targets.cuda()
else:
# Prepare batch
if args.parallel_featurization:
if len(currently_loaded_batches) == 0:
currently_loaded_batches = batch_queue.get()
mol_batch, featurized_mol_batch = currently_loaded_batches.pop(0)
else:
mol_batch = MoleculeDataset(data[i:i + args.batch_size])
smiles_batch, features_batch = mol_batch.smiles(), mol_batch.features()
# Run model
if args.dataset_type == 'bert_pretraining':
batch = mol2graph(smiles_batch, args)
batch.bert_mask(mol_batch.mask())
else:
batch = smiles_batch
if args.maml: # TODO refactor with train loop
model.zero_grad()
intermediate_preds = model(batch, features_batch)
loss = get_loss_func(args)(intermediate_preds, targets)
loss = loss.sum() / len(batch)
grad = torch.autograd.grad(loss, [p for p in model.parameters() if p.requires_grad])
theta = [p for p in model.named_parameters() if p[1].requires_grad] # comes in same order as grad
theta_prime = {p[0]: p[1] - args.maml_lr * grad[i] for i, p in enumerate(theta)}
for name, nongrad_param in [p for p in model.named_parameters() if not p[1].requires_grad]:
theta_prime[name] = nongrad_param + torch.zeros(nongrad_param.size()).to(nongrad_param)
model_prime = build_model(args=args, params=theta_prime)
smiles_batch, features_batch, targets_batch = task_test_data.smiles(), task_test_data.features(), task_test_data.targets(task_idx)
def forward(self,
batch: Union[List[str], BatchMolGraph],
features_batch: List[np.ndarray] = None) -> torch.Tensor:
"""
Encodes a batch of molecular SMILES strings.
:param batch: A list of SMILES strings or a BatchMolGraph (if self.graph_input).
:param features_batch: A list of ndarrays containing additional features.
:return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule.
"""
if not self.graph_input and not self.args.features_only: # if features only, batch won't even be used
batch = mol2graph(batch, self.args)
output = self.encoder.forward(batch, features_batch)
if self.args.adversarial:
self.saved_encoder_output = output
return output
def viz_attention(self,
viz_dir: str,
batch: Union[List[str], BatchMolGraph],
features_batch: List[np.ndarray] = None):
"""
Visualizes attention weights for a batch of molecular SMILES strings
:param viz_dir: Directory in which to save visualized attention weights.
:param batch: A list of SMILES strings or a BatchMolGraph (if self.graph_input).
:param features_batch: A list of ndarrays containing additional features.
"""
if not self.graph_input:
batch = mol2graph(batch, self.args)
self.encoder.forward(batch, features_batch, viz_dir=viz_dir)