diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..a5c937b 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: minor + changes: + added: + - Weight distribution diagnostics during training showing buckets (<0.01, 0.01-0.1, 0.1-1, 1-10, 10-1000, >1000) + - use_gates parameter to optionally disable L0 gates in SparseCalibrationWeights diff --git a/l0/calibration.py b/l0/calibration.py index a1e08dd..2fcf516 100644 --- a/l0/calibration.py +++ b/l0/calibration.py @@ -60,6 +60,7 @@ def __init__( log_weight_jitter_sd: float = 0.0, log_alpha_jitter_sd: float = 0.01, device: str | torch.device = "cpu", + use_gates: bool = True, ): super().__init__() self.n_features = n_features @@ -69,6 +70,7 @@ def __init__( self.log_weight_jitter_sd = log_weight_jitter_sd self.log_alpha_jitter_sd = log_alpha_jitter_sd self.device = torch.device(device) + self.use_gates = use_gates # Initialize weights (on original scale) if init_weights is None: @@ -193,6 +195,10 @@ def get_weights(self, deterministic: bool = False) -> torch.Tensor: torch.Tensor Positive calibration weights with L0 sparsity applied """ + if not self.use_gates: + # No gates - just return positive weights + return torch.exp(self.log_weight) + # Sample or get deterministic gates if deterministic: gates = self.get_deterministic_gates() @@ -453,6 +459,46 @@ def fit( 1 - active_info["count"] / self.n_features ) + # Weight distribution diagnostic + if active_weights.numel() > 0: + w_tiny = (active_weights < 0.01).sum().item() + w_small = ( + ((active_weights >= 0.01) & (active_weights < 0.1)) + .sum() + .item() + ) + w_med = ( + ((active_weights >= 0.1) & (active_weights < 1.0)) + .sum() + .item() + ) + w_normal = ( + ((active_weights >= 1.0) & (active_weights < 10.0)) + .sum() + .item() + ) + w_large = ( + ( + (active_weights >= 10.0) + & (active_weights < 1000.0) + ) + .sum() + .item() + ) + w_huge = (active_weights >= 1000.0).sum().item() + + total_active = active_weights.numel() + weight_dist = ( + f"[<0.01: {100*w_tiny/total_active:.1f}%, " + f"0.01-0.1: {100*w_small/total_active:.1f}%, " + f"0.1-1: {100*w_med/total_active:.1f}%, " + f"1-10: {100*w_normal/total_active:.1f}%, " + f"10-1000: {100*w_large/total_active:.1f}%, " + f">1000: {100*w_huge/total_active:.1f}%]" + ) + else: + weight_dist = "[no active weights]" + # Calculate components of the actual loss being minimized actual_data_loss = data_loss.item() actual_l0_loss = l0_loss.item() @@ -464,7 +510,8 @@ def fit( f"mean_group_mare={mean_group_mare:.4%}, " f"max_error={max_rel_err:.1%}, " f"total_loss={actual_total_loss:.3f}, " - f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)" + f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)\n" + f" Weight dist: {weight_dist}" ) else: print( @@ -472,7 +519,8 @@ def fit( f"mean_error={mean_rel_err:.4%}, " f"max_error={max_rel_err:.1%}, " f"total_loss={actual_total_loss:.3f}, " - f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)" + f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)\n" + f" Weight dist: {weight_dist}" ) return self