diff --git a/fern/versions/latest/pages/concepts/workflow-chaining.mdx b/fern/versions/latest/pages/concepts/workflow-chaining.mdx index ed41363b2..2f7272452 100644 --- a/fern/versions/latest/pages/concepts/workflow-chaining.mdx +++ b/fern/versions/latest/pages/concepts/workflow-chaining.mdx @@ -138,6 +138,35 @@ workflow.add_stage("enriched", enriched) `on_success_version` is part of the stage resume identity. Change it when the callback's output semantics change. If a callback returns zero rows, the workflow raises by default; set `allow_empty=True` to mark that stage as completed empty and skip downstream stages. +## Repeating until a filtered count + +Use `repeat_until` when a stage should keep generating candidates until its selected output reaches a target row count. This is useful for bounded rejection sampling, such as generating many candidates and keeping only rows that pass a judge or quality gate. + +```python +from data_designer.interface import RepeatUntil + +workflow = data_designer.compose_workflow(name="judge-disagreements") +workflow.add_stage( + "judged", + judges, + num_records=1_000, + on_success=keep_disagreements, + on_success_version="disagreements-v1", + repeat_until=RepeatUntil( + output_records=5_000, + max_iterations=10, + max_generated_records=20_000, + ), +) +workflow.add_stage("enriched", enriched) +``` + +`num_records` is the per-attempt size. In the default `mode="append"`, each iteration requests the cumulative stage size (`num_records`, then `2 * num_records`, and so on), reruns `on_success` over the accumulated stage output, and feeds exactly `output_records` selected rows downstream. + +Set `on_exhausted="return_partial"` to keep the best partial output when the bounds are reached; otherwise the workflow raises. If no rows pass, the stage completes empty and downstream stages are skipped, matching `allow_empty=True` behavior. + +Use `mode="discard"` when each attempt should replace the previous selected output instead of accumulating it. Discard mode restarts the stage on resume because previous attempts are intentionally replaced. Keep bounded limits in place: a low acceptance rate is often a signal to inspect the recipe, not just to run indefinitely. In append mode, `max_generated_records` caps the cumulative requested stage size; in discard mode, it caps records produced across attempts. + ## Changing row counts between stages Each stage has a fixed requested row count while it runs. To resize a workflow, change the selected output at a stage boundary and let the next stage seed from that output. diff --git a/packages/data-designer/src/data_designer/interface/__init__.py b/packages/data-designer/src/data_designer/interface/__init__.py index febf02f55..a64119abb 100644 --- a/packages/data-designer/src/data_designer/interface/__init__.py +++ b/packages/data-designer/src/data_designer/interface/__init__.py @@ -11,6 +11,9 @@ from data_designer.interface.composite_workflow import ( # noqa: F401 CompositeWorkflow, CompositeWorkflowResults, + RepeatUntil, + RepeatUntilExhaustion, + RepeatUntilMode, SkippedStageResult, SkippedStageStatus, ) @@ -33,6 +36,9 @@ "DataDesignerWorkflowError": ("data_designer.interface.errors", "DataDesignerWorkflowError"), "DatasetCreationResults": ("data_designer.interface.results", "DatasetCreationResults"), "ResumeMode": ("data_designer.engine.storage.artifact_storage", "ResumeMode"), + "RepeatUntil": ("data_designer.interface.composite_workflow", "RepeatUntil"), + "RepeatUntilExhaustion": ("data_designer.interface.composite_workflow", "RepeatUntilExhaustion"), + "RepeatUntilMode": ("data_designer.interface.composite_workflow", "RepeatUntilMode"), "SkippedStageResult": ("data_designer.interface.composite_workflow", "SkippedStageResult"), "SkippedStageStatus": ("data_designer.interface.composite_workflow", "SkippedStageStatus"), } diff --git a/packages/data-designer/src/data_designer/interface/composite_workflow.py b/packages/data-designer/src/data_designer/interface/composite_workflow.py index 408083be5..0dcd26488 100644 --- a/packages/data-designer/src/data_designer/interface/composite_workflow.py +++ b/packages/data-designer/src/data_designer/interface/composite_workflow.py @@ -60,9 +60,69 @@ "callback_output_path", "output_processor_output_path", "stage_output_override_path", + "repeat_until_output_path", ) +class RepeatUntilMode(StrEnum): + APPEND = "append" + DISCARD = "discard" + + +class RepeatUntilExhaustion(StrEnum): + RAISE = "raise" + RETURN_PARTIAL = "return_partial" + + +@dataclass(frozen=True) +class RepeatUntil: + """Bounded stage-level retry policy for exact selected-output counts.""" + + output_records: int + max_iterations: int + mode: RepeatUntilMode | str = RepeatUntilMode.APPEND + max_generated_records: int | None = None + on_exhausted: RepeatUntilExhaustion | str = RepeatUntilExhaustion.RAISE + trim: bool = True + + def __post_init__(self) -> None: + if self.output_records < 1: + raise DataDesignerWorkflowError("repeat_until.output_records must be at least 1.") + if self.max_iterations < 1: + raise DataDesignerWorkflowError("repeat_until.max_iterations must be at least 1.") + if self.max_generated_records is not None and self.max_generated_records < 1: + raise DataDesignerWorkflowError("repeat_until.max_generated_records must be at least 1.") + try: + mode = RepeatUntilMode(self.mode) + except ValueError as exc: + raise DataDesignerWorkflowError( + f"repeat_until.mode must be one of: {_enum_values(RepeatUntilMode)}." + ) from exc + try: + on_exhausted = RepeatUntilExhaustion(self.on_exhausted) + except ValueError as exc: + raise DataDesignerWorkflowError( + f"repeat_until.on_exhausted must be one of: {_enum_values(RepeatUntilExhaustion)}." + ) from exc + object.__setattr__(self, "mode", mode) + object.__setattr__(self, "on_exhausted", on_exhausted) + + +@dataclass(frozen=True) +class _StageRunResult: + output_result: DatasetCreationResults + actual_records: int + output_seed_path: Path + output_records: int + callback_output_path: Path | None + output_processor_output_path: Path | None + num_records_requested: int + repeat_iterations: int | None = None + repeat_generated_records: int | None = None + repeat_satisfied: bool | None = None + repeat_until_output_path: Path | None = None + + @dataclass(frozen=True) class _WorkflowStage: name: str @@ -76,6 +136,7 @@ class _WorkflowStage: allow_empty: bool sampling_strategy: SamplingStrategy selection_strategy: IndexRange | PartitionBlock | None + repeat_until: RepeatUntil | None class SkippedStageStatus(StrEnum): @@ -207,6 +268,7 @@ def add_stage( allow_empty: bool = False, sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED, selection_strategy: IndexRange | PartitionBlock | None = None, + repeat_until: RepeatUntil | None = None, ) -> CompositeWorkflow: """Add a stage to the workflow. @@ -214,7 +276,8 @@ def add_stage( ``output_processors`` for stage-boundary transforms whose output should feed downstream stages by default. ``output="processor:"`` selects a named processor artifact, and ``on_success`` can override the selected - output by returning a parquet file or directory. + output by returning a parquet file or directory. Use ``repeat_until`` to + rerun the stage until the selected output reaches an exact row count. """ _validate_dir_name(name, "stage name") if any(stage.name == name for stage in self._stages): @@ -238,6 +301,7 @@ def add_stage( allow_empty=allow_empty, sampling_strategy=sampling_strategy, selection_strategy=selection_strategy, + repeat_until=repeat_until, ) ) return self @@ -409,46 +473,23 @@ def run( start_time = time.monotonic() try: - result = self._data_designer.create( - stage_builder, + run_result = self._run_stage( + stage=stage, + stage_builder=stage_builder, + workflow_path=workflow_path, + stage_dir_name=stage_dir_name, + stage_path=stage_path, num_records=num_records, - dataset_name=stage_dir_name, - artifact_path=workflow_path, resume=stage_resume, ) - actual_records = result.count_records() - output_result = result - output_source_result = result - if stage.output_processors: - output_processor_path = stage_path / "output-processors" - if output_processor_path.exists(): - shutil.rmtree(output_processor_path) - output_processor_builder = _output_processor_config_builder( - stage_builder=stage_builder, - seed_path=result.artifact_storage.final_dataset_path, - output_processors=stage.output_processors, - ) - output_result = self._data_designer.create( - output_processor_builder, - num_records=actual_records, - dataset_name="output-processors", - artifact_path=workflow_path / stage_dir_name, - ) - output_source_result = _select_output_result(stage, result, output_result) - - callback_output_path = None - if stage.on_success is not None: - callback_output_path = Path(stage.on_success(result.artifact_storage.base_dataset_path)) - output_seed_path = callback_output_path - else: - output_seed_path = _resolve_stage_output_path(output_source_result, stage.output) + output_seed_path = run_result.output_seed_path override_path = _stage_output_override(stage.name, stage_output_overrides) if override_path is not None: output_seed_path = override_path output_records = _count_parquet_records(output_seed_path) if output_records == 0: - if not stage.allow_empty: + if not _allows_empty_stage_output(stage, run_result): raise DataDesignerWorkflowError(f"Stage {stage.name!r} produced an empty output.") status = "completed_empty" skipped_upstream_stage = stage.name @@ -458,29 +499,45 @@ def run( stage_metadata.update( { "status": status, - "num_records_actual": actual_records, + "num_records_requested": run_result.num_records_requested, + "num_records_actual": run_result.actual_records, "output_records": output_records, "output_seed_path": _metadata_path_value(workflow_path, output_seed_path), "callback_output_path": ( - _metadata_path_value(workflow_path, callback_output_path) if callback_output_path else None + _metadata_path_value(workflow_path, run_result.callback_output_path) + if run_result.callback_output_path + else None ), "stage_output_override_path": ( _metadata_path_value(workflow_path, override_path) if override_path else None ), "output_processor_output_path": ( - _metadata_path_value(workflow_path, output_result.artifact_storage.base_dataset_path) - if stage.output_processors + _metadata_path_value(workflow_path, run_result.output_processor_output_path) + if run_result.output_processor_output_path else None ), "duration_sec": time.monotonic() - start_time, } ) + if run_result.repeat_iterations is not None: + stage_metadata.update( + { + "repeat_iterations": run_result.repeat_iterations, + "repeat_generated_records": run_result.repeat_generated_records, + "repeat_satisfied": run_result.repeat_satisfied, + "repeat_until_output_path": ( + _metadata_path_value(workflow_path, run_result.repeat_until_output_path) + if run_result.repeat_until_output_path + else None + ), + } + ) except Exception: stage_metadata.update({"status": "failed", "duration_sec": time.monotonic() - start_time}) _write_workflow_metadata(workflow_path, metadata) raise - stage_results[stage.name] = output_result + stage_results[stage.name] = run_result.output_result stage_output_paths[stage.name] = output_seed_path previous_seed_path = output_seed_path previous_output_records = None if status == "completed_empty" else output_records @@ -496,11 +553,305 @@ def run( stage_output_paths=stage_output_paths, ) + def _run_stage( + self, + *, + stage: _WorkflowStage, + stage_builder: DataDesignerConfigBuilder, + workflow_path: Path, + stage_dir_name: str, + stage_path: Path, + num_records: int, + resume: ResumeMode, + ) -> _StageRunResult: + if stage.repeat_until is None: + return self._run_stage_attempt( + stage=stage, + stage_builder=stage_builder, + artifact_path=workflow_path, + dataset_name=stage_dir_name, + num_records=num_records, + resume=resume, + ) + if stage.repeat_until.mode == RepeatUntilMode.DISCARD: + return self._run_stage_until_discard( + stage=stage, + stage_builder=stage_builder, + workflow_path=workflow_path, + stage_dir_name=stage_dir_name, + stage_path=stage_path, + num_records=num_records, + resume=resume, + ) + return self._run_stage_until_append( + stage=stage, + stage_builder=stage_builder, + workflow_path=workflow_path, + stage_dir_name=stage_dir_name, + stage_path=stage_path, + num_records=num_records, + resume=resume, + ) + + def _run_stage_attempt( + self, + *, + stage: _WorkflowStage, + stage_builder: DataDesignerConfigBuilder, + artifact_path: Path, + dataset_name: str, + num_records: int, + resume: ResumeMode, + ) -> _StageRunResult: + result = self._data_designer.create( + stage_builder, + num_records=num_records, + dataset_name=dataset_name, + artifact_path=artifact_path, + resume=resume, + ) + actual_records = result.count_records() + output_result = result + output_source_result = result + stage_path = artifact_path / dataset_name + output_processor_output_path = None + if stage.output_processors: + output_processor_path = stage_path / "output-processors" + if output_processor_path.exists(): + shutil.rmtree(output_processor_path) + output_processor_builder = _output_processor_config_builder( + stage_builder=stage_builder, + seed_path=result.artifact_storage.final_dataset_path, + output_processors=stage.output_processors, + ) + output_result = self._data_designer.create( + output_processor_builder, + num_records=actual_records, + dataset_name="output-processors", + artifact_path=stage_path, + ) + output_source_result = _select_output_result(stage, result, output_result) + output_processor_output_path = output_result.artifact_storage.base_dataset_path + + callback_output_path = None + if stage.on_success is not None: + callback_output_path = Path(stage.on_success(result.artifact_storage.base_dataset_path)) + output_seed_path = callback_output_path + else: + output_seed_path = _resolve_stage_output_path(output_source_result, stage.output) + output_records = _count_parquet_records(output_seed_path) + return _StageRunResult( + output_result=output_result, + actual_records=actual_records, + output_seed_path=output_seed_path, + output_records=output_records, + callback_output_path=callback_output_path, + output_processor_output_path=output_processor_output_path, + num_records_requested=num_records, + ) + + def _run_stage_until_append( + self, + *, + stage: _WorkflowStage, + stage_builder: DataDesignerConfigBuilder, + workflow_path: Path, + stage_dir_name: str, + stage_path: Path, + num_records: int, + resume: ResumeMode, + ) -> _StageRunResult: + repeat_until = _require_repeat_until(stage) + last_result = None + for iteration in range(1, repeat_until.max_iterations + 1): + requested_records = num_records * iteration + if _exceeds_max_generated_records(repeat_until, requested_records): + break + attempt_resume = resume if iteration == 1 else ResumeMode.ALWAYS + last_result = self._run_stage_attempt( + stage=stage, + stage_builder=stage_builder, + artifact_path=workflow_path, + dataset_name=stage_dir_name, + num_records=requested_records, + resume=attempt_resume, + ) + if last_result.output_records >= repeat_until.output_records: + return _with_repeat_result( + last_result, + stage_path=stage_path, + repeat_until=repeat_until, + iterations=iteration, + generated_records=last_result.actual_records, + satisfied=True, + ) + + return _handle_repeat_until_exhausted( + stage=stage, + repeat_until=repeat_until, + last_result=last_result, + stage_path=stage_path, + iterations=(last_result.num_records_requested // num_records if last_result else 0), + generated_records=(last_result.actual_records if last_result else 0), + ) + + def _run_stage_until_discard( + self, + *, + stage: _WorkflowStage, + stage_builder: DataDesignerConfigBuilder, + workflow_path: Path, + stage_dir_name: str, + stage_path: Path, + num_records: int, + resume: ResumeMode, + ) -> _StageRunResult: + repeat_until = _require_repeat_until(stage) + if resume == ResumeMode.ALWAYS: + logger.warning( + "Stage %r uses repeat_until mode='discard'; previous attempts cannot be resumed and will be replaced.", + stage.name, + ) + last_result = None + generated_records = 0 + iterations_run = 0 + for iteration in range(1, repeat_until.max_iterations + 1): + if _exceeds_max_generated_records(repeat_until, generated_records + num_records): + break + if stage_path.exists(): + shutil.rmtree(stage_path) + last_result = self._run_stage_attempt( + stage=stage, + stage_builder=stage_builder, + artifact_path=workflow_path, + dataset_name=stage_dir_name, + num_records=num_records, + resume=ResumeMode.NEVER, + ) + iterations_run = iteration + generated_records += last_result.actual_records + if last_result.output_records >= repeat_until.output_records: + return _with_repeat_result( + last_result, + stage_path=stage_path, + repeat_until=repeat_until, + iterations=iteration, + generated_records=generated_records, + satisfied=True, + ) + + return _handle_repeat_until_exhausted( + stage=stage, + repeat_until=repeat_until, + last_result=last_result, + stage_path=stage_path, + iterations=iterations_run, + generated_records=generated_records, + ) + def _stage_indices_by_name(stages: list[_WorkflowStage]) -> dict[str, int]: return {stage.name: index for index, stage in enumerate(stages)} +def _require_repeat_until(stage: _WorkflowStage) -> RepeatUntil: + if stage.repeat_until is None: + raise DataDesignerWorkflowError(f"Stage {stage.name!r} has no repeat_until policy.") + return stage.repeat_until + + +def _exceeds_max_generated_records(repeat_until: RepeatUntil, generated_records: int) -> bool: + return repeat_until.max_generated_records is not None and generated_records > repeat_until.max_generated_records + + +def _with_repeat_result( + result: _StageRunResult, + *, + stage_path: Path, + repeat_until: RepeatUntil, + iterations: int, + generated_records: int, + satisfied: bool, +) -> _StageRunResult: + output_seed_path = result.output_seed_path + output_records = result.output_records + repeat_until_output_path = None + if output_records > repeat_until.output_records and repeat_until.trim: + repeat_until_output_path = stage_path / "repeat-until" / "selected-output" + _write_parquet_head(output_seed_path, repeat_until_output_path, repeat_until.output_records) + output_seed_path = repeat_until_output_path + output_records = repeat_until.output_records + return _StageRunResult( + output_result=result.output_result, + actual_records=result.actual_records, + output_seed_path=output_seed_path, + output_records=output_records, + callback_output_path=result.callback_output_path, + output_processor_output_path=result.output_processor_output_path, + num_records_requested=result.num_records_requested, + repeat_iterations=iterations, + repeat_generated_records=generated_records, + repeat_satisfied=satisfied, + repeat_until_output_path=repeat_until_output_path, + ) + + +def _handle_repeat_until_exhausted( + *, + stage: _WorkflowStage, + repeat_until: RepeatUntil, + last_result: _StageRunResult | None, + stage_path: Path, + iterations: int, + generated_records: int, +) -> _StageRunResult: + selected_records = last_result.output_records if last_result is not None else 0 + if repeat_until.on_exhausted == RepeatUntilExhaustion.RAISE: + raise DataDesignerWorkflowError( + f"Stage {stage.name!r} repeat_until exhausted after {iterations} iteration(s): " + f"selected {selected_records} of {repeat_until.output_records} requested records." + ) + if last_result is None: + raise DataDesignerWorkflowError( + f"Stage {stage.name!r} repeat_until did not run because no iteration fit within the configured limits." + ) + return _with_repeat_result( + last_result, + stage_path=stage_path, + repeat_until=repeat_until, + iterations=iterations, + generated_records=generated_records, + satisfied=False, + ) + + +def _allows_empty_stage_output(stage: _WorkflowStage, run_result: _StageRunResult) -> bool: + if stage.allow_empty: + return True + return ( + stage.repeat_until is not None + and stage.repeat_until.on_exhausted == RepeatUntilExhaustion.RETURN_PARTIAL + and run_result.repeat_satisfied is False + ) + + +def _repeat_until_payload(repeat_until: RepeatUntil | None) -> dict[str, Any] | None: + if repeat_until is None: + return None + return { + "output_records": repeat_until.output_records, + "max_iterations": repeat_until.max_iterations, + "mode": repeat_until.mode.value, + "max_generated_records": repeat_until.max_generated_records, + "on_exhausted": repeat_until.on_exhausted.value, + "trim": repeat_until.trim, + } + + +def _enum_values(enum_type: type[StrEnum]) -> str: + return ", ".join(repr(item.value) for item in enum_type) + + def _normalize_stage_names( stage_names: StageTargets | None, stage_indices: dict[str, int], @@ -758,6 +1109,7 @@ def _base_stage_metadata(index: int, stage: _WorkflowStage, stage_dir_name: str) "output": stage.output, "sampling_strategy": stage.sampling_strategy.value, "selection_strategy": _selection_strategy_payload(stage.selection_strategy), + "repeat_until": _repeat_until_payload(stage.repeat_until), } @@ -777,6 +1129,7 @@ def _stage_fingerprint( "on_success_version": stage.on_success_version, "output_processors": [processor.model_dump(mode="json") for processor in stage.output_processors], "output": stage.output, + "repeat_until": _repeat_until_payload(stage.repeat_until), "library_version": get_library_version(), "upstream_fingerprint": upstream_fingerprint, } @@ -841,6 +1194,14 @@ def _load_parquet_dataset(path: Path) -> pd.DataFrame: raise DataDesignerWorkflowError(f"Failed to read parquet files at {str(path)!r}: {e}") from e +def _write_parquet_head(source_path: Path, output_path: Path, num_records: int) -> None: + df = _load_parquet_dataset(source_path).head(num_records) + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True) + df.to_parquet(output_path / "data.parquet", index=False) + + def _export_parquet_dataset(source_path: Path, output_path: Path, *, format: ExportFormat | None = None) -> Path: resolved_format: str = format if format is not None else output_path.suffix.lstrip(".").lower() if resolved_format not in SUPPORTED_EXPORT_FORMATS: diff --git a/packages/data-designer/tests/interface/test_composite_workflow.py b/packages/data-designer/tests/interface/test_composite_workflow.py index eb9a7abad..2a76b3026 100644 --- a/packages/data-designer/tests/interface/test_composite_workflow.py +++ b/packages/data-designer/tests/interface/test_composite_workflow.py @@ -22,7 +22,7 @@ from data_designer.config.seed_source_dataframe import DataFrameSeedSource from data_designer.engine.secret_resolver import PlaintextResolver from data_designer.engine.storage.artifact_storage import ArtifactStorage, BatchStage, ResumeMode -from data_designer.interface.composite_workflow import SkippedStageResult, SkippedStageStatus +from data_designer.interface.composite_workflow import RepeatUntil, SkippedStageResult, SkippedStageStatus from data_designer.interface.data_designer import DataDesigner from data_designer.interface.errors import DataDesignerWorkflowError from data_designer.interface.results import DatasetCreationResults @@ -1030,6 +1030,297 @@ def expand(stage_path: Path) -> Path: assert results.count_stage_output_records("personas") == 4 +def test_composite_workflow_repeat_until_append_accumulates_and_trims( + tmp_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], +) -> None: + stage_1 = _seeded_builder( + stub_model_configs, + [ + {"name": "Ada", "keep": False}, + {"name": "Grace", "keep": True}, + {"name": "Linus", "keep": True}, + {"name": "Margaret", "keep": True}, + ], + ) + stage_1.add_column(ExpressionColumnConfig(name="candidate", expr="{{ name }} candidate")) + + def keep_rows(stage_path: Path) -> Path: + df = lazy.pd.read_parquet(stage_path / "parquet-files") + output_path = stage_path / "callback-outputs" / "kept" + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True) + df[df["keep"]].to_parquet(output_path / "data.parquet", index=False) + return output_path + + stage_2 = _expression_builder(stub_model_configs, "final", "{{ candidate }} final") + workflow = _real_data_designer(tmp_path / "artifacts", stub_model_providers).compose_workflow(name="repeat-append") + workflow.add_stage( + "candidates", + stage_1, + num_records=2, + on_success=keep_rows, + on_success_version="kept-v1", + repeat_until=RepeatUntil(output_records=2, max_iterations=3), + ) + workflow.add_stage("final", stage_2) + + results = workflow.run() + final = results.load_dataset().sort_values("name").reset_index(drop=True) + metadata = _load_workflow_metadata(tmp_path / "artifacts", "repeat-append") + + assert final[["name", "candidate", "final"]].to_dict(orient="records") == [ + {"name": "Grace", "candidate": "Grace candidate", "final": "Grace candidate final"}, + {"name": "Linus", "candidate": "Linus candidate", "final": "Linus candidate final"}, + ] + assert results["candidates"].count_records() == 4 + assert results.count_stage_output_records("candidates") == 2 + assert metadata["stages"][0]["num_records_requested"] == 4 + assert metadata["stages"][0]["repeat_iterations"] == 2 + assert metadata["stages"][0]["repeat_generated_records"] == 4 + assert metadata["stages"][0]["repeat_satisfied"] is True + assert metadata["stages"][0]["repeat_until_output_path"].endswith("stage-0-candidates/repeat-until/selected-output") + + +def test_composite_workflow_repeat_until_returns_partial_when_exhausted( + tmp_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], +) -> None: + stage_1 = _seeded_builder( + stub_model_configs, + [ + {"name": "Ada", "keep": False}, + {"name": "Grace", "keep": True}, + {"name": "Linus", "keep": False}, + {"name": "Margaret", "keep": False}, + ], + ) + stage_1.add_column(ExpressionColumnConfig(name="candidate", expr="{{ name }} candidate")) + + def keep_rows(stage_path: Path) -> Path: + df = lazy.pd.read_parquet(stage_path / "parquet-files") + output_path = stage_path / "callback-outputs" / "kept" + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True) + df[df["keep"]].to_parquet(output_path / "data.parquet", index=False) + return output_path + + workflow = _real_data_designer(tmp_path / "artifacts", stub_model_providers).compose_workflow(name="repeat-partial") + workflow.add_stage( + "candidates", + stage_1, + num_records=2, + on_success=keep_rows, + on_success_version="kept-v1", + repeat_until=RepeatUntil(output_records=3, max_iterations=2, on_exhausted="return_partial"), + ) + + results = workflow.run() + metadata = _load_workflow_metadata(tmp_path / "artifacts", "repeat-partial") + + assert results.load_dataset()["name"].tolist() == ["Grace"] + assert results.count_records() == 1 + assert metadata["stages"][0]["repeat_iterations"] == 2 + assert metadata["stages"][0]["repeat_generated_records"] == 4 + assert metadata["stages"][0]["repeat_satisfied"] is False + + +def test_composite_workflow_repeat_until_returns_empty_partial_when_exhausted( + tmp_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], +) -> None: + stage = _seeded_builder( + stub_model_configs, + [{"name": "Ada", "keep": False}, {"name": "Grace", "keep": False}], + ) + stage.add_column(ExpressionColumnConfig(name="candidate", expr="{{ name }} candidate")) + + def keep_rows(stage_path: Path) -> Path: + df = lazy.pd.read_parquet(stage_path / "parquet-files") + output_path = stage_path / "callback-outputs" / "kept" + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True) + df[df["keep"]].to_parquet(output_path / "data.parquet", index=False) + return output_path + + workflow = _real_data_designer(tmp_path / "artifacts", stub_model_providers).compose_workflow( + name="repeat-empty-partial" + ) + workflow.add_stage( + "candidates", + stage, + num_records=1, + on_success=keep_rows, + on_success_version="kept-v1", + repeat_until=RepeatUntil(output_records=1, max_iterations=2, on_exhausted="return_partial"), + ) + + results = workflow.run() + metadata = _load_workflow_metadata(tmp_path / "artifacts", "repeat-empty-partial") + + assert results.count_records() == 0 + assert results.load_dataset().empty + assert metadata["stages"][0]["status"] == "completed_empty" + assert metadata["stages"][0]["repeat_satisfied"] is False + + +def test_composite_workflow_repeat_until_raises_when_exhausted( + tmp_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], +) -> None: + stage = _seeded_builder(stub_model_configs, [{"name": "Ada", "keep": False}]) + stage.add_column(ExpressionColumnConfig(name="candidate", expr="{{ name }} candidate")) + + def keep_rows(stage_path: Path) -> Path: + df = lazy.pd.read_parquet(stage_path / "parquet-files") + output_path = stage_path / "callback-outputs" / "kept" + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True) + df[df["keep"]].to_parquet(output_path / "data.parquet", index=False) + return output_path + + workflow = _real_data_designer(tmp_path / "artifacts", stub_model_providers).compose_workflow(name="repeat-raises") + workflow.add_stage( + "candidates", + stage, + num_records=1, + on_success=keep_rows, + repeat_until=RepeatUntil(output_records=1, max_iterations=2), + ) + + with pytest.raises(DataDesignerWorkflowError, match="repeat_until exhausted"): + workflow.run() + + +def test_composite_workflow_repeat_until_discard_keeps_latest_attempt( + tmp_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], +) -> None: + stage = _seeded_builder( + stub_model_configs, + [{"name": "Ada"}, {"name": "Grace"}, {"name": "Linus"}], + ) + stage.add_column(ExpressionColumnConfig(name="candidate", expr="{{ name }} candidate")) + callback_calls = 0 + + def keep_more_each_time(stage_path: Path) -> Path: + nonlocal callback_calls + callback_calls += 1 + df = lazy.pd.read_parquet(stage_path / "parquet-files").head(callback_calls) + output_path = stage_path / "callback-outputs" / "latest" + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True) + df.to_parquet(output_path / "data.parquet", index=False) + return output_path + + workflow = _real_data_designer(tmp_path / "artifacts", stub_model_providers).compose_workflow(name="repeat-discard") + workflow.add_stage( + "candidates", + stage, + num_records=3, + on_success=keep_more_each_time, + repeat_until=RepeatUntil(output_records=2, max_iterations=3, mode="discard"), + ) + + results = workflow.run() + metadata = _load_workflow_metadata(tmp_path / "artifacts", "repeat-discard") + + assert callback_calls == 2 + assert results.load_dataset()["name"].tolist() == ["Ada", "Grace"] + assert results["candidates"].count_records() == 3 + assert metadata["stages"][0]["repeat_iterations"] == 2 + assert metadata["stages"][0]["repeat_generated_records"] == 6 + + +def test_composite_workflow_repeat_until_discard_warns_when_resuming( + stub_artifact_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], + stub_dataset_profiler_results, + caplog: pytest.LogCaptureFixture, +) -> None: + data_designer = _data_designer(stub_artifact_path, stub_model_providers) + _patch_create(data_designer, stub_dataset_profiler_results) + workflow = data_designer.compose_workflow(name="repeat-discard-resume") + workflow.add_stage( + "candidates", + _category_builder(stub_model_configs), + num_records=2, + repeat_until=RepeatUntil(output_records=1, max_iterations=2, mode="discard"), + ) + workflow.run() + + metadata_path = stub_artifact_path / "repeat-discard-resume" / "workflow-metadata.json" + metadata = json.loads(metadata_path.read_text()) + _mark_stage_resumable(metadata, 0, "failed") + metadata_path.write_text(json.dumps(metadata)) + + resumed = data_designer.compose_workflow(name="repeat-discard-resume") + resumed.add_stage( + "candidates", + _category_builder(stub_model_configs), + num_records=2, + repeat_until=RepeatUntil(output_records=1, max_iterations=2, mode="discard"), + ) + + with caplog.at_level("WARNING", logger="data_designer.interface.composite_workflow"): + resumed.run(resume=ResumeMode.IF_POSSIBLE) + + assert "previous attempts cannot be resumed" in caplog.text + + +def test_composite_workflow_repeat_until_uses_processor_output( + tmp_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], +) -> None: + stage_1 = _seeded_builder( + stub_model_configs, + [{"name": "Ada"}, {"name": "Linus"}, {"name": "Grace"}], + ) + stage_1.add_column(ExpressionColumnConfig(name="persona", expr="{{ name }}")) + stage_2 = _expression_builder(stub_model_configs, "final", "{{ compact_name }} final") + + workflow = _real_data_designer(tmp_path / "artifacts", stub_model_providers).compose_workflow( + name="repeat-processor-output" + ) + workflow.add_stage( + "compact", + stage_1, + num_records=1, + output_processors=[SchemaTransformProcessorConfig(name="compact", template={"compact_name": "{{ persona }}"})], + output="processor:compact", + repeat_until=RepeatUntil(output_records=2, max_iterations=3), + ) + workflow.add_stage("final", stage_2) + + results = workflow.run() + final = results.load_dataset().sort_values("compact_name").reset_index(drop=True) + stage_output = results.load_stage_output("compact").sort_values("compact_name").reset_index(drop=True) + metadata = _load_workflow_metadata(tmp_path / "artifacts", "repeat-processor-output") + + assert stage_output.to_dict(orient="records") == [{"compact_name": "Ada"}, {"compact_name": "Linus"}] + assert final.to_dict(orient="records") == [ + {"compact_name": "Ada", "final": "Ada final"}, + {"compact_name": "Linus", "final": "Linus final"}, + ] + assert metadata["stages"][0]["num_records_requested"] == 2 + assert metadata["stages"][0]["repeat_iterations"] == 2 + assert metadata["stages"][0]["output_seed_path"].endswith( + "stage-0-compact/output-processors/processors-files/compact" + ) + + def test_composite_workflow_does_not_forward_dropped_processor_columns( tmp_path: Path, stub_model_providers: list[ModelProvider],