Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def forward(self,
fnode: torch.Tensor,
fmess: torch.Tensor,
node_graph: torch.Tensor,
mess_graph: torch.Tensor,
scope: List[Tuple[int, int]]) -> torch.Tensor:
messages = torch.zeros(mess_graph.size(0), self.hidden_size)
if next(self.parameters()).is_cuda:
fnode, fmess, node_graph, mess_graph, messages = fnode.cuda(), fmess.cuda(), node_graph.cuda(), mess_graph.cuda(), messages.cuda()
fnode = self.embedding(fnode)
fmess = index_select_ND(fnode, fmess)
messages = self.GRU(messages, fmess, mess_graph)
mess_nei = index_select_ND(messages, node_graph)
fnode = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
fnode = self.outputNN(fnode)
tree_vec = []
for st, le in scope:
tree_vec.append(fnode.narrow(0, st, le).mean(dim=0))
return torch.stack(tree_vec, dim=0)
for i in range(self.num_heads)] # num_bonds x maxnb
attention_weights = [F.softmax(attention_scores[i], dim=1)
for i in range(self.num_heads)] # num_bonds x maxnb
message_components = [nei_message * attention_weights[i].unsqueeze(2).repeat((1, 1, self.hidden_size))
for i in range(self.num_heads)] # num_bonds x maxnb x hidden
message_components = [component.sum(dim=1) for component in message_components] # num_bonds x hidden
message = torch.cat(message_components, dim=1) # num_bonds x num_heads * hidden
elif self.atom_messages:
nei_a_message = index_select_ND(message, a2a) # num_atoms x max_num_bonds x hidden
nei_f_bonds = index_select_ND(f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim
nei_message = torch.cat((nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim
message = nei_message.sum(dim=1) # num_atoms x hidden + bond_fdim
else:
# m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
# message a_message = sum(nei_a_message) rev_message
nei_a_message = index_select_ND(message, a2b) # num_atoms x max_num_bonds x hidden
a_message = nei_a_message.sum(dim=1) # num_atoms x hidden
rev_message = message[b2revb] # num_bonds x hidden
message = a_message[b2a] - rev_message # num_bonds x hidden
for lpm in range(self.layers_per_message - 1):
message = self.W_h[lpm][depth](message) # num_bonds x hidden
message = self.act_func(message)
message = self.W_h[self.layers_per_message - 1][depth](message)
if self.normalize_messages:
message = message / message.norm(dim=1, keepdim=True)
if self.master_node:
# master_state = self.W_master_in(self.act_func(nei_message.sum(dim=0))) #try something like this to preserve invariance for master node
# master_state = self.GRU_master(nei_message.unsqueeze(1))
# master_state = master_state[-1].squeeze(0) #this actually doesn't preserve order invariance anymore
def forward(self,
fnode: torch.Tensor,
fmess: torch.Tensor,
node_graph: torch.Tensor,
mess_graph: torch.Tensor,
scope: List[Tuple[int, int]]) -> torch.Tensor:
messages = torch.zeros(mess_graph.size(0), self.hidden_size)
if next(self.parameters()).is_cuda:
fnode, fmess, node_graph, mess_graph, messages = fnode.cuda(), fmess.cuda(), node_graph.cuda(), mess_graph.cuda(), messages.cuda()
fnode = self.embedding(fnode)
fmess = index_select_ND(fnode, fmess)
messages = self.GRU(messages, fmess, mess_graph)
mess_nei = index_select_ND(messages, node_graph)
fnode = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
fnode = self.outputNN(fnode)
tree_vec = []
for st, le in scope:
tree_vec.append(fnode.narrow(0, st, le).mean(dim=0))
return torch.stack(tree_vec, dim=0)
# TODO: Parallelize attention heads
nei_message = index_select_ND(message, b2b)
message = message.unsqueeze(1).repeat((1, nei_message.size(1), 1)) # num_bonds x maxnb x hidden
attention_scores = [(self.W_ma[i](nei_message) * message).sum(dim=2)
for i in range(self.num_heads)] # num_bonds x maxnb
attention_scores = [attention_scores[i] * message_attention_mask + (1 - message_attention_mask) * (-1e+20)
for i in range(self.num_heads)] # num_bonds x maxnb
attention_weights = [F.softmax(attention_scores[i], dim=1)
for i in range(self.num_heads)] # num_bonds x maxnb
message_components = [nei_message * attention_weights[i].unsqueeze(2).repeat((1, 1, self.hidden_size))
for i in range(self.num_heads)] # num_bonds x maxnb x hidden
message_components = [component.sum(dim=1) for component in message_components] # num_bonds x hidden
message = torch.cat(message_components, dim=1) # num_bonds x num_heads * hidden
elif self.atom_messages:
nei_a_message = index_select_ND(message, a2a) # num_atoms x max_num_bonds x hidden
nei_f_bonds = index_select_ND(f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim
nei_message = torch.cat((nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim
message = nei_message.sum(dim=1) # num_atoms x hidden + bond_fdim
else:
# m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
# message a_message = sum(nei_a_message) rev_message
nei_a_message = index_select_ND(message, a2b) # num_atoms x max_num_bonds x hidden
a_message = nei_a_message.sum(dim=1) # num_atoms x hidden
rev_message = message[b2revb] # num_bonds x hidden
message = a_message[b2a] - rev_message # num_bonds x hidden
for lpm in range(self.layers_per_message - 1):
message = self.W_h[lpm][depth](message) # num_bonds x hidden
message = self.act_func(message)
message = self.W_h[self.layers_per_message - 1][depth](message)
if self.normalize_messages:
global_attention_mask[i, start:start + length] = 1
if next(self.parameters()).is_cuda:
global_attention_mask = global_attention_mask.cuda()
# Message passing
for depth in range(self.depth - 1):
if self.undirected:
message = (message + message[b2revb]) / 2
if self.learn_virtual_edges:
message = message * straight_through_mask
if self.message_attention:
# TODO: Parallelize attention heads
nei_message = index_select_ND(message, b2b)
message = message.unsqueeze(1).repeat((1, nei_message.size(1), 1)) # num_bonds x maxnb x hidden
attention_scores = [(self.W_ma[i](nei_message) * message).sum(dim=2)
for i in range(self.num_heads)] # num_bonds x maxnb
attention_scores = [attention_scores[i] * message_attention_mask + (1 - message_attention_mask) * (-1e+20)
for i in range(self.num_heads)] # num_bonds x maxnb
attention_weights = [F.softmax(attention_scores[i], dim=1)
for i in range(self.num_heads)] # num_bonds x maxnb
message_components = [nei_message * attention_weights[i].unsqueeze(2).repeat((1, 1, self.hidden_size))
for i in range(self.num_heads)] # num_bonds x maxnb x hidden
message_components = [component.sum(dim=1) for component in message_components] # num_bonds x hidden
message = torch.cat(message_components, dim=1) # num_bonds x num_heads * hidden
elif self.atom_messages:
nei_a_message = index_select_ND(message, a2a) # num_atoms x max_num_bonds x hidden
nei_f_bonds = index_select_ND(f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim
nei_message = torch.cat((nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim
message = nei_message.sum(dim=1) # num_atoms x hidden + bond_fdim
def forward(self, smiles_batch: List[str]):
# Get MolTrees with memoization
mol_batch = [SMILES_TO_MOLTREE[smiles]
if smiles in SMILES_TO_MOLTREE else SMILES_TO_MOLTREE.setdefault(smiles, MolTree(smiles))
for smiles in smiles_batch]
fnode, fmess, node_graph, mess_graph, scope = self.tensorize(mol_batch)
if next(self.parameters()).is_cuda:
fnode, fmess, node_graph, mess_graph = fnode.cuda(), fmess.cuda(), node_graph.cuda(), mess_graph.cuda()
fnode = self.embedding(fnode)
fmess = index_select_ND(fnode, fmess)
tree_vec = self.jtnn((fnode, fmess, node_graph, mess_graph, scope, []))
mol_vec = self.mpn(smiles_batch)
return torch.cat([tree_vec, mol_vec], dim=-1)
if self.master_node and self.use_master_as_output:
assert self.hidden_size == self.master_dim
mol_vecs = []
for start, size in b_scope:
if size == 0:
mol_vecs.append(self.cached_zero_vector)
else:
mol_vecs.append(master_state[start])
return torch.stack(mol_vecs, dim=0)
# Get atom hidden states from message hidden states
if self.learn_virtual_edges:
message = message * straight_through_mask
a2x = a2a if self.atom_messages else a2b
nei_a_message = index_select_ND(message, a2x) # num_atoms x max_num_bonds x hidden
a_message = nei_a_message.sum(dim=1) # num_atoms x hidden
a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden)
atom_hiddens = self.act_func(self.W_o(a_input)) # num_atoms x hidden
atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden
if self.deepset:
atom_hiddens = self.W_s2s_a(atom_hiddens)
atom_hiddens = self.act_func(atom_hiddens)
atom_hiddens = self.W_s2s_b(atom_hiddens)
if self.bert_pretraining:
atom_preds = self.W_v(atom_hiddens)[1:] # num_atoms x vocab/output size (leave out atom padding)
# Readout
if self.set2set:
# Set up sizes
if self.message_attention:
# TODO: Parallelize attention heads
nei_message = index_select_ND(message, b2b)
message = message.unsqueeze(1).repeat((1, nei_message.size(1), 1)) # num_bonds x maxnb x hidden
attention_scores = [(self.W_ma[i](nei_message) * message).sum(dim=2)
for i in range(self.num_heads)] # num_bonds x maxnb
attention_scores = [attention_scores[i] * message_attention_mask + (1 - message_attention_mask) * (-1e+20)
for i in range(self.num_heads)] # num_bonds x maxnb
attention_weights = [F.softmax(attention_scores[i], dim=1)
for i in range(self.num_heads)] # num_bonds x maxnb
message_components = [nei_message * attention_weights[i].unsqueeze(2).repeat((1, 1, self.hidden_size))
for i in range(self.num_heads)] # num_bonds x maxnb x hidden
message_components = [component.sum(dim=1) for component in message_components] # num_bonds x hidden
message = torch.cat(message_components, dim=1) # num_bonds x num_heads * hidden
elif self.atom_messages:
nei_a_message = index_select_ND(message, a2a) # num_atoms x max_num_bonds x hidden
nei_f_bonds = index_select_ND(f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim
nei_message = torch.cat((nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim
message = nei_message.sum(dim=1) # num_atoms x hidden + bond_fdim
else:
# m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
# message a_message = sum(nei_a_message) rev_message
nei_a_message = index_select_ND(message, a2b) # num_atoms x max_num_bonds x hidden
a_message = nei_a_message.sum(dim=1) # num_atoms x hidden
rev_message = message[b2revb] # num_bonds x hidden
message = a_message[b2a] - rev_message # num_bonds x hidden
for lpm in range(self.layers_per_message - 1):
message = self.W_h[lpm][depth](message) # num_bonds x hidden
message = self.act_func(message)
message = self.W_h[self.layers_per_message - 1][depth](message)