Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _create_generator(self, *args, **kwargs) -> Union[GraphBatchDistanceConvert, GraphBatchGenerator]:
if hasattr(self.graph_converter, 'bond_converter'):
kwargs.update({'distance_converter': self.graph_converter.bond_converter})
return GraphBatchDistanceConvert(*args, **kwargs)
else:
return GraphBatchGenerator(*args, **kwargs)
- [ndarray]: List of indices for the start of each bond
- [ndarray]: List of indices for the end of each bond
"""
# Get the features and connectivity lists for this batch
it = itemgetter(*batch_index)
feature_list_temp = itemgetter_list(self.atom_features, batch_index)
connection_list_temp = itemgetter_list(self.bond_features, batch_index)
global_list_temp = itemgetter_list(self.state_features, batch_index)
index1_temp = itemgetter_list(self.index1_list, batch_index)
index2_temp = itemgetter_list(self.index2_list, batch_index)
return feature_list_temp, connection_list_temp, global_list_temp, index1_temp, index2_temp
class GraphBatchDistanceConvert(GraphBatchGenerator):
"""
Generate batch of structures with bond distance being expanded using a Expansor
Args:
atom_features: (list of np.array) list of atom feature matrix,
bond_features: (list of np.array) list of bond features matrix
state_features: (list of np.array) list of [1, G] state features, where G is the global state feature dimension
index1_list: (list of integer) list of (M, ) one side atomic index of the bond, M is different for different
structures
index2_list: (list of integer) list of (M, ) the other side atomic index of the bond, M is different for
different structures, but it has to be the same as the correponding index1.
targets: (numpy array), N*1, where N is the number of structures
batch_size: (int) number of samples in a batch
is_shuffle: (bool) whether to shuffle the structure, default to True
distance_converter: (bool) converter for processing the distances
def _create_generator(self, *args, **kwargs) -> Union[GraphBatchDistanceConvert, GraphBatchGenerator]:
if hasattr(self.graph_converter, 'bond_converter'):
kwargs.update({'distance_converter': self.graph_converter.bond_converter})
return GraphBatchDistanceConvert(*args, **kwargs)
else:
return GraphBatchGenerator(*args, **kwargs)
def create_cached_generator(self) -> GraphBatchGenerator:
"""Generates features for all of the molecules and stores them in memory
Returns:
(GraphBatchGenerator) Graph genereator that relies on having the graphs in memory
"""
# Make all the graphs
graphs = self._generate_graphs(self.mols)
# Turn them into a fat array
inputs = self.converter.get_flat_data(graphs, self.targets)
return GraphBatchGenerator(*inputs, is_shuffle=self.is_shuffle,
batch_size=self.batch_size)