Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
itraj : int
trajectory index
stride : int
return value is the number of frames in the trajectory when
running through it with a step size of `stride`.
skip: int or None
skip n frames.
Returns
-------
int : length of trajectory
"""
if itraj >= self.ntraj:
raise IndexError("given index (%s) exceeds number of data sets (%s)."
" Zero based indexing!" % (itraj, self.ntraj))
if not IteratorState.is_uniform_stride(stride):
selection = stride[stride[:, 0] == itraj][:, 0]
return 0 if itraj not in selection else len(selection)
else:
res = max((self._lengths[itraj] - skip - 1) // int(stride) + 1, 0)
return res
def __init_stride(self, stride):
self.state.stride = stride
if isinstance(stride, np.ndarray):
keys = stride[:, 0]
if keys.max() >= self.number_of_trajectories():
raise ValueError("provided too large trajectory index in stride argument (given max index: %s, "
"allowed: %s)" % (keys.max(), self.number_of_trajectories() - 1))
self.state.traj_keys, self.state.trajectory_lengths = np.unique(keys, return_counts=True)
self.state.ra_indices_for_traj_dict = {}
for traj in self.state.traj_keys:
self.state.ra_indices_for_traj_dict[traj] = self.state.stride[self.state.stride[:, 0] == traj][:, 1]
else:
self.state.traj_keys = None
self.state.uniform_stride = IteratorState.is_uniform_stride(stride)
if not IteratorState.is_uniform_stride(stride):
if not self.state.is_stride_sorted():
raise ValueError("Only sorted arrays allowed for iterator pseudo random access")
# skip trajs which are not included in stride
while self.state.itraj not in self.state.traj_keys and self.state.itraj < self._data_source.ntraj:
self.state.itraj += 1
def __init_stride(self, stride):
self.state.stride = stride
if isinstance(stride, np.ndarray):
keys = stride[:, 0]
if keys.max() >= self.number_of_trajectories():
raise ValueError("provided too large trajectory index in stride argument (given max index: %s, "
"allowed: %s)" % (keys.max(), self.number_of_trajectories() - 1))
self.state.traj_keys, self.state.trajectory_lengths = np.unique(keys, return_counts=True)
self.state.ra_indices_for_traj_dict = {}
for traj in self.state.traj_keys:
self.state.ra_indices_for_traj_dict[traj] = self.state.stride[self.state.stride[:, 0] == traj][:, 1]
else:
self.state.traj_keys = None
self.state.uniform_stride = IteratorState.is_uniform_stride(stride)
if not IteratorState.is_uniform_stride(stride):
if not self.state.is_stride_sorted():
raise ValueError("Only sorted arrays allowed for iterator pseudo random access")
# skip trajs which are not included in stride
while self.state.itraj not in self.state.traj_keys and self.state.itraj < self._data_source.ntraj:
self.state.itraj += 1
self.state.stride = stride
if isinstance(stride, np.ndarray):
# shift frame indices by skip
self.state.stride[:, 1] += self.state.skip
keys = stride[:, 0]
if keys.max() >= self.number_of_trajectories():
raise ValueError("provided too large trajectory index in stride argument (given max index: %s, "
"allowed: %s)" % (keys.max(), self.number_of_trajectories() - 1))
self.state.traj_keys, self.state.trajectory_lengths = np.unique(keys, return_counts=True)
self.state.ra_indices_for_traj_dict = {}
for traj in self.state.traj_keys:
self.state.ra_indices_for_traj_dict[traj] = self.state.stride[self.state.stride[:, 0] == traj][:, 1]
else:
self.state.traj_keys = None
self.state.uniform_stride = IteratorState.is_uniform_stride(stride)
if not IteratorState.is_uniform_stride(stride):
if not self.state.is_stride_sorted():
raise ValueError("Only sorted arrays allowed for iterator pseudo random access")
# skip trajs which are not included in stride
while self.state.itraj not in self.state.traj_keys and self.state.itraj < self._data_source.ntraj:
self.state.itraj += 1
def number_of_trajectories(self, stride=None):
r""" Returns the number of trajectories.
Parameters
----------
stride: None (default) or np.ndarray
Returns
-------
int : number of trajectories
"""
if not IteratorState.is_uniform_stride(stride):
n = len(np.unique(stride[:, 0]))
else:
n = self.ntraj
return n
def is_uniform_stride(stride):
return IteratorState.is_uniform_stride(stride)