Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# write recording as binary format + json + prb
raw_filename = study_folder / 'raw_files' / (rec_name + '.dat')
prb_filename = study_folder / 'raw_files' / (rec_name + '.prb')
json_filename = study_folder / 'raw_files' / (rec_name + '.json')
num_chan = recording.get_num_channels()
chunksize = 2 ** 24 // num_chan
sr = recording.get_sampling_frequency()
se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize)
se.save_probe_file(recording, prb_filename, format='spyking_circus')
with open(json_filename, 'w', encoding='utf8') as f:
info = dict(sample_rate=sr, num_chan=num_chan, dtype='float32', frames_first=True)
json.dump(info, f, indent=4)
# write recording sorting_gt as with npz format
se.NpzSortingExtractor.write_sorting(sorting_gt, study_folder / 'ground_truth' / (rec_name + '.npz'))
# make an index of recording names
with open(study_folder / 'names.txt', mode='w', encoding='utf8') as f:
for rec_name in gt_dict:
f.write(rec_name + '\n')
def test_write_then_read(self):
recording, sorting_gt = se.example_datasets.toy_example(num_channels=4, duration=10, seed=0)
se.NpzSortingExtractor.write_sorting(sorting_gt, 'test_NpzSortingExtractors.npz')
npz = np.load('test_NpzSortingExtractors.npz')
sorting_npz = se.NpzSortingExtractor('test_NpzSortingExtractors.npz')
units_ids = npz['unit_ids']
self.assertEqual(list(units_ids), list(sorting_gt.get_unit_ids()))
self.assertEqual(list(sorting_npz.get_unit_ids()), list(sorting_gt.get_unit_ids()))
self.assertEqual(sorting_npz.get_sampling_frequency(), 30000.0)
def test_write_then_read(self):
recording, sorting_gt = se.example_datasets.toy_example(num_channels=4, duration=10, seed=0)
se.NpzSortingExtractor.write_sorting(sorting_gt, 'test_NpzSortingExtractors.npz')
npz = np.load('test_NpzSortingExtractors.npz')
sorting_npz = se.NpzSortingExtractor('test_NpzSortingExtractors.npz')
units_ids = npz['unit_ids']
self.assertEqual(list(units_ids), list(sorting_gt.get_unit_ids()))
self.assertEqual(list(sorting_npz.get_unit_ids()), list(sorting_gt.get_unit_ids()))
self.assertEqual(sorting_npz.get_sampling_frequency(), 30000.0)
def iter_computed_sorting(study_folder):
"""
Iter over sorting files.
"""
sorting_folder = Path(study_folder) / 'sortings'
for filename in os.listdir(sorting_folder):
if filename.endswith('.npz') and '[#]' in filename:
rec_name, sorter_name = filename.replace('.npz', '').split('[#]')
sorting = se.NpzSortingExtractor(sorting_folder / filename)
yield rec_name, sorter_name, sorting
----------
study_folder: str
The study folder.
Returns
----------
ground_truths: dict
Dict of sorintg_gt.
"""
study_folder = Path(study_folder)
rec_names = get_rec_names(study_folder)
ground_truths = {}
for rec_name in rec_names:
sorting = se.NpzSortingExtractor(study_folder / 'ground_truth' / (rec_name + '.npz'))
ground_truths[rec_name] = sorting
return ground_truths
def get_ground_truth(self, rec_name):
sorting = se.NpzSortingExtractor(self.study_folder / 'ground_truth' / (rec_name+'.npz'))
return sorting