Standardize Hard Concrete numerical guards across modules#42
Merged
Conversation
2 tasks
MaxGhenis
commented
Apr 17, 2026
Contributor
Author
MaxGhenis
left a comment
There was a problem hiding this comment.
LGTM.
Verified:
- Epsilon standardized to
1e-6acrossdistributions.py,calibration.py,sparse.py— no more1e-8indistributions.py. LOG_ALPHA_BOUND = 20.0class constant added to all three modules withclampapplied in every path that feeds sigmoid:_sample_gates,_deterministic_gates/get_deterministic_gates,get_penalty/get_l0_penalty, andget_active_probin distributions.py. Symmetric across modules.- fp16 stress test passes locally:
HardConcrete(...).half()withqz_logits=100.0produces finite sampled gates, eval gates, and penalty. - Regression tests:
test_extreme_logits_stay_finite(distributions),test_extreme_log_alpha_stays_finite(calibration, sparse),test_uniform_eps_is_fp16_safepins the standardized floor. - 44 tests pass locally (87 total in PR + pre-existing).
- The clamp rebind to local
logits/log_alphabefore each sigmoid is clean — doesn't mutate the nn.Parameter.
Tiny nit (not blocking): in calibration.py and sparse.py, eps = 1e-6 is still a local variable in _sample_gates rather than a class constant like _UNIFORM_EPS in distributions.py. Consider unifying in a follow-up.
Will go green once #44 merges.
Three modules implement the same Hard Concrete math but disagreed on their numerical guards: - `distributions.py:109` sampled uniform noise from `[1e-8, 1 - 1e-8]`. At fp16/bf16 `1e-8` underflows to zero and `log(0) = -inf` produces `NaN` gates; at fp32 with `temperature=0.1` the log-odds can reach ~-184 before the stretch, which saturates downstream sigmoids. - `calibration.py:171` and `sparse.py:115` already used the fp16-safe `1e-6`; this PR brings `distributions.py` in line. - None of the three modules clamped `qz_logits`/`log_alpha`. With `init_mean=0.999` the initial logit is ~6.9, and optimizer updates can easily push it past 30, at which point `sigmoid(log_alpha + c)` saturates and gradients vanish (the paper explicitly recommends clamping). Adds a shared `LOG_ALPHA_BOUND = 20.0` class constant to all three modules and clamps `log_alpha`/`qz_logits` before the sigmoid in the sampling, deterministic, and penalty paths. Regression tests in `test_distributions.py`, `test_calibration.py`, and `test_sparse.py` fill `log_alpha`/`qz_logits` with `±1e3` and assert that sampled gates, deterministic gates, and the L0 penalty all stay finite and in `[0, 1]`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
7c6e94b to
5a55733
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
distributions.pyuniform-noise epsilon: was1e-8(underflows at fp16), now1e-6to matchcalibration.pyandsparse.py.LOG_ALPHA_BOUND = 20.0clamp onqz_logits/log_alphaacross all three modules in the sampling, deterministic, and penalty paths. This preventssigmoidsaturation, matches the Louizos et al. 2017 recommendation, and keeps the math well-defined for fp16/bf16.Addresses bug-hunt findings
1e-8vs1e-6) and nolog_alphaclamp.Test plan
uv run pytest tests -x -qpasses (87 passed, 1 skipped; previously 83).test_extreme_logits_stay_finiteintest_distributions.pyandtest_extreme_log_alpha_stays_finiteintest_calibration.pyandtest_sparse.py— fill logits with±1e3, confirm sampled/deterministic/penalty output stays finite and in[0, 1].test_uniform_eps_is_fp16_safeto pin the standardized1e-6floor.