Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- bump: minor
changes:
added:
- Per-group loss multipliers via group_multipliers parameter in SparseCalibrationWeights.fit()
- normalize_groups parameter to control within-group normalization independently of target_groups
- Verbose echo of group weighting config at start of fit()
58 changes: 46 additions & 12 deletions l0/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ def fit(
verbose: bool = False,
verbose_freq: int = 100,
target_groups: np.ndarray | None = None,
normalize_groups: bool = True,
group_multipliers: dict[int, float] | None = None,
) -> "SparseCalibrationWeights":
"""
Fit calibration weights using gradient descent.
Expand Down Expand Up @@ -338,6 +340,12 @@ def fit(
Array of group IDs for each target. Targets in the same group
will be averaged together so each group contributes equally to loss.
If None, all targets are treated independently.
normalize_groups : bool
Whether to normalize within groups so each group contributes
equally to the loss. Default True (backward-compatible).
group_multipliers : dict[int, float], optional
Per-group loss scaling factors. Applied after normalization
(if enabled). Requires target_groups to be set.

Returns
-------
Expand All @@ -350,27 +358,53 @@ def fit(
# Convert M to torch sparse (will be cached)
M_torch = self._convert_sparse_to_torch(M)

# Validate group_multipliers
if group_multipliers is not None and target_groups is None:
raise ValueError(
"group_multipliers requires target_groups to be set"
)

# Compute group weights for loss averaging
if target_groups is not None:
# Convert to tensor
target_groups = torch.tensor(
target_groups, dtype=torch.long, device=self.device
)

# Calculate group weights: 1 / group_size for each target
unique_groups = torch.unique(target_groups)
group_weights = torch.zeros_like(y)

for group_id in unique_groups:
group_mask = target_groups == group_id
group_size = group_mask.sum().item()
# Each target in the group gets weight 1/group_size
# so the group's total contribution is 1
group_weights[group_mask] = 1.0 / group_size

if normalize_groups:
group_weights = torch.zeros_like(y)
for group_id in unique_groups:
group_mask = target_groups == group_id
group_size = group_mask.sum().item()
group_weights[group_mask] = 1.0 / group_size
else:
group_weights = torch.ones_like(y)

if group_multipliers is not None:
for gid, mult in group_multipliers.items():
group_mask = target_groups == gid
if not group_mask.any():
raise ValueError(
f"group_multipliers key {gid} not found "
f"in target_groups"
)
group_weights[group_mask] *= mult
else:
# No grouping - all targets weighted equally
group_weights = torch.ones_like(y)

if verbose:
if target_groups is not None:
n_groups = len(torch.unique(target_groups))
parts = [f"{n_groups} groups"]
parts.append(
f"normalize={'on' if normalize_groups else 'off'}"
)
if group_multipliers is not None:
parts.append(f"multipliers={group_multipliers}")
print(f"Groups: {', '.join(parts)}")
else:
print("Groups: none (uniform weights)")

# Add jitter to weights to break symmetry (if jitter_sd > 0)
if self.log_weight_jitter_sd > 0:
with torch.no_grad():
Expand Down
221 changes: 221 additions & 0 deletions tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,3 +769,224 @@ def test_init_keep_prob_options(self):
log_alphas = model_extreme.log_alpha.numpy()
# Should not have inf or -inf values
assert np.all(np.isfinite(log_alphas))

def test_group_multipliers_without_normalization(self):
"""Multipliers work with normalize_groups=False."""
N = 100
Q = 39 # 3 singletons + 18 + 18

np.random.seed(42)
torch.manual_seed(42)

M = sp.random(Q, N, density=0.3, format="csr")
y_singletons = np.array([1e9, 5e8, 2e9])
y_group1 = np.random.uniform(1e3, 1e6, size=18)
y_group2 = np.random.uniform(1e3, 1e6, size=18)
y = np.concatenate([y_singletons, y_group1, y_group2])

target_groups = np.array([0, 1, 2] + [3] * 18 + [4] * 18)

# Without multipliers (baseline)
model_base = SparseCalibrationWeights(n_features=N)
model_base.fit(
M,
y,
lambda_l0=0.0001,
lr=0.1,
epochs=500,
loss_type="relative",
verbose=False,
target_groups=target_groups,
normalize_groups=False,
)

# With 10x multiplier on singletons
model_mult = SparseCalibrationWeights(n_features=N)
model_mult.fit(
M,
y,
lambda_l0=0.0001,
lr=0.1,
epochs=500,
loss_type="relative",
verbose=False,
target_groups=target_groups,
normalize_groups=False,
group_multipliers={0: 10.0, 1: 10.0, 2: 10.0},
)

with torch.no_grad():
y_pred_base = model_base.predict(M).cpu().numpy()
y_pred_mult = model_mult.predict(M).cpu().numpy()

# Singleton relative errors
sing_err_base = np.abs(
(y[:3] - y_pred_base[:3]) / (y[:3] + 1)
).mean()
sing_err_mult = np.abs(
(y[:3] - y_pred_mult[:3]) / (y[:3] + 1)
).mean()

assert sing_err_mult < sing_err_base * 1.1, (
f"10x multiplier should improve singleton fit: "
f"{sing_err_mult:.4f} vs {sing_err_base:.4f}"
)

def test_group_multipliers_with_normalization(self):
"""Multipliers compose with normalization."""
N = 100
Q = 39

np.random.seed(42)
torch.manual_seed(42)

M = sp.random(Q, N, density=0.3, format="csr")
y_singletons = np.array([1e9, 5e8, 2e9])
y_group1 = np.random.uniform(1e3, 1e6, size=18)
y_group2 = np.random.uniform(1e3, 1e6, size=18)
y = np.concatenate([y_singletons, y_group1, y_group2])

target_groups = np.array([0, 1, 2] + [3] * 18 + [4] * 18)

# With normalization only (baseline)
model_norm = SparseCalibrationWeights(n_features=N)
model_norm.fit(
M,
y,
lambda_l0=0.0001,
lr=0.1,
epochs=500,
loss_type="relative",
verbose=False,
target_groups=target_groups,
normalize_groups=True,
)

# With normalization + 5x multiplier on group 3
model_norm_mult = SparseCalibrationWeights(n_features=N)
model_norm_mult.fit(
M,
y,
lambda_l0=0.0001,
lr=0.1,
epochs=500,
loss_type="relative",
verbose=False,
target_groups=target_groups,
normalize_groups=True,
group_multipliers={3: 5.0},
)

with torch.no_grad():
y_pred_norm = model_norm.predict(M).cpu().numpy()
y_pred_mult = model_norm_mult.predict(M).cpu().numpy()

# Group 3 error should improve with multiplier
g3_err_norm = np.abs(
(y[3:21] - y_pred_norm[3:21]) / (y[3:21] + 1)
).mean()
g3_err_mult = np.abs(
(y[3:21] - y_pred_mult[3:21]) / (y[3:21] + 1)
).mean()

assert g3_err_mult < g3_err_norm * 1.1, (
f"5x multiplier should improve group 3 fit: "
f"{g3_err_mult:.4f} vs {g3_err_norm:.4f}"
)

def test_group_multipliers_requires_target_groups(self):
"""ValueError when group_multipliers set but target_groups is None."""
N = 50
Q = 10

M = sp.random(Q, N, density=0.3, format="csr")
y = np.random.uniform(100, 1000, size=Q)

model = SparseCalibrationWeights(n_features=N)
with pytest.raises(
ValueError,
match="group_multipliers requires target_groups",
):
model.fit(
M,
y,
epochs=10,
verbose=False,
group_multipliers={0: 2.0},
)

def test_group_multipliers_invalid_group_id(self):
"""ValueError when multiplier key doesn't exist in target_groups."""
N = 50
Q = 10

M = sp.random(Q, N, density=0.3, format="csr")
y = np.random.uniform(100, 1000, size=Q)
target_groups = np.zeros(Q, dtype=int) # All group 0

model = SparseCalibrationWeights(n_features=N)
with pytest.raises(
ValueError,
match="group_multipliers key 99 not found",
):
model.fit(
M,
y,
epochs=10,
verbose=False,
target_groups=target_groups,
group_multipliers={99: 2.0},
)

def test_normalize_groups_false_uniform_weights(self):
"""normalize_groups=False with target_groups gives uniform weights."""
N = 50
Q = 10

np.random.seed(42)
torch.manual_seed(42)

M = sp.random(Q, N, density=0.3, format="csr")
y = np.random.uniform(100, 1000, size=Q)

# Without target_groups at all
model_none = SparseCalibrationWeights(n_features=N)
model_none.fit(
M,
y,
lambda_l0=0.0001,
lr=0.1,
epochs=500,
loss_type="relative",
verbose=False,
target_groups=None,
)

# With target_groups but normalize_groups=False (should be equivalent)
torch.manual_seed(42)
target_groups = np.array([0] * 5 + [1] * 5)
model_no_norm = SparseCalibrationWeights(n_features=N)
model_no_norm.fit(
M,
y,
lambda_l0=0.0001,
lr=0.1,
epochs=500,
loss_type="relative",
verbose=False,
target_groups=target_groups,
normalize_groups=False,
)

with torch.no_grad():
y_pred_none = model_none.predict(M).cpu().numpy()
y_pred_no_norm = model_no_norm.predict(M).cpu().numpy()

err_none = np.abs((y - y_pred_none) / (y + 1)).mean()
err_no_norm = np.abs((y - y_pred_no_norm) / (y + 1)).mean()

# Both should achieve similar error levels
assert abs(err_none - err_no_norm) < 0.1, (
f"normalize_groups=False should behave like no groups: "
f"{err_none:.4f} vs {err_no_norm:.4f}"
)