Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/site-snapshot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ jobs:
run: |
uv run --extra dev --with pydantic --with-editable ../microplex pytest -q \
tests/test_package_imports.py \
tests/test_calibration_harness.py \
tests/targets/test_supabase.py \
tests/pipelines/test_check_site_snapshot.py \
tests/pipelines/test_imputation_ablation.py \
Expand Down
77 changes: 73 additions & 4 deletions src/microplex_us/calibration_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any

import numpy as np
import pandas as pd
Expand All @@ -11,6 +12,8 @@
FilterOperator,
TargetAggregation,
TargetFilter,
TargetProvider,
TargetQuery,
TargetSpec,
)

Expand All @@ -19,7 +22,10 @@
TargetLevel,
TargetRegistry,
get_registry,
target_available_in_cps,
target_category,
target_group_name,
target_level,
target_requires_imputation,
)

Expand Down Expand Up @@ -59,10 +65,50 @@ def summary(self) -> str:
class CalibrationHarness:
"""Harness for calibration experiments over one entity frame at a time."""

def __init__(self, registry: TargetRegistry | None = None):
self.registry = registry or get_registry()
def __init__(
self,
registry: TargetRegistry | None = None,
*,
target_provider: TargetProvider | None = None,
):
if target_provider is None:
self.registry = registry or get_registry()
self.target_provider = self.registry
else:
self.registry = registry
self.target_provider = target_provider
self._results: dict[str, CalibrationResult] = {}

def select_targets(
self,
*,
categories: list[TargetCategory] | None = None,
levels: list[TargetLevel] | None = None,
groups: list[str] | None = None,
only_available: bool = False,
entity: EntityType | str | None = None,
period: int | str | None = None,
provider_filters: dict[str, Any] | None = None,
) -> list[TargetSpec]:
"""Select canonical targets from the configured provider."""
query = TargetQuery(
period=period,
entity=entity,
provider_filters=dict(provider_filters or {}),
)
targets = self.target_provider.load_target_set(query).targets
return [
target
for target in targets
if _matches_us_target_filters(
target,
categories=categories,
levels=levels,
groups=groups,
only_available=only_available,
)
]

