From 86597ccf012cd51d3c78fcd5d892ca28d8edf346 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Sun, 28 Sep 2025 18:39:47 -0400 Subject: [PATCH] better logging --- README.md | 11 ++ hierarchical_penalty.md | 223 ++++++++++++++++++++++++++++++++++++++ l0/calibration.py | 8 +- tests/test_calibration.py | 193 +++++++++++++++++++++++++++++++++ 4 files changed, 431 insertions(+), 4 deletions(-) create mode 100644 hierarchical_penalty.md diff --git a/README.md b/README.md index e7065eb..aed2b3c 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,17 @@ A PyTorch implementation of L0 regularization for neural network sparsification and intelligent sampling, based on [Louizos, Welling, & Kingma (2017)](https://arxiv.org/abs/1712.01312). +This method is considered a more faithful interpretation of L0 regularization because it directly integrates a differentiable approximation of the L0 norm into the training objective, allowing the model to learn which weights should be zero as part of the optimization process. Simply setting small weights to zero is a post-training heuristic that is disconnected from the learning objective. + +This Paper's Approach (Principled Optimization) 🧠: + +The method creates a + +- differentiable surrogate for the L0 norm. It achieves this by introducing stochastic gates that control whether a weight is active. +- The objective function is modified to minimize both the task-specific error and the expected number of "on" gates. +- By using a special "hard concrete" distribution, the gates can become exactly zero during training while still allowing gradients to flow. +- This means the network actively learns a sparse structure that balances performance and complexity from the very beginning of training. + ## Features - **Hard Concrete Distribution**: Differentiable approximation of L0 norm diff --git a/hierarchical_penalty.md b/hierarchical_penalty.md new file mode 100644 index 0000000..966fe98 --- /dev/null +++ b/hierarchical_penalty.md @@ -0,0 +1,223 @@ +# Hierarchical Penalty for Geographic Calibration + +This is describing a TODO + +## Problem Statement + +When calibrating weights at a granular geographic level (e.g., congressional districts), we want to ensure that aggregations to higher geographic levels (states, national) remain consistent with known totals at those levels. Instead of adding redundant rows to the calibration matrix, we implement this as a penalty term in the loss function. + +## Mathematical Formulation + +### Base Problem + +Given: +- Matrix `X` of shape `(n_targets, n_features)` where each row represents a geographic-specific target +- Target values `T` of length `n_targets` +- Weights `w` of length `n_features` + +The base loss function is: +``` +L_orig(w) = Σ_i ((X_i·w - T_i) / T_i)² +``` + +### Hierarchical Penalty + +We add a penalty term that measures consistency at aggregate levels: + +``` +P(w) = Σ_agg ((Σ_j∈agg X_j·w - Σ_j∈agg T_j) / Σ_j∈agg T_j)² +``` + +Where `agg` represents each aggregation level (e.g., each state, national total). + +The new loss function becomes: +``` +L_new(w) = L_orig(w) + λ·P(w) +``` + +## Implementation Specification + +### Geography Mapping Structure + +The `geography_mapping` should be a dictionary with the following structure: + +```python +geography_mapping = { + 'hierarchy': { + 'cd_to_state': { + '0101': '01', # CD 0101 belongs to state 01 + '0102': '01', # CD 0102 belongs to state 01 + '0201': '02', # CD 0201 belongs to state 02 + # ... for all CDs + }, + 'state_to_nation': { + '01': 'US', + '02': 'US', + # ... all states map to US + } + }, + 'target_indices': { + '0101': [0, 1, 2, ...], # Indices in X/T for CD 0101 + '0102': [50, 51, ...], # Indices in X/T for CD 0102 + # ... for all geographic units + }, + 'aggregation_targets': { + 'state': { + '01': { + 'indices': [1000, 1001, ...], # Where state 01's targets would be + 'values': [100000, 200000, ...] # Actual state-level target values + }, + '02': {...}, + # ... for all states + }, + 'national': { + 'US': { + 'indices': [2000, 2001, ...], # Where national targets would be + 'values': [5000000, 10000000, ...] # Actual national target values + } + } + } +} +``` + +### Alternative Simpler Structure + +A simpler mapping structure that just handles geographic aggregation: + +```python +geography_mapping = { + 'cd_to_state': { + '0101': '01', + '0102': '01', + '0201': '02', + # ... for all CDs + }, + 'target_groups': { + # Group indices that should sum to the same aggregate + # Each tuple is (target_indices, aggregate_target_value) + 'state_01_pop': ([0, 50, 100], 1234567), # Indices for pop targets in CDs of state 01 + 'state_01_snap': ([1, 51, 101], 45678), # Indices for SNAP in CDs of state 01 + 'state_02_pop': ([150, 200], 2345678), # Indices for pop in CDs of state 02 + # ... + 'national_pop': ([0, 50, 100, 150, 200, ...], 300000000), # All pop targets + 'national_snap': ([1, 51, 101, 151, 201, ...], 50000000), # All SNAP targets + } +} +``` + +### Function Signature + +```python +def add_hierarchical_penalty( + loss_function, + X: sparse.csr_matrix, + targets: np.ndarray, + geography_mapping: dict, + lambda_state: float = 1.0, + lambda_national: float = 1.0 +) -> callable: + """ + Wraps a loss function to add hierarchical consistency penalties. + + Args: + loss_function: Base loss function(w, X, targets) -> scalar + X: Calibration matrix (n_targets x n_features) + targets: Target values (n_targets,) + geography_mapping: Geographic hierarchy and target mappings + lambda_state: Weight for state-level consistency penalty + lambda_national: Weight for national-level consistency penalty + + Returns: + New loss function with hierarchical penalties + """ +``` + +## Implementation Details + +### Computing State-Level Penalties + +For each state: +1. Identify all CD target indices belonging to that state +2. Compute predicted sum: `state_pred = Σ(X[cd_indices] @ w)` +3. Compute target sum: `state_target = Σ(targets[cd_indices])` +4. Compute penalty: `((state_pred - state_target) / state_target)²` + +### Computing National Penalty + +1. Compute predicted sum: `national_pred = Σ(X @ w)` +2. Compute target sum: `national_target = Σ(targets)` +3. Compute penalty: `((national_pred - national_target) / national_target)²` + +### Gradient Computation + +The gradient of the penalty term with respect to w: + +For state s: +``` +∂P_s/∂w = 2 * ((Σ_j∈s X_j·w - Σ_j∈s T_j) / (Σ_j∈s T_j)²) * Σ_j∈s X_j +``` + +For national: +``` +∂P_nat/∂w = 2 * ((Σ X·w - Σ T) / (Σ T)²) * Σ X +``` + +### Efficient Implementation + +To avoid recomputing aggregations: +1. Pre-compute aggregation matrices `X_state` and `X_national` where each row is the sum of relevant CD rows +2. Pre-compute aggregate targets `T_state` and `T_national` +3. Then the penalty computation becomes simple matrix operations + +```python +# Precompute once +X_states = [] # Each row is sum of CDs for that state +T_states = [] # Corresponding state targets + +for state in states: + cd_indices = get_cd_indices(state) + X_states.append(X[cd_indices].sum(axis=0)) + T_states.append(targets[cd_indices].sum()) + +X_states = sparse.vstack(X_states) +T_states = np.array(T_states) + +# During optimization +state_preds = X_states @ w +state_penalties = ((state_preds - T_states) / T_states) ** 2 +``` + +## Usage Example + +```python +from l0.hierarchical import add_hierarchical_penalty + +# Set up geography mapping +geography_mapping = create_geography_mapping(targets_df) + +# Create penalized loss function +penalized_loss = add_hierarchical_penalty( + original_loss, + X_sparse, + targets, + geography_mapping, + lambda_state=0.1, + lambda_national=0.05 +) + +# Use in optimization +model.fit(X, y, loss_fn=penalized_loss) +``` + +## Benefits + +1. **No matrix expansion**: Don't need to add redundant rows for state/national targets +2. **Tunable enforcement**: Lambda parameters control strictness of hierarchical consistency +3. **Efficient computation**: Aggregations can be pre-computed +4. **Flexible hierarchy**: Can handle arbitrary geographic hierarchies (regions, divisions, etc.) + +## Considerations + +1. **Lambda tuning**: May need cross-validation to find optimal lambda values +2. **Different lambdas per variable**: Some variables (e.g., population) might need stricter consistency than others (e.g., income) +3. **Weighted penalties**: Could weight penalties by the importance/reliability of aggregate targets diff --git a/l0/calibration.py b/l0/calibration.py index 3b2eb90..a1e08dd 100644 --- a/l0/calibration.py +++ b/l0/calibration.py @@ -444,9 +444,9 @@ def fit( rel_errors[group_mask].mean().item() ) group_losses.append(group_mean_err) - mean_group_loss = np.mean(group_losses) + mean_group_mare = np.mean(group_losses) else: - mean_group_loss = mean_rel_err + mean_group_mare = mean_rel_err # Calculate sparsity percentage sparsity_pct = 100 * ( @@ -461,7 +461,7 @@ def fit( if target_groups is not None: print( f"Epoch {epoch+1:4d}: " - f"mean_group_loss={mean_group_loss:.1%}, " + f"mean_group_mare={mean_group_mare:.4%}, " f"max_error={max_rel_err:.1%}, " f"total_loss={actual_total_loss:.3f}, " f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)" @@ -469,7 +469,7 @@ def fit( else: print( f"Epoch {epoch+1:4d}: " - f"mean_error={mean_rel_err:.1%}, " + f"mean_error={mean_rel_err:.4%}, " f"max_error={max_rel_err:.1%}, " f"total_loss={actual_total_loss:.3f}, " f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)" diff --git a/tests/test_calibration.py b/tests/test_calibration.py index af01c3a..51a006f 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -247,6 +247,199 @@ def test_l2_regularization(self): weights_with_l2.max() <= weights_no_l2.max() * 2.0 ), "L2 should prevent extreme weights" + def test_pure_l2_penalty(self): + """Test pure L2 regularization without L0.""" + N = 50 + Q = 30 + + np.random.seed(42) + torch.manual_seed(42) + + # Create underdetermined problem where L2 helps regularize + M = sp.random(Q, N, density=0.5, format="csr") + M.data = np.abs(M.data) * 2 # Scale up to create larger weights + y = np.ones(Q) * 10 # Simple target + + # Train without any regularization + model_no_reg = SparseCalibrationWeights( + n_features=N, + init_weights=2.0, # Start with larger weights + init_keep_prob=0.999 # Keep all weights active + ) + model_no_reg.fit( + M, y, + lambda_l0=0.0, # No L0 + lambda_l2=0.0, # No L2 + epochs=1000, + lr=0.02, + verbose=False + ) + + # Train with strong L2 regularization + model_l2_only = SparseCalibrationWeights( + n_features=N, + init_weights=2.0, # Same starting point + init_keep_prob=0.999 # Keep all weights active + ) + model_l2_only.fit( + M, y, + lambda_l0=0.0, # No L0 + lambda_l2=10.0, # Very strong L2 to see clear effect + epochs=1000, + lr=0.02, + verbose=False + ) + + with torch.no_grad(): + weights_no_reg = model_no_reg.get_weights(deterministic=True) + weights_l2 = model_l2_only.get_weights(deterministic=True) + + # Get non-zero weights + active_no_reg = weights_no_reg[weights_no_reg > 1e-8] + active_l2 = weights_l2[weights_l2 > 1e-8] + + # L2 should keep most weights active (no sparsity from L2) + assert len(active_l2) >= N * 0.9, f"Pure L2 should not induce sparsity, got {len(active_l2)}/{N}" + + # L2 norm should be smaller with regularization + l2_norm_no_reg = (weights_no_reg ** 2).sum().sqrt() + l2_norm_with_l2 = (weights_l2 ** 2).sum().sqrt() + assert l2_norm_with_l2 < l2_norm_no_reg, \ + f"L2 regularization should reduce L2 norm: {l2_norm_with_l2:.2f} vs {l2_norm_no_reg:.2f}" + + # Check coefficient of variation (CV = std/mean) is lower with L2 + # This captures that L2 shrinks weights toward each other + if active_no_reg.mean() > 1e-6 and active_l2.mean() > 1e-6: + cv_no_reg = active_no_reg.std() / active_no_reg.mean() + cv_l2 = active_l2.std() / active_l2.mean() + assert cv_l2 < cv_no_reg * 1.2, \ + f"L2 should reduce relative variation: CV {cv_l2:.2f} vs {cv_no_reg:.2f}" + + # L2 should prevent extreme weights + assert weights_l2.max() < weights_no_reg.max() * 1.5, \ + f"L2 should limit max weights: {weights_l2.max():.2f} vs {weights_no_reg.max():.2f}" + + # Both should still fit reasonably well + y_pred_no_reg = model_no_reg.predict(M).cpu().numpy() + y_pred_l2 = model_l2_only.predict(M).cpu().numpy() + + error_no_reg = np.abs((y - y_pred_no_reg) / y).mean() + error_l2 = np.abs((y - y_pred_l2) / y).mean() + + # L2 model may have slightly worse fit due to regularization + assert error_l2 < 0.5, f"L2 model should still fit reasonably: {error_l2:.3f}" + # But the trade-off is worth it for regularization + + def test_l0_l2_combination(self): + """Test that combining L0 and L2 gives both sparsity and regularization.""" + N = 100 # features + Q = 50 # targets + + np.random.seed(42) + torch.manual_seed(42) + + # Create problem with potential for overfitting + M = sp.random(Q, N, density=0.4, format="csr") + M.data = np.abs(M.data) * 3 + y = np.random.uniform(5, 20, size=Q) + + # Model 1: Only L0 (sparsity without weight regularization) + model_l0_only = SparseCalibrationWeights( + n_features=N, + init_weights=2.0, + init_keep_prob=0.5 # Start with 50% probability + ) + model_l0_only.fit( + M, y, + lambda_l0=0.01, # Stronger L0 for sparsity + lambda_l2=0.0, # No L2 + epochs=2000, + lr=0.02, + verbose=False + ) + + # Model 2: Only L2 (weight regularization without sparsity) + model_l2_only = SparseCalibrationWeights( + n_features=N, + init_weights=2.0, + init_keep_prob=0.999 # Keep all weights active + ) + model_l2_only.fit( + M, y, + lambda_l0=0.0, # No L0 + lambda_l2=0.1, # Moderate L2 + epochs=1500, + lr=0.02, + verbose=False + ) + + # Model 3: Combined L0+L2 (both sparsity and weight regularization) + model_l0_l2 = SparseCalibrationWeights( + n_features=N, + init_weights=2.0, + init_keep_prob=0.5 # Same starting point as L0-only + ) + model_l0_l2.fit( + M, y, + lambda_l0=0.01, # Same L0 as model 1 + lambda_l2=0.1, # Add L2 regularization + epochs=2000, + lr=0.02, + verbose=False + ) + + with torch.no_grad(): + # Get weights and stats for all models + weights_l0_only = model_l0_only.get_weights(deterministic=True) + weights_l2_only = model_l2_only.get_weights(deterministic=True) + weights_l0_l2 = model_l0_l2.get_weights(deterministic=True) + + # Count active weights + active_l0_only = (weights_l0_only > 1e-6).sum().item() + active_l2_only = (weights_l2_only > 1e-6).sum().item() + active_l0_l2 = (weights_l0_l2 > 1e-6).sum().item() + + # L0-only should have sparsity + assert active_l0_only < N * 0.8, f"L0 should induce sparsity: {active_l0_only}/{N} active" + + # L2-only should have no/little sparsity + assert active_l2_only > N * 0.9, f"L2 alone shouldn't induce sparsity: {active_l2_only}/{N} active" + + # L0+L2 should have sparsity (from L0) + assert active_l0_l2 < N * 0.8, f"L0+L2 should have sparsity: {active_l0_l2}/{N} active" + + # Among active weights, L0+L2 should have smaller magnitudes than L0-only (from L2) + active_mask_l0_only = weights_l0_only > 1e-6 + active_mask_l0_l2 = weights_l0_l2 > 1e-6 + + if active_mask_l0_only.any() and active_mask_l0_l2.any(): + # Compare L2 norms of active weights + l2_norm_l0_only = (weights_l0_only[active_mask_l0_only] ** 2).sum().sqrt() + l2_norm_l0_l2 = (weights_l0_l2[active_mask_l0_l2] ** 2).sum().sqrt() + + # L0+L2 should have smaller weight norms than L0-only + # (L2 regularization effect on top of sparsity) + assert l2_norm_l0_l2 < l2_norm_l0_only * 1.2, \ + f"L0+L2 should have controlled weights: {l2_norm_l0_l2:.2f} vs L0-only {l2_norm_l0_only:.2f}" + + # Check prediction quality for all models + y_pred_l0_only = model_l0_only.predict(M).cpu().numpy() + y_pred_l2_only = model_l2_only.predict(M).cpu().numpy() + y_pred_l0_l2 = model_l0_l2.predict(M).cpu().numpy() + + error_l0_only = np.abs((y - y_pred_l0_only) / (y + 1)).mean() + error_l2_only = np.abs((y - y_pred_l2_only) / (y + 1)).mean() + error_l0_l2 = np.abs((y - y_pred_l0_l2) / (y + 1)).mean() + + # All should fit reasonably well (L2-only may have slightly worse fit due to regularization) + assert error_l0_only < 0.35, f"L0-only should fit well: {error_l0_only:.3f}" + assert error_l2_only < 0.35, f"L2-only should fit well: {error_l2_only:.3f}" + assert error_l0_l2 < 0.35, f"L0+L2 should fit well: {error_l0_l2:.3f}" + + print(f"\nL0-only: {active_l0_only}/{N} active, error={error_l0_only:.3f}") + print(f"L2-only: {active_l2_only}/{N} active, error={error_l2_only:.3f}") + print(f"L0+L2: {active_l0_l2}/{N} active, error={error_l0_l2:.3f}") + def test_group_wise_averaging(self): """Test that group-wise averaging balances loss contributions.""" N = 100 # features (households)