Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3dc5729
Test IBL extractors tests failing for PI update
alejoe91 Dec 29, 2025
d1a0532
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 6, 2026
33c6769
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 16, 2026
2c94bac
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 20, 2026
a412bd8
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Feb 2, 2026
504e19d
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Feb 12, 2026
cd09c19
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Feb 19, 2026
a40d073
Merge branch 'main' of github.com:alejoe91/spikeinterface
alejoe91 Feb 24, 2026
a1da327
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 2, 2026
ef19a8e
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 3, 2026
a098b51
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 6, 2026
61c317a
Fix OpenEphys tests
alejoe91 Mar 6, 2026
c9ff247
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 9, 2026
3520138
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 16, 2026
f61329d
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 16, 2026
d64ae6a
Merge branch 'main' of github.com:alejoe91/spikeinterface
alejoe91 Mar 16, 2026
aef197d
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 17, 2026
e82331b
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 20, 2026
710cb6f
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 23, 2026
c2f8db1
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 23, 2026
161d25b
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 27, 2026
1d09ec6
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 30, 2026
afb7d33
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 30, 2026
fa556ba
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 30, 2026
8e68f16
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Apr 14, 2026
1c80910
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Apr 14, 2026
5eff246
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Apr 17, 2026
2c7bceb
wip: add censor_period_ms to slay
alejoe91 Apr 17, 2026
e0ea859
subtract coincident spikes and use merged ACG for slay RP violation
alejoe91 Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 93 additions & 10 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
"censored_period_ms": 0.3,
},
"quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3},
"slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5},
"slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5, "censored_period_ms": 0.2},
}


Expand Down Expand Up @@ -309,7 +309,6 @@ def compute_merge_unit_groups(
win_sizes,
pair_mask=pair_mask,
)
# print(correlogram_diff)
pair_mask = pair_mask & (correlogram_diff < params["corr_diff_thresh"])
outs["correlograms"] = correlograms
outs["bins"] = bins
Expand Down Expand Up @@ -373,12 +372,20 @@ def compute_merge_unit_groups(
outs["pairs_decreased_score"] = pairs_decreased_score

elif step == "slay_score":

M_ij = compute_slay_matrix(
sorting_analyzer, params["k1"], params["k2"], templates_diff=outs["templates_diff"], pair_mask=pair_mask
M_ij, sigma_ij, rho_ij, eta_ij = compute_slay_matrix(
sorting_analyzer,
params["k1"],
params["k2"],
params["censored_period_ms"],
templates_diff=outs["templates_diff"],
pair_mask=pair_mask,
)

pair_mask = pair_mask & (M_ij > params["slay_threshold"])
outs["slay_M_ij"] = M_ij
outs["slay_sigma_ij"] = sigma_ij
outs["slay_rho_ij"] = rho_ij
outs["slay_eta_ij"] = eta_ij

# FINAL STEP : create the final list from pair_mask boolean matrix
ind1, ind2 = np.nonzero(pair_mask)
Expand Down Expand Up @@ -1552,6 +1559,7 @@ def compute_slay_matrix(
sorting_analyzer: SortingAnalyzer,
k1: float,
k2: float,
censor_period_ms: float,
templates_diff: np.ndarray | None,
pair_mask: np.ndarray | None = None,
):
Expand All @@ -1569,6 +1577,9 @@ def compute_slay_matrix(
Coefficient determining the importance of the cross-correlation significance
k2 : float
Coefficient determining the importance of the sliding rp violation
censor_period_ms : float
The censored period to exclude from the refractory period computation to discard
duplicated spikes.
templates_diff : np.ndarray | None
Pre-computed template similarity difference matrix. If None, it will be retrieved from the sorting_analyzer.
pair_mask : None | np.ndarray, default: None
Expand All @@ -1592,14 +1603,35 @@ def compute_slay_matrix(
sigma_ij = 1 - templates_diff
else:
sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data()
rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask)
rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask, censor_period_ms)

M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij

return M_ij
return M_ij, sigma_ij, rho_ij, eta_ij


def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray):
def _count_coincident_spikes(t1, t2, max_samples):
"""
Count spikes in t1 that have a matching spike in t2 within max_samples,
split by lag direction.

Returns (n_nonneg, n_neg) where n_nonneg counts pairs where t2 >= t1
(non-negative lag, mapped to the right center CCG bin) and n_neg counts
pairs where t2 < t1 (negative lag, mapped to the left center CCG bin).
"""
if len(t1) == 0 or len(t2) == 0:
return 0, 0
indices = np.searchsorted(t2, t1, side="left")
right_valid = indices < len(t2)
right_diffs = np.where(right_valid, t2[np.minimum(indices, len(t2) - 1)] - t1, max_samples + 1)
left_valid = indices > 0
left_diffs = np.where(left_valid, t1 - t2[np.maximum(indices - 1, 0)], max_samples + 1)
n_nonneg = int(np.sum(right_diffs <= max_samples))
n_neg = int(np.sum((left_diffs <= max_samples) & (left_diffs > 0)))
return n_nonneg, n_neg


def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray, censor_period_ms: float):
"""
Computes a cross-correlation significance measure and a sliding refractory period violation
measure for all units in the `sorting_analyzer`.
Expand All @@ -1610,14 +1642,36 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarra
The sorting analyzer object containing the spike sorting data
pair_mask : np.ndarray
A bool matrix describing which pairs are possible merges based on previous steps
censor_period_ms : float
The censored period to exclude from the refractory period computation to discard
duplicated spikes.

Returns
-------
rho_ij : np.ndarray
The cross-correlation significance measure for each pair of units.
eta_ij : np.ndarray
The sliding refractory period violation measure for each pair of units.
"""

