From 4950851d837975d76734ba7df324a8ce02a699d8 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Fri, 5 Sep 2025 09:39:46 -0400 Subject: [PATCH 1/8] reducing requirements to Python 3.12 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6a243a8..9570634 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ {name = "PolicyEngine", email = "hello@policyengine.org"}, ] license = {file = "LICENSE"} -requires-python = ">=3.13" +requires-python = ">=3.12" classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Science/Research", From 25c7710567c08fceec742d7d6fc1cdfc39967353 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Sat, 6 Sep 2025 20:40:59 -0400 Subject: [PATCH 2/8] jitter parameter improvements --- examples/sparse_calibration_stress_test.py | 1 + l0/calibration.py | 12 ++++++++++-- tests/test_calibration.py | 1 + 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/examples/sparse_calibration_stress_test.py b/examples/sparse_calibration_stress_test.py index 43d96d9..885011c 100644 --- a/examples/sparse_calibration_stress_test.py +++ b/examples/sparse_calibration_stress_test.py @@ -49,6 +49,7 @@ zeta=1.1, init_keep_prob=0.5, init_weight_scale=0.5, + log_weight_jitter_sd=0.5, # Maintain backward compatibility device="cpu", ) diff --git a/l0/calibration.py b/l0/calibration.py index fb59cdc..9b0148d 100644 --- a/l0/calibration.py +++ b/l0/calibration.py @@ -34,6 +34,9 @@ class SparseCalibrationWeights(nn.Module): Initial probability of keeping each weight active init_weight_scale : float Initial scale for log weights (controls initial weight magnitude) + log_weight_jitter_sd : float + Standard deviation of noise added to log weights at start of fit() to break symmetry. + Set to 0 to disable jitter. Default is 0.5 for backward compatibility. device : str or torch.device Device to run computations on ('cpu' or 'cuda') """ @@ -46,6 +49,7 @@ def __init__( zeta: float = 1.1, init_keep_prob: float = 0.5, init_weight_scale: float = 1.0, + log_weight_jitter_sd: float = 0.5, device: str | torch.device = "cpu", ): super().__init__() @@ -53,6 +57,7 @@ def __init__( self.beta = beta self.gamma = gamma self.zeta = zeta + self.log_weight_jitter_sd = log_weight_jitter_sd self.device = torch.device(device) # Log weights to ensure positivity via exp transformation @@ -312,8 +317,11 @@ def fit( # No grouping - all targets weighted equally group_weights = torch.ones_like(y) - # Initialize weights - nn.init.normal_(self.log_weight, 0, 0.5) + # Add jitter to weights to break symmetry (if jitter_sd > 0) + if self.log_weight_jitter_sd > 0: + with torch.no_grad(): + jitter = torch.randn_like(self.log_weight) * self.log_weight_jitter_sd + self.log_weight.data += jitter # Setup optimizer optimizer = torch.optim.Adam([self.log_weight, self.log_alpha], lr=lr) diff --git a/tests/test_calibration.py b/tests/test_calibration.py index f0ee9c3..cff0ce0 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -58,6 +58,7 @@ def test_sparse_ground_truth_relative_loss(self): zeta=1.1, init_keep_prob=0.3, init_weight_scale=0.5, + log_weight_jitter_sd=0.5, # Maintain backward compatibility in tests ) model.fit( From 0abf19f58180375881822cda54447cd0a6052175 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Sun, 7 Sep 2025 17:22:40 -0400 Subject: [PATCH 3/8] better arguments --- l0/calibration.py | 86 ++++++++++++++++++------- tests/test_calibration.py | 131 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 187 insertions(+), 30 deletions(-) diff --git a/l0/calibration.py b/l0/calibration.py index 9b0148d..da0bef2 100644 --- a/l0/calibration.py +++ b/l0/calibration.py @@ -30,13 +30,18 @@ class SparseCalibrationWeights(nn.Module): Lower bound of stretched concrete distribution zeta : float Upper bound of stretched concrete distribution - init_keep_prob : float - Initial probability of keeping each weight active - init_weight_scale : float - Initial scale for log weights (controls initial weight magnitude) - log_weight_jitter_sd : float + init_keep_prob : float or array-like + Initial probability of keeping each weight active. + If float, all weights use the same probability. + If array, must have shape (n_features,) + init_weights : float, array-like, or None + Initial weight values (on original scale, not log). + If float, all weights initialized to this value. + If array, must have shape (n_features,). + If None, defaults to 1.0 for all weights. + weight_jitter_sd : float Standard deviation of noise added to log weights at start of fit() to break symmetry. - Set to 0 to disable jitter. Default is 0.5 for backward compatibility. + Set to 0 to disable jitter. Default is 0.0 (no jitter). device : str or torch.device Device to run computations on ('cpu' or 'cuda') """ @@ -47,9 +52,9 @@ def __init__( beta: float = 2 / 3, gamma: float = -0.1, zeta: float = 1.1, - init_keep_prob: float = 0.5, - init_weight_scale: float = 1.0, - log_weight_jitter_sd: float = 0.5, + init_keep_prob: float | np.ndarray = 0.5, + init_weights: float | np.ndarray | None = None, + weight_jitter_sd: float = 0.0, device: str | torch.device = "cpu", ): super().__init__() @@ -57,25 +62,58 @@ def __init__( self.beta = beta self.gamma = gamma self.zeta = zeta - self.log_weight_jitter_sd = log_weight_jitter_sd + self.weight_jitter_sd = weight_jitter_sd self.device = torch.device(device) - # Log weights to ensure positivity via exp transformation - self.log_weight = nn.Parameter( - torch.normal( - mean=0.0, - std=init_weight_scale, - size=(n_features,), - device=self.device, + # Initialize weights (on original scale) + if init_weights is None: + # Default: all weights start at 1.0 + weight_values = torch.ones(n_features, device=self.device) + elif isinstance(init_weights, (int, float)): + # Scalar: all weights start at this value + weight_values = torch.full( + (n_features,), float(init_weights), device=self.device ) - ) + else: + # Array: use provided values + weight_values = torch.tensor( + init_weights, dtype=torch.float32, device=self.device + ) + if weight_values.shape != (n_features,): + raise ValueError( + f"init_weights array must have shape ({n_features},), " + f"got {weight_values.shape}" + ) + + # Convert to log space to ensure positivity via exp transformation + # Add small epsilon to avoid log(0) + self.log_weight = nn.Parameter(torch.log(weight_values + 1e-8)) # L0 gate parameters - mu = torch.log(torch.tensor(init_keep_prob / (1 - init_keep_prob))) - self.log_alpha = nn.Parameter( - torch.normal( - mu.item(), 0.01, size=(n_features,), device=self.device + # Handle init_keep_prob as scalar or array + if isinstance(init_keep_prob, (int, float)): + # Scalar: broadcast to all features + keep_prob_values = torch.full( + (n_features,), float(init_keep_prob), device=self.device ) + else: + # Array: use provided values + keep_prob_values = torch.tensor( + init_keep_prob, dtype=torch.float32, device=self.device + ) + if keep_prob_values.shape != (n_features,): + raise ValueError( + f"init_keep_prob array must have shape ({n_features},), " + f"got {keep_prob_values.shape}" + ) + # Clip to valid probability range to avoid log(0) or log(inf) + keep_prob_values = keep_prob_values.clamp(1e-6, 1 - 1e-6) + + # Convert probabilities to log_alpha + mu = torch.log(keep_prob_values / (1 - keep_prob_values)) + # Add small jitter to break symmetry + self.log_alpha = nn.Parameter( + mu + torch.randn(n_features, device=self.device) * 0.01 ) # Cache for sparse tensor conversion @@ -318,9 +356,9 @@ def fit( group_weights = torch.ones_like(y) # Add jitter to weights to break symmetry (if jitter_sd > 0) - if self.log_weight_jitter_sd > 0: + if self.weight_jitter_sd > 0: with torch.no_grad(): - jitter = torch.randn_like(self.log_weight) * self.log_weight_jitter_sd + jitter = torch.randn_like(self.log_weight) * self.weight_jitter_sd self.log_weight.data += jitter # Setup optimizer diff --git a/tests/test_calibration.py b/tests/test_calibration.py index cff0ce0..0f3d4ff 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -21,7 +21,7 @@ def test_positive_weights(self): M = sp.random(Q, N, density=0.3, format="csr") y = np.random.randn(Q) + 10 - model = SparseCalibrationWeights(n_features=N) + model = SparseCalibrationWeights(n_features=N, init_weights=1.0) model.fit(M, y, epochs=100, verbose=False) # Check positivity @@ -57,8 +57,8 @@ def test_sparse_ground_truth_relative_loss(self): gamma=-0.1, zeta=1.1, init_keep_prob=0.3, - init_weight_scale=0.5, - log_weight_jitter_sd=0.5, # Maintain backward compatibility in tests + init_weights=1.0, # Start all weights at 1.0 + weight_jitter_sd=0.5, # Add jitter for symmetry breaking ) model.fit( @@ -172,7 +172,7 @@ def test_sparsity_control(self): def test_get_active_weights(self): """Test active weight extraction.""" N = 100 - model = SparseCalibrationWeights(n_features=N) + model = SparseCalibrationWeights(n_features=N, init_weights=1.0) # Simple test data M = sp.eye(N, format="csr") @@ -203,7 +203,7 @@ def test_deterministic_inference(self): M = sp.random(Q, N, density=0.5, format="csr") y = np.random.randn(Q) - model = SparseCalibrationWeights(n_features=N) + model = SparseCalibrationWeights(n_features=N, init_weights=1.0) model.fit(M, y, epochs=100, verbose=False) # Multiple predictions should be identical @@ -358,7 +358,7 @@ def test_group_wise_averaging_edge_cases(self): M = sp.random(Q, N, density=0.3, format="csr") y = np.random.uniform(100, 1000, size=Q) - model = SparseCalibrationWeights(n_features=N) + model = SparseCalibrationWeights(n_features=N, init_weights=1.0) # Test 1: All targets in one group (should behave like no grouping) target_groups_single = np.zeros(Q, dtype=int) @@ -422,3 +422,122 @@ def test_group_wise_averaging_edge_cases(self): assert ( np.mean(small_group_errors) < 0.5 ), "Small groups should not be ignored" + + def test_init_weights_options(self): + """Test different weight initialization options.""" + N = 50 + + # Test 1: Default (None) should give all weights = 1.0 + model_default = SparseCalibrationWeights(n_features=N) + with torch.no_grad(): + weights = torch.exp(model_default.log_weight) + assert torch.allclose(weights, torch.ones(N), atol=1e-6) + + # Test 2: Scalar initialization + model_scalar = SparseCalibrationWeights(n_features=N, init_weights=2.5) + with torch.no_grad(): + weights = torch.exp(model_scalar.log_weight) + assert torch.allclose(weights, torch.full((N,), 2.5), atol=1e-6) + + # Test 3: Array initialization + init_array = np.random.uniform(0.5, 2.0, size=N) + model_array = SparseCalibrationWeights(n_features=N, init_weights=init_array) + with torch.no_grad(): + weights = torch.exp(model_array.log_weight).cpu().numpy() + np.testing.assert_allclose(weights, init_array, rtol=1e-5) + + # Test 4: Wrong shape should raise error + with pytest.raises(ValueError, match="must have shape"): + SparseCalibrationWeights(n_features=N, init_weights=np.ones(N + 1)) + + def test_weight_jitter(self): + """Test that weight jitter works correctly.""" + N = 100 + Q = 20 + + np.random.seed(42) + torch.manual_seed(42) + + M = sp.random(Q, N, density=0.3, format="csr") + y = np.random.randn(Q) + 10 + + # Model with jitter + model_with_jitter = SparseCalibrationWeights( + n_features=N, + init_weights=1.0, + weight_jitter_sd=0.5 + ) + + # Store initial weights + initial_weights = model_with_jitter.log_weight.data.clone() + + # Fit should add jitter + model_with_jitter.fit(M, y, epochs=10, verbose=False) + + # Weights should have changed due to jitter (and training) + final_weights = model_with_jitter.log_weight.data + assert not torch.allclose(initial_weights, final_weights) + + # Model without jitter + torch.manual_seed(42) # Reset seed + model_no_jitter = SparseCalibrationWeights( + n_features=N, + init_weights=1.0, + weight_jitter_sd=0.0 # No jitter + ) + + initial_weights_no_jitter = model_no_jitter.log_weight.data.clone() + model_no_jitter.fit(M, y, epochs=1, verbose=False) # Just 1 epoch + + # After 1 epoch, change should be small without jitter + weights_after_1_epoch = model_no_jitter.log_weight.data + # The change is due to gradient updates only + change = (weights_after_1_epoch - initial_weights_no_jitter).abs().max() + assert change < 1.0, "Without jitter, initial change should be small" + + def test_init_keep_prob_options(self): + """Test init_keep_prob as scalar and array.""" + n_features = 20 + + # Test 1: Scalar init_keep_prob (existing behavior) + model_scalar = SparseCalibrationWeights( + n_features=n_features, + init_keep_prob=0.7, + ) + # All log_alpha values should be similar (around log(0.7/0.3) plus small jitter) + expected_mu = np.log(0.7 / 0.3) + with torch.no_grad(): + log_alphas = model_scalar.log_alpha.numpy() + # Check they're all close to expected value (within jitter range) + assert np.all(np.abs(log_alphas - expected_mu) < 0.1) + + # Test 2: Array init_keep_prob + keep_probs = np.linspace(0.1, 0.9, n_features) + model_array = SparseCalibrationWeights( + n_features=n_features, + init_keep_prob=keep_probs, + ) + # Each log_alpha should correspond to its keep_prob + with torch.no_grad(): + log_alphas = model_array.log_alpha.numpy() + expected_mus = np.log(keep_probs / (1 - keep_probs)) + # Check each is close to its expected value + assert np.all(np.abs(log_alphas - expected_mus) < 0.1) + + # Test 3: Wrong shape should raise error + with pytest.raises(ValueError, match="must have shape"): + SparseCalibrationWeights( + n_features=10, + init_keep_prob=np.ones(5), # Wrong size + ) + + # Test 4: Edge case probabilities get clamped + extreme_probs = np.array([0.0, 0.5, 1.0]) + model_extreme = SparseCalibrationWeights( + n_features=3, + init_keep_prob=extreme_probs, + ) + with torch.no_grad(): + log_alphas = model_extreme.log_alpha.numpy() + # Should not have inf or -inf values + assert np.all(np.isfinite(log_alphas)) From a74c6da54020734bb031554ca7ebbbb23582df76 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Sun, 7 Sep 2025 17:31:12 -0400 Subject: [PATCH 4/8] allowing 3.11 now --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9570634..3b33f60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ {name = "PolicyEngine", email = "hello@policyengine.org"}, ] license = {file = "LICENSE"} -requires-python = ">=3.12" +requires-python = ">=3.11" classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Science/Research", From d374e44f7388420922c434e9c7a6301c71aa54b6 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Sun, 7 Sep 2025 18:32:03 -0400 Subject: [PATCH 5/8] argument improvement --- l0/calibration.py | 10 +++++----- tests/test_calibration.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/l0/calibration.py b/l0/calibration.py index da0bef2..14f6d25 100644 --- a/l0/calibration.py +++ b/l0/calibration.py @@ -39,7 +39,7 @@ class SparseCalibrationWeights(nn.Module): If float, all weights initialized to this value. If array, must have shape (n_features,). If None, defaults to 1.0 for all weights. - weight_jitter_sd : float + log_weight_jitter_sd : float Standard deviation of noise added to log weights at start of fit() to break symmetry. Set to 0 to disable jitter. Default is 0.0 (no jitter). device : str or torch.device @@ -54,7 +54,7 @@ def __init__( zeta: float = 1.1, init_keep_prob: float | np.ndarray = 0.5, init_weights: float | np.ndarray | None = None, - weight_jitter_sd: float = 0.0, + log_weight_jitter_sd: float = 0.0, device: str | torch.device = "cpu", ): super().__init__() @@ -62,7 +62,7 @@ def __init__( self.beta = beta self.gamma = gamma self.zeta = zeta - self.weight_jitter_sd = weight_jitter_sd + self.log_weight_jitter_sd = log_weight_jitter_sd self.device = torch.device(device) # Initialize weights (on original scale) @@ -356,9 +356,9 @@ def fit( group_weights = torch.ones_like(y) # Add jitter to weights to break symmetry (if jitter_sd > 0) - if self.weight_jitter_sd > 0: + if self.log_weight_jitter_sd > 0: with torch.no_grad(): - jitter = torch.randn_like(self.log_weight) * self.weight_jitter_sd + jitter = torch.randn_like(self.log_weight) * self.log_weight_jitter_sd self.log_weight.data += jitter # Setup optimizer diff --git a/tests/test_calibration.py b/tests/test_calibration.py index 0f3d4ff..dc5c518 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -58,7 +58,7 @@ def test_sparse_ground_truth_relative_loss(self): zeta=1.1, init_keep_prob=0.3, init_weights=1.0, # Start all weights at 1.0 - weight_jitter_sd=0.5, # Add jitter for symmetry breaking + log_weight_jitter_sd=0.5, # Add jitter for symmetry breaking ) model.fit( @@ -465,7 +465,7 @@ def test_weight_jitter(self): model_with_jitter = SparseCalibrationWeights( n_features=N, init_weights=1.0, - weight_jitter_sd=0.5 + log_weight_jitter_sd=0.5 ) # Store initial weights @@ -483,7 +483,7 @@ def test_weight_jitter(self): model_no_jitter = SparseCalibrationWeights( n_features=N, init_weights=1.0, - weight_jitter_sd=0.0 # No jitter + log_weight_jitter_sd=0.0 # No jitter ) initial_weights_no_jitter = model_no_jitter.log_weight.data.clone() From 3dfe5e463aa9bd30f1a637b09be2dfbbaef36bd2 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Mon, 8 Sep 2025 10:25:15 -0400 Subject: [PATCH 6/8] discarding empty workflow --- .github/workflows/pr.yml | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 .github/workflows/pr.yml diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml deleted file mode 100644 index 33a4329..0000000 --- a/.github/workflows/pr.yml +++ /dev/null @@ -1,2 +0,0 @@ -# This file is deprecated - use pr_changelog.yaml and pr_code_changes.yaml instead -# Keeping for backwards compatibility \ No newline at end of file From 1c0780b6ab64c1f38042ead4dd4db4597429ab0a Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Mon, 8 Sep 2025 11:29:45 -0400 Subject: [PATCH 7/8] jitter --- CRITICAL_TEMPERATURE_BUG.md | 155 ++++++++++++++++++++++++++++++++++++ l0/calibration.py | 27 +++++-- tests/test_calibration.py | 50 ++++++------ 3 files changed, 201 insertions(+), 31 deletions(-) create mode 100644 CRITICAL_TEMPERATURE_BUG.md diff --git a/CRITICAL_TEMPERATURE_BUG.md b/CRITICAL_TEMPERATURE_BUG.md new file mode 100644 index 0000000..959e530 --- /dev/null +++ b/CRITICAL_TEMPERATURE_BUG.md @@ -0,0 +1,155 @@ +# Temperature Preservation in L0 Regularization: Critical Implementation Notes + +## Executive Summary + +The temperature parameter (β) in Hard Concrete distributions is **critical** for L0 regularization performance. Our analysis reveals that: + +1. **The original authors' implementation** (`distributions.py`) contains a bug where temperature is incorrectly dropped in deterministic mode +2. **Our standalone implementations** (`calibration.py` and `sparse.py`) correctly preserve temperature in all modes +3. **The gold standard** (`l0_louizos_improved_gate.py`) confirms temperature must always be preserved + +## The Temperature Bug in distributions.py + +### What We Found + +In `distributions.py` (from the authors' repository), the deterministic gates incorrectly drop temperature: + +```python +# distributions.py line 134 - INCORRECT +def _deterministic_gates(self) -> torch.Tensor: + probs = torch.sigmoid(self.qz_logits) # ❌ Missing temperature! + gates = probs * (self.zeta - self.gamma) + self.gamma + return torch.clamp(gates, 0, 1) +``` + +### Why This Matters + +The temperature parameter controls the "hardness" of the concrete distribution: +- **Lower temperature** (e.g., 0.1): More discrete, binary-like gates +- **Higher temperature** (e.g., 2.0): Softer, more continuous gates + +Dropping temperature is equivalent to setting β=1, which fundamentally changes the distribution's behavior and can severely impact: +- Convergence speed +- Final sparsity levels +- Model performance + +## Correct Implementations + +### Gold Standard (l0_louizos_improved_gate.py) + +Your validated implementation consistently uses temperature: + +```python +# Sampling (line 55) +X = (torch.log(u) - torch.log(1 - u) + log_alpha) / beta # ✅ + +# Deterministic (line 151) +z_final = ((log_alpha / beta).sigmoid() * (zeta - gamma) + gamma).clamp(0, 1) # ✅ + +# Penalty computation (line 133) +c = -beta * torch.log(torch.tensor(-gamma / zeta)) # ✅ +pi = torch.sigmoid(log_alpha + c) +``` + +### calibration.py (Standalone, Correct) + +```python +# Sampling (lines 160-164) +def _sample_gates(self) -> torch.Tensor: + u = torch.rand_like(self.log_alpha).clamp(eps, 1 - eps) + s = (torch.log(u) - torch.log(1 - u) + self.log_alpha) / self.beta # ✅ + s = torch.sigmoid(s) + s_bar = s * (self.zeta - self.gamma) + self.gamma + return s_bar.clamp(0, 1) + +# Deterministic (lines 167-170) +def get_deterministic_gates(self) -> torch.Tensor: + s = torch.sigmoid(self.log_alpha / self.beta) # ✅ + s_bar = s * (self.zeta - self.gamma) + self.gamma + return s_bar.clamp(0, 1) + +# Penalty (lines 233-237) +def get_l0_penalty(self) -> torch.Tensor: + c = -self.beta * torch.log(torch.tensor(-self.gamma / self.zeta)) # ✅ + pi = torch.sigmoid(self.log_alpha + c) + return pi.sum() +``` + +### sparse.py (Standalone, Correct) + +```python +# Sampling (lines 113-120) +def _sample_gates(self) -> torch.Tensor: + u = torch.rand_like(self.log_alpha).clamp(eps, 1 - eps) + X = (torch.log(u) - torch.log(1 - u) + self.log_alpha) / self.beta # ✅ + s = torch.sigmoid(X) + s_bar = s * (self.zeta - self.gamma) + self.gamma + return s_bar.clamp(0, 1) + +# Deterministic (lines 122-127) +def get_deterministic_gates(self) -> torch.Tensor: + X = self.log_alpha / self.beta # ✅ + s = torch.sigmoid(X) + s_bar = s * (self.zeta - self.gamma) + self.gamma + return s_bar.clamp(0, 1) + +# Penalty (lines 172-182) +def get_l0_penalty(self) -> torch.Tensor: + c = -self.beta * torch.log(torch.tensor(-self.gamma / self.zeta)) # ✅ + pi = torch.sigmoid(self.log_alpha + c) + return pi.sum() +``` + +## Mathematical Correctness + +The Hard Concrete distribution requires temperature in three key places: + +### 1. Sampling (Training Mode) +The concrete distribution samples as: +``` +s = sigmoid((log(u) - log(1-u) + log_α) / β) +``` +Temperature β controls the sharpness of the sigmoid, essential for the reparameterization trick. + +### 2. Deterministic Gates (Inference Mode) +The mean of the distribution: +``` +s = sigmoid(log_α / β) +``` +**Must use the same temperature** as during training to maintain consistency. + +### 3. L0 Penalty Computation +The probability of a gate being active: +``` +P(gate > 0) = sigmoid(log_α - β * log(-γ/ζ)) +``` +Temperature affects the shift in log-odds space. + +## Practical Implications + +### For Survey Calibration (calibration.py) + +Your implementation with β=2/3 (from the paper) is correct. The temperature: +- Provides the right balance between exploration and exploitation +- Enables smooth gradient flow during optimization +- Allows fine control over sparsity levels + +### For Sparse Linear Models (sparse.py) + +The preserved temperature ensures: +- Proper feature selection dynamics +- Stable convergence to sparse solutions +- Consistency between training and inference + +## Recommendations + +1. **Continue using** `calibration.py` and `sparse.py` as standalone modules - they're correct +2. **Avoid importing** from `distributions.py` until the temperature bug is fixed +3. **Keep temperature** in the range [0.1, 2/3] for best results (lower = harder gates) +4. **Document this issue** when sharing code to prevent others from using the buggy version + +## Key Takeaway + +Your instinct about never dropping temperature was absolutely correct. The temperature parameter is fundamental to the Hard Concrete distribution's behavior, and dropping it (as in the authors' `distributions.py`) is a significant bug that can severely impact model performance. + +Both `calibration.py` and `sparse.py` correctly implement the Hard Concrete distribution with proper temperature preservation, making them reliable for production use. \ No newline at end of file diff --git a/l0/calibration.py b/l0/calibration.py index 14f6d25..3b2eb90 100644 --- a/l0/calibration.py +++ b/l0/calibration.py @@ -42,6 +42,9 @@ class SparseCalibrationWeights(nn.Module): log_weight_jitter_sd : float Standard deviation of noise added to log weights at start of fit() to break symmetry. Set to 0 to disable jitter. Default is 0.0 (no jitter). + log_alpha_jitter_sd : float + Standard deviation of noise added to log_alpha at initialization to break symmetry. + Set to 0 to disable jitter. Default is 0.01 (following Louizos et al.). device : str or torch.device Device to run computations on ('cpu' or 'cuda') """ @@ -55,6 +58,7 @@ def __init__( init_keep_prob: float | np.ndarray = 0.5, init_weights: float | np.ndarray | None = None, log_weight_jitter_sd: float = 0.0, + log_alpha_jitter_sd: float = 0.01, device: str | torch.device = "cpu", ): super().__init__() @@ -63,6 +67,7 @@ def __init__( self.gamma = gamma self.zeta = zeta self.log_weight_jitter_sd = log_weight_jitter_sd + self.log_alpha_jitter_sd = log_alpha_jitter_sd self.device = torch.device(device) # Initialize weights (on original scale) @@ -84,7 +89,7 @@ def __init__( f"init_weights array must have shape ({n_features},), " f"got {weight_values.shape}" ) - + # Convert to log space to ensure positivity via exp transformation # Add small epsilon to avoid log(0) self.log_weight = nn.Parameter(torch.log(weight_values + 1e-8)) @@ -108,13 +113,18 @@ def __init__( ) # Clip to valid probability range to avoid log(0) or log(inf) keep_prob_values = keep_prob_values.clamp(1e-6, 1 - 1e-6) - + # Convert probabilities to log_alpha mu = torch.log(keep_prob_values / (1 - keep_prob_values)) - # Add small jitter to break symmetry - self.log_alpha = nn.Parameter( - mu + torch.randn(n_features, device=self.device) * 0.01 - ) + # Add jitter to break symmetry (if specified) + if self.log_alpha_jitter_sd > 0: + jitter = ( + torch.randn(n_features, device=self.device) + * self.log_alpha_jitter_sd + ) + self.log_alpha = nn.Parameter(mu + jitter) + else: + self.log_alpha = nn.Parameter(mu) # Cache for sparse tensor conversion self._cached_M_torch: torch.sparse.Tensor | None = None @@ -358,7 +368,10 @@ def fit( # Add jitter to weights to break symmetry (if jitter_sd > 0) if self.log_weight_jitter_sd > 0: with torch.no_grad(): - jitter = torch.randn_like(self.log_weight) * self.log_weight_jitter_sd + jitter = ( + torch.randn_like(self.log_weight) + * self.log_weight_jitter_sd + ) self.log_weight.data += jitter # Setup optimizer diff --git a/tests/test_calibration.py b/tests/test_calibration.py index dc5c518..af01c3a 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -426,79 +426,81 @@ def test_group_wise_averaging_edge_cases(self): def test_init_weights_options(self): """Test different weight initialization options.""" N = 50 - + # Test 1: Default (None) should give all weights = 1.0 model_default = SparseCalibrationWeights(n_features=N) with torch.no_grad(): weights = torch.exp(model_default.log_weight) assert torch.allclose(weights, torch.ones(N), atol=1e-6) - + # Test 2: Scalar initialization model_scalar = SparseCalibrationWeights(n_features=N, init_weights=2.5) with torch.no_grad(): weights = torch.exp(model_scalar.log_weight) assert torch.allclose(weights, torch.full((N,), 2.5), atol=1e-6) - + # Test 3: Array initialization init_array = np.random.uniform(0.5, 2.0, size=N) - model_array = SparseCalibrationWeights(n_features=N, init_weights=init_array) + model_array = SparseCalibrationWeights( + n_features=N, init_weights=init_array + ) with torch.no_grad(): weights = torch.exp(model_array.log_weight).cpu().numpy() np.testing.assert_allclose(weights, init_array, rtol=1e-5) - + # Test 4: Wrong shape should raise error with pytest.raises(ValueError, match="must have shape"): SparseCalibrationWeights(n_features=N, init_weights=np.ones(N + 1)) - + def test_weight_jitter(self): """Test that weight jitter works correctly.""" N = 100 Q = 20 - + np.random.seed(42) torch.manual_seed(42) - + M = sp.random(Q, N, density=0.3, format="csr") y = np.random.randn(Q) + 10 - + # Model with jitter model_with_jitter = SparseCalibrationWeights( - n_features=N, - init_weights=1.0, - log_weight_jitter_sd=0.5 + n_features=N, init_weights=1.0, log_weight_jitter_sd=0.5 ) - + # Store initial weights initial_weights = model_with_jitter.log_weight.data.clone() - + # Fit should add jitter model_with_jitter.fit(M, y, epochs=10, verbose=False) - + # Weights should have changed due to jitter (and training) final_weights = model_with_jitter.log_weight.data assert not torch.allclose(initial_weights, final_weights) - + # Model without jitter torch.manual_seed(42) # Reset seed model_no_jitter = SparseCalibrationWeights( n_features=N, init_weights=1.0, - log_weight_jitter_sd=0.0 # No jitter + log_weight_jitter_sd=0.0, # No jitter ) - + initial_weights_no_jitter = model_no_jitter.log_weight.data.clone() model_no_jitter.fit(M, y, epochs=1, verbose=False) # Just 1 epoch - + # After 1 epoch, change should be small without jitter weights_after_1_epoch = model_no_jitter.log_weight.data # The change is due to gradient updates only - change = (weights_after_1_epoch - initial_weights_no_jitter).abs().max() + change = ( + (weights_after_1_epoch - initial_weights_no_jitter).abs().max() + ) assert change < 1.0, "Without jitter, initial change should be small" def test_init_keep_prob_options(self): """Test init_keep_prob as scalar and array.""" n_features = 20 - + # Test 1: Scalar init_keep_prob (existing behavior) model_scalar = SparseCalibrationWeights( n_features=n_features, @@ -510,7 +512,7 @@ def test_init_keep_prob_options(self): log_alphas = model_scalar.log_alpha.numpy() # Check they're all close to expected value (within jitter range) assert np.all(np.abs(log_alphas - expected_mu) < 0.1) - + # Test 2: Array init_keep_prob keep_probs = np.linspace(0.1, 0.9, n_features) model_array = SparseCalibrationWeights( @@ -523,14 +525,14 @@ def test_init_keep_prob_options(self): expected_mus = np.log(keep_probs / (1 - keep_probs)) # Check each is close to its expected value assert np.all(np.abs(log_alphas - expected_mus) < 0.1) - + # Test 3: Wrong shape should raise error with pytest.raises(ValueError, match="must have shape"): SparseCalibrationWeights( n_features=10, init_keep_prob=np.ones(5), # Wrong size ) - + # Test 4: Edge case probabilities get clamped extreme_probs = np.array([0.0, 0.5, 1.0]) model_extreme = SparseCalibrationWeights( From 9f51ce9a9150229ece9dc471d1229a3c917ce5b7 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Tue, 9 Sep 2025 22:57:23 -0400 Subject: [PATCH 8/8] Add changelog entry and fix yaml-changelog version requirement --- changelog_entry.yaml | 11 +++++++++++ pyproject.toml | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..53bdd45 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,11 @@ +- bump: minor + changes: + added: + - Jitter parameter for improved numerical stability in calibration layers + - More flexible parameter initialization options for calibration + - Python 3.11 support (previously required 3.13) + changed: + - Enhanced argument handling in calibration module + - Improved parameter initialization defaults + fixed: + - Numerical stability issues in edge cases during calibration \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3b33f60..32e794e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dev = [ "black>=22.0", "build>=1.3", "scikit-learn>=1.0", - "yaml-changelog>=0.3.1", + "yaml-changelog>=0.3.0", "twine>=4.0.0", ] docs = [