From 5d685babd9e8ee79277510772171ba8722c957b2 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:34:47 -0400 Subject: [PATCH] Export SparseCalibrationWeights and add optional `seed` parameter - `SparseCalibrationWeights` was documented in the README and in `CRITICAL_TEMPERATURE_BUG.md` as the temperature-correct path for calibration workflows, but it was not re-exported from `l0/__init__.py`. `SparseL0Linear` was already exported, so the asymmetry made `SparseCalibrationWeights` noticeably harder to discover. Adds it to the package's `__all__`. - `SparseCalibrationWeights.__init__` and `SparseL0Linear.__init__` drew their `log_alpha` jitter from PyTorch's global RNG, so two constructions with the same inputs were only deterministic if the caller remembered to `torch.manual_seed(...)` beforehand. Adds an optional `seed: int | None` kwarg to both classes; when set, the jitter (and `log_weight` jitter inside `SparseCalibrationWeights.fit`) draws from a local `torch.Generator`. `seed=None` preserves legacy behaviour. - Updates `CLAUDE.md` to list `calibration.py` and `sparse.py` as first-class modules, fixes the stale `black -l 79` formatter hint, and adds the matching test files. Adds seed/export regression tests in `test_calibration.py` and `test_sparse.py`. Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.md | 18 ++++++---- .../fix-api-exports-and-seeds.added.md | 6 ++++ l0/__init__.py | 3 ++ l0/calibration.py | 33 +++++++++++++++++-- l0/sparse.py | 23 ++++++++++--- tests/test_calibration.py | 24 ++++++++++++++ tests/test_sparse.py | 17 ++++++++++ 7 files changed, 112 insertions(+), 12 deletions(-) create mode 100644 changelog.d/fix-api-exports-and-seeds.added.md diff --git a/CLAUDE.md b/CLAUDE.md index 3d351e0..e145d4c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -13,6 +13,8 @@ L0 is a PyTorch package implementing L0 regularization from Louizos, Welling, & - **l0/layers.py**: L0Linear, L0Conv2d, L0DepthwiseConv2d, SparseMLP - **l0/gates.py**: L0Gate, SampleGate, FeatureGate, HybridGate - **l0/penalties.py**: L0/L2/L0L2 penalties, TemperatureScheduler, PenaltyTracker +- **l0/calibration.py**: SparseCalibrationWeights (non-negative L0 weights for survey calibration) +- **l0/sparse.py**: SparseL0Linear (L0-regularized linear regression for scipy.sparse inputs) - **tests/**: Comprehensive test coverage using TDD approach - **CI/CD**: GitHub Actions workflow for Python 3.13 @@ -24,12 +26,16 @@ l0/ │ ├── distributions.py # HardConcrete distribution │ ├── layers.py # Neural network layers with L0 │ ├── gates.py # Standalone gates for selection -│ └── penalties.py # Penalty computation and utilities +│ ├── penalties.py # Penalty computation and utilities +│ ├── calibration.py # SparseCalibrationWeights (positive, sparse) +│ └── sparse.py # SparseL0Linear (scipy.sparse inputs) ├── tests/ │ ├── test_distributions.py │ ├── test_layers.py │ ├── test_gates.py -│ └── test_penalties.py +│ ├── test_penalties.py +│ ├── test_calibration.py +│ └── test_sparse.py ├── docs/ # Jupyter Book documentation (pending) ├── examples/ # Example notebooks (pending) ├── .github/workflows/ci.yml @@ -50,11 +56,11 @@ pytest tests/ -v --cov=l0 # Run specific test pytest tests/test_layers.py::TestL0Linear -v -# Format code (79 char line length) -black . -l 79 +# Format code (uses ruff format, default 88 char line length) +ruff format . # Check formatting -black . -l 79 --check +ruff format --check . # Lint with ruff ruff check . @@ -125,7 +131,7 @@ l0_lambda = 5.0e-07 # Tuned value from PolicyEngine ## Code Standards - **Python 3.13**: Required for latest features -- **Black Formatter**: 79-character line length (PolicyEngine standard) +- **Ruff formatter** (default 88-char line length) - **Type Hints**: All public functions fully typed - **Docstrings**: NumPy style with examples - **Imports**: Grouped (stdlib, third-party, local) and alphabetized diff --git a/changelog.d/fix-api-exports-and-seeds.added.md b/changelog.d/fix-api-exports-and-seeds.added.md new file mode 100644 index 0000000..57f86d5 --- /dev/null +++ b/changelog.d/fix-api-exports-and-seeds.added.md @@ -0,0 +1,6 @@ +Export `SparseCalibrationWeights` from the top-level `l0` package (closing +the discoverability gap with `SparseL0Linear`), add an optional `seed` +parameter to both `SparseCalibrationWeights` and `SparseL0Linear` so +`log_alpha` / `log_weight` jitter is reproducible without managing +PyTorch's global RNG, and update `CLAUDE.md` to list `calibration.py` and +`sparse.py` as first-class modules. diff --git a/l0/__init__.py b/l0/__init__.py index fe58b59..2d11d8a 100644 --- a/l0/__init__.py +++ b/l0/__init__.py @@ -7,6 +7,7 @@ __version__ = "0.5.0" +from .calibration import SparseCalibrationWeights from .distributions import HardConcrete from .gates import FeatureGate, HybridGate, L0Gate, SampleGate from .layers import ( @@ -39,6 +40,8 @@ "prune_model", # Sparse "SparseL0Linear", + # Calibration + "SparseCalibrationWeights", # Gates "L0Gate", "SampleGate", diff --git a/l0/calibration.py b/l0/calibration.py index 1348832..950ad3e 100644 --- a/l0/calibration.py +++ b/l0/calibration.py @@ -47,6 +47,12 @@ class SparseCalibrationWeights(nn.Module): Set to 0 to disable jitter. Default is 0.01 (following Louizos et al.). device : str or torch.device Device to run computations on ('cpu' or 'cuda') + seed : int, optional + Seed for the RNG used by the `log_alpha` init jitter and (inside + ``fit``) the `log_weight` jitter. When set, two models with the same + inputs produce byte-identical initial `log_alpha` and jitter without + the caller having to manage PyTorch's global RNG. ``None`` (default) + preserves legacy behaviour of using the global RNG. """ def __init__( @@ -61,6 +67,7 @@ def __init__( log_alpha_jitter_sd: float = 0.01, device: str | torch.device = "cpu", use_gates: bool = True, + seed: int | None = None, ): super().__init__() self.n_features = n_features @@ -71,6 +78,17 @@ def __init__( self.log_alpha_jitter_sd = log_alpha_jitter_sd self.device = torch.device(device) self.use_gates = use_gates + self.seed = seed + + # Local RNG (only used when `seed` is provided). Kept on the module so + # `fit`'s `log_weight` jitter uses the same deterministic stream. + if seed is not None: + self._generator: torch.Generator | None = torch.Generator( + device=self.device + ) + self._generator.manual_seed(int(seed)) + else: + self._generator = None # Initialize weights (on original scale) if init_weights is None: @@ -121,7 +139,8 @@ def __init__( # Add jitter to break symmetry (if specified) if self.log_alpha_jitter_sd > 0: jitter = ( - torch.randn(n_features, device=self.device) * self.log_alpha_jitter_sd + torch.randn(n_features, generator=self._generator, device=self.device) + * self.log_alpha_jitter_sd ) self.log_alpha = nn.Parameter(mu + jitter) else: @@ -398,7 +417,17 @@ def fit( # Add jitter to weights to break symmetry (if jitter_sd > 0) if self.log_weight_jitter_sd > 0: with torch.no_grad(): - jitter = torch.randn_like(self.log_weight) * self.log_weight_jitter_sd + # `torch.randn_like` can't take a generator kwarg, so draw + # explicitly via `torch.randn` when a local RNG is set. + if self._generator is not None: + noise = torch.randn( + self.log_weight.shape, + generator=self._generator, + device=self.device, + ) + else: + noise = torch.randn_like(self.log_weight) + jitter = noise * self.log_weight_jitter_sd self.log_weight.data += jitter # Setup optimizer diff --git a/l0/sparse.py b/l0/sparse.py index 0b6a425..36ba3e6 100644 --- a/l0/sparse.py +++ b/l0/sparse.py @@ -36,6 +36,12 @@ class SparseL0Linear(nn.Module): Initial probability of keeping each feature device : str or torch.device Device to run computations on ('cpu' or 'cuda') + seed : int, optional + Seed for the RNG used by the ``log_alpha`` init. When set, two + models with the same inputs produce byte-identical initial + ``log_alpha`` without the caller having to manage PyTorch's + global RNG. ``None`` (default) preserves legacy behaviour of + using the global RNG. """ def __init__( @@ -47,6 +53,7 @@ def __init__( zeta: float = 1.1, init_keep_prob: float = 0.5, device: str | torch.device = "cpu", + seed: int | None = None, ): super().__init__() self.n_features = n_features @@ -55,6 +62,15 @@ def __init__( self.gamma = gamma self.zeta = zeta self.device = torch.device(device) + self.seed = seed + + if seed is not None: + self._generator: torch.Generator | None = torch.Generator( + device=self.device + ) + self._generator.manual_seed(int(seed)) + else: + self._generator = None # Model parameters self.weight = nn.Parameter(torch.zeros(n_features, device=self.device)) @@ -63,11 +79,10 @@ def __init__( else: self.register_parameter("bias", None) - # L0 gate parameters + # L0 gate parameters: mu + N(0, 0.01) jitter, optionally seeded. mu = torch.log(torch.tensor(init_keep_prob / (1 - init_keep_prob))) - self.log_alpha = nn.Parameter( - torch.normal(mu.item(), 0.01, size=(n_features,), device=self.device) - ) + noise = torch.randn(n_features, generator=self._generator, device=self.device) + self.log_alpha = nn.Parameter(mu.item() + 0.01 * noise) # Cache for sparse tensor conversion self._cached_X_torch: torch.sparse.Tensor | None = None diff --git a/tests/test_calibration.py b/tests/test_calibration.py index b48ed9c..77304d1 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -961,3 +961,27 @@ def test_normalize_groups_false_uniform_weights(self): f"normalize_groups=False should behave like no groups: " f"{err_none:.4f} vs {err_no_norm:.4f}" ) + + def test_seed_produces_deterministic_log_alpha(self): + """Two models with the same `seed` must share `log_alpha` init.""" + a = SparseCalibrationWeights(n_features=50, log_alpha_jitter_sd=0.1, seed=123) + b = SparseCalibrationWeights(n_features=50, log_alpha_jitter_sd=0.1, seed=123) + torch.testing.assert_close(a.log_alpha.data, b.log_alpha.data) + + c = SparseCalibrationWeights(n_features=50, log_alpha_jitter_sd=0.1, seed=456) + assert not torch.allclose(a.log_alpha.data, c.log_alpha.data) + + def test_seed_none_uses_global_rng(self): + """`seed=None` is the legacy behaviour: global RNG, caller-managed.""" + torch.manual_seed(0) + a = SparseCalibrationWeights(n_features=20, log_alpha_jitter_sd=0.1) + torch.manual_seed(0) + b = SparseCalibrationWeights(n_features=20, log_alpha_jitter_sd=0.1) + torch.testing.assert_close(a.log_alpha.data, b.log_alpha.data) + + def test_sparse_calibration_weights_exported(self): + """`SparseCalibrationWeights` must be importable from the top-level.""" + import l0 + + assert "SparseCalibrationWeights" in l0.__all__ + assert l0.SparseCalibrationWeights is SparseCalibrationWeights diff --git a/tests/test_sparse.py b/tests/test_sparse.py index fa285f8..955bbe7 100644 --- a/tests/test_sparse.py +++ b/tests/test_sparse.py @@ -224,3 +224,20 @@ def test_deterministic_vs_stochastic(self): assert not torch.allclose(y_stoch1, y_stoch2), ( "Stochastic predictions should differ" ) + + def test_seed_produces_deterministic_log_alpha(self): + """Two models with the same `seed` must share `log_alpha` init.""" + a = SparseL0Linear(n_features=50, init_keep_prob=0.5, seed=123) + b = SparseL0Linear(n_features=50, init_keep_prob=0.5, seed=123) + torch.testing.assert_close(a.log_alpha.data, b.log_alpha.data) + + c = SparseL0Linear(n_features=50, init_keep_prob=0.5, seed=456) + assert not torch.allclose(a.log_alpha.data, c.log_alpha.data) + + def test_seed_none_uses_global_rng(self): + """`seed=None` is the legacy behaviour: global RNG, caller-managed.""" + torch.manual_seed(0) + a = SparseL0Linear(n_features=20, init_keep_prob=0.5) + torch.manual_seed(0) + b = SparseL0Linear(n_features=20, init_keep_prob=0.5) + torch.testing.assert_close(a.log_alpha.data, b.log_alpha.data)