From 10c8356ea8f8ccda9da863976b2865816894842d Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:25:56 -0400 Subject: [PATCH] Fix temperature scaling in HardConcrete deterministic gates `HardConcrete._deterministic_gates` computed `sigmoid(qz_logits)` without dividing by `temperature`, so `.eval()` output silently ignored the temperature parameter. This contradicted `_sample_gates`, `get_penalty`, and `get_active_prob`, which all apply the temperature scaling. For PolicyEngine's default `temperature=0.25` this was a 4x distortion in log-odds space and broke train/eval consistency for every L0Linear / L0Conv2d / L0DepthwiseConv2d / L0Gate user. TemperatureScheduler updates were also silently dropped at eval time. The fix uses `sigmoid(qz_logits / temperature)`, matching Louizos et al. 2017 Eq. 11 and the temperature-correct implementations in `SparseCalibrationWeights` (calibration.py) and `SparseL0Linear` (sparse.py). Adds three regression tests: - `test_deterministic_gates_respect_temperature` fixes identical logits, varies temperature, and asserts eval output differs. Confirmed to FAIL on the pre-fix code (`AssertionError: Deterministic gates should depend on temperature`). - `test_deterministic_gates_match_reference_formula` pins the closed-form output against `sigmoid(log_alpha / beta) * (zeta - gamma) + gamma`. - `test_sparsity_stats_match_eval_activation` checks that `get_sparsity()` (temperature-aware) stays consistent with the fraction of gates actually non-zero at eval time. Co-Authored-By: Claude Opus 4.7 (1M context) --- ...x-temperature-deterministic-gates.fixed.md | 6 ++ l0/distributions.py | 9 +- tests/test_distributions.py | 89 +++++++++++++++++++ 3 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 changelog.d/fix-temperature-deterministic-gates.fixed.md diff --git a/changelog.d/fix-temperature-deterministic-gates.fixed.md b/changelog.d/fix-temperature-deterministic-gates.fixed.md new file mode 100644 index 0000000..717d32a --- /dev/null +++ b/changelog.d/fix-temperature-deterministic-gates.fixed.md @@ -0,0 +1,6 @@ +Fix `HardConcrete._deterministic_gates` to apply the temperature scaling so +that `eval()` output matches `_sample_gates`, `get_penalty`, and +`get_active_prob`. Previously the deterministic branch used +`sigmoid(qz_logits)` without dividing by `temperature`, producing a 4x +distortion for PolicyEngine's default `temperature=0.25` and silently +ignoring `TemperatureScheduler` updates at eval time. diff --git a/l0/distributions.py b/l0/distributions.py index 40c5c98..9a03e35 100644 --- a/l0/distributions.py +++ b/l0/distributions.py @@ -123,13 +123,18 @@ def _deterministic_gates(self) -> torch.Tensor: """ Compute deterministic gates for evaluation. + Uses the mean of the Hard Concrete distribution, which applies the + temperature scaling to ``qz_logits`` so that ``eval()`` output is + consistent with ``_sample_gates`` and ``get_penalty`` / + ``get_active_prob`` (see Louizos et al. 2017, Eq. 11). + Returns ------- torch.Tensor Deterministic gate values in [0, 1] """ - # Use mean of the distribution - probs = torch.sigmoid(self.qz_logits) + # Mean of the binary concrete before stretch: sigmoid(logits / beta). + probs = torch.sigmoid(self.qz_logits / self.temperature) # Apply stretching transformation gates = probs * (self.zeta - self.gamma) + self.gamma diff --git a/tests/test_distributions.py b/tests/test_distributions.py index bfac63b..92aea6d 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -191,3 +191,92 @@ def test_zero_gradient_in_eval(self, basic_gate): # In eval mode, gates should be deterministic # but gradients can still flow through assert basic_gate.qz_logits.grad is not None + + def test_deterministic_gates_respect_temperature(self): + """Regression test for the temperature bug in ``_deterministic_gates``. + + Previously ``_deterministic_gates`` computed ``sigmoid(qz_logits)`` + without the ``/ temperature`` scaling, so two models with identical + ``qz_logits`` but different temperatures produced identical eval + gates. The fix matches Louizos et al. 2017 Eq. 11: the mean of the + binary concrete before stretch is ``sigmoid(log_alpha / beta)``. + """ + high_temp = HardConcrete(10, temperature=2.0, init_mean=0.5) + low_temp = HardConcrete(10, temperature=0.1, init_mean=0.5) + + # Install identical, non-zero logits so temperature scaling matters. + fixed_logits = torch.linspace(-3.0, 3.0, 10) + with torch.no_grad(): + high_temp.qz_logits.copy_(fixed_logits) + low_temp.qz_logits.copy_(fixed_logits) + + high_temp.eval() + low_temp.eval() + + high_gates = high_temp() + low_gates = low_temp() + + # With the bug these would be bit-identical. After the fix they must + # differ everywhere the logits are non-zero. + assert not torch.allclose(high_gates, low_gates, atol=1e-6), ( + "Deterministic gates should depend on temperature" + ) + + # Spot-check the closed form of the fix: sigmoid(logits / beta) * (zeta - gamma) + gamma + expected_low = torch.clamp( + torch.sigmoid(fixed_logits / 0.1) * (low_temp.zeta - low_temp.gamma) + + low_temp.gamma, + 0, + 1, + ) + assert torch.allclose(low_gates, expected_low, atol=1e-6) + + def test_deterministic_gates_match_reference_formula(self): + """``_deterministic_gates`` must follow ``sigmoid(log_alpha / beta)``. + + This is the mean of the stretched binary concrete (Louizos et al. 2017 + Eq. 11) and is what the standalone ``SparseCalibrationWeights`` / + ``SparseL0Linear`` modules implement. Any regression that drops the + temperature from this branch makes the three modules inconsistent. + """ + gate = HardConcrete(6, temperature=0.4, init_mean=0.5) + with torch.no_grad(): + gate.qz_logits.copy_(torch.linspace(-2.0, 2.0, 6)) + gate.eval() + + deterministic = gate() + + expected = torch.clamp( + torch.sigmoid(gate.qz_logits / gate.temperature) + * (gate.zeta - gate.gamma) + + gate.gamma, + 0, + 1, + ) + assert torch.allclose(deterministic, expected, atol=1e-6) + + def test_sparsity_stats_match_eval_activation(self): + """``get_sparsity`` should agree with the fraction of active eval gates. + + ``get_sparsity`` uses ``get_active_prob`` (which applies the + temperature shift), while ``_deterministic_gates`` now applies the + same temperature scaling. With the fix the two should be numerically + consistent: the reported sparsity cannot diverge wildly from the + fraction of gates that are effectively zero at eval time. + """ + torch.manual_seed(0) + gate = HardConcrete(200, temperature=0.25, init_mean=0.5) + with torch.no_grad(): + gate.qz_logits.copy_(torch.randn(200)) + + reported_sparsity = gate.get_sparsity() + + gate.eval() + eval_gates = gate() + # Active at eval-time = gate value strictly above zero. + eval_active_fraction = (eval_gates > 0).float().mean().item() + eval_sparsity = 1.0 - eval_active_fraction + + # Both numbers come from the same temperature-aware distribution + # so they must be close (within sampling-free rounding slack). + assert abs(reported_sparsity - eval_sparsity) < 0.15