Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/3487.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Resolve immutable PolicyEngine v4 runtime bundles for simulations.
219 changes: 219 additions & 0 deletions policyengine_api/libs/runtime_bundle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
"""Resolve reproducibility metadata for simulation API runs.

The API has two dataset identities to keep straight:

* the worker URI, usually ``gs://...``, which Modal can read efficiently;
* the canonical artifact URI, usually ``hf://...@version``, which belongs in
provenance and cache keys.

This module centralizes that mapping so economy runs do not silently lose the
data version before they are cached or sent to the worker.
"""

from __future__ import annotations

import hashlib
import json
from importlib.metadata import PackageNotFoundError, version
from typing import Any

from pydantic import BaseModel

from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS

try:
from policyengine.provenance.manifest import (
CountryReleaseManifest,
get_release_manifest,
resolve_dataset_reference,
resolve_region_dataset_path,
)
except Exception: # pragma: no cover - exercised only if pe.py import is broken.
CountryReleaseManifest = Any # type: ignore[misc,assignment]
get_release_manifest = None
resolve_dataset_reference = None
resolve_region_dataset_path = None


_HF_TO_GS_REPOS = {
"policyengine/policyengine-us-data": "gs://policyengine-us-data",
"policyengine/policyengine-uk-data-private": "gs://policyengine-uk-data-private",
}


class RuntimeBundle(BaseModel):
country_id: str
bundle_id: str | None = None
policyengine_version: str | None = None
bundle_policyengine_version: str | None = None
model_package: str | None = None
model_version: str | None = None
certified_model_version: str | None = None
data_package: str | None = None
data_version: str | None = None
dataset: str
canonical_dataset_uri: str | None = None
worker_dataset_uri: str
dataset_sha256: str | None = None
compatibility_basis: str | None = None
certified_by: str | None = None
provenance_status: str

@property
def fingerprint(self) -> str:
payload = self.model_dump(mode="json")
encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode(
"utf-8"
)
return f"sha256:{hashlib.sha256(encoded).hexdigest()}"

def as_payload(self) -> dict[str, Any]:
payload = self.model_dump(mode="json")
payload["fingerprint"] = self.fingerprint
return payload


def _package_version(name: str) -> str | None:
try:
return version(name)
except PackageNotFoundError:
return None


def _strip_hf_revision(uri: str) -> tuple[str, str | None]:
if "@" not in uri:
return uri, None
base, revision = uri.rsplit("@", 1)
return base, revision


def _hf_to_worker_uri(uri: str) -> str:
if not uri.startswith("hf://"):
return uri

without_scheme, _revision = _strip_hf_revision(uri[5:])
parts = without_scheme.split("/", 2)
if len(parts) != 3:
return uri
repo_id = f"{parts[0]}/{parts[1]}"
path = parts[2]
bucket = _HF_TO_GS_REPOS.get(repo_id)
if bucket is None:
return uri
return f"{bucket}/{path}"


def _canonical_region_dataset_uri(
country_id: str,
region: str,
) -> str | None:
if resolve_region_dataset_path is None:
return None

try:
if region == country_id:
return resolve_region_dataset_path(country_id, "national")

if country_id == "us" and region.startswith("state/"):
return resolve_region_dataset_path(
"us",
"state",
state_code=region.split("/", 1)[1].upper(),
)

if country_id == "us" and region.startswith("congressional_district/"):
return resolve_region_dataset_path(
"us",
"congressional_district",
district_code=region.split("/", 1)[1].upper(),
)

if country_id == "us" and region.startswith("place/"):
parent_state = region.split("/", 1)[1].split("-", 1)[0].upper()
return resolve_region_dataset_path("us", "state", state_code=parent_state)
except Exception:
return None

return None


def _manifest_for(country_id: str) -> CountryReleaseManifest | None:
if get_release_manifest is None:
return None
try:
return get_release_manifest(country_id)
except Exception:
return None


def resolve_runtime_bundle(
*,
country_id: str,
region: str,
dataset: str,
requested_model_version: str | None = None,
) -> RuntimeBundle:
"""Resolve the model/data bundle for an API simulation request."""

manifest = _manifest_for(country_id)
canonical_dataset_uri: str | None = None
dataset_name = dataset
provenance_status = "unmanaged"

if "://" in dataset:
canonical_dataset_uri = dataset
elif dataset == "default":
canonical_dataset_uri = _canonical_region_dataset_uri(country_id, region)
dataset_name = (
manifest.default_dataset
if manifest is not None and region == country_id
else dataset
)
provenance_status = "managed"
elif resolve_dataset_reference is not None:
try:
canonical_dataset_uri = resolve_dataset_reference(country_id, dataset)
provenance_status = "managed"
except Exception:
canonical_dataset_uri = None

if canonical_dataset_uri is None:
worker_dataset_uri = dataset
data_version = None
else:
worker_dataset_uri = _hf_to_worker_uri(canonical_dataset_uri)
_base_uri, data_version = _strip_hf_revision(canonical_dataset_uri)

certified_artifact = (
manifest.certified_data_artifact if manifest is not None else None
)
certification = manifest.certification if manifest is not None else None
model_version = requested_model_version or COUNTRY_PACKAGE_VERSIONS.get(country_id)

