Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ L0 is a PyTorch package implementing L0 regularization from Louizos, Welling, &
- **l0/layers.py**: L0Linear, L0Conv2d, L0DepthwiseConv2d, SparseMLP
- **l0/gates.py**: L0Gate, SampleGate, FeatureGate, HybridGate
- **l0/penalties.py**: L0/L2/L0L2 penalties, TemperatureScheduler, PenaltyTracker
- **l0/calibration.py**: SparseCalibrationWeights (non-negative L0 weights for survey calibration)
- **l0/sparse.py**: SparseL0Linear (L0-regularized linear regression for scipy.sparse inputs)
- **tests/**: Comprehensive test coverage using TDD approach
- **CI/CD**: GitHub Actions workflow for Python 3.13

Expand All @@ -24,12 +26,16 @@ l0/
│ ├── distributions.py # HardConcrete distribution
│ ├── layers.py # Neural network layers with L0
│ ├── gates.py # Standalone gates for selection
│ └── penalties.py # Penalty computation and utilities
│ ├── penalties.py # Penalty computation and utilities
│ ├── calibration.py # SparseCalibrationWeights (positive, sparse)
│ └── sparse.py # SparseL0Linear (scipy.sparse inputs)
├── tests/
│ ├── test_distributions.py
│ ├── test_layers.py
│ ├── test_gates.py
│ └── test_penalties.py
│ ├── test_penalties.py
│ ├── test_calibration.py
│ └── test_sparse.py
├── docs/ # Jupyter Book documentation (pending)
├── examples/ # Example notebooks (pending)
├── .github/workflows/ci.yml
Expand All @@ -50,11 +56,11 @@ pytest tests/ -v --cov=l0
# Run specific test
pytest tests/test_layers.py::TestL0Linear -v

# Format code (79 char line length)
black . -l 79
# Format code (uses ruff format, default 88 char line length)
ruff format .

# Check formatting
black . -l 79 --check
ruff format --check .

# Lint with ruff
ruff check .
Expand Down Expand Up @@ -125,7 +131,7 @@ l0_lambda = 5.0e-07 # Tuned value from PolicyEngine
## Code Standards

- **Python 3.13**: Required for latest features
- **Black Formatter**: 79-character line length (PolicyEngine standard)
- **Ruff formatter** (default 88-char line length)
- **Type Hints**: All public functions fully typed
- **Docstrings**: NumPy style with examples
- **Imports**: Grouped (stdlib, third-party, local) and alphabetized
Expand Down
6 changes: 6 additions & 0 deletions changelog.d/fix-api-exports-and-seeds.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Export `SparseCalibrationWeights` from the top-level `l0` package (closing
the discoverability gap with `SparseL0Linear`), add an optional `seed`
parameter to both `SparseCalibrationWeights` and `SparseL0Linear` so
`log_alpha` / `log_weight` jitter is reproducible without managing
PyTorch's global RNG, and update `CLAUDE.md` to list `calibration.py` and
`sparse.py` as first-class modules.
3 changes: 3 additions & 0 deletions l0/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

__version__ = "0.5.0"

from .calibration import SparseCalibrationWeights
from .distributions import HardConcrete
from .gates import FeatureGate, HybridGate, L0Gate, SampleGate
from .layers import (
Expand Down Expand Up @@ -39,6 +40,8 @@
"prune_model",
# Sparse
"SparseL0Linear",
# Calibration
"SparseCalibrationWeights",
# Gates
"L0Gate",
"SampleGate",
Expand Down
33 changes: 31 additions & 2 deletions l0/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ class SparseCalibrationWeights(nn.Module):
Set to 0 to disable jitter. Default is 0.01 (following Louizos et al.).
device : str or torch.device
Device to run computations on ('cpu' or 'cuda')
seed : int, optional
Seed for the RNG used by the `log_alpha` init jitter and (inside
``fit``) the `log_weight` jitter. When set, two models with the same
inputs produce byte-identical initial `log_alpha` and jitter without
the caller having to manage PyTorch's global RNG. ``None`` (default)
preserves legacy behaviour of using the global RNG.
"""

def __init__(
Expand All @@ -61,6 +67,7 @@ def __init__(
log_alpha_jitter_sd: float = 0.01,
device: str | torch.device = "cpu",
use_gates: bool = True,
seed: int | None = None,
):
super().__init__()
self.n_features = n_features
Expand All @@ -71,6 +78,17 @@ def __init__(
self.log_alpha_jitter_sd = log_alpha_jitter_sd
self.device = torch.device(device)
self.use_gates = use_gates
self.seed = seed

# Local RNG (only used when `seed` is provided). Kept on the module so
# `fit`'s `log_weight` jitter uses the same deterministic stream.
if seed is not None:
self._generator: torch.Generator | None = torch.Generator(
device=self.device
)
self._generator.manual_seed(int(seed))
else:
self._generator = None

# Initialize weights (on original scale)
if init_weights is None:
Expand Down Expand Up @@ -121,7 +139,8 @@ def __init__(
# Add jitter to break symmetry (if specified)
if self.log_alpha_jitter_sd > 0:
jitter = (
torch.randn(n_features, device=self.device) * self.log_alpha_jitter_sd
torch.randn(n_features, generator=self._generator, device=self.device)
* self.log_alpha_jitter_sd
)
self.log_alpha = nn.Parameter(mu + jitter)
else:
Expand Down Expand Up @@ -398,7 +417,17 @@ def fit(
# Add jitter to weights to break symmetry (if jitter_sd > 0)
if self.log_weight_jitter_sd > 0:
with torch.no_grad():
jitter = torch.randn_like(self.log_weight) * self.log_weight_jitter_sd
# `torch.randn_like` can't take a generator kwarg, so draw
# explicitly via `torch.randn` when a local RNG is set.
if self._generator is not None:
noise = torch.randn(
self.log_weight.shape,
generator=self._generator,
device=self.device,
)
else:
noise = torch.randn_like(self.log_weight)
jitter = noise * self.log_weight_jitter_sd
self.log_weight.data += jitter

# Setup optimizer
Expand Down
23 changes: 19 additions & 4 deletions l0/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ class SparseL0Linear(nn.Module):
Initial probability of keeping each feature
device : str or torch.device
Device to run computations on ('cpu' or 'cuda')
seed : int, optional
Seed for the RNG used by the ``log_alpha`` init. When set, two
models with the same inputs produce byte-identical initial
``log_alpha`` without the caller having to manage PyTorch's
global RNG. ``None`` (default) preserves legacy behaviour of
using the global RNG.
"""

def __init__(
Expand All @@ -47,6 +53,7 @@ def __init__(
zeta: float = 1.1,
init_keep_prob: float = 0.5,
device: str | torch.device = "cpu",
seed: int | None = None,
):
super().__init__()
self.n_features = n_features
Expand All @@ -55,6 +62,15 @@ def __init__(
self.gamma = gamma
self.zeta = zeta
self.device = torch.device(device)
self.seed = seed

if seed is not None:
self._generator: torch.Generator | None = torch.Generator(
device=self.device
)
self._generator.manual_seed(int(seed))
else:
self._generator = None

# Model parameters
self.weight = nn.Parameter(torch.zeros(n_features, device=self.device))
Expand All @@ -63,11 +79,10 @@ def __init__(
else:
self.register_parameter("bias", None)

# L0 gate parameters
# L0 gate parameters: mu + N(0, 0.01) jitter, optionally seeded.
mu = torch.log(torch.tensor(init_keep_prob / (1 - init_keep_prob)))
self.log_alpha = nn.Parameter(
torch.normal(mu.item(), 0.01, size=(n_features,), device=self.device)
)
noise = torch.randn(n_features, generator=self._generator, device=self.device)
self.log_alpha = nn.Parameter(mu.item() + 0.01 * noise)

# Cache for sparse tensor conversion
self._cached_X_torch: torch.sparse.Tensor | None = None
Expand Down
24 changes: 24 additions & 0 deletions tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,3 +961,27 @@ def test_normalize_groups_false_uniform_weights(self):
f"normalize_groups=False should behave like no groups: "
f"{err_none:.4f} vs {err_no_norm:.4f}"
)

def test_seed_produces_deterministic_log_alpha(self):
"""Two models with the same `seed` must share `log_alpha` init."""
a = SparseCalibrationWeights(n_features=50, log_alpha_jitter_sd=0.1, seed=123)
b = SparseCalibrationWeights(n_features=50, log_alpha_jitter_sd=0.1, seed=123)
torch.testing.assert_close(a.log_alpha.data, b.log_alpha.data)

c = SparseCalibrationWeights(n_features=50, log_alpha_jitter_sd=0.1, seed=456)
assert not torch.allclose(a.log_alpha.data, c.log_alpha.data)

def test_seed_none_uses_global_rng(self):
"""`seed=None` is the legacy behaviour: global RNG, caller-managed."""
torch.manual_seed(0)
a = SparseCalibrationWeights(n_features=20, log_alpha_jitter_sd=0.1)
torch.manual_seed(0)
b = SparseCalibrationWeights(n_features=20, log_alpha_jitter_sd=0.1)
torch.testing.assert_close(a.log_alpha.data, b.log_alpha.data)

def test_sparse_calibration_weights_exported(self):
"""`SparseCalibrationWeights` must be importable from the top-level."""
import l0

assert "SparseCalibrationWeights" in l0.__all__
assert l0.SparseCalibrationWeights is SparseCalibrationWeights
17 changes: 17 additions & 0 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,20 @@ def test_deterministic_vs_stochastic(self):
assert not torch.allclose(y_stoch1, y_stoch2), (
"Stochastic predictions should differ"
)

def test_seed_produces_deterministic_log_alpha(self):
"""Two models with the same `seed` must share `log_alpha` init."""
a = SparseL0Linear(n_features=50, init_keep_prob=0.5, seed=123)
b = SparseL0Linear(n_features=50, init_keep_prob=0.5, seed=123)
torch.testing.assert_close(a.log_alpha.data, b.log_alpha.data)

c = SparseL0Linear(n_features=50, init_keep_prob=0.5, seed=456)
assert not torch.allclose(a.log_alpha.data, c.log_alpha.data)

def test_seed_none_uses_global_rng(self):
"""`seed=None` is the legacy behaviour: global RNG, caller-managed."""
torch.manual_seed(0)
a = SparseL0Linear(n_features=20, init_keep_prob=0.5)
torch.manual_seed(0)
b = SparseL0Linear(n_features=20, init_keep_prob=0.5)
torch.testing.assert_close(a.log_alpha.data, b.log_alpha.data)
Loading