Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _initialize_subrecording_extractor(self):
if isinstance(self._bad_channel_ids, (list, np.ndarray)):
active_channels = []
for chan in self._recording.get_channel_ids():
if chan not in self._bad_channel_ids:
active_channels.append(chan)
self._subrecording = SubRecordingExtractor(self._recording, channel_ids=active_channels)
elif self._bad_channel_ids is None:
start_frame = self._recording.get_num_frames() // 2
end_frame = int(start_frame + self._seconds * self._recording.get_sampling_frequency())
if end_frame > self._recording.get_num_frames():
end_frame = self._recording.get_num_frames()
traces = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame)
stds = np.std(traces, axis=1)
bad_channel_ids = [ch for ch, std in enumerate(stds) if std > self._bad_threshold * np.median(stds)]
if self.verbose:
print('Automatically removing channels:', bad_channel_ids)
active_channels = []
for chan in self._recording.get_channel_ids():
if chan not in bad_channel_ids:
active_channels.append(chan)
self._subrecording = SubRecordingExtractor(self._recording, channel_ids=active_channels)
else:
def test_multi_sub_recording_extractor(self):
RX_multi = se.MultiRecordingTimeExtractor(
recordings=[self.RX, self.RX, self.RX],
epoch_names=['A', 'B', 'C']
)
RX_sub = RX_multi.get_epoch('C')
self._check_recordings_equal(self.RX, RX_sub)
self.assertEqual(4, len(RX_sub.get_channel_ids()))
RX_multi = se.MultiRecordingChannelExtractor(
recordings=[self.RX, self.RX2, self.RX3],
groups=[1, 2, 3]
)
print(RX_multi.get_channel_groups())
RX_sub = se.SubRecordingExtractor(RX_multi, channel_ids=[4, 5, 6, 7], renamed_channel_ids=[0, 1, 2, 3])
self._check_recordings_equal(self.RX2, RX_sub)
self.assertEqual([2, 2, 2, 2], RX_sub.get_channel_groups())
self.assertEqual(12, len(RX_multi.get_channel_ids()))
def get_max_channels_per_waveforms(recording, grouping_property, channel_ids, max_channels_per_waveforms):
if grouping_property is None:
if max_channels_per_waveforms is None:
n_channels = len(channel_ids)
elif max_channels_per_waveforms >= len(channel_ids):
n_channels = len(channel_ids)
else:
n_channels = max_channels_per_waveforms
else:
rec = se.SubRecordingExtractor(recording, channel_ids=channel_ids)
rec_groups = np.array([rec.get_channel_property(ch, grouping_property) for ch in rec.get_channel_ids()])
groups, count = np.unique(rec_groups, return_counts=True)
if max_channels_per_waveforms is None:
n_channels = np.max(count)
elif max_channels_per_waveforms >= np.max(count):
n_channels = np.max(count)
else:
n_channels = max_channels_per_waveforms
return n_channels
memmap_array[:] = wf
else:
# some channels are missing - re-instantiate object
memmap_file = memmap_array.filename
del memmap_array
memmap_array = np.memmap(memmap_file, mode='w+', shape=wf.shape, dtype=wf.dtype)
memmap_array[:] = wf
waveforms = memmap_array
return waveforms, list(indexes), list(max_channel_idxs)
else:
for i, unit_id in enumerate(unit_ids):
if unit == unit_id:
if channel_ids is None:
channel_ids = recording.get_channel_ids()
rec = se.SubRecordingExtractor(recording, channel_ids=channel_ids)
rec_groups = np.array(rec.get_channel_groups())
groups, count = np.unique(rec_groups, return_counts=True)
if max_channels_per_waveforms is None:
max_channels_per_waveforms = np.max(count)
elif max_channels_per_waveforms >= np.max(count):
max_channels_per_waveforms = np.max(count)
if max_spikes_per_unit is None:
max_spikes = len(sorting.get_unit_spike_train(unit_id))
else:
max_spikes = max_spikes_per_unit
if verbose:
print('Waveform ' + str(i + 1) + '/' + str(len(unit_ids)))
wf, indexes = _get_random_spike_waveforms(recording=recording,
sorting=sorting,