Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""
predicted = {}
for batch in self.dataloader:
# build batch for prediction
batch = {k: v.to(device) for k, v in batch.items()}
# predict
result = self.model(batch)
# store prediction batches to dict
for p in result.keys():
value = result[p].cpu().detach().numpy()
if p in predicted.keys():
predicted[p].append(value)
else:
predicted[p] = [value]
# store positions, numbers and mask to dict
for p in [Properties.R, Properties.Z, Properties.atom_mask]:
value = batch[p].cpu().detach().numpy()
if p in predicted.keys():
predicted[p].append(value)
else:
predicted[p] = [value]
max_shapes = {
prop: max([list(val.shape) for val in values])
for prop, values in predicted.items()
}
for prop, values in predicted.items():
max_shape = max_shapes[prop]
predicted[prop] = np.vstack(
[
np.lib.pad(
batch,
def forward(self, inputs):
r"""
predicts atomwise property
"""
atomic_numbers = inputs[Properties.Z]
atom_mask = inputs[Properties.atom_mask]
# run prediction
yi = self.out_net(inputs)
yi = self.standardize(yi)
if self.atomref is not None:
y0 = self.atomref(atomic_numbers)
yi = yi + y0
y = self.atom_pool(yi, atom_mask)
# collect results
result = {self.property: y}
if self.contributions:
def forward(self, inputs):
"""Compute atomic representations/embeddings.
Args:
inputs (dict of torch.Tensor): SchNetPack dictionary of input tensors.
Returns:
torch.Tensor: atom-wise representation.
list of torch.Tensor: intermediate atom-wise representations, if
return_intermediate=True was used.
"""
# get tensors from input dictionary
atomic_numbers = inputs[Properties.Z]
positions = inputs[Properties.R]
cell = inputs[Properties.cell]
cell_offset = inputs[Properties.cell_offset]
neighbors = inputs[Properties.neighbors]
neighbor_mask = inputs[Properties.neighbor_mask]
atom_mask = inputs[Properties.atom_mask]
# get atom embeddings for the input atomic numbers
x = self.embedding(atomic_numbers)
if False and self.charged_systems and Properties.charge in inputs.keys():
n_atoms = torch.sum(atom_mask, dim=1, keepdim=True)
charge = inputs[Properties.charge] / n_atoms # B
charge = charge[:, None] * self.charge # B x F
x = x + charge
def forward(self, inputs):
"""
Args:
inputs (dict of torch.Tensor): SchNetPack format dictionary of input tensors.
Returns:
torch.Tensor: Output of the gated network.
"""
# At this point, inputs should be the general schnetpack container
atomic_numbers = inputs[Properties.Z]
representation = inputs["representation"]
gated_network = self.gate(atomic_numbers) * self.network(representation)
return torch.sum(gated_network, -1, keepdim=True)
def get_properties(self, idx):
_idx = self._subset_index(idx)
with connect(self.dbpath) as conn:
row = conn.get(_idx + 1)
at = row.toatoms()
# extract/calculate structure
properties = {}
properties[Properties.Z] = torch.LongTensor(at.numbers.astype(np.int))
positions = at.positions.astype(np.float32)
positions -= at.get_center_of_mass() # center positions
properties[Properties.R] = torch.FloatTensor(positions)
properties[Properties.cell] = torch.FloatTensor(at.cell.astype(np.float32))
# recover connectivity matrix from compressed format
con_mat = self.connectivity_compressor.decompress(row.data['con_mat'])
# save in dictionary
properties['_con_mat'] = torch.FloatTensor(con_mat.astype(np.float32))
# extract pre-computed distances (if they exist)
if 'dists' in row.data:
properties['dists'] = row.data['dists']
# get atom environment
nbh_idx, offsets = self.environment_provider.get_environment(at)
replica_idx (int): Replica of the molecule to extract (e.g. for ring polymer molecular dynamics). If
replica_idx is set to None (default), the centroid is returned if multiple replicas are
present.
atomistic (bool): Whether the property is atomistic (e.g. forces) or defined for the whole molecule
(e.g. energies, dipole moments). If set to True, the array is masked according to the
number of atoms for the requested molecule to counteract potential zero-padding.
(default=False)
Returns:
np.array: N_steps x [ property dimensions... ] array containing the requested property collected during the
simulation.
"""
# Special case for atom types
if property_name == Properties.Z:
return self.properties[Properties.Z][mol_idx]
# Check whether property is present
if property_name not in self.properties:
raise HDF5LoaderError(f"Property {property_name} not found in database.")
# Mask by number of atoms if property is declared atomistic
if atomistic:
n_atoms = self.n_atoms[mol_idx]
target_property = self.properties[property_name][
:, :, mol_idx, :n_atoms, ...
]
else:
target_property = self.properties[property_name][:, :, mol_idx, ...]
# Compute the centroid unless requested otherwise
if replica_idx is None:
for prop, val in properties.items():
shape = val.size()
s = (k,) + tuple([slice(0, d) for d in shape])
batch[prop][s] = val
# add mask
if not has_neighbor_mask:
nbh = properties[Properties.neighbors]
shape = nbh.size()
s = (k,) + tuple([slice(0, d) for d in shape])
mask = nbh >= 0
batch[Properties.neighbor_mask][s] = mask
batch[Properties.neighbors][s] = nbh * mask.long()
if not has_atom_mask:
z = properties[Properties.Z]
shape = z.size()
s = (k,) + tuple([slice(0, d) for d in shape])
batch[Properties.atom_mask][s] = z > 0
# Check if neighbor pair indices are present
# Since the structure of both idx_j and idx_k is identical
# (not the values), only one cutoff mask has to be generated
if Properties.neighbor_pairs_j in properties:
nbh_idx_j = properties[Properties.neighbor_pairs_j]
shape = nbh_idx_j.size()
s = (k,) + tuple([slice(0, d) for d in shape])
batch[Properties.neighbor_pairs_mask][s] = nbh_idx_j >= 0
return batch
Args:
mol_dict (dict of torch.Tensor): dict containing the atom positions
('_positions') ordered by distance to the center of mass of the molecule,
the atomic numbers ('_atomic_numbers'), the connectivity matrix
('_con_mat'), and, optionally, the precomputed distances ('dists')
seed (int, optional): a seed for the random selection of the focus at each
step (default: None)
'''
# set seed
if seed is not None:
old_state = torch.get_rng_state()
torch.manual_seed(seed)
# extract positions, atomic numbers, and connectivity matrix
numbers = mol_dict[Properties.Z]
n_atoms = len(numbers)
con_mat = (mol_dict['_con_mat'] > 0).float()
current = [-1] # in the first step, none of the atoms is focused
order = [0] # the new ordering always starts with the first atom (closest to com)
pred_types = [numbers[0]] # the first predicted type is that of the first atom
# start from first atom and traverse molecular graph (choosing the focus randomly)
con_mat[:, 0] = 0 # mark first atom as placed by removing its bonds
avail = torch.zeros(n_atoms).float() # list with atoms available as focus
avail[0] = 1. # first atom is available
i = 1
while torch.sum(con_mat > 0) or (torch.sum(avail) > 0):
# take random current focus
cur_i = torch.multinomial(avail, 1)[0]
current += [cur_i]
mol_idx (int): Index of the molecule to extract, by default uses the first molecule (mol_idx=0)
replica_idx (int): Replica of the molecule to extract (e.g. for ring polymer molecular dynamics). If
replica_idx is set to None (default), the centroid is returned if multiple replicas are
present.
atomistic (bool): Whether the property is atomistic (e.g. forces) or defined for the whole molecule
(e.g. energies, dipole moments). If set to True, the array is masked according to the
number of atoms for the requested molecule to counteract potential zero-padding.
(default=False)
Returns:
np.array: N_steps x property dimensions array containing the requested property collected during the simulation.
"""
# Special case for atom types
if property_name == Properties.Z:
return self.properties[Properties.Z][mol_idx]
# Check whether property is present
if property_name not in self.properties:
raise HDF5LoaderError(f"Property {property_name} not found in database.")
# Mask by number of atoms if property is declared atomistic
if atomistic:
n_atoms = self.n_atoms[mol_idx]
target_property = self.properties[property_name][
:, :, mol_idx, :n_atoms, ...
]
else:
target_property = self.properties[property_name][:, :, mol_idx, ...]
# Compute the centroid unless requested otherwise
if replica_idx is None:
def evaluate(self, device):
predicted = self._get_predicted(device)
positions = predicted.pop(Properties.R)
atomic_numbers = predicted.pop(Properties.Z)
atom_masks = predicted.pop(Properties.atom_mask).astype(bool)
with connect(self.out_file) as conn:
for i, mask in enumerate(atom_masks):
z = atomic_numbers[i, mask]
r = positions[i, mask]
ats = Atoms(numbers=z, positions=r)
data = {
prop: self._unpad(mask, values[i])
for prop, values in predicted.items()
}
conn.write(ats, data=data)