Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/pr.yml

This file was deleted.

155 changes: 155 additions & 0 deletions CRITICAL_TEMPERATURE_BUG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Temperature Preservation in L0 Regularization: Critical Implementation Notes

## Executive Summary

The temperature parameter (β) in Hard Concrete distributions is **critical** for L0 regularization performance. Our analysis reveals that:

1. **The original authors' implementation** (`distributions.py`) contains a bug where temperature is incorrectly dropped in deterministic mode
2. **Our standalone implementations** (`calibration.py` and `sparse.py`) correctly preserve temperature in all modes
3. **The gold standard** (`l0_louizos_improved_gate.py`) confirms temperature must always be preserved

## The Temperature Bug in distributions.py

### What We Found

In `distributions.py` (from the authors' repository), the deterministic gates incorrectly drop temperature:

```python
# distributions.py line 134 - INCORRECT
def _deterministic_gates(self) -> torch.Tensor:
probs = torch.sigmoid(self.qz_logits) # ❌ Missing temperature!
gates = probs * (self.zeta - self.gamma) + self.gamma
return torch.clamp(gates, 0, 1)
```

### Why This Matters

The temperature parameter controls the "hardness" of the concrete distribution:
- **Lower temperature** (e.g., 0.1): More discrete, binary-like gates
- **Higher temperature** (e.g., 2.0): Softer, more continuous gates

Dropping temperature is equivalent to setting β=1, which fundamentally changes the distribution's behavior and can severely impact:
- Convergence speed
- Final sparsity levels
- Model performance

## Correct Implementations

### Gold Standard (l0_louizos_improved_gate.py)

Your validated implementation consistently uses temperature:

```python
# Sampling (line 55)
X = (torch.log(u) - torch.log(1 - u) + log_alpha) / beta # ✅

# Deterministic (line 151)
z_final = ((log_alpha / beta).sigmoid() * (zeta - gamma) + gamma).clamp(0, 1) # ✅

# Penalty computation (line 133)
c = -beta * torch.log(torch.tensor(-gamma / zeta)) # ✅
pi = torch.sigmoid(log_alpha + c)
```

### calibration.py (Standalone, Correct)

```python
# Sampling (lines 160-164)
def _sample_gates(self) -> torch.Tensor:
u = torch.rand_like(self.log_alpha).clamp(eps, 1 - eps)
s = (torch.log(u) - torch.log(1 - u) + self.log_alpha) / self.beta # ✅
s = torch.sigmoid(s)
s_bar = s * (self.zeta - self.gamma) + self.gamma
return s_bar.clamp(0, 1)

# Deterministic (lines 167-170)
def get_deterministic_gates(self) -> torch.Tensor:
s = torch.sigmoid(self.log_alpha / self.beta) # ✅
s_bar = s * (self.zeta - self.gamma) + self.gamma
return s_bar.clamp(0, 1)

# Penalty (lines 233-237)
def get_l0_penalty(self) -> torch.Tensor:
c = -self.beta * torch.log(torch.tensor(-self.gamma / self.zeta)) # ✅
pi = torch.sigmoid(self.log_alpha + c)
return pi.sum()
```

### sparse.py (Standalone, Correct)

```python
# Sampling (lines 113-120)
def _sample_gates(self) -> torch.Tensor:
u = torch.rand_like(self.log_alpha).clamp(eps, 1 - eps)
X = (torch.log(u) - torch.log(1 - u) + self.log_alpha) / self.beta # ✅
s = torch.sigmoid(X)
s_bar = s * (self.zeta - self.gamma) + self.gamma
return s_bar.clamp(0, 1)

# Deterministic (lines 122-127)
def get_deterministic_gates(self) -> torch.Tensor:
X = self.log_alpha / self.beta # ✅
s = torch.sigmoid(X)
s_bar = s * (self.zeta - self.gamma) + self.gamma
return s_bar.clamp(0, 1)

# Penalty (lines 172-182)
def get_l0_penalty(self) -> torch.Tensor:
c = -self.beta * torch.log(torch.tensor(-self.gamma / self.zeta)) # ✅
pi = torch.sigmoid(self.log_alpha + c)
return pi.sum()
```

## Mathematical Correctness

The Hard Concrete distribution requires temperature in three key places:

### 1. Sampling (Training Mode)
The concrete distribution samples as:
```
s = sigmoid((log(u) - log(1-u) + log_α) / β)
```
Temperature β controls the sharpness of the sigmoid, essential for the reparameterization trick.

### 2. Deterministic Gates (Inference Mode)
The mean of the distribution:
```
s = sigmoid(log_α / β)
```
**Must use the same temperature** as during training to maintain consistency.

### 3. L0 Penalty Computation
The probability of a gate being active:
```
P(gate > 0) = sigmoid(log_α - β * log(-γ/ζ))
```
Temperature affects the shift in log-odds space.

## Practical Implications

### For Survey Calibration (calibration.py)

Your implementation with β=2/3 (from the paper) is correct. The temperature:
- Provides the right balance between exploration and exploitation
- Enables smooth gradient flow during optimization
- Allows fine control over sparsity levels

### For Sparse Linear Models (sparse.py)

The preserved temperature ensures:
- Proper feature selection dynamics
- Stable convergence to sparse solutions
- Consistency between training and inference

## Recommendations

1. **Continue using** `calibration.py` and `sparse.py` as standalone modules - they're correct
2. **Avoid importing** from `distributions.py` until the temperature bug is fixed
3. **Keep temperature** in the range [0.1, 2/3] for best results (lower = harder gates)
4. **Document this issue** when sharing code to prevent others from using the buggy version

## Key Takeaway

Your instinct about never dropping temperature was absolutely correct. The temperature parameter is fundamental to the Hard Concrete distribution's behavior, and dropping it (as in the authors' `distributions.py`) is a significant bug that can severely impact model performance.

