Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
- bump: minor
changes:
added:
- Group-wise loss averaging for calibration to balance contributions from targets with different cardinalities
- Improved training output with meaningful error percentages and sparsity statistics
changed:
- Simplified active weight detection in SparseCalibrationWeights (removed threshold parameter)
- Enhanced verbose output during calibration training to show relative errors and sparsity percentage
119 changes: 96 additions & 23 deletions l0/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,26 +218,20 @@ def get_sparsity(self) -> float:
"""
with torch.no_grad():
gates = self.get_deterministic_gates()
return (gates < 0.01).float().mean().item()
return (gates == 0).float().mean().item()

def get_active_weights(self, threshold: float = 0.01) -> dict:
def get_active_weights(self) -> dict:
"""
Get indices and values of active (non-zero) weights.

Parameters
----------
threshold : float
Gate values below this are considered zero

Returns
-------
dict
Dictionary with 'indices' and 'values' of active weights
"""
with torch.no_grad():
weights = self.get_weights(deterministic=True)
gates = self.get_deterministic_gates()
active_mask = gates > threshold
active_mask = weights > 0

return {
"indices": torch.where(active_mask)[0],
Expand All @@ -256,6 +250,7 @@ def fit(
loss_type: str = "mse",
verbose: bool = False,
verbose_freq: int = 100,
target_groups: np.ndarray | None = None,
) -> "SparseCalibrationWeights":
"""
Fit calibration weights using gradient descent.
Expand All @@ -280,6 +275,10 @@ def fit(
Whether to print progress
verbose_freq : int
How often to print progress
target_groups : numpy.ndarray, optional
Array of group IDs for each target. Targets in the same group
will be averaged together so each group contributes equally to loss.
If None, all targets are treated independently.

Returns
-------
Expand All @@ -292,6 +291,27 @@ def fit(
# Convert M to torch sparse (will be cached)
M_torch = self._convert_sparse_to_torch(M)

# Compute group weights for loss averaging
if target_groups is not None:
# Convert to tensor
target_groups = torch.tensor(
target_groups, dtype=torch.long, device=self.device
)

# Calculate group weights: 1 / group_size for each target
unique_groups = torch.unique(target_groups)
group_weights = torch.zeros_like(y)

for group_id in unique_groups:
group_mask = target_groups == group_id
group_size = group_mask.sum().item()
# Each target in the group gets weight 1/group_size
# so the group's total contribution is 1
group_weights[group_mask] = 1.0 / group_size
else:
# No grouping - all targets weighted equally
group_weights = torch.ones_like(y)

# Initialize weights
nn.init.normal_(self.log_weight, 0, 0.5)

Expand All @@ -303,15 +323,25 @@ def fit(
# Forward pass
y_pred = self.forward(M, deterministic=False)

# Compute loss
# Compute loss with group weighting
if loss_type == "relative":
# Relative error: (y - y_pred)^2 / (y + 1)^2
# Adding 1 to avoid division by zero
relative_errors = (y - y_pred) / (y + 1)
data_loss = relative_errors.pow(2).mean()
# Apply group weights and then average
weighted_squared_errors = (
relative_errors.pow(2) * group_weights
)
data_loss = (
weighted_squared_errors.sum()
) # Sum because weights already normalize
else:
# Standard MSE
data_loss = (y - y_pred).pow(2).mean()
# Standard MSE with group weighting
squared_errors = (y - y_pred).pow(2)
weighted_squared_errors = squared_errors * group_weights
data_loss = (
weighted_squared_errors.sum()
) # Sum because weights already normalize

l0_loss = self.get_l0_penalty()
loss = data_loss + lambda_l0 * l0_loss
Expand All @@ -331,18 +361,61 @@ def fit(
with torch.no_grad():
active_info = self.get_active_weights()
weights = self.get_weights(deterministic=True)
# Compute MSE for monitoring even if using relative loss
mse = (y - y_pred).pow(2).mean().item()
print(
f"Epoch {epoch+1:4d}: "
f"loss={loss.item():.4f}, "
f"data_loss={data_loss.item():.4f}, "
f"mse={mse:.4f}, "
f"l0={l0_loss.item():.2f}, "
f"active={active_info['count']}, "
f"mean_weight={weights[weights > 0.01].mean().item() if (weights > 0.01).any() else 0:.3f}"
active_weights = weights[weights > 0]

# Compute relative errors for meaningful output
y_det = self.forward(M, deterministic=True)
if loss_type == "relative":
rel_errors = torch.abs((y - y_det) / (y + 1))
else:
# For MSE, show relative errors anyway for interpretability
rel_errors = torch.abs((y - y_det) / (y + 1))

# For reporting, we can show both overall and group-averaged errors
mean_rel_err = rel_errors.mean().item()
max_rel_err = rel_errors.max().item()

# Compute mean group loss if groups are used
if target_groups is not None:
# Calculate mean loss per group
group_losses = []
for group_id in torch.unique(target_groups):
group_mask = target_groups == group_id
group_mean_err = (
rel_errors[group_mask].mean().item()
)
group_losses.append(group_mean_err)
mean_group_loss = np.mean(group_losses)
else:
mean_group_loss = mean_rel_err

# Calculate sparsity percentage
sparsity_pct = 100 * (
1 - active_info["count"] / self.n_features
)

# Calculate components of the actual loss being minimized
actual_data_loss = data_loss.item()
actual_l0_loss = l0_loss.item()
actual_total_loss = loss.item()

if target_groups is not None:
print(
f"Epoch {epoch+1:4d}: "
f"mean_group_loss={mean_group_loss:.1%}, "
f"max_error={max_rel_err:.1%}, "
f"total_loss={actual_total_loss:.3f}, "
f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)"
)
else:
print(
f"Epoch {epoch+1:4d}: "
f"mean_error={mean_rel_err:.1%}, "
f"max_error={max_rel_err:.1%}, "
f"total_loss={actual_total_loss:.3f}, "
f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)"
)

return self

def predict(self, M: sp.spmatrix) -> torch.Tensor:
Expand Down
Loading