diff --git a/.github/workflows/site-snapshot.yml b/.github/workflows/site-snapshot.yml index 76dd0e1..cab4b56 100644 --- a/.github/workflows/site-snapshot.yml +++ b/.github/workflows/site-snapshot.yml @@ -41,6 +41,7 @@ jobs: run: | uv run --extra dev --with pydantic --with-editable ../microplex pytest -q \ tests/test_package_imports.py \ + tests/test_calibration_harness.py \ tests/targets/test_supabase.py \ tests/pipelines/test_check_site_snapshot.py \ tests/pipelines/test_imputation_ablation.py \ diff --git a/src/microplex_us/calibration_harness.py b/src/microplex_us/calibration_harness.py index ae9653b..4066389 100644 --- a/src/microplex_us/calibration_harness.py +++ b/src/microplex_us/calibration_harness.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Any import numpy as np import pandas as pd @@ -11,6 +12,8 @@ FilterOperator, TargetAggregation, TargetFilter, + TargetProvider, + TargetQuery, TargetSpec, ) @@ -19,7 +22,10 @@ TargetLevel, TargetRegistry, get_registry, + target_available_in_cps, target_category, + target_group_name, + target_level, target_requires_imputation, ) @@ -59,10 +65,50 @@ def summary(self) -> str: class CalibrationHarness: """Harness for calibration experiments over one entity frame at a time.""" - def __init__(self, registry: TargetRegistry | None = None): - self.registry = registry or get_registry() + def __init__( + self, + registry: TargetRegistry | None = None, + *, + target_provider: TargetProvider | None = None, + ): + if target_provider is None: + self.registry = registry or get_registry() + self.target_provider = self.registry + else: + self.registry = registry + self.target_provider = target_provider self._results: dict[str, CalibrationResult] = {} + def select_targets( + self, + *, + categories: list[TargetCategory] | None = None, + levels: list[TargetLevel] | None = None, + groups: list[str] | None = None, + only_available: bool = False, + entity: EntityType | str | None = None, + period: int | str | None = None, + provider_filters: dict[str, Any] | None = None, + ) -> list[TargetSpec]: + """Select canonical targets from the configured provider.""" + query = TargetQuery( + period=period, + entity=entity, + provider_filters=dict(provider_filters or {}), + ) + targets = self.target_provider.load_target_set(query).targets + return [ + target + for target in targets + if _matches_us_target_filters( + target, + categories=categories, + levels=levels, + groups=groups, + only_available=only_available, + ) + ] + def get_target_vector( self, df: pd.DataFrame, @@ -202,15 +248,19 @@ def run_experiment( groups: list[str] | None = None, only_available: bool = False, entity: EntityType | str | None = None, + period: int | str | None = None, + provider_filters: dict[str, Any] | None = None, **calibrate_kwargs, ) -> CalibrationResult: """Run a calibration experiment over a filtered target subset.""" - selected = self.registry.select_targets( + selected = self.select_targets( categories=categories, levels=levels, groups=groups, only_available=only_available, entity=entity, + period=period, + provider_filters=provider_filters, ) selected = [ target @@ -262,7 +312,7 @@ def print_target_coverage( print("TARGET COVERAGE ANALYSIS") print("=" * 70) - all_targets = self.registry.select_targets(entity=entity) + all_targets = self.select_targets(entity=entity) columns = set(df.columns) available: list[TargetSpec] = [] @@ -388,6 +438,25 @@ def _weight_stats(weights: np.ndarray) -> dict[str, float]: } +def _matches_us_target_filters( + target: TargetSpec, + *, + categories: list[TargetCategory] | None = None, + levels: list[TargetLevel] | None = None, + groups: list[str] | None = None, + only_available: bool = False, +) -> bool: + if categories and target_category(target) not in categories: + return False + if levels and target_level(target) not in levels: + return False + if groups and target_group_name(target) not in groups: + return False + if only_available and not target_available_in_cps(target): + return False + return True + + def _build_constraint_row(df: pd.DataFrame, spec: TargetSpec) -> np.ndarray: if spec.aggregation is TargetAggregation.MEAN: raise NotImplementedError("Mean targets are not supported by this harness") diff --git a/tests/targets/test_supabase.py b/tests/targets/test_supabase.py index ead8ca9..e02a30f 100644 --- a/tests/targets/test_supabase.py +++ b/tests/targets/test_supabase.py @@ -5,10 +5,12 @@ from dataclasses import dataclass from typing import Any +import pandas as pd import pytest from microplex.core import EntityType from microplex.targets import FilterOperator, TargetAggregation, TargetQuery +from microplex_us.calibration_harness import CalibrationHarness from microplex_us.supabase_targets import ( SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY, SUPABASE_TARGET_TYPE_KEY, @@ -393,3 +395,52 @@ def test_load_target_set_filters_rows_with_core_query( assert [target.name for target in target_set.targets] == ["employment_income"] assert calls[0]["params"]["period"] == "eq.2024" + + +def test_calibration_harness_can_use_supabase_target_provider( + provider: SupabaseTargetProvider, + request_queue, +) -> None: + request_queue( + [ + { + "id": "target-1", + "variable": "employment_income", + "value": 30, + "target_type": "amount", + "period": 2024, + "source": {"name": "IRS SOI", "institution": "IRS"}, + "stratum": {"name": "National", "jurisdiction": "us"}, + }, + { + "id": "target-2", + "variable": "unknown_cash_income", + "value": 100, + "target_type": "amount", + "period": 2024, + "source": {"name": "Unknown", "institution": "Other"}, + "stratum": {"name": "National", "jurisdiction": "us"}, + }, + ] + ) + harness = CalibrationHarness(target_provider=provider) + frame = pd.DataFrame( + { + "employment_income": [10.0, 20.0], + "weight": [1.0, 1.0], + } + ) + + result = harness.run_experiment( + frame, + "supabase_income", + categories=[TargetCategory.INCOME], + only_available=True, + period=2024, + provider_filters={"include_unsupported": False}, + entity=EntityType.PERSON, + verbose=False, + ) + + assert result.targets_used == ["employment_income"] + assert result.errors == {"employment_income": 0.0} diff --git a/tests/test_calibration_harness.py b/tests/test_calibration_harness.py index 64bc044..bdd5cc4 100644 --- a/tests/test_calibration_harness.py +++ b/tests/test_calibration_harness.py @@ -3,7 +3,13 @@ import numpy as np import pandas as pd from microplex.core import EntityType -from microplex.targets import TargetAggregation, TargetFilter, TargetSpec +from microplex.targets import ( + StaticTargetProvider, + TargetAggregation, + TargetFilter, + TargetSet, + TargetSpec, +) from microplex_us.calibration_harness import CalibrationHarness from microplex_us.target_registry import ( @@ -101,3 +107,58 @@ def test_run_experiment_filters_to_selected_canonical_targets(self): assert result.targets_used == ["ca_people", "ca_income"] np.testing.assert_allclose(result.weights, np.ones(3)) + + def test_run_experiment_can_use_core_target_provider(self): + targets = _make_registry().get_all_targets() + [ + TargetSpec( + name="future_income", + entity=EntityType.PERSON, + value=50.0, + period=2025, + measure="employment_income", + aggregation=TargetAggregation.SUM, + metadata={ + "us_category": "income", + "us_level": "national", + "us_group": "future", + "available_in_cps": True, + "requires_imputation": False, + }, + ) + ] + provider = StaticTargetProvider(TargetSet(targets)) + harness = CalibrationHarness(target_provider=provider) + df = pd.DataFrame( + { + "state_fips": ["06", "06", "08"], + "employment_income": [10.0, 20.0, 5.0], + "weight": [1.0, 1.0, 1.0], + } + ) + + result = harness.run_experiment( + df, + "provider_people_only", + groups=["people"], + only_available=True, + period=2024, + entity=EntityType.PERSON, + verbose=False, + ) + + assert result.targets_used == ["ca_people", "ca_income"] + + def test_print_target_coverage_can_use_core_target_provider(self, capsys): + provider = StaticTargetProvider(TargetSet(_make_registry().get_all_targets())) + harness = CalibrationHarness(target_provider=provider) + df = pd.DataFrame( + { + "state_fips": ["06", "06", "08"], + "employment_income": [10.0, 20.0, 5.0], + "weight": [1.0, 1.0, 1.0], + } + ) + + harness.print_target_coverage(df, entity=EntityType.PERSON) + + assert "Available (2 targets)" in capsys.readouterr().out