Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions changelog.d/fix-temperature-deterministic-gates.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Fix `HardConcrete._deterministic_gates` to apply the temperature scaling so
that `eval()` output matches `_sample_gates`, `get_penalty`, and
`get_active_prob`. Previously the deterministic branch used
`sigmoid(qz_logits)` without dividing by `temperature`, producing a 4x
distortion for PolicyEngine's default `temperature=0.25` and silently
ignoring `TemperatureScheduler` updates at eval time.
9 changes: 7 additions & 2 deletions l0/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,18 @@ def _deterministic_gates(self) -> torch.Tensor:
"""
Compute deterministic gates for evaluation.

Uses the mean of the Hard Concrete distribution, which applies the
temperature scaling to ``qz_logits`` so that ``eval()`` output is
consistent with ``_sample_gates`` and ``get_penalty`` /
``get_active_prob`` (see Louizos et al. 2017, Eq. 11).

Returns
-------
torch.Tensor
Deterministic gate values in [0, 1]
"""
# Use mean of the distribution
probs = torch.sigmoid(self.qz_logits)
# Mean of the binary concrete before stretch: sigmoid(logits / beta).
probs = torch.sigmoid(self.qz_logits / self.temperature)

# Apply stretching transformation
gates = probs * (self.zeta - self.gamma) + self.gamma
Expand Down
89 changes: 89 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,92 @@ def test_zero_gradient_in_eval(self, basic_gate):
# In eval mode, gates should be deterministic
# but gradients can still flow through
assert basic_gate.qz_logits.grad is not None

def test_deterministic_gates_respect_temperature(self):
"""Regression test for the temperature bug in ``_deterministic_gates``.

Previously ``_deterministic_gates`` computed ``sigmoid(qz_logits)``
without the ``/ temperature`` scaling, so two models with identical
``qz_logits`` but different temperatures produced identical eval
gates. The fix matches Louizos et al. 2017 Eq. 11: the mean of the
binary concrete before stretch is ``sigmoid(log_alpha / beta)``.
"""
high_temp = HardConcrete(10, temperature=2.0, init_mean=0.5)
low_temp = HardConcrete(10, temperature=0.1, init_mean=0.5)

# Install identical, non-zero logits so temperature scaling matters.
fixed_logits = torch.linspace(-3.0, 3.0, 10)
with torch.no_grad():
high_temp.qz_logits.copy_(fixed_logits)
low_temp.qz_logits.copy_(fixed_logits)

high_temp.eval()
low_temp.eval()

high_gates = high_temp()
low_gates = low_temp()

# With the bug these would be bit-identical. After the fix they must
# differ everywhere the logits are non-zero.
assert not torch.allclose(high_gates, low_gates, atol=1e-6), (
"Deterministic gates should depend on temperature"
)

# Spot-check the closed form of the fix: sigmoid(logits / beta) * (zeta - gamma) + gamma
expected_low = torch.clamp(
torch.sigmoid(fixed_logits / 0.1) * (low_temp.zeta - low_temp.gamma)
+ low_temp.gamma,
0,
1,
)
assert torch.allclose(low_gates, expected_low, atol=1e-6)

def test_deterministic_gates_match_reference_formula(self):
"""``_deterministic_gates`` must follow ``sigmoid(log_alpha / beta)``.

This is the mean of the stretched binary concrete (Louizos et al. 2017
Eq. 11) and is what the standalone ``SparseCalibrationWeights`` /
``SparseL0Linear`` modules implement. Any regression that drops the
temperature from this branch makes the three modules inconsistent.
"""
gate = HardConcrete(6, temperature=0.4, init_mean=0.5)
with torch.no_grad():
gate.qz_logits.copy_(torch.linspace(-2.0, 2.0, 6))
gate.eval()

deterministic = gate()

expected = torch.clamp(
torch.sigmoid(gate.qz_logits / gate.temperature)
* (gate.zeta - gate.gamma)
+ gate.gamma,
0,
1,
)
assert torch.allclose(deterministic, expected, atol=1e-6)

def test_sparsity_stats_match_eval_activation(self):
"""``get_sparsity`` should agree with the fraction of active eval gates.

``get_sparsity`` uses ``get_active_prob`` (which applies the
temperature shift), while ``_deterministic_gates`` now applies the
same temperature scaling. With the fix the two should be numerically
consistent: the reported sparsity cannot diverge wildly from the
fraction of gates that are effectively zero at eval time.
"""
torch.manual_seed(0)
gate = HardConcrete(200, temperature=0.25, init_mean=0.5)
with torch.no_grad():
gate.qz_logits.copy_(torch.randn(200))

reported_sparsity = gate.get_sparsity()

gate.eval()
eval_gates = gate()
# Active at eval-time = gate value strictly above zero.
eval_active_fraction = (eval_gates > 0).float().mean().item()
eval_sparsity = 1.0 - eval_active_fraction

# Both numbers come from the same temperature-aware distribution
# so they must be close (within sampling-free rounding slack).
assert abs(reported_sparsity - eval_sparsity) < 0.15
Loading