return RuntimeBundle(
country_id=country_id,
bundle_id=manifest.bundle_id if manifest is not None else None,
policyengine_version=_package_version("policyengine"),
bundle_policyengine_version=(
manifest.policyengine_version if manifest is not None else None
),
model_package=(manifest.model_package.name if manifest is not None else None),
model_version=model_version,
certified_model_version=(
manifest.model_package.version if manifest is not None else None
),
data_package=manifest.data_package.name if manifest is not None else None,
data_version=(
data_version
or (manifest.data_package.version if manifest is not None else None)
),
dataset=dataset_name,
canonical_dataset_uri=canonical_dataset_uri,
worker_dataset_uri=worker_dataset_uri,
dataset_sha256=certified_artifact.sha256 if certified_artifact else None,
compatibility_basis=(
certification.compatibility_basis if certification is not None else None
),
certified_by=certification.certified_by if certification is not None else None,
provenance_status=provenance_status,
)
2 changes: 0 additions & 2 deletions policyengine_api/libs/simulation_api_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ def run(self, payload: dict) -> ModalSimulationExecution:
modal_payload = dict(payload)
if "model_version" in modal_payload:
modal_payload["version"] = modal_payload.pop("model_version")
# Remove data_version as Modal doesn't use it
modal_payload.pop("data_version", None)

response = self.client.post(
f"{self.base_url}/simulate/economy/comparison",
Expand Down
127 changes: 127 additions & 0 deletions policyengine_api/libs/simulation_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Api-local types for the Modal simulation-worker request payload.

Historically these lived on the pre-v4 ``policyengine`` package and
were imported via ``from policyengine.simulation import
SimulationOptions``. The v4 rearchitecture dropped that module
(``policyengine.core.simulation.Simulation`` has a different shape),
but the wire contract between the api and the Modal simulation
worker has not changed: the worker still deserializes a dict with
the fields below, regardless of which pe.py version assembles it.

Owning the type api-side removes the coupling to pe.py's internal
class layout. The worker's own pe.py version can evolve
independently as long as it accepts this JSON shape.

Field semantics match the pre-v4 ``SimulationOptions``:
country Country id ("us", "uk", "canada", "ng", "il").
scope "macro" (society-wide) or "household".
reform Reform policy JSON dict.
baseline Baseline policy JSON dict.
time_period Simulation year as a string.
include_cliffs Whether the worker should compute cliff-impact.
region Region identifier — "us", "state/CA", etc.
data Dataset URI or keyword ("default", "enhanced_cps",
a GCS URI, or a passthrough keyword understood
directly by the worker).
model_version Optional pin for the country model package version.
data_version Optional pin for the country data package version.
"""

from __future__ import annotations

from typing import Any, Literal, Optional

from pydantic import BaseModel

from policyengine_api.libs.runtime_bundle import resolve_runtime_bundle


class SimulationOptions(BaseModel):
"""Payload for a Modal simulation worker job.

Schema is kept deliberately permissive on ``reform`` / ``baseline``
so the worker receives whatever policy JSON the api policy service
produced without shape enforcement at this boundary.
"""

country: str
scope: Literal["macro", "household"]
reform: dict[str, Any]
baseline: dict[str, Any]
time_period: str
include_cliffs: bool = False
region: str
data: str
model_version: Optional[str] = None
data_version: Optional[str] = None


# GCS-hosted artifact bucket names per country. The Modal simulation
# worker reads these paths directly; this is an infrastructure
# contract distinct from the HuggingFace-hosted canonical release
# manifest that ``policyengine.py`` resolves (see
# ``policyengine.provenance.manifest.resolve_managed_dataset_reference``).
# State and congressional-district regions each have their own h5
# artifact under ``states/`` and ``districts/`` respectively;
# ``place/`` regions reuse the parent state's h5.
_US_DATA_BUCKET = "gs://policyengine-us-data"
_UK_DATA_BUCKET = "gs://policyengine-uk-data-private"


def _resolve_us_region_dataset(region: str) -> str:
if region == "us":
return f"{_US_DATA_BUCKET}/enhanced_cps_2024.h5"
if region.startswith("state/"):
state_code = region.split("/", 1)[1].upper()
return f"{_US_DATA_BUCKET}/states/{state_code}.h5"
if region.startswith("congressional_district/"):
district_id = region.split("/", 1)[1].upper()
return f"{_US_DATA_BUCKET}/districts/{district_id}.h5"
if region.startswith("place/"):
# A ``place/NJ-57000`` region reuses the parent state's h5.
place_id = region.split("/", 1)[1]
parent_state = place_id.split("-", 1)[0].upper()
return f"{_US_DATA_BUCKET}/states/{parent_state}.h5"
raise ValueError(f"Unknown US region for dataset resolution: {region!r}")


def _resolve_uk_region_dataset(region: str) -> str:
if region == "uk":
return f"{_UK_DATA_BUCKET}/enhanced_frs_2023_24.h5"
raise ValueError(f"Unknown UK region for dataset resolution: {region!r}")


def get_default_dataset(country_id: str, region: str) -> str:
"""Resolve the default dataset URI for a country + region pair.

Returns a ``gs://...`` URI that the Modal simulation worker reads
directly. Naming conventions preserved from the pre-v4
``policyengine.utils.data.datasets.get_default_dataset`` helper:

us + "us" -> enhanced_cps_2024.h5
us + "state/CA" -> states/CA.h5
us + "congressional_district/CA-37" -> districts/CA-37.h5
us + "place/NJ-57000" -> states/NJ.h5
uk + "uk" -> enhanced_frs_2023_24.h5

Args:
country_id: Country id.
region: Region identifier.

Raises:
ValueError: country has no configured resolver, or region is
not recognized.
"""
resolved = resolve_runtime_bundle(
country_id=country_id,
region=region,
dataset="default",
)
if resolved.canonical_dataset_uri is not None:
return resolved.worker_dataset_uri

if country_id == "us":
return _resolve_us_region_dataset(region)
if country_id == "uk":
return _resolve_uk_region_dataset(region)
raise ValueError(f"No default dataset configured for country_id={country_id!r}.")
Loading
Loading