diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml deleted file mode 100644 index 33a4329..0000000 --- a/.github/workflows/pr.yml +++ /dev/null @@ -1,2 +0,0 @@ -# This file is deprecated - use pr_changelog.yaml and pr_code_changes.yaml instead -# Keeping for backwards compatibility \ No newline at end of file diff --git a/CRITICAL_TEMPERATURE_BUG.md b/CRITICAL_TEMPERATURE_BUG.md new file mode 100644 index 0000000..959e530 --- /dev/null +++ b/CRITICAL_TEMPERATURE_BUG.md @@ -0,0 +1,155 @@ +# Temperature Preservation in L0 Regularization: Critical Implementation Notes + +## Executive Summary + +The temperature parameter (β) in Hard Concrete distributions is **critical** for L0 regularization performance. Our analysis reveals that: + +1. **The original authors' implementation** (`distributions.py`) contains a bug where temperature is incorrectly dropped in deterministic mode +2. **Our standalone implementations** (`calibration.py` and `sparse.py`) correctly preserve temperature in all modes +3. **The gold standard** (`l0_louizos_improved_gate.py`) confirms temperature must always be preserved + +## The Temperature Bug in distributions.py + +### What We Found + +In `distributions.py` (from the authors' repository), the deterministic gates incorrectly drop temperature: + +```python +# distributions.py line 134 - INCORRECT +def _deterministic_gates(self) -> torch.Tensor: + probs = torch.sigmoid(self.qz_logits) # ❌ Missing temperature! + gates = probs * (self.zeta - self.gamma) + self.gamma + return torch.clamp(gates, 0, 1) +``` + +### Why This Matters + +The temperature parameter controls the "hardness" of the concrete distribution: +- **Lower temperature** (e.g., 0.1): More discrete, binary-like gates +- **Higher temperature** (e.g., 2.0): Softer, more continuous gates + +Dropping temperature is equivalent to setting β=1, which fundamentally changes the distribution's behavior and can severely impact: +- Convergence speed +- Final sparsity levels +- Model performance + +## Correct Implementations + +### Gold Standard (l0_louizos_improved_gate.py) + +Your validated implementation consistently uses temperature: + +```python +# Sampling (line 55) +X = (torch.log(u) - torch.log(1 - u) + log_alpha) / beta # ✅ + +# Deterministic (line 151) +z_final = ((log_alpha / beta).sigmoid() * (zeta - gamma) + gamma).clamp(0, 1) # ✅ + +# Penalty computation (line 133) +c = -beta * torch.log(torch.tensor(-gamma / zeta)) # ✅ +pi = torch.sigmoid(log_alpha + c) +``` + +### calibration.py (Standalone, Correct) + +```python +# Sampling (lines 160-164) +def _sample_gates(self) -> torch.Tensor: + u = torch.rand_like(self.log_alpha).clamp(eps, 1 - eps) + s = (torch.log(u) - torch.log(1 - u) + self.log_alpha) / self.beta # ✅ + s = torch.sigmoid(s) + s_bar = s * (self.zeta - self.gamma) + self.gamma + return s_bar.clamp(0, 1) + +# Deterministic (lines 167-170) +def get_deterministic_gates(self) -> torch.Tensor: + s = torch.sigmoid(self.log_alpha / self.beta) # ✅ + s_bar = s * (self.zeta - self.gamma) + self.gamma + return s_bar.clamp(0, 1) + +# Penalty (lines 233-237) +def get_l0_penalty(self) -> torch.Tensor: + c = -self.beta * torch.log(torch.tensor(-self.gamma / self.zeta)) # ✅ + pi = torch.sigmoid(self.log_alpha + c) + return pi.sum() +``` + +### sparse.py (Standalone, Correct) + +```python +# Sampling (lines 113-120) +def _sample_gates(self) -> torch.Tensor: + u = torch.rand_like(self.log_alpha).clamp(eps, 1 - eps) + X = (torch.log(u) - torch.log(1 - u) + self.log_alpha) / self.beta # ✅ + s = torch.sigmoid(X) + s_bar = s * (self.zeta - self.gamma) + self.gamma + return s_bar.clamp(0, 1) + +# Deterministic (lines 122-127) +def get_deterministic_gates(self) -> torch.Tensor: + X = self.log_alpha / self.beta # ✅ + s = torch.sigmoid(X) + s_bar = s * (self.zeta - self.gamma) + self.gamma + return s_bar.clamp(0, 1) + +# Penalty (lines 172-182) +def get_l0_penalty(self) -> torch.Tensor: + c = -self.beta * torch.log(torch.tensor(-self.gamma / self.zeta)) # ✅ + pi = torch.sigmoid(self.log_alpha + c) + return pi.sum() +``` + +## Mathematical Correctness + +The Hard Concrete distribution requires temperature in three key places: + +### 1. Sampling (Training Mode) +The concrete distribution samples as: +``` +s = sigmoid((log(u) - log(1-u) + log_α) / β) +``` +Temperature β controls the sharpness of the sigmoid, essential for the reparameterization trick. + +### 2. Deterministic Gates (Inference Mode) +The mean of the distribution: +``` +s = sigmoid(log_α / β) +``` +**Must use the same temperature** as during training to maintain consistency. + +### 3. L0 Penalty Computation +The probability of a gate being active: +``` +P(gate > 0) = sigmoid(log_α - β * log(-γ/ζ)) +``` +Temperature affects the shift in log-odds space. + +## Practical Implications + +### For Survey Calibration (calibration.py) + +Your implementation with β=2/3 (from the paper) is correct. The temperature: +- Provides the right balance between exploration and exploitation +- Enables smooth gradient flow during optimization +- Allows fine control over sparsity levels + +### For Sparse Linear Models (sparse.py) + +The preserved temperature ensures: +- Proper feature selection dynamics +- Stable convergence to sparse solutions +- Consistency between training and inference + +## Recommendations + +1. **Continue using** `calibration.py` and `sparse.py` as standalone modules - they're correct +2. **Avoid importing** from `distributions.py` until the temperature bug is fixed +3. **Keep temperature** in the range [0.1, 2/3] for best results (lower = harder gates) +4. **Document this issue** when sharing code to prevent others from using the buggy version + +## Key Takeaway + +Your instinct about never dropping temperature was absolutely correct. The temperature parameter is fundamental to the Hard Concrete distribution's behavior, and dropping it (as in the authors' `distributions.py`) is a significant bug that can severely impact model performance. + +Both `calibration.py` and `sparse.py` correctly implement the Hard Concrete distribution with proper temperature preservation, making them reliable for production use. \ No newline at end of file diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..53bdd45 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,11 @@ +- bump: minor + changes: + added: + - Jitter parameter for improved numerical stability in calibration layers + - More flexible parameter initialization options for calibration + - Python 3.11 support (previously required 3.13) + changed: + - Enhanced argument handling in calibration module + - Improved parameter initialization defaults + fixed: + - Numerical stability issues in edge cases during calibration \ No newline at end of file diff --git a/examples/sparse_calibration_stress_test.py b/examples/sparse_calibration_stress_test.py index 43d96d9..885011c 100644 --- a/examples/sparse_calibration_stress_test.py +++ b/examples/sparse_calibration_stress_test.py @@ -49,6 +49,7 @@ zeta=1.1, init_keep_prob=0.5, init_weight_scale=0.5, + log_weight_jitter_sd=0.5, # Maintain backward compatibility device="cpu", ) diff --git a/l0/calibration.py b/l0/calibration.py index fb59cdc..3b2eb90 100644 --- a/l0/calibration.py +++ b/l0/calibration.py @@ -30,10 +30,21 @@ class SparseCalibrationWeights(nn.Module): Lower bound of stretched concrete distribution zeta : float Upper bound of stretched concrete distribution - init_keep_prob : float - Initial probability of keeping each weight active - init_weight_scale : float - Initial scale for log weights (controls initial weight magnitude) + init_keep_prob : float or array-like + Initial probability of keeping each weight active. + If float, all weights use the same probability. + If array, must have shape (n_features,) + init_weights : float, array-like, or None + Initial weight values (on original scale, not log). + If float, all weights initialized to this value. + If array, must have shape (n_features,). + If None, defaults to 1.0 for all weights. + log_weight_jitter_sd : float + Standard deviation of noise added to log weights at start of fit() to break symmetry. + Set to 0 to disable jitter. Default is 0.0 (no jitter). + log_alpha_jitter_sd : float + Standard deviation of noise added to log_alpha at initialization to break symmetry. + Set to 0 to disable jitter. Default is 0.01 (following Louizos et al.). device : str or torch.device Device to run computations on ('cpu' or 'cuda') """ @@ -44,8 +55,10 @@ def __init__( beta: float = 2 / 3, gamma: float = -0.1, zeta: float = 1.1, - init_keep_prob: float = 0.5, - init_weight_scale: float = 1.0, + init_keep_prob: float | np.ndarray = 0.5, + init_weights: float | np.ndarray | None = None, + log_weight_jitter_sd: float = 0.0, + log_alpha_jitter_sd: float = 0.01, device: str | torch.device = "cpu", ): super().__init__() @@ -53,25 +66,65 @@ def __init__( self.beta = beta self.gamma = gamma self.zeta = zeta + self.log_weight_jitter_sd = log_weight_jitter_sd + self.log_alpha_jitter_sd = log_alpha_jitter_sd self.device = torch.device(device) - # Log weights to ensure positivity via exp transformation - self.log_weight = nn.Parameter( - torch.normal( - mean=0.0, - std=init_weight_scale, - size=(n_features,), - device=self.device, + # Initialize weights (on original scale) + if init_weights is None: + # Default: all weights start at 1.0 + weight_values = torch.ones(n_features, device=self.device) + elif isinstance(init_weights, (int, float)): + # Scalar: all weights start at this value + weight_values = torch.full( + (n_features,), float(init_weights), device=self.device ) - ) + else: + # Array: use provided values + weight_values = torch.tensor( + init_weights, dtype=torch.float32, device=self.device + ) + if weight_values.shape != (n_features,): + raise ValueError( + f"init_weights array must have shape ({n_features},), " + f"got {weight_values.shape}" + ) + + # Convert to log space to ensure positivity via exp transformation + # Add small epsilon to avoid log(0) + self.log_weight = nn.Parameter(torch.log(weight_values + 1e-8)) # L0 gate parameters - mu = torch.log(torch.tensor(init_keep_prob / (1 - init_keep_prob))) - self.log_alpha = nn.Parameter( - torch.normal( - mu.item(), 0.01, size=(n_features,), device=self.device + # Handle init_keep_prob as scalar or array + if isinstance(init_keep_prob, (int, float)): + # Scalar: broadcast to all features + keep_prob_values = torch.full( + (n_features,), float(init_keep_prob), device=self.device ) - ) + else: + # Array: use provided values + keep_prob_values = torch.tensor( + init_keep_prob, dtype=torch.float32, device=self.device + ) + if keep_prob_values.shape != (n_features,): + raise ValueError( + f"init_keep_prob array must have shape ({n_features},), " + f"got {keep_prob_values.shape}" + ) + # Clip to valid probability range to avoid log(0) or log(inf) + keep_prob_values = keep_prob_values.clamp(1e-6, 1 - 1e-6) + + # Convert probabilities to log_alpha + mu = torch.log(keep_prob_values / (1 - keep_prob_values)) + # Add jitter to break symmetry (if specified) + if self.log_alpha_jitter_sd > 0: + jitter = ( + torch.randn(n_features, device=self.device) + * self.log_alpha_jitter_sd + ) + self.log_alpha = nn.Parameter(mu + jitter) + else: + self.log_alpha = nn.Parameter(mu) # Cache for sparse tensor conversion self._cached_M_torch: torch.sparse.Tensor | None = None @@ -312,8 +365,14 @@ def fit( # No grouping - all targets weighted equally group_weights = torch.ones_like(y) - # Initialize weights - nn.init.normal_(self.log_weight, 0, 0.5) + # Add jitter to weights to break symmetry (if jitter_sd > 0) + if self.log_weight_jitter_sd > 0: + with torch.no_grad(): + jitter = ( + torch.randn_like(self.log_weight) + * self.log_weight_jitter_sd + ) + self.log_weight.data += jitter # Setup optimizer optimizer = torch.optim.Adam([self.log_weight, self.log_alpha], lr=lr) diff --git a/pyproject.toml b/pyproject.toml index 6a243a8..32e794e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ {name = "PolicyEngine", email = "hello@policyengine.org"}, ] license = {file = "LICENSE"} -requires-python = ">=3.13" +requires-python = ">=3.11" classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Science/Research", @@ -34,7 +34,7 @@ dev = [ "black>=22.0", "build>=1.3", "scikit-learn>=1.0", - "yaml-changelog>=0.3.1", + "yaml-changelog>=0.3.0", "twine>=4.0.0", ] docs = [ diff --git a/tests/test_calibration.py b/tests/test_calibration.py index f0ee9c3..af01c3a 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -21,7 +21,7 @@ def test_positive_weights(self): M = sp.random(Q, N, density=0.3, format="csr") y = np.random.randn(Q) + 10 - model = SparseCalibrationWeights(n_features=N) + model = SparseCalibrationWeights(n_features=N, init_weights=1.0) model.fit(M, y, epochs=100, verbose=False) # Check positivity @@ -57,7 +57,8 @@ def test_sparse_ground_truth_relative_loss(self): gamma=-0.1, zeta=1.1, init_keep_prob=0.3, - init_weight_scale=0.5, + init_weights=1.0, # Start all weights at 1.0 + log_weight_jitter_sd=0.5, # Add jitter for symmetry breaking ) model.fit( @@ -171,7 +172,7 @@ def test_sparsity_control(self): def test_get_active_weights(self): """Test active weight extraction.""" N = 100 - model = SparseCalibrationWeights(n_features=N) + model = SparseCalibrationWeights(n_features=N, init_weights=1.0) # Simple test data M = sp.eye(N, format="csr") @@ -202,7 +203,7 @@ def test_deterministic_inference(self): M = sp.random(Q, N, density=0.5, format="csr") y = np.random.randn(Q) - model = SparseCalibrationWeights(n_features=N) + model = SparseCalibrationWeights(n_features=N, init_weights=1.0) model.fit(M, y, epochs=100, verbose=False) # Multiple predictions should be identical @@ -357,7 +358,7 @@ def test_group_wise_averaging_edge_cases(self): M = sp.random(Q, N, density=0.3, format="csr") y = np.random.uniform(100, 1000, size=Q) - model = SparseCalibrationWeights(n_features=N) + model = SparseCalibrationWeights(n_features=N, init_weights=1.0) # Test 1: All targets in one group (should behave like no grouping) target_groups_single = np.zeros(Q, dtype=int) @@ -421,3 +422,124 @@ def test_group_wise_averaging_edge_cases(self): assert ( np.mean(small_group_errors) < 0.5 ), "Small groups should not be ignored" + + def test_init_weights_options(self): + """Test different weight initialization options.""" + N = 50 + + # Test 1: Default (None) should give all weights = 1.0 + model_default = SparseCalibrationWeights(n_features=N) + with torch.no_grad(): + weights = torch.exp(model_default.log_weight) + assert torch.allclose(weights, torch.ones(N), atol=1e-6) + + # Test 2: Scalar initialization + model_scalar = SparseCalibrationWeights(n_features=N, init_weights=2.5) + with torch.no_grad(): + weights = torch.exp(model_scalar.log_weight) + assert torch.allclose(weights, torch.full((N,), 2.5), atol=1e-6) + + # Test 3: Array initialization + init_array = np.random.uniform(0.5, 2.0, size=N) + model_array = SparseCalibrationWeights( + n_features=N, init_weights=init_array + ) + with torch.no_grad(): + weights = torch.exp(model_array.log_weight).cpu().numpy() + np.testing.assert_allclose(weights, init_array, rtol=1e-5) + + # Test 4: Wrong shape should raise error + with pytest.raises(ValueError, match="must have shape"): + SparseCalibrationWeights(n_features=N, init_weights=np.ones(N + 1)) + + def test_weight_jitter(self): + """Test that weight jitter works correctly.""" + N = 100 + Q = 20 + + np.random.seed(42) + torch.manual_seed(42) + + M = sp.random(Q, N, density=0.3, format="csr") + y = np.random.randn(Q) + 10 + + # Model with jitter + model_with_jitter = SparseCalibrationWeights( + n_features=N, init_weights=1.0, log_weight_jitter_sd=0.5 + ) + + # Store initial weights + initial_weights = model_with_jitter.log_weight.data.clone() + + # Fit should add jitter + model_with_jitter.fit(M, y, epochs=10, verbose=False) + + # Weights should have changed due to jitter (and training) + final_weights = model_with_jitter.log_weight.data + assert not torch.allclose(initial_weights, final_weights) + + # Model without jitter + torch.manual_seed(42) # Reset seed + model_no_jitter = SparseCalibrationWeights( + n_features=N, + init_weights=1.0, + log_weight_jitter_sd=0.0, # No jitter + ) + + initial_weights_no_jitter = model_no_jitter.log_weight.data.clone() + model_no_jitter.fit(M, y, epochs=1, verbose=False) # Just 1 epoch + + # After 1 epoch, change should be small without jitter + weights_after_1_epoch = model_no_jitter.log_weight.data + # The change is due to gradient updates only + change = ( + (weights_after_1_epoch - initial_weights_no_jitter).abs().max() + ) + assert change < 1.0, "Without jitter, initial change should be small" + + def test_init_keep_prob_options(self): + """Test init_keep_prob as scalar and array.""" + n_features = 20 + + # Test 1: Scalar init_keep_prob (existing behavior) + model_scalar = SparseCalibrationWeights( + n_features=n_features, + init_keep_prob=0.7, + ) + # All log_alpha values should be similar (around log(0.7/0.3) plus small jitter) + expected_mu = np.log(0.7 / 0.3) + with torch.no_grad(): + log_alphas = model_scalar.log_alpha.numpy() + # Check they're all close to expected value (within jitter range) + assert np.all(np.abs(log_alphas - expected_mu) < 0.1) + + # Test 2: Array init_keep_prob + keep_probs = np.linspace(0.1, 0.9, n_features) + model_array = SparseCalibrationWeights( + n_features=n_features, + init_keep_prob=keep_probs, + ) + # Each log_alpha should correspond to its keep_prob + with torch.no_grad(): + log_alphas = model_array.log_alpha.numpy() + expected_mus = np.log(keep_probs / (1 - keep_probs)) + # Check each is close to its expected value + assert np.all(np.abs(log_alphas - expected_mus) < 0.1) + + # Test 3: Wrong shape should raise error + with pytest.raises(ValueError, match="must have shape"): + SparseCalibrationWeights( + n_features=10, + init_keep_prob=np.ones(5), # Wrong size + ) + + # Test 4: Edge case probabilities get clamped + extreme_probs = np.array([0.0, 0.5, 1.0]) + model_extreme = SparseCalibrationWeights( + n_features=3, + init_keep_prob=extreme_probs, + ) + with torch.no_grad(): + log_alphas = model_extreme.log_alpha.numpy() + # Should not have inf or -inf values + assert np.all(np.isfinite(log_alphas))