From 3dc57290dbde0aeaa5048f2301ee75015a93fe26 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Dec 2025 15:43:44 +0100 Subject: [PATCH 1/4] Test IBL extractors tests failing for PI update --- src/spikeinterface/extractors/tests/test_iblextractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 972a8e7bb0..56d01e38cf 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -76,8 +76,8 @@ def test_offsets(self): def test_probe_representation(self): probe = self.recording.get_probe() - expected_probe_representation = "Probe - 384ch - 1shanks" - assert repr(probe) == expected_probe_representation + expected_probe_representation = "Probe - 384ch" + assert expected_probe_representation in repr(probe) def test_property_keys(self): expected_property_keys = [ From 61c317aba92608d9f096a3a374bc3d43e27faaba Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Mar 2026 10:09:46 -0800 Subject: [PATCH 2/4] Fix OpenEphys tests --- .../extractors/neoextractors/openephys.py | 20 ++++++++++++------- .../extractors/tests/test_neoextractors.py | 3 +++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 1c39a1b97c..1d16df534b 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -351,13 +351,19 @@ def __init__( # Ensure device channel index corresponds to channel_ids probe_channel_names = probe.contact_annotations.get("channel_name", None) if probe_channel_names is not None and not np.array_equal(probe_channel_names, self.channel_ids): - device_channel_indices = [] - probe_channel_names = list(probe_channel_names) - device_channel_indices = np.zeros(len(self.channel_ids), dtype=int) - for i, ch in enumerate(self.channel_ids): - index_in_probe = probe_channel_names.index(ch) - device_channel_indices[index_in_probe] = i - probe.set_device_channel_indices(device_channel_indices) + if set(probe_channel_names) == set(self.channel_ids): + device_channel_indices = [] + probe_channel_names = list(probe_channel_names) + device_channel_indices = np.zeros(len(self.channel_ids), dtype=int) + for i, ch in enumerate(self.channel_ids): + index_in_probe = probe_channel_names.index(ch) + device_channel_indices[index_in_probe] = i + probe.set_device_channel_indices(device_channel_indices) + else: + warnings.warn( + "Channel names in the probe do not match the channel ids from Neo. " + "Cannot set device channel indices, but this might lead to incorrect probe geometries" + ) if probe.shank_ids is not None: self.set_probe(probe, in_place=True, group_mode="by_shank") diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index f80f62ebf0..f40b4d05ab 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -121,6 +121,9 @@ class OpenEphysBinaryRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ("openephysbinary/v0.5.x_two_nodes", {"stream_id": "0"}), ("openephysbinary/v0.5.x_two_nodes", {"stream_id": "1"}), ("openephysbinary/v0.6.x_neuropixels_multiexp_multistream", {"stream_id": "0", "block_index": 0}), + # TODO: block_indices 1/2 of v0.6.x_neuropixels_multiexp_multistream have a mismatch in the channel names between + # the settings files (starting with CH0) and structure.oebin (starting at CH1). + # Currently, the extractor will skip remapping to match order in oebin and settings file, raising a warning ("openephysbinary/v0.6.x_neuropixels_multiexp_multistream", {"stream_id": "1", "block_index": 1}), ( "openephysbinary/v0.6.x_neuropixels_multiexp_multistream", From 2c7bcebf5dbff9fe9f78726a66789c0dc6b864e1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 17 Apr 2026 11:20:24 +0200 Subject: [PATCH 3/4] wip: add censor_period_ms to slay --- src/spikeinterface/curation/auto_merge.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index b66b553fd7..6a2905d468 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -88,7 +88,7 @@ "censored_period_ms": 0.3, }, "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - "slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5}, + "slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5, "censored_period_ms": 0.2}, } @@ -1552,6 +1552,7 @@ def compute_slay_matrix( sorting_analyzer: SortingAnalyzer, k1: float, k2: float, + censor_period_ms: float, templates_diff: np.ndarray | None, pair_mask: np.ndarray | None = None, ): @@ -1569,6 +1570,9 @@ def compute_slay_matrix( Coefficient determining the importance of the cross-correlation significance k2 : float Coefficient determining the importance of the sliding rp violation + censor_period_ms : float + The censored period to exclude from the refractory period computation to discard + duplicated spikes. templates_diff : np.ndarray | None Pre-computed template similarity difference matrix. If None, it will be retrieved from the sorting_analyzer. pair_mask : None | np.ndarray, default: None @@ -1592,14 +1596,14 @@ def compute_slay_matrix( sigma_ij = 1 - templates_diff else: sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data() - rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask) + rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask, censor_period_ms) M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij return M_ij -def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray): +def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray, censor_period_ms: float): """ Computes a cross-correlation significance measure and a sliding refractory period violation measure for all units in the `sorting_analyzer`. @@ -1610,6 +1614,9 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarra The sorting analyzer object containing the spike sorting data pair_mask : np.ndarray A bool matrix describing which pairs are possible merges based on previous steps + censor_period_ms : float + The censored period to exclude from the refractory period computation to discard + duplicated spikes. """ correlograms_extension = sorting_analyzer.get_extension("correlograms") @@ -1628,7 +1635,12 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarra if not pair_mask[unit_index_1, unit_index_2]: continue + # TODO: test this xgram = ccgs[unit_index_1, unit_index_2, :] + if censor_period_ms > 0: + center_bin = len(xgram) // 2 + censor_bins = int(round(censor_period_ms / bin_size_ms)) + xgram[center_bin - censor_bins : center_bin + censor_bins + 1] = 0 rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair( xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0 From e0ea8595fbe75b148ba36446eca3be1d5209ca39 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 17 Apr 2026 16:15:05 +0200 Subject: [PATCH 4/4] subtract coincident spikes and use merged ACG for slay RP violation --- src/spikeinterface/curation/auto_merge.py | 93 ++++++++++++++++--- src/spikeinterface/curation/tests/common.py | 15 ++- .../curation/tests/test_auto_merge.py | 56 ++++++++++- .../postprocessing/correlograms.py | 2 +- 4 files changed, 144 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 6a2905d468..66ebf439c5 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -309,7 +309,6 @@ def compute_merge_unit_groups( win_sizes, pair_mask=pair_mask, ) - # print(correlogram_diff) pair_mask = pair_mask & (correlogram_diff < params["corr_diff_thresh"]) outs["correlograms"] = correlograms outs["bins"] = bins @@ -373,12 +372,20 @@ def compute_merge_unit_groups( outs["pairs_decreased_score"] = pairs_decreased_score elif step == "slay_score": - - M_ij = compute_slay_matrix( - sorting_analyzer, params["k1"], params["k2"], templates_diff=outs["templates_diff"], pair_mask=pair_mask + M_ij, sigma_ij, rho_ij, eta_ij = compute_slay_matrix( + sorting_analyzer, + params["k1"], + params["k2"], + params["censored_period_ms"], + templates_diff=outs["templates_diff"], + pair_mask=pair_mask, ) pair_mask = pair_mask & (M_ij > params["slay_threshold"]) + outs["slay_M_ij"] = M_ij + outs["slay_sigma_ij"] = sigma_ij + outs["slay_rho_ij"] = rho_ij + outs["slay_eta_ij"] = eta_ij # FINAL STEP : create the final list from pair_mask boolean matrix ind1, ind2 = np.nonzero(pair_mask) @@ -1600,7 +1607,28 @@ def compute_slay_matrix( M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij - return M_ij + return M_ij, sigma_ij, rho_ij, eta_ij + + +def _count_coincident_spikes(t1, t2, max_samples): + """ + Count spikes in t1 that have a matching spike in t2 within max_samples, + split by lag direction. + + Returns (n_nonneg, n_neg) where n_nonneg counts pairs where t2 >= t1 + (non-negative lag, mapped to the right center CCG bin) and n_neg counts + pairs where t2 < t1 (negative lag, mapped to the left center CCG bin). + """ + if len(t1) == 0 or len(t2) == 0: + return 0, 0 + indices = np.searchsorted(t2, t1, side="left") + right_valid = indices < len(t2) + right_diffs = np.where(right_valid, t2[np.minimum(indices, len(t2) - 1)] - t1, max_samples + 1) + left_valid = indices > 0 + left_diffs = np.where(left_valid, t1 - t2[np.maximum(indices - 1, 0)], max_samples + 1) + n_nonneg = int(np.sum(right_diffs <= max_samples)) + n_neg = int(np.sum((left_diffs <= max_samples) & (left_diffs > 0))) + return n_nonneg, n_neg def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray, censor_period_ms: float): @@ -1617,14 +1645,33 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarra censor_period_ms : float The censored period to exclude from the refractory period computation to discard duplicated spikes. + + Returns + ------- + rho_ij : np.ndarray + The cross-correlation significance measure for each pair of units. + eta_ij : np.ndarray + The sliding refractory period violation measure for each pair of units. """ correlograms_extension = sorting_analyzer.get_extension("correlograms") - ccgs, _ = correlograms_extension.get_data() + ccgs, bin_edges = correlograms_extension.get_data() # convert to seconds for SLAy functions bin_size_ms = correlograms_extension.params["bin_ms"] + # pre-fetch spike trains for duplicate counting (sub-bin resolution) + if censor_period_ms > 0: + sorting = sorting_analyzer.sorting + censor_period_samples = int(censor_period_ms / 1000 * sorting_analyzer.sampling_frequency) + n_segments = sorting_analyzer.get_num_segments() + spike_trains = [ + [sorting.get_unit_spike_train(unit_id=uid, segment_index=seg) for seg in range(n_segments)] + for uid in sorting_analyzer.unit_ids + ] + # lag=0 spike pairs land in the bin starting at 0: xgram[num_half_bins] + center_bin = ccgs.shape[2] // 2 + rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) @@ -1635,17 +1682,41 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarra if not pair_mask[unit_index_1, unit_index_2]: continue - # TODO: test this xgram = ccgs[unit_index_1, unit_index_2, :] + + # Merged ACG approximation: sum of individual ACGs and both CCG directions. + # _sliding_RP_viol_pair expects the ACG of the merged unit; the merged ACG + # has a large center bin when duplicates are present, which the LP filter + # attenuates so bin_rate_max reflects the flank rate — making RP violations + # detectable (unlike using the CCG alone where bin_rate_max is dominated by + # the duplicate peak and masks violations). + merged_acg = ( + ccgs[unit_index_1, unit_index_1, :] + + ccgs[unit_index_2, unit_index_2, :] + + ccgs[unit_index_1, unit_index_2, :] + + ccgs[unit_index_2, unit_index_1, :] + ) + if censor_period_ms > 0: - center_bin = len(xgram) // 2 - censor_bins = int(round(censor_period_ms / bin_size_ms)) - xgram[center_bin - censor_bins : center_bin + censor_bins + 1] = 0 + # count number of spikes within the censor period from the two units + n_right, n_left = 0, 0 + for seg in range(n_segments): + r, l = _count_coincident_spikes( + spike_trains[unit_index_1][seg], spike_trains[unit_index_2][seg], censor_period_samples + ) + n_right += r + n_left += l + # subtract number of duplicates from central bin(s) of the merged ACG: + # n_right pairs land in center_bin (lag ≥ 0), n_left in center_bin-1 (lag < 0); + # each direction is counted in both ccgs[i,j] and ccgs[j,i], hence the factor 2 + merged_acg = merged_acg.copy() + merged_acg[center_bin] = max(0, merged_acg[center_bin] - 2 * n_right) + merged_acg[center_bin - 1] = max(0, merged_acg[center_bin - 1] - 2 * n_left) rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair( xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0 ) - eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size_ms=bin_size_ms) + eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(merged_acg, bin_size_ms=bin_size_ms) return rho_ij, eta_ij diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index a665b074a6..4b1c1095fb 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -35,12 +35,9 @@ def make_sorting_analyzer(sparse=True, num_units=5, durations=[300.0]): sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) sorting_analyzer = create_sorting_analyzer( - sorting=sorting, - recording=recording, - format="memory", - sparse=sparse, + sorting=sorting, recording=recording, format="memory", sparse=sparse, n_jobs=-1 ) - sorting_analyzer.compute(extensions) + sorting_analyzer.compute(extensions, n_jobs=-1) return sorting_analyzer @@ -58,9 +55,9 @@ def make_sorting_analyzer_with_splits(sorting_analyzer, num_unit_splitted=1, num ) sorting_analyzer_with_splits = create_sorting_analyzer( - sorting=sorting_with_split, recording=sorting_analyzer.recording, format="memory", sparse=True + sorting=sorting_with_split, recording=sorting_analyzer.recording, format="memory", sparse=True, n_jobs=-1 ) - sorting_analyzer_with_splits.compute(extensions) + sorting_analyzer_with_splits.compute(extensions, n_jobs=-1) return sorting_analyzer_with_splits, num_unit_splitted, other_ids @@ -78,8 +75,8 @@ def sorting_analyzer_for_unitrefine_curation(): recording, sorting_1 = generate_ground_truth_recording(num_channels=4, seed=1, num_units=6) _, sorting_2 = generate_ground_truth_recording(num_channels=4, seed=2, num_units=6) both_sortings = aggregate_units([sorting_1, sorting_2]) - analyzer = create_sorting_analyzer(sorting=both_sortings, recording=recording) - analyzer.compute(["random_spikes", "noise_levels", "templates"]) + analyzer = create_sorting_analyzer(sorting=both_sortings, recording=recording, n_jobs=-1) + analyzer.compute(["random_spikes", "noise_levels", "templates"], n_jobs=-1) return analyzer diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index cab508f4fb..eeafdc9a07 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -1,12 +1,14 @@ import pytest +import numpy as np -from spikeinterface.core import create_sorting_analyzer +from spikeinterface.core import create_sorting_analyzer, NumpySorting from spikeinterface.curation import compute_merge_unit_groups, auto_merge_units from spikeinterface.generation import split_sorting_by_times from spikeinterface.curation.tests.common import ( + extensions, make_sorting_analyzer, sorting_analyzer_for_curation, sorting_analyzer_with_splits, @@ -68,6 +70,58 @@ def test_compute_merge_unit_groups_multi_segment(sorting_analyzer_multi_segment_ ) +def test_slay_discard_duplicated_spikes(sorting_analyzer_with_splits): + sorting_analyzer, num_unit_splitted, split_ids = sorting_analyzer_with_splits + + # now for the split units, we add some duplicated spikes + percent_duplicated = 0.7 + split_units = [] + for split in split_ids: + split_units.extend(split_ids[split]) + + # add unsplit spiketrains untouched + new_spiketrains = {} + for unit_id in sorting_analyzer.unit_ids: + if unit_id in split_ids: + continue + new_spiketrains[unit_id] = sorting_analyzer.sorting.get_unit_spike_train(unit_id=unit_id) + # ad duplicated spikes for split units + for unit_id in split_ids: + split_units = split_ids[unit_id] + spiketrains0 = sorting_analyzer.sorting.get_unit_spike_train(unit_id=split_units[0]) + spiketrains1 = sorting_analyzer.sorting.get_unit_spike_train(unit_id=split_units[1]) + num_duplicated = int(percent_duplicated * min(len(spiketrains0), len(spiketrains1))) + duplicated_spikes0 = np.random.choice(spiketrains0, size=num_duplicated, replace=False) + new_spiketrain1 = np.sort(np.concatenate([spiketrains1, duplicated_spikes0])) + + new_spiketrains[split_units[0]] = spiketrains0 + new_spiketrains[split_units[1]] = new_spiketrain1 + + sorting_duplicated = NumpySorting.from_unit_dict( + new_spiketrains, sampling_frequency=sorting_analyzer.sampling_frequency + ) + + sorting_analyzer_duplicated = create_sorting_analyzer( + sorting_duplicated, sorting_analyzer.recording, format="memory" + ) + sorting_analyzer_duplicated.compute(extensions) + + # Without censor period the split should not be found because of duplicates. + merges_no_censor_period, outs_no_censor_period = compute_merge_unit_groups( + sorting_analyzer_duplicated, + preset="slay", + steps_params={"slay_score": {"censored_period_ms": 0.0}}, + extra_outputs=True, + ) + merges_censor_period, outs_censor_period = compute_merge_unit_groups( + sorting_analyzer_duplicated, + preset="slay", + steps_params={"slay_score": {"censored_period_ms": 0.5}}, + extra_outputs=True, + ) + assert np.sum(outs_censor_period["slay_eta_ij"]) < np.sum(outs_no_censor_period["slay_eta_ij"]) + + def test_auto_merge_units(sorting_analyzer_for_curation): recording = sorting_analyzer_for_curation.recording new_sorting, _ = split_sorting_by_times(sorting_analyzer_for_curation) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 40ac386ecc..492440a91e 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -208,7 +208,7 @@ def _run(self, verbose=False, **job_kwargs): self.data["bins"] = bins def _get_data(self): - return self.data["ccgs"], self.data["bins"] + return self.data["ccgs"].copy(), self.data["bins"].copy() class ComputeAutoCorrelograms(AnalyzerExtension):