Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.12", "3.13"]
python-version: ["3.11", "3.12", "3.13", "3.14"]

steps:
- uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
]
Expand Down
7 changes: 7 additions & 0 deletions src/microplex/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
- Multi-resolution dataset generation
"""

from microplex.core.checkpoints import (
load_entity_table_checkpoint,
save_entity_table_checkpoint,
)
from microplex.core.entities import (
BenefitUnit,
Entity,
Expand Down Expand Up @@ -115,4 +119,7 @@
"for_browser",
"for_api",
"for_research",
# Checkpoints
"save_entity_table_checkpoint",
"load_entity_table_checkpoint",
]
136 changes: 136 additions & 0 deletions src/microplex/core/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Generic entity-table pipeline checkpoints.

``save_entity_table_checkpoint`` and ``load_entity_table_checkpoint``
persist a dict of named entity tables to disk as parquet files plus a
``metadata.json`` index keyed by a pipeline stage. Country-specific
microplex packages (e.g. ``microplex-us``) wrap these with typed entity
bundles so a downstream rerun can resume from a saved state without
redoing expensive upstream work (synthesis, donor imputation,
tax-benefit microsim).

Usage
-----

.. code-block:: python

from microplex.core.checkpoints import save_entity_table_checkpoint

save_entity_table_checkpoint(
{"households": households_df, "persons": persons_df},
Path("artifacts/run/checkpoint"),
stage="post_imputation",
extra_metadata={"config_fingerprint": config_hash},
)

At load time, ``extra_metadata`` is available for cache-invalidation
decisions (stale checkpoint if the fingerprint differs from the
current config).
"""

from __future__ import annotations

import json
import shutil
from pathlib import Path
from typing import Any, Mapping

import pandas as pd


def save_entity_table_checkpoint(
tables: Mapping[str, pd.DataFrame | None],
path: str | Path,
*,
stage: str,
extra_metadata: Mapping[str, Any] | None = None,
) -> Path:
"""Persist a named dict of entity tables to ``path`` as parquet + metadata.

Args:
tables: Mapping from entity-table name (``"households"``,
``"persons"``, ...) to the DataFrame for that table, or
``None`` if the country / pipeline stage doesn't populate
that entity.
path: Target directory. Any existing directory at this path is
removed and replaced.
stage: Non-empty identifier describing the pipeline stage the
checkpoint was taken at (``"post_imputation"``,
``"post_microsim"``, ...). Stored in ``metadata.json`` and
validated by ``expected_stage`` on load.
extra_metadata: Optional mapping attached to the checkpoint
under the ``"extra"`` key — use for config fingerprints,
source-data versions, etc. that a caller wants to check for
cache invalidation.

Returns:
The directory the checkpoint was written to.
"""
if not stage:
raise ValueError("stage must be a non-empty string")

checkpoint_dir = Path(path)
if checkpoint_dir.exists():
shutil.rmtree(checkpoint_dir)
checkpoint_dir.mkdir(parents=True)

table_metadata: dict[str, dict[str, Any] | None] = {}
for table_name, frame in tables.items():
if frame is None:
table_metadata[table_name] = None
continue
frame.to_parquet(checkpoint_dir / f"{table_name}.parquet", index=False)
table_metadata[table_name] = {
"rows": int(len(frame)),
"columns": list(frame.columns),
}

metadata: dict[str, Any] = {
"format_version": 1,
"stage": stage,
"tables": table_metadata,
}
if extra_metadata is not None:
metadata["extra"] = dict(extra_metadata)

(checkpoint_dir / "metadata.json").write_text(
json.dumps(metadata, indent=2, default=str)
)
return checkpoint_dir


def load_entity_table_checkpoint(
path: str | Path,
*,
expected_stage: str | None = None,
) -> tuple[dict[str, pd.DataFrame | None], dict[str, Any]]:
"""Load a dict of entity tables previously saved by ``save_entity_table_checkpoint``.

Returns ``(tables, metadata)``. If ``expected_stage`` is set and
the saved stage doesn't match, a ``ValueError`` is raised —
protects against resuming from the wrong pipeline stage.
"""
checkpoint_dir = Path(path)
metadata_path = checkpoint_dir / "metadata.json"
if not metadata_path.exists():
raise FileNotFoundError(
f"Entity-table checkpoint not found at {checkpoint_dir}"
)
metadata = json.loads(metadata_path.read_text())

saved_stage = metadata.get("stage")
if expected_stage is not None and saved_stage != expected_stage:
raise ValueError(
f"Checkpoint at {checkpoint_dir} has stage {saved_stage!r}, "
f"expected {expected_stage!r}"
)

tables: dict[str, pd.DataFrame | None] = {}
table_metadata = metadata.get("tables", {})
for table_name, table_info in table_metadata.items():
if table_info is None:
tables[table_name] = None
continue
tables[table_name] = pd.read_parquet(
checkpoint_dir / f"{table_name}.parquet"
)
return tables, metadata
52 changes: 33 additions & 19 deletions src/microplex/geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from collections.abc import Callable
from dataclasses import dataclass, field
from importlib import import_module
from typing import Any, Protocol, runtime_checkable

import numpy as np
Expand Down Expand Up @@ -380,28 +381,41 @@ def _raise_missing_us_geography() -> ModuleNotFoundError:
)


try:
from microplex_us.geography import ( # noqa: F401
BlockGeography,
derive_geographies,
load_block_probabilities,
)
except ModuleNotFoundError as exc:
if exc.name != "microplex_us":
def _load_us_geography_symbol(name: str) -> Any:
try:
module = import_module("microplex_us.geography")
except ModuleNotFoundError as exc:
if exc.name == "microplex_us":
raise _raise_missing_us_geography() from exc
raise
return getattr(module, name)


def load_block_probabilities(*args: Any, **kwargs: Any) -> pd.DataFrame:
return _load_us_geography_symbol("load_block_probabilities")(*args, **kwargs)


def derive_geographies(*args: Any, **kwargs: Any) -> pd.DataFrame:
return _load_us_geography_symbol("derive_geographies")(*args, **kwargs)


class _USBlockGeographyProxy(type):
def __getattr__(cls, name: str) -> Any:
return getattr(_load_us_geography_symbol("BlockGeography"), name)

def __instancecheck__(cls, instance: Any) -> bool:
return isinstance(instance, _load_us_geography_symbol("BlockGeography"))

def load_block_probabilities(*args: Any, **kwargs: Any) -> pd.DataFrame:
raise _raise_missing_us_geography()
def __subclasscheck__(cls, subclass: type) -> bool:
return issubclass(subclass, _load_us_geography_symbol("BlockGeography"))

def derive_geographies(*args: Any, **kwargs: Any) -> pd.DataFrame:
raise _raise_missing_us_geography()

class BlockGeography: # type: ignore[no-redef]
"""Compatibility placeholder for the moved US block geography adapter."""
class BlockGeography(metaclass=_USBlockGeographyProxy):
"""Compatibility proxy for the moved US block geography adapter."""

@classmethod
def from_data(cls, data: pd.DataFrame) -> BlockGeography:
raise _raise_missing_us_geography()
@classmethod
def from_data(cls, data: pd.DataFrame) -> Any:
return _load_us_geography_symbol("BlockGeography").from_data(data)

def __init__(self, *args: Any, **kwargs: Any) -> None:
raise _raise_missing_us_geography()
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
return _load_us_geography_symbol("BlockGeography")(*args, **kwargs)
Loading
Loading