Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def exportToPhy(recording, sorting, output_folder, nPCchan=3, nPC=5, filter=False, electrode_dimensions=None,
max_num_waveforms=np.inf):
analyzer = Analyzer(recording, sorting)
if not isinstance(recording, se.RecordingExtractor) or not isinstance(sorting, se.SortingExtractor):
raise AttributeError()
output_folder = os.path.abspath(output_folder)
if not os.path.isdir(output_folder):
os.makedirs(output_folder)
if filter:
recording = bandpass_filter(recording, freq_min=300, freq_max=6000)
# save dat file
se.writeBinaryDatFormat(recording, join(output_folder, 'recording.dat'), dtype='int16')
# write params.py
with open(join(output_folder, 'params.py'), 'w') as f:
f.write("dat_path =" + "'" + join(output_folder, 'recording.dat') +"'" + '\n')
f.write('n_channels_dat = ' + str(recording.getNumChannels()) + '\n')
f.write("dtype = 'int16'\n")
dtype = 'int16'
with save_file_path.open('wb') as f:
header = MdaHeader(dt0=dtype, dims0=(num_chan, num_frames))
header.write(f)
# takes care of the chunking
write_to_binary_dat_format(recording, file_handle=f, dtype=dtype, chunk_size=chunk_size,
chunk_mb=chunk_mb)
params["samplerate"] = recording.get_sampling_frequency()
with (parent_dir / params_fname).open('w') as f:
json.dump(params, f)
np.savetxt(str(parent_dir / geom_fname), geom, delimiter=',')
class MdaSortingExtractor(SortingExtractor):
extractor_name = 'MdaSortingExtractor'
installed = True # check at class level if installed or not
is_writable = True
mode = 'file'
installation_mesg = "" # error message when not installed
def __init__(self, file_path, sampling_frequency=None):
SortingExtractor.__init__(self)
self._firings_path = file_path
self._firings = readmda(self._firings_path)
self._max_channels = self._firings[0, :]
self._times = self._firings[1, :]
self._labels = self._firings[2, :]
self._unit_ids = np.unique(self._labels).astype(int)
self._sampling_frequency = sampling_frequency
try:
from scipy.io.matlab import loadmat, savemat
HAVE_LOADMAT = True
except ImportError:
HAVE_LOADMAT = False
HAVE_MAT = HAVE_H5PY & HAVE_LOADMAT
from spikeextractors import SortingExtractor
PathType = Union[str, Path]
class MATSortingExtractor(SortingExtractor):
extractor_name = "MATSortingExtractor"
installed = HAVE_MAT # check at class level if installed or not
is_writable = False
mode = "file"
installation_mesg = "To use the MATSortingExtractor install h5py and scipy: \n\n pip install h5py scipy\n\n" # error message when not installed
def __init__(self, file_path: PathType):
assert HAVE_MAT, self.installation_mesg
super().__init__()
file_path = Path(file_path) if isinstance(file_path, str) else file_path
if not isinstance(file_path, Path):
raise TypeError(f"Expected a str or Path file_path but got '{type(file_path).__name__}'")
file_path = file_path.resolve() # get absolute path to this file
if not file_path.is_file():
from spikeextractors import SortingExtractor
from pathlib import Path
import numpy as np
class NpzSortingExtractor(SortingExtractor):
"""
Dead simple format super light base on the NPZ numpy format.
https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez.html#numpy.savez
It is in fact an arichive of several .npy format.
All spike are store in two columns maner index+labels
"""
extractor_name = 'NpzSortingExtractor'
exporter_name = 'NpzSortingExporter'
exporter_gui_params = [
{'name': 'save_path', 'type': 'file', 'title': "Save path (.npz)"},
]
installed = True # depend only on numpy
installation_mesg = "Always installed"
if (phy_folder / 'channel_groups.npy').is_file():
channel_groups = np.load(phy_folder / 'channel_groups.npy')
assert len(channel_groups) == self.get_num_channels()
for (ch, cg) in zip(self.get_channel_ids(), channel_groups):
self.set_channel_property(ch, 'group', cg)
if (phy_folder / 'channel_positions.npy').is_file():
channel_locations = np.load(phy_folder / 'channel_positions.npy')
assert len(channel_locations) == self.get_num_channels()
for (ch, loc) in zip(self.get_channel_ids(), channel_locations):
self.set_channel_property(ch, 'location', loc)
self._kwargs = {'folder_path': str(Path(folder_path).absolute())}
class PhySortingExtractor(SortingExtractor):
extractor_name = 'PhySortingExtractor'
exporter_name = 'PhySortingExporter'
exporter_gui_params = [
{'name': 'save_path', 'type': 'folder', 'title': "Save path"},
]
installed = True # check at class level if installed or not
is_writable = True
mode = 'folder'
installation_mesg = "" # error message when not installed
def __init__(self, folder_path, exclude_cluster_groups=None, load_waveforms=False, verbose=False):
SortingExtractor.__init__(self)
phy_folder = Path(folder_path)
spike_times = np.load(phy_folder / 'spike_times.npy')
def __init__(self, folder_path, exclude_cluster_groups=None, load_waveforms=False, verbose=False):
SortingExtractor.__init__(self)
phy_folder = Path(folder_path)
spike_times = np.load(phy_folder / 'spike_times.npy')
spike_templates = np.load(phy_folder / 'spike_templates.npy')
if (phy_folder /'spike_clusters.npy').is_file():
spike_clusters = np.load(phy_folder / 'spike_clusters.npy')
else:
spike_clusters = spike_templates
if (phy_folder / 'amplitudes.npy').is_file():
amplitudes = np.load(phy_folder / 'amplitudes.npy')
else:
amplitudes = np.ones(len(spike_times))
if (phy_folder /'pc_features.npy').is_file():
def _get_phy_data(recording, sorting, compute_pc_features, compute_amplitudes,
max_channels_per_template, **kwargs):
if not isinstance(recording, se.RecordingExtractor) or not isinstance(sorting, se.SortingExtractor):
raise AttributeError()
if len(sorting.get_unit_ids()) == 0:
raise Exception("No units in the sorting result, can't compute phy information.")
params_dict = update_all_param_dicts_with_kwargs(kwargs)
n_comp = params_dict['n_comp']
max_spikes_for_pca = params_dict['max_spikes_for_pca']
recompute_info = params_dict['recompute_info']
save_property_or_features = params_dict['save_property_or_features']
verbose = params_dict['verbose']
grouping_property = params_dict['grouping_property']
ms_before = params_dict['ms_before']
ms_after = params_dict['ms_after']
dtype = params_dict['dtype']
memmap = params_dict['memmap']
n_jobs = params_dict['n_jobs']
sorting.threshold_sorting(0, "less_or_equal")
if unit_ids is None:
unit_ids = sorting.get_unit_ids()
else:
unit_ids = set(unit_ids)
unit_ids = list(unit_ids.intersection(sorting.get_unit_ids()))
if len(unit_ids) == 0:
raise ValueError("No units found.")
spike_times, spike_clusters = get_spike_times_metrics_data(
sorting, self._sampling_frequency
)
assert isinstance(
sorting, SortingExtractor
), "'sorting' must be a SortingExtractor object"
self._sorting = sorting
self._set_unit_ids(unit_ids)
self._set_epochs(epoch_tuples, epoch_names)
self._spike_times = spike_times
self._spike_clusters = spike_clusters
self._total_units = len(unit_ids)
self._unit_indices = _get_unit_indices(self._sorting, unit_ids)
# To compute this data, need to call all metric data
self._amplitudes = None
self._pc_features = None
self._pc_feature_ind = None
self._spike_clusters_pca = None
self._spike_clusters_amps = None
self._spike_times_pca = None
self._spike_times_amps = None