Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions .github/check-changelog.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
#!/usr/bin/env bash
set -euo pipefail

FRAGMENTS=$(find changelog.d -type f ! -name '.gitkeep' | wc -l)
if [ "$FRAGMENTS" -eq 0 ]; then
echo "::error::No changelog fragment found in changelog.d/"
if [ "${GITHUB_EVENT_NAME:-}" = "pull_request" ] && [ -n "${GITHUB_BASE_REF:-}" ]; then
git fetch --no-tags --depth=1 origin "$GITHUB_BASE_REF:refs/remotes/origin/$GITHUB_BASE_REF"
FRAGMENTS=$(git diff --name-only --diff-filter=ACMRT "origin/$GITHUB_BASE_REF" HEAD -- changelog.d/ | grep -v '^changelog.d/.gitkeep$' || true)
else
FRAGMENTS=$(find changelog.d -type f ! -name '.gitkeep' -print)
fi

if [ -z "$FRAGMENTS" ]; then
echo "::error::No changelog fragment found in this pull request."
echo "Add one with: echo 'Description.' > changelog.d/\$(git branch --show-current).<type>.md"
echo "Types: added, changed, fixed, removed, breaking"
exit 1
Expand Down
1 change: 1 addition & 0 deletions changelog.d/codex-household-axes.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Expose axes in the US and UK household calculators.
38 changes: 38 additions & 0 deletions docs/households.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,44 @@ result = pe.us.calculate_household(
)
```

## Axes

Use `axes` to evaluate one household across a grid of input values. Pass either
the lower-level nested shape or a flat list of axis dictionaries; missing
`period` values default to `year`.

```python
result = pe.us.calculate_household(
people=[
{
"age": 35,
"employment_income": 60_000,
"is_tax_unit_head": True,
"charitable_cash_donations": 0,
}
],
tax_unit={"filing_status": "SINGLE"},
household={"state_code": "CA"},
year=2026,
axes=[
{
"name": "charitable_cash_donations",
"min": 0,
"max": 10_000,
"count": 3,
}
],
extra_variables=["charitable_cash_donations"],
)

result.person[0].charitable_cash_donations # [0, 5000, 10000]
result.tax_unit.income_tax # one value per axis point
```

When axes are present, result values are lists ordered by the axis grid instead
of scalars. For person results, each person still has their own result object;
each variable on that person is its own axis series.

## Accessing the result

```python
Expand Down
2 changes: 2 additions & 0 deletions src/policyengine/tax_benefit_models/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
their public ``calculate_household`` / ``analyze_reform`` entry points.
"""

from .axes import normalize_axes as normalize_axes
from .axes import values_for_entity as values_for_entity
from .extra_variables import dispatch_extra_variables as dispatch_extra_variables
from .model_version import (
MicrosimulationModelVersion as MicrosimulationModelVersion,
Expand Down
81 changes: 81 additions & 0 deletions src/policyengine/tax_benefit_models/common/axes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Axes helpers for household calculators."""

from __future__ import annotations

from collections.abc import Mapping
from difflib import get_close_matches
from typing import Any, Optional

from policyengine.core.tax_benefit_model_version import TaxBenefitModelVersion


def normalize_axes(
*,
axes: Optional[list[Any]],
year: int,
model_version: TaxBenefitModelVersion,
) -> Optional[list[list[dict[str, Any]]]]:
"""Validate and periodise household-calculator axes.

The country packages expect the lower-level OpenFisca/PolicyEngine Core
shape: ``[[{"name": ..., "min": ..., "max": ..., "count": ...}]]``.
For convenience, callers may also pass a flat list of axis dictionaries.
Missing ``period`` values default to the household calculator's ``year``.
"""
if axes is None:
return None
if not isinstance(axes, list) or not axes:
raise ValueError("axes must be a non-empty list of axis dictionaries.")

axis_groups = axes if isinstance(axes[0], list) else [axes]
normalized: list[list[dict[str, Any]]] = []
variables_by_name = model_version.variables_by_name

for group in axis_groups:
if not isinstance(group, list) or not group:
raise ValueError("each axes group must be a non-empty list.")

normalized_group: list[dict[str, Any]] = []
for axis in group:
if not isinstance(axis, Mapping):
raise ValueError("each axis must be a dictionary.")

