Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _initialize_subrecording_extractor(self):
if isinstance(self._bad_channel_ids, (list, np.ndarray)):
active_channels = []
for chan in self._recording.get_channel_ids():
if chan not in self._bad_channel_ids:
active_channels.append(chan)
self._subrecording = SubRecordingExtractor(self._recording, channel_ids=active_channels)
elif self._bad_channel_ids is None:
start_frame = self._recording.get_num_frames() // 2
end_frame = int(start_frame + self._seconds * self._recording.get_sampling_frequency())
if end_frame > self._recording.get_num_frames():
end_frame = self._recording.get_num_frames()
traces = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame)
stds = np.std(traces, axis=1)
bad_channel_ids = [ch for ch, std in enumerate(stds) if std > self._bad_threshold * np.median(stds)]
if self.verbose:
print('Automatically removing channels:', bad_channel_ids)
active_channels = []
for chan in self._recording.get_channel_ids():
if chan not in bad_channel_ids:
active_channels.append(chan)
self._subrecording = SubRecordingExtractor(self._recording, channel_ids=active_channels)
else:
# write recording as binary format + json + prb
raw_filename = study_folder / 'raw_files' / (rec_name + '.dat')
prb_filename = study_folder / 'raw_files' / (rec_name + '.prb')
json_filename = study_folder / 'raw_files' / (rec_name + '.json')
num_chan = recording.get_num_channels()
chunksize = 2 ** 24 // num_chan
sr = recording.get_sampling_frequency()
se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize)
se.save_probe_file(recording, prb_filename, format='spyking_circus')
with open(json_filename, 'w', encoding='utf8') as f:
info = dict(sample_rate=sr, num_chan=num_chan, dtype='float32', frames_first=True)
json.dump(info, f, indent=4)
# write recording sorting_gt as with npz format
se.NpzSortingExtractor.write_sorting(sorting_gt, study_folder / 'ground_truth' / (rec_name + '.npz'))
# make an index of recording names
with open(study_folder / 'names.txt', mode='w', encoding='utf8') as f:
for rec_name in gt_dict:
f.write(rec_name + '\n')
os.makedirs(str(study_folder / 'ground_truth'))
os.makedirs(str(study_folder / 'sortings'))
os.makedirs(str(study_folder / 'sortings/run_log' ))
for rec_name, (recording, sorting_gt) in gt_dict.items():
# write recording as binary format + json + prb
raw_filename = study_folder / 'raw_files' / (rec_name + '.dat')
prb_filename = study_folder / 'raw_files' / (rec_name + '.prb')
json_filename = study_folder / 'raw_files' / (rec_name + '.json')
num_chan = recording.get_num_channels()
chunksize = 2 ** 24 // num_chan
sr = recording.get_sampling_frequency()
se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize)
se.save_probe_file(recording, prb_filename, format='spyking_circus')
with open(json_filename, 'w', encoding='utf8') as f:
info = dict(sample_rate=sr, num_chan=num_chan, dtype='float32', frames_first=True)
json.dump(info, f, indent=4)
# write recording sorting_gt as with npz format
se.NpzSortingExtractor.write_sorting(sorting_gt, study_folder / 'ground_truth' / (rec_name + '.npz'))
# make an index of recording names
with open(study_folder / 'names.txt', mode='w', encoding='utf8') as f:
for rec_name in gt_dict:
f.write(rec_name + '\n')
# source file
if isinstance(recording, se.BinDatRecordingExtractor) and recording._frame_first and\
recording._timeseries.offset==0:
# no need to copy
raw_filename = str(recording._datfile)
raw_filename = raw_filename.replace('.dat', '')
dtype = recording._timeseries.dtype.str
nb_chan = len(recording._channels)
else:
# save binary file (chunk by hcunk) into a new file
raw_filename = output_folder / 'recording'
n_chan = recording.get_num_channels()
chunksize = 2**24// n_chan
dtype='int16'
se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype=dtype, chunksize=chunksize)
if p['detect_sign'] < 0:
detect_sign = 'negative'
elif p['detect_sign'] > 0:
detect_sign = 'positive'
else:
detect_sign = 'both'
# set up klusta config file
with (source_dir / 'config_default.prm').open('r') as f:
klusta_config = f.readlines()
# Note: should use format with dict approach here
klusta_config = ''.join(klusta_config).format(raw_filename,
p['probe_file'], float(recording.get_sampling_frequency()),
recording.get_num_channels(), "'{}'".format(dtype),
def _setup_recording(self, recording, output_folder):
source_dir = Path(__file__).parent
p = self.params
if not check_if_installed(KilosortSorter.kilosort_path, KilosortSorter.npy_matlab_path):
raise Exception(KilosortSorter.installation_mesg)
# save binary file
file_name = 'recording'
se.write_binary_dat_format(recording, output_folder / file_name, dtype='int16')
# set up kilosort config files and run kilosort on data
with (source_dir / 'kilosort_master.txt').open('r') as f:
kilosort_master = f.readlines()
with (source_dir / 'kilosort_config.txt').open('r') as f:
kilosort_config = f.readlines()
with (source_dir / 'kilosort_channelmap.txt').open('r') as f:
kilosort_channelmap = f.readlines()
nchan = recording.get_num_channels()
dat_file = (output_folder / (file_name + '.dat')).absolute()
kilo_thresh = p['detect_threshold']
Nfilt = (nchan // 32) * 32 * 8
if Nfilt == 0:
Nfilt = nchan * 8
nsamples = 128 * 1024 + 64
# source file
if isinstance(recording, se.BinDatRecordingExtractor) and recording._frame_first:
# no need to copy
raw_filename = recording._datfile
dtype = recording._timeseries.dtype.str
nb_chan = len(recording._channels)
offset = recording._timeseries.offset
else:
if self.debug:
print('Local copy of recording')
# save binary file (chunk by hcunk) into a new file
raw_filename = output_folder / 'raw_signals.raw'
n_chan = recording.get_num_channels()
chunksize = 2**24// n_chan
se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize)
dtype='float32'
offset = 0
# initialize source and probe file
tdc_dataio = tdc.DataIO(dirname=str(output_folder))
nb_chan = recording.get_num_channels()
tdc_dataio.set_data_source(type='RawData', filenames=[str(raw_filename)],
dtype=dtype, sample_rate=recording.get_sampling_frequency(),
total_channel=nb_chan, offset=offset)
tdc_dataio.set_probe_file(str(probe_file))
if self.debug:
print(tdc_dataio)
def _create_example(self):
channel_ids = [0, 1, 2, 3]
num_channels = 4
num_frames = 10000
sampling_frequency = 30000
X = np.random.normal(0, 1, (num_channels, num_frames))
geom = np.random.normal(0, 1, (num_channels, 2))
X = (X * 100).astype(int)
RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
RX2 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
RX3 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
SX = se.NumpySortingExtractor()
spike_times = [200, 300, 400]
train1 = np.sort(np.rint(np.random.uniform(0, num_frames, spike_times[0])).astype(int))
SX.add_unit(unit_id=1, times=train1)
SX.add_unit(unit_id=2, times=np.sort(np.random.uniform(0, num_frames, spike_times[1])))
SX.add_unit(unit_id=3, times=np.sort(np.random.uniform(0, num_frames, spike_times[2])))
SX.set_unit_property(unit_id=1, property_name='stability', value=80)
SX.set_sampling_frequency(sampling_frequency)
SX2 = se.NumpySortingExtractor()
spike_times2 = [100, 150, 450]
train2 = np.rint(np.random.uniform(0, num_frames, spike_times2[0])).astype(int)
SX2.add_unit(unit_id=3, times=train2)
SX2.add_unit(unit_id=4, times=np.random.uniform(0, num_frames, spike_times2[1]))
SX2.add_unit(unit_id=5, times=np.random.uniform(0, num_frames, spike_times2[2]))
SX2.set_unit_property(unit_id=4, property_name='stability', value=80)
SX2.set_unit_spike_features(unit_id=3, feature_name='widths', value=np.asarray([3] * spike_times2[0]))
RX.set_channel_property(channel_id=0, property_name='location', value=(0, 0))
spike_times2 = [100, 150, 450]
train2 = np.rint(np.random.uniform(0, num_frames, spike_times2[0])).astype(int)
SX2.add_unit(unit_id=3, times=train2)
SX2.add_unit(unit_id=4, times=np.random.uniform(0, num_frames, spike_times2[1]))
SX2.add_unit(unit_id=5, times=np.random.uniform(0, num_frames, spike_times2[2]))
SX2.set_unit_property(unit_id=4, property_name='stability', value=80)
SX2.set_unit_spike_features(unit_id=3, feature_name='widths', value=np.asarray([3] * spike_times2[0]))
RX.set_channel_property(channel_id=0, property_name='location', value=(0, 0))
for i, unit_id in enumerate(SX2.get_unit_ids()):
SX2.set_unit_property(unit_id=unit_id, property_name='shared_unit_prop', value=i)
SX2.set_unit_spike_features(unit_id=unit_id, feature_name='shared_unit_feature',
value=np.asarray([i] * spike_times2[i]))
for i, channel_id in enumerate(RX.get_channel_ids()):
RX.set_channel_property(channel_id=channel_id, property_name='shared_channel_prop', value=i)
SX3 = se.NumpySortingExtractor()
train3= np.asarray([1,20,21,35,38,45,46,47])
SX3.add_unit(unit_id=0, times=train3)
features3 = np.asarray([0,5,10,15,20,25,30,35])
SX3.set_unit_spike_features(unit_id=0, feature_name='dummy', value=features3)
example_info = dict(
channel_ids=channel_ids,
num_channels=num_channels,
num_frames=num_frames,
sampling_frequency=sampling_frequency,
unit_ids=[1, 2, 3],
train1=train1,
train2=train2,
train3=train3,
features3=features3,
unit_prop=80,
def setUp(self):
M = 4
N = 10000
seed= 0
sampling_frequency = 30000
X = np.random.RandomState(seed=seed).normal(0, 1, (M, N))
geom = np.random.RandomState(seed=seed).normal(0, 1, (M, 2))
self._X = X
self._geom = geom
self._sampling_frequency = sampling_frequency
self.RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
self.SX = se.NumpySortingExtractor()
L = 200
self._train1 = np.rint(np.random.RandomState(seed=seed).uniform(0, N, L)).astype(int)
self.SX.add_unit(unit_id=1, times=self._train1)
self.SX.add_unit(unit_id=2, times=np.random.RandomState(seed=seed).uniform(0, N, L))
self.SX.add_unit(unit_id=3, times=np.random.RandomState(seed=seed).uniform(0, N, L))
def _create_example(self):
channel_ids = [0, 1, 2, 3]
num_channels = 4
num_frames = 10000
sampling_frequency = 30000
X = np.random.normal(0, 1, (num_channels, num_frames))
geom = np.random.normal(0, 1, (num_channels, 2))
X = (X * 100).astype(int)
RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
RX2 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
RX3 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
SX = se.NumpySortingExtractor()
spike_times = [200, 300, 400]
train1 = np.sort(np.rint(np.random.uniform(0, num_frames, spike_times[0])).astype(int))
SX.add_unit(unit_id=1, times=train1)
SX.add_unit(unit_id=2, times=np.sort(np.random.uniform(0, num_frames, spike_times[1])))
SX.add_unit(unit_id=3, times=np.sort(np.random.uniform(0, num_frames, spike_times[2])))
SX.set_unit_property(unit_id=1, property_name='stability', value=80)
SX.set_sampling_frequency(sampling_frequency)
SX2 = se.NumpySortingExtractor()
spike_times2 = [100, 150, 450]
train2 = np.rint(np.random.uniform(0, num_frames, spike_times2[0])).astype(int)
SX2.add_unit(unit_id=3, times=train2)
SX2.add_unit(unit_id=4, times=np.random.uniform(0, num_frames, spike_times2[1]))
SX2.add_unit(unit_id=5, times=np.random.uniform(0, num_frames, spike_times2[2]))
SX2.set_unit_property(unit_id=4, property_name='stability', value=80)
SX2.set_unit_spike_features(unit_id=3, feature_name='widths', value=np.asarray([3] * spike_times2[0]))