Export SparseCalibrationWeights and add optional seed parameter#43
Merged
Export SparseCalibrationWeights and add optional seed parameter#43
Conversation
- `SparseCalibrationWeights` was documented in the README and in `CRITICAL_TEMPERATURE_BUG.md` as the temperature-correct path for calibration workflows, but it was not re-exported from `l0/__init__.py`. `SparseL0Linear` was already exported, so the asymmetry made `SparseCalibrationWeights` noticeably harder to discover. Adds it to the package's `__all__`. - `SparseCalibrationWeights.__init__` and `SparseL0Linear.__init__` drew their `log_alpha` jitter from PyTorch's global RNG, so two constructions with the same inputs were only deterministic if the caller remembered to `torch.manual_seed(...)` beforehand. Adds an optional `seed: int | None` kwarg to both classes; when set, the jitter (and `log_weight` jitter inside `SparseCalibrationWeights.fit`) draws from a local `torch.Generator`. `seed=None` preserves legacy behaviour. - Updates `CLAUDE.md` to list `calibration.py` and `sparse.py` as first-class modules, fixes the stale `black -l 79` formatter hint, and adds the matching test files. Adds seed/export regression tests in `test_calibration.py` and `test_sparse.py`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2 tasks
MaxGhenis
commented
Apr 17, 2026
Contributor
Author
MaxGhenis
left a comment
There was a problem hiding this comment.
LGTM.
Verified:
l0/__init__.py:SparseCalibrationWeightsadded to__all__alongsideSparseL0Linear(previously only the latter was exported). Import works:from l0 import SparseCalibrationWeightssucceeds;l0.SparseCalibrationWeights is SparseCalibrationWeights.seed: int | None = Noneon bothSparseCalibrationWeights.__init__andSparseL0Linear.__init__, with matching docstring entries.- Default
Nonepreserves legacy global-RNG behaviour —test_seed_none_uses_global_rngpins this withtorch.manual_seed(0)on both sides. - When
seedis set, a localtorch.Generatoris created (on the module's device) and used consistently in both init jitter and (forSparseCalibrationWeights) thefitlog_weightjitter. Therandn_like-has-no-generator edge case is handled explicitly viatorch.randnwith shape + generator. SparseL0Linear.__init__also rewritestorch.normal(mu.item(), 0.01, size=...)tomu.item() + 0.01 * torch.randn(..., generator=...)so the seeded path has a single consistent RNG — nice tidy.CLAUDE.mdnow listscalibration.py/sparse.pyas first-class modules, updates the package tree, addstest_calibration.py/test_sparse.py, and replaces the staleblack -l 79formatter hint withruff format. Subsumes finding #5's doc staleness.- 28 tests pass locally.
Will go green once #44 merges.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
SparseCalibrationWeightsfrom the top-levell0package — it was already documented in the README and inCRITICAL_TEMPERATURE_BUG.mdas the temperature-correct path for calibration but was missing from__all__.SparseL0Linearwas already exported.seed: int | Nonekwarg toSparseCalibrationWeightsandSparseL0Linear. When set,log_alphajitter (andlog_weightjitter insideSparseCalibrationWeights.fit) draws from a localtorch.Generator;seed=Nonepreserves the legacy global-RNG behaviour.CLAUDE.mdto listcalibration.py/sparse.pyas first-class modules, update test file layout, and replace the staleblack -l 79formatter hint with the currentruff formatinvocation.Addresses bug-hunt findings
SparseCalibrationWeightsmissing froml0/__init__.py::__all__andCLAUDE.mdlisting stale module set.seedparameter onSparseCalibrationWeights.__init__/SparseL0Linear.__init__; jitter relied on the global RNG.Test plan
uv run pytest tests -x -qpasses (88 passed, 1 skipped; previously 83).test_sparse_calibration_weights_exportedpins"SparseCalibrationWeights" in l0.__all__.test_seed_produces_deterministic_log_alphaandtest_seed_none_uses_global_rngverify both RNG paths intest_calibration.pyandtest_sparse.py.