Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/mdn-cache-key.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed collision-prone MDN cache key (#5). `_generate_data_hash` previously used `pd.util.hash_pandas_object(X).sum()` which loses row ordering (any permutation hashes identically) and makes cross-dataset collisions trivial, so a different dataset with matching `(shape, columns, sum-of-hashes)` could silently load a stale MDN from disk (correctness bug). The function now uses `hashlib.sha256` over the raw bytes of per-row hashes, producing an order-sensitive content digest.
34 changes: 24 additions & 10 deletions microimpute/models/mdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,23 @@ def _suppress_pytorch_logging() -> None:


def _generate_data_hash(X: pd.DataFrame, y: pd.Series) -> str:
"""Generate a hash from the training data for cache identification.
"""Generate a content-sensitive hash from the training data for cache
identification.

Creates a reproducible hash based on the data shape, column names,
and a sample of the data values.
Creates an order-sensitive hash based on data shape, column names, and
the SHA-256 digest of the per-row hash bytes of X and y. Previously
this summed the per-row uint64 hashes from
``hash_pandas_object``, which lost row ordering (permutations hashed
identically) and was collision-prone across semantically different
datasets of matching shape — a cache hit on a collision would load a
stale model for a new dataset (silent correctness bug).

Args:
X: Feature DataFrame.
y: Target Series.

Returns:
A short hash string identifying the dataset.
A 16-character hex string identifying the dataset.
"""
# Include shape, column names, and data statistics for identification
hash_components = [
Expand All @@ -142,12 +148,18 @@ def _generate_data_hash(X: pd.DataFrame, y: pd.Series) -> str:
str(len(y)),
]

# Add hash of actual data values for uniqueness
# Use pandas hash_pandas_object for consistent hashing
# Order-sensitive content hash: SHA-256 over the raw bytes of per-row
# hashes. Any change in row values OR row ordering produces a
# different digest, eliminating the sum-of-hashes collision trap.
try:
data_hash = pd.util.hash_pandas_object(X).sum()
y_hash = pd.util.hash_pandas_object(y).sum()
hash_components.extend([str(data_hash), str(y_hash)])
x_row_hashes = pd.util.hash_pandas_object(X, index=True).values
y_row_hashes = pd.util.hash_pandas_object(y, index=True).values
hash_components.extend(
[
hashlib.sha256(x_row_hashes.tobytes()).hexdigest(),
hashlib.sha256(y_row_hashes.tobytes()).hexdigest(),
]
)
except Exception:
# Fallback to basic stats if hashing fails
hash_components.extend(
Expand All @@ -158,7 +170,9 @@ def _generate_data_hash(X: pd.DataFrame, y: pd.Series) -> str:
)

combined = "_".join(hash_components)
return hashlib.md5(combined.encode()).hexdigest()[:12]
# SHA-256 truncated to 16 hex chars (64 bits) — collision-resistant
# for any realistic cache size while keeping filesystem paths short.
return hashlib.sha256(combined.encode()).hexdigest()[:16]


def _get_package_versions_hash() -> str:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_models/test_mdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ def test_generate_cache_key():


def test_generate_data_hash():
"""Test data hash generation."""
"""Basic smoke test for data hash generation (extended coverage lives
in test_mdn_cache_key.py so it can run without torch)."""
np.random.seed(42)

# Create test data
Expand Down
85 changes: 85 additions & 0 deletions tests/test_models/test_mdn_cache_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Tests for the MDN dataset-hashing helpers used to key the model cache.

These regression tests cover the cache-key collision bug (#5). The tests
are gated on the full MDN import stack being available because mdn.py's
top-level ``import torch`` is not optional.
"""

import numpy as np
import pandas as pd
import pytest

pytest.importorskip("torch")
pytest.importorskip("pytorch_tabular")

from microimpute.models.mdn import _generate_cache_key, _generate_data_hash


def test_generate_data_hash_is_stable() -> None:
"""The same data must always hash to the same digest."""
X = pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]})
y = pd.Series([7, 8, 9], name="target")
assert _generate_data_hash(X, y) == _generate_data_hash(X.copy(), y.copy())


def test_generate_data_hash_detects_value_change() -> None:
"""A single-value change must produce a different digest."""
X1 = pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]})
X2 = pd.DataFrame({"a": [1.0, 2.0, 4.0], "b": [4.0, 5.0, 6.0]})
y = pd.Series([7, 8, 9], name="target")
assert _generate_data_hash(X1, y) != _generate_data_hash(X2, y)


def test_generate_data_hash_is_order_sensitive() -> None:
"""Regression test for the sum-of-hashes collision bug (#5).

Permuting rows must produce a different cache key. Previously the
key was the SUM of per-row uint64 hashes, so any row permutation
hashed identically and cache lookups could load a stale model
trained on differently ordered data.
"""
X = pd.DataFrame({"a": [1.0, 2.0, 3.0, 4.0], "b": [10, 20, 30, 40]})
y = pd.Series([100, 200, 300, 400], name="target")

X_perm = X.iloc[[3, 1, 0, 2]].reset_index(drop=True)
y_perm = y.iloc[[3, 1, 0, 2]].reset_index(drop=True)

assert _generate_data_hash(X, y) != _generate_data_hash(X_perm, y_perm), (
"Permuted rows must produce a different cache key"
)


def test_generate_data_hash_avoids_sum_collision() -> None:
"""Two datasets of matching shape/columns whose per-row hashes sum to
the same value must still produce different cache keys.

Previously the cache key was effectively
``hash(sum(per_row_hashes))``, making collisions trivial — a stale
cached MDN could be loaded for a new, different dataset (silent
correctness bug).
"""
X1 = pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]})
y1 = pd.Series([7.0, 8.0, 9.0], name="target")

X2 = pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.1]})
y2 = pd.Series([7.0, 8.0, 9.0], name="target")

assert _generate_data_hash(X1, y1) != _generate_data_hash(X2, y2)


def test_generate_data_hash_differs_across_random_datasets() -> None:
"""Property-style check: 50 random datasets produce 50 distinct hashes."""
rng = np.random.default_rng(0)
hashes = set()
for _ in range(50):
X = pd.DataFrame(rng.normal(size=(20, 3)), columns=["a", "b", "c"])
y = pd.Series(rng.normal(size=20), name="target")
hashes.add(_generate_data_hash(X, y))
assert len(hashes) == 50, "50 random datasets should produce 50 distinct cache keys"


def test_generate_cache_key_integrates_data_hash() -> None:
"""_generate_cache_key must change when _generate_data_hash changes."""
k1 = _generate_cache_key(["a", "b"], "target", "data_hash_1")
k2 = _generate_cache_key(["a", "b"], "target", "data_hash_2")
assert k1 != k2
Loading