diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 3e0eb0b632..dd0e475982 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -1,7 +1,9 @@ """Sorting components: template matching.""" import numpy as np +from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core import get_noise_levels, get_channel_distances +from spikeinterface.core.sparsity import compute_sparsity from .base import BaseTemplateMatching, _base_matching_dtype @@ -23,8 +25,15 @@ class NearestTemplatesPeeler(BaseTemplateMatching): The threshold for peak detection in term of k x MAD noise_levels : None | array If None the noise levels are estimated using random chunks of the recording. If array it should be an array of size (num_channels,) with the noise level of each channel - radius_um : float - The radius to define the neighborhood between channels in micrometers while detecting the peaks + detection_radius_um : float, default 100.0 + The radius to define the neighborhood while detecting the peaks for locally exclusive detection. + neighborhood_radius_um : float, default 50.0 + The radius to use to select neighbour templates when assigning a detected peak to a template. + The neighborhood is defined around the extremum channel of the templates. + sparsity_radius_um : float, default 100.0 + The radius in um to use to compute the sparsity of the templates when the templates are not already sparse. + support_radius_um : float, default 50.0 + The radius in um to use to define the support of the templates when computing the distance between templates and waveforms. """ def __init__( @@ -39,6 +48,7 @@ def __init__( detection_radius_um=100.0, neighborhood_radius_um=50.0, sparsity_radius_um=100.0, + support_radius_um=50.0, ): BaseTemplateMatching.__init__(self, recording, templates, return_output=return_output) @@ -48,13 +58,12 @@ def __init__( self.peak_sign = peak_sign self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance <= detection_radius_um + self.support_radius_um = support_radius_um num_templates = len(self.templates.unit_ids) num_channels = recording.get_num_channels() if neighborhood_radius_um is not None: - from spikeinterface.core.template_tools import get_template_extremum_channel - best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index") best_channels = np.array([best_channels[i] for i in templates.unit_ids]) channel_locations = recording.get_channel_locations() @@ -65,20 +74,20 @@ def __init__( else: self.neighborhood_mask = np.ones((num_channels, num_templates), dtype=bool) - if sparsity_radius_um is not None: + if support_radius_um is not None: if not templates.are_templates_sparse(): - from spikeinterface.core.sparsity import compute_sparsity - - sparsity = compute_sparsity( - templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign - ) + if sparsity_radius_um is not None: + sparsity = compute_sparsity( + templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign + ) + else: + raise ValueError("sparsity_radius_um should be provided if templates are not sparse") else: sparsity = templates.sparsity - self.sparsity_mask = np.zeros((num_channels, num_channels), dtype=bool) - for channel_index in np.arange(num_channels): - mask = self.neighborhood_mask[channel_index] - self.sparsity_mask[channel_index] = np.sum(sparsity.mask[mask], axis=0) > 0 + channel_locations = recording.get_channel_locations() + channel_distances = np.linalg.norm(channel_locations[:, None] - channel_locations[np.newaxis, :], axis=2) + self.sparsity_support_mask = channel_distances <= self.support_radius_um else: self.sparsity_mask = np.ones((num_channels, num_channels), dtype=bool) @@ -86,13 +95,14 @@ def __init__( self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) self.nbefore = self.templates.nbefore self.nafter = self.templates.nafter - self.margin = max(self.nbefore, self.nafter) + self.width = self.nbefore + self.nafter + self.margin = self.width + 1 + self.exclude_sweep_size self.lookup_tables = {} self.lookup_tables["templates"] = {} self.lookup_tables["channels"] = {} for i in range(num_channels): self.lookup_tables["templates"][i] = np.flatnonzero(self.neighborhood_mask[i]) - self.lookup_tables["channels"][i] = np.flatnonzero(self.sparsity_mask[i]) + self.lookup_tables["channels"][i] = np.flatnonzero(self.sparsity_support_mask[i]) def get_trace_margin(self): return self.margin @@ -104,13 +114,14 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): from scipy.spatial.distance import cdist if self.margin > 0: - peak_traces = traces[self.margin : -self.margin, :] + peak_traces = traces[self.width : -self.width, :] else: peak_traces = traces + peak_sample_ind, peak_chan_ind = detect_peaks_numba_locally_exclusive_on_chunk( peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size, self.neighbours_mask ) - peak_sample_ind += self.margin + peak_sample_ind += self.width spikes = np.empty(peak_sample_ind.size, dtype=_base_matching_dtype) spikes["sample_index"] = peak_sample_ind @@ -164,6 +175,7 @@ def __init__( detection_radius_um=100.0, neighborhood_radius_um=50.0, sparsity_radius_um=100.0, + support_radius_um=50.0, ): NearestTemplatesPeeler.__init__( @@ -178,6 +190,7 @@ def __init__( detection_radius_um=detection_radius_um, neighborhood_radius_um=neighborhood_radius_um, sparsity_radius_um=sparsity_radius_um, + support_radius_um=support_radius_um, ) from spikeinterface.sortingcomponents.waveforms.waveform_utils import ( @@ -206,13 +219,13 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): ) if self.margin > 0: - peak_traces = traces[self.margin : -self.margin, :] + peak_traces = traces[self.width : -self.width, :] else: peak_traces = traces peak_sample_ind, peak_chan_ind = detect_peaks_numba_locally_exclusive_on_chunk( peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size, self.neighbours_mask ) - peak_sample_ind += self.margin + peak_sample_ind += self.width spikes = np.empty(peak_sample_ind.size, dtype=_base_matching_dtype) spikes["sample_index"] = peak_sample_ind