Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
recording, sorting = se.example_datasets.toy_example(num_channels=4, duration=30, seed=0)
##############################################################################
# and let's spike sort using klusta
sorting_KL = ss.run_klusta(recording)
print('Units:', sorting_KL.get_unit_ids())
print('Number of units:', len(sorting_KL.get_unit_ids()))
##############################################################################
# There are several available functions that enables to only retrieve units with respect to some rules. For example,
# let's automatically curate the sorting output so that only the units with SNR > 10 and mean firing rate > 2.3 Hz are
# kept:
sorting_fr = st.curation.threshold_firing_rate(sorting_KL, threshold=2.3, threshold_sign='less')
print('Units after FR theshold:', sorting_fr.get_unit_ids())
print('Number of units after FR theshold:', len(sorting_fr.get_unit_ids()))
sorting_snr = st.curation.threshold_snr(sorting_fr, recording, threshold=10, threshold_sign='less')
print('Units after SNR theshold:', sorting_snr.get_unit_ids())
print('Number of units after SNR theshold:', len(sorting_snr.get_unit_ids()))
##############################################################################
# Let's now check with the :code:`toolkit.validation` submodule that all units have a firing rate > 10 and snr > 0
fr = st.validation.compute_firing_rates(sorting_snr)
snrs = st.validation.compute_snrs(sorting_snr, recording)
print('Firing rates:', fr)
# metrics to assess the goodness of sorted units. Among those, for example, are signal-to-noise ratio, ISI violation
# ratio, isolation distance, and many more.
snrs = st.validation.compute_snrs(sorting_KL, recording_cmr)
isi_violations = st.validation.compute_isi_violations(sorting_KL)
isolations = st.validation.compute_isolation_distances(sorting_KL, recording)
print('SNR', snrs)
print('ISI violation ratios', isi_violations)
print('Isolation distances', isolations)
##############################################################################
# Quality metrics can be also used to automatically curate the spike sorting output. For example, you can select
# sorted units with a SNR above a certain threshold:
sorting_curated_snr = st.curation.threshold_snr(sorting_KL, recording, threshold=5, threshold_sign='less')
snrs_above = st.validation.compute_snrs(sorting_curated_snr, recording_cmr)
print('Curated SNR', snrs_above)
##############################################################################
# The final part of this tutorial deals with comparing spike sorting outputs.
# We can either (1) compare the spike sorting results with the ground-truth sorting :code:`sorting_true`, (2) compare
# the output of two (Klusta and Mountainsor4), or (3) compare the output of multiple sorters:
comp_gt_KL = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true, tested_sorting=sorting_KL)
comp_KL_MS4 = sc.compare_two_sorters(sorting1=sorting_KL, sorting2=sorting_MS4)
comp_multi = sc.compare_multiple_sorters(sorting_list=[sorting_MS4, sorting_KL],
name_list=['klusta', 'ms4'])
##############################################################################
fig1, ax1 = plt.subplots()
ax1.plot(recording_car.get_traces()[0])
ax1.plot(recording_cmr.get_traces()[0])
fig2, ax2 = plt.subplots()
ax2.plot(recording_single_groups.get_traces()[1]) # not zero
ax2.plot(recording_single_groups.get_traces()[0])
##############################################################################
# Remove bad channels
# ----------------------
#
# In to remove noisy channels from the analysis, the
# :code:`remove_bad_channels` function can be used.
recording_remove_bad = st.preprocessing.remove_bad_channels(recording, bad_channel_ids=[0])
print(recording_remove_bad.get_channel_ids())
##############################################################################
# As expected, channel 0 is removed. Bad channels removal can also be done
# automatically. In this case, the channels with a standard deviation
# exceeding :code:`bad_threshold` times the median standard deviation are
# removed. The standard deviations are computed on the traces with length
# :code:`seconds` from the middle of the recordings.
recording_remove_bad_auto = st.preprocessing.remove_bad_channels(recording, bad_channel_ids=None, bad_threshold=2,
seconds=2)
print(recording_remove_bad_auto.get_channel_ids())
##############################################################################
max_chan = st.postprocessing.get_unit_max_channels(recording, sorting, save_as_property=True, verbose=True)
print(max_chan)
##############################################################################
print(sorting.get_shared_unit_property_names())
##############################################################################
# Compute pca scores
# ---------------------
#
# For some applications, for example validating the spike sorting output,
# PCA scores can be computed.
pca_scores = st.postprocessing.compute_unit_pca_scores(recording, sorting, n_comp=3, verbose=True)
for pc in pca_scores:
print(pc.shape)
fig, ax = plt.subplots()
ax.plot(pca_scores[0][:, 0], pca_scores[0][:, 1], 'r*')
ax.plot(pca_scores[2][:, 0], pca_scores[2][:, 1], 'b*')
##############################################################################
# PCA scores can be also computed electrode-wise. In the previous example,
# PCA was applied to the concatenation of the waveforms over channels.
pca_scores_by_electrode = st.postprocessing.compute_unit_pca_scores(recording, sorting, n_comp=3, by_electrode=True)
for pc in pca_scores_by_electrode:
# :code:`postprocessing` module allows to extract all relevant information
# from the paired recording-sorting.
##############################################################################
# Compute spike waveforms
# --------------------------
#
# Waveforms are extracted with the :code:`get_unit_waveforms` function by
# extracting snippets of the recordings when spikes are detected. When
# waveforms are extracted, the can be loaded in the :code:`SortingExtractor`
# object as features. The ms before and after the spike event can be
# chosen. Waveforms are returned as a list of np.arrays (n\_spikes,
# n\_channels, n\_points)
wf = st.postprocessing.get_unit_waveforms(recording, sorting, ms_before=1, ms_after=2,
save_as_features=True, verbose=True)
##############################################################################
# Now :code:`waveforms` is a unit spike feature!
print(sorting.get_shared_unit_spike_feature_names())
print(wf[0].shape)
##############################################################################
# plotting waveforms of units 0,1,2 on channel 0
fig, ax = plt.subplots()
ax.plot(wf[0][:, 0, :].T, color='k', lw=0.3)
ax.plot(wf[1][:, 0, :].T, color='r', lw=0.3)
ax.plot(wf[2][:, 0, :].T, color='b', lw=0.3)