Skip to content

Standardize Hard Concrete numerical guards across modules#42

Merged
MaxGhenis merged 1 commit intomainfrom
fix/numerical-guards
Apr 17, 2026
Merged

Standardize Hard Concrete numerical guards across modules#42
MaxGhenis merged 1 commit intomainfrom
fix/numerical-guards

Conversation

@MaxGhenis
Copy link
Copy Markdown
Contributor

Summary

  • Fix distributions.py uniform-noise epsilon: was 1e-8 (underflows at fp16), now 1e-6 to match calibration.py and sparse.py.
  • Add LOG_ALPHA_BOUND = 20.0 clamp on qz_logits/log_alpha across all three modules in the sampling, deterministic, and penalty paths. This prevents sigmoid saturation, matches the Louizos et al. 2017 recommendation, and keeps the math well-defined for fp16/bf16.

Addresses bug-hunt findings

Test plan

  • uv run pytest tests -x -q passes (87 passed, 1 skipped; previously 83).
  • Adds test_extreme_logits_stay_finite in test_distributions.py and test_extreme_log_alpha_stays_finite in test_calibration.py and test_sparse.py — fill logits with ±1e3, confirm sampled/deterministic/penalty output stays finite and in [0, 1].
  • Adds test_uniform_eps_is_fp16_safe to pin the standardized 1e-6 floor.

Copy link
Copy Markdown
Contributor Author

@MaxGhenis MaxGhenis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Verified:

  • Epsilon standardized to 1e-6 across distributions.py, calibration.py, sparse.py — no more 1e-8 in distributions.py.
  • LOG_ALPHA_BOUND = 20.0 class constant added to all three modules with clamp applied in every path that feeds sigmoid: _sample_gates, _deterministic_gates / get_deterministic_gates, get_penalty / get_l0_penalty, and get_active_prob in distributions.py. Symmetric across modules.
  • fp16 stress test passes locally: HardConcrete(...).half() with qz_logits=100.0 produces finite sampled gates, eval gates, and penalty.
  • Regression tests: test_extreme_logits_stay_finite (distributions), test_extreme_log_alpha_stays_finite (calibration, sparse), test_uniform_eps_is_fp16_safe pins the standardized floor.
  • 44 tests pass locally (87 total in PR + pre-existing).
  • The clamp rebind to local logits/log_alpha before each sigmoid is clean — doesn't mutate the nn.Parameter.

Tiny nit (not blocking): in calibration.py and sparse.py, eps = 1e-6 is still a local variable in _sample_gates rather than a class constant like _UNIFORM_EPS in distributions.py. Consider unifying in a follow-up.

Will go green once #44 merges.

Three modules implement the same Hard Concrete math but disagreed on their
numerical guards:

- `distributions.py:109` sampled uniform noise from `[1e-8, 1 - 1e-8]`. At
  fp16/bf16 `1e-8` underflows to zero and `log(0) = -inf` produces `NaN`
  gates; at fp32 with `temperature=0.1` the log-odds can reach ~-184 before
  the stretch, which saturates downstream sigmoids.
- `calibration.py:171` and `sparse.py:115` already used the fp16-safe
  `1e-6`; this PR brings `distributions.py` in line.
- None of the three modules clamped `qz_logits`/`log_alpha`. With
  `init_mean=0.999` the initial logit is ~6.9, and optimizer updates can
  easily push it past 30, at which point `sigmoid(log_alpha + c)` saturates
  and gradients vanish (the paper explicitly recommends clamping).

Adds a shared `LOG_ALPHA_BOUND = 20.0` class constant to all three modules
and clamps `log_alpha`/`qz_logits` before the sigmoid in the sampling,
deterministic, and penalty paths.

Regression tests in `test_distributions.py`, `test_calibration.py`, and
`test_sparse.py` fill `log_alpha`/`qz_logits` with `±1e3` and assert that
sampled gates, deterministic gates, and the L0 penalty all stay finite and
in `[0, 1]`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@MaxGhenis MaxGhenis force-pushed the fix/numerical-guards branch from 7c6e94b to 5a55733 Compare April 17, 2026 16:28
@MaxGhenis MaxGhenis merged commit 0c78b10 into main Apr 17, 2026
2 of 3 checks passed
@MaxGhenis MaxGhenis deleted the fix/numerical-guards branch April 17, 2026 16:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant