From 0bf9f40a1102333e4fd64bcd2b47b32ac26c08b9 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sun, 26 Apr 2026 23:18:50 -0400 Subject: [PATCH 1/3] Expose axes in household calculators --- docs/households.md | 38 ++++++++ .../tax_benefit_models/common/__init__.py | 2 + .../tax_benefit_models/common/axes.py | 81 +++++++++++++++++ .../tax_benefit_models/uk/household.py | 40 +++++++-- .../tax_benefit_models/us/household.py | 39 +++++++- tests/test_household_impact.py | 90 +++++++++++++++++++ 6 files changed, 281 insertions(+), 9 deletions(-) create mode 100644 src/policyengine/tax_benefit_models/common/axes.py diff --git a/docs/households.md b/docs/households.md index cae9501a..29488727 100644 --- a/docs/households.md +++ b/docs/households.md @@ -103,6 +103,44 @@ result = pe.us.calculate_household( ) ``` +## Axes + +Use `axes` to evaluate one household across a grid of input values. Pass either +the lower-level nested shape or a flat list of axis dictionaries; missing +`period` values default to `year`. + +```python +result = pe.us.calculate_household( + people=[ + { + "age": 35, + "employment_income": 60_000, + "is_tax_unit_head": True, + "charitable_cash_donations": 0, + } + ], + tax_unit={"filing_status": "SINGLE"}, + household={"state_code": "CA"}, + year=2026, + axes=[ + { + "name": "charitable_cash_donations", + "min": 0, + "max": 10_000, + "count": 3, + } + ], + extra_variables=["charitable_cash_donations"], +) + +result.person[0].charitable_cash_donations # [0, 5000, 10000] +result.tax_unit.income_tax # one value per axis point +``` + +When axes are present, result values are lists ordered by the axis grid instead +of scalars. For person results, each person still has their own result object; +each variable on that person is its own axis series. + ## Accessing the result ```python diff --git a/src/policyengine/tax_benefit_models/common/__init__.py b/src/policyengine/tax_benefit_models/common/__init__.py index 654f350d..3c482ebd 100644 --- a/src/policyengine/tax_benefit_models/common/__init__.py +++ b/src/policyengine/tax_benefit_models/common/__init__.py @@ -5,6 +5,8 @@ their public ``calculate_household`` / ``analyze_reform`` entry points. """ +from .axes import normalize_axes as normalize_axes +from .axes import values_for_entity as values_for_entity from .extra_variables import dispatch_extra_variables as dispatch_extra_variables from .model_version import ( MicrosimulationModelVersion as MicrosimulationModelVersion, diff --git a/src/policyengine/tax_benefit_models/common/axes.py b/src/policyengine/tax_benefit_models/common/axes.py new file mode 100644 index 00000000..94a5facf --- /dev/null +++ b/src/policyengine/tax_benefit_models/common/axes.py @@ -0,0 +1,81 @@ +"""Axes helpers for household calculators.""" + +from __future__ import annotations + +from collections.abc import Mapping +from difflib import get_close_matches +from typing import Any, Optional + +from policyengine.core.tax_benefit_model_version import TaxBenefitModelVersion + + +def normalize_axes( + *, + axes: Optional[list[Any]], + year: int, + model_version: TaxBenefitModelVersion, +) -> Optional[list[list[dict[str, Any]]]]: + """Validate and periodise household-calculator axes. + + The country packages expect the lower-level OpenFisca/PolicyEngine Core + shape: ``[[{"name": ..., "min": ..., "max": ..., "count": ...}]]``. + For convenience, callers may also pass a flat list of axis dictionaries. + Missing ``period`` values default to the household calculator's ``year``. + """ + if axes is None: + return None + if not isinstance(axes, list) or not axes: + raise ValueError("axes must be a non-empty list of axis dictionaries.") + + axis_groups = axes if isinstance(axes[0], list) else [axes] + normalized: list[list[dict[str, Any]]] = [] + variables_by_name = model_version.variables_by_name + + for group in axis_groups: + if not isinstance(group, list) or not group: + raise ValueError("each axes group must be a non-empty list.") + + normalized_group: list[dict[str, Any]] = [] + for axis in group: + if not isinstance(axis, Mapping): + raise ValueError("each axis must be a dictionary.") + + axis_dict = dict(axis) + name = axis_dict.get("name") + if not isinstance(name, str): + raise ValueError("each axis must include a string 'name'.") + if name not in variables_by_name: + suggestions = get_close_matches( + name, list(variables_by_name), n=1, cutoff=0.7 + ) + suggestion = ( + f" (did you mean '{suggestions[0]}'?)" if suggestions else "" + ) + raise ValueError( + f"axis variable '{name}' is not defined on " + f"{model_version.model.id} {model_version.version}{suggestion}" + ) + + for required_key in ("min", "max", "count"): + if required_key not in axis_dict: + raise ValueError(f"axis '{name}' must include '{required_key}'.") + + axis_dict.setdefault("period", year) + normalized_group.append(axis_dict) + + normalized.append(normalized_group) + + return normalized + + +def values_for_entity( + values: list[Any], + *, + entity_index: int, + entity_count: int, + axes_active: bool, +): + """Return scalar or axis-series values for one entity member.""" + if not axes_active: + return values[entity_index] + return values[entity_index::entity_count] diff --git a/src/policyengine/tax_benefit_models/uk/household.py b/src/policyengine/tax_benefit_models/uk/household.py index 5dbd71bb..4ca38200 100644 --- a/src/policyengine/tax_benefit_models/uk/household.py +++ b/src/policyengine/tax_benefit_models/uk/household.py @@ -28,6 +28,8 @@ HouseholdResult, compile_reform, dispatch_extra_variables, + normalize_axes, + values_for_entity, ) from policyengine.utils.household_validation import validate_household_input @@ -62,6 +64,7 @@ def _build_situation( benunit: Mapping[str, Any], household: Mapping[str, Any], year: int, + axes: Optional[list[Any]] = None, ) -> dict[str, Any]: year_str = str(year) @@ -74,15 +77,18 @@ def _periodise(spec: Mapping[str, Any]) -> dict[str, dict[str, Any]]: def _group(spec: Mapping[str, Any]) -> dict[str, Any]: return {"members": list(person_ids), **_periodise(spec)} - return { + situation = { "people": persons, "benunits": {"benunit_0": _group(benunit)}, "households": {"household_0": _group(household)}, } + if axes is not None: + situation["axes"] = axes + return situation _ALLOWED_KWARGS = frozenset( - {"people", "benunit", "household", "year", "reform", "extra_variables"} + {"people", "benunit", "household", "year", "reform", "extra_variables", "axes"} ) @@ -99,7 +105,7 @@ def _raise_unexpected_kwargs(unexpected: Mapping[str, Any]) -> None: ) lines.append(f" - '{name}'{hint}") lines.append( - "Valid kwargs: people, benunit, household, year, reform, extra_variables." + "Valid kwargs: people, benunit, household, year, reform, extra_variables, axes." ) raise TypeError("\n".join(lines)) @@ -112,6 +118,7 @@ def calculate_household( year: int = 2026, reform: Optional[Mapping[str, Any]] = None, extra_variables: Optional[list[str]] = None, + axes: Optional[list[Any]] = None, **unexpected: Any, ) -> HouseholdResult: """Compute tax and benefit variables for a single UK household. @@ -126,6 +133,11 @@ def calculate_household( close-match suggestion. extra_variables: Flat list of extra UK variables to compute; the library dispatches each to its entity. + axes: Optional household-calculator axes. Pass either the lower-level + ``[[{"name": ..., "min": ..., "max": ..., "count": ...}]]`` + shape or a flat list of axis dictionaries. Missing ``period`` + values default to ``year``. When axes are present, result values + are lists ordered by the axis grid instead of scalars. Returns: :class:`HouseholdResult` with dot-accessible entity results. @@ -160,6 +172,8 @@ def calculate_household( ) output_columns = _default_output_columns(extra_by_entity) reform_dict = compile_reform(reform, year=year, model_version=uk_latest) + normalized_axes = normalize_axes(axes=axes, year=year, model_version=uk_latest) + axes_active = normalized_axes is not None simulation = Simulation( situation=_build_situation( @@ -167,6 +181,7 @@ def calculate_household( benunit=benunit_dict, household=household_dict, year=year, + axes=normalized_axes, ), reform=reform_dict, ) @@ -180,12 +195,27 @@ def calculate_household( if entity == "person": result["person"] = [ EntityResult( - {variable: _safe_convert(raw[variable][i]) for variable in columns} + { + variable: values_for_entity( + [_safe_convert(value) for value in raw[variable]], + entity_index=i, + entity_count=len(people), + axes_active=axes_active, + ) + for variable in columns + } ) for i in range(len(people)) ] else: result[entity] = EntityResult( - {variable: _safe_convert(raw[variable][0]) for variable in columns} + { + variable: ( + [_safe_convert(value) for value in raw[variable]] + if axes_active + else _safe_convert(raw[variable][0]) + ) + for variable in columns + } ) return result diff --git a/src/policyengine/tax_benefit_models/us/household.py b/src/policyengine/tax_benefit_models/us/household.py index 5258043a..7a933a20 100644 --- a/src/policyengine/tax_benefit_models/us/household.py +++ b/src/policyengine/tax_benefit_models/us/household.py @@ -45,6 +45,8 @@ HouseholdResult, compile_reform, dispatch_extra_variables, + normalize_axes, + values_for_entity, ) from policyengine.utils.household_validation import validate_household_input @@ -65,7 +67,7 @@ def _raise_unexpected_kwargs(unexpected: Mapping[str, Any]) -> None: lines.append(f" - '{name}'{hint}") lines.append( "Valid kwargs: people, marital_unit, family, spm_unit, tax_unit, " - "household, year, reform, extra_variables." + "household, year, reform, extra_variables, axes." ) raise TypeError("\n".join(lines)) @@ -101,6 +103,7 @@ def _build_situation( tax_unit: Mapping[str, Any], household: Mapping[str, Any], year: int, + axes: Optional[list[Any]] = None, ) -> dict[str, Any]: year_str = str(year) @@ -113,7 +116,7 @@ def _periodise(spec: Mapping[str, Any]) -> dict[str, dict[str, Any]]: def _group(spec: Mapping[str, Any]) -> dict[str, Any]: return {"members": list(person_ids), **_periodise(spec)} - return { + situation = { "people": persons, "marital_units": {"marital_unit_0": _group(marital_unit)}, "families": {"family_0": _group(family)}, @@ -121,6 +124,9 @@ def _group(spec: Mapping[str, Any]) -> dict[str, Any]: "tax_units": {"tax_unit_0": _group(tax_unit)}, "households": {"household_0": _group(household)}, } + if axes is not None: + situation["axes"] = axes + return situation _ALLOWED_KWARGS = frozenset( @@ -134,6 +140,7 @@ def _group(spec: Mapping[str, Any]) -> dict[str, Any]: "year", "reform", "extra_variables", + "axes", } ) @@ -149,6 +156,7 @@ def calculate_household( year: int = 2026, reform: Optional[Mapping[str, Any]] = None, extra_variables: Optional[list[str]] = None, + axes: Optional[list[Any]] = None, **unexpected: Any, ) -> HouseholdResult: """Compute tax and benefit variables for a single US household. @@ -170,6 +178,11 @@ def calculate_household( the default output columns; the library dispatches each name to its entity. Unknown names raise ``ValueError`` with a close-match suggestion. + axes: Optional household-calculator axes. Pass either the lower-level + ``[[{"name": ..., "min": ..., "max": ..., "count": ...}]]`` + shape or a flat list of axis dictionaries. Missing ``period`` + values default to ``year``. When axes are present, result values + are lists ordered by the axis grid instead of scalars. Returns: :class:`HouseholdResult` with dot-accessible per-entity @@ -211,6 +224,8 @@ def calculate_household( ) output_columns = _default_output_columns(extra_by_entity) reform_dict = compile_reform(reform, year=year, model_version=us_latest) + normalized_axes = normalize_axes(axes=axes, year=year, model_version=us_latest) + axes_active = normalized_axes is not None simulation = Simulation( situation=_build_situation( @@ -221,6 +236,7 @@ def calculate_household( tax_unit=entities["tax_unit"], household=entities["household"], year=year, + axes=normalized_axes, ), reform=reform_dict, ) @@ -234,12 +250,27 @@ def calculate_household( if entity == "person": result["person"] = [ EntityResult( - {variable: _safe_convert(raw[variable][i]) for variable in columns} + { + variable: values_for_entity( + [_safe_convert(value) for value in raw[variable]], + entity_index=i, + entity_count=len(people), + axes_active=axes_active, + ) + for variable in columns + } ) for i in range(len(people)) ] else: result[entity] = EntityResult( - {variable: _safe_convert(raw[variable][0]) for variable in columns} + { + variable: ( + [_safe_convert(value) for value in raw[variable]] + if axes_active + else _safe_convert(raw[variable][0]) + ) + for variable in columns + } ) return result diff --git a/tests/test_household_impact.py b/tests/test_household_impact.py index d99d144b..668b5fe2 100644 --- a/tests/test_household_impact.py +++ b/tests/test_household_impact.py @@ -65,6 +65,30 @@ def test__reform_changes_child_benefit__then_dict_compiles_and_applies(self): assert isinstance(reformed.benunit.child_benefit, float) assert isinstance(baseline.benunit.child_benefit, float) + def test__axes__then_result_values_are_axis_series(self): + result = pe.uk.calculate_household( + people=[ + { + "age": 35, + "employment_income": 50000, + "gift_aid": 0, + } + ], + year=2026, + axes=[ + { + "name": "gift_aid", + "min": 0, + "max": 10000, + "count": 3, + } + ], + extra_variables=["gift_aid"], + ) + assert result.person[0].gift_aid == [0, 5000, 10000] + assert len(result.person[0].income_tax) == 3 + assert len(result.household.household_tax) == 3 + class TestUSCalculateHousehold: def test__single_adult__then_returns_result_with_net_income(self): @@ -111,6 +135,57 @@ def test__extra_variables_flat_list__then_values_appear_on_entity(self): assert "adjusted_gross_income" in result.tax_unit assert result.tax_unit.adjusted_gross_income > 0 + def test__axes__then_result_values_are_axis_series(self): + result = pe.us.calculate_household( + people=[ + { + "age": 35, + "employment_income": 60000, + "is_tax_unit_head": True, + "charitable_cash_donations": 0, + } + ], + tax_unit={"filing_status": "SINGLE"}, + household={"state_code": "CA"}, + year=2026, + axes=[ + { + "name": "charitable_cash_donations", + "min": 0, + "max": 10000, + "count": 3, + } + ], + extra_variables=["charitable_cash_donations"], + ) + assert result.person[0].charitable_cash_donations == [0, 5000, 10000] + assert result.tax_unit.income_tax == [5020, 4900, 4900] + assert len(result.household.household_net_income) == 3 + + def test__nested_axes_shape__then_supported(self): + result = pe.us.calculate_household( + people=[ + { + "age": 35, + "employment_income": 0, + "is_tax_unit_head": True, + } + ], + tax_unit={"filing_status": "SINGLE"}, + year=2026, + axes=[ + [ + { + "name": "employment_income", + "min": 0, + "max": 10000, + "count": 2, + } + ] + ], + ) + assert result.person[0].employment_income == [0, 10000] + def test__reform_compiles_effective_date_form(self): result = pe.us.calculate_household( people=[{"age": 30, "is_tax_unit_head": True}], @@ -165,6 +240,21 @@ def test__unknown_reform_path__then_raises_with_close_match(self): reform={"gov.irs.not_a_real_parameter": 0}, ) + def test__unknown_axis_variable__then_raises_with_suggestion(self): + with pytest.raises(ValueError, match="axis variable"): + pe.us.calculate_household( + people=[{"age": 35, "is_tax_unit_head": True}], + year=2026, + axes=[ + { + "name": "employment_incme", + "min": 0, + "max": 10000, + "count": 3, + } + ], + ) + def test__us_kwarg_on_uk__then_raises_with_uk_hint(self): with pytest.raises(TypeError, match="US-only"): pe.uk.calculate_household( From 624b5e4b67d144d2bf4c6f75e8099529997f4b1a Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sun, 26 Apr 2026 23:20:45 -0400 Subject: [PATCH 2/3] Add changelog fragment for household axes --- changelog.d/codex-household-axes.added.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/codex-household-axes.added.md diff --git a/changelog.d/codex-household-axes.added.md b/changelog.d/codex-household-axes.added.md new file mode 100644 index 00000000..e2d15b5d --- /dev/null +++ b/changelog.d/codex-household-axes.added.md @@ -0,0 +1 @@ +Expose axes in the US and UK household calculators. From ad2e3e05b0b57cf3c3881ed5781ddae1366dba9e Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sun, 26 Apr 2026 23:27:23 -0400 Subject: [PATCH 3/3] Require changelog fragments from PR diff --- .github/check-changelog.sh | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/check-changelog.sh b/.github/check-changelog.sh index 7e9e5dd3..855b8437 100755 --- a/.github/check-changelog.sh +++ b/.github/check-changelog.sh @@ -1,9 +1,15 @@ #!/usr/bin/env bash set -euo pipefail -FRAGMENTS=$(find changelog.d -type f ! -name '.gitkeep' | wc -l) -if [ "$FRAGMENTS" -eq 0 ]; then - echo "::error::No changelog fragment found in changelog.d/" +if [ "${GITHUB_EVENT_NAME:-}" = "pull_request" ] && [ -n "${GITHUB_BASE_REF:-}" ]; then + git fetch --no-tags --depth=1 origin "$GITHUB_BASE_REF:refs/remotes/origin/$GITHUB_BASE_REF" + FRAGMENTS=$(git diff --name-only --diff-filter=ACMRT "origin/$GITHUB_BASE_REF" HEAD -- changelog.d/ | grep -v '^changelog.d/.gitkeep$' || true) +else + FRAGMENTS=$(find changelog.d -type f ! -name '.gitkeep' -print) +fi + +if [ -z "$FRAGMENTS" ]; then + echo "::error::No changelog fragment found in this pull request." echo "Add one with: echo 'Description.' > changelog.d/\$(git branch --show-current)..md" echo "Types: added, changed, fixed, removed, breaking" exit 1