correlograms_extension = sorting_analyzer.get_extension("correlograms")
ccgs, _ = correlograms_extension.get_data()
ccgs, bin_edges = correlograms_extension.get_data()

# convert to seconds for SLAy functions
bin_size_ms = correlograms_extension.params["bin_ms"]

# pre-fetch spike trains for duplicate counting (sub-bin resolution)
if censor_period_ms > 0:
sorting = sorting_analyzer.sorting
censor_period_samples = int(censor_period_ms / 1000 * sorting_analyzer.sampling_frequency)
n_segments = sorting_analyzer.get_num_segments()
spike_trains = [
[sorting.get_unit_spike_train(unit_id=uid, segment_index=seg) for seg in range(n_segments)]
for uid in sorting_analyzer.unit_ids
]
# lag=0 spike pairs land in the bin starting at 0: xgram[num_half_bins]
center_bin = ccgs.shape[2] // 2

rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)])
eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)])

Expand All @@ -1630,10 +1684,39 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarra

xgram = ccgs[unit_index_1, unit_index_2, :]

# Merged ACG approximation: sum of individual ACGs and both CCG directions.
# _sliding_RP_viol_pair expects the ACG of the merged unit; the merged ACG
# has a large center bin when duplicates are present, which the LP filter
# attenuates so bin_rate_max reflects the flank rate — making RP violations
# detectable (unlike using the CCG alone where bin_rate_max is dominated by
# the duplicate peak and masks violations).
merged_acg = (
ccgs[unit_index_1, unit_index_1, :]
+ ccgs[unit_index_2, unit_index_2, :]
+ ccgs[unit_index_1, unit_index_2, :]
+ ccgs[unit_index_2, unit_index_1, :]
)

if censor_period_ms > 0:
# count number of spikes within the censor period from the two units
n_right, n_left = 0, 0
for seg in range(n_segments):
r, l = _count_coincident_spikes(
spike_trains[unit_index_1][seg], spike_trains[unit_index_2][seg], censor_period_samples
)
n_right += r
n_left += l
# subtract number of duplicates from central bin(s) of the merged ACG:
# n_right pairs land in center_bin (lag ≥ 0), n_left in center_bin-1 (lag < 0);
# each direction is counted in both ccgs[i,j] and ccgs[j,i], hence the factor 2
merged_acg = merged_acg.copy()
merged_acg[center_bin] = max(0, merged_acg[center_bin] - 2 * n_right)
merged_acg[center_bin - 1] = max(0, merged_acg[center_bin - 1] - 2 * n_left)

rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair(
xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0
)
eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size_ms=bin_size_ms)
eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(merged_acg, bin_size_ms=bin_size_ms)

return rho_ij, eta_ij

Expand Down
15 changes: 6 additions & 9 deletions src/spikeinterface/curation/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,9 @@ def make_sorting_analyzer(sparse=True, num_units=5, durations=[300.0]):
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

sorting_analyzer = create_sorting_analyzer(
sorting=sorting,
recording=recording,
format="memory",
sparse=sparse,
sorting=sorting, recording=recording, format="memory", sparse=sparse, n_jobs=-1
)
sorting_analyzer.compute(extensions)
sorting_analyzer.compute(extensions, n_jobs=-1)

return sorting_analyzer

Expand All @@ -58,9 +55,9 @@ def make_sorting_analyzer_with_splits(sorting_analyzer, num_unit_splitted=1, num
)

