Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- bump: minor
changes:
added:
- Weight distribution diagnostics during training showing buckets (<0.01, 0.01-0.1, 0.1-1, 1-10, 10-1000, >1000)
- use_gates parameter to optionally disable L0 gates in SparseCalibrationWeights
52 changes: 50 additions & 2 deletions l0/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
log_weight_jitter_sd: float = 0.0,
log_alpha_jitter_sd: float = 0.01,
device: str | torch.device = "cpu",
use_gates: bool = True,
):
super().__init__()
self.n_features = n_features
Expand All @@ -69,6 +70,7 @@ def __init__(
self.log_weight_jitter_sd = log_weight_jitter_sd
self.log_alpha_jitter_sd = log_alpha_jitter_sd
self.device = torch.device(device)
self.use_gates = use_gates

# Initialize weights (on original scale)
if init_weights is None:
Expand Down Expand Up @@ -193,6 +195,10 @@ def get_weights(self, deterministic: bool = False) -> torch.Tensor:
torch.Tensor
Positive calibration weights with L0 sparsity applied
"""
if not self.use_gates:
# No gates - just return positive weights
return torch.exp(self.log_weight)

# Sample or get deterministic gates
if deterministic:
gates = self.get_deterministic_gates()
Expand Down Expand Up @@ -453,6 +459,46 @@ def fit(
1 - active_info["count"] / self.n_features
)

# Weight distribution diagnostic
if active_weights.numel() > 0:
w_tiny = (active_weights < 0.01).sum().item()
w_small = (
((active_weights >= 0.01) & (active_weights < 0.1))
.sum()
.item()
)
w_med = (
((active_weights >= 0.1) & (active_weights < 1.0))
.sum()
.item()
)
w_normal = (
((active_weights >= 1.0) & (active_weights < 10.0))
.sum()
.item()
)
w_large = (
(
(active_weights >= 10.0)
& (active_weights < 1000.0)
)
.sum()
.item()
)
w_huge = (active_weights >= 1000.0).sum().item()

total_active = active_weights.numel()
weight_dist = (
f"[<0.01: {100*w_tiny/total_active:.1f}%, "
f"0.01-0.1: {100*w_small/total_active:.1f}%, "
f"0.1-1: {100*w_med/total_active:.1f}%, "
f"1-10: {100*w_normal/total_active:.1f}%, "
f"10-1000: {100*w_large/total_active:.1f}%, "
f">1000: {100*w_huge/total_active:.1f}%]"
)
else:
weight_dist = "[no active weights]"

# Calculate components of the actual loss being minimized
actual_data_loss = data_loss.item()
actual_l0_loss = l0_loss.item()
Expand All @@ -464,15 +510,17 @@ def fit(
f"mean_group_mare={mean_group_mare:.4%}, "
f"max_error={max_rel_err:.1%}, "
f"total_loss={actual_total_loss:.3f}, "
f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)"
f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)\n"
f" Weight dist: {weight_dist}"
)
else:
print(
f"Epoch {epoch+1:4d}: "
f"mean_error={mean_rel_err:.4%}, "
f"max_error={max_rel_err:.1%}, "
f"total_loss={actual_total_loss:.3f}, "
f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)"
f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)\n"
f" Weight dist: {weight_dist}"
)

return self
Expand Down