Skip to content

Fix temperature scaling in HardConcrete deterministic gates#41

Merged
MaxGhenis merged 1 commit intomainfrom
fix/temperature-deterministic-gates
Apr 17, 2026
Merged

Fix temperature scaling in HardConcrete deterministic gates#41
MaxGhenis merged 1 commit intomainfrom
fix/temperature-deterministic-gates

Conversation

@MaxGhenis
Copy link
Copy Markdown
Contributor

Summary

  • HardConcrete._deterministic_gates was computing sigmoid(qz_logits) without the / temperature scaling, so .eval() output silently ignored the temperature parameter. For PolicyEngine's default temperature=0.25 this was a 4x distortion in log-odds space and broke train/eval consistency for every L0Linear / L0Conv2d / L0DepthwiseConv2d / L0Gate caller, and TemperatureScheduler updates were dropped at eval time.
  • The fix uses sigmoid(qz_logits / temperature), matching Louizos et al. 2017 Eq. 11 and the temperature-correct implementations in SparseCalibrationWeights (calibration.py) and SparseL0Linear (sparse.py).
  • Adds three regression tests.

Addresses bug-hunt findings

Test plan

  • uv run pytest tests -x -q passes (86 passed, 1 skipped; previously 83).
  • Manually reverted l0/distributions.py and re-ran test_deterministic_gates_respect_temperature — fails with AssertionError: Deterministic gates should depend on temperature.
  • test_deterministic_gates_match_reference_formula pins the closed form against SparseCalibrationWeights / SparseL0Linear.
  • test_sparsity_stats_match_eval_activation verifies get_sparsity_stats agrees with actual eval-time activation rate.

`HardConcrete._deterministic_gates` computed `sigmoid(qz_logits)` without
dividing by `temperature`, so `.eval()` output silently ignored the
temperature parameter. This contradicted `_sample_gates`, `get_penalty`, and
`get_active_prob`, which all apply the temperature scaling. For PolicyEngine's
default `temperature=0.25` this was a 4x distortion in log-odds space and
broke train/eval consistency for every L0Linear / L0Conv2d / L0DepthwiseConv2d
/ L0Gate user. TemperatureScheduler updates were also silently dropped at
eval time.

The fix uses `sigmoid(qz_logits / temperature)`, matching Louizos et al.
2017 Eq. 11 and the temperature-correct implementations in
`SparseCalibrationWeights` (calibration.py) and `SparseL0Linear` (sparse.py).

Adds three regression tests:

- `test_deterministic_gates_respect_temperature` fixes identical logits,
  varies temperature, and asserts eval output differs. Confirmed to FAIL on
  the pre-fix code (`AssertionError: Deterministic gates should depend on
  temperature`).
- `test_deterministic_gates_match_reference_formula` pins the closed-form
  output against `sigmoid(log_alpha / beta) * (zeta - gamma) + gamma`.
- `test_sparsity_stats_match_eval_activation` checks that `get_sparsity()`
  (temperature-aware) stays consistent with the fraction of gates actually
  non-zero at eval time.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
MaxGhenis added a commit that referenced this pull request Apr 17, 2026
This file documented a temperature-scaling bug in
`HardConcrete._deterministic_gates` and pointed readers at a "gold
standard" reference file `l0_louizos_improved_gate.py` that does not
actually exist in the repository. The bug itself is fixed by
#41 (the `sigmoid(qz_logits /
temperature)` change in `l0/distributions.py`), and the three standalone
modules (`distributions.py`, `calibration.py`, `sparse.py`) are now
consistent. Remove the stale doc now that its content is either obsolete
or misleading.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor Author

@MaxGhenis MaxGhenis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM — the critical fix is exactly right.

Verified:

  • l0/distributions.py:137 changes torch.sigmoid(self.qz_logits) to torch.sigmoid(self.qz_logits / self.temperature) — matches Louizos et al. 2017 Eq. 11 and the standalone SparseCalibrationWeights/SparseL0Linear closed form.
  • Only the _deterministic_gates mean is changed; _sample_gates, get_penalty, get_active_prob were already temperature-aware and are untouched (no over-reach).
  • Three regression tests are load-bearing:
    • test_deterministic_gates_respect_temperature — confirmed to FAIL on the pre-fix impl: I monkeypatched _deterministic_gates back to sigmoid(qz_logits) and observed torch.allclose(high_gates, low_gates) == True, which violates the test's not torch.allclose assertion.
    • test_deterministic_gates_match_reference_formula — pins the closed form (the fix's Eq. 11).
    • test_sparsity_stats_match_eval_activation — subsumes finding #3 (reported vs actual sparsity agreement within 0.15).
  • Findings #1 (critical), #2 (test gap), #3 (within-class inconsistency) all addressed.
  • No unrelated changes; just the 5-line code fix + 89-line test block + changelog.

Will go green once #44 merges and lint unblocks.

MaxGhenis added a commit that referenced this pull request Apr 17, 2026
* Delete CRITICAL_TEMPERATURE_BUG.md

This file documented a temperature-scaling bug in
`HardConcrete._deterministic_gates` and pointed readers at a "gold
standard" reference file `l0_louizos_improved_gate.py` that does not
actually exist in the repository. The bug itself is fixed by
#41 (the `sigmoid(qz_logits /
temperature)` change in `l0/distributions.py`), and the three standalone
modules (`distributions.py`, `calibration.py`, `sparse.py`) are now
consistent. Remove the stale doc now that its content is either obsolete
or misleading.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Switch CI lint job from black to ruff

#40 switched the repo's formatter from `black -l 79` to `ruff format`
(default 88-char line length) and updated the Makefile, but the reusable
lint workflow still invoked `lgeiger/black-action` with `. -l 79 --check`.
Since ruff-formatted files don't pass `black -l 79`, every PR's `lint`
check has been failing since #40.

Replace the black action with a `uvx ruff format --check .` run, matching
what `make format-check` would do locally.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@MaxGhenis MaxGhenis merged commit 845dd6b into main Apr 17, 2026
2 of 3 checks passed
@MaxGhenis MaxGhenis deleted the fix/temperature-deterministic-gates branch April 17, 2026 16:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant