diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..d5177fc 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,6 @@ +- bump: minor + changes: + added: + - Per-group loss multipliers via group_multipliers parameter in SparseCalibrationWeights.fit() + - normalize_groups parameter to control within-group normalization independently of target_groups + - Verbose echo of group weighting config at start of fit() diff --git a/l0/calibration.py b/l0/calibration.py index 2fcf516..2a02996 100644 --- a/l0/calibration.py +++ b/l0/calibration.py @@ -310,6 +310,8 @@ def fit( verbose: bool = False, verbose_freq: int = 100, target_groups: np.ndarray | None = None, + normalize_groups: bool = True, + group_multipliers: dict[int, float] | None = None, ) -> "SparseCalibrationWeights": """ Fit calibration weights using gradient descent. @@ -338,6 +340,12 @@ def fit( Array of group IDs for each target. Targets in the same group will be averaged together so each group contributes equally to loss. If None, all targets are treated independently. + normalize_groups : bool + Whether to normalize within groups so each group contributes + equally to the loss. Default True (backward-compatible). + group_multipliers : dict[int, float], optional + Per-group loss scaling factors. Applied after normalization + (if enabled). Requires target_groups to be set. Returns ------- @@ -350,27 +358,53 @@ def fit( # Convert M to torch sparse (will be cached) M_torch = self._convert_sparse_to_torch(M) + # Validate group_multipliers + if group_multipliers is not None and target_groups is None: + raise ValueError( + "group_multipliers requires target_groups to be set" + ) + # Compute group weights for loss averaging if target_groups is not None: - # Convert to tensor target_groups = torch.tensor( target_groups, dtype=torch.long, device=self.device ) - - # Calculate group weights: 1 / group_size for each target unique_groups = torch.unique(target_groups) - group_weights = torch.zeros_like(y) - - for group_id in unique_groups: - group_mask = target_groups == group_id - group_size = group_mask.sum().item() - # Each target in the group gets weight 1/group_size - # so the group's total contribution is 1 - group_weights[group_mask] = 1.0 / group_size + + if normalize_groups: + group_weights = torch.zeros_like(y) + for group_id in unique_groups: + group_mask = target_groups == group_id + group_size = group_mask.sum().item() + group_weights[group_mask] = 1.0 / group_size + else: + group_weights = torch.ones_like(y) + + if group_multipliers is not None: + for gid, mult in group_multipliers.items(): + group_mask = target_groups == gid + if not group_mask.any(): + raise ValueError( + f"group_multipliers key {gid} not found " + f"in target_groups" + ) + group_weights[group_mask] *= mult else: - # No grouping - all targets weighted equally group_weights = torch.ones_like(y) + if verbose: + if target_groups is not None: + n_groups = len(torch.unique(target_groups)) + parts = [f"{n_groups} groups"] + parts.append( + f"normalize={'on' if normalize_groups else 'off'}" + ) + if group_multipliers is not None: + parts.append(f"multipliers={group_multipliers}") + print(f"Groups: {', '.join(parts)}") + else: + print("Groups: none (uniform weights)") + # Add jitter to weights to break symmetry (if jitter_sd > 0) if self.log_weight_jitter_sd > 0: with torch.no_grad(): diff --git a/tests/test_calibration.py b/tests/test_calibration.py index 2df75df..cd8869b 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -769,3 +769,224 @@ def test_init_keep_prob_options(self): log_alphas = model_extreme.log_alpha.numpy() # Should not have inf or -inf values assert np.all(np.isfinite(log_alphas)) + + def test_group_multipliers_without_normalization(self): + """Multipliers work with normalize_groups=False.""" + N = 100 + Q = 39 # 3 singletons + 18 + 18 + + np.random.seed(42) + torch.manual_seed(42) + + M = sp.random(Q, N, density=0.3, format="csr") + y_singletons = np.array([1e9, 5e8, 2e9]) + y_group1 = np.random.uniform(1e3, 1e6, size=18) + y_group2 = np.random.uniform(1e3, 1e6, size=18) + y = np.concatenate([y_singletons, y_group1, y_group2]) + + target_groups = np.array([0, 1, 2] + [3] * 18 + [4] * 18) + + # Without multipliers (baseline) + model_base = SparseCalibrationWeights(n_features=N) + model_base.fit( + M, + y, + lambda_l0=0.0001, + lr=0.1, + epochs=500, + loss_type="relative", + verbose=False, + target_groups=target_groups, + normalize_groups=False, + ) + + # With 10x multiplier on singletons + model_mult = SparseCalibrationWeights(n_features=N) + model_mult.fit( + M, + y, + lambda_l0=0.0001, + lr=0.1, + epochs=500, + loss_type="relative", + verbose=False, + target_groups=target_groups, + normalize_groups=False, + group_multipliers={0: 10.0, 1: 10.0, 2: 10.0}, + ) + + with torch.no_grad(): + y_pred_base = model_base.predict(M).cpu().numpy() + y_pred_mult = model_mult.predict(M).cpu().numpy() + + # Singleton relative errors + sing_err_base = np.abs( + (y[:3] - y_pred_base[:3]) / (y[:3] + 1) + ).mean() + sing_err_mult = np.abs( + (y[:3] - y_pred_mult[:3]) / (y[:3] + 1) + ).mean() + + assert sing_err_mult < sing_err_base * 1.1, ( + f"10x multiplier should improve singleton fit: " + f"{sing_err_mult:.4f} vs {sing_err_base:.4f}" + ) + + def test_group_multipliers_with_normalization(self): + """Multipliers compose with normalization.""" + N = 100 + Q = 39 + + np.random.seed(42) + torch.manual_seed(42) + + M = sp.random(Q, N, density=0.3, format="csr") + y_singletons = np.array([1e9, 5e8, 2e9]) + y_group1 = np.random.uniform(1e3, 1e6, size=18) + y_group2 = np.random.uniform(1e3, 1e6, size=18) + y = np.concatenate([y_singletons, y_group1, y_group2]) + + target_groups = np.array([0, 1, 2] + [3] * 18 + [4] * 18) + + # With normalization only (baseline) + model_norm = SparseCalibrationWeights(n_features=N) + model_norm.fit( + M, + y, + lambda_l0=0.0001, + lr=0.1, + epochs=500, + loss_type="relative", + verbose=False, + target_groups=target_groups, + normalize_groups=True, + ) + + # With normalization + 5x multiplier on group 3 + model_norm_mult = SparseCalibrationWeights(n_features=N) + model_norm_mult.fit( + M, + y, + lambda_l0=0.0001, + lr=0.1, + epochs=500, + loss_type="relative", + verbose=False, + target_groups=target_groups, + normalize_groups=True, + group_multipliers={3: 5.0}, + ) + + with torch.no_grad(): + y_pred_norm = model_norm.predict(M).cpu().numpy() + y_pred_mult = model_norm_mult.predict(M).cpu().numpy() + + # Group 3 error should improve with multiplier + g3_err_norm = np.abs( + (y[3:21] - y_pred_norm[3:21]) / (y[3:21] + 1) + ).mean() + g3_err_mult = np.abs( + (y[3:21] - y_pred_mult[3:21]) / (y[3:21] + 1) + ).mean() + + assert g3_err_mult < g3_err_norm * 1.1, ( + f"5x multiplier should improve group 3 fit: " + f"{g3_err_mult:.4f} vs {g3_err_norm:.4f}" + ) + + def test_group_multipliers_requires_target_groups(self): + """ValueError when group_multipliers set but target_groups is None.""" + N = 50 + Q = 10 + + M = sp.random(Q, N, density=0.3, format="csr") + y = np.random.uniform(100, 1000, size=Q) + + model = SparseCalibrationWeights(n_features=N) + with pytest.raises( + ValueError, + match="group_multipliers requires target_groups", + ): + model.fit( + M, + y, + epochs=10, + verbose=False, + group_multipliers={0: 2.0}, + ) + + def test_group_multipliers_invalid_group_id(self): + """ValueError when multiplier key doesn't exist in target_groups.""" + N = 50 + Q = 10 + + M = sp.random(Q, N, density=0.3, format="csr") + y = np.random.uniform(100, 1000, size=Q) + target_groups = np.zeros(Q, dtype=int) # All group 0 + + model = SparseCalibrationWeights(n_features=N) + with pytest.raises( + ValueError, + match="group_multipliers key 99 not found", + ): + model.fit( + M, + y, + epochs=10, + verbose=False, + target_groups=target_groups, + group_multipliers={99: 2.0}, + ) + + def test_normalize_groups_false_uniform_weights(self): + """normalize_groups=False with target_groups gives uniform weights.""" + N = 50 + Q = 10 + + np.random.seed(42) + torch.manual_seed(42) + + M = sp.random(Q, N, density=0.3, format="csr") + y = np.random.uniform(100, 1000, size=Q) + + # Without target_groups at all + model_none = SparseCalibrationWeights(n_features=N) + model_none.fit( + M, + y, + lambda_l0=0.0001, + lr=0.1, + epochs=500, + loss_type="relative", + verbose=False, + target_groups=None, + ) + + # With target_groups but normalize_groups=False (should be equivalent) + torch.manual_seed(42) + target_groups = np.array([0] * 5 + [1] * 5) + model_no_norm = SparseCalibrationWeights(n_features=N) + model_no_norm.fit( + M, + y, + lambda_l0=0.0001, + lr=0.1, + epochs=500, + loss_type="relative", + verbose=False, + target_groups=target_groups, + normalize_groups=False, + ) + + with torch.no_grad(): + y_pred_none = model_none.predict(M).cpu().numpy() + y_pred_no_norm = model_no_norm.predict(M).cpu().numpy() + + err_none = np.abs((y - y_pred_none) / (y + 1)).mean() + err_no_norm = np.abs((y - y_pred_no_norm) / (y + 1)).mean() + + # Both should achieve similar error levels + assert abs(err_none - err_no_norm) < 0.1, ( + f"normalize_groups=False should behave like no groups: " + f"{err_none:.4f} vs {err_no_norm:.4f}" + )