diff --git a/changelog.d/mdn-cache-key.fixed.md b/changelog.d/mdn-cache-key.fixed.md new file mode 100644 index 0000000..c4afe0d --- /dev/null +++ b/changelog.d/mdn-cache-key.fixed.md @@ -0,0 +1 @@ +Fixed collision-prone MDN cache key (#5). `_generate_data_hash` previously used `pd.util.hash_pandas_object(X).sum()` which loses row ordering (any permutation hashes identically) and makes cross-dataset collisions trivial, so a different dataset with matching `(shape, columns, sum-of-hashes)` could silently load a stale MDN from disk (correctness bug). The function now uses `hashlib.sha256` over the raw bytes of per-row hashes, producing an order-sensitive content digest. diff --git a/microimpute/models/mdn.py b/microimpute/models/mdn.py index aac265e..8832e6e 100644 --- a/microimpute/models/mdn.py +++ b/microimpute/models/mdn.py @@ -122,17 +122,23 @@ def _suppress_pytorch_logging() -> None: def _generate_data_hash(X: pd.DataFrame, y: pd.Series) -> str: - """Generate a hash from the training data for cache identification. + """Generate a content-sensitive hash from the training data for cache + identification. - Creates a reproducible hash based on the data shape, column names, - and a sample of the data values. + Creates an order-sensitive hash based on data shape, column names, and + the SHA-256 digest of the per-row hash bytes of X and y. Previously + this summed the per-row uint64 hashes from + ``hash_pandas_object``, which lost row ordering (permutations hashed + identically) and was collision-prone across semantically different + datasets of matching shape — a cache hit on a collision would load a + stale model for a new dataset (silent correctness bug). Args: X: Feature DataFrame. y: Target Series. Returns: - A short hash string identifying the dataset. + A 16-character hex string identifying the dataset. """ # Include shape, column names, and data statistics for identification hash_components = [ @@ -142,12 +148,18 @@ def _generate_data_hash(X: pd.DataFrame, y: pd.Series) -> str: str(len(y)), ] - # Add hash of actual data values for uniqueness - # Use pandas hash_pandas_object for consistent hashing + # Order-sensitive content hash: SHA-256 over the raw bytes of per-row + # hashes. Any change in row values OR row ordering produces a + # different digest, eliminating the sum-of-hashes collision trap. try: - data_hash = pd.util.hash_pandas_object(X).sum() - y_hash = pd.util.hash_pandas_object(y).sum() - hash_components.extend([str(data_hash), str(y_hash)]) + x_row_hashes = pd.util.hash_pandas_object(X, index=True).values + y_row_hashes = pd.util.hash_pandas_object(y, index=True).values + hash_components.extend( + [ + hashlib.sha256(x_row_hashes.tobytes()).hexdigest(), + hashlib.sha256(y_row_hashes.tobytes()).hexdigest(), + ] + ) except Exception: # Fallback to basic stats if hashing fails hash_components.extend( @@ -158,7 +170,9 @@ def _generate_data_hash(X: pd.DataFrame, y: pd.Series) -> str: ) combined = "_".join(hash_components) - return hashlib.md5(combined.encode()).hexdigest()[:12] + # SHA-256 truncated to 16 hex chars (64 bits) — collision-resistant + # for any realistic cache size while keeping filesystem paths short. + return hashlib.sha256(combined.encode()).hexdigest()[:16] def _get_package_versions_hash() -> str: diff --git a/tests/test_models/test_mdn.py b/tests/test_models/test_mdn.py index 3b2b080..b0e64e1 100644 --- a/tests/test_models/test_mdn.py +++ b/tests/test_models/test_mdn.py @@ -460,7 +460,8 @@ def test_generate_cache_key(): def test_generate_data_hash(): - """Test data hash generation.""" + """Basic smoke test for data hash generation (extended coverage lives + in test_mdn_cache_key.py so it can run without torch).""" np.random.seed(42) # Create test data diff --git a/tests/test_models/test_mdn_cache_key.py b/tests/test_models/test_mdn_cache_key.py new file mode 100644 index 0000000..ea4ecb0 --- /dev/null +++ b/tests/test_models/test_mdn_cache_key.py @@ -0,0 +1,85 @@ +"""Tests for the MDN dataset-hashing helpers used to key the model cache. + +These regression tests cover the cache-key collision bug (#5). The tests +are gated on the full MDN import stack being available because mdn.py's +top-level ``import torch`` is not optional. +""" + +import numpy as np +import pandas as pd +import pytest + +pytest.importorskip("torch") +pytest.importorskip("pytorch_tabular") + +from microimpute.models.mdn import _generate_cache_key, _generate_data_hash + + +def test_generate_data_hash_is_stable() -> None: + """The same data must always hash to the same digest.""" + X = pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]}) + y = pd.Series([7, 8, 9], name="target") + assert _generate_data_hash(X, y) == _generate_data_hash(X.copy(), y.copy()) + + +def test_generate_data_hash_detects_value_change() -> None: + """A single-value change must produce a different digest.""" + X1 = pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]}) + X2 = pd.DataFrame({"a": [1.0, 2.0, 4.0], "b": [4.0, 5.0, 6.0]}) + y = pd.Series([7, 8, 9], name="target") + assert _generate_data_hash(X1, y) != _generate_data_hash(X2, y) + + +def test_generate_data_hash_is_order_sensitive() -> None: + """Regression test for the sum-of-hashes collision bug (#5). + + Permuting rows must produce a different cache key. Previously the + key was the SUM of per-row uint64 hashes, so any row permutation + hashed identically and cache lookups could load a stale model + trained on differently ordered data. + """ + X = pd.DataFrame({"a": [1.0, 2.0, 3.0, 4.0], "b": [10, 20, 30, 40]}) + y = pd.Series([100, 200, 300, 400], name="target") + + X_perm = X.iloc[[3, 1, 0, 2]].reset_index(drop=True) + y_perm = y.iloc[[3, 1, 0, 2]].reset_index(drop=True) + + assert _generate_data_hash(X, y) != _generate_data_hash(X_perm, y_perm), ( + "Permuted rows must produce a different cache key" + ) + + +def test_generate_data_hash_avoids_sum_collision() -> None: + """Two datasets of matching shape/columns whose per-row hashes sum to + the same value must still produce different cache keys. + + Previously the cache key was effectively + ``hash(sum(per_row_hashes))``, making collisions trivial — a stale + cached MDN could be loaded for a new, different dataset (silent + correctness bug). + """ + X1 = pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]}) + y1 = pd.Series([7.0, 8.0, 9.0], name="target") + + X2 = pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.1]}) + y2 = pd.Series([7.0, 8.0, 9.0], name="target") + + assert _generate_data_hash(X1, y1) != _generate_data_hash(X2, y2) + + +def test_generate_data_hash_differs_across_random_datasets() -> None: + """Property-style check: 50 random datasets produce 50 distinct hashes.""" + rng = np.random.default_rng(0) + hashes = set() + for _ in range(50): + X = pd.DataFrame(rng.normal(size=(20, 3)), columns=["a", "b", "c"]) + y = pd.Series(rng.normal(size=20), name="target") + hashes.add(_generate_data_hash(X, y)) + assert len(hashes) == 50, "50 random datasets should produce 50 distinct cache keys" + + +def test_generate_cache_key_integrates_data_hash() -> None: + """_generate_cache_key must change when _generate_data_hash changes.""" + k1 = _generate_cache_key(["a", "b"], "target", "data_hash_1") + k2 = _generate_cache_key(["a", "b"], "target", "data_hash_2") + assert k1 != k2