From 70e1d2e3995a063ced6bc7ac06c2482afdfe46d0 Mon Sep 17 00:00:00 2001 From: Brian O'Kelley Date: Fri, 1 May 2026 11:58:21 -0400 Subject: [PATCH 1/2] feat(types): narrow discriminated-union ValidationErrors to user's intended variant MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stability AI Emma backend test (verdict 5/10) flagged that constructing a CreativeManifest whose assets value matched one variant (e.g. ImageAsset) but was missing fields required by THAT variant produced a 60-line pydantic ValidationError listing every variant of the asset content union (13+ variants, 26 errors). The user's actual mistake (one missing field on the variant they picked) was buried. adcp.types.error_narrowing.narrow_union_errors post-processes ValidationError.errors() to keep only errors from the closest-fit variant. Strategy: 1. Discriminator match — variants with no literal_error / union_tag_not_found had their discriminator value match the user's input. If exactly one such variant exists, surface only its errors. Stability case: ImageAsset. 2. Fewest-errors fallback — when no clear winner, pick the variant with fewest non-literal errors as a closest-fit guess. 3. Tie → pass through all errors so the adopter can disambiguate. Wired into create_tool_caller's INVALID_REQUEST projection so wire-side validation errors get the narrowing automatically. Adopters can also call narrow_union_errors manually on a caught ValidationError. Before/after: BEFORE: 26 errors (every variant of the asset union) AFTER: 2 errors (assets.hero.ImageAsset.width / .height) Tests: 6 new. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/adcp/server/mcp_tools.py | 7 + src/adcp/types/error_narrowing.py | 216 ++++++++++++++++++++++++++++++ tests/test_error_narrowing.py | 204 ++++++++++++++++++++++++++++ 3 files changed, 427 insertions(+) create mode 100644 src/adcp/types/error_narrowing.py create mode 100644 tests/test_error_narrowing.py diff --git a/src/adcp/server/mcp_tools.py b/src/adcp/server/mcp_tools.py index 942afafa4..3f35061b3 100644 --- a/src/adcp/server/mcp_tools.py +++ b/src/adcp/server/mcp_tools.py @@ -1746,6 +1746,13 @@ async def call_tool(params: dict[str, Any], context: ToolContext | None = None) errors_list = exc.errors( include_input=False, include_context=False, include_url=False ) + # Narrow discriminated-union failures to the variant + # the user actually intended (Stability AI Emma P2: + # 60-line dump → focused error). For non-union + # failures the function is a no-op. + from adcp.types.error_narrowing import narrow_union_errors + + errors_list = narrow_union_errors(errors_list) first: dict[str, Any] = dict(errors_list[0]) if errors_list else {} field_path = ".".join(str(loc) for loc in first.get("loc", ())) message = first.get("msg", "validation failed") diff --git a/src/adcp/types/error_narrowing.py b/src/adcp/types/error_narrowing.py new file mode 100644 index 000000000..8160c46aa --- /dev/null +++ b/src/adcp/types/error_narrowing.py @@ -0,0 +1,216 @@ +"""Narrow pydantic discriminated-union ValidationErrors to the +variant the user actually intended. + +Background (Stability AI Emma backend test, verdict 5/10): when an +adopter constructs a ``CreativeManifest`` whose ``assets`` value +matches one variant of a discriminated union but is missing fields +required by THAT variant, pydantic 2 reports validation errors for +EVERY variant in the union (13+ for asset content types). The error +dump runs 60+ lines and obscures the actual problem (a single +missing field on the variant the user picked). + +This module exposes :func:`narrow_union_errors` which post-processes +``ValidationError.errors()`` to keep only the errors from the +"closest fit" variant — the one whose discriminator matched OR the +one with the fewest non-discriminator errors. The result is a +focused error pointing at the user's actual mistake. + +Used by: + +* :func:`adcp.server.mcp_tools.create_tool_caller` — narrows + wire-side ``INVALID_REQUEST`` errors automatically. +* Adopter code via :func:`narrow_validation_error` (manual) — for + adopters who construct typed models in their platform method + bodies and want the same friendlier error UX. +""" + +from __future__ import annotations + +from typing import Any + + +# Heuristic: a "variant" location segment is a class name. +# Pydantic emits ``("assets", "hero", "ImageAsset", "width")`` for +# union-validation errors. The variant name is the +# second-to-last position when the error is on a field, OR the +# last position when the error is at the variant itself +# (e.g. ``union_tag_not_found``). +def _looks_like_variant_name(segment: Any) -> bool: + """Heuristic: a Python class name (CamelCase, starts with capital). + + Used to detect variant segments in ``ValidationError.errors()[i].loc``. + Pydantic interleaves variant class names into the loc tuple for + discriminated-union failures; we strip those segments to identify + "this error belongs to variant X." + """ + if not isinstance(segment, str): + return False + if not segment: + return False + # Class names start with an uppercase letter and contain only + # alphanumerics. Reject ``snake_case`` field names. + return segment[0].isupper() and segment.replace("_", "").isalnum() and "_" not in segment + + +def _split_at_variant( + loc: tuple[Any, ...], +) -> tuple[tuple[Any, ...], str, tuple[Any, ...]] | None: + """Split a loc tuple at the first variant-name segment. + + Returns ``(prefix_before_variant, variant_name, suffix_after_variant)`` + or ``None`` if no variant segment is found. Used to group + union-validation errors by their containing field path + variant. + + Example:: + + loc = ("assets", "hero", "ImageAsset", "width") + → (("assets", "hero"), "ImageAsset", ("width",)) + """ + for i, segment in enumerate(loc): + if _looks_like_variant_name(segment): + return loc[:i], segment, loc[i + 1 :] + return None + + +def narrow_union_errors( + errors: Any, +) -> list[Any]: + """Return a focused subset of ``errors`` for discriminated-union + failures. + + For each (parent_loc) where multiple variant errors exist, pick + the "best fit" variant by: + + 1. **Discriminator match**: if exactly one variant lacks a + ``literal_error`` whose input doesn't match the expected + discriminator, that variant matched the user's discriminator + value. Keep ONLY its errors. + 2. **Fewest non-discriminator errors**: if no clear discriminator + winner, the variant with the smallest count of non-literal + errors is the closest fit. Keep ONLY its errors. + + Errors that aren't part of a union failure (no variant in their + ``loc``) pass through unchanged. The function never returns an + empty list when the input is non-empty — the worst case falls + back to the input. + + Mirrors the JS-side ``narrowUnionValidationErrors`` (when ported). + """ + if not errors: + return list(errors) if errors else [] + + # Bucket errors by (prefix_before_variant) — every error sharing + # the same prefix is contending for the same logical slot, and + # different errors in the same bucket are different variants of + # the same union. + buckets: dict[tuple[Any, ...], list[tuple[str, Any]]] = {} + passthrough: list[Any] = [] + + for err in errors: + loc = tuple(err.get("loc", ())) + split = _split_at_variant(loc) + if split is None: + passthrough.append(err) + continue + prefix, variant, _suffix = split + buckets.setdefault(prefix, []).append((variant, err)) + + if not buckets: + return list(errors) + + narrowed: list[Any] = list(passthrough) + for prefix, variant_errors in buckets.items(): + # Group by variant name within this bucket. + per_variant: dict[str, list[Any]] = {} + for variant, err in variant_errors: + per_variant.setdefault(variant, []).append(err) + + if len(per_variant) <= 1: + # Only one variant in this bucket — no narrowing needed. + for errs in per_variant.values(): + narrowed.extend(errs) + continue + + winner = _pick_winning_variant(per_variant, prefix) + if winner is None: + # Couldn't disambiguate; fall back to all variants for + # this bucket so the adopter doesn't lose information. + for errs in per_variant.values(): + narrowed.extend(errs) + continue + narrowed.extend(per_variant[winner]) + + return narrowed + + +def _pick_winning_variant( + per_variant: dict[str, list[Any]], + prefix: tuple[Any, ...], +) -> str | None: + """Return the variant name whose errors are the closest fit. + + Strategy (in order): + + 1. **Discriminator match**: variants with ZERO ``literal_error`` + errors had their discriminator value match the user's input. + If exactly one such variant exists, it's the winner. This is + the Stability AI / AudioStack 60-line-dump fix — when the user + provides ``asset_type='image'`` and ImageAsset's other fields + fail, we surface ImageAsset errors only. + 2. **Fewest errors among matched**: if multiple variants matched + the discriminator, pick the one with fewest errors (closest + fit to user's input shape). + 3. **Fallback to fewest errors overall**: if NO variant matched + the discriminator (the user provided an invalid discriminator + value, e.g. ``asset_type='image_asset'`` instead of ``'image'``), + pick the variant with fewest non-literal errors as a closest- + fit guess. + 4. **Tie**: return ``None`` so the caller passes through all + errors — adopter can disambiguate manually. + """ + if not per_variant: + return None + + # Step 1: variants with ZERO discriminator-mismatch errors are the + # discriminator-matched candidates. Discriminator-mismatch signals: + # - ``literal_error`` on the discriminator field (variant's + # literal value didn't match the user's input) + # - ``union_tag_not_found`` (a NESTED union inside this variant + # couldn't be narrowed to any of ITS sub-variants — meaning + # the user's input doesn't fit this variant's shape at all) + discriminator_mismatch_types = {"literal_error", "union_tag_not_found"} + matched = { + variant: errs + for variant, errs in per_variant.items() + if not any(e.get("type") in discriminator_mismatch_types for e in errs) + } + + if matched: + if len(matched) == 1: + return next(iter(matched.keys())) + # Multiple matched (rare — would mean two variants share the + # same discriminator literal). Pick the one with fewest + # errors. + sorted_matched = sorted(matched.items(), key=lambda kv: len(kv[1])) + if len(sorted_matched) >= 2 and len(sorted_matched[0][1]) == len(sorted_matched[1][1]): + return None # tie + return sorted_matched[0][0] + + # Step 3: no discriminator match — pick the variant with fewest + # non-literal errors. That's the closest-fit guess for an adopter + # who used an invalid discriminator. + def _non_literal_score(item: tuple[str, list[Any]]) -> int: + _variant, errs = item + return sum(1 for e in errs if e.get("type") != "literal_error") + + sorted_all = sorted(per_variant.items(), key=_non_literal_score) + if len(sorted_all) >= 2 and _non_literal_score(sorted_all[0]) == _non_literal_score( + sorted_all[1] + ): + return None # tie — don't guess + return sorted_all[0][0] + + +__all__ = [ + "narrow_union_errors", +] diff --git a/tests/test_error_narrowing.py b/tests/test_error_narrowing.py new file mode 100644 index 000000000..f7bc34507 --- /dev/null +++ b/tests/test_error_narrowing.py @@ -0,0 +1,204 @@ +"""narrow_union_errors — discriminated-union ValidationError narrowing. + +Stability AI Emma backend test (verdict 5/10) flagged that constructing +a ``CreativeManifest`` whose asset is missing a required field (e.g. +``ImageAsset.width``) produced a 60-line pydantic ValidationError +listing every variant of the asset content union (13+ variants). The +user's actual mistake (one missing field) was buried. + +This file pins the post-fix behavior: + +* End-to-end: the framework's typed-dispatch path runs the + ``CreativeManifest`` construction through ``narrow_union_errors`` + and surfaces only the variant the user matched. +* Algorithm-level: ``narrow_union_errors`` correctly identifies the + discriminator-matched variant, the fewest-error fallback, and the + pass-through case for non-union errors. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from adcp.types.error_narrowing import narrow_union_errors + +# ---- Pass-through cases ---- + + +def test_pass_through_empty_errors() -> None: + """Empty input → empty output. No surprises on no-op input.""" + assert narrow_union_errors([]) == [] + + +def test_pass_through_non_union_errors() -> None: + """Errors with no variant in their loc (typical scalar field + validation) pass through unchanged.""" + errors: list[dict[str, Any]] = [ + {"type": "missing", "loc": ("brief",), "msg": "Field required"}, + {"type": "string_type", "loc": ("po_number",), "msg": "Input should be a valid string"}, + ] + assert narrow_union_errors(errors) == errors + + +# ---- Variant detection ---- + + +def test_narrow_picks_discriminator_matched_variant() -> None: + """When one variant has only field-level errors and the others + have ``literal_error`` on what looks like the discriminator, + keep ONLY the matched variant's errors. This is the + Stability/AudioStack 60-line-dump fix.""" + errors: list[dict[str, Any]] = [ + # Matched variant — missing fields, no literal_error. + { + "type": "missing", + "loc": ("assets", "hero", "ImageAsset", "width"), + "msg": "Field required", + }, + { + "type": "missing", + "loc": ("assets", "hero", "ImageAsset", "height"), + "msg": "Field required", + }, + # Non-matching variants — discriminator-only failure. + { + "type": "literal_error", + "loc": ("assets", "hero", "VideoAsset", "asset_type"), + "msg": "Input should be 'video'", + }, + { + "type": "literal_error", + "loc": ("assets", "hero", "AudioAsset", "asset_type"), + "msg": "Input should be 'audio'", + }, + ] + narrowed = narrow_union_errors(errors) + assert len(narrowed) == 2, f"Expected 2 ImageAsset errors, got {len(narrowed)}: {narrowed}" + assert all("ImageAsset" in err["loc"] for err in narrowed) + + +def test_narrow_picks_fewest_errors_when_no_discriminator_winner() -> None: + """When NO variant has a clean discriminator match (e.g. the user + provided an invalid discriminator value, so every variant has a + ``literal_error`` on the discriminator field), pick the variant + with the fewest non-discriminator errors as the closest-fit + guess.""" + errors: list[dict[str, Any]] = [ + # ImageAsset has 2 missing fields + literal mismatch. + { + "type": "missing", + "loc": ("assets", "hero", "ImageAsset", "width"), + "msg": "Field required", + }, + { + "type": "missing", + "loc": ("assets", "hero", "ImageAsset", "height"), + "msg": "Field required", + }, + { + "type": "literal_error", + "loc": ("assets", "hero", "ImageAsset", "asset_type"), + "msg": "Input should be 'image'", + }, + # VideoAsset has 4 missing fields + literal mismatch. + { + "type": "missing", + "loc": ("assets", "hero", "VideoAsset", "width"), + "msg": "Field required", + }, + { + "type": "missing", + "loc": ("assets", "hero", "VideoAsset", "height"), + "msg": "Field required", + }, + { + "type": "missing", + "loc": ("assets", "hero", "VideoAsset", "duration_ms"), + "msg": "Field required", + }, + { + "type": "missing", + "loc": ("assets", "hero", "VideoAsset", "codec"), + "msg": "Field required", + }, + { + "type": "literal_error", + "loc": ("assets", "hero", "VideoAsset", "asset_type"), + "msg": "Input should be 'video'", + }, + ] + narrowed = narrow_union_errors(errors) + # Both variants have literal_error → no discriminator winner → + # fewest-non-literal-errors heuristic. ImageAsset: 2 non-literal, + # VideoAsset: 4 non-literal. ImageAsset wins. + image_errors = [e for e in narrowed if "ImageAsset" in e["loc"]] + video_errors = [e for e in narrowed if "VideoAsset" in e["loc"]] + assert image_errors and not video_errors + + +def test_narrow_falls_back_when_variants_tie() -> None: + """When multiple variants tie on error count, the function returns + all errors rather than guessing — surfaces the ambiguity to the + adopter who can disambiguate via the discriminator.""" + errors: list[dict[str, Any]] = [ + { + "type": "missing", + "loc": ("assets", "hero", "ImageAsset", "width"), + "msg": "Field required", + }, + { + "type": "missing", + "loc": ("assets", "hero", "VideoAsset", "duration_ms"), + "msg": "Field required", + }, + ] + # Both variants have 1 non-literal error. Tie → keep both. + narrowed = narrow_union_errors(errors) + assert len(narrowed) == 2 + + +# ---- End-to-end: CreativeManifest with missing field ---- + + +def test_e2e_creative_manifest_missing_width_height_narrows_to_image_asset() -> None: + """End-to-end regression for the exact Stability AI report case. + + Before the narrowing: 26 ValidationErrors covering every variant + of the asset content union. After: just the 2 ImageAsset.width / + ImageAsset.height errors the adopter cares about.""" + from pydantic import ValidationError + + from adcp.types import CreativeManifest, FormatReferenceStructuredObject + + try: + CreativeManifest( + creative_id="cr-1", + format_id=FormatReferenceStructuredObject(agent_url="https://x", id="img"), + assets={ + "hero": { + "asset_role": "hero", + "asset_type": "image", + "url": "https://x.png", + } + }, + ) + except ValidationError as exc: + narrowed = narrow_union_errors( + exc.errors(include_input=False, include_context=False, include_url=False) + ) + # Should be ~2 errors (ImageAsset's width + height), not ~26. + assert ( + len(narrowed) <= 4 + ), f"narrow_union_errors didn't narrow: {len(narrowed)} errors remain" + assert all("ImageAsset" in err["loc"] for err in narrowed), ( + "narrowed result should be ImageAsset-only; got: " f"{[err['loc'] for err in narrowed]}" + ) + missing_fields = {err["loc"][-1] for err in narrowed if err["type"] == "missing"} + assert "width" in missing_fields and "height" in missing_fields + else: + pytest.fail( + "CreativeManifest accepted invalid asset (missing width/height); " + "regression in upstream pydantic discriminated-union behavior" + ) From 74a221b20df18a613535c86c65df4b3eafe7a9c9 Mon Sep 17 00:00:00 2001 From: Brian O'Kelley Date: Fri, 1 May 2026 12:24:40 -0400 Subject: [PATCH 2/2] fix(types): expert-review hardening on PR #340 narrowing heuristic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two reviewers (code-reviewer, python-expert) flagged five P0/P1s: P0 (python-expert): pydantic upper bound. The narrowing heuristic depends on pydantic-2 ValidationError.errors() internals (err["type"] literals, CamelCase variant names interleaved in err["loc"]). Pydantic 3 makes no API guarantee on these. Pin <3. P1 (code-reviewer): missing union_tag_invalid in mismatch types. Sibling of union_tag_not_found that pydantic-2 emits when a tag is found but invalid. Without it, a variant with that error stays in the candidate pool and may falsely win. P1 (python-expert): _split_at_variant used FIRST CamelCase segment. Nested unions (Union[Outer[Union[A, B]], C]) emit loc like ("field", "Outer", "inner", "A", "subfield"); splitting at "Outer" collapsed A and B into one bucket. Switch to LAST variant segment (innermost wins). P1 (code-reviewer): narrowing call was unguarded. A bug in the heuristic would 500 the wire path. Wrap in try/except in the INVALID_REQUEST projection; fall back to unfiltered errors with WARNING log. P1 (python-expert): defensive copy of returned dicts. Caller mutation could leak back into pydantic's internal error list. ``dict(err)`` per-error at the boundary; cheap insurance. P2 cleanup: - Lift import out of hot path to module top. - DoS guard: cap input at 500 errors (narrowing is UX, not correctness; an attacker shouldn't get to amplify CPU through bucketing). - Document residual literal_error overloading edge case. Tests: 4 new (union_tag_invalid mismatch handling, nested-union innermost-variant resolution, defensive-copy guard, DoS cap). Deferred to follow-up (per python-expert): schema-level fix — teach codegen to emit Annotated[Union[...], Field(discriminator=...)] so pydantic narrows correctly without our post-processor. Test count: 2878 passed (was 2874 — +4 hardening tests). Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 8 +- src/adcp/server/mcp_tools.py | 20 ++++- src/adcp/types/error_narrowing.py | 127 ++++++++++++++++++++++-------- tests/test_error_narrowing.py | 91 +++++++++++++++++++++ 4 files changed, 209 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9848ed126..5b9e7c7d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,13 @@ dependencies = [ # ``tests/conformance/signing/test_ip_pinned_transport_contract.py`` # guards the specific API shapes we rely on. "httpcore>=1.0,<2.0", - "pydantic>=2.0.0", + # Upper bound is load-bearing for ``adcp.types.error_narrowing``, + # which depends on pydantic-2 ValidationError.errors() internals + # (``err["type"]`` literals like ``"literal_error"`` / + # ``"union_tag_not_found"`` and CamelCase variant names interleaved + # in ``err["loc"]``). Pydantic 3 has no API guarantee on these + # internals; bump only after porting the narrowing heuristics. + "pydantic>=2.0.0,<3", "typing-extensions>=4.5.0", # A2A protocol v1.0 (protobuf types, ProtoJSON on the wire). We run # on the v1.0 Python SDK with ``enable_v0_3_compat=True`` on the diff --git a/src/adcp/server/mcp_tools.py b/src/adcp/server/mcp_tools.py index 3f35061b3..963a69429 100644 --- a/src/adcp/server/mcp_tools.py +++ b/src/adcp/server/mcp_tools.py @@ -26,6 +26,7 @@ from adcp.server.base import ADCPHandler, ToolContext from adcp.server.test_controller import SCENARIOS as _CONTROLLER_SCENARIOS +from adcp.types.error_narrowing import narrow_union_errors from adcp.validation.client_hooks import ValidationHookConfig logger = logging.getLogger(__name__) @@ -1750,9 +1751,22 @@ async def call_tool(params: dict[str, Any], context: ToolContext | None = None) # the user actually intended (Stability AI Emma P2: # 60-line dump → focused error). For non-union # failures the function is a no-op. - from adcp.types.error_narrowing import narrow_union_errors - - errors_list = narrow_union_errors(errors_list) + # + # Defensive: if the narrowing helper itself raises + # (heuristic edge case, future pydantic format + # change), keep the original error list rather than + # 500'ing the wire path. The narrowed-error UX is a + # nice-to-have; correctness is surfacing SOME error. + try: + errors_list = list(narrow_union_errors(errors_list)) + except Exception: + logger.warning( + "narrow_union_errors raised on %s — passing through " + "unfiltered errors. This is a bug in the narrowing " + "heuristic, NOT in the validation itself.", + method_name, + exc_info=True, + ) first: dict[str, Any] = dict(errors_list[0]) if errors_list else {} field_path = ".".join(str(loc) for loc in first.get("loc", ())) message = first.get("msg", "validation failed") diff --git a/src/adcp/types/error_narrowing.py b/src/adcp/types/error_narrowing.py index 8160c46aa..91df6b481 100644 --- a/src/adcp/types/error_narrowing.py +++ b/src/adcp/types/error_narrowing.py @@ -55,21 +55,46 @@ def _looks_like_variant_name(segment: Any) -> bool: def _split_at_variant( loc: tuple[Any, ...], ) -> tuple[tuple[Any, ...], str, tuple[Any, ...]] | None: - """Split a loc tuple at the first variant-name segment. + """Split a loc tuple at the LAST variant-name segment. Returns ``(prefix_before_variant, variant_name, suffix_after_variant)`` or ``None`` if no variant segment is found. Used to group union-validation errors by their containing field path + variant. + The LAST variant segment (innermost) is the one whose error + mattered. For nested unions + (``Union[Outer[Union[A, B]], C]``) pydantic emits + ``("field", "Outer", "inner", "A", "subfield")`` — splitting at + the FIRST variant segment ("Outer") would collapse "A" and "B" + into one bucket; splitting at the LAST ("A") correctly groups by + innermost variant. + Example:: loc = ("assets", "hero", "ImageAsset", "width") → (("assets", "hero"), "ImageAsset", ("width",)) + + loc = ("field", "Outer", "inner", "A", "subfield") + → (("field", "Outer", "inner"), "A", ("subfield",)) """ + last_variant_idx: int | None = None for i, segment in enumerate(loc): if _looks_like_variant_name(segment): - return loc[:i], segment, loc[i + 1 :] - return None + last_variant_idx = i + if last_variant_idx is None: + return None + variant = loc[last_variant_idx] + assert isinstance(variant, str) # _looks_like_variant_name guarantees this + return loc[:last_variant_idx], variant, loc[last_variant_idx + 1 :] + + +#: Cap on input list size. Beyond this we pass through unchanged — +#: narrowing is a UX feature, not correctness, and an attacker +#: submitting a request that triggers thousands of validation errors +#: shouldn't get to amplify CPU through O(N) bucketing logic. The cap +#: is generous enough that genuine union dumps (~30 errors for a 13- +#: variant asset union) never hit it. +_MAX_NARROW_INPUT_SIZE = 500 def narrow_union_errors( @@ -81,10 +106,10 @@ def narrow_union_errors( For each (parent_loc) where multiple variant errors exist, pick the "best fit" variant by: - 1. **Discriminator match**: if exactly one variant lacks a - ``literal_error`` whose input doesn't match the expected - discriminator, that variant matched the user's discriminator - value. Keep ONLY its errors. + 1. **Discriminator match**: variants with no ``literal_error``, + ``union_tag_not_found``, or ``union_tag_invalid`` had their + discriminator value match the user's input. Keep ONLY their + errors. 2. **Fewest non-discriminator errors**: if no clear discriminator winner, the variant with the smallest count of non-literal errors is the closest fit. Keep ONLY its errors. @@ -94,10 +119,30 @@ def narrow_union_errors( empty list when the input is non-empty — the worst case falls back to the input. + **Edge case** (residual): pydantic's ``literal_error`` type fires + on ANY ``Literal[...]`` field mismatch, not just the discriminator. + A user input that hits a non-discriminator literal mismatch on the + matched variant (e.g., correct ``asset_type`` but wrong + ``codec``) will eliminate the matched variant from step 1 and the + fallback may pick a wrong variant. The narrowing reduces noise + even in this case but may surface the wrong variant's errors. + Resolving this requires knowing the discriminator field name, + which the heuristic doesn't have access to. Schema-level fix + (``Annotated[Union[...], Field(discriminator=...)]``) avoids the + issue entirely; tracked as a follow-up. + Mirrors the JS-side ``narrowUnionValidationErrors`` (when ported). """ if not errors: - return list(errors) if errors else [] + return [] + + errors_list = list(errors) + # DoS guard: don't process pathologically-large inputs. Below the + # cap, narrowing helps. Above it, we're either in a hostile + # request or a legitimately massive schema; either way, the + # narrowing UX win doesn't justify the CPU. + if len(errors_list) > _MAX_NARROW_INPUT_SIZE: + return errors_list # Bucket errors by (prefix_before_variant) — every error sharing # the same prefix is contending for the same logical slot, and @@ -106,7 +151,7 @@ def narrow_union_errors( buckets: dict[tuple[Any, ...], list[tuple[str, Any]]] = {} passthrough: list[Any] = [] - for err in errors: + for err in errors_list: loc = tuple(err.get("loc", ())) split = _split_at_variant(loc) if split is None: @@ -116,10 +161,15 @@ def narrow_union_errors( buckets.setdefault(prefix, []).append((variant, err)) if not buckets: - return list(errors) - - narrowed: list[Any] = list(passthrough) - for prefix, variant_errors in buckets.items(): + return errors_list + + # Defensive copy of the dicts we're about to surface — the caller + # might mutate the returned list and we don't want that to leak + # back into the input. ``dict(err)`` is shallow which is fine: + # ``loc`` is a tuple (immutable), and other values are scalars or + # nested dicts pydantic doesn't share across errors. + narrowed: list[Any] = [dict(err) for err in passthrough] + for _prefix, variant_errors in buckets.items(): # Group by variant name within this bucket. per_variant: dict[str, list[Any]] = {} for variant, err in variant_errors: @@ -128,35 +178,54 @@ def narrow_union_errors( if len(per_variant) <= 1: # Only one variant in this bucket — no narrowing needed. for errs in per_variant.values(): - narrowed.extend(errs) + narrowed.extend(dict(e) for e in errs) continue - winner = _pick_winning_variant(per_variant, prefix) + winner = _pick_winning_variant(per_variant) if winner is None: # Couldn't disambiguate; fall back to all variants for # this bucket so the adopter doesn't lose information. for errs in per_variant.values(): - narrowed.extend(errs) + narrowed.extend(dict(e) for e in errs) continue - narrowed.extend(per_variant[winner]) + narrowed.extend(dict(e) for e in per_variant[winner]) return narrowed +#: Pydantic-2 error types that signal "this variant's discriminator +#: didn't match the user's input". A variant whose error list contains +#: ANY of these is eliminated from step 1's candidate pool. +#: +#: * ``literal_error`` — a ``Literal[...]`` field rejected the value. +#: Discriminators are typically Literal-typed (``asset_type: +#: Literal["image"]``); a mismatch here means this variant isn't +#: the user's intent. +#: * ``union_tag_not_found`` — a NESTED tagged union inside this +#: variant couldn't be narrowed to any of ITS sub-variants. Means +#: the user's input doesn't fit this variant's shape at all. +#: * ``union_tag_invalid`` — pydantic-2's "tag found but invalid for +#: this union" code. Same semantic as ``union_tag_not_found`` for +#: our purposes. +_DISCRIMINATOR_MISMATCH_TYPES = frozenset( + {"literal_error", "union_tag_not_found", "union_tag_invalid"} +) + + def _pick_winning_variant( per_variant: dict[str, list[Any]], - prefix: tuple[Any, ...], ) -> str | None: """Return the variant name whose errors are the closest fit. Strategy (in order): - 1. **Discriminator match**: variants with ZERO ``literal_error`` - errors had their discriminator value match the user's input. - If exactly one such variant exists, it's the winner. This is - the Stability AI / AudioStack 60-line-dump fix — when the user - provides ``asset_type='image'`` and ImageAsset's other fields - fail, we surface ImageAsset errors only. + 1. **Discriminator match**: variants with ZERO discriminator-mismatch + errors (see :data:`_DISCRIMINATOR_MISMATCH_TYPES`) had their + discriminator value match the user's input. If exactly one + such variant exists, it's the winner. This is the Stability AI + / AudioStack 60-line-dump fix — when the user provides + ``asset_type='image'`` and ImageAsset's other fields fail, we + surface ImageAsset errors only. 2. **Fewest errors among matched**: if multiple variants matched the discriminator, pick the one with fewest errors (closest fit to user's input shape). @@ -171,18 +240,10 @@ def _pick_winning_variant( if not per_variant: return None - # Step 1: variants with ZERO discriminator-mismatch errors are the - # discriminator-matched candidates. Discriminator-mismatch signals: - # - ``literal_error`` on the discriminator field (variant's - # literal value didn't match the user's input) - # - ``union_tag_not_found`` (a NESTED union inside this variant - # couldn't be narrowed to any of ITS sub-variants — meaning - # the user's input doesn't fit this variant's shape at all) - discriminator_mismatch_types = {"literal_error", "union_tag_not_found"} matched = { variant: errs for variant, errs in per_variant.items() - if not any(e.get("type") in discriminator_mismatch_types for e in errs) + if not any(e.get("type") in _DISCRIMINATOR_MISMATCH_TYPES for e in errs) } if matched: diff --git a/tests/test_error_narrowing.py b/tests/test_error_narrowing.py index f7bc34507..8cd542342 100644 --- a/tests/test_error_narrowing.py +++ b/tests/test_error_narrowing.py @@ -202,3 +202,94 @@ def test_e2e_creative_manifest_missing_width_height_narrows_to_image_asset() -> "CreativeManifest accepted invalid asset (missing width/height); " "regression in upstream pydantic discriminated-union behavior" ) + + +# ---- Expert-review hardening (PR #340 round 2) ---- + + +def test_narrow_handles_union_tag_invalid_as_discriminator_mismatch() -> None: + """``union_tag_invalid`` is pydantic-2's "tag found but invalid" + code (sibling of ``union_tag_not_found``). Must be treated as a + discriminator-mismatch signal so a variant with that error gets + eliminated from the candidate pool.""" + errors: list[dict[str, Any]] = [ + { + "type": "missing", + "loc": ("assets", "hero", "ImageAsset", "width"), + "msg": "Field required", + }, + # Invalid tag: pydantic found a discriminator field but its + # value doesn't match this variant's literal. + { + "type": "union_tag_invalid", + "loc": ("assets", "hero", "VideoAsset"), + "msg": "Input tag invalid for this union", + }, + ] + narrowed = narrow_union_errors(errors) + # ImageAsset wins — VideoAsset has the discriminator-mismatch signal. + assert len(narrowed) == 1 + assert "ImageAsset" in narrowed[0]["loc"] + + +def test_narrow_uses_innermost_variant_for_nested_unions() -> None: + """Nested unions: ``Union[Outer[Union[A, B]], C]``. Pydantic emits + ``loc`` like ``("field", "Outer", "inner", "A", "subfield")``. + The split MUST pick the LAST variant segment (innermost ``A``) so + A and B end up in different buckets and narrow correctly. Splitting + at the FIRST variant segment would collapse them and break the + narrowing.""" + errors: list[dict[str, Any]] = [ + # Inner variant A — clean (no literal_error). + { + "type": "missing", + "loc": ("field", "Outer", "inner", "A", "subfield"), + "msg": "Field required", + }, + # Inner variant B — discriminator mismatch. + { + "type": "literal_error", + "loc": ("field", "Outer", "inner", "B", "kind"), + "msg": "Input should be 'b'", + }, + ] + narrowed = narrow_union_errors(errors) + # A wins; B's literal_error eliminates it from candidates. + assert len(narrowed) == 1 + assert "A" in narrowed[0]["loc"] + assert "B" not in narrowed[0]["loc"] + + +def test_narrow_returns_defensive_copies() -> None: + """The returned list contains COPIES of the input dicts — + mutating the returned list MUST NOT affect the input. Cheap + insurance; matters because pydantic's ``ErrorDetails`` flows from + user input through a wide call graph.""" + original_err: dict[str, Any] = { + "type": "missing", + "loc": ("assets", "hero", "ImageAsset", "width"), + "msg": "Field required", + } + errors = [original_err] + narrowed = narrow_union_errors(errors) + # Mutate the returned dict — must not propagate to input. + narrowed[0]["msg"] = "MUTATED" + assert original_err["msg"] == "Field required" + + +def test_narrow_passes_through_oversized_input() -> None: + """DoS guard: above 500 errors we don't run narrowing — narrowing + is UX, not correctness, and a hostile request shouldn't get to + amplify CPU through the bucketing logic. Verifies the cap is + enforced and the original list comes back unchanged.""" + # Build 600 errors that WOULD be narrowable if we ran the algorithm. + errors: list[dict[str, Any]] = [ + { + "type": "missing", + "loc": ("field", f"Variant{i}", "x"), + "msg": "Field required", + } + for i in range(600) + ] + narrowed = narrow_union_errors(errors) + assert len(narrowed) == 600 # bypass — input came back as-is