From 7966b6d8cdf7118660ca096097912dd1ed2eddcd Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 15 Apr 2026 17:13:21 +0200 Subject: [PATCH 1/3] Patch --- .../sortingcomponents/matching/nearest.py | 55 ++++++++++++------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 3e0eb0b632..2c36cba0f1 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -1,7 +1,9 @@ """Sorting components: template matching.""" import numpy as np +from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core import get_noise_levels, get_channel_distances +from spikeinterface.core.sparsity import compute_sparsity from .base import BaseTemplateMatching, _base_matching_dtype @@ -23,8 +25,15 @@ class NearestTemplatesPeeler(BaseTemplateMatching): The threshold for peak detection in term of k x MAD noise_levels : None | array If None the noise levels are estimated using random chunks of the recording. If array it should be an array of size (num_channels,) with the noise level of each channel - radius_um : float - The radius to define the neighborhood between channels in micrometers while detecting the peaks + detection_radius_um : float, default 100.0 + The radius to define the neighborhood while detecting the peaks for locally exclusive detection. + neighborhood_radius_um : float, default 50.0 + The radius to use to select neighbour templates when assigning a detected peak to a template. + The neighborhood is defined around the extremum channel of the templates. + sparsity_radius_um : float, default 100.0 + The radius in um to use to compute the sparsity of the templates when the templates are not already sparse. + support_radius_um : float, default 50.0 + The radius in um to use to define the support of the templates when computing the distance between templates and waveforms. """ def __init__( @@ -39,6 +48,7 @@ def __init__( detection_radius_um=100.0, neighborhood_radius_um=50.0, sparsity_radius_um=100.0, + support_radius_um=50.0, ): BaseTemplateMatching.__init__(self, recording, templates, return_output=return_output) @@ -48,13 +58,12 @@ def __init__( self.peak_sign = peak_sign self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance <= detection_radius_um + self.support_radius_um = support_radius_um num_templates = len(self.templates.unit_ids) num_channels = recording.get_num_channels() if neighborhood_radius_um is not None: - from spikeinterface.core.template_tools import get_template_extremum_channel - best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index") best_channels = np.array([best_channels[i] for i in templates.unit_ids]) channel_locations = recording.get_channel_locations() @@ -65,20 +74,22 @@ def __init__( else: self.neighborhood_mask = np.ones((num_channels, num_templates), dtype=bool) - if sparsity_radius_um is not None: + if support_radius_um is not None: if not templates.are_templates_sparse(): - from spikeinterface.core.sparsity import compute_sparsity - - sparsity = compute_sparsity( - templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign - ) + if sparsity_radius_um is not None: + sparsity = compute_sparsity( + templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign + ) + else: + raise ValueError("sparsity_radius_um should be provided if templates are not sparse") else: sparsity = templates.sparsity - self.sparsity_mask = np.zeros((num_channels, num_channels), dtype=bool) - for channel_index in np.arange(num_channels): - mask = self.neighborhood_mask[channel_index] - self.sparsity_mask[channel_index] = np.sum(sparsity.mask[mask], axis=0) > 0 + channel_locations = recording.get_channel_locations() + channel_distances = np.linalg.norm( + channel_locations[:, None] - channel_locations[np.newaxis, :], axis=2 + ) + self.sparsity_support_mask = (channel_distances <= self.support_radius_um) else: self.sparsity_mask = np.ones((num_channels, num_channels), dtype=bool) @@ -86,13 +97,14 @@ def __init__( self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) self.nbefore = self.templates.nbefore self.nafter = self.templates.nafter - self.margin = max(self.nbefore, self.nafter) + self.width = self.nbefore + self.nafter + self.margin = self.width + 1 + self.exclude_sweep_size self.lookup_tables = {} self.lookup_tables["templates"] = {} self.lookup_tables["channels"] = {} for i in range(num_channels): self.lookup_tables["templates"][i] = np.flatnonzero(self.neighborhood_mask[i]) - self.lookup_tables["channels"][i] = np.flatnonzero(self.sparsity_mask[i]) + self.lookup_tables["channels"][i] = np.flatnonzero(self.sparsity_support_mask[i]) def get_trace_margin(self): return self.margin @@ -104,13 +116,14 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): from scipy.spatial.distance import cdist if self.margin > 0: - peak_traces = traces[self.margin : -self.margin, :] + peak_traces = traces[self.width : -self.width, :] else: peak_traces = traces + peak_sample_ind, peak_chan_ind = detect_peaks_numba_locally_exclusive_on_chunk( peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size, self.neighbours_mask ) - peak_sample_ind += self.margin + peak_sample_ind += self.width spikes = np.empty(peak_sample_ind.size, dtype=_base_matching_dtype) spikes["sample_index"] = peak_sample_ind @@ -164,6 +177,7 @@ def __init__( detection_radius_um=100.0, neighborhood_radius_um=50.0, sparsity_radius_um=100.0, + support_radius_um=50.0, ): NearestTemplatesPeeler.__init__( @@ -178,6 +192,7 @@ def __init__( detection_radius_um=detection_radius_um, neighborhood_radius_um=neighborhood_radius_um, sparsity_radius_um=sparsity_radius_um, + support_radius_um=support_radius_um ) from spikeinterface.sortingcomponents.waveforms.waveform_utils import ( @@ -206,13 +221,13 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): ) if self.margin > 0: - peak_traces = traces[self.margin : -self.margin, :] + peak_traces = traces[self.width : -self.wdith, :] else: peak_traces = traces peak_sample_ind, peak_chan_ind = detect_peaks_numba_locally_exclusive_on_chunk( peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size, self.neighbours_mask ) - peak_sample_ind += self.margin + peak_sample_ind += self.wdith spikes = np.empty(peak_sample_ind.size, dtype=_base_matching_dtype) spikes["sample_index"] = peak_sample_ind From fe31bfc27a1db2eedd245a4a367fe8dc1a8ec82b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 15:16:49 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/matching/nearest.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 2c36cba0f1..5d5af00470 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -28,7 +28,7 @@ class NearestTemplatesPeeler(BaseTemplateMatching): detection_radius_um : float, default 100.0 The radius to define the neighborhood while detecting the peaks for locally exclusive detection. neighborhood_radius_um : float, default 50.0 - The radius to use to select neighbour templates when assigning a detected peak to a template. + The radius to use to select neighbour templates when assigning a detected peak to a template. The neighborhood is defined around the extremum channel of the templates. sparsity_radius_um : float, default 100.0 The radius in um to use to compute the sparsity of the templates when the templates are not already sparse. @@ -79,17 +79,15 @@ def __init__( if sparsity_radius_um is not None: sparsity = compute_sparsity( templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign - ) + ) else: raise ValueError("sparsity_radius_um should be provided if templates are not sparse") else: sparsity = templates.sparsity channel_locations = recording.get_channel_locations() - channel_distances = np.linalg.norm( - channel_locations[:, None] - channel_locations[np.newaxis, :], axis=2 - ) - self.sparsity_support_mask = (channel_distances <= self.support_radius_um) + channel_distances = np.linalg.norm(channel_locations[:, None] - channel_locations[np.newaxis, :], axis=2) + self.sparsity_support_mask = channel_distances <= self.support_radius_um else: self.sparsity_mask = np.ones((num_channels, num_channels), dtype=bool) @@ -119,7 +117,7 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): peak_traces = traces[self.width : -self.width, :] else: peak_traces = traces - + peak_sample_ind, peak_chan_ind = detect_peaks_numba_locally_exclusive_on_chunk( peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size, self.neighbours_mask ) @@ -192,7 +190,7 @@ def __init__( detection_radius_um=detection_radius_um, neighborhood_radius_um=neighborhood_radius_um, sparsity_radius_um=sparsity_radius_um, - support_radius_um=support_radius_um + support_radius_um=support_radius_um, ) from spikeinterface.sortingcomponents.waveforms.waveform_utils import ( From 2b4cea5152aaeb7ed5a8a37ac9e0a6f24b0b11ba Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 16 Apr 2026 09:37:43 +0200 Subject: [PATCH 3/3] Typo --- src/spikeinterface/sortingcomponents/matching/nearest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 5d5af00470..dd0e475982 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -219,13 +219,13 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): ) if self.margin > 0: - peak_traces = traces[self.width : -self.wdith, :] + peak_traces = traces[self.width : -self.width, :] else: peak_traces = traces peak_sample_ind, peak_chan_ind = detect_peaks_numba_locally_exclusive_on_chunk( peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size, self.neighbours_mask ) - peak_sample_ind += self.wdith + peak_sample_ind += self.width spikes = np.empty(peak_sample_ind.size, dtype=_base_matching_dtype) spikes["sample_index"] = peak_sample_ind