Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# in insert_into_shaders() below.
v_inserter += self.inserter
# Now, we insert the transforms GLSL into the shaders.
vs, fs = visual.vertex_shader, visual.fragment_shader
vs, fs = v_inserter.insert_into_shaders(vs, fs, exclude_origins=exclude_origins)
# Finally, we create the visual's program.
visual.program = LazyProgram(vs, fs)
logger.log(5, "Vertex shader: %s", vs)
logger.log(5, "Fragment shader: %s", fs)
# Initialize the size.
visual.on_resize(self.size().width(), self.size().height())
# Register the visual in the list of visuals in the canvas.
self.visuals.append(Bunch(visual=visual, **kwargs))
emit('visual_added', self, visual)
return visual
def _get_traces(interval):
return Bunch(data=select_traces(self.data, interval, sample_rate=self.sample_rate))
def _get_waveforms(self, cluster_id):
"""Return a selection of waveforms for a cluster."""
pos = self.model.channel_positions
spike_ids = self.selector.select_spikes(
[cluster_id], self.n_spikes_waveforms, self.batch_size_waveforms)
data = self.model.all_waveforms[spike_ids]
mm = self._get_mean_masks(cluster_id)
mw = np.mean(data, axis=0)
amp = get_waveform_amplitude(mm, mw)
masks = self._get_masks(cluster_id)
# Find the best channels.
channel_ids = np.argsort(amp)[::-1]
return Bunch(
data=data[..., channel_ids],
channel_ids=channel_ids,
channel_positions=pos[channel_ids],
masks=masks[:, channel_ids],
)
def validate(self, pos=None, data_bounds=None, **kwargs):
"""Validate the requested data before passing it to set_data()."""
assert pos is not None
pos = np.atleast_2d(pos)
assert pos.ndim == 2
assert pos.shape[1] == 2
# By default, we assume that the coordinates are in NDC.
if data_bounds is None:
data_bounds = NDC
data_bounds = _get_data_bounds(data_bounds)
data_bounds = data_bounds.astype(np.float64)
assert data_bounds.shape == (1, 4)
return Bunch(
pos=pos, data_bounds=data_bounds,
_n_items=pos.shape[0], _n_vertices=self.vertex_count(pos=pos))
assert len(text) == n_text
anchor = anchor if anchor is not None else (0., 0.)
anchor = np.atleast_2d(anchor)
if anchor.shape[0] == 1:
anchor = np.repeat(anchor, n_text, axis=0)
assert anchor.ndim == 2
assert anchor.shape == (n_text, 2)
data_bounds = data_bounds if data_bounds is not None else NDC
data_bounds = _get_data_bounds(data_bounds, pos)
assert data_bounds.shape[0] == n_text
data_bounds = data_bounds.astype(np.float64)
assert data_bounds.shape == (n_text, 4)
return Bunch(
pos=pos, text=text, anchor=anchor, data_bounds=data_bounds,
_n_items=n_text, _n_vertices=self.vertex_count(text=text))
except IOError:
return
try:
cols = self._read_array('template_feature_ind')
assert cols.shape == (self.n_templates, n_channels_loc)
except IOError:
cols = None
try:
rows = self._read_array('template_feature_spike_ids')
assert rows.shape == (n_spikes,)
except IOError:
rows = None
return Bunch(data=data, cols=cols, rows=rows)
def data(self):
"""Return the concatenated data as a dictionary."""
return Bunch({key: getattr(self, key) for key in self.items.keys()})
except IOError:
return
try:
cols = self._read_array('pc_feature_ind')
assert cols.shape == (self.n_templates, n_channels_loc)
except IOError:
cols = None
try:
rows = self._read_array('pc_feature_spike_ids')
assert rows.shape == (n_spikes,)
except IOError:
rows = None
return Bunch(data=data, cols=cols, rows=rows)
def get_view_state(self, view):
"""Return the state of a view instance."""
return self.get(view.name, Bunch())
c = m.spike_clusters[i]
is_selected = c in p.selected
# Show non selected spikes first, then selected spikes so that they appear on top.
if is_selected is not show_selected:
continue
# Skip non-selected spikes if requested.
if (not show_all_spikes and c not in supervisor.selected):
continue
# cg = p.cluster_meta.get('group', c)
channel_ids = get_best_channels(c)
s = int(round(t * sr)) - s0
# Skip partial spikes.
if s - k < 0 or s + k >= (s1 - s0): # pragma: no cover
continue
# Extract the waveform.
wave = Bunch(
data=traces_interval[s - k:s + ns - k, channel_ids],
channel_ids=channel_ids,
start_time=(s + s0 - k) / sr,
spike_id=i,
spike_time=t,
spike_cluster=c,
select_index=p.selected.index(c) if c in p.selected else None,
)
assert wave.data.shape == (ns, len(channel_ids))
yield wave