axis_dict = dict(axis)
name = axis_dict.get("name")
if not isinstance(name, str):
raise ValueError("each axis must include a string 'name'.")
if name not in variables_by_name:
suggestions = get_close_matches(
name, list(variables_by_name), n=1, cutoff=0.7
)
suggestion = (
f" (did you mean '{suggestions[0]}'?)" if suggestions else ""
)
raise ValueError(
f"axis variable '{name}' is not defined on "
f"{model_version.model.id} {model_version.version}{suggestion}"
)

for required_key in ("min", "max", "count"):
if required_key not in axis_dict:
raise ValueError(f"axis '{name}' must include '{required_key}'.")

axis_dict.setdefault("period", year)
normalized_group.append(axis_dict)

normalized.append(normalized_group)

return normalized


def values_for_entity(
values: list[Any],
*,
entity_index: int,
entity_count: int,
axes_active: bool,
):
"""Return scalar or axis-series values for one entity member."""
if not axes_active:
return values[entity_index]
return values[entity_index::entity_count]
40 changes: 35 additions & 5 deletions src/policyengine/tax_benefit_models/uk/household.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
HouseholdResult,
compile_reform,
dispatch_extra_variables,
normalize_axes,
values_for_entity,
)
from policyengine.utils.household_validation import validate_household_input

Expand Down Expand Up @@ -62,6 +64,7 @@ def _build_situation(
benunit: Mapping[str, Any],
household: Mapping[str, Any],
year: int,
axes: Optional[list[Any]] = None,
) -> dict[str, Any]:
year_str = str(year)

Expand All @@ -74,15 +77,18 @@ def _periodise(spec: Mapping[str, Any]) -> dict[str, dict[str, Any]]:
def _group(spec: Mapping[str, Any]) -> dict[str, Any]:
return {"members": list(person_ids), **_periodise(spec)}

return {
situation = {
"people": persons,
"benunits": {"benunit_0": _group(benunit)},
"households": {"household_0": _group(household)},
}
if axes is not None:
situation["axes"] = axes
return situation


_ALLOWED_KWARGS = frozenset(
{"people", "benunit", "household", "year", "reform", "extra_variables"}
{"people", "benunit", "household", "year", "reform", "extra_variables", "axes"}
)


Expand All @@ -99,7 +105,7 @@ def _raise_unexpected_kwargs(unexpected: Mapping[str, Any]) -> None:
)
lines.append(f" - '{name}'{hint}")
lines.append(
"Valid kwargs: people, benunit, household, year, reform, extra_variables."
"Valid kwargs: people, benunit, household, year, reform, extra_variables, axes."
)
raise TypeError("\n".join(lines))

