diff --git a/src/microplex/core/__init__.py b/src/microplex/core/__init__.py index 992d110..3fc69f3 100644 --- a/src/microplex/core/__init__.py +++ b/src/microplex/core/__init__.py @@ -2,44 +2,68 @@ Core data models for microplex. This module provides the foundational data structures for microdata representation: -- Entity types (Person, TaxUnit, Household, Family, SPMUnit, Record) +- Entity types (Person, TaxUnit, Household, Family, BenefitUnit, SPMUnit, Record) - Variable definitions with legal references - Period arithmetic - Multi-resolution dataset generation """ from microplex.core.entities import ( + BenefitUnit, + Entity, EntityType, + Family, FilingStatus, - RecordType, - Entity, - Person, - TaxUnit, Household, - Family, - SPMUnit, + Person, Record, -) -from microplex.core.variables import ( - DataType, - VariableRole, - LegalReference, - Variable, - VariableRegistry, + RecordType, + SPMUnit, + TaxUnit, ) from microplex.core.periods import ( - PeriodType, Period, + PeriodType, ) from microplex.core.resolution import ( - ResolutionLevel, - ResolutionConfig, HardConcreteGate, + ResolutionConfig, + ResolutionLevel, compress_dataset, - for_browser, for_api, + for_browser, for_research, ) +from microplex.core.source_manifests import ( + SourceColumnManifest, + SourceColumnValueType, + SourceManifest, + SourceObservationManifest, + load_source_manifest, +) +from microplex.core.sources import ( + EntityObservation, + EntityRelationship, + ObservationFrame, + RelationshipCardinality, + Shareability, + SourceAdapter, + SourceArchetype, + SourceDescriptor, + SourceProvider, + SourceQuery, + SourceVariableCapability, + StaticSourceProvider, + TimeStructure, + apply_source_query, +) +from microplex.core.variables import ( + DataType, + LegalReference, + Variable, + VariableRegistry, + VariableRole, +) __all__ = [ # Entities @@ -51,8 +75,24 @@ "TaxUnit", "Household", "Family", + "BenefitUnit", "SPMUnit", "Record", + # Sources + "TimeStructure", + "Shareability", + "SourceArchetype", + "RelationshipCardinality", + "EntityObservation", + "SourceVariableCapability", + "SourceDescriptor", + "EntityRelationship", + "ObservationFrame", + "SourceQuery", + "SourceProvider", + "StaticSourceProvider", + "apply_source_query", + "SourceAdapter", # Variables "DataType", "VariableRole", @@ -66,6 +106,11 @@ "ResolutionLevel", "ResolutionConfig", "HardConcreteGate", + "SourceColumnValueType", + "SourceColumnManifest", + "SourceObservationManifest", + "SourceManifest", + "load_source_manifest", "compress_dataset", "for_browser", "for_api", diff --git a/src/microplex/core/entities.py b/src/microplex/core/entities.py index 7c9fe03..646759f 100644 --- a/src/microplex/core/entities.py +++ b/src/microplex/core/entities.py @@ -3,8 +3,9 @@ Entities represent the hierarchical structure of tax-benefit microdata: - Person: Individual-level attributes - TaxUnit: Tax filing unit (IRS perspective) -- Household: Census household (housing costs, geography) -- Family: SPM family unit (poverty calculation) +- Household: Residential unit (housing costs, geography) +- Family: Family grouping used by some systems +- BenefitUnit: Benefit assessment unit (UK-style family benefit unit) - SPMUnit: Supplemental Poverty Measure unit - Record: Sub-person records (W-2s, K-1s, 1099s, etc.) """ @@ -18,28 +19,32 @@ class EntityType(Enum): """Types of entities in the microdata hierarchy.""" + RECORD = "record" PERSON = "person" TAX_UNIT = "tax_unit" HOUSEHOLD = "household" FAMILY = "family" + BENEFIT_UNIT = "benefit_unit" SPM_UNIT = "spm_unit" @property def level(self) -> int: """Hierarchy level (0 = lowest/most granular).""" levels = { - EntityType.PERSON: 0, - EntityType.TAX_UNIT: 1, - EntityType.HOUSEHOLD: 1, - EntityType.FAMILY: 1, - EntityType.SPM_UNIT: 1, + EntityType.RECORD: 0, + EntityType.PERSON: 1, + EntityType.TAX_UNIT: 2, + EntityType.HOUSEHOLD: 2, + EntityType.FAMILY: 2, + EntityType.BENEFIT_UNIT: 2, + EntityType.SPM_UNIT: 2, } return levels[self] @property def is_group(self) -> bool: """Whether this entity groups persons.""" - return self != EntityType.PERSON + return self not in (EntityType.RECORD, EntityType.PERSON) class FilingStatus(Enum): @@ -207,6 +212,21 @@ def entity_type(self) -> EntityType: return EntityType.FAMILY +class BenefitUnit(Entity): + """Benefit assessment unit. + + Used by tax-benefit systems that determine eligibility/resources at a + family-benefit-unit grain distinct from households or tax units. + """ + + member_ids: list[str] = Field(default_factory=list) + head_id: str | None = None + + @property + def entity_type(self) -> EntityType: + return EntityType.BENEFIT_UNIT + + class SPMUnit(Entity): """Supplemental Poverty Measure unit.""" @@ -275,4 +295,4 @@ class Record(Entity): @property def entity_type(self) -> EntityType: - return EntityType.PERSON # Records are person-level + return EntityType.RECORD diff --git a/src/microplex/core/source_manifests.py b/src/microplex/core/source_manifests.py new file mode 100644 index 0000000..ef342ef --- /dev/null +++ b/src/microplex/core/source_manifests.py @@ -0,0 +1,117 @@ +"""Typed source-manifest loader for externalized provider specs.""" + +from __future__ import annotations + +import json +from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path + +from microplex.core.entities import EntityType +from microplex.core.sources import SourceArchetype + + +class SourceColumnValueType(Enum): + """How a raw source column should be coerced during canonical mapping.""" + + NUMERIC = "numeric" + CATEGORICAL = "categorical" + + +@dataclass(frozen=True) +class SourceColumnManifest: + """One raw-to-canonical column mapping.""" + + raw_column: str + canonical_name: str + value_type: SourceColumnValueType = SourceColumnValueType.NUMERIC + + +@dataclass(frozen=True) +class SourceObservationManifest: + """Manifest for one observed entity table.""" + + entity: EntityType + key_column: str + table_name: str | None = None + weight_column: str | None = None + period_column: str | None = None + excluded_columns: tuple[str, ...] = () + aliases: Mapping[str, str] = field(default_factory=dict) + columns: tuple[SourceColumnManifest, ...] = () + + def observed_variable_names( + self, + frame_columns: Iterable[str] | None = None, + ) -> tuple[str, ...]: + """Return canonical observed variables for this entity.""" + reserved = {self.key_column} + if self.weight_column is not None: + reserved.add(self.weight_column) + if self.period_column is not None: + reserved.add(self.period_column) + reserved.update(self.excluded_columns) + if self.columns: + return tuple( + column.canonical_name + for column in self.columns + if column.canonical_name not in reserved + ) + if frame_columns is None: + raise ValueError( + "frame_columns must be provided when manifest columns are implicit" + ) + return tuple(column for column in frame_columns if column not in reserved) + + +@dataclass(frozen=True) +class SourceManifest: + """Typed manifest for one source-provider family.""" + + name: str + archetype: SourceArchetype + population: str | None = None + description: str | None = None + observations: tuple[SourceObservationManifest, ...] = () + + def observation_for(self, entity: EntityType) -> SourceObservationManifest: + """Return the manifest entry for one entity.""" + for observation in self.observations: + if observation.entity is entity: + return observation + raise KeyError(f"Manifest '{self.name}' has no entity '{entity.value}'") + + +def load_source_manifest(path: str | Path) -> SourceManifest: + """Load a typed source manifest from JSON.""" + payload = json.loads(Path(path).read_text()) + observations = tuple( + SourceObservationManifest( + entity=EntityType(observation_payload["entity"]), + key_column=observation_payload["key_column"], + table_name=observation_payload.get("table_name"), + weight_column=observation_payload.get("weight_column"), + period_column=observation_payload.get("period_column"), + excluded_columns=tuple(observation_payload.get("excluded_columns", ())), + aliases=dict(observation_payload.get("aliases", {})), + columns=tuple( + SourceColumnManifest( + raw_column=column_payload["raw_column"], + canonical_name=column_payload["canonical_name"], + value_type=SourceColumnValueType( + column_payload.get("value_type", "numeric") + ), + ) + for column_payload in observation_payload.get("columns", ()) + ), + ) + for observation_payload in payload["observations"] + ) + return SourceManifest( + name=payload["name"], + archetype=SourceArchetype(payload["archetype"]), + population=payload.get("population"), + description=payload.get("description"), + observations=observations, + ) diff --git a/src/microplex/core/sources.py b/src/microplex/core/sources.py new file mode 100644 index 0000000..f362439 --- /dev/null +++ b/src/microplex/core/sources.py @@ -0,0 +1,362 @@ +"""Source and observation metadata for multientity fusion.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Protocol, runtime_checkable + +import pandas as pd + +from microplex.core.entities import EntityType + + +class TimeStructure(Enum): + """How observations are distributed across time.""" + + CROSS_SECTION = "cross_section" + REPEATED_CROSS_SECTION = "repeated_cross_section" + PANEL = "panel" + EVENT_HISTORY = "event_history" + + +class Shareability(Enum): + """Whether source microdata can appear directly in released artifacts.""" + + PUBLIC = "public" + RESTRICTED = "restricted" + NON_SHAREABLE = "non_shareable" + + @property + def allows_direct_release(self) -> bool: + """Whether this source can be directly represented in public outputs.""" + return self is Shareability.PUBLIC + + +class SourceArchetype(Enum): + """Cross-country source role used for planning analogous survey families.""" + + HOUSEHOLD_INCOME = "household_income" + TAX_MICRODATA = "tax_microdata" + WEALTH = "wealth" + CONSUMPTION = "consumption" + LONGITUDINAL_SOCIOECONOMIC = "longitudinal_socioeconomic" + + +class RelationshipCardinality(Enum): + """Cardinality from parent entity to child entity.""" + + ONE_TO_ONE = "one_to_one" + ONE_TO_MANY = "one_to_many" + MANY_TO_ONE = "many_to_one" + MANY_TO_MANY = "many_to_many" + + +@dataclass(frozen=True) +class SourceVariableCapability: + """How one source variable should be used during fusion and imputation.""" + + authoritative: bool = True + usable_as_condition: bool = True + notes: str | None = None + + +@dataclass(frozen=True) +class EntityObservation: + """Observed variables for one entity within a source.""" + + entity: EntityType + key_column: str + variable_names: tuple[str, ...] + weight_column: str | None = None + period_column: str | None = None + + def __post_init__(self) -> None: + if not self.key_column: + raise ValueError("EntityObservation.key_column must be non-empty") + if not self.variable_names: + raise ValueError("EntityObservation.variable_names must be non-empty") + if len(set(self.variable_names)) != len(self.variable_names): + raise ValueError( + f"Duplicate variables declared for entity '{self.entity.value}'" + ) + + +@dataclass(frozen=True) +class SourceDescriptor: + """Metadata describing one source as a partial view of the population.""" + + name: str + shareability: Shareability + time_structure: TimeStructure + observations: tuple[EntityObservation, ...] + archetype: SourceArchetype | None = None + population: str | None = None + description: str | None = None + variable_capabilities: Mapping[str, SourceVariableCapability] = field( + default_factory=dict + ) + + def __post_init__(self) -> None: + if not self.name: + raise ValueError("SourceDescriptor.name must be non-empty") + if not self.observations: + raise ValueError("SourceDescriptor.observations must be non-empty") + entities = [observation.entity for observation in self.observations] + if len(set(entities)) != len(entities): + raise ValueError( + f"Source '{self.name}' declares the same entity more than once" + ) + unknown_capabilities = set(self.variable_capabilities) - self.all_variable_names + if unknown_capabilities: + missing = ", ".join(sorted(unknown_capabilities)) + raise ValueError( + f"Source '{self.name}' declares capabilities for unknown variables: {missing}" + ) + + @property + def observed_entities(self) -> tuple[EntityType, ...]: + """Entities observed by this source.""" + return tuple(observation.entity for observation in self.observations) + + @property + def all_variable_names(self) -> frozenset[str]: + """All variables observed across every entity table.""" + return frozenset( + variable + for observation in self.observations + for variable in observation.variable_names + ) + + def observation_for(self, entity: EntityType) -> EntityObservation: + """Return the observation metadata for one entity.""" + for observation in self.observations: + if observation.entity is entity: + return observation + raise KeyError( + f"Source '{self.name}' does not observe entity '{entity.value}'" + ) + + def variables_for(self, entity: EntityType) -> frozenset[str]: + """Return the variables observed for one entity.""" + return frozenset(self.observation_for(entity).variable_names) + + def observes(self, variable_name: str, entity: EntityType | None = None) -> bool: + """Whether this source observes a variable.""" + if entity is not None: + return variable_name in self.variables_for(entity) + return any(variable_name in observation.variable_names for observation in self.observations) + + def capability_for(self, variable_name: str) -> SourceVariableCapability: + """Return usage metadata for one variable, defaulting to permissive behavior.""" + return self.variable_capabilities.get(variable_name, SourceVariableCapability()) + + def is_authoritative_for(self, variable_name: str) -> bool: + """Whether the source should be trusted to donate this variable.""" + return self.capability_for(variable_name).authoritative + + def allows_conditioning_on(self, variable_name: str) -> bool: + """Whether the variable is semantically valid as a shared conditioning feature.""" + return self.capability_for(variable_name).usable_as_condition + + +@dataclass(frozen=True) +class EntityRelationship: + """Relationship between two observed entity tables.""" + + parent_entity: EntityType + child_entity: EntityType + parent_key: str + child_key: str + cardinality: RelationshipCardinality = RelationshipCardinality.ONE_TO_MANY + + def __post_init__(self) -> None: + if self.parent_entity is self.child_entity: + raise ValueError("EntityRelationship must connect different entities") + if not self.parent_key or not self.child_key: + raise ValueError("EntityRelationship keys must be non-empty") + + +@dataclass +class ObservationFrame: + """Observed tables and relationships for one source realization.""" + + source: SourceDescriptor + tables: Mapping[EntityType, pd.DataFrame] + relationships: tuple[EntityRelationship, ...] = field(default_factory=tuple) + + def validate(self) -> None: + """Validate table schemas, primary keys, and foreign-key relationships.""" + for observation in self.source.observations: + table = self.tables.get(observation.entity) + if table is None: + raise ValueError( + f"Source '{self.source.name}' is missing table for " + f"entity '{observation.entity.value}'" + ) + + required_columns = set(observation.variable_names) | {observation.key_column} + if observation.weight_column is not None: + required_columns.add(observation.weight_column) + if observation.period_column is not None: + required_columns.add(observation.period_column) + + missing_columns = required_columns - set(table.columns) + if missing_columns: + missing = ", ".join(sorted(missing_columns)) + raise ValueError( + f"Source '{self.source.name}' entity '{observation.entity.value}' " + f"is missing columns: {missing}" + ) + + if table[observation.key_column].isna().any(): + raise ValueError( + f"Source '{self.source.name}' entity '{observation.entity.value}' " + "contains null primary keys" + ) + if table[observation.key_column].duplicated().any(): + raise ValueError( + f"Source '{self.source.name}' entity '{observation.entity.value}' " + "contains duplicate primary keys" + ) + + for relationship in self.relationships: + parent_table = self._table_for(relationship.parent_entity) + child_table = self._table_for(relationship.child_entity) + self._require_columns( + relationship=relationship, + table=parent_table, + entity=relationship.parent_entity, + columns=(relationship.parent_key,), + ) + self._require_columns( + relationship=relationship, + table=child_table, + entity=relationship.child_entity, + columns=(relationship.child_key,), + ) + + parent_keys = set(parent_table[relationship.parent_key].dropna()) + child_keys = set(child_table[relationship.child_key].dropna()) + missing_parent_keys = sorted(child_keys - parent_keys) + if missing_parent_keys: + raise ValueError( + "Relationship " + f"{relationship.child_entity.value}->{relationship.parent_entity.value} " + f"has missing parent keys: {missing_parent_keys}" + ) + + if relationship.cardinality is RelationshipCardinality.ONE_TO_ONE: + duplicates = child_table[relationship.child_key].dropna().duplicated() + if duplicates.any(): + raise ValueError( + "Relationship " + f"{relationship.child_entity.value}->{relationship.parent_entity.value} " + "violates one-to-one cardinality" + ) + + def observation_mask(self, entity: EntityType) -> pd.DataFrame: + """Return a boolean observation mask for one entity table.""" + observation = self.source.observation_for(entity) + table = self._table_for(entity) + mask = table.loc[:, list(observation.variable_names)].notna().copy() + mask.index = pd.Index(table[observation.key_column], name=observation.key_column) + return mask + + def _table_for(self, entity: EntityType) -> pd.DataFrame: + table = self.tables.get(entity) + if table is None: + raise ValueError( + f"Source '{self.source.name}' is missing table for entity '{entity.value}'" + ) + return table + + def _require_columns( + self, + relationship: EntityRelationship, + table: pd.DataFrame, + entity: EntityType, + columns: Sequence[str], + ) -> None: + missing_columns = set(columns) - set(table.columns) + if missing_columns: + missing = ", ".join(sorted(missing_columns)) + raise ValueError( + "Relationship " + f"{relationship.child_entity.value}->{relationship.parent_entity.value} " + f"entity '{entity.value}' is missing columns: {missing}" + ) + + +@dataclass(frozen=True) +class SourceQuery: + """Generic query parameters for loading observation frames.""" + + period: int | str | None = None + provider_filters: dict[str, Any] = field(default_factory=dict) + + +def apply_source_query( + frame: ObservationFrame, + query: SourceQuery | None = None, +) -> ObservationFrame: + """Filter an observation frame using generic query semantics.""" + if query is None or query.period is None: + return frame + + filtered_tables: dict[EntityType, pd.DataFrame] = {} + for observation in frame.source.observations: + table = frame.tables[observation.entity] + if observation.period_column is None: + filtered_tables[observation.entity] = table.copy() + continue + filtered_tables[observation.entity] = table.loc[ + table[observation.period_column] == query.period + ].copy() + + filtered = ObservationFrame( + source=frame.source, + tables=filtered_tables, + relationships=frame.relationships, + ) + filtered.validate() + return filtered + + +@runtime_checkable +class SourceProvider(Protocol): + """Protocol for providers that materialize observation frames.""" + + @property + def descriptor(self) -> SourceDescriptor: + """Return metadata describing the source.""" + + def load_frame(self, query: SourceQuery | None = None) -> ObservationFrame: + """Load the source into a validated observation frame.""" + + +@dataclass +class StaticSourceProvider: + """A provider backed by an in-memory observation frame.""" + + frame: ObservationFrame + + @property + def descriptor(self) -> SourceDescriptor: + return self.frame.source + + def load_frame(self, query: SourceQuery | None = None) -> ObservationFrame: + self.frame.validate() + return apply_source_query(self.frame, query) + + +class SourceAdapter(Protocol): + """Protocol for adapters that materialize observation frames.""" + + @property + def descriptor(self) -> SourceDescriptor: + """Return metadata describing the source.""" + + def load(self) -> ObservationFrame: + """Load the source into a validated observation frame.""" diff --git a/src/microplex/fusion/__init__.py b/src/microplex/fusion/__init__.py index d28e12e..1f3cc67 100644 --- a/src/microplex/fusion/__init__.py +++ b/src/microplex/fusion/__init__.py @@ -17,15 +17,16 @@ """ from .harmonize import ( - harmonize_surveys, - stack_surveys, COMMON_SCHEMA, CPS_MAPPING, PUF_MAPPING, - apply_transform, apply_inverse_transform, + apply_transform, + harmonize_surveys, + stack_surveys, ) from .masked_maf import MaskedMAF +from .multi_source_fusion import MultiSourceFusion from .pipeline import ( FusionConfig, FusionResult, @@ -34,7 +35,7 @@ load_puf_for_fusion, synthesize_from_surveys, ) -from .multi_source_fusion import MultiSourceFusion +from .planning import FusionPlan, VariableCoverage __all__ = [ # Low-level harmonization @@ -47,6 +48,8 @@ "apply_inverse_transform", # Masked MAF model "MaskedMAF", + "VariableCoverage", + "FusionPlan", # High-level pipeline "FusionConfig", "FusionResult", diff --git a/src/microplex/fusion/planning.py b/src/microplex/fusion/planning.py new file mode 100644 index 0000000..2500305 --- /dev/null +++ b/src/microplex/fusion/planning.py @@ -0,0 +1,142 @@ +"""Source-symmetric planning for multientity fusion.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from microplex.core import EntityType +from microplex.core.sources import ( + Shareability, + SourceArchetype, + SourceDescriptor, + TimeStructure, +) + + +@dataclass(frozen=True) +class VariableCoverage: + """Coverage metadata for one entity-scoped variable.""" + + entity: EntityType + variable_name: str + sources: tuple[str, ...] + shareabilities: frozenset[Shareability] + time_structures: frozenset[TimeStructure] + + @property + def publicly_observed(self) -> bool: + """Whether at least one public source observes this variable.""" + return Shareability.PUBLIC in self.shareabilities + + @property + def requires_synthetic_release(self) -> bool: + """Whether public release must synthesize this variable.""" + return not self.publicly_observed + + +@dataclass(frozen=True) +class FusionPlan: + """A source-symmetric plan describing multientity fusion coverage.""" + + source_names: tuple[str, ...] + source_archetypes: dict[str, SourceArchetype | None] + coverage: dict[EntityType, dict[str, VariableCoverage]] + + @classmethod + def from_sources(cls, sources: list[SourceDescriptor]) -> FusionPlan: + """Build a fusion plan from a set of source descriptors.""" + if not sources: + raise ValueError("FusionPlan requires at least one source") + + source_names = [source.name for source in sources] + if len(set(source_names)) != len(source_names): + raise ValueError("FusionPlan source names must be unique") + + by_variable: dict[ + tuple[EntityType, str], + dict[str, set[str] | set[Shareability] | set[TimeStructure]], + ] = {} + + for source in sources: + for observation in source.observations: + for variable_name in observation.variable_names: + key = (observation.entity, variable_name) + entry = by_variable.setdefault( + key, + { + "sources": set(), + "shareabilities": set(), + "time_structures": set(), + }, + ) + entry["sources"].add(source.name) + entry["shareabilities"].add(source.shareability) + entry["time_structures"].add(source.time_structure) + + coverage: dict[EntityType, dict[str, VariableCoverage]] = {} + for (entity, variable_name), entry in sorted( + by_variable.items(), + key=lambda item: (item[0][0].value, item[0][1]), + ): + entity_coverage = coverage.setdefault(entity, {}) + entity_coverage[variable_name] = VariableCoverage( + entity=entity, + variable_name=variable_name, + sources=tuple(sorted(entry["sources"])), + shareabilities=frozenset(entry["shareabilities"]), + time_structures=frozenset(entry["time_structures"]), + ) + + return cls( + source_names=tuple(source_names), + source_archetypes={ + source.name: source.archetype for source in sources + }, + coverage=coverage, + ) + + @property + def output_entities(self) -> tuple[EntityType, ...]: + """Entities that appear in the planned fusion output.""" + return tuple(self.coverage.keys()) + + def variables_for(self, entity: EntityType) -> frozenset[str]: + """Variables covered for one entity.""" + entity_coverage = self.coverage.get(entity, {}) + return frozenset(entity_coverage) + + def sources_for_archetype( + self, + archetype: SourceArchetype, + ) -> tuple[str, ...]: + """Return source names registered to one cross-country archetype.""" + return tuple( + source_name + for source_name in self.source_names + if self.source_archetypes.get(source_name) is archetype + ) + + def variable_plan( + self, + entity: EntityType, + variable_name: str, + ) -> VariableCoverage: + """Return coverage metadata for one variable.""" + entity_coverage = self.coverage.get(entity) + if entity_coverage is None or variable_name not in entity_coverage: + raise KeyError( + f"Fusion plan has no coverage for {entity.value}.{variable_name}" + ) + return entity_coverage[variable_name] + + def variables_requiring_synthetic_release( + self, + entity: EntityType, + ) -> frozenset[str]: + """Variables that must be synthesized for public release.""" + entity_coverage = self.coverage.get(entity, {}) + return frozenset( + variable_name + for variable_name, variable_coverage in entity_coverage.items() + if variable_coverage.requires_synthetic_release + ) diff --git a/src/microplex/geography.py b/src/microplex/geography.py index 302d83b..bc09a93 100644 --- a/src/microplex/geography.py +++ b/src/microplex/geography.py @@ -1,578 +1,407 @@ """ -Block-based geography derivation utilities for microplex. - -This module provides tools for working with Census block GEOIDs and deriving -higher-level geographies (tract, county, state, congressional district). - -Census GEOID Structure (15 characters total): -- State FIPS: 2 chars (positions 0-1) -- County FIPS: 3 chars (positions 2-4) -- Tract: 6 chars (positions 5-10) -- Block: 4 chars (positions 11-14) - -Example: - >>> from microplex.geography import BlockGeography - >>> geo = BlockGeography() - >>> geo.get_state("010010201001000") - '01' - >>> geo.get_county("010010201001000") - '01001' - >>> geo.get_tract("010010201001000") - '01001020100' +Atomic geography helpers and provider protocols for microplex. + +Core `microplex` owns the generic crosswalk, assignment, and provider +abstractions. Country-specific concrete geography adapters live in extension +packages such as `microplex-us`. """ -from functools import lru_cache -from pathlib import Path -from typing import Dict, List, Optional, Union +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable import numpy as np import pandas as pd - -# GEOID structure constants +# Kept as compatibility constants for callers that still use US Census GEOIDs. STATE_LEN = 2 -COUNTY_LEN = 3 # County portion after state (total 5 chars for state+county) -TRACT_LEN = 6 # Tract portion after county (total 11 chars for tract GEOID) -BLOCK_LEN = 4 # Block portion after tract (total 15 chars for full GEOID) +COUNTY_LEN = 3 +TRACT_LEN = 6 +BLOCK_LEN = 4 -# Full length constants for convenience -STATE_GEOID_LEN = STATE_LEN # 2 -COUNTY_GEOID_LEN = STATE_LEN + COUNTY_LEN # 5 -TRACT_GEOID_LEN = STATE_LEN + COUNTY_LEN + TRACT_LEN # 11 -BLOCK_GEOID_LEN = STATE_LEN + COUNTY_LEN + TRACT_LEN + BLOCK_LEN # 15 +STATE_GEOID_LEN = STATE_LEN +COUNTY_GEOID_LEN = STATE_LEN + COUNTY_LEN +TRACT_GEOID_LEN = STATE_LEN + COUNTY_LEN + TRACT_LEN +BLOCK_GEOID_LEN = STATE_LEN + COUNTY_LEN + TRACT_LEN + BLOCK_LEN -# Default data directory (relative to package root) -DEFAULT_DATA_DIR = Path(__file__).parent.parent.parent / "data" -DEFAULT_BLOCK_PROBABILITIES_PATH = DEFAULT_DATA_DIR / "block_probabilities.parquet" +PartitionKey = tuple[Any, ...] +PartitionFallbackResolver = Callable[[PartitionKey, tuple[PartitionKey, ...]], PartitionKey] +PartitionNormalizer = Callable[[Any], Any] -def load_block_probabilities( - path: Optional[Union[str, Path]] = None -) -> pd.DataFrame: - """ - Load block probabilities from parquet file. - - Args: - path: Path to parquet file. If None, uses default package data location. - - Returns: - DataFrame with columns: geoid, state_fips, county, tract, block, - population, tract_geoid, cd_id, state_total, prob, national_prob - - Raises: - FileNotFoundError: If the parquet file doesn't exist. - - Example: - >>> df = load_block_probabilities() - >>> print(f"Loaded {len(df):,} blocks") - """ - if path is None: - path = DEFAULT_BLOCK_PROBABILITIES_PATH - else: - path = Path(path) - - if not path.exists(): - raise FileNotFoundError( - f"Block probabilities file not found at {path}.\n" - "Run the data preparation script to generate this file." - ) +@dataclass(frozen=True) +class AtomicGeographyCrosswalk: + """Crosswalk from atomic geography units to materialized parent geographies.""" - return pd.read_parquet(path) + data: pd.DataFrame + atomic_id_column: str + geography_columns: tuple[str, ...] = () + probability_column: str | None = None + def __post_init__(self) -> None: + if self.atomic_id_column not in self.data.columns: + raise ValueError( + f"Atomic geography column '{self.atomic_id_column}' not found in crosswalk" + ) + if self.data[self.atomic_id_column].duplicated().any(): + raise ValueError("Atomic geography crosswalk must have unique atomic ids") + if ( + self.probability_column is not None + and self.probability_column not in self.data.columns + ): + raise ValueError( + f"Probability column '{self.probability_column}' not found in crosswalk" + ) -def derive_geographies( - block_geoids: Union[List[str], np.ndarray, pd.Series], - include_cd: bool = False, - include_sld: bool = False, - block_data: Optional[pd.DataFrame] = None, -) -> pd.DataFrame: - """ - Derive all higher-level geographies from block GEOIDs. - - This is a convenience function for batch processing. For repeated lookups, - use BlockGeography class with caching. - - Args: - block_geoids: List/array of 15-character block GEOIDs - include_cd: If True, include congressional district lookup (requires block_data) - include_sld: If True, include state legislative district lookup (requires block_data) - block_data: Block probabilities DataFrame for CD/SLD lookup - - Returns: - DataFrame with columns: block_geoid, state_fips, county_fips, tract_geoid - If include_cd=True, also includes cd_id column. - If include_sld=True, also includes sldu_id and sldl_id columns. - - Example: - >>> geoids = ["010010201001000", "060372073021001"] - >>> result = derive_geographies(geoids) - >>> print(result) - """ - geoids = pd.Series(block_geoids).astype(str) - - result = pd.DataFrame({ - "block_geoid": geoids, - "state_fips": geoids.str[:STATE_GEOID_LEN], - "county_fips": geoids.str[:COUNTY_GEOID_LEN], - "tract_geoid": geoids.str[:TRACT_GEOID_LEN], - }) - - if include_cd or include_sld: - if block_data is None: - block_data = load_block_probabilities() - - if include_cd: - # Create lookup dict for CD - cd_lookup = dict(zip(block_data["geoid"], block_data["cd_id"])) - result["cd_id"] = geoids.map(cd_lookup) - - if include_sld: - # Create lookup dicts for SLD - if "sldu_id" in block_data.columns: - sldu_lookup = dict(zip(block_data["geoid"], block_data["sldu_id"])) - result["sldu_id"] = geoids.map(sldu_lookup) - if "sldl_id" in block_data.columns: - sldl_lookup = dict(zip(block_data["geoid"], block_data["sldl_id"])) - result["sldl_id"] = geoids.map(sldl_lookup) - - return result - - -class BlockGeography: - """ - Geography derivation from Census block GEOIDs. - - Provides efficient methods for deriving higher-level geographies - (tract, county, state, congressional district) from block GEOIDs, - with caching for performance. - - Attributes: - data: Block probabilities DataFrame (loaded lazily) - - Example: - >>> geo = BlockGeography() - >>> block = "060372073021001" # A block in Los Angeles County, CA - >>> geo.get_state(block) - '06' - >>> geo.get_county(block) - '06037' - >>> geo.get_tract(block) - '06037207302' - >>> geo.get_cd(block) - 'CA-37' - """ - - def __init__( + geography_columns = self.geography_columns or tuple( + column + for column in self.data.columns + if column not in {self.atomic_id_column, self.probability_column} + ) + missing_columns = [ + column for column in geography_columns if column not in self.data.columns + ] + if missing_columns: + raise ValueError( + f"Geography columns not found in crosswalk: {sorted(missing_columns)}" + ) + object.__setattr__(self, "geography_columns", tuple(geography_columns)) + + def lookup( + self, + atomic_ids: pd.Series | np.ndarray | list[Any], + *, + columns: tuple[str, ...] | list[str] | None = None, + ) -> pd.DataFrame: + """Lookup parent geographies for atomic ids.""" + requested_columns = tuple(columns or self.geography_columns) + lookup = self.data[[self.atomic_id_column, *requested_columns]].copy() + requested_ids = pd.DataFrame({self.atomic_id_column: list(atomic_ids)}) + return requested_ids.merge(lookup, on=self.atomic_id_column, how="left") + + def materialize( self, - data_path: Optional[Union[str, Path]] = None, - lazy_load: bool = True, - ): - """ - Initialize BlockGeography. - - Args: - data_path: Path to block probabilities parquet. If None, uses default. - lazy_load: If True, defer loading data until needed. If False, load immediately. - """ - self._data_path = data_path - self._data: Optional[pd.DataFrame] = None - self._cd_lookup: Optional[Dict[str, str]] = None - self._sldu_lookup: Optional[Dict[str, str]] = None - self._sldl_lookup: Optional[Dict[str, str]] = None - self._state_blocks: Optional[Dict[str, pd.DataFrame]] = None - - if not lazy_load: - self._load_data() - - def _load_data(self) -> None: - """Load block probabilities data if not already loaded.""" - if self._data is None: - self._data = load_block_probabilities(self._data_path) - - @property - def data(self) -> pd.DataFrame: - """Block probabilities DataFrame (loaded lazily).""" - if self._data is None: - self._load_data() - return self._data - - @staticmethod - @lru_cache(maxsize=100000) - def get_state(block_geoid: str) -> str: - """ - Get state FIPS from block GEOID. - - Args: - block_geoid: 15-character Census block GEOID - - Returns: - 2-character state FIPS code - - Example: - >>> BlockGeography.get_state("060372073021001") - '06' - """ - return block_geoid[:STATE_GEOID_LEN] - - @staticmethod - @lru_cache(maxsize=100000) - def get_county(block_geoid: str) -> str: - """ - Get county FIPS (state + county) from block GEOID. - - Args: - block_geoid: 15-character Census block GEOID - - Returns: - 5-character county FIPS (state + county) - - Example: - >>> BlockGeography.get_county("060372073021001") - '06037' - """ - return block_geoid[:COUNTY_GEOID_LEN] - - @staticmethod - @lru_cache(maxsize=100000) - def get_tract(block_geoid: str) -> str: - """ - Get tract GEOID from block GEOID. - - Args: - block_geoid: 15-character Census block GEOID - - Returns: - 11-character tract GEOID - - Example: - >>> BlockGeography.get_tract("060372073021001") - '06037207302' - """ - return block_geoid[:TRACT_GEOID_LEN] - - def get_cd(self, block_geoid: str) -> Optional[str]: - """ - Get congressional district ID from block GEOID. - - Unlike state/county/tract (which can be derived from the GEOID string), - congressional district requires a lookup in the block data. - - Args: - block_geoid: 15-character Census block GEOID - - Returns: - Congressional district ID (e.g., "CA-37") or None if not found - - Example: - >>> geo = BlockGeography() - >>> geo.get_cd("060372073021001") - 'CA-37' - """ - if self._cd_lookup is None: - self._build_lookups() - - return self._cd_lookup.get(block_geoid) - - def get_sldu(self, block_geoid: str) -> Optional[str]: - """ - Get State Senate (upper chamber) district ID from block GEOID. - - Args: - block_geoid: 15-character Census block GEOID - - Returns: - SLDU ID (e.g., "CA-SLDU-01") or None if not found - - Example: - >>> geo = BlockGeography() - >>> geo.get_sldu("060372073021001") - 'CA-SLDU-22' - """ - if self._sldu_lookup is None: - self._build_lookups() - - return self._sldu_lookup.get(block_geoid) - - def get_sldl(self, block_geoid: str) -> Optional[str]: - """ - Get State House (lower chamber) district ID from block GEOID. - - Note: Nebraska has a unicameral legislature, so SLDL will be None - for Nebraska blocks. - - Args: - block_geoid: 15-character Census block GEOID - - Returns: - SLDL ID (e.g., "CA-SLDL-40") or None if not found - - Example: - >>> geo = BlockGeography() - >>> geo.get_sldl("060372073021001") - 'CA-SLDL-46' - """ - if self._sldl_lookup is None: - self._build_lookups() + frame: pd.DataFrame, + *, + columns: tuple[str, ...] | list[str] | None = None, + atomic_id_column: str | None = None, + overwrite: bool = False, + ) -> pd.DataFrame: + """Attach parent geography columns to a frame that already has atomic ids.""" + join_column = atomic_id_column or self.atomic_id_column + if join_column not in frame.columns: + raise ValueError( + f"Atomic geography column '{join_column}' not found in frame" + ) + requested_columns = tuple(columns or self.geography_columns) + if not overwrite: + requested_columns = tuple( + column + for column in requested_columns + if column not in frame.columns or column == join_column + ) + if not requested_columns: + return frame.copy() + lookup = self.data[[self.atomic_id_column, *requested_columns]].copy() + if join_column != self.atomic_id_column: + lookup = lookup.rename(columns={self.atomic_id_column: join_column}) + if overwrite: + columns_to_drop = [ + column + for column in requested_columns + if column != join_column and column in frame.columns + ] + base_frame = frame.drop(columns=columns_to_drop) + else: + base_frame = frame + return base_frame.merge(lookup, on=join_column, how="left") - return self._sldl_lookup.get(block_geoid) - - def _build_lookups(self) -> None: - """Build lookup dictionaries for CD and SLD.""" - self._cd_lookup = dict(zip(self.data["geoid"], self.data["cd_id"])) - # SLD lookups (may not exist in older data) - if "sldu_id" in self.data.columns: - self._sldu_lookup = dict(zip(self.data["geoid"], self.data["sldu_id"])) - else: - self._sldu_lookup = {} +def materialize_geographies( + frame: pd.DataFrame, + crosswalk: AtomicGeographyCrosswalk, + *, + columns: tuple[str, ...] | list[str] | None = None, + atomic_id_column: str | None = None, + overwrite: bool = False, +) -> pd.DataFrame: + """Attach geography columns from an atomic-geography crosswalk.""" + return crosswalk.materialize( + frame, + columns=columns, + atomic_id_column=atomic_id_column, + overwrite=overwrite, + ) + + +def nearest_numeric_partition_key( + requested_key: PartitionKey, + available_keys: tuple[PartitionKey, ...], +) -> PartitionKey: + """Resolve a missing partition key to the nearest numeric available key.""" + if len(requested_key) != 1: + raise ValueError("nearest_numeric_partition_key only supports one-column partitions") + if not available_keys: + raise ValueError("Cannot resolve nearest key without available partitions") + + requested_value = int(round(float(requested_key[0]))) + distances = np.array( + [ + abs(int(round(float(candidate[0]))) - requested_value) + for candidate in available_keys + ] + ) + return available_keys[int(np.argmin(distances))] + + +def normalize_us_state_fips(value: Any) -> str: + """Compatibility helper for US state FIPS normalization.""" + return str(int(round(float(value)))).zfill(2) + + +@dataclass +class ProbabilisticAtomicGeographyAssigner: + """Assign atomic geography ids from grouped probability distributions.""" + + crosswalk: AtomicGeographyCrosswalk + partition_columns: tuple[str, ...] + probability_column: str | None = None + partition_normalizers: dict[str, PartitionNormalizer] = field(default_factory=dict) + fallback_resolver: PartitionFallbackResolver | None = None + + def __post_init__(self) -> None: + missing_columns = [ + column + for column in self.partition_columns + if column not in self.crosswalk.data.columns + ] + if missing_columns: + raise ValueError( + f"Partition columns not found in crosswalk: {sorted(missing_columns)}" + ) + if self.probability_column is None: + self.probability_column = self.crosswalk.probability_column + if self.probability_column is None: + raise ValueError("A probability column is required for probabilistic assignment") + if self.probability_column not in self.crosswalk.data.columns: + raise ValueError( + f"Probability column '{self.probability_column}' not found in crosswalk" + ) + self._group_lookup = self._build_group_lookup() - if "sldl_id" in self.data.columns: - self._sldl_lookup = dict(zip(self.data["geoid"], self.data["sldl_id"])) - else: - self._sldl_lookup = {} - - def get_all_geographies(self, block_geoid: str) -> Dict[str, Optional[str]]: - """ - Get all derived geographies for a block GEOID. - - Args: - block_geoid: 15-character Census block GEOID - - Returns: - Dictionary with keys: state_fips, county_fips, tract_geoid, cd_id, - sldu_id, sldl_id - - Example: - >>> geo = BlockGeography() - >>> geo.get_all_geographies("060372073021001") - {'state_fips': '06', 'county_fips': '06037', - 'tract_geoid': '06037207302', 'cd_id': 'CA-37', - 'sldu_id': 'CA-SLDU-22', 'sldl_id': 'CA-SLDL-46'} - """ - return { - "state_fips": self.get_state(block_geoid), - "county_fips": self.get_county(block_geoid), - "tract_geoid": self.get_tract(block_geoid), - "cd_id": self.get_cd(block_geoid), - "sldu_id": self.get_sldu(block_geoid), - "sldl_id": self.get_sldl(block_geoid), - } - - def sample_blocks( + def assign( self, - state_fips: str, - n: int, - replace: bool = True, - random_state: Optional[int] = None, - ) -> np.ndarray: - """ - Sample blocks from a state using population-weighted probabilities. - - Args: - state_fips: 2-character state FIPS code - n: Number of blocks to sample - replace: Sample with replacement (default True) - random_state: Random seed for reproducibility - - Returns: - Array of sampled block GEOIDs - - Raises: - ValueError: If state_fips not found in data - - Example: - >>> geo = BlockGeography() - >>> blocks = geo.sample_blocks("06", n=100, random_state=42) - >>> print(f"Sampled {len(blocks)} blocks from California") - """ - # Build state index if needed - if self._state_blocks is None: - self._build_state_index() - - if state_fips not in self._state_blocks: + frame: pd.DataFrame, + *, + atomic_id_column: str | None = None, + random_state: int | None = None, + ) -> pd.DataFrame: + """Assign atomic geography ids to each row of a frame.""" + missing_columns = [ + column for column in self.partition_columns if column not in frame.columns + ] + if missing_columns: raise ValueError( - f"State FIPS '{state_fips}' not found in block data. " - f"Available states: {sorted(self._state_blocks.keys())}" + f"Partition columns not found in frame: {sorted(missing_columns)}" ) - state_df = self._state_blocks[state_fips] + rng = np.random.default_rng(random_state) + assigned_atomic_ids: list[Any] = [] + available_keys = tuple(self._group_lookup.keys()) + + for raw_key in frame.loc[:, self.partition_columns].itertuples( + index=False, + name=None, + ): + normalized_key = self._normalize_partition_key(raw_key) + lookup_key = normalized_key + if lookup_key not in self._group_lookup: + if self.fallback_resolver is None: + raise ValueError( + f"No atomic geography distribution available for partition key {lookup_key}" + ) + lookup_key = self.fallback_resolver(lookup_key, available_keys) + group = self._group_lookup[lookup_key] + sampled_index = rng.choice(len(group["atomic_ids"]), p=group["probabilities"]) + assigned_atomic_ids.append(group["atomic_ids"][sampled_index]) + + result = frame.copy() + result[atomic_id_column or self.crosswalk.atomic_id_column] = assigned_atomic_ids + return result + + def _build_group_lookup(self) -> dict[PartitionKey, dict[str, np.ndarray]]: + group_lookup: dict[PartitionKey, dict[str, np.ndarray]] = {} + grouped = self.crosswalk.data.groupby( + list(self.partition_columns), dropna=False, sort=False + ) + for raw_key, group in grouped: + key_tuple = raw_key if isinstance(raw_key, tuple) else (raw_key,) + normalized_key = self._normalize_partition_key(key_tuple) + probabilities = pd.to_numeric( + group[self.probability_column], + errors="coerce", + ).fillna(0.0) + total_probability = float(probabilities.sum()) + if total_probability <= 0: + raise ValueError( + f"Partition {normalized_key} has non-positive total probability" + ) + group_lookup[normalized_key] = { + "atomic_ids": group[self.crosswalk.atomic_id_column].to_numpy(), + "probabilities": (probabilities / total_probability).to_numpy(dtype=float), + } + return group_lookup + + def _normalize_partition_key(self, raw_key: tuple[Any, ...]) -> PartitionKey: + normalized: list[Any] = [] + for column, value in zip(self.partition_columns, raw_key, strict=False): + normalizer = self.partition_normalizers.get(column) + normalized_value = normalizer(value) if normalizer is not None else value + if hasattr(normalized_value, "item"): + normalized_value = normalized_value.item() + normalized.append(normalized_value) + return tuple(normalized) + + +@dataclass(frozen=True) +class GeographyQuery: + """Generic query parameters for atomic-geography providers.""" + + geography_columns: tuple[str, ...] = () + partition_columns: tuple[str, ...] = () + probability_column: str | None = None + partition_normalizers: dict[str, PartitionNormalizer] = field(default_factory=dict) + fallback_resolver: PartitionFallbackResolver | None = None + + +@dataclass(frozen=True) +class GeographyAssignmentPlan: + """How a model should assign atomic geography ids during synthesis.""" + + partition_columns: tuple[str, ...] + atomic_id_column: str + geography_columns: tuple[str, ...] = () + probability_column: str | None = None + partition_normalizers: dict[str, PartitionNormalizer] = field(default_factory=dict) + fallback_resolver: PartitionFallbackResolver | None = None + sync_partition_columns: bool = True + + def requested_geography_columns(self) -> tuple[str, ...]: + """Columns that should be materialized after assignment.""" + ordered_columns: list[str] = [] + if self.sync_partition_columns: + ordered_columns.extend(self.partition_columns) + ordered_columns.extend(self.geography_columns) + return tuple(dict.fromkeys(ordered_columns)) + + def to_query(self) -> GeographyQuery: + """Convert the assignment plan into a provider query.""" + return GeographyQuery( + geography_columns=self.requested_geography_columns(), + partition_columns=self.partition_columns, + probability_column=self.probability_column, + partition_normalizers=dict(self.partition_normalizers), + fallback_resolver=self.fallback_resolver, + ) - if random_state is not None: - np.random.seed(random_state) - # Use within-state probabilities (prob column) - sampled_indices = np.random.choice( - len(state_df), - size=n, - replace=replace, - p=state_df["prob"].values, - ) +@runtime_checkable +class GeographyProvider(Protocol): + """Protocol for providers of atomic geography crosswalks and assigners.""" + + def load_crosswalk( + self, + query: GeographyQuery | None = None, + ) -> AtomicGeographyCrosswalk: + """Load an atomic geography crosswalk.""" + + def load_assigner( + self, + query: GeographyQuery | None = None, + ) -> ProbabilisticAtomicGeographyAssigner: + """Load a probabilistic atomic geography assigner.""" + - return state_df["geoid"].values[sampled_indices] +@dataclass +class StaticGeographyProvider: + """A geography provider backed by an in-memory atomic crosswalk.""" - def _build_state_index(self) -> None: - """Build index of blocks by state for efficient sampling.""" - self._state_blocks = {} - for state_fips, group in self.data.groupby("state_fips"): - # Store as a copy to ensure prob column is contiguous - self._state_blocks[state_fips] = group[["geoid", "prob"]].copy() + crosswalk: AtomicGeographyCrosswalk + default_partition_columns: tuple[str, ...] = () + default_partition_normalizers: dict[str, PartitionNormalizer] = field( + default_factory=dict + ) + default_fallback_resolver: PartitionFallbackResolver | None = None - def sample_blocks_national( + def load_crosswalk( self, - n: int, - replace: bool = True, - random_state: Optional[int] = None, - ) -> np.ndarray: - """ - Sample blocks nationally using population-weighted probabilities. - - Args: - n: Number of blocks to sample - replace: Sample with replacement (default True) - random_state: Random seed for reproducibility - - Returns: - Array of sampled block GEOIDs - - Example: - >>> geo = BlockGeography() - >>> blocks = geo.sample_blocks_national(n=1000, random_state=42) - >>> print(f"Sampled {len(blocks)} blocks from US") - """ - if random_state is not None: - np.random.seed(random_state) - - sampled_indices = np.random.choice( - len(self.data), - size=n, - replace=replace, - p=self.data["national_prob"].values, + query: GeographyQuery | None = None, + ) -> AtomicGeographyCrosswalk: + query = query or GeographyQuery() + geography_columns = query.geography_columns or self.crosswalk.geography_columns + probability_column = query.probability_column or self.crosswalk.probability_column + return AtomicGeographyCrosswalk( + data=self.crosswalk.data.copy(), + atomic_id_column=self.crosswalk.atomic_id_column, + geography_columns=tuple(geography_columns), + probability_column=probability_column, ) - return self.data["geoid"].values[sampled_indices] - - def get_blocks_in_state(self, state_fips: str) -> pd.DataFrame: - """ - Get all blocks in a state. - - Args: - state_fips: 2-character state FIPS code - - Returns: - DataFrame with block data for the state - - Example: - >>> geo = BlockGeography() - >>> ca_blocks = geo.get_blocks_in_state("06") - >>> print(f"California has {len(ca_blocks):,} blocks") - """ - return self.data[self.data["state_fips"] == state_fips].copy() - - def get_blocks_in_county(self, county_fips: str) -> pd.DataFrame: - """ - Get all blocks in a county. - - Args: - county_fips: 5-character county FIPS (state + county) - - Returns: - DataFrame with block data for the county - - Example: - >>> geo = BlockGeography() - >>> la_blocks = geo.get_blocks_in_county("06037") - >>> print(f"Los Angeles County has {len(la_blocks):,} blocks") - """ - state = county_fips[:STATE_GEOID_LEN] - county = county_fips[STATE_GEOID_LEN:] - return self.data[ - (self.data["state_fips"] == state) & - (self.data["county"] == county) - ].copy() - - def get_blocks_in_tract(self, tract_geoid: str) -> pd.DataFrame: - """ - Get all blocks in a tract. - - Args: - tract_geoid: 11-character tract GEOID - - Returns: - DataFrame with block data for the tract - - Example: - >>> geo = BlockGeography() - >>> tract_blocks = geo.get_blocks_in_tract("06037207302") - >>> print(f"Tract has {len(tract_blocks)} blocks") - """ - return self.data[self.data["tract_geoid"] == tract_geoid].copy() - - def get_blocks_in_cd(self, cd_id: str) -> pd.DataFrame: - """ - Get all blocks in a congressional district. - - Args: - cd_id: Congressional district ID (e.g., "CA-37") - - Returns: - DataFrame with block data for the congressional district - - Example: - >>> geo = BlockGeography() - >>> cd_blocks = geo.get_blocks_in_cd("CA-37") - >>> print(f"District has {len(cd_blocks):,} blocks") - """ - return self.data[self.data["cd_id"] == cd_id].copy() - - def get_blocks_in_sldu(self, sldu_id: str) -> pd.DataFrame: - """ - Get all blocks in a State Senate (upper chamber) district. - - Args: - sldu_id: SLDU ID (e.g., "CA-SLDU-22") - - Returns: - DataFrame with block data for the State Senate district - - Example: - >>> geo = BlockGeography() - >>> sldu_blocks = geo.get_blocks_in_sldu("CA-SLDU-22") - >>> print(f"District has {len(sldu_blocks):,} blocks") - """ - if "sldu_id" not in self.data.columns: - return pd.DataFrame() - return self.data[self.data["sldu_id"] == sldu_id].copy() - - def get_blocks_in_sldl(self, sldl_id: str) -> pd.DataFrame: - """ - Get all blocks in a State House (lower chamber) district. - - Args: - sldl_id: SLDL ID (e.g., "CA-SLDL-46") - - Returns: - DataFrame with block data for the State House district - - Example: - >>> geo = BlockGeography() - >>> sldl_blocks = geo.get_blocks_in_sldl("CA-SLDL-46") - >>> print(f"District has {len(sldl_blocks):,} blocks") - """ - if "sldl_id" not in self.data.columns: - return pd.DataFrame() - return self.data[self.data["sldl_id"] == sldl_id].copy() - - @property - def states(self) -> List[str]: - """List of all state FIPS codes in the data.""" - return sorted(self.data["state_fips"].unique()) - - @property - def n_blocks(self) -> int: - """Total number of blocks in the data.""" - return len(self.data) - - def __repr__(self) -> str: - if self._data is None: - return "BlockGeography(not loaded)" - return f"BlockGeography({self.n_blocks:,} blocks, {len(self.states)} states)" + def load_assigner( + self, + query: GeographyQuery | None = None, + ) -> ProbabilisticAtomicGeographyAssigner: + query = query or GeographyQuery() + partition_columns = query.partition_columns or self.default_partition_columns + if not partition_columns: + raise ValueError("partition_columns are required to build a geography assigner") + probability_column = query.probability_column or self.crosswalk.probability_column + return ProbabilisticAtomicGeographyAssigner( + crosswalk=self.load_crosswalk(query), + partition_columns=tuple(partition_columns), + probability_column=probability_column, + partition_normalizers=( + query.partition_normalizers or self.default_partition_normalizers + ), + fallback_resolver=( + query.fallback_resolver + if query.fallback_resolver is not None + else self.default_fallback_resolver + ), + ) + + +def _raise_missing_us_geography() -> ModuleNotFoundError: + return ModuleNotFoundError( + "US block geography helpers moved to the separate `microplex-us` package. " + "Install or add `microplex-us`, then import `microplex_us.geography`." + ) + + +try: + from microplex_us.geography import ( # noqa: F401 + BlockGeography, + derive_geographies, + load_block_probabilities, + ) +except ModuleNotFoundError as exc: + if exc.name != "microplex_us": + raise + + def load_block_probabilities(*args: Any, **kwargs: Any) -> pd.DataFrame: + raise _raise_missing_us_geography() + + def derive_geographies(*args: Any, **kwargs: Any) -> pd.DataFrame: + raise _raise_missing_us_geography() + + class BlockGeography: # type: ignore[no-redef] + """Compatibility placeholder for the moved US block geography adapter.""" + + @classmethod + def from_data(cls, data: pd.DataFrame) -> BlockGeography: + raise _raise_missing_us_geography() + + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise _raise_missing_us_geography() diff --git a/src/microplex/synthesizer.py b/src/microplex/synthesizer.py index 93634bd..1164d30 100644 --- a/src/microplex/synthesizer.py +++ b/src/microplex/synthesizer.py @@ -5,27 +5,30 @@ variables conditioned on context variables. """ +from __future__ import annotations + from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Self + import numpy as np import pandas as pd import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset -from .transforms import MultiVariableTransformer -from .flows import ConditionalMAF from .discrete import BinaryModel, DiscreteModelCollection +from .flows import ConditionalMAF +from .transforms import MultiVariableTransformer @dataclass class SynthesizerConfig: """Configuration for Synthesizer.""" - target_vars: List[str] - condition_vars: List[str] - discrete_vars: Optional[List[str]] = None + target_vars: list[str] + condition_vars: list[str] + discrete_vars: list[str] | None = None # Model architecture n_layers: int = 6 @@ -64,9 +67,9 @@ class Synthesizer: def __init__( self, - target_vars: List[str], - condition_vars: List[str], - discrete_vars: Optional[List[str]] = None, + target_vars: list[str], + condition_vars: list[str], + discrete_vars: list[str] | None = None, n_layers: int = 6, hidden_dim: int = 64, zero_inflated: bool = True, @@ -99,27 +102,52 @@ def __init__( self.sample_clipping = sample_clipping # Will be set during fit - self.transformer_: Optional[MultiVariableTransformer] = None - self.flow_model_: Optional[ConditionalMAF] = None - self.zero_indicators_: Optional[nn.ModuleDict] = None - self.discrete_model_: Optional[DiscreteModelCollection] = None + self.transformer_: MultiVariableTransformer | None = None + self.flow_model_: ConditionalMAF | None = None + self.zero_indicators_: nn.ModuleDict | None = None + self.discrete_model_: DiscreteModelCollection | None = None self.is_fitted_: bool = False - self.training_history_: List[float] = [] + self.training_history_: list[float] = [] self._actual_n_context: int = 0 # Actual context dim (may include dummy) - self._train_target_std: Optional[torch.Tensor] = None # Store target std for variance reg - self._train_target_max: Optional[torch.Tensor] = None # Store max for clipping calibration - self._original_scale_stats: Optional[Dict[str, Dict[str, float]]] = None # Original scale stats for clipping - self._training_data: Optional[pd.DataFrame] = None # Store for full synthesis + self._train_target_std: torch.Tensor | None = None # Store target std for variance reg + self._train_target_max: torch.Tensor | None = None # Store max for clipping calibration + self._original_scale_stats: dict[str, dict[str, float]] | None = None # Original scale stats for clipping + self._training_data: pd.DataFrame | None = None # Store for full synthesis + + def _build_context_tensor(self, data: pd.DataFrame) -> torch.Tensor: + if self.condition_vars: + context_np = np.column_stack([data[var].values for var in self.condition_vars]) + else: + context_np = np.zeros((len(data), 1)) + return torch.tensor(context_np, dtype=torch.float32) + + def _build_original_scale_stats( + self, + data: pd.DataFrame, + ) -> dict[str, dict[str, float]]: + stats: dict[str, dict[str, float]] = {} + for var in self.target_vars: + values = np.asarray(data[var].values, dtype=float) + positive_values = values[values > 0] + if len(positive_values) > 0: + stats[var] = { + "max": float(np.max(positive_values)), + "p99": float(np.percentile(positive_values, 99)), + "p999": float(np.percentile(positive_values, 99.9)), + } + else: + stats[var] = {"max": 1.0, "p99": 1.0, "p999": 1.0} + return stats def fit( self, data: pd.DataFrame, - weight_col: Optional[str] = "weight", + weight_col: str | None = "weight", epochs: int = 100, batch_size: int = 256, learning_rate: float = 1e-3, verbose: bool = True, - ) -> "Synthesizer": + ) -> Self: """ Fit synthesizer on training data. @@ -156,34 +184,14 @@ def fit( transformed = self.transformer_.transform(data_dict) # Store original scale statistics for adaptive clipping - self._original_scale_stats = {} - for var in self.target_vars: - values = data[var].values - positive_values = values[values > 0] - if len(positive_values) > 0: - self._original_scale_stats[var] = { - 'max': float(np.max(positive_values)), - 'p99': float(np.percentile(positive_values, 99)), - 'p999': float(np.percentile(positive_values, 99.9)), - } - else: - self._original_scale_stats[var] = {'max': 1.0, 'p99': 1.0, 'p999': 1.0} + self._original_scale_stats = self._build_original_scale_stats(data) # Prepare tensors - n_context = len(self.condition_vars) n_targets = len(self.target_vars) # Context tensor (handle empty condition_vars for unconditional generation) - if n_context > 0: - context_np = np.column_stack([ - data[var].values for var in self.condition_vars - ]) - self._actual_n_context = n_context - else: - # Unconditional: use dummy context of zeros - context_np = np.zeros((len(data), 1)) - self._actual_n_context = 1 - context = torch.tensor(context_np, dtype=torch.float32) + context = self._build_context_tensor(data) + self._actual_n_context = len(self.condition_vars) or 1 # Target tensor and observation mask # NaN values in ORIGINAL data indicate missing observations (from multi-survey stacking) @@ -511,7 +519,7 @@ def _train_discrete( def generate( self, conditions: pd.DataFrame, - seed: Optional[int] = None, + seed: int | None = None, ) -> pd.DataFrame: """ Generate synthetic target variables for given conditions. @@ -535,14 +543,7 @@ def generate( np.random.seed(seed) # Prepare context tensor (handle empty condition_vars for unconditional generation) - if len(self.condition_vars) > 0: - context_np = np.column_stack([ - conditions[var].values for var in self.condition_vars - ]) - else: - # Unconditional: use dummy context of zeros (matching fit behavior) - context_np = np.zeros((len(conditions), 1)) - context = torch.tensor(context_np, dtype=torch.float32) + context = self._build_context_tensor(conditions) # Sample from flow (with optional clipping) with torch.no_grad(): @@ -561,9 +562,10 @@ def generate( # Apply zero indicators with torch.no_grad(): + zero_indicators = self.zero_indicators_ for var in self.target_vars: - if self.zero_indicators_ and var in self.zero_indicators_: - prob_positive = self.zero_indicators_[var](context).squeeze(-1) + if zero_indicators is not None and var in zero_indicators: + prob_positive = zero_indicators[var](context).squeeze(-1) is_positive = torch.bernoulli(prob_positive).numpy() original_dict[var] = np.where( is_positive > 0.5, @@ -603,7 +605,7 @@ def generate( def sample( self, n: int, - seed: Optional[int] = None, + seed: int | None = None, ) -> pd.DataFrame: """ Generate fully synthetic records (both conditions and targets). @@ -638,7 +640,7 @@ def sample( # Generate targets conditioned on sampled conditions return self.generate(conditions, seed=seed) - def save(self, path: Union[str, Path]) -> None: + def save(self, path: str | Path) -> None: """Save fitted model to disk.""" if not self.is_fitted_: raise ValueError("Synthesizer not fitted. Call fit() first.") @@ -670,7 +672,7 @@ def save(self, path: Union[str, Path]) -> None: torch.save(state, Path(path)) @classmethod - def load(cls, path: Union[str, Path]) -> "Synthesizer": + def load(cls, path: str | Path) -> Self: """Load fitted model from disk.""" state = torch.load(Path(path), weights_only=False) diff --git a/src/microplex/targets/__init__.py b/src/microplex/targets/__init__.py index 8122fc3..aba87b9 100644 --- a/src/microplex/targets/__init__.py +++ b/src/microplex/targets/__init__.py @@ -1,41 +1,31 @@ -""" -Microplex Calibration Targets Framework +"""Target primitives for microplex.""" -General-purpose framework for calibration targets. -Country-specific targets and loaders are in microplex-sources. - -Classes: - Target: A single calibration target - TargetCategory: Categories of targets (income, benefits, demographics) - TargetsDatabase: Collection of targets with indexing - -RAC Mapping: - RACVariable: Variable definition linked to statute - RAC_VARIABLE_MAP: Mapping from variable names to RAC definitions -""" - -from microplex.targets.database import TargetsDatabase, Target, TargetCategory -from microplex.targets.rac_mapping import ( - RACVariable, - RAC_VARIABLE_MAP, - POLICYENGINE_TO_RAC, - MICRODATA_TO_RAC, - get_rac_for_target, - get_rac_for_pe_variable, - get_rac_for_microdata_column, +from microplex.targets.database import Target, TargetCategory, TargetsDatabase +from microplex.targets.provider import ( + StaticTargetProvider, + TargetProvider, + TargetQuery, + apply_target_query, +) +from microplex.targets.spec import ( + FilterOperator, + TargetAggregation, + TargetFilter, + TargetSet, + TargetSpec, ) __all__ = [ - # Core classes "TargetsDatabase", "Target", "TargetCategory", - # RAC mapping - "RACVariable", - "RAC_VARIABLE_MAP", - "POLICYENGINE_TO_RAC", - "MICRODATA_TO_RAC", - "get_rac_for_target", - "get_rac_for_pe_variable", - "get_rac_for_microdata_column", + "FilterOperator", + "TargetAggregation", + "TargetFilter", + "TargetProvider", + "TargetQuery", + "StaticTargetProvider", + "apply_target_query", + "TargetSet", + "TargetSpec", ] diff --git a/src/microplex/targets/provider.py b/src/microplex/targets/provider.py new file mode 100644 index 0000000..3c77738 --- /dev/null +++ b/src/microplex/targets/provider.py @@ -0,0 +1,75 @@ +"""Provider abstractions for canonical target specs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + +from microplex.core import EntityType +from microplex.targets.spec import TargetSet, TargetSpec + + +@dataclass(frozen=True) +class TargetQuery: + """Generic query parameters for loading canonical targets.""" + + period: int | str | None = None + entity: EntityType | str | None = None + names: tuple[str, ...] = () + metadata_filters: dict[str, Any] = field(default_factory=dict) + provider_filters: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + entity = self.entity + if entity is not None and not isinstance(entity, EntityType): + entity = EntityType(entity) + object.__setattr__(self, "entity", entity) + object.__setattr__(self, "names", tuple(self.names)) + + +def apply_target_query( + targets: TargetSet | list[TargetSpec], + query: TargetQuery | None = None, +) -> TargetSet: + """Filter a canonical target collection using generic query semantics.""" + target_set = targets if isinstance(targets, TargetSet) else TargetSet(list(targets)) + if query is None: + return target_set + + selected: list[TargetSpec] = [] + for target in target_set.targets: + if query.period is not None and target.period != query.period: + continue + if query.entity is not None and target.entity is not query.entity: + continue + if query.names and target.name not in query.names: + continue + if not _matches_metadata(target, query.metadata_filters): + continue + selected.append(target) + return TargetSet(selected) + + +@runtime_checkable +class TargetProvider(Protocol): + """Protocol for loading canonical target sets.""" + + def load_target_set(self, query: TargetQuery | None = None) -> TargetSet: + """Return a canonical target set for the requested slice.""" + + +@dataclass +class StaticTargetProvider: + """A provider backed by an in-memory canonical target set.""" + + target_set: TargetSet = field(default_factory=TargetSet) + + def load_target_set(self, query: TargetQuery | None = None) -> TargetSet: + return apply_target_query(self.target_set, query) + + +def _matches_metadata(target: TargetSpec, metadata_filters: dict[str, Any]) -> bool: + for key, expected in metadata_filters.items(): + if target.metadata.get(key) != expected: + return False + return True diff --git a/src/microplex/targets/spec.py b/src/microplex/targets/spec.py new file mode 100644 index 0000000..ee69eb7 --- /dev/null +++ b/src/microplex/targets/spec.py @@ -0,0 +1,120 @@ +"""Canonical target specification primitives for microplex.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from microplex.core import EntityType + + +class FilterOperator(str, Enum): + """Supported operators for target filters.""" + + EQ = "==" + NE = "!=" + GT = ">" + GTE = ">=" + LT = "<" + LTE = "<=" + IN = "in" + NOT_IN = "not_in" + + +class TargetAggregation(str, Enum): + """Supported target aggregation modes.""" + + COUNT = "count" + SUM = "sum" + MEAN = "mean" + + +@dataclass(frozen=True) +class TargetFilter: + """A boolean filter over a materialized feature.""" + + feature: str + operator: FilterOperator | str + value: Any + + def __post_init__(self) -> None: + object.__setattr__(self, "operator", FilterOperator(self.operator)) + + +@dataclass(frozen=True) +class TargetSpec: + """Canonical representation of a calibration target.""" + + name: str + entity: EntityType | str + value: float + period: int | str + measure: str | None = None + aggregation: TargetAggregation | str = TargetAggregation.SUM + filters: tuple[TargetFilter, ...] = () + tolerance: float | None = None + source: str | None = None + units: str | None = None + description: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + entity = self.entity + if not isinstance(entity, EntityType): + entity = EntityType(entity) + aggregation = self.aggregation + if not isinstance(aggregation, TargetAggregation): + aggregation = TargetAggregation(aggregation) + + normalized_filters = tuple( + target_filter + if isinstance(target_filter, TargetFilter) + else TargetFilter(**target_filter) + for target_filter in self.filters + ) + + if aggregation is TargetAggregation.COUNT and self.measure is not None: + raise ValueError("Count targets must not define a measure column") + + object.__setattr__(self, "entity", entity) + object.__setattr__(self, "aggregation", aggregation) + object.__setattr__(self, "filters", normalized_filters) + + @property + def required_features(self) -> tuple[str, ...]: + """Features that must be materialized to evaluate this target.""" + features = [] + if self.measure is not None: + features.append(self.measure) + features.extend(target_filter.feature for target_filter in self.filters) + ordered_unique = dict.fromkeys(features) + return tuple(ordered_unique) + + +@dataclass +class TargetSet: + """Collection helpers for canonical target specs.""" + + targets: list[TargetSpec] = field(default_factory=list) + + def add(self, target: TargetSpec) -> None: + self.targets.append(target) + + def add_many(self, targets: list[TargetSpec]) -> None: + self.targets.extend(targets) + + def for_entity(self, entity: EntityType | str) -> list[TargetSpec]: + entity_type = entity if isinstance(entity, EntityType) else EntityType(entity) + return [target for target in self.targets if target.entity is entity_type] + + def for_period(self, period: int | str) -> list[TargetSpec]: + return [target for target in self.targets if target.period == period] + + def required_features(self, entity: EntityType | str | None = None) -> tuple[str, ...]: + relevant_targets = self.targets if entity is None else self.for_entity(entity) + features: list[str] = [] + for target in relevant_targets: + features.extend(target.required_features) + ordered_unique = dict.fromkeys(features) + return tuple(ordered_unique) diff --git a/tests/core/test_sources.py b/tests/core/test_sources.py new file mode 100644 index 0000000..c158968 --- /dev/null +++ b/tests/core/test_sources.py @@ -0,0 +1,344 @@ +"""Tests for source-symmetric multientity fusion planning.""" + +import json + +import pandas as pd +import pytest + +from microplex.core import ( + BenefitUnit, + EntityObservation, + EntityRelationship, + EntityType, + ObservationFrame, + Record, + RecordType, + RelationshipCardinality, + Shareability, + SourceArchetype, + SourceColumnValueType, + SourceDescriptor, + SourceManifest, + SourceProvider, + SourceQuery, + SourceVariableCapability, + StaticSourceProvider, + TimeStructure, + load_source_manifest, +) + + +def test_record_is_first_class_entity_type(): + assert EntityType.RECORD.value == "record" + + record = Record(id="job-1", person_id="p1", record_type=RecordType.W2) + + assert record.entity_type is EntityType.RECORD + + +def test_benefit_unit_is_first_class_entity_type(): + assert EntityType.BENEFIT_UNIT.value == "benefit_unit" + + benefit_unit = BenefitUnit(id="bu-1", member_ids=["p1", "p2"], head_id="p1") + + assert benefit_unit.entity_type is EntityType.BENEFIT_UNIT + assert benefit_unit.entity_type.is_group + + +def test_source_descriptor_tracks_variables_by_entity(): + descriptor = SourceDescriptor( + name="cps", + shareability=Shareability.PUBLIC, + time_structure=TimeStructure.REPEATED_CROSS_SECTION, + archetype=SourceArchetype.HOUSEHOLD_INCOME, + observations=( + EntityObservation( + entity=EntityType.PERSON, + key_column="person_id", + variable_names=("age", "employment_income"), + ), + EntityObservation( + entity=EntityType.HOUSEHOLD, + key_column="household_id", + variable_names=("state_fips", "rent"), + weight_column="household_weight", + ), + ), + ) + + assert set(descriptor.observed_entities) == { + EntityType.PERSON, + EntityType.HOUSEHOLD, + } + assert descriptor.variables_for(EntityType.PERSON) == frozenset( + {"age", "employment_income"} + ) + assert descriptor.archetype is SourceArchetype.HOUSEHOLD_INCOME + assert descriptor.variables_for(EntityType.HOUSEHOLD) == frozenset( + {"state_fips", "rent"} + ) + assert descriptor.observes("rent", entity=EntityType.HOUSEHOLD) + assert not descriptor.observes("rent", entity=EntityType.PERSON) + + +def test_source_descriptor_tracks_variable_capabilities(): + descriptor = SourceDescriptor( + name="puf_like", + shareability=Shareability.RESTRICTED, + time_structure=TimeStructure.REPEATED_CROSS_SECTION, + observations=( + EntityObservation( + entity=EntityType.HOUSEHOLD, + key_column="household_id", + variable_names=("state_fips",), + ), + EntityObservation( + entity=EntityType.PERSON, + key_column="person_id", + variable_names=("age", "taxable_interest_income"), + ), + ), + variable_capabilities={ + "state_fips": SourceVariableCapability(usable_as_condition=False), + "taxable_interest_income": SourceVariableCapability(authoritative=True), + }, + ) + + assert descriptor.all_variable_names == frozenset( + {"state_fips", "age", "taxable_interest_income"} + ) + assert not descriptor.allows_conditioning_on("state_fips") + assert descriptor.is_authoritative_for("taxable_interest_income") + assert descriptor.is_authoritative_for("age") + + +def test_load_source_manifest_reads_typed_json(tmp_path): + manifest_path = tmp_path / "spi.json" + manifest_path.write_text( + json.dumps( + { + "name": "uk_spi", + "archetype": "tax_microdata", + "population": "UK tax units", + "description": "Tax-unit source", + "observations": [ + { + "entity": "tax_unit", + "key_column": "tax_unit_id", + "weight_column": "weight", + "period_column": "year", + "columns": [ + { + "raw_column": "FACT", + "canonical_name": "weight", + }, + { + "raw_column": "DIVIDENDS", + "canonical_name": "dividend_income", + "value_type": "numeric", + }, + { + "raw_column": "GORCODE", + "canonical_name": "region_code", + "value_type": "categorical", + }, + ], + } + ], + } + ) + ) + + manifest = load_source_manifest(manifest_path) + + assert isinstance(manifest, SourceManifest) + assert manifest.archetype is SourceArchetype.TAX_MICRODATA + observation = manifest.observation_for(EntityType.TAX_UNIT) + assert observation.weight_column == "weight" + assert observation.observed_variable_names() == ("dividend_income", "region_code") + assert observation.columns[-1].value_type is SourceColumnValueType.CATEGORICAL + + +def test_source_descriptor_rejects_unknown_capability_variables(): + with pytest.raises(ValueError, match="unknown variables"): + SourceDescriptor( + name="bad", + shareability=Shareability.PUBLIC, + time_structure=TimeStructure.REPEATED_CROSS_SECTION, + observations=( + EntityObservation( + entity=EntityType.PERSON, + key_column="person_id", + variable_names=("age",), + ), + ), + variable_capabilities={ + "state_fips": SourceVariableCapability(usable_as_condition=False), + }, + ) + + +def test_observation_frame_validates_relationships_and_builds_masks(): + descriptor = SourceDescriptor( + name="cps", + shareability=Shareability.PUBLIC, + time_structure=TimeStructure.REPEATED_CROSS_SECTION, + observations=( + EntityObservation( + entity=EntityType.PERSON, + key_column="person_id", + variable_names=("age", "employment_income"), + ), + EntityObservation( + entity=EntityType.HOUSEHOLD, + key_column="household_id", + variable_names=("state_fips", "rent"), + ), + ), + ) + households = pd.DataFrame( + { + "household_id": ["h1"], + "state_fips": ["06"], + "rent": [2400.0], + } + ) + persons = pd.DataFrame( + { + "person_id": ["p1", "p2"], + "household_id": ["h1", "h1"], + "age": [35, None], + "employment_income": [50_000.0, 0.0], + } + ) + frame = ObservationFrame( + source=descriptor, + tables={ + EntityType.HOUSEHOLD: households, + EntityType.PERSON: persons, + }, + relationships=( + EntityRelationship( + parent_entity=EntityType.HOUSEHOLD, + child_entity=EntityType.PERSON, + parent_key="household_id", + child_key="household_id", + cardinality=RelationshipCardinality.ONE_TO_MANY, + ), + ), + ) + + frame.validate() + mask = frame.observation_mask(EntityType.PERSON) + + assert mask.index.tolist() == ["p1", "p2"] + assert list(mask.columns) == ["age", "employment_income"] + assert mask["age"].tolist() == [True, False] + assert mask["employment_income"].tolist() == [True, True] + + +def test_observation_frame_rejects_missing_parent_keys(): + descriptor = SourceDescriptor( + name="cps", + shareability=Shareability.PUBLIC, + time_structure=TimeStructure.REPEATED_CROSS_SECTION, + observations=( + EntityObservation( + entity=EntityType.PERSON, + key_column="person_id", + variable_names=("age",), + ), + EntityObservation( + entity=EntityType.HOUSEHOLD, + key_column="household_id", + variable_names=("state_fips",), + ), + ), + ) + frame = ObservationFrame( + source=descriptor, + tables={ + EntityType.HOUSEHOLD: pd.DataFrame( + { + "household_id": ["h1"], + "state_fips": ["06"], + } + ), + EntityType.PERSON: pd.DataFrame( + { + "person_id": ["p1", "p2"], + "household_id": ["h1", "missing"], + "age": [35, 41], + } + ), + }, + relationships=( + EntityRelationship( + parent_entity=EntityType.HOUSEHOLD, + child_entity=EntityType.PERSON, + parent_key="household_id", + child_key="household_id", + cardinality=RelationshipCardinality.ONE_TO_MANY, + ), + ), + ) + + with pytest.raises(ValueError, match="missing parent keys"): + frame.validate() + + +def test_static_source_provider_filters_on_period_columns(): + descriptor = SourceDescriptor( + name="survey", + shareability=Shareability.PUBLIC, + time_structure=TimeStructure.REPEATED_CROSS_SECTION, + observations=( + EntityObservation( + entity=EntityType.HOUSEHOLD, + key_column="household_id", + variable_names=("state_fips",), + period_column="year", + ), + EntityObservation( + entity=EntityType.PERSON, + key_column="person_id", + variable_names=("age",), + period_column="year", + ), + ), + ) + frame = ObservationFrame( + source=descriptor, + tables={ + EntityType.HOUSEHOLD: pd.DataFrame( + { + "household_id": ["h1", "h2"], + "state_fips": ["06", "36"], + "year": [2024, 2023], + } + ), + EntityType.PERSON: pd.DataFrame( + { + "person_id": ["p1", "p2"], + "household_id": ["h1", "h2"], + "age": [35, 41], + "year": [2024, 2023], + } + ), + }, + relationships=( + EntityRelationship( + parent_entity=EntityType.HOUSEHOLD, + child_entity=EntityType.PERSON, + parent_key="household_id", + child_key="household_id", + ), + ), + ) + provider = StaticSourceProvider(frame) + + assert isinstance(provider, SourceProvider) + filtered = provider.load_frame(SourceQuery(period=2024)) + + assert filtered.tables[EntityType.HOUSEHOLD]["household_id"].tolist() == ["h1"] + assert filtered.tables[EntityType.PERSON]["person_id"].tolist() == ["p1"] diff --git a/tests/fusion/test_planning.py b/tests/fusion/test_planning.py new file mode 100644 index 0000000..706bf6f --- /dev/null +++ b/tests/fusion/test_planning.py @@ -0,0 +1,105 @@ +"""Tests for source-symmetric fusion planning.""" + +from microplex.core import ( + EntityObservation, + EntityType, + Shareability, + SourceArchetype, + SourceDescriptor, + TimeStructure, +) +from microplex.fusion import FusionPlan + + +def test_fusion_plan_is_source_symmetric_and_release_aware(): + cps = SourceDescriptor( + name="cps", + shareability=Shareability.PUBLIC, + time_structure=TimeStructure.REPEATED_CROSS_SECTION, + archetype=SourceArchetype.HOUSEHOLD_INCOME, + observations=( + EntityObservation( + entity=EntityType.PERSON, + key_column="person_id", + variable_names=("age", "employment_income"), + ), + EntityObservation( + entity=EntityType.HOUSEHOLD, + key_column="household_id", + variable_names=("state_fips",), + ), + ), + ) + puf = SourceDescriptor( + name="puf", + shareability=Shareability.NON_SHAREABLE, + time_structure=TimeStructure.CROSS_SECTION, + archetype=SourceArchetype.TAX_MICRODATA, + observations=( + EntityObservation( + entity=EntityType.PERSON, + key_column="person_id", + variable_names=("age",), + ), + EntityObservation( + entity=EntityType.TAX_UNIT, + key_column="tax_unit_id", + variable_names=("adjusted_gross_income", "capital_gains"), + ), + ), + ) + sipp = SourceDescriptor( + name="sipp", + shareability=Shareability.PUBLIC, + time_structure=TimeStructure.PANEL, + archetype=SourceArchetype.LONGITUDINAL_SOCIOECONOMIC, + observations=( + EntityObservation( + entity=EntityType.PERSON, + key_column="person_id", + variable_names=("age", "employment_income", "is_disabled"), + ), + EntityObservation( + entity=EntityType.RECORD, + key_column="record_id", + variable_names=("job_hours", "job_wages"), + ), + ), + ) + + plan = FusionPlan.from_sources([cps, puf, sipp]) + + assert set(plan.output_entities) == { + EntityType.PERSON, + EntityType.HOUSEHOLD, + EntityType.TAX_UNIT, + EntityType.RECORD, + } + + employment_income = plan.variable_plan( + entity=EntityType.PERSON, + variable_name="employment_income", + ) + assert set(employment_income.sources) == {"cps", "sipp"} + assert employment_income.publicly_observed + assert not employment_income.requires_synthetic_release + assert plan.sources_for_archetype(SourceArchetype.HOUSEHOLD_INCOME) == ("cps",) + assert plan.sources_for_archetype(SourceArchetype.TAX_MICRODATA) == ("puf",) + + adjusted_gross_income = plan.variable_plan( + entity=EntityType.TAX_UNIT, + variable_name="adjusted_gross_income", + ) + assert set(adjusted_gross_income.sources) == {"puf"} + assert not adjusted_gross_income.publicly_observed + assert adjusted_gross_income.requires_synthetic_release + assert plan.variables_requiring_synthetic_release(EntityType.TAX_UNIT) == ( + frozenset({"adjusted_gross_income", "capital_gains"}) + ) + + job_hours = plan.variable_plan( + entity=EntityType.RECORD, + variable_name="job_hours", + ) + assert set(job_hours.sources) == {"sipp"} + assert job_hours.publicly_observed diff --git a/tests/targets/test_provider.py b/tests/targets/test_provider.py new file mode 100644 index 0000000..96caacd --- /dev/null +++ b/tests/targets/test_provider.py @@ -0,0 +1,61 @@ +"""Tests for canonical target providers.""" + +from microplex.core import EntityType +from microplex.targets import ( + StaticTargetProvider, + TargetAggregation, + TargetProvider, + TargetQuery, + TargetSet, + TargetSpec, + apply_target_query, +) + + +def test_apply_target_query_filters_target_set(): + target_set = TargetSet( + [ + TargetSpec( + name="ca_people", + entity=EntityType.PERSON, + value=2.0, + period=2024, + aggregation=TargetAggregation.COUNT, + metadata={"kind": "admin"}, + ), + TargetSpec( + name="ny_people", + entity=EntityType.PERSON, + value=3.0, + period=2023, + aggregation=TargetAggregation.COUNT, + metadata={"kind": "survey"}, + ), + ] + ) + + selected = apply_target_query( + target_set, + TargetQuery( + period=2024, + entity=EntityType.PERSON, + names=("ca_people",), + metadata_filters={"kind": "admin"}, + ), + ) + + assert selected.targets == [target_set.targets[0]] + + +def test_static_target_provider_implements_protocol(): + target = TargetSpec( + name="ca_people", + entity=EntityType.PERSON, + value=2.0, + period=2024, + aggregation=TargetAggregation.COUNT, + ) + provider = StaticTargetProvider(TargetSet([target])) + + assert isinstance(provider, TargetProvider) + assert provider.load_target_set(TargetQuery(entity=EntityType.PERSON)).targets == [target] diff --git a/tests/targets/test_spec.py b/tests/targets/test_spec.py new file mode 100644 index 0000000..a34aeb7 --- /dev/null +++ b/tests/targets/test_spec.py @@ -0,0 +1,89 @@ +"""Tests for the canonical target specification primitives.""" + +from microplex.core import EntityType +from microplex.targets import ( + FilterOperator, + TargetAggregation, + TargetFilter, + TargetSet, + TargetSpec, +) + + +class TestTargetFilter: + def test_operator_normalizes_from_string(self): + target_filter = TargetFilter(feature="snap", operator=">", value=0) + + assert target_filter.operator is FilterOperator.GT + + +class TestTargetSpec: + def test_entity_and_aggregation_normalize_from_strings(self): + target = TargetSpec( + name="snap_recipients", + entity="spm_unit", + value=100.0, + period=2024, + aggregation="count", + filters=(TargetFilter(feature="snap", operator=">", value=0),), + ) + + assert target.entity is EntityType.SPM_UNIT + assert target.aggregation is TargetAggregation.COUNT + + def test_count_target_rejects_measure(self): + try: + TargetSpec( + name="bad_target", + entity=EntityType.HOUSEHOLD, + value=1.0, + period=2024, + measure="income", + aggregation=TargetAggregation.COUNT, + ) + except ValueError as exc: + assert "Count targets" in str(exc) + else: + raise AssertionError("Expected ValueError for count target with measure") + + def test_required_features_deduplicates_and_preserves_order(self): + target = TargetSpec( + name="california_snap", + entity=EntityType.HOUSEHOLD, + value=1_000.0, + period=2024, + measure="snap", + aggregation=TargetAggregation.SUM, + filters=( + TargetFilter(feature="state_fips", operator="==", value="06"), + TargetFilter(feature="snap", operator=">", value=0), + ), + ) + + assert target.required_features == ("snap", "state_fips") + + +class TestTargetSet: + def test_collection_helpers(self): + targets = TargetSet( + targets=[ + TargetSpec( + name="households", + entity=EntityType.HOUSEHOLD, + value=10.0, + period=2024, + aggregation=TargetAggregation.COUNT, + ), + TargetSpec( + name="people", + entity=EntityType.PERSON, + value=20.0, + period=2025, + aggregation=TargetAggregation.COUNT, + ), + ] + ) + + assert len(targets.for_entity(EntityType.HOUSEHOLD)) == 1 + assert len(targets.for_period(2025)) == 1 + assert targets.required_features() == () diff --git a/tests/test_geography_shims.py b/tests/test_geography_shims.py new file mode 100644 index 0000000..b0d38a1 --- /dev/null +++ b/tests/test_geography_shims.py @@ -0,0 +1,50 @@ +"""Compatibility coverage for moved US geography helpers.""" + +from __future__ import annotations + +import importlib.util + +import pandas as pd +import pytest + +from microplex.geography import BlockGeography, derive_geographies + +HAS_MICROPLEX_US = importlib.util.find_spec("microplex_us") is not None + + +def _sample_block_frame() -> pd.DataFrame: + return pd.DataFrame( + { + "geoid": ["060010001001001", "360610001001001"], + "state_fips": ["06", "36"], + "county": ["001", "061"], + "tract": ["000100", "000100"], + "tract_geoid": ["06001000100", "36061000100"], + "cd_id": ["CA-13", "NY-12"], + "prob": [0.6, 1.0], + "national_prob": [0.4, 0.6], + } + ) + + +def test_block_geography_from_data_shim() -> None: + if not HAS_MICROPLEX_US: + with pytest.raises(ModuleNotFoundError, match="microplex-us"): + BlockGeography.from_data(_sample_block_frame()) + return + + geography = BlockGeography.from_data(_sample_block_frame()) + assigned = geography.assign(pd.DataFrame({"state_fips": ["06", "36"]}), random_state=1) + assert "block_geoid" in assigned.columns + assert assigned["block_geoid"].str.startswith(("06", "36")).all() + + +def test_derive_geographies_shim() -> None: + if not HAS_MICROPLEX_US: + with pytest.raises(ModuleNotFoundError, match="microplex-us"): + derive_geographies(["060010001001001", "360610001001001"]) + return + + result = derive_geographies(["060010001001001", "360610001001001"]) + assert list(result["state_fips"]) == ["06", "36"] + assert list(result["county_fips"]) == ["06001", "36061"] diff --git a/tests/test_package_surface.py b/tests/test_package_surface.py new file mode 100644 index 0000000..ffee24d --- /dev/null +++ b/tests/test_package_surface.py @@ -0,0 +1,18 @@ +"""Package-surface regression tests for the core engine.""" + +from __future__ import annotations + +import microplex + + +def test_top_level_package_does_not_export_us_specific_helpers() -> None: + assert not hasattr(microplex, "load_cps_asec") + assert not hasattr(microplex, "load_cps_for_synthesis") + assert not hasattr(microplex, "create_sample_data") + assert not hasattr(microplex, "get_data_info") + assert not hasattr(microplex, "CPSSummaryStats") + assert not hasattr(microplex, "CPSSyntheticGenerator") + assert not hasattr(microplex, "validate_synthetic") + assert not hasattr(microplex, "BlockGeography") + assert not hasattr(microplex, "load_block_probabilities") + assert not hasattr(microplex, "derive_geographies") diff --git a/tests/test_synthesizer.py b/tests/test_synthesizer.py index 803ab5a..37c04dc 100644 --- a/tests/test_synthesizer.py +++ b/tests/test_synthesizer.py @@ -8,10 +8,9 @@ 4. Save and load models """ -import pytest import numpy as np import pandas as pd -import torch +import pytest class TestSynthesizerInit: @@ -99,6 +98,28 @@ def test_fit_learns_transforms(self, sample_data): assert synth.transformer_ is not None assert "income" in synth.transformer_.transformers_ + def test_fit_handles_boolean_zero_inflated_targets(self): + """Fit should handle boolean-valued zero-inflated targets without percentile errors.""" + from microplex import Synthesizer + + data = pd.DataFrame( + { + "age": [25, 40, 55, 32, 61, 47], + "owns_asset": [False, True, False, True, True, False], + "weight": np.ones(6), + } + ) + + synth = Synthesizer( + target_vars=["owns_asset"], + condition_vars=["age"], + discrete_vars=["owns_asset"], + ) + + synth.fit(data, epochs=2, verbose=False) + + assert synth.is_fitted_ + def test_fit_trains_flow(self, sample_data): """Fit should train the normalizing flow.""" from microplex import Synthesizer @@ -327,7 +348,7 @@ def test_variance_ratio_in_acceptable_range(self, high_variance_data): variance_ratio = mean_synthetic_var / real_var - print(f"\nVariance Ratio Test:") + print("\nVariance Ratio Test:") print(f" Real variance: {real_var:,.0f}") print(f" Synthetic variance (mean of 5): {mean_synthetic_var:,.0f}") print(f" Variance ratio: {variance_ratio:.3f}") @@ -389,7 +410,7 @@ def test_variance_ratio_multiple_variables(self, high_variance_data): else: variance_ratios[var] = 1.0 - print(f"\nMulti-Variable Variance Ratios (vs training data):") + print("\nMulti-Variable Variance Ratios (vs training data):") for var, ratio in variance_ratios.items(): print(f" {var}: {ratio:.3f}")