sorting_analyzer_with_splits = create_sorting_analyzer(
sorting=sorting_with_split, recording=sorting_analyzer.recording, format="memory", sparse=True
sorting=sorting_with_split, recording=sorting_analyzer.recording, format="memory", sparse=True, n_jobs=-1
)
sorting_analyzer_with_splits.compute(extensions)
sorting_analyzer_with_splits.compute(extensions, n_jobs=-1)

return sorting_analyzer_with_splits, num_unit_splitted, other_ids

Expand All @@ -78,8 +75,8 @@ def sorting_analyzer_for_unitrefine_curation():
recording, sorting_1 = generate_ground_truth_recording(num_channels=4, seed=1, num_units=6)
_, sorting_2 = generate_ground_truth_recording(num_channels=4, seed=2, num_units=6)
both_sortings = aggregate_units([sorting_1, sorting_2])
analyzer = create_sorting_analyzer(sorting=both_sortings, recording=recording)
analyzer.compute(["random_spikes", "noise_levels", "templates"])
analyzer = create_sorting_analyzer(sorting=both_sortings, recording=recording, n_jobs=-1)
analyzer.compute(["random_spikes", "noise_levels", "templates"], n_jobs=-1)
return analyzer


Expand Down
56 changes: 55 additions & 1 deletion src/spikeinterface/curation/tests/test_auto_merge.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import pytest

import numpy as np

from spikeinterface.core import create_sorting_analyzer
from spikeinterface.core import create_sorting_analyzer, NumpySorting
from spikeinterface.curation import compute_merge_unit_groups, auto_merge_units
from spikeinterface.generation import split_sorting_by_times


from spikeinterface.curation.tests.common import (
extensions,
make_sorting_analyzer,
sorting_analyzer_for_curation,
sorting_analyzer_with_splits,
Expand Down Expand Up @@ -68,6 +70,58 @@ def test_compute_merge_unit_groups_multi_segment(sorting_analyzer_multi_segment_
)


def test_slay_discard_duplicated_spikes(sorting_analyzer_with_splits):
sorting_analyzer, num_unit_splitted, split_ids = sorting_analyzer_with_splits

# now for the split units, we add some duplicated spikes
percent_duplicated = 0.7
split_units = []
for split in split_ids:
split_units.extend(split_ids[split])

# add unsplit spiketrains untouched
new_spiketrains = {}
for unit_id in sorting_analyzer.unit_ids:
if unit_id in split_ids:
continue
new_spiketrains[unit_id] = sorting_analyzer.sorting.get_unit_spike_train(unit_id=unit_id)
# ad duplicated spikes for split units
for unit_id in split_ids:
split_units = split_ids[unit_id]
spiketrains0 = sorting_analyzer.sorting.get_unit_spike_train(unit_id=split_units[0])
spiketrains1 = sorting_analyzer.sorting.get_unit_spike_train(unit_id=split_units[1])
num_duplicated = int(percent_duplicated * min(len(spiketrains0), len(spiketrains1)))
duplicated_spikes0 = np.random.choice(spiketrains0, size=num_duplicated, replace=False)
new_spiketrain1 = np.sort(np.concatenate([spiketrains1, duplicated_spikes0]))

new_spiketrains[split_units[0]] = spiketrains0
new_spiketrains[split_units[1]] = new_spiketrain1

sorting_duplicated = NumpySorting.from_unit_dict(
new_spiketrains, sampling_frequency=sorting_analyzer.sampling_frequency
)

sorting_analyzer_duplicated = create_sorting_analyzer(
sorting_duplicated, sorting_analyzer.recording, format="memory"
)
sorting_analyzer_duplicated.compute(extensions)

# Without censor period the split should not be found because of duplicates.
merges_no_censor_period, outs_no_censor_period = compute_merge_unit_groups(
sorting_analyzer_duplicated,
preset="slay",
steps_params={"slay_score": {"censored_period_ms": 0.0}},
extra_outputs=True,
)
merges_censor_period, outs_censor_period = compute_merge_unit_groups(
sorting_analyzer_duplicated,
preset="slay",
steps_params={"slay_score": {"censored_period_ms": 0.5}},
extra_outputs=True,
)
assert np.sum(outs_censor_period["slay_eta_ij"]) < np.sum(outs_no_censor_period["slay_eta_ij"])


def test_auto_merge_units(sorting_analyzer_for_curation):
recording = sorting_analyzer_for_curation.recording
new_sorting, _ = split_sorting_by_times(sorting_analyzer_for_curation)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _run(self, verbose=False, **job_kwargs):
self.data["bins"] = bins

def _get_data(self):
return self.data["ccgs"], self.data["bins"]
return self.data["ccgs"].copy(), self.data["bins"].copy()


class ComputeAutoCorrelograms(AnalyzerExtension):
Expand Down
Loading