Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
color=self.ax_color,
data_bounds=ax_db,
box_index=box_index,
)
# Vertical ticks every millisecond.
steps = np.arange(np.round(self.wave_duration * 1000))
# A vline every millisecond.
x = .001 * steps
# Scale to [-1, 1], same coordinates as the waveform points.
x = -1 + 2 * x / self.wave_duration
# Take overlap into account.
x = _overlap_transform(x, offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap)
x = np.tile(x, len(channel_ids_loc))
# Generate the box index.
box_index = _index_of(channel_ids_loc, self.channel_ids)
box_index = np.repeat(box_index, x.size // len(box_index))
assert x.size == box_index.size
self.tick_visual.add_batch_data(
x=x, y=np.zeros_like(x),
data_bounds=ax_db,
box_index=box_index,
)
----------
spike_clusters : array-like
The spike-cluster assignments.
cluster_ids : array-like
The set of unique selected cluster ids appearing in spike_clusters, in a given order
Returns
-------
spike_colors : array-like
For each spike, the RGBA color (in [0,1]) depending on the index of the cluster within
`cluster_ids`.
"""
spike_clusters_idx = _index_of(spike_clusters, cluster_ids)
return add_alpha(colormaps.default[np.mod(spike_clusters_idx, colormaps.default.shape[0])])
def _categorical_colormap(colormap, values, vmin=None, vmax=None):
"""Convert values into colors given a specified categorical colormap."""
assert np.issubdtype(values.dtype, np.integer)
assert colormap.shape[1] == 3
n = colormap.shape[0]
if vmin is None and vmax is None:
# Find unique values and keep the order.
_, idx = np.unique(values, return_index=True)
lookup = values[np.sort(idx)]
x = _index_of(values, lookup)
else:
x = values
return colormap[x % n, :]
def _get_box_index(self, bunch):
"""Get the box_index array for a cluster."""
# Generate the box index (channel_idx, cluster_idx) per vertex.
n_samples, nc = bunch.template.shape
box_index = _index_of(bunch.channel_ids, self.channel_ids)
box_index = np.repeat(box_index, n_samples)
box_index = np.c_[
box_index.reshape((-1, 1)),
bunch.cluster_idx * np.ones((n_samples * len(bunch.channel_ids), 1))]
assert box_index.shape == (len(bunch.channel_ids) * n_samples, 2)
assert box_index.size == bunch.template.size * 2
return box_index
def _get_box_index(self):
"""Return, for every spike, its row in the raster plot. This depends on the ordering
in self.cluster_ids."""
cl = self.spike_clusters[self.spike_ids]
# Sanity check.
# assert np.all(np.in1d(cl, self.cluster_ids))
return _index_of(cl, self.all_cluster_ids)
channel_positions = self.model.channel_positions
assert channel_positions.ndim == 2
cluster_channels = np.load(p / 'clusters.peakChannel.npy')
assert cluster_channels.ndim == 1
n_clusters = cluster_channels.shape[0]
spike_clusters = self.model.spike_clusters
assert spike_clusters.ndim == 1
n_spikes = spike_clusters.shape[0]
self.cluster_ids = _unique(self.model.spike_clusters)
clusters_depths = channel_positions[cluster_channels, 1]
assert clusters_depths.shape == (n_clusters,)
spike_clusters_rel = _index_of(spike_clusters, self.cluster_ids)
assert spike_clusters_rel.max() < clusters_depths.shape[0]
spikes_depths = clusters_depths[spike_clusters_rel]
assert spikes_depths.shape == (n_spikes,)
np.save(p / 'spikes.depths.npy', spikes_depths)
np.save(p / 'clusters.depths.npy', clusters_depths)
ns = len(spike_ids)
nc = len(channel_ids)
# Initialize the output array.
features = np.empty((ns, n_channels_loc, n_pcs))
features[:] = np.NAN
if self.features_rows is not None:
s = np.intersect1d(spike_ids, self.features_rows)
# Relative indices of the spikes in the self.features_spike_ids
# array, necessary to load features from all_features which only
# contains the subset of the spikes.
rows = _index_of(s, self.features_rows)
# Relative indices of the non-null rows in the output features
# array.
rows_out = _index_of(s, spike_ids)
else:
rows = spike_ids
rows_out = slice(None, None, None)
features[rows_out, ...] = data[rows]
if self.features_cols is not None:
assert self.features_cols.shape[1] == n_channels_loc
cols = self.features_cols[self.spike_templates[spike_ids]]
features = from_sparse(features, cols, channel_ids)
assert features.shape == (ns, nc, n_pcs)
return features