diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3a2118b..064c52f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,12 +30,16 @@ jobs: run: | pytest tests/ -v --tb=short - - name: Run linter + - name: Run focused linter + run: | + ruff check src/microplex/core src/microplex/targets src/microplex/supabase_targets.py tests/targets tests/test_package_surface.py tests/test_supabase_targets.py + + - name: Run advisory full linter continue-on-error: true run: | ruff check src/ - - name: Type check + - name: Run advisory type check continue-on-error: true run: | mypy src/microplex/ diff --git a/pyproject.toml b/pyproject.toml index 3e86a38..baced51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", @@ -103,14 +102,14 @@ addopts = "-v --tb=short" [tool.ruff] line-length = 88 -target-version = "py310" +target-version = "py311" [tool.ruff.lint] select = ["E", "F", "I", "N", "W", "UP"] ignore = ["E501"] # Line length handled by formatter [tool.mypy] -python_version = "3.10" +python_version = "3.11" warn_return_any = true warn_unused_configs = true ignore_missing_imports = true diff --git a/scripts/load_pe_targets.py b/scripts/load_pe_targets.py index 0279c26..24c63f7 100644 --- a/scripts/load_pe_targets.py +++ b/scripts/load_pe_targets.py @@ -14,10 +14,7 @@ # Supabase connection SUPABASE_URL = "https://nsupqhfchdtqclomlrgs.supabase.co" -SUPABASE_KEY = os.environ.get( - "COSILICO_SUPABASE_SERVICE_KEY", - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im5zdXBxaGZjaGR0cWNsb21scmdzIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc2NjkzMTEwOCwiZXhwIjoyMDgyNTA3MTA4fQ.IZX2C6dM6CCuxzBeg3zoZSA31p_jy9XLjdxjaE126BU" -) +SUPABASE_KEY = os.environ.get("COSILICO_SUPABASE_SERVICE_KEY") PE_BASE = "https://raw.githubusercontent.com/PolicyEngine/policyengine-us-data/main/policyengine_us_data/storage/calibration_targets" @@ -41,6 +38,11 @@ class BatchSupabaseClient: """Supabase client optimized for batch operations.""" def __init__(self, url: str, key: str, schema: str = "microplex"): + if not key: + raise ValueError( + "COSILICO_SUPABASE_SERVICE_KEY must be set before loading " + "PolicyEngine calibration targets." + ) self.base_url = f"{url}/rest/v1" self.headers = { "apikey": key, diff --git a/scripts/run_supabase_calibration.py b/scripts/run_supabase_calibration.py index 9987ab8..8caa41b 100644 --- a/scripts/run_supabase_calibration.py +++ b/scripts/run_supabase_calibration.py @@ -109,10 +109,12 @@ def __init__(self): "SUPABASE_URL", "https://nsupqhfchdtqclomlrgs.supabase.co" ) - self.key = os.environ.get( - "COSILICO_SUPABASE_SERVICE_KEY", - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im5zdXBxaGZjaGR0cWNsb21scmdzIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc2NjkzMTEwOCwiZXhwIjoyMDgyNTA3MTA4fQ.IZX2C6dM6CCuxzBeg3zoZSA31p_jy9XLjdxjaE126BU" - ) + self.key = os.environ.get("COSILICO_SUPABASE_SERVICE_KEY") + if not self.key: + raise ValueError( + "COSILICO_SUPABASE_SERVICE_KEY must be set before running " + "Supabase calibration." + ) self.base_url = f"{self.url}/rest/v1" self.headers = { "apikey": self.key, diff --git a/src/microplex/core/checkpoints.py b/src/microplex/core/checkpoints.py index 1e378dd..6343922 100644 --- a/src/microplex/core/checkpoints.py +++ b/src/microplex/core/checkpoints.py @@ -31,8 +31,9 @@ import json import shutil +from collections.abc import Mapping from pathlib import Path -from typing import Any, Mapping +from typing import Any import pandas as pd diff --git a/src/microplex/core/periods.py b/src/microplex/core/periods.py index 5f6efd9..18576bc 100644 --- a/src/microplex/core/periods.py +++ b/src/microplex/core/periods.py @@ -6,10 +6,10 @@ from __future__ import annotations +from collections.abc import Iterator from datetime import date from enum import Enum from functools import total_ordering -from typing import Iterator from pydantic import BaseModel, Field, model_validator @@ -59,7 +59,7 @@ class Period(BaseModel): model_config = {"frozen": True} @model_validator(mode="after") - def validate_period_consistency(self) -> "Period": + def validate_period_consistency(self) -> Period: """Ensure period components match period type.""" if self.period_type == PeriodType.YEAR: if self.month is not None or self.day is not None or self.quarter is not None: diff --git a/src/microplex/core/resolution.py b/src/microplex/core/resolution.py index 7486e96..4cbaf79 100644 --- a/src/microplex/core/resolution.py +++ b/src/microplex/core/resolution.py @@ -7,7 +7,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Literal import numpy as np @@ -225,7 +224,6 @@ def compress_dataset( # Target moments from features target_vector = np.array(list(targets.values())) - best_weights = weights.copy() best_loss = float("inf") for iteration in range(max_iterations): @@ -270,7 +268,6 @@ def compress_dataset( if total_loss < best_loss: best_loss = total_loss - best_weights = effective_weights.copy() if iteration % 100 == 0: n_active = np.sum(gates > 0.01) diff --git a/src/microplex/supabase_targets.py b/src/microplex/supabase_targets.py index eb9e2ba..782677d 100644 --- a/src/microplex/supabase_targets.py +++ b/src/microplex/supabase_targets.py @@ -3,11 +3,25 @@ Provides SupabaseTargetLoader for loading PE calibration targets from the microplex Supabase schema and mapping them to CPS columns for calibration. + +Deprecated: + This is US-specific compatibility code and will move to microplex-us. """ +from __future__ import annotations + import os +import warnings +from typing import Any + import requests -from typing import List, Dict, Any, Optional + +warnings.warn( + "microplex.supabase_targets is US-specific compatibility code and will move " + "to microplex-us.", + DeprecationWarning, + stacklevel=2, +) class SupabaseTargetLoader: @@ -60,7 +74,12 @@ class SupabaseTargetLoader: "56": "wy" } - def __init__(self, url: str = None, key: str = None, schema: str = "microplex"): + def __init__( + self, + url: str | None = None, + key: str | None = None, + schema: str = "microplex", + ): """Initialize the loader. Args: @@ -72,10 +91,12 @@ def __init__(self, url: str = None, key: str = None, schema: str = "microplex"): "SUPABASE_URL", "https://nsupqhfchdtqclomlrgs.supabase.co" ) - self.key = key or os.environ.get( - "COSILICO_SUPABASE_SERVICE_KEY", - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im5zdXBxaGZjaGR0cWNsb21scmdzIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc2NjkzMTEwOCwiZXhwIjoyMDgyNTA3MTA4fQ.IZX2C6dM6CCuxzBeg3zoZSA31p_jy9XLjdxjaE126BU" - ) + self.key = key or os.environ.get("COSILICO_SUPABASE_SERVICE_KEY") + if not self.key: + raise ValueError( + "Supabase service key must be provided via the key argument or " + "COSILICO_SUPABASE_SERVICE_KEY." + ) self.base_url = f"{self.url}/rest/v1" self.headers = { "apikey": self.key, @@ -86,7 +107,12 @@ def __init__(self, url: str = None, key: str = None, schema: str = "microplex"): } self._cache = {} - def _get(self, endpoint: str, params: Dict = None, paginate: bool = True) -> List[Dict]: + def _get( + self, + endpoint: str, + params: dict[str, Any] | None = None, + paginate: bool = True, + ) -> list[dict[str, Any]]: """Make a GET request to Supabase with optional pagination. Args: @@ -128,7 +154,7 @@ def _get(self, endpoint: str, params: Dict = None, paginate: bool = True) -> Lis return all_results - def load_all(self, period: int = None) -> List[Dict]: + def load_all(self, period: int | None = None) -> list[dict[str, Any]]: """Load all targets with source and stratum info. Args: @@ -146,7 +172,11 @@ def load_all(self, period: int = None) -> List[Dict]: return self._get("targets", params) - def load_by_institution(self, institution: str, period: int = None) -> List[Dict]: + def load_by_institution( + self, + institution: str, + period: int | None = None, + ) -> list[dict[str, Any]]: """Load targets from a specific institution. Args: @@ -173,7 +203,7 @@ def load_by_institution(self, institution: str, period: int = None) -> List[Dict return self._get("targets", params) - def load_by_period(self, period: int) -> List[Dict]: + def load_by_period(self, period: int) -> list[dict[str, Any]]: """Load targets for a specific year. Args: @@ -184,7 +214,7 @@ def load_by_period(self, period: int) -> List[Dict]: """ return self.load_all(period=period) - def get_cps_column_map(self) -> Dict[str, str]: + def get_cps_column_map(self) -> dict[str, str]: """Get the mapping from Supabase variable names to CPS columns. Returns: @@ -192,7 +222,7 @@ def get_cps_column_map(self) -> Dict[str, str]: """ return self.CPS_COLUMN_MAP.copy() - def _parse_jurisdiction(self, jurisdiction: str) -> Optional[str]: + def _parse_jurisdiction(self, jurisdiction: str) -> str | None: """Parse jurisdiction to get state code if applicable. Args: @@ -221,8 +251,8 @@ def build_calibration_constraints( self, period: int = 2024, include_states: bool = False, - target_types: List[str] = None - ) -> Dict[str, float]: + target_types: list[str] | None = None, + ) -> dict[str, float]: """Build calibration constraint dict from Supabase targets. Args: @@ -268,7 +298,7 @@ def build_calibration_constraints( return constraints - def get_summary(self) -> Dict[str, Any]: + def get_summary(self) -> dict[str, Any]: """Get summary of available targets in Supabase. Returns: diff --git a/src/microplex/targets/database.py b/src/microplex/targets/database.py index e898296..cbc7155 100644 --- a/src/microplex/targets/database.py +++ b/src/microplex/targets/database.py @@ -7,10 +7,9 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List, Optional, Union -import pandas as pd + import numpy as np -from pathlib import Path +import pandas as pd class TargetCategory(Enum): @@ -50,14 +49,14 @@ class Target: value: float year: int source: str - source_url: Optional[str] = None + source_url: str | None = None # Geographic scope geography: str = "US" # US, state FIPS, county FIPS - state_fips: Optional[str] = None + state_fips: str | None = None # Filtering dimensions - filing_status: Optional[str] = None # All, Single, MFJ, MFS, HOH + filing_status: str | None = None # All, Single, MFJ, MFS, HOH agi_lower: float = -np.inf agi_upper: float = np.inf @@ -66,15 +65,15 @@ class Target: is_taxable_only: bool = False # RAC mapping - rac_variable: Optional[str] = None # e.g., "adjusted_gross_income" - rac_statute: Optional[str] = None # e.g., "26/62" + rac_variable: str | None = None # e.g., "adjusted_gross_income" + rac_statute: str | None = None # e.g., "26/62" # Microdata column mapping - microdata_column: Optional[str] = None # Column in CPS/PUF + microdata_column: str | None = None # Column in CPS/PUF # Metadata - notes: Optional[str] = None - last_updated: Optional[str] = None + notes: str | None = None + last_updated: str | None = None @dataclass @@ -85,9 +84,9 @@ class TargetsDatabase: Maintains parity with PolicyEngine targets while adding RAC variable mappings for Cosilico integration. """ - targets: List[Target] = field(default_factory=list) - _by_category: Dict[TargetCategory, List[Target]] = field(default_factory=dict) - _by_geography: Dict[str, List[Target]] = field(default_factory=dict) + targets: list[Target] = field(default_factory=list) + _by_category: dict[TargetCategory, list[Target]] = field(default_factory=dict) + _by_geography: dict[str, list[Target]] = field(default_factory=dict) def add(self, target: Target): """Add a target to the database.""" @@ -103,28 +102,28 @@ def add(self, target: Target): self._by_geography[target.geography] = [] self._by_geography[target.geography].append(target) - def add_many(self, targets: List[Target]): + def add_many(self, targets: list[Target]): """Add multiple targets.""" for t in targets: self.add(t) - def get_by_category(self, category: TargetCategory) -> List[Target]: + def get_by_category(self, category: TargetCategory) -> list[Target]: """Get all targets in a category.""" return self._by_category.get(category, []) - def get_by_geography(self, geography: str) -> List[Target]: + def get_by_geography(self, geography: str) -> list[Target]: """Get all targets for a geography.""" return self._by_geography.get(geography, []) - def get_national(self) -> List[Target]: + def get_national(self) -> list[Target]: """Get national-level targets.""" return self.get_by_geography("US") - def get_state(self, state_fips: str) -> List[Target]: + def get_state(self, state_fips: str) -> list[Target]: """Get state-level targets.""" return [t for t in self.targets if t.state_fips == state_fips] - def get_with_rac_mapping(self) -> List[Target]: + def get_with_rac_mapping(self) -> list[Target]: """Get targets that have RAC variable mappings.""" return [t for t in self.targets if t.rac_variable is not None] @@ -154,7 +153,7 @@ def to_calibration_format( self, geography: str = "US", year: int = 2021, - ) -> tuple[Dict[str, Dict], Dict[str, float]]: + ) -> tuple[dict[str, dict], dict[str, float]]: """ Convert to microplex calibration format. @@ -209,7 +208,7 @@ def compare_to_policyengine(self, pe_targets: pd.DataFrame) -> pd.DataFrame: return comparison - def coverage_summary(self) -> Dict[str, int]: + def coverage_summary(self) -> dict[str, int]: """Summarize target coverage by category.""" summary = {} for cat in TargetCategory: diff --git a/src/microplex/targets/rac_mapping.py b/src/microplex/targets/rac_mapping.py index 7fef945..88cf5bd 100644 --- a/src/microplex/targets/rac_mapping.py +++ b/src/microplex/targets/rac_mapping.py @@ -6,7 +6,6 @@ """ from dataclasses import dataclass -from typing import Dict, Optional @dataclass @@ -22,7 +21,7 @@ class RACVariable: # Map from target variable names to RAC definitions # Based on cosilico-us/statute structure -RAC_VARIABLE_MAP: Dict[str, RACVariable] = { +RAC_VARIABLE_MAP: dict[str, RACVariable] = { # Income (IRC Section 61 - Gross Income) "adjusted_gross_income": RACVariable( name="adjusted_gross_income", @@ -366,7 +365,7 @@ class RACVariable: # Map from PolicyEngine variable names to our RAC variables -POLICYENGINE_TO_RAC: Dict[str, str] = { +POLICYENGINE_TO_RAC: dict[str, str] = { "adjusted_gross_income": "adjusted_gross_income", "irs_employment_income": "employment_income", "self_employment_income": "self_employment_income", @@ -397,7 +396,7 @@ class RACVariable: # Map from microdata column names (CPS/PUF) to RAC variables -MICRODATA_TO_RAC: Dict[str, str] = { +MICRODATA_TO_RAC: dict[str, str] = { # CPS columns "wage_income": "employment_income", "self_employment_income": "self_employment_income", @@ -424,12 +423,12 @@ class RACVariable: } -def get_rac_for_target(target_name: str) -> Optional[RACVariable]: +def get_rac_for_target(target_name: str) -> RACVariable | None: """Get RAC variable definition for a target name.""" return RAC_VARIABLE_MAP.get(target_name) -def get_rac_for_pe_variable(pe_variable: str) -> Optional[RACVariable]: +def get_rac_for_pe_variable(pe_variable: str) -> RACVariable | None: """Get RAC variable for a PolicyEngine variable name.""" rac_name = POLICYENGINE_TO_RAC.get(pe_variable) if rac_name: @@ -437,7 +436,7 @@ def get_rac_for_pe_variable(pe_variable: str) -> Optional[RACVariable]: return None -def get_rac_for_microdata_column(column: str) -> Optional[RACVariable]: +def get_rac_for_microdata_column(column: str) -> RACVariable | None: """Get RAC variable for a microdata column name.""" rac_name = MICRODATA_TO_RAC.get(column) if rac_name: diff --git a/src/microplex/targets/spec.py b/src/microplex/targets/spec.py index ee69eb7..94e0768 100644 --- a/src/microplex/targets/spec.py +++ b/src/microplex/targets/spec.py @@ -3,13 +3,13 @@ from __future__ import annotations from dataclasses import dataclass, field -from enum import Enum +from enum import StrEnum from typing import Any from microplex.core import EntityType -class FilterOperator(str, Enum): +class FilterOperator(StrEnum): """Supported operators for target filters.""" EQ = "==" @@ -22,7 +22,7 @@ class FilterOperator(str, Enum): NOT_IN = "not_in" -class TargetAggregation(str, Enum): +class TargetAggregation(StrEnum): """Supported target aggregation modes.""" COUNT = "count" diff --git a/tests/test_package_surface.py b/tests/test_package_surface.py index ffee24d..82b2089 100644 --- a/tests/test_package_surface.py +++ b/tests/test_package_surface.py @@ -13,6 +13,7 @@ def test_top_level_package_does_not_export_us_specific_helpers() -> None: assert not hasattr(microplex, "CPSSummaryStats") assert not hasattr(microplex, "CPSSyntheticGenerator") assert not hasattr(microplex, "validate_synthetic") + assert not hasattr(microplex, "SupabaseTargetLoader") assert not hasattr(microplex, "BlockGeography") assert not hasattr(microplex, "load_block_probabilities") assert not hasattr(microplex, "derive_geographies") diff --git a/tests/test_supabase_targets.py b/tests/test_supabase_targets.py index 85e9e1e..5d080d5 100644 --- a/tests/test_supabase_targets.py +++ b/tests/test_supabase_targets.py @@ -7,19 +7,18 @@ 3. Calibration constraints can be built from targets """ -import pytest -import responses -import json - +import importlib.util import sys from pathlib import Path +import pytest +import responses + # Direct import to avoid torch dependency in __init__.py src_path = Path(__file__).parent.parent / "src" / "microplex" sys.path.insert(0, str(src_path.parent)) # Import directly to avoid package __init__.py -import importlib.util spec = importlib.util.spec_from_file_location("supabase_targets", src_path / "supabase_targets.py") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) @@ -37,6 +36,13 @@ class TestSupabaseTargetLoader: def loader(self): return SupabaseTargetLoader(SUPABASE_URL, SUPABASE_KEY) + def test_missing_service_key_raises(self, monkeypatch): + """Should never fall back to an embedded service-role key.""" + monkeypatch.delenv("COSILICO_SUPABASE_SERVICE_KEY", raising=False) + + with pytest.raises(ValueError, match="COSILICO_SUPABASE_SERVICE_KEY"): + SupabaseTargetLoader(SUPABASE_URL) + @responses.activate def test_load_all_targets(self, loader): """Should load all targets with source and stratum info.""" @@ -242,8 +248,8 @@ def loader(self): @pytest.mark.skip(reason="Integration test requires real Supabase connection") def test_calibration_with_supabase_targets(self, loader): """End-to-end test: load targets from Supabase and run calibration.""" - import pandas as pd import numpy as np + import pandas as pd try: # Direct import to avoid torch dependency import importlib.util @@ -253,7 +259,7 @@ def test_calibration_with_supabase_targets(self, loader): ) cal_module = importlib.util.module_from_spec(cal_spec) cal_spec.loader.exec_module(cal_module) - Calibrator = cal_module.Calibrator + calibrator_cls = cal_module.Calibrator except Exception as e: pytest.skip(f"Cannot import Calibrator: {e}") @@ -279,7 +285,7 @@ def test_calibration_with_supabase_targets(self, loader): pytest.skip("No matching targets for test data") # Run calibration - calibrator = Calibrator(method="ipf", max_iter=100) + calibrator = calibrator_cls(method="ipf", max_iter=100) calibrator.fit(df, marginal_targets={}, continuous_targets=available, weight_col="weight") assert calibrator.weights_ is not None