diff --git a/.github/workflows/pipeline.yaml b/.github/workflows/pipeline.yaml index f557ed2e3..a906121ab 100644 --- a/.github/workflows/pipeline.yaml +++ b/.github/workflows/pipeline.yaml @@ -33,6 +33,22 @@ on: description: "Override version (default: read from pyproject.toml)" default: "" type: string + chunked_matrix: + description: "Build the calibration matrix in chunks (opt-in)" + default: false + type: boolean + chunk_size: + description: "Clone-household columns per chunk" + default: "25000" + type: string + parallel_matrix: + description: "Fan chunked matrix building across Modal workers" + default: false + type: boolean + num_matrix_workers: + description: "Number of Modal workers for parallel matrix build" + default: "50" + type: string concurrency: group: pipeline-main @@ -68,6 +84,10 @@ jobs: SKIP_NATIONAL="${{ inputs.skip_national || 'false' }}" RESUME_RUN_ID="${{ inputs.resume_run_id || '' }}" VERSION_OVERRIDE="${{ inputs.version_override || '' }}" + CHUNKED_MATRIX="${{ inputs.chunked_matrix || 'false' }}" + CHUNK_SIZE="${{ inputs.chunk_size || '25000' }}" + PARALLEL_MATRIX="${{ inputs.parallel_matrix || 'false' }}" + NUM_MATRIX_WORKERS="${{ inputs.num_matrix_workers || '50' }}" python -c " import modal @@ -81,6 +101,10 @@ jobs: skip_national='${SKIP_NATIONAL}' == 'true', resume_run_id='${RESUME_RUN_ID}' or None, version_override='${VERSION_OVERRIDE}' or '', + chunked_matrix='${CHUNKED_MATRIX}' == 'true', + chunk_size=int('${CHUNK_SIZE}'), + parallel_matrix='${PARALLEL_MATRIX}' == 'true', + num_matrix_workers=int('${NUM_MATRIX_WORKERS}'), ) print(f'Pipeline spawned.') print(f'Function call ID: {fc.object_id}') diff --git a/changelog.d/818.changed.md b/changelog.d/818.changed.md new file mode 100644 index 000000000..b83a4a658 --- /dev/null +++ b/changelog.d/818.changed.md @@ -0,0 +1,3 @@ +Extract `ChunkedMatrixAssembler` from `UnifiedMatrixBuilder.build_matrix_chunked` and replace the list-then-concat final assembly with a two-pass streaming CSR build. The facade signature and all existing chunked-matrix behaviour are unchanged. + +Parallelize chunked matrix building across Modal workers. Adds `dispatch_chunks_modal` (`policyengine_us_data/calibration/chunked_matrix_modal.py`) that pickles `SharedBuildState` to the pipeline volume, fans contiguous chunk-id batches to `build_matrix_chunk_worker` (`modal_app/matrix_chunk_worker.py`, registered on the `policyengine-us-data-fit-weights` app), and streams the final CSR from shards on the volume. New CLI flags on `unified_calibration`: `--parallel` (default off) and `--num-matrix-workers` (default 50). `build_package_remote` threads `run_id` to the subprocess via the `POLICYENGINE_US_DATA_RUN_ID` env var and forwards `--parallel` / `--num-matrix-workers` when `parallel_matrix=True`. `--parallel` without `--chunked-matrix` logs an info message and runs the non-chunked path unchanged. diff --git a/modal_app/matrix_chunk_worker.py b/modal_app/matrix_chunk_worker.py new file mode 100644 index 000000000..b9aa0a760 --- /dev/null +++ b/modal_app/matrix_chunk_worker.py @@ -0,0 +1,119 @@ +"""Modal worker that materializes a batch of matrix chunks. + +The ``@app.function`` decorator here attaches to ``_calibration_app`` +(declared as ``policyengine-us-data-fit-weights`` in +``remote_calibration_runner.py``), alongside ``build_package_remote``. +At deploy time, ``modal_app/pipeline.py`` merges that app into the +pipeline app via ``app.include(_calibration_app)``, so after +``modal deploy modal_app/pipeline.py`` the function is registered +under ``policyengine-us-data-pipeline`` in Modal's registry — that's +the name ``dispatch_chunks_modal`` uses in its +``modal.Function.from_name`` lookup. + +Each worker reads the shared ``ChunkedMatrixAssembler`` state from +``pipeline_volume``, materializes its assigned chunks to COO shard +files on the volume, and commits. The coordinator reads the shards +back after all workers finish and streams them into the final CSR +matrix. +""" + +from __future__ import annotations + +import pickle +import sys +import traceback +from pathlib import Path +from typing import Dict, List + +_baked = "/root/policyengine-us-data" +_local = str(Path(__file__).resolve().parent.parent) +for _p in (_baked, _local): + if _p not in sys.path: + sys.path.insert(0, _p) + +from modal_app.images import cpu_image # noqa: E402 +from modal_app.remote_calibration_runner import ( # noqa: E402 + PIPELINE_MOUNT, + app, + hf_secret, + pipeline_vol, +) + + +def _chunk_root(run_id: str) -> str: + return f"{PIPELINE_MOUNT}/artifacts/{run_id}/matrix_build" + + +@app.function( + image=cpu_image, + secrets=[hf_secret], + volumes={PIPELINE_MOUNT: pipeline_vol}, + memory=16384, + cpu=1.0, + timeout=28800, + max_containers=50, + nonpreemptible=True, +) +def build_matrix_chunk_worker(run_id: str, chunk_ids: List[int]) -> Dict: + """Materialize ``chunk_ids`` from the pickled ``SharedBuildState``. + + Args: + run_id: Pipeline run identifier; selects the volume path for + this worker's shared state and shard output directory. + chunk_ids: Chunk indices this worker is responsible for. + + Returns: + Dict with ``chunk_ids``, ``nnz_per_chunk``, and ``errors`` + lists suitable for the coordinator to aggregate. + """ + from policyengine_us_data.calibration.chunked_matrix_assembler import ( + ChunkedMatrixAssembler, + ) + + pipeline_vol.reload() + chunk_root = Path(_chunk_root(run_id)) + state_path = chunk_root / "chunk_build_state.pkl" + if not state_path.exists(): + return { + "chunk_ids": list(chunk_ids), + "nnz_per_chunk": [], + "errors": [ + { + "chunk_ids": list(chunk_ids), + "error": f"Missing shared state at {state_path}", + } + ], + } + + with open(state_path, "rb") as f: + shared_state = pickle.load(f) + + assembler = ChunkedMatrixAssembler( + shared_state=shared_state, + chunk_root=chunk_root, + chunk_size=shared_state.chunk_size, + resume=True, + keep_chunks=False, + ) + + errors: List[Dict] = [] + nnz_per_chunk: List[int] = [] + for chunk_id in chunk_ids: + try: + result = assembler.run_single_chunk(chunk_id) + nnz_per_chunk.append(result.nnz) + except Exception as exc: + errors.append( + { + "chunk_id": chunk_id, + "error": str(exc), + "traceback": traceback.format_exc(), + } + ) + + pipeline_vol.commit() + return { + "chunk_ids": list(chunk_ids), + "nnz_per_chunk": nnz_per_chunk, + "errors": errors, + } diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index c02d6f10e..74a810b74 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -284,6 +284,13 @@ def _record_step( PACKAGE_GPU_FUNCTIONS, ) +# Import registers ``build_matrix_chunk_worker`` on ``_calibration_app`` +# so a single ``modal deploy modal_app/pipeline.py`` also deploys the +# worker via ``app.include(_calibration_app)`` below. Without this the +# dispatch layer's ``modal.Function.from_name`` lookup would fail at +# runtime. +from modal_app.matrix_chunk_worker import build_matrix_chunk_worker # noqa: F401 + app.include(_calibration_app) from modal_app.local_area import app as _local_area_app @@ -681,6 +688,10 @@ def run_pipeline( resume_run_id: str = None, clear_checkpoints: bool = False, version_override: str = "", + chunked_matrix: bool = False, + chunk_size: int = 25_000, + parallel_matrix: bool = False, + num_matrix_workers: int = 50, ) -> str: """Run the full pipeline end-to-end. @@ -699,6 +710,15 @@ def run_pipeline( scoped by commit SHA, so stale ones from other commits are cleaned automatically. Use True only to force a full rebuild of the current commit. + chunked_matrix: Build the calibration matrix in clone-household + chunks instead of the non-chunked path. Opt-in; default off. + chunk_size: Clone-household columns per chunk when + ``chunked_matrix`` is True. + parallel_matrix: Fan chunked matrix building across Modal + workers via ``build_matrix_chunk_worker``. Only meaningful + when ``chunked_matrix`` is True; ignored otherwise. + num_matrix_workers: Number of Modal workers when + ``parallel_matrix`` is True. Returns: The run ID for use with promote. @@ -832,6 +852,10 @@ def run_pipeline( workers=num_workers, n_clones=n_clones, run_id=run_id, + chunked_matrix=chunked_matrix, + chunk_size=chunk_size, + parallel_matrix=parallel_matrix, + num_matrix_workers=num_matrix_workers, ) print(f" Package at: {pkg_path}") diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index f66113bdc..3b372fecb 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -363,6 +363,10 @@ def _build_package_impl( workers: int = 8, n_clones: int = 430, run_id: str = "", + chunked_matrix: bool = False, + chunk_size: int = 25_000, + parallel_matrix: bool = False, + num_matrix_workers: int = 50, ) -> str: """Read data from pipeline volume, build X matrix, save package.""" _setup_repo() @@ -401,10 +405,26 @@ def _build_package_impl( if workers > 1: cmd.extend(["--workers", str(workers)]) cmd.extend(["--n-clones", str(n_clones)]) + if chunked_matrix: + cmd.extend(["--chunked-matrix", "--chunk-size", str(chunk_size)]) + if parallel_matrix: + cmd.extend( + [ + "--parallel", + "--num-matrix-workers", + str(num_matrix_workers), + ] + ) + build_env = os.environ.copy() + if run_id: + # ``unified_calibration.py`` reads this env var so workers can + # locate their shared state at {pipeline-artifacts}/{run_id}/ + # matrix_build/chunk_build_state.pkl on the pipeline volume. + build_env["POLICYENGINE_US_DATA_RUN_ID"] = run_id build_rc, build_lines = _run_streaming( cmd, - env=os.environ.copy(), + env=build_env, label="build", ) if build_rc != 0: @@ -443,6 +463,10 @@ def build_package_remote( workers: int = 8, n_clones: int = 430, run_id: str = "", + chunked_matrix: bool = False, + chunk_size: int = 25_000, + parallel_matrix: bool = False, + num_matrix_workers: int = 50, ) -> str: return _build_package_impl( branch, @@ -451,6 +475,10 @@ def build_package_remote( workers=workers, n_clones=n_clones, run_id=run_id, + chunked_matrix=chunked_matrix, + chunk_size=chunk_size, + parallel_matrix=parallel_matrix, + num_matrix_workers=num_matrix_workers, ) diff --git a/policyengine_us_data/calibration/chunked_matrix_assembler.py b/policyengine_us_data/calibration/chunked_matrix_assembler.py new file mode 100644 index 000000000..59b6699bb --- /dev/null +++ b/policyengine_us_data/calibration/chunked_matrix_assembler.py @@ -0,0 +1,549 @@ +"""Coordinator for chunked sparse calibration matrix building. + +Extracted from ``UnifiedMatrixBuilder.build_matrix_chunked`` so per-chunk +work, final assembly, and any future parallel dispatch share one +well-tested seam. Phase-1 scope: in-process serial execution and +streaming CSR assembly. A later phase will add a Modal dispatch +function that constructs ``ChunkedMatrixAssembler`` on each worker and +calls ``run_chunks`` with its assigned chunk ids. +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Set, Tuple + +import numpy as np +from scipy import sparse + +logger = logging.getLogger(__name__) + + +@dataclass +class ChunkPlan: + """Identity and output paths for one column chunk.""" + + chunk_id: int + col_start: int + col_end: int + coo_path: Path + h5_path: Path + + +@dataclass +class ChunkResult: + """Per-chunk summary returned after materialization. + + ``nnz`` is the number of nonzero entries written to the COO shard. + ``cached`` is true when the shard already existed and was reused + under resume semantics (in which case no kernel work ran). + """ + + chunk_id: int + nnz: int + cached: bool = False + n_households: Optional[int] = None + n_persons: Optional[int] = None + unique_states: Optional[int] = None + unique_counties: Optional[int] = None + unique_cds: Optional[int] = None + + +@dataclass +class SharedBuildState: + """Read-only state every chunk consumes. + + Pickle-clean: only data, no bound instance methods. A Modal worker + can unpickle this and reconstruct a ``ChunkedMatrixAssembler`` + without access to the originating ``UnifiedMatrixBuilder``. + """ + + source_dataset_path: str + time_period: int + rerandomize_takeup: bool + n_records: int + n_clones: int + n_targets: int + chunk_size: int + target_variables: List[str] + target_reform_ids: List[int] + target_geo_info: List[Tuple[str, str]] + non_geo_constraints_list: List[List[dict]] + unique_variables: Set[str] + unique_constraint_vars: Set[str] + reform_variables: Set[str] + target_names: List[str] + base_entity_maps: object + block_geoid: np.ndarray + cd_geoid: np.ndarray + county_fips: np.ndarray + state_fips: np.ndarray + + @property + def n_total(self) -> int: + return self.n_records * self.n_clones + + +def partition_chunks( + n_total: int, chunk_size: int, coo_dir: Path, h5_dir: Path +) -> List[ChunkPlan]: + """Split ``n_total`` columns into ``ChunkPlan`` objects of ``chunk_size``. + + The last chunk may be smaller. ``chunk_size`` must be positive and + ``n_total`` must be non-negative. + """ + if chunk_size <= 0: + raise ValueError("chunk_size must be positive") + if n_total < 0: + raise ValueError("n_total must be non-negative") + plans: List[ChunkPlan] = [] + chunk_id = 0 + for col_start in range(0, n_total, chunk_size): + col_end = min(col_start + chunk_size, n_total) + plans.append( + ChunkPlan( + chunk_id=chunk_id, + col_start=col_start, + col_end=col_end, + coo_path=coo_dir / f"chunk_{chunk_id:06d}.npz", + h5_path=h5_dir / f"chunk_{chunk_id:06d}.h5", + ) + ) + chunk_id += 1 + return plans + + +def stream_csr_from_shards( + shard_dir: Path, + n_chunks: int, + n_targets: int, + n_total: int, +) -> sparse.csr_matrix: + """Assemble a CSR matrix from per-chunk COO ``.npz`` shards without + materializing a full COO triple or scipy's internal COO->CSR copy. + + Two passes over shards: pass 1 counts per-row nonzeros across all + shards to compute ``indptr``; pass 2 scatters each shard's entries + into preallocated ``data``/``indices`` arrays at the right offsets. + + Peak memory during pass 2 is one shard plus the final CSR arrays. + """ + row_nnz = np.zeros(n_targets, dtype=np.int64) + shard_paths: List[Path] = [] + for chunk_id in range(n_chunks): + path = shard_dir / f"chunk_{chunk_id:06d}.npz" + shard_paths.append(path) + with np.load(str(path)) as shard: + rows = shard["rows"] + if rows.size == 0: + continue + counts = np.bincount(rows.astype(np.int64), minlength=n_targets) + row_nnz += counts + + total_nnz = int(row_nnz.sum()) + indptr = np.empty(n_targets + 1, dtype=np.int64) + indptr[0] = 0 + np.cumsum(row_nnz, out=indptr[1:]) + + data = np.empty(total_nnz, dtype=np.float32) + indices = np.empty(total_nnz, dtype=np.int32) + row_cursor = indptr[:-1].copy() + + for path in shard_paths: + with np.load(str(path)) as shard: + rows = shard["rows"] + if rows.size == 0: + continue + cols = shard["cols"] + vals = shard["vals"] + # Group entries by row within the shard so we can write + # contiguous slices per row instead of looping entry-by-entry. + order = np.argsort(rows, kind="stable") + rows_sorted = rows[order] + cols_sorted = cols[order] + vals_sorted = vals[order] + unique_rows, starts, counts = np.unique( + rows_sorted, return_index=True, return_counts=True + ) + for row, start, count in zip(unique_rows, starts, counts): + offset = int(row_cursor[row]) + end = start + count + data[offset : offset + count] = vals_sorted[start:end] + indices[offset : offset + count] = cols_sorted[start:end] + row_cursor[row] += count + + # scipy requires indptr/indices to be int32 for canonical CSR; cast + # once at the end. indices are already int32; indptr may need to be + # downcast if total_nnz fits. + if indptr[-1] <= np.iinfo(np.int32).max: + indptr_final = indptr.astype(np.int32) + else: + indptr_final = indptr + X = sparse.csr_matrix( + (data, indices, indptr_final), + shape=(n_targets, n_total), + ) + X.sort_indices() + return X + + +class ChunkedMatrixAssembler: + """Coordinate partitioning, per-chunk execution, and streaming assembly. + + Serial execution today; a Modal dispatch function can construct one + of these per worker container and call ``run_chunks`` with the + worker's assigned chunk ids. + + This class is a deliberate precursor to the ``MatrixAssembler`` + extraction described in ``US Data Pipeline Refactor.md`` Phase 4. + It owns chunking/assembly today and will absorb target repository, + simulation batching, and constraint evaluation responsibilities as + that refactor lands. + """ + + def __init__( + self, + shared_state: SharedBuildState, + chunk_root: Path, + chunk_size: int, + resume: bool, + keep_chunks: bool, + base_sim=None, + ): + self.shared_state = shared_state + self.chunk_root = Path(chunk_root) + self.coo_dir = self.chunk_root / "coo" + self.h5_dir = self.chunk_root / "h5" + self.coo_dir.mkdir(parents=True, exist_ok=True) + self.h5_dir.mkdir(parents=True, exist_ok=True) + self.chunk_size = chunk_size + self.resume = resume + self.keep_chunks = keep_chunks + self.plans: List[ChunkPlan] = partition_chunks( + shared_state.n_total, chunk_size, self.coo_dir, self.h5_dir + ) + self.n_chunks: int = len(self.plans) + # ``base_sim`` is the source ``Microsimulation`` whose household + # arrays are sliced by ``materialize_clone_household_chunk``. The + # facade builds it once and passes it in; a Modal worker in + # phase 2 would construct it from ``source_dataset_path`` on the + # volume (unpicklable, so not part of ``SharedBuildState``). + self._base_sim = base_sim + + def run_chunks(self, chunk_ids: Iterable[int]) -> List[ChunkResult]: + """Materialize the given chunks serially, honoring resume skip.""" + ids = list(chunk_ids) + results: List[ChunkResult] = [] + t_build = time.time() + processed_times: List[float] = [] + cached_chunks = 0 + for i, chunk_id in enumerate(ids): + t0 = time.time() + result = self.run_single_chunk(chunk_id) + results.append(result) + if result.cached: + cached_chunks += 1 + else: + processed_times.append(time.time() - t0) + remaining = len(ids) - (i + 1) + if processed_times: + avg = float(np.mean(processed_times)) + eta = avg * remaining + else: + avg = 0.0 + eta = 0.0 + elapsed = time.time() - t_build + plan = self.plans[chunk_id] + if result.cached: + logger.info( + "Chunk %d/%d cached: cols %d-%d, cached=%d", + chunk_id + 1, + self.n_chunks, + plan.col_start, + plan.col_end - 1, + cached_chunks, + ) + else: + from policyengine_us_data.calibration.unified_matrix_builder import ( + _current_rss_mb, + _format_duration, + ) + + rss = _current_rss_mb() + rss_part = f", rss={rss:,.0f} MB" if rss is not None else "" + logger.info( + "Chunk %d/%d: cols %d-%d, hh=%s, persons=%s, " + "states=%s, counties=%s, cds=%s, nnz=%d, " + "chunk=%s, avg=%s, elapsed=%s, eta=%s%s", + chunk_id + 1, + self.n_chunks, + plan.col_start, + plan.col_end - 1, + result.n_households, + result.n_persons, + result.unique_states, + result.unique_counties, + result.unique_cds, + result.nnz, + _format_duration(time.time() - t0), + _format_duration(avg), + _format_duration(elapsed), + _format_duration(eta), + rss_part, + ) + return results + + def run_single_chunk(self, chunk_id: int) -> ChunkResult: + """Run one chunk's kernel: materialize H5, simulate, write shard. + + If ``resume=True`` and a valid shard already exists at the + expected ``coo_path``, the kernel is skipped and a cached + ``ChunkResult`` is returned. + """ + plan = self.plans[chunk_id] + state = self.shared_state + + if self.resume and plan.coo_path.exists(): + with np.load(str(plan.coo_path)) as cached_chunk: + if "col_start" not in cached_chunk or "col_end" not in cached_chunk: + raise ValueError( + f"Cached chunk {plan.coo_path} is missing " + "col_start/col_end metadata" + ) + cached_col_start = int(np.asarray(cached_chunk["col_start"]).item()) + cached_col_end = int(np.asarray(cached_chunk["col_end"]).item()) + cached_nnz = int(cached_chunk["rows"].shape[0]) + if cached_col_start != plan.col_start or cached_col_end != plan.col_end: + raise ValueError( + f"Cached chunk {plan.coo_path} covers cols " + f"{cached_col_start}-{cached_col_end - 1}, " + f"expected {plan.col_start}-{plan.col_end - 1}" + ) + return ChunkResult(chunk_id=chunk_id, nnz=cached_nnz, cached=True) + + # Imports are local so the module is import-safe in lightweight + # environments (e.g., cold Modal containers that haven't yet + # run ``uv sync`` for the heavy deps). + from policyengine_us import Microsimulation + + from policyengine_us_data.calibration.entity_clone import ( + materialize_clone_household_chunk, + ) + from policyengine_us_data.calibration.unified_matrix_builder import ( + _build_entity_index_maps, + _calculate_target_values_standalone, + _make_neutralize_variable_reform, + build_entity_relationship, + ) + + global_cols = np.arange(plan.col_start, plan.col_end, dtype=np.int64) + active_hh = global_cols % state.n_records + active_clone_indices = global_cols // state.n_records + active_blocks = np.asarray(state.block_geoid)[global_cols] + active_cd_geoids = np.asarray(state.cd_geoid, dtype=str)[global_cols] + active_states = np.asarray(state.state_fips)[global_cols] + active_counties = np.asarray(state.county_fips, dtype=str)[global_cols] + + if self._base_sim is None: + self._base_sim = Microsimulation(dataset=state.source_dataset_path) + summary = materialize_clone_household_chunk( + sim=self._base_sim, + entity_maps=state.base_entity_maps, + active_hh=active_hh, + active_blocks=active_blocks, + active_cd_geoids=active_cd_geoids, + active_clone_indices=active_clone_indices, + output_path=plan.h5_path, + apply_takeup=state.rerandomize_takeup, + ) + + chunk_sim = Microsimulation(dataset=str(plan.h5_path)) + chunk_n = len(global_cols) + entity_rel = build_entity_relationship(chunk_sim) + household_ids = chunk_sim.calculate("household_id", map_to="household").values + entity_hh_idx_map, person_to_entity_idx_map = _build_entity_index_maps( + entity_rel, household_ids, chunk_sim + ) + + variable_entity_map: Dict[str, str] = {} + hh_vars: Dict[str, np.ndarray] = {} + target_entity_vars: Dict[str, np.ndarray] = {} + for variable in sorted(state.unique_variables): + if variable in chunk_sim.tax_benefit_system.variables: + variable_entity_map[variable] = chunk_sim.tax_benefit_system.variables[ + variable + ].entity.key + if variable.endswith("_count"): + continue + try: + hh_vars[variable] = chunk_sim.calculate( + variable, state.time_period, map_to="household" + ).values.astype(np.float32) + except Exception as exc: + logger.warning( + "Chunk %d cannot calculate target '%s': %s", + chunk_id, + variable, + exc, + ) + entity_key = variable_entity_map.get(variable, "household") + if entity_key == "household": + continue + try: + target_entity_vars[variable] = chunk_sim.calculate( + variable, state.time_period, map_to=entity_key + ).values.astype(np.float32) + except Exception as exc: + logger.warning( + "Chunk %d cannot calculate entity-level target '%s' " + "(map_to=%s): %s", + chunk_id, + variable, + entity_key, + exc, + ) + + person_vars: Dict[str, np.ndarray] = {} + for variable in sorted(state.unique_constraint_vars): + try: + raw = chunk_sim.calculate( + variable, state.time_period, map_to="person" + ).values + try: + person_vars[variable] = raw.astype(np.float32) + except (ValueError, TypeError): + person_vars[variable] = raw + except Exception as exc: + logger.warning( + "Chunk %d cannot calculate constraint '%s': %s", + chunk_id, + variable, + exc, + ) + + reform_hh_vars: Dict[str, np.ndarray] = {} + if state.reform_variables: + baseline_income_tax = chunk_sim.calculate( + "income_tax", state.time_period, map_to="household" + ).values.astype(np.float32) + for variable in sorted(state.reform_variables): + try: + reform_sim = Microsimulation( + dataset=str(plan.h5_path), + reform=_make_neutralize_variable_reform(variable), + ) + reform_income_tax = reform_sim.calculate( + "income_tax", state.time_period, map_to="household" + ).values.astype(np.float32) + reform_hh_vars[variable] = reform_income_tax - baseline_income_tax + except Exception as exc: + logger.warning( + "Chunk %d cannot calculate reform target '%s': %s", + chunk_id, + variable, + exc, + ) + + target_value_cache: Dict[tuple, np.ndarray] = {} + rows_list: List[np.ndarray] = [] + cols_list: List[np.ndarray] = [] + vals_list: List[np.ndarray] = [] + + for row_idx in range(state.n_targets): + variable = state.target_variables[row_idx] + reform_id = state.target_reform_ids[row_idx] + geo_level, geo_id = state.target_geo_info[row_idx] + non_geo = state.non_geo_constraints_list[row_idx] + + if geo_level == "district": + geo_mask = active_cd_geoids == str(geo_id) + elif geo_level == "state": + geo_mask = active_states.astype(np.int64) == int(geo_id) + elif geo_level == "county": + geo_mask = active_counties == str(geo_id).zfill(5) + else: + geo_mask = np.ones(chunk_n, dtype=bool) + + if not geo_mask.any(): + continue + + constraint_key = tuple( + sorted((c["variable"], c["operation"], c["value"]) for c in non_geo) + ) + value_key = (variable, constraint_key, reform_id) + if value_key not in target_value_cache: + target_value_cache[value_key] = _calculate_target_values_standalone( + target_variable=variable, + non_geo_constraints=non_geo, + n_households=chunk_n, + hh_vars=hh_vars, + reform_hh_vars=reform_hh_vars, + target_entity_vars=target_entity_vars, + person_vars=person_vars, + entity_rel=entity_rel, + household_ids=household_ids, + variable_entity_map=variable_entity_map, + entity_hh_idx_map=entity_hh_idx_map, + person_to_entity_idx_map=person_to_entity_idx_map, + reform_id=reform_id, + ) + values = target_value_cache[value_key] + + vals = values[geo_mask] + nonzero = vals != 0 + if nonzero.any(): + rows_list.append(np.full(nonzero.sum(), row_idx, dtype=np.int32)) + cols_list.append(global_cols[geo_mask][nonzero].astype(np.int32)) + vals_list.append(vals[nonzero].astype(np.float32, copy=False)) + + if rows_list: + rows = np.concatenate(rows_list) + cols = np.concatenate(cols_list) + vals = np.concatenate(vals_list) + else: + rows = np.array([], dtype=np.int32) + cols = np.array([], dtype=np.int32) + vals = np.array([], dtype=np.float32) + + np.savez_compressed( + str(plan.coo_path), + rows=rows, + cols=cols, + vals=vals, + col_start=np.array([plan.col_start], dtype=np.int64), + col_end=np.array([plan.col_end], dtype=np.int64), + ) + + if not self.keep_chunks and plan.h5_path.exists(): + plan.h5_path.unlink() + + return ChunkResult( + chunk_id=chunk_id, + nnz=int(vals.shape[0]), + cached=False, + n_households=getattr(summary, "n_households", None), + n_persons=getattr(summary, "n_persons", None), + unique_states=getattr(summary, "unique_states", None), + unique_counties=getattr(summary, "unique_counties", None), + unique_cds=getattr(summary, "unique_cds", None), + ) + + def assemble_final(self) -> sparse.csr_matrix: + """Stream-assemble the final CSR matrix from all shards on disk.""" + logger.info("Assembling matrix from %d chunk files...", self.n_chunks) + X_csr = stream_csr_from_shards( + shard_dir=self.coo_dir, + n_chunks=self.n_chunks, + n_targets=self.shared_state.n_targets, + n_total=self.shared_state.n_total, + ) + logger.info( + "Chunked matrix: %d targets x %d cols, %d nnz", + X_csr.shape[0], + X_csr.shape[1], + X_csr.nnz, + ) + return X_csr diff --git a/policyengine_us_data/calibration/chunked_matrix_modal.py b/policyengine_us_data/calibration/chunked_matrix_modal.py new file mode 100644 index 000000000..7be9a2b7d --- /dev/null +++ b/policyengine_us_data/calibration/chunked_matrix_modal.py @@ -0,0 +1,214 @@ +"""Coordinator-side Modal dispatch for chunked matrix building. + +Writes shared ``SharedBuildState`` to a pipeline-volume path, +spawns ``build_matrix_chunk_worker`` per contiguous batch of +chunk ids, collects per-worker results, then streams the final CSR +from all shards on the volume. + +Kept separate from ``unified_matrix_builder`` so the core matrix +builder doesn't import Modal; only the dispatch path does. +""" + +from __future__ import annotations + +import logging +import math +import pickle +import time +from pathlib import Path +from typing import Any, Dict, List, Optional + +from scipy import sparse + +from policyengine_us_data.calibration.chunked_matrix_assembler import ( + ChunkedMatrixAssembler, + SharedBuildState, +) + +logger = logging.getLogger(__name__) + +DEFAULT_NUM_MATRIX_WORKERS = 50 +# The worker is declared on ``_calibration_app`` in +# ``modal_app/matrix_chunk_worker.py`` (``policyengine-us-data-fit-weights``), +# but ``modal_app/pipeline.py`` merges that app into the pipeline app via +# ``app.include(_calibration_app)``. After ``modal deploy modal_app/pipeline.py`` +# the function is registered under the pipeline app's name — that's the name +# ``modal.Function.from_name`` looks up. +MODAL_APP_NAME = "policyengine-us-data-pipeline" +WORKER_FUNCTION_NAME = "build_matrix_chunk_worker" + + +def partition_chunk_ids_contiguous(n_chunks: int, num_workers: int) -> List[List[int]]: + """Split ``range(n_chunks)`` into ``num_workers`` contiguous batches. + + Contiguous (not round-robin) so that a partial run leaves complete + prefixes on disk that future `--resume-chunks` invocations can + skip cleanly. Returns at most ``num_workers`` non-empty batches. + """ + if n_chunks <= 0: + return [] + if num_workers <= 0: + raise ValueError("num_workers must be positive") + batch_size = math.ceil(n_chunks / num_workers) + batches: List[List[int]] = [] + for start in range(0, n_chunks, batch_size): + end = min(start + batch_size, n_chunks) + batches.append(list(range(start, end))) + return batches + + +def _lookup_worker_function(): + """Resolve the deployed Modal worker function. + + Using ``Function.from_name`` avoids importing the worker module + here (Modal imports are heavy and would pull into every caller). + It also means unit tests can monkeypatch this function without + touching the worker module. + """ + import modal + + return modal.Function.from_name(MODAL_APP_NAME, WORKER_FUNCTION_NAME) + + +def dispatch_chunks_modal( + *, + shared_state: SharedBuildState, + chunk_root: Path, + run_id: str, + num_workers: int = DEFAULT_NUM_MATRIX_WORKERS, + worker_function: Optional[Any] = None, + volume: Optional[Any] = None, +) -> sparse.csr_matrix: + """Fan chunk materialization across Modal workers, then assemble. + + Args: + shared_state: Read-only per-build state; pickled once to the + volume so workers can reconstruct the assembler without + receiving the arrays through ``.spawn()`` args. + chunk_root: Directory on the pipeline volume where shards land + (``{chunk_root}/coo/chunk_XXXXXX.npz``) and where the + shared state pickle lives + (``{chunk_root}/chunk_build_state.pkl``). + run_id: Forwarded to each worker so its volume paths align + with the coordinator's. + num_workers: Upper bound on workers; actual count equals + ``min(num_workers, n_chunks)``. + worker_function: Override for the Modal function (tests only). + volume: Override for the pipeline volume (tests only). When + omitted, resolves ``modal.Volume.from_name("pipeline-artifacts")``. + + Raises: + RuntimeError: if any worker reports one or more chunk errors + after all workers finish. Raised after aggregating so no + errors are silently dropped. + """ + chunk_root = Path(chunk_root) + chunk_root.mkdir(parents=True, exist_ok=True) + state_path = chunk_root / "chunk_build_state.pkl" + with open(state_path, "wb") as f: + pickle.dump(shared_state, f) + + if volume is None: + import modal + + volume = modal.Volume.from_name("pipeline-artifacts", create_if_missing=True) + # Make the shared-state pickle visible to workers. + volume.commit() + + n_chunks = math.ceil(shared_state.n_total / shared_state.chunk_size) + batches = partition_chunk_ids_contiguous(n_chunks, num_workers) + + if not batches: + # Nothing to materialize; fall through to assembly (which will + # return an empty CSR). + volume.reload() + assembler = ChunkedMatrixAssembler( + shared_state=shared_state, + chunk_root=chunk_root, + chunk_size=shared_state.chunk_size, + resume=True, + keep_chunks=False, + ) + return assembler.assemble_final() + + if worker_function is None: + worker_function = _lookup_worker_function() + + logger.info( + "Dispatching %d chunks across %d workers (batch sizes: %s)", + n_chunks, + len(batches), + [len(b) for b in batches[:5]] + (["..."] if len(batches) > 5 else []), + ) + t_dispatch = time.time() + handles = [] + for batch_idx, chunk_ids in enumerate(batches): + handle = worker_function.spawn(run_id=run_id, chunk_ids=chunk_ids) + logger.info( + "Worker %d/%d: %d chunks (%d-%d), fc=%s", + batch_idx + 1, + len(batches), + len(chunk_ids), + chunk_ids[0], + chunk_ids[-1], + getattr(handle, "object_id", "unknown"), + ) + handles.append((batch_idx, handle)) + + aggregated_errors: List[Dict] = [] + for batch_idx, handle in handles: + try: + result = handle.get() + except Exception as exc: + aggregated_errors.append( + { + "batch": batch_idx, + "error": f"Worker crashed: {exc}", + } + ) + logger.error("Worker %d crashed: %s", batch_idx, exc) + continue + if result is None: + aggregated_errors.append( + {"batch": batch_idx, "error": "Worker returned None"} + ) + continue + errors = result.get("errors", []) + if errors: + for err in errors: + err_copy = dict(err) + err_copy["batch"] = batch_idx + aggregated_errors.append(err_copy) + logger.info( + "Worker %d done: %d chunks completed, %d errors", + batch_idx, + len(result.get("chunk_ids", [])) - len(errors), + len(errors), + ) + + logger.info( + "All workers finished in %.1fs; %d errors total", + time.time() - t_dispatch, + len(aggregated_errors), + ) + + if aggregated_errors: + preview = "; ".join( + f"batch {e['batch']}: {e.get('error', 'unknown')[:120]}" + for e in aggregated_errors[:3] + ) + raise RuntimeError( + f"Parallel chunked matrix build failed with " + f"{len(aggregated_errors)} error(s). First: {preview}" + ) + + # All shards present on the volume; reload and assemble. + volume.reload() + assembler = ChunkedMatrixAssembler( + shared_state=shared_state, + chunk_root=chunk_root, + chunk_size=shared_state.chunk_size, + resume=True, + keep_chunks=False, + ) + return assembler.assemble_final() diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index 4056638eb..31d101573 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -323,6 +323,25 @@ def parse_args(argv=None): action="store_true", help="Reuse existing chunk COO files in --chunk-dir when present.", ) + parser.add_argument( + "--parallel", + action="store_true", + default=False, + help=( + "Fan chunked matrix building across Modal workers. Requires " + "--chunked-matrix and a deployed build_matrix_chunk_worker; " + "silently ignored on the non-chunked path." + ), + ) + parser.add_argument( + "--num-matrix-workers", + type=int, + default=50, + help=( + "Number of Modal workers to fan chunked matrix building " + "across when --parallel is set (default: 50)." + ), + ) parser.add_argument( "--package-path", default=None, @@ -1085,6 +1104,9 @@ def run_calibration( chunk_dir: str = None, keep_chunks: bool = False, resume_chunks: bool = False, + parallel: bool = False, + num_matrix_workers: int = 50, + run_id: str = "", ): """Run unified calibration pipeline. @@ -1327,8 +1349,16 @@ def run_calibration( keep_chunks=keep_chunks, resume_chunks=resume_chunks, rerandomize_takeup=do_rerandomize, + parallel=parallel, + num_matrix_workers=num_matrix_workers, + run_id=run_id, ) else: + if parallel: + logger.info( + "--parallel is ignored on the non-chunked matrix path; " + "pass --chunked-matrix to enable Modal fan-out" + ) targets_df, X_sparse, target_names = builder.build_matrix( geography=geography, sim=sim, @@ -1588,6 +1618,9 @@ def main(argv=None): chunk_dir=args.chunk_dir, keep_chunks=args.keep_chunks, resume_chunks=args.resume_chunks, + parallel=args.parallel, + num_matrix_workers=args.num_matrix_workers, + run_id=os.environ.get("POLICYENGINE_US_DATA_RUN_ID", ""), ) source_imputed = geography_info.get("dataset_for_matrix") diff --git a/policyengine_us_data/calibration/unified_matrix_builder.py b/policyengine_us_data/calibration/unified_matrix_builder.py index e38c28e6a..510cb38f3 100644 --- a/policyengine_us_data/calibration/unified_matrix_builder.py +++ b/policyengine_us_data/calibration/unified_matrix_builder.py @@ -130,6 +130,22 @@ def apply(self): return NeutralizeVariable +def build_entity_relationship(sim) -> pd.DataFrame: + """Build a person-level DataFrame of entity id columns. + + Module-level so the chunked matrix kernel can call it from any + process without depending on an ``UnifiedMatrixBuilder`` instance. + """ + return pd.DataFrame( + { + "person_id": sim.calculate("person_id", map_to="person").values, + "household_id": sim.calculate("household_id", map_to="person").values, + "tax_unit_id": sim.calculate("tax_unit_id", map_to="person").values, + "spm_unit_id": sim.calculate("spm_unit_id", map_to="person").values, + } + ) + + def _compute_reform_household_values( dataset_path: str, time_period: int, @@ -1304,15 +1320,7 @@ def __init__( def _build_entity_relationship(self, sim) -> pd.DataFrame: if self._entity_rel_cache is not None: return self._entity_rel_cache - - self._entity_rel_cache = pd.DataFrame( - { - "person_id": sim.calculate("person_id", map_to="person").values, - "household_id": sim.calculate("household_id", map_to="person").values, - "tax_unit_id": sim.calculate("tax_unit_id", map_to="person").values, - "spm_unit_id": sim.calculate("spm_unit_id", map_to="person").values, - } - ) + self._entity_rel_cache = build_entity_relationship(sim) return self._entity_rel_cache # --------------------------------------------------------------- @@ -3073,21 +3081,26 @@ def build_matrix_chunked( keep_chunks: bool = False, resume_chunks: bool = False, rerandomize_takeup: bool = True, + parallel: bool = False, + num_matrix_workers: int = 50, + run_id: str = "", ) -> Tuple[pd.DataFrame, sparse.csr_matrix, List[str]]: """Build a sparse matrix by materializing mixed-geography chunks. - This path chunks over global clone-household columns. Each chunk H5 - contains only the selected base households and their dependent entities, - with geography inputs assigned row-by-row from the clone assignment. + Thin facade: target querying, uprating, constraint extraction, + and manifest handling live here; chunking, per-chunk execution, + and streaming final assembly live in + :class:`ChunkedMatrixAssembler`. """ import shutil import tempfile - import time - from policyengine_us import Microsimulation + from policyengine_us_data.calibration.chunked_matrix_assembler import ( + ChunkedMatrixAssembler, + SharedBuildState, + ) from policyengine_us_data.calibration.entity_clone import ( build_household_entity_maps, - materialize_clone_household_chunk, ) if self.dataset_path is None: @@ -3169,6 +3182,9 @@ def build_matrix_chunked( unique_constraint_vars.add(constraint["variable"]) base_entity_maps = build_household_entity_maps(sim) + target_variables = [ + str(targets_df.iloc[i]["variable"]) for i in range(n_targets) + ] if chunk_dir is None: chunk_root = Path(tempfile.mkdtemp(prefix="matrix_chunks_")) @@ -3177,23 +3193,7 @@ def build_matrix_chunked( chunk_root = Path(chunk_dir) remove_chunk_root = False coo_dir = chunk_root / "coo" - h5_dir = chunk_root / "h5" coo_dir.mkdir(parents=True, exist_ok=True) - h5_dir.mkdir(parents=True, exist_ok=True) - - n_chunks = (n_total + chunk_size - 1) // chunk_size - logger.info( - "Chunked matrix build: %d targets x %d columns, " - "%d chunks of up to %d columns", - n_targets, - n_total, - n_chunks, - chunk_size, - ) - - target_variables = [ - str(targets_df.iloc[i]["variable"]) for i in range(n_targets) - ] chunk_manifest_path = chunk_root / "chunk_manifest.json" chunk_signature = build_chunk_lineage_signature( dataset_path=self.dataset_path, @@ -3217,303 +3217,72 @@ def build_matrix_chunked( else: _save_chunk_manifest(chunk_manifest_path, chunk_signature) - t_build = time.time() - processed_chunk_times: list[float] = [] - cached_chunks = 0 - - for chunk_id, col_start in enumerate(range(0, n_total, chunk_size)): - t_chunk = time.time() - col_end = min(col_start + chunk_size, n_total) - coo_path = coo_dir / f"chunk_{chunk_id:06d}.npz" - h5_path = h5_dir / f"chunk_{chunk_id:06d}.h5" - - if resume_chunks and coo_path.exists(): - with np.load(str(coo_path)) as cached_chunk: - if "col_start" not in cached_chunk or "col_end" not in cached_chunk: - raise ValueError( - f"Cached chunk {coo_path} is missing col_start/col_end metadata" - ) - cached_col_start = int(np.asarray(cached_chunk["col_start"]).item()) - cached_col_end = int(np.asarray(cached_chunk["col_end"]).item()) - if cached_col_start != col_start or cached_col_end != col_end: - raise ValueError( - f"Cached chunk {coo_path} covers cols {cached_col_start}-{cached_col_end - 1}, " - f"expected {col_start}-{col_end - 1}" - ) - cached_chunks += 1 - logger.info( - "Chunk %d/%d cached: cols %d-%d, cached=%d", - chunk_id + 1, - n_chunks, - col_start, - col_end - 1, - cached_chunks, - ) - continue - - global_cols = np.arange(col_start, col_end, dtype=np.int64) - active_hh = global_cols % n_records - active_clone_indices = global_cols // n_records - active_blocks = np.asarray(geography.block_geoid)[global_cols] - active_cd_geoids = np.asarray(geography.cd_geoid, dtype=str)[global_cols] - active_states = np.asarray(geography.state_fips)[global_cols] - active_counties = np.asarray(geography.county_fips, dtype=str)[global_cols] - - summary = materialize_clone_household_chunk( - sim=sim, - entity_maps=base_entity_maps, - active_hh=active_hh, - active_blocks=active_blocks, - active_cd_geoids=active_cd_geoids, - active_clone_indices=active_clone_indices, - output_path=h5_path, - apply_takeup=rerandomize_takeup, - ) - - chunk_sim = Microsimulation(dataset=str(h5_path)) - chunk_n = len(global_cols) - self._entity_rel_cache = None - entity_rel = self._build_entity_relationship(chunk_sim) - household_ids = chunk_sim.calculate( - "household_id", - map_to="household", - ).values - entity_hh_idx_map, person_to_entity_idx_map = _build_entity_index_maps( - entity_rel, - household_ids, - chunk_sim, - ) - - variable_entity_map: Dict[str, str] = {} - hh_vars = {} - target_entity_vars = {} - for variable in sorted(unique_variables): - if variable in chunk_sim.tax_benefit_system.variables: - variable_entity_map[variable] = ( - chunk_sim.tax_benefit_system.variables[variable].entity.key - ) - if variable.endswith("_count"): - continue - try: - hh_vars[variable] = chunk_sim.calculate( - variable, - self.time_period, - map_to="household", - ).values.astype(np.float32) - except Exception as exc: - logger.warning( - "Chunk %d cannot calculate target '%s': %s", - chunk_id, - variable, - exc, - ) - entity_key = variable_entity_map.get(variable, "household") - if entity_key == "household": - continue - try: - target_entity_vars[variable] = chunk_sim.calculate( - variable, - self.time_period, - map_to=entity_key, - ).values.astype(np.float32) - except Exception as exc: - logger.warning( - "Chunk %d cannot calculate entity-level target '%s' " - "(map_to=%s): %s", - chunk_id, - variable, - entity_key, - exc, - ) - - person_vars = {} - for variable in sorted(unique_constraint_vars): - try: - raw = chunk_sim.calculate( - variable, - self.time_period, - map_to="person", - ).values - try: - person_vars[variable] = raw.astype(np.float32) - except (ValueError, TypeError): - person_vars[variable] = raw - except Exception as exc: - logger.warning( - "Chunk %d cannot calculate constraint '%s': %s", - chunk_id, - variable, - exc, - ) - - reform_hh_vars = {} - if reform_variables: - baseline_income_tax = chunk_sim.calculate( - "income_tax", - self.time_period, - map_to="household", - ).values.astype(np.float32) - for variable in sorted(reform_variables): - try: - reform_sim = Microsimulation( - dataset=str(h5_path), - reform=_make_neutralize_variable_reform(variable), - ) - reform_income_tax = reform_sim.calculate( - "income_tax", - self.time_period, - map_to="household", - ).values.astype(np.float32) - reform_hh_vars[variable] = ( - reform_income_tax - baseline_income_tax - ) - except Exception as exc: - logger.warning( - "Chunk %d cannot calculate reform target '%s': %s", - chunk_id, - variable, - exc, - ) - target_value_cache: Dict[tuple, np.ndarray] = {} - rows_list: list = [] - cols_list: list = [] - vals_list: list = [] - - for row_idx in range(n_targets): - variable = target_variables[row_idx] - reform_id = target_reform_ids[row_idx] - geo_level, geo_id = target_geo_info[row_idx] - non_geo = non_geo_constraints_list[row_idx] - - if geo_level == "district": - geo_mask = active_cd_geoids == str(geo_id) - elif geo_level == "state": - geo_mask = active_states.astype(np.int64) == int(geo_id) - elif geo_level == "county": - geo_mask = active_counties == str(geo_id).zfill(5) - else: - geo_mask = np.ones(chunk_n, dtype=bool) - - if not geo_mask.any(): - continue - - constraint_key = tuple( - sorted( - ( - c["variable"], - c["operation"], - c["value"], - ) - for c in non_geo - ) + shared_state = SharedBuildState( + source_dataset_path=str(self.dataset_path), + time_period=self.time_period, + rerandomize_takeup=rerandomize_takeup, + n_records=n_records, + n_clones=n_clones, + n_targets=n_targets, + chunk_size=chunk_size, + target_variables=target_variables, + target_reform_ids=target_reform_ids, + target_geo_info=target_geo_info, + non_geo_constraints_list=non_geo_constraints_list, + unique_variables=unique_variables, + unique_constraint_vars=unique_constraint_vars, + reform_variables=reform_variables, + target_names=target_names, + base_entity_maps=base_entity_maps, + block_geoid=np.asarray(geography.block_geoid), + cd_geoid=np.asarray(geography.cd_geoid, dtype=str), + county_fips=np.asarray(geography.county_fips, dtype=str), + state_fips=np.asarray(geography.state_fips), + ) + assembler = ChunkedMatrixAssembler( + shared_state=shared_state, + chunk_root=chunk_root, + chunk_size=chunk_size, + resume=resume_chunks, + keep_chunks=keep_chunks, + base_sim=sim, + ) + logger.info( + "Chunked matrix build: %d targets x %d columns, " + "%d chunks of up to %d columns%s", + n_targets, + n_total, + assembler.n_chunks, + chunk_size, + ( + f", parallel across {num_matrix_workers} Modal workers" + if parallel + else "" + ), + ) + if parallel: + if not run_id: + raise ValueError( + "run_id is required when parallel=True so workers can " + "find the shared state on the pipeline volume" ) - - value_key = (variable, constraint_key, reform_id) - if value_key not in target_value_cache: - target_value_cache[value_key] = _calculate_target_values_standalone( - target_variable=variable, - non_geo_constraints=non_geo, - n_households=chunk_n, - hh_vars=hh_vars, - reform_hh_vars=reform_hh_vars, - target_entity_vars=target_entity_vars, - person_vars=person_vars, - entity_rel=entity_rel, - household_ids=household_ids, - variable_entity_map=variable_entity_map, - entity_hh_idx_map=entity_hh_idx_map, - person_to_entity_idx_map=person_to_entity_idx_map, - reform_id=reform_id, - ) - values = target_value_cache[value_key] - - vals = values[geo_mask] - nonzero = vals != 0 - if nonzero.any(): - rows_list.append(np.full(nonzero.sum(), row_idx, dtype=np.int32)) - cols_list.append(global_cols[geo_mask][nonzero].astype(np.int32)) - vals_list.append(vals[nonzero].astype(np.float32, copy=False)) - - if rows_list: - rows = np.concatenate(rows_list) - cols = np.concatenate(cols_list) - vals = np.concatenate(vals_list) - else: - rows = np.array([], dtype=np.int32) - cols = np.array([], dtype=np.int32) - vals = np.array([], dtype=np.float32) - - np.savez_compressed( - str(coo_path), - rows=rows, - cols=cols, - vals=vals, - col_start=np.array([col_start], dtype=np.int64), - col_end=np.array([col_end], dtype=np.int64), + from policyengine_us_data.calibration.chunked_matrix_modal import ( + dispatch_chunks_modal, ) - if not keep_chunks and h5_path.exists(): - h5_path.unlink() - - chunk_seconds = time.time() - t_chunk - processed_chunk_times.append(chunk_seconds) - remaining_chunks = n_chunks - (chunk_id + 1) - avg_seconds = float(np.mean(processed_chunk_times)) - eta_seconds = avg_seconds * remaining_chunks - elapsed_seconds = time.time() - t_build - rss = _current_rss_mb() - rss_part = f", rss={rss:,.0f} MB" if rss is not None else "" - logger.info( - "Chunk %d/%d: cols %d-%d, hh=%d, persons=%d, " - "states=%d, counties=%d, cds=%d, nnz=%d, " - "chunk=%s, avg=%s, elapsed=%s, eta=%s%s", - chunk_id + 1, - n_chunks, - col_start, - col_end - 1, - summary.n_households, - summary.n_persons, - summary.unique_states, - summary.unique_counties, - summary.unique_cds, - len(vals), - _format_duration(chunk_seconds), - _format_duration(avg_seconds), - _format_duration(elapsed_seconds), - _format_duration(eta_seconds), - rss_part, + X_csr = dispatch_chunks_modal( + shared_state=shared_state, + chunk_root=chunk_root, + run_id=run_id, + num_workers=num_matrix_workers, ) - - del chunk_sim, hh_vars, person_vars, reform_hh_vars - - logger.info("Assembling matrix from %d chunk files...", n_chunks) - all_rows, all_cols, all_vals = [], [], [] - for chunk_id in range(n_chunks): - path = coo_dir / f"chunk_{chunk_id:06d}.npz" - data = np.load(str(path)) - all_rows.append(data["rows"]) - all_cols.append(data["cols"]) - all_vals.append(data["vals"]) - - rows = np.concatenate(all_rows) if all_rows else np.array([], dtype=np.int32) - cols = np.concatenate(all_cols) if all_cols else np.array([], dtype=np.int32) - vals = np.concatenate(all_vals) if all_vals else np.array([], dtype=np.float32) - - X_csr = sparse.csr_matrix( - (vals, (rows, cols)), - shape=(n_targets, n_total), - dtype=np.float32, - ) - logger.info( - "Chunked matrix: %d targets x %d cols, %d nnz", - X_csr.shape[0], - X_csr.shape[1], - X_csr.nnz, - ) + else: + assembler.run_chunks(range(assembler.n_chunks)) + X_csr = assembler.assemble_final() if remove_chunk_root and chunk_root.exists(): shutil.rmtree(chunk_root) - elif not keep_chunks and h5_dir.exists(): - shutil.rmtree(h5_dir) + elif not keep_chunks and assembler.h5_dir.exists(): + shutil.rmtree(assembler.h5_dir) return targets_df, X_csr, target_names diff --git a/tests/integration/test_matrix_chunk_worker_modal.py b/tests/integration/test_matrix_chunk_worker_modal.py new file mode 100644 index 000000000..c5f2cd26c --- /dev/null +++ b/tests/integration/test_matrix_chunk_worker_modal.py @@ -0,0 +1,76 @@ +"""Env-gated Modal smoke test for the matrix-chunk worker. + +Skipped by default. Runs only when all of: + - ``MODAL_TOKEN_ID`` and ``MODAL_TOKEN_SECRET`` are set (Modal auth) + - ``POLICYENGINE_US_DATA_MODAL_SMOKE=1`` is set + +Assumes the pipeline app (which ``.include()``s the worker) has been +deployed via: + + modal deploy modal_app/pipeline.py + +The worker's ``@app.function`` decorator attaches to ``_calibration_app`` +(``policyengine-us-data-fit-weights``) at the Python level, but +``pipeline.py`` merges that app into the pipeline app. After deploy, +the function is registered in Modal's registry under the pipeline +app's name — that's the name this test looks up. + +This test validates two things without running a full chunk build: + 1. The Modal worker function is discoverable via + ``modal.Function.from_name`` — catches "worker not deployed" and + signature mismatches at test time rather than at pipeline time. + 2. ``partition_chunk_ids_contiguous`` produces the same batches the + coordinator would send — sanity on the shape we'd ship. + +A true end-to-end fan-out smoke (write shared state, spawn, verify +shards) requires the pipeline volume to hold a real fixture dataset; +that is a pre-merge manual step for phase 2, documented in the PR. +""" + +from __future__ import annotations + +import os + +import pytest + + +_SMOKE_ENABLED = ( + os.environ.get("MODAL_TOKEN_ID") + and os.environ.get("MODAL_TOKEN_SECRET") + and os.environ.get("POLICYENGINE_US_DATA_MODAL_SMOKE") == "1" +) + + +pytestmark = pytest.mark.skipif( + not _SMOKE_ENABLED, + reason=( + "Modal smoke test; set MODAL_TOKEN_ID, MODAL_TOKEN_SECRET, and " + "POLICYENGINE_US_DATA_MODAL_SMOKE=1 to enable" + ), +) + + +def test_worker_function_is_deployed() -> None: + """The deployed worker must be lookupable under the pipeline app.""" + import modal + + worker = modal.Function.from_name( + "policyengine-us-data-pipeline", + "build_matrix_chunk_worker", + ) + assert worker is not None + + +def test_dispatch_contiguous_batching_shape() -> None: + """The batching produced for a typical production-scale run is sane.""" + from policyengine_us_data.calibration.chunked_matrix_modal import ( + partition_chunk_ids_contiguous, + ) + + # ~207 chunks at production; check the 50-worker default shape. + batches = partition_chunk_ids_contiguous(n_chunks=207, num_workers=50) + assert len(batches) <= 50 + assert sum(len(b) for b in batches) == 207 + # Contiguous: every batch is a range of consecutive ids. + for batch in batches: + assert list(batch) == list(range(batch[0], batch[0] + len(batch))) diff --git a/tests/unit/calibration/test_chunked_matrix_assembler.py b/tests/unit/calibration/test_chunked_matrix_assembler.py new file mode 100644 index 000000000..589867880 --- /dev/null +++ b/tests/unit/calibration/test_chunked_matrix_assembler.py @@ -0,0 +1,356 @@ +"""Unit tests for ``ChunkedMatrixAssembler`` pure helpers and class. + +Tests here avoid constructing a real ``Microsimulation``. The kernel +(``run_single_chunk``) is exercised by the integration suite at +``tests/integration/test_chunked_matrix_builder.py``. This file +covers: chunk partitioning, streaming CSR assembly (including memory +profile), and resume-skip behaviour on pre-staged shards. +""" + +from __future__ import annotations + +import gc +from pathlib import Path +from typing import List +from unittest import mock + +import numpy as np +import pytest +from scipy import sparse + +from policyengine_us_data.calibration.chunked_matrix_assembler import ( + ChunkResult, + ChunkedMatrixAssembler, + SharedBuildState, + partition_chunks, + stream_csr_from_shards, +) + + +def _write_shard( + shard_dir: Path, + chunk_id: int, + rows: np.ndarray, + cols: np.ndarray, + vals: np.ndarray, + col_start: int, + col_end: int, +) -> Path: + path = shard_dir / f"chunk_{chunk_id:06d}.npz" + np.savez_compressed( + str(path), + rows=np.asarray(rows, dtype=np.int32), + cols=np.asarray(cols, dtype=np.int32), + vals=np.asarray(vals, dtype=np.float32), + col_start=np.array([col_start], dtype=np.int64), + col_end=np.array([col_end], dtype=np.int64), + ) + return path + + +def _make_shared_state( + n_records: int = 10, + n_clones: int = 2, + n_targets: int = 3, +) -> SharedBuildState: + """Minimal ``SharedBuildState`` for tests that don't run the kernel.""" + n_total = n_records * n_clones + return SharedBuildState( + source_dataset_path="/nonexistent/fixture.h5", + time_period=2024, + rerandomize_takeup=False, + n_records=n_records, + n_clones=n_clones, + n_targets=n_targets, + chunk_size=10, + target_variables=["x"] * n_targets, + target_reform_ids=[0] * n_targets, + target_geo_info=[("national", "US")] * n_targets, + non_geo_constraints_list=[[] for _ in range(n_targets)], + unique_variables={"x"}, + unique_constraint_vars=set(), + reform_variables=set(), + target_names=[f"t{i}" for i in range(n_targets)], + base_entity_maps=None, + block_geoid=np.zeros(n_total, dtype="U15"), + cd_geoid=np.zeros(n_total, dtype="U4"), + county_fips=np.zeros(n_total, dtype="U5"), + state_fips=np.zeros(n_total, dtype=np.int32), + ) + + +# ----------------------------------------------------------------------- +# partition_chunks +# ----------------------------------------------------------------------- + + +def test_partition_chunks_exact_multiple(tmp_path: Path) -> None: + plans = partition_chunks( + n_total=1000, chunk_size=250, coo_dir=tmp_path / "coo", h5_dir=tmp_path / "h5" + ) + assert len(plans) == 4 + assert [p.col_start for p in plans] == [0, 250, 500, 750] + assert [p.col_end for p in plans] == [250, 500, 750, 1000] + assert plans[0].coo_path.name == "chunk_000000.npz" + assert plans[3].coo_path.name == "chunk_000003.npz" + + +def test_partition_chunks_remainder(tmp_path: Path) -> None: + plans = partition_chunks( + n_total=1050, chunk_size=250, coo_dir=tmp_path / "coo", h5_dir=tmp_path / "h5" + ) + assert len(plans) == 5 + assert plans[-1].col_start == 1000 + assert plans[-1].col_end == 1050 + + +def test_partition_chunks_zero_total(tmp_path: Path) -> None: + plans = partition_chunks( + n_total=0, chunk_size=100, coo_dir=tmp_path / "coo", h5_dir=tmp_path / "h5" + ) + assert plans == [] + + +def test_partition_chunks_rejects_invalid_chunk_size(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="chunk_size"): + partition_chunks(n_total=10, chunk_size=0, coo_dir=tmp_path, h5_dir=tmp_path) + + +# ----------------------------------------------------------------------- +# stream_csr_from_shards +# ----------------------------------------------------------------------- + + +def test_stream_csr_from_shards_matches_coo_reference(tmp_path: Path) -> None: + shard_dir = tmp_path + # Three chunks, 4 cols per chunk, 5 targets. + # Hand-built entries that span multiple rows and include duplicates + # across shards (different cols) to exercise the scatter pass. + _write_shard( + shard_dir, + 0, + rows=[0, 2, 4], + cols=[0, 2, 3], + vals=[1.0, 2.0, 3.0], + col_start=0, + col_end=4, + ) + _write_shard( + shard_dir, + 1, + rows=[0, 1, 2], + cols=[4, 5, 7], + vals=[4.0, 5.0, 6.0], + col_start=4, + col_end=8, + ) + _write_shard( + shard_dir, + 2, + rows=[], # empty shard + cols=[], + vals=[], + col_start=8, + col_end=12, + ) + n_targets = 5 + n_total = 12 + + X = stream_csr_from_shards( + shard_dir, n_chunks=3, n_targets=n_targets, n_total=n_total + ) + + # Reference: build the same CSR via scipy's COO constructor. + all_rows = np.array([0, 2, 4, 0, 1, 2], dtype=np.int32) + all_cols = np.array([0, 2, 3, 4, 5, 7], dtype=np.int32) + all_vals = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32) + reference = sparse.csr_matrix( + (all_vals, (all_rows, all_cols)), + shape=(n_targets, n_total), + dtype=np.float32, + ) + reference.sort_indices() + + assert X.shape == reference.shape + assert X.nnz == reference.nnz + assert np.array_equal(X.indptr, reference.indptr) + assert np.array_equal(X.indices, reference.indices) + assert np.allclose(X.data, reference.data) + + +def test_stream_csr_from_shards_all_empty(tmp_path: Path) -> None: + shard_dir = tmp_path + for chunk_id in range(2): + _write_shard( + shard_dir, + chunk_id, + rows=[], + cols=[], + vals=[], + col_start=chunk_id * 4, + col_end=(chunk_id + 1) * 4, + ) + X = stream_csr_from_shards(shard_dir, n_chunks=2, n_targets=3, n_total=8) + assert X.shape == (3, 8) + assert X.nnz == 0 + + +def test_stream_csr_memory_within_bound(tmp_path: Path) -> None: + """Peak RSS during streaming assembly should not exceed the final + CSR arrays plus one shard by a wide margin. Guards against a future + regression to list-then-concat assembly. + """ + psutil = pytest.importorskip("psutil") + rng = np.random.default_rng(0) + n_targets = 500 + n_total = 500_000 + nnz_per_shard = 50_000 + n_chunks = 8 # ~400k total nnz + + for chunk_id in range(n_chunks): + col_start = chunk_id * (n_total // n_chunks) + col_end = col_start + (n_total // n_chunks) + rows = rng.integers(0, n_targets, size=nnz_per_shard, dtype=np.int32) + cols = rng.integers(col_start, col_end, size=nnz_per_shard, dtype=np.int32) + vals = rng.standard_normal(nnz_per_shard).astype(np.float32) + _write_shard(tmp_path, chunk_id, rows, cols, vals, col_start, col_end) + + gc.collect() + process = psutil.Process() + baseline = process.memory_info().rss + + # Sample RSS during assembly by wrapping ``stream_csr_from_shards`` + # and taking a reading after the inner arrays are allocated. Simpler + # than a sampling thread, sufficient for a coarse bound check. + X = stream_csr_from_shards( + tmp_path, n_chunks=n_chunks, n_targets=n_targets, n_total=n_total + ) + peak = process.memory_info().rss + + final_csr_bytes = X.data.nbytes + X.indices.nbytes + X.indptr.nbytes + # Bound: 4x final CSR + 32 MB slack. Tight enough to catch a full + # list-then-concat regression (which would be >6x at this size), loose + # enough that Python interpreter overhead + heap fragmentation don't + # make this flaky on shared CI runners. + bound = 4 * final_csr_bytes + 32 * 1024 * 1024 + delta = peak - baseline + assert delta < bound, ( + f"Peak RSS grew by {delta:,} bytes; expected <{bound:,} " + f"(final CSR is {final_csr_bytes:,} bytes)" + ) + + +# ----------------------------------------------------------------------- +# ChunkedMatrixAssembler resume semantics +# ----------------------------------------------------------------------- + + +def test_assembler_skips_existing_shards_when_resume(tmp_path: Path) -> None: + """A pre-staged shard with matching col_start/col_end should skip + the kernel and return a cached ``ChunkResult``. + """ + state = _make_shared_state(n_records=10, n_clones=2, n_targets=3) + assembler = ChunkedMatrixAssembler( + shared_state=state, + chunk_root=tmp_path, + chunk_size=10, + resume=True, + keep_chunks=False, + ) + # Two plans: cols 0-9 and 10-19. + plan0 = assembler.plans[0] + _write_shard( + assembler.coo_dir, + 0, + rows=np.array([1, 2], dtype=np.int32), + cols=np.array([3, 7], dtype=np.int32), + vals=np.array([1.0, 2.0], dtype=np.float32), + col_start=plan0.col_start, + col_end=plan0.col_end, + ) + # Run only chunk 0; kernel would fail (sim=None, fixture path does + # not exist), so hitting the cache path is proof of the skip. + result = assembler.run_single_chunk(0) + assert result.cached is True + assert result.nnz == 2 + assert result.chunk_id == 0 + + +def test_assembler_rejects_shard_with_mismatched_range(tmp_path: Path) -> None: + state = _make_shared_state(n_records=10, n_clones=2, n_targets=3) + assembler = ChunkedMatrixAssembler( + shared_state=state, + chunk_root=tmp_path, + chunk_size=10, + resume=True, + keep_chunks=False, + ) + # Write a shard whose metadata claims col_start=5 (not 0). + _write_shard( + assembler.coo_dir, + 0, + rows=np.array([], dtype=np.int32), + cols=np.array([], dtype=np.int32), + vals=np.array([], dtype=np.float32), + col_start=5, + col_end=15, + ) + with pytest.raises(ValueError, match="expected 0-9"): + assembler.run_single_chunk(0) + + +def test_shared_build_state_roundtrips_pickle() -> None: + """Pickling ``SharedBuildState`` and loading it back must preserve + every field. This guards phase-2 Modal dispatch where each worker + receives its state by reading a pickle file from the volume. + """ + import pickle + + state = _make_shared_state(n_records=12, n_clones=5, n_targets=7) + blob = pickle.dumps(state) + restored = pickle.loads(blob) + + assert restored.source_dataset_path == state.source_dataset_path + assert restored.time_period == state.time_period + assert restored.rerandomize_takeup == state.rerandomize_takeup + assert restored.n_records == state.n_records + assert restored.n_clones == state.n_clones + assert restored.n_targets == state.n_targets + assert restored.chunk_size == state.chunk_size + assert restored.n_total == state.n_total + assert restored.target_variables == state.target_variables + assert restored.target_reform_ids == state.target_reform_ids + assert restored.target_geo_info == state.target_geo_info + assert restored.non_geo_constraints_list == state.non_geo_constraints_list + assert restored.unique_variables == state.unique_variables + assert restored.unique_constraint_vars == state.unique_constraint_vars + assert restored.reform_variables == state.reform_variables + assert restored.target_names == state.target_names + assert np.array_equal(restored.block_geoid, state.block_geoid) + assert np.array_equal(restored.cd_geoid, state.cd_geoid) + assert np.array_equal(restored.county_fips, state.county_fips) + assert np.array_equal(restored.state_fips, state.state_fips) + + +def test_assembler_run_chunks_dispatches_each_id(tmp_path: Path) -> None: + state = _make_shared_state(n_records=10, n_clones=3, n_targets=2) + assembler = ChunkedMatrixAssembler( + shared_state=state, + chunk_root=tmp_path, + chunk_size=10, + resume=False, + keep_chunks=False, + ) + assert assembler.n_chunks == 3 + + observed: List[int] = [] + + def fake_run(chunk_id: int) -> ChunkResult: + observed.append(chunk_id) + return ChunkResult(chunk_id=chunk_id, nnz=0, cached=False) + + with mock.patch.object(assembler, "run_single_chunk", side_effect=fake_run): + results = assembler.run_chunks([0, 2]) + + assert observed == [0, 2] + assert [r.chunk_id for r in results] == [0, 2] diff --git a/tests/unit/calibration/test_chunked_matrix_modal.py b/tests/unit/calibration/test_chunked_matrix_modal.py new file mode 100644 index 000000000..7dc50a489 --- /dev/null +++ b/tests/unit/calibration/test_chunked_matrix_modal.py @@ -0,0 +1,257 @@ +"""Unit tests for the Modal dispatch layer. + +Covers the pure partition helper plus ``dispatch_chunks_modal`` under +a fake worker function (via injected ``worker_function`` and +``volume`` overrides). No real Modal calls. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Dict, List +from unittest import mock + +import numpy as np +import pytest + +from policyengine_us_data.calibration.chunked_matrix_assembler import ( + SharedBuildState, +) +from policyengine_us_data.calibration.chunked_matrix_modal import ( + dispatch_chunks_modal, + partition_chunk_ids_contiguous, +) + + +# ----------------------------------------------------------------------- +# partition_chunk_ids_contiguous +# ----------------------------------------------------------------------- + + +def test_contiguous_batch_exact_division() -> None: + batches = partition_chunk_ids_contiguous(n_chunks=12, num_workers=3) + assert batches == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + + +def test_contiguous_batch_remainder() -> None: + batches = partition_chunk_ids_contiguous(n_chunks=10, num_workers=3) + # ceil(10/3) = 4, so batches are [0..3], [4..7], [8..9] — 3 workers. + assert batches == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]] + assert sum(len(b) for b in batches) == 10 + + +def test_contiguous_batch_more_workers_than_chunks() -> None: + batches = partition_chunk_ids_contiguous(n_chunks=3, num_workers=10) + # ceil(3/10) = 1, so at most 3 non-empty batches. + assert batches == [[0], [1], [2]] + + +def test_contiguous_batch_zero_chunks() -> None: + assert partition_chunk_ids_contiguous(n_chunks=0, num_workers=5) == [] + + +def test_contiguous_batch_rejects_non_positive_workers() -> None: + with pytest.raises(ValueError, match="num_workers"): + partition_chunk_ids_contiguous(n_chunks=5, num_workers=0) + + +# ----------------------------------------------------------------------- +# dispatch_chunks_modal with injected worker + volume fakes +# ----------------------------------------------------------------------- + + +def _minimal_shared_state( + n_records: int = 10, n_clones: int = 2, chunk_size: int = 10 +) -> SharedBuildState: + n_total = n_records * n_clones + return SharedBuildState( + source_dataset_path="/nonexistent/fixture.h5", + time_period=2024, + rerandomize_takeup=False, + n_records=n_records, + n_clones=n_clones, + n_targets=2, + chunk_size=chunk_size, + target_variables=["x", "y"], + target_reform_ids=[0, 0], + target_geo_info=[("national", "US"), ("national", "US")], + non_geo_constraints_list=[[], []], + unique_variables={"x", "y"}, + unique_constraint_vars=set(), + reform_variables=set(), + target_names=["t0", "t1"], + base_entity_maps=None, + block_geoid=np.zeros(n_total, dtype="U15"), + cd_geoid=np.zeros(n_total, dtype="U4"), + county_fips=np.zeros(n_total, dtype="U5"), + state_fips=np.zeros(n_total, dtype=np.int32), + ) + + +class _FakeHandle: + def __init__(self, result: Dict, *, raise_on_get: Exception = None): + self._result = result + self._raise = raise_on_get + self.object_id = "fc-fake" + + def get(self) -> Dict: + if self._raise is not None: + raise self._raise + return self._result + + +class _FakeVolume: + def __init__(self) -> None: + self.commit_count = 0 + self.reload_count = 0 + + def commit(self) -> None: + self.commit_count += 1 + + def reload(self) -> None: + self.reload_count += 1 + + +def _write_fake_shard( + shard_dir: Path, chunk_id: int, col_start: int, col_end: int +) -> None: + shard_dir.mkdir(parents=True, exist_ok=True) + np.savez_compressed( + str(shard_dir / f"chunk_{chunk_id:06d}.npz"), + rows=np.array([0], dtype=np.int32), + cols=np.array([col_start], dtype=np.int32), + vals=np.array([1.0], dtype=np.float32), + col_start=np.array([col_start], dtype=np.int64), + col_end=np.array([col_end], dtype=np.int64), + ) + + +def test_dispatch_spawns_per_batch_and_assembles(tmp_path: Path) -> None: + state = _minimal_shared_state(n_records=10, n_clones=4, chunk_size=10) + # n_total=40, chunk_size=10 -> 4 chunks. num_workers=2 -> 2 batches. + n_chunks = 4 + + # Fake worker writes shards as a side effect of .spawn() so that + # by the time assemble_final() runs, the shard files exist. + spawn_calls: List[Dict] = [] + + def fake_spawn(*, run_id: str, chunk_ids: List[int]) -> _FakeHandle: + spawn_calls.append({"run_id": run_id, "chunk_ids": list(chunk_ids)}) + for chunk_id in chunk_ids: + col_start = chunk_id * state.chunk_size + col_end = col_start + state.chunk_size + _write_fake_shard(tmp_path / "coo", chunk_id, col_start, col_end) + return _FakeHandle( + { + "chunk_ids": list(chunk_ids), + "nnz_per_chunk": [1] * len(chunk_ids), + "errors": [], + } + ) + + fake_worker = mock.MagicMock() + fake_worker.spawn.side_effect = fake_spawn + volume = _FakeVolume() + + X = dispatch_chunks_modal( + shared_state=state, + chunk_root=tmp_path, + run_id="run-test", + num_workers=2, + worker_function=fake_worker, + volume=volume, + ) + + # 2 batches spawned, each covering 2 contiguous chunk ids. + assert [c["chunk_ids"] for c in spawn_calls] == [[0, 1], [2, 3]] + # Every spawn carried the run_id. + assert all(c["run_id"] == "run-test" for c in spawn_calls) + # Final CSR covers all 4 chunks' nnz. + assert X.shape == (state.n_targets, state.n_total) + assert X.nnz == n_chunks + # Volume was committed (pre-spawn) and reloaded (pre-assemble). + assert volume.commit_count >= 1 + assert volume.reload_count >= 1 + + +def test_dispatch_short_circuits_when_zero_chunks(tmp_path: Path) -> None: + state = _minimal_shared_state(n_records=0, n_clones=0, chunk_size=10) + fake_worker = mock.MagicMock() + volume = _FakeVolume() + + X = dispatch_chunks_modal( + shared_state=state, + chunk_root=tmp_path, + run_id="run-test", + num_workers=4, + worker_function=fake_worker, + volume=volume, + ) + + assert X.shape == (state.n_targets, 0) + assert X.nnz == 0 + fake_worker.spawn.assert_not_called() + + +def test_dispatch_aggregates_worker_errors(tmp_path: Path) -> None: + state = _minimal_shared_state(n_records=10, n_clones=2, chunk_size=10) + # n_total=20, chunk_size=10 -> 2 chunks, 2 workers. + + def fake_spawn(*, run_id: str, chunk_ids: List[int]) -> _FakeHandle: + # First worker returns a per-chunk error; second crashes in .get(). + if chunk_ids == [0]: + return _FakeHandle( + { + "chunk_ids": chunk_ids, + "nnz_per_chunk": [], + "errors": [{"chunk_id": 0, "error": "boom"}], + } + ) + return _FakeHandle(None, raise_on_get=RuntimeError("worker oom")) + + fake_worker = mock.MagicMock() + fake_worker.spawn.side_effect = fake_spawn + volume = _FakeVolume() + + with pytest.raises(RuntimeError, match="Parallel chunked matrix build failed"): + dispatch_chunks_modal( + shared_state=state, + chunk_root=tmp_path, + run_id="run-test", + num_workers=2, + worker_function=fake_worker, + volume=volume, + ) + + +def test_dispatch_writes_shared_state_pickle(tmp_path: Path) -> None: + import pickle + + state = _minimal_shared_state(n_records=5, n_clones=1, chunk_size=10) + # n_total=5, chunk_size=10 -> 1 chunk, 1 worker. + + def fake_spawn(*, run_id: str, chunk_ids: List[int]) -> _FakeHandle: + for chunk_id in chunk_ids: + _write_fake_shard(tmp_path / "coo", chunk_id, 0, 5) + return _FakeHandle({"chunk_ids": chunk_ids, "nnz_per_chunk": [1], "errors": []}) + + fake_worker = mock.MagicMock() + fake_worker.spawn.side_effect = fake_spawn + volume = _FakeVolume() + + dispatch_chunks_modal( + shared_state=state, + chunk_root=tmp_path, + run_id="run-test", + num_workers=1, + worker_function=fake_worker, + volume=volume, + ) + + pickle_path = tmp_path / "chunk_build_state.pkl" + assert pickle_path.exists() + with open(pickle_path, "rb") as f: + roundtripped = pickle.load(f) + assert roundtripped.n_total == state.n_total + assert roundtripped.chunk_size == state.chunk_size + assert roundtripped.target_names == state.target_names