def get_target_vector(
self,
df: pd.DataFrame,
Expand Down Expand Up @@ -202,15 +248,19 @@ def run_experiment(
groups: list[str] | None = None,
only_available: bool = False,
entity: EntityType | str | None = None,
period: int | str | None = None,
provider_filters: dict[str, Any] | None = None,
**calibrate_kwargs,
) -> CalibrationResult:
"""Run a calibration experiment over a filtered target subset."""
selected = self.registry.select_targets(
selected = self.select_targets(
categories=categories,
levels=levels,
groups=groups,
only_available=only_available,
entity=entity,
period=period,
provider_filters=provider_filters,
)
selected = [
target
Expand Down Expand Up @@ -262,7 +312,7 @@ def print_target_coverage(
print("TARGET COVERAGE ANALYSIS")
print("=" * 70)

all_targets = self.registry.select_targets(entity=entity)
all_targets = self.select_targets(entity=entity)
columns = set(df.columns)

available: list[TargetSpec] = []
Expand Down Expand Up @@ -388,6 +438,25 @@ def _weight_stats(weights: np.ndarray) -> dict[str, float]:
}


def _matches_us_target_filters(
target: TargetSpec,
*,
categories: list[TargetCategory] | None = None,
levels: list[TargetLevel] | None = None,
groups: list[str] | None = None,
only_available: bool = False,
) -> bool:
if categories and target_category(target) not in categories:
return False
if levels and target_level(target) not in levels:
return False
if groups and target_group_name(target) not in groups:
return False
if only_available and not target_available_in_cps(target):
return False
return True


def _build_constraint_row(df: pd.DataFrame, spec: TargetSpec) -> np.ndarray:
if spec.aggregation is TargetAggregation.MEAN:
raise NotImplementedError("Mean targets are not supported by this harness")
Expand Down
51 changes: 51 additions & 0 deletions tests/targets/test_supabase.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from dataclasses import dataclass
from typing import Any

import pandas as pd
import pytest
from microplex.core import EntityType
from microplex.targets import FilterOperator, TargetAggregation, TargetQuery

from microplex_us.calibration_harness import CalibrationHarness
from microplex_us.supabase_targets import (
SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY,
SUPABASE_TARGET_TYPE_KEY,
Expand Down Expand Up @@ -393,3 +395,52 @@ def test_load_target_set_filters_rows_with_core_query(

assert [target.name for target in target_set.targets] == ["employment_income"]
assert calls[0]["params"]["period"] == "eq.2024"


def test_calibration_harness_can_use_supabase_target_provider(
provider: SupabaseTargetProvider,
request_queue,
) -> None:
request_queue(
[
{
"id": "target-1",
"variable": "employment_income",
"value": 30,
"target_type": "amount",
"period": 2024,
"source": {"name": "IRS SOI", "institution": "IRS"},
"stratum": {"name": "National", "jurisdiction": "us"},
},
{
"id": "target-2",
"variable": "unknown_cash_income",
"value": 100,
"target_type": "amount",
"period": 2024,
"source": {"name": "Unknown", "institution": "Other"},
"stratum": {"name": "National", "jurisdiction": "us"},
},
]
)
harness = CalibrationHarness(target_provider=provider)
frame = pd.DataFrame(
{
"employment_income": [10.0, 20.0],
"weight": [1.0, 1.0],
}
)

result = harness.run_experiment(
frame,
"supabase_income",
categories=[TargetCategory.INCOME],
only_available=True,
period=2024,
provider_filters={"include_unsupported": False},
entity=EntityType.PERSON,
verbose=False,
)

assert result.targets_used == ["employment_income"]
assert result.errors == {"employment_income": 0.0}
63 changes: 62 additions & 1 deletion tests/test_calibration_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import numpy as np
import pandas as pd
from microplex.core import EntityType
from microplex.targets import TargetAggregation, TargetFilter, TargetSpec
from microplex.targets import (
StaticTargetProvider,
TargetAggregation,
TargetFilter,
TargetSet,
TargetSpec,
)

from microplex_us.calibration_harness import CalibrationHarness
from microplex_us.target_registry import (
Expand Down Expand Up @@ -101,3 +107,58 @@ def test_run_experiment_filters_to_selected_canonical_targets(self):

assert result.targets_used == ["ca_people", "ca_income"]
np.testing.assert_allclose(result.weights, np.ones(3))

def test_run_experiment_can_use_core_target_provider(self):
targets = _make_registry().get_all_targets() + [
TargetSpec(
name="future_income",
entity=EntityType.PERSON,
value=50.0,
period=2025,
measure="employment_income",
aggregation=TargetAggregation.SUM,
metadata={
"us_category": "income",
"us_level": "national",
"us_group": "future",
"available_in_cps": True,
"requires_imputation": False,
},
)
]
provider = StaticTargetProvider(TargetSet(targets))
harness = CalibrationHarness(target_provider=provider)
df = pd.DataFrame(
{
"state_fips": ["06", "06", "08"],
"employment_income": [10.0, 20.0, 5.0],
"weight": [1.0, 1.0, 1.0],
}
)

result = harness.run_experiment(
df,
"provider_people_only",
groups=["people"],
only_available=True,
period=2024,
entity=EntityType.PERSON,
verbose=False,
)

assert result.targets_used == ["ca_people", "ca_income"]

def test_print_target_coverage_can_use_core_target_provider(self, capsys):
provider = StaticTargetProvider(TargetSet(_make_registry().get_all_targets()))
harness = CalibrationHarness(target_provider=provider)
df = pd.DataFrame(
{
"state_fips": ["06", "06", "08"],
"employment_income": [10.0, 20.0, 5.0],
"weight": [1.0, 1.0, 1.0],
}
)

harness.print_target_coverage(df, entity=EntityType.PERSON)

assert "Available (2 targets)" in capsys.readouterr().out
Loading