Expand All @@ -112,6 +118,7 @@ def calculate_household(
year: int = 2026,
reform: Optional[Mapping[str, Any]] = None,
extra_variables: Optional[list[str]] = None,
axes: Optional[list[Any]] = None,
**unexpected: Any,
) -> HouseholdResult:
"""Compute tax and benefit variables for a single UK household.
Expand All @@ -126,6 +133,11 @@ def calculate_household(
close-match suggestion.
extra_variables: Flat list of extra UK variables to compute;
the library dispatches each to its entity.
axes: Optional household-calculator axes. Pass either the lower-level
``[[{"name": ..., "min": ..., "max": ..., "count": ...}]]``
shape or a flat list of axis dictionaries. Missing ``period``
values default to ``year``. When axes are present, result values
are lists ordered by the axis grid instead of scalars.

Returns:
:class:`HouseholdResult` with dot-accessible entity results.
Expand Down Expand Up @@ -160,13 +172,16 @@ def calculate_household(
)
output_columns = _default_output_columns(extra_by_entity)
reform_dict = compile_reform(reform, year=year, model_version=uk_latest)
normalized_axes = normalize_axes(axes=axes, year=year, model_version=uk_latest)
axes_active = normalized_axes is not None

simulation = Simulation(
situation=_build_situation(
people=people,
benunit=benunit_dict,
household=household_dict,
year=year,
axes=normalized_axes,
),
reform=reform_dict,
)
Expand All @@ -180,12 +195,27 @@ def calculate_household(
if entity == "person":
result["person"] = [
EntityResult(
{variable: _safe_convert(raw[variable][i]) for variable in columns}
{
variable: values_for_entity(
[_safe_convert(value) for value in raw[variable]],
entity_index=i,
entity_count=len(people),
axes_active=axes_active,
)
for variable in columns
}
)
for i in range(len(people))
]
else:
result[entity] = EntityResult(
{variable: _safe_convert(raw[variable][0]) for variable in columns}
{
variable: (
[_safe_convert(value) for value in raw[variable]]
if axes_active
else _safe_convert(raw[variable][0])
)
for variable in columns
}
)
return result
39 changes: 35 additions & 4 deletions src/policyengine/tax_benefit_models/us/household.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
HouseholdResult,
compile_reform,
dispatch_extra_variables,
normalize_axes,
values_for_entity,
)
from policyengine.utils.household_validation import validate_household_input

Expand All @@ -65,7 +67,7 @@ def _raise_unexpected_kwargs(unexpected: Mapping[str, Any]) -> None:
lines.append(f" - '{name}'{hint}")
lines.append(
"Valid kwargs: people, marital_unit, family, spm_unit, tax_unit, "
"household, year, reform, extra_variables."
"household, year, reform, extra_variables, axes."
)
raise TypeError("\n".join(lines))

Expand Down Expand Up @@ -101,6 +103,7 @@ def _build_situation(
tax_unit: Mapping[str, Any],
household: Mapping[str, Any],
year: int,
axes: Optional[list[Any]] = None,
) -> dict[str, Any]:
year_str = str(year)

Expand All @@ -113,14 +116,17 @@ def _periodise(spec: Mapping[str, Any]) -> dict[str, dict[str, Any]]:
def _group(spec: Mapping[str, Any]) -> dict[str, Any]:
return {"members": list(person_ids), **_periodise(spec)}

return {
situation = {
"people": persons,
"marital_units": {"marital_unit_0": _group(marital_unit)},
"families": {"family_0": _group(family)},
"spm_units": {"spm_unit_0": _group(spm_unit)},
"tax_units": {"tax_unit_0": _group(tax_unit)},
"households": {"household_0": _group(household)},
}
if axes is not None:
situation["axes"] = axes
return situation


_ALLOWED_KWARGS = frozenset(
Expand All @@ -134,6 +140,7 @@ def _group(spec: Mapping[str, Any]) -> dict[str, Any]:
"year",
"reform",
"extra_variables",
"axes",
}
)

Expand All @@ -149,6 +156,7 @@ def calculate_household(
year: int = 2026,
reform: Optional[Mapping[str, Any]] = None,
extra_variables: Optional[list[str]] = None,
axes: Optional[list[Any]] = None,
**unexpected: Any,
) -> HouseholdResult:
"""Compute tax and benefit variables for a single US household.
Expand All @@ -170,6 +178,11 @@ def calculate_household(
the default output columns; the library dispatches each
name to its entity. Unknown names raise ``ValueError``
with a close-match suggestion.
axes: Optional household-calculator axes. Pass either the lower-level
``[[{"name": ..., "min": ..., "max": ..., "count": ...}]]``
shape or a flat list of axis dictionaries. Missing ``period``
values default to ``year``. When axes are present, result values
are lists ordered by the axis grid instead of scalars.

Returns:
:class:`HouseholdResult` with dot-accessible per-entity
Expand Down Expand Up @@ -211,6 +224,8 @@ def calculate_household(
)
output_columns = _default_output_columns(extra_by_entity)
reform_dict = compile_reform(reform, year=year, model_version=us_latest)
normalized_axes = normalize_axes(axes=axes, year=year, model_version=us_latest)
axes_active = normalized_axes is not None

simulation = Simulation(
situation=_build_situation(
Expand All @@ -221,6 +236,7 @@ def calculate_household(
tax_unit=entities["tax_unit"],
household=entities["household"],
year=year,
axes=normalized_axes,
),
reform=reform_dict,
)
Expand All @@ -234,12 +250,27 @@ def calculate_household(
if entity == "person":
result["person"] = [
EntityResult(
{variable: _safe_convert(raw[variable][i]) for variable in columns}
{
variable: values_for_entity(
[_safe_convert(value) for value in raw[variable]],
entity_index=i,
entity_count=len(people),
axes_active=axes_active,
)
for variable in columns
}
)
for i in range(len(people))
]
else:
result[entity] = EntityResult(
{variable: _safe_convert(raw[variable][0]) for variable in columns}
{
variable: (
[_safe_convert(value) for value in raw[variable]]
if axes_active
else _safe_convert(raw[variable][0])
)
for variable in columns
}
)
return result
Loading
Loading