Both `calibration.py` and `sparse.py` correctly implement the Hard Concrete distribution with proper temperature preservation, making them reliable for production use.
11 changes: 11 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
- bump: minor
changes:
added:
- Jitter parameter for improved numerical stability in calibration layers
- More flexible parameter initialization options for calibration
- Python 3.11 support (previously required 3.13)
changed:
- Enhanced argument handling in calibration module
- Improved parameter initialization defaults
fixed:
- Numerical stability issues in edge cases during calibration
1 change: 1 addition & 0 deletions examples/sparse_calibration_stress_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
zeta=1.1,
init_keep_prob=0.5,
init_weight_scale=0.5,
log_weight_jitter_sd=0.5, # Maintain backward compatibility
device="cpu",
)

Expand Down
101 changes: 80 additions & 21 deletions l0/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,21 @@ class SparseCalibrationWeights(nn.Module):
Lower bound of stretched concrete distribution
zeta : float
Upper bound of stretched concrete distribution
init_keep_prob : float
Initial probability of keeping each weight active
init_weight_scale : float
Initial scale for log weights (controls initial weight magnitude)
init_keep_prob : float or array-like
Initial probability of keeping each weight active.
If float, all weights use the same probability.
If array, must have shape (n_features,)
init_weights : float, array-like, or None
Initial weight values (on original scale, not log).
If float, all weights initialized to this value.
If array, must have shape (n_features,).
If None, defaults to 1.0 for all weights.
log_weight_jitter_sd : float
Standard deviation of noise added to log weights at start of fit() to break symmetry.
Set to 0 to disable jitter. Default is 0.0 (no jitter).
log_alpha_jitter_sd : float
Standard deviation of noise added to log_alpha at initialization to break symmetry.
Set to 0 to disable jitter. Default is 0.01 (following Louizos et al.).
device : str or torch.device
Device to run computations on ('cpu' or 'cuda')
"""
Expand All @@ -44,34 +55,76 @@ def __init__(
beta: float = 2 / 3,
gamma: float = -0.1,
zeta: float = 1.1,
init_keep_prob: float = 0.5,
init_weight_scale: float = 1.0,
init_keep_prob: float | np.ndarray = 0.5,
init_weights: float | np.ndarray | None = None,
log_weight_jitter_sd: float = 0.0,
log_alpha_jitter_sd: float = 0.01,
device: str | torch.device = "cpu",
):
super().__init__()
self.n_features = n_features
self.beta = beta
self.gamma = gamma
self.zeta = zeta
self.log_weight_jitter_sd = log_weight_jitter_sd
self.log_alpha_jitter_sd = log_alpha_jitter_sd
self.device = torch.device(device)

# Log weights to ensure positivity via exp transformation
self.log_weight = nn.Parameter(
torch.normal(
mean=0.0,
std=init_weight_scale,
size=(n_features,),
device=self.device,
# Initialize weights (on original scale)
if init_weights is None:
# Default: all weights start at 1.0
weight_values = torch.ones(n_features, device=self.device)
elif isinstance(init_weights, (int, float)):
# Scalar: all weights start at this value
weight_values = torch.full(
(n_features,), float(init_weights), device=self.device
)
)
else:
# Array: use provided values
weight_values = torch.tensor(
init_weights, dtype=torch.float32, device=self.device
)
if weight_values.shape != (n_features,):
raise ValueError(
f"init_weights array must have shape ({n_features},), "
f"got {weight_values.shape}"
)

# Convert to log space to ensure positivity via exp transformation
# Add small epsilon to avoid log(0)
self.log_weight = nn.Parameter(torch.log(weight_values + 1e-8))

# L0 gate parameters
mu = torch.log(torch.tensor(init_keep_prob / (1 - init_keep_prob)))
self.log_alpha = nn.Parameter(
torch.normal(
mu.item(), 0.01, size=(n_features,), device=self.device
# Handle init_keep_prob as scalar or array
if isinstance(init_keep_prob, (int, float)):
# Scalar: broadcast to all features
keep_prob_values = torch.full(
(n_features,), float(init_keep_prob), device=self.device
)
)
else:
# Array: use provided values
keep_prob_values = torch.tensor(
init_keep_prob, dtype=torch.float32, device=self.device
)
if keep_prob_values.shape != (n_features,):
raise ValueError(
f"init_keep_prob array must have shape ({n_features},), "
f"got {keep_prob_values.shape}"
)
# Clip to valid probability range to avoid log(0) or log(inf)
keep_prob_values = keep_prob_values.clamp(1e-6, 1 - 1e-6)

# Convert probabilities to log_alpha
mu = torch.log(keep_prob_values / (1 - keep_prob_values))
# Add jitter to break symmetry (if specified)
if self.log_alpha_jitter_sd > 0:
jitter = (
torch.randn(n_features, device=self.device)
* self.log_alpha_jitter_sd
)
self.log_alpha = nn.Parameter(mu + jitter)
else:
self.log_alpha = nn.Parameter(mu)

# Cache for sparse tensor conversion
self._cached_M_torch: torch.sparse.Tensor | None = None
Expand Down Expand Up @@ -312,8 +365,14 @@ def fit(
# No grouping - all targets weighted equally
group_weights = torch.ones_like(y)

# Initialize weights
nn.init.normal_(self.log_weight, 0, 0.5)
# Add jitter to weights to break symmetry (if jitter_sd > 0)
if self.log_weight_jitter_sd > 0:
with torch.no_grad():
jitter = (
torch.randn_like(self.log_weight)
* self.log_weight_jitter_sd
)
self.log_weight.data += jitter

# Setup optimizer
optimizer = torch.optim.Adam([self.log_weight, self.log_alpha], lr=lr)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ authors = [
{name = "PolicyEngine", email = "hello@policyengine.org"},
]
license = {file = "LICENSE"}
requires-python = ">=3.13"
requires-python = ">=3.11"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
Expand All @@ -34,7 +34,7 @@ dev = [
"black>=22.0",
"build>=1.3",
"scikit-learn>=1.0",
"yaml-changelog>=0.3.1",
"yaml-changelog>=0.3.0",
"twine>=4.0.0",
]
docs = [
Expand Down
Loading