Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@ jobs:
run: |
pytest tests/ -v --tb=short

- name: Run linter
- name: Run focused linter
run: |
ruff check src/microplex/core src/microplex/targets src/microplex/supabase_targets.py tests/targets tests/test_package_surface.py tests/test_supabase_targets.py

- name: Run advisory full linter
continue-on-error: true
run: |
ruff check src/

- name: Type check
- name: Run advisory type check
continue-on-error: true
run: |
mypy src/microplex/
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ classifiers = [
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
Expand Down Expand Up @@ -103,14 +102,14 @@ addopts = "-v --tb=short"

[tool.ruff]
line-length = 88
target-version = "py310"
target-version = "py311"

[tool.ruff.lint]
select = ["E", "F", "I", "N", "W", "UP"]
ignore = ["E501"] # Line length handled by formatter

[tool.mypy]
python_version = "3.10"
python_version = "3.11"
warn_return_any = true
warn_unused_configs = true
ignore_missing_imports = true
10 changes: 6 additions & 4 deletions scripts/load_pe_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@

# Supabase connection
SUPABASE_URL = "https://nsupqhfchdtqclomlrgs.supabase.co"
SUPABASE_KEY = os.environ.get(
"COSILICO_SUPABASE_SERVICE_KEY",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im5zdXBxaGZjaGR0cWNsb21scmdzIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc2NjkzMTEwOCwiZXhwIjoyMDgyNTA3MTA4fQ.IZX2C6dM6CCuxzBeg3zoZSA31p_jy9XLjdxjaE126BU"
)
SUPABASE_KEY = os.environ.get("COSILICO_SUPABASE_SERVICE_KEY")

PE_BASE = "https://raw.githubusercontent.com/PolicyEngine/policyengine-us-data/main/policyengine_us_data/storage/calibration_targets"

Expand All @@ -41,6 +38,11 @@ class BatchSupabaseClient:
"""Supabase client optimized for batch operations."""

def __init__(self, url: str, key: str, schema: str = "microplex"):
if not key:
raise ValueError(
"COSILICO_SUPABASE_SERVICE_KEY must be set before loading "
"PolicyEngine calibration targets."
)
self.base_url = f"{url}/rest/v1"
self.headers = {
"apikey": key,
Expand Down
10 changes: 6 additions & 4 deletions scripts/run_supabase_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,12 @@ def __init__(self):
"SUPABASE_URL",
"https://nsupqhfchdtqclomlrgs.supabase.co"
)
self.key = os.environ.get(
"COSILICO_SUPABASE_SERVICE_KEY",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im5zdXBxaGZjaGR0cWNsb21scmdzIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc2NjkzMTEwOCwiZXhwIjoyMDgyNTA3MTA4fQ.IZX2C6dM6CCuxzBeg3zoZSA31p_jy9XLjdxjaE126BU"
)
self.key = os.environ.get("COSILICO_SUPABASE_SERVICE_KEY")
if not self.key:
raise ValueError(
"COSILICO_SUPABASE_SERVICE_KEY must be set before running "
"Supabase calibration."
)
self.base_url = f"{self.url}/rest/v1"
self.headers = {
"apikey": self.key,
Expand Down
3 changes: 2 additions & 1 deletion src/microplex/core/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@

import json
import shutil
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Mapping
from typing import Any

import pandas as pd

Expand Down
4 changes: 2 additions & 2 deletions src/microplex/core/periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from __future__ import annotations

from collections.abc import Iterator
from datetime import date
from enum import Enum
from functools import total_ordering
from typing import Iterator

from pydantic import BaseModel, Field, model_validator

Expand Down Expand Up @@ -59,7 +59,7 @@ class Period(BaseModel):
model_config = {"frozen": True}

@model_validator(mode="after")
def validate_period_consistency(self) -> "Period":
def validate_period_consistency(self) -> Period:
"""Ensure period components match period type."""
if self.period_type == PeriodType.YEAR:
if self.month is not None or self.day is not None or self.quarter is not None:
Expand Down
3 changes: 0 additions & 3 deletions src/microplex/core/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from dataclasses import dataclass
from enum import Enum
from typing import Literal

import numpy as np

Expand Down Expand Up @@ -225,7 +224,6 @@ def compress_dataset(
# Target moments from features
target_vector = np.array(list(targets.values()))

best_weights = weights.copy()
best_loss = float("inf")

for iteration in range(max_iterations):
Expand Down Expand Up @@ -270,7 +268,6 @@ def compress_dataset(

if total_loss < best_loss:
best_loss = total_loss
best_weights = effective_weights.copy()

if iteration % 100 == 0:
n_active = np.sum(gates > 0.01)
Expand Down
60 changes: 45 additions & 15 deletions src/microplex/supabase_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,25 @@

Provides SupabaseTargetLoader for loading PE calibration targets from the
microplex Supabase schema and mapping them to CPS columns for calibration.

Deprecated:
This is US-specific compatibility code and will move to microplex-us.
"""

from __future__ import annotations

import os
import warnings
from typing import Any

import requests
from typing import List, Dict, Any, Optional

warnings.warn(
"microplex.supabase_targets is US-specific compatibility code and will move "
"to microplex-us.",
DeprecationWarning,
stacklevel=2,
)


class SupabaseTargetLoader:
Expand Down Expand Up @@ -60,7 +74,12 @@ class SupabaseTargetLoader:
"56": "wy"
}

def __init__(self, url: str = None, key: str = None, schema: str = "microplex"):
def __init__(
self,
url: str | None = None,
key: str | None = None,
schema: str = "microplex",
):
"""Initialize the loader.

Args:
Expand All @@ -72,10 +91,12 @@ def __init__(self, url: str = None, key: str = None, schema: str = "microplex"):
"SUPABASE_URL",
"https://nsupqhfchdtqclomlrgs.supabase.co"
)
self.key = key or os.environ.get(
"COSILICO_SUPABASE_SERVICE_KEY",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im5zdXBxaGZjaGR0cWNsb21scmdzIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc2NjkzMTEwOCwiZXhwIjoyMDgyNTA3MTA4fQ.IZX2C6dM6CCuxzBeg3zoZSA31p_jy9XLjdxjaE126BU"
)
self.key = key or os.environ.get("COSILICO_SUPABASE_SERVICE_KEY")
if not self.key:
raise ValueError(
"Supabase service key must be provided via the key argument or "
"COSILICO_SUPABASE_SERVICE_KEY."
)
self.base_url = f"{self.url}/rest/v1"
self.headers = {
"apikey": self.key,
Expand All @@ -86,7 +107,12 @@ def __init__(self, url: str = None, key: str = None, schema: str = "microplex"):
}
self._cache = {}

def _get(self, endpoint: str, params: Dict = None, paginate: bool = True) -> List[Dict]:
def _get(
self,
endpoint: str,
params: dict[str, Any] | None = None,
paginate: bool = True,
) -> list[dict[str, Any]]:
"""Make a GET request to Supabase with optional pagination.

Args:
Expand Down Expand Up @@ -128,7 +154,7 @@ def _get(self, endpoint: str, params: Dict = None, paginate: bool = True) -> Lis

return all_results

def load_all(self, period: int = None) -> List[Dict]:
def load_all(self, period: int | None = None) -> list[dict[str, Any]]:
"""Load all targets with source and stratum info.

Args:
Expand All @@ -146,7 +172,11 @@ def load_all(self, period: int = None) -> List[Dict]:

return self._get("targets", params)

def load_by_institution(self, institution: str, period: int = None) -> List[Dict]:
def load_by_institution(
self,
institution: str,
period: int | None = None,
) -> list[dict[str, Any]]:
"""Load targets from a specific institution.

Args:
Expand All @@ -173,7 +203,7 @@ def load_by_institution(self, institution: str, period: int = None) -> List[Dict

return self._get("targets", params)

def load_by_period(self, period: int) -> List[Dict]:
def load_by_period(self, period: int) -> list[dict[str, Any]]:
"""Load targets for a specific year.

Args:
Expand All @@ -184,15 +214,15 @@ def load_by_period(self, period: int) -> List[Dict]:
"""
return self.load_all(period=period)

def get_cps_column_map(self) -> Dict[str, str]:
def get_cps_column_map(self) -> dict[str, str]:
"""Get the mapping from Supabase variable names to CPS columns.

Returns:
Dict mapping variable -> CPS column name.
"""
return self.CPS_COLUMN_MAP.copy()

def _parse_jurisdiction(self, jurisdiction: str) -> Optional[str]:
def _parse_jurisdiction(self, jurisdiction: str) -> str | None:
"""Parse jurisdiction to get state code if applicable.

Args:
Expand Down Expand Up @@ -221,8 +251,8 @@ def build_calibration_constraints(
self,
period: int = 2024,
include_states: bool = False,
target_types: List[str] = None
) -> Dict[str, float]:
target_types: list[str] | None = None,
) -> dict[str, float]:
"""Build calibration constraint dict from Supabase targets.

Args:
Expand Down Expand Up @@ -268,7 +298,7 @@ def build_calibration_constraints(

return constraints

def get_summary(self) -> Dict[str, Any]:
def get_summary(self) -> dict[str, Any]:
"""Get summary of available targets in Supabase.

Returns:
Expand Down
43 changes: 21 additions & 22 deletions src/microplex/targets/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Union
import pandas as pd

import numpy as np
from pathlib import Path
import pandas as pd


class TargetCategory(Enum):
Expand Down Expand Up @@ -50,14 +49,14 @@ class Target:
value: float
year: int
source: str
source_url: Optional[str] = None
source_url: str | None = None

# Geographic scope
geography: str = "US" # US, state FIPS, county FIPS
state_fips: Optional[str] = None
state_fips: str | None = None

# Filtering dimensions
filing_status: Optional[str] = None # All, Single, MFJ, MFS, HOH
filing_status: str | None = None # All, Single, MFJ, MFS, HOH
agi_lower: float = -np.inf
agi_upper: float = np.inf

Expand All @@ -66,15 +65,15 @@ class Target:
is_taxable_only: bool = False

# RAC mapping
rac_variable: Optional[str] = None # e.g., "adjusted_gross_income"
rac_statute: Optional[str] = None # e.g., "26/62"
rac_variable: str | None = None # e.g., "adjusted_gross_income"
rac_statute: str | None = None # e.g., "26/62"

# Microdata column mapping
microdata_column: Optional[str] = None # Column in CPS/PUF
microdata_column: str | None = None # Column in CPS/PUF

# Metadata
notes: Optional[str] = None
last_updated: Optional[str] = None
notes: str | None = None
last_updated: str | None = None


@dataclass
Expand All @@ -85,9 +84,9 @@ class TargetsDatabase:
Maintains parity with PolicyEngine targets while adding
RAC variable mappings for Cosilico integration.
"""
targets: List[Target] = field(default_factory=list)
_by_category: Dict[TargetCategory, List[Target]] = field(default_factory=dict)
_by_geography: Dict[str, List[Target]] = field(default_factory=dict)
targets: list[Target] = field(default_factory=list)
_by_category: dict[TargetCategory, list[Target]] = field(default_factory=dict)
_by_geography: dict[str, list[Target]] = field(default_factory=dict)

def add(self, target: Target):
"""Add a target to the database."""
Expand All @@ -103,28 +102,28 @@ def add(self, target: Target):
self._by_geography[target.geography] = []
self._by_geography[target.geography].append(target)

def add_many(self, targets: List[Target]):
def add_many(self, targets: list[Target]):
"""Add multiple targets."""
for t in targets:
self.add(t)

def get_by_category(self, category: TargetCategory) -> List[Target]:
def get_by_category(self, category: TargetCategory) -> list[Target]:
"""Get all targets in a category."""
return self._by_category.get(category, [])

def get_by_geography(self, geography: str) -> List[Target]:
def get_by_geography(self, geography: str) -> list[Target]:
"""Get all targets for a geography."""
return self._by_geography.get(geography, [])

def get_national(self) -> List[Target]:
def get_national(self) -> list[Target]:
"""Get national-level targets."""
return self.get_by_geography("US")

def get_state(self, state_fips: str) -> List[Target]:
def get_state(self, state_fips: str) -> list[Target]:
"""Get state-level targets."""
return [t for t in self.targets if t.state_fips == state_fips]

def get_with_rac_mapping(self) -> List[Target]:
def get_with_rac_mapping(self) -> list[Target]:
"""Get targets that have RAC variable mappings."""
return [t for t in self.targets if t.rac_variable is not None]

Expand Down Expand Up @@ -154,7 +153,7 @@ def to_calibration_format(
self,
geography: str = "US",
year: int = 2021,
) -> tuple[Dict[str, Dict], Dict[str, float]]:
) -> tuple[dict[str, dict], dict[str, float]]:
"""
Convert to microplex calibration format.

Expand Down Expand Up @@ -209,7 +208,7 @@ def compare_to_policyengine(self, pe_targets: pd.DataFrame) -> pd.DataFrame:

return comparison

def coverage_summary(self) -> Dict[str, int]:
def coverage_summary(self) -> dict[str, int]:
"""Summarize target coverage by category."""
summary = {}
for cat in TargetCategory:
Expand Down
Loading
Loading