diff --git a/examples/typed_handler_demo.py b/examples/typed_handler_demo.py index 9462f5d1..1e9d3158 100644 --- a/examples/typed_handler_demo.py +++ b/examples/typed_handler_demo.py @@ -21,7 +21,7 @@ from typing import Any -from adcp.server import ADCPHandler, ToolContext, serve +from adcp.server import ADCPHandler, ServeConfig, ToolContext, serve from adcp.types import ( GetAdcpCapabilitiesResponse, GetProductsRequest, @@ -87,4 +87,8 @@ async def get_products( # For production, wrap with an auth middleware (see # ``examples/mcp_with_auth_middleware.py``) and restrict the host # via reverse-proxy config or the ``port=`` / bind-host hooks. - serve(TypedSeller(), name="typed-demo-seller", transport="streamable-http") + + # ServeConfig bundles all options — IDE autocomplete shows each field + # with its doc. The legacy kwargs form still works unchanged. + config = ServeConfig(name="typed-demo-seller", transport="streamable-http") + serve(TypedSeller(), config=config) diff --git a/src/adcp/decisioning/serve.py b/src/adcp/decisioning/serve.py index f58e1012..bf36e86c 100644 --- a/src/adcp/decisioning/serve.py +++ b/src/adcp/decisioning/serve.py @@ -426,7 +426,11 @@ def serve( spec-compliance storyboards) pass ``True``. :param serve_kwargs: Forwarded to :func:`adcp.server.serve`. Use for ``host``, ``port``, ``transport``, ``test_controller``, - ``context_factory``, ``middleware``, ``validation``, etc. + ``context_factory``, ``middleware``, ``validation``, + ``config`` (:class:`adcp.server.ServeConfig` bundle), etc. + Pass ``config=ServeConfig(transport="a2a", ...)`` to supply + all server options as a single typed object rather than + individual kwargs. Pass ``validation=ValidationHookConfig(requests="strict", responses="strict")`` to enable schema-driven request/response validation against the bundled AdCP JSON schemas — sellers who diff --git a/src/adcp/server/__init__.py b/src/adcp/server/__init__.py index 6cbd71c9..dd5e4cf9 100644 --- a/src/adcp/server/__init__.py +++ b/src/adcp/server/__init__.py @@ -133,6 +133,7 @@ async def get_products(params, context=None): ASGIMiddlewareEntry, ContextFactory, RequestMetadata, + ServeConfig, SkillMiddleware, create_mcp_server, serve, @@ -181,6 +182,7 @@ async def get_products(params, context=None): "DISCOVERY_TOOLS", "MCPToolSet", "RequestMetadata", + "ServeConfig", "create_mcp_tools", "create_mcp_server", "get_tools_for_handler", diff --git a/src/adcp/server/serve.py b/src/adcp/server/serve.py index 4216378f..671f0bc8 100644 --- a/src/adcp/server/serve.py +++ b/src/adcp/server/serve.py @@ -21,7 +21,7 @@ async def get_adcp_capabilities(self, params, context=None): import logging import os import warnings -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal @@ -48,8 +48,6 @@ async def get_adcp_capabilities(self, params, context=None): # import via this module. if TYPE_CHECKING: - from collections.abc import Sequence - from a2a.server.tasks.push_notification_config_store import ( PushNotificationConfigStore, ) @@ -85,6 +83,94 @@ class RequestMetadata: request_id: str | None = None +@dataclass(frozen=True) +class ServeConfig: + """Configuration bundle for :func:`serve`. + + Consolidates the 22 keyword arguments of :func:`serve` into a single + named, IDE-friendly object. Use either the bundled form or individual + kwargs — not both:: + + # Bundled (cleaner IDE signature, easy to share / reuse) + serve(MyAgent(), config=ServeConfig(name="my-agent", transport="a2a")) + + # Individual kwargs (backwards-compatible, unchanged) + serve(MyAgent(), name="my-agent", transport="a2a") + + When *config* is supplied, all field values come from it; any individual + kwargs passed alongside are ignored. To vary a single field from a + shared base config use :func:`dataclasses.replace`:: + + base = ServeConfig(name="my-agent", validation=strict) + serve(handler, config=dataclasses.replace(base, transport="a2a")) + + **Transport-specific fields** — fields marked *(A2A only)* or + *(MCP only)* are silently ignored by the other transport. Setting + cross-transport fields triggers a ``UserWarning`` at boot. + """ + + # --- Identity / networking --- + name: str = "adcp-agent" + port: int | None = None + host: str | None = None + transport: str = "streamable-http" + + # --- MCP only --- + instructions: str | None = None + streaming_responses: bool = False + + # --- A2A only --- + task_store: TaskStore | None = None + push_config_store: PushNotificationConfigStore | None = None + message_parser: MessageParser | None = None + + # --- Shared infrastructure --- + test_controller: TestControllerStore | None = None + context_factory: ContextFactory | None = None + middleware: Sequence[SkillMiddleware] | None = None + asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None = None + advertise_all: bool = False + max_request_size: int | None = None + validation: ValidationHookConfig | None = None + + # --- Discovery manifest --- + base_url: str | None = None + specialisms: list[str] | None = None + description: str | None = None + + # --- Debug endpoints --- + enable_debug_endpoints: bool = False + debug_traffic_source: Callable[[], dict[str, int]] | None = None + + def __post_init__(self) -> None: + _a2a_only = ("task_store", "push_config_store", "message_parser") + _mcp_only = ("instructions", "streaming_responses") + if self.transport == "a2a": + mcp_set = sorted( + f for f in _mcp_only if getattr(self, f) not in (None, False) + ) + if mcp_set: + warnings.warn( + f"ServeConfig sets MCP-only fields {mcp_set} but " + f"transport='a2a'. These fields will be ignored.", + UserWarning, + stacklevel=3, + ) + elif self.transport not in ("both", "streamable-http", "sse", "stdio"): + pass # unknown transport — let serve() raise a clear error + elif self.transport not in ("a2a", "both"): + a2a_set = sorted( + f for f in _a2a_only if getattr(self, f) is not None + ) + if a2a_set: + warnings.warn( + f"ServeConfig sets A2A-only fields {a2a_set} but " + f"transport={self.transport!r}. These fields will be ignored.", + UserWarning, + stacklevel=3, + ) + + SkillMiddleware = Callable[ [str, dict[str, Any], ToolContext, Callable[[], Awaitable[Any]]], Awaitable[Any], @@ -423,6 +509,7 @@ def build_context(meta: RequestMetadata) -> ToolContext: def serve( handler: ADCPHandler[Any] | Any, *, + config: ServeConfig | None = None, name: str = "adcp-agent", port: int | None = None, host: str | None = None, @@ -461,6 +548,10 @@ def serve( Args: handler: An ADCPHandler subclass instance with your tool implementations. + config: Optional :class:`ServeConfig` bundle. When supplied, all + field values come from it and any individual kwargs passed + alongside are ignored. Use ``dataclasses.replace(config, ...)`` + to vary a single field from a shared base config. name: Server name shown to clients / in the A2A agent card. port: Port to listen on. Defaults to PORT env var, then 3001. transport: ``"streamable-http"`` (default, MCP), ``"a2a"``, or @@ -637,6 +728,33 @@ async def force_account_status(self, account_id, status): serve(MyAgent(), name="my-agent", test_controller=MyStore()) """ + # When a ServeConfig bundle is provided, extract all fields from it. + # Individual kwargs are ignored so that config= is the single source of + # truth. Callers who need to vary one field should use + # dataclasses.replace(config, field=value) rather than mixing styles. + if config is not None: + name = config.name + port = config.port + host = config.host + transport = config.transport + instructions = config.instructions + test_controller = config.test_controller + context_factory = config.context_factory + task_store = config.task_store + push_config_store = config.push_config_store + middleware = config.middleware + asgi_middleware = config.asgi_middleware + message_parser = config.message_parser + advertise_all = config.advertise_all + max_request_size = config.max_request_size + streaming_responses = config.streaming_responses + validation = config.validation + enable_debug_endpoints = config.enable_debug_endpoints + debug_traffic_source = config.debug_traffic_source + base_url = config.base_url + specialisms = config.specialisms + description = config.description + # Accept ADCPServerBuilder from adcp_server() decorator pattern from adcp.server.builder import ADCPServerBuilder diff --git a/tests/test_serve_config.py b/tests/test_serve_config.py new file mode 100644 index 00000000..a4f1e8c8 --- /dev/null +++ b/tests/test_serve_config.py @@ -0,0 +1,163 @@ +"""Tests for ServeConfig dataclass and its integration with serve(). + +ServeConfig provides a bundled alternative to passing 22 individual kwargs +to serve(). When config= is supplied, values come from the dataclass; +when it's absent, individual kwargs work as before. +""" + +from __future__ import annotations + +import dataclasses +import importlib +import warnings +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from adcp.server import ServeConfig +from adcp.server.base import ADCPHandler, ToolContext + +_serve_mod = importlib.import_module("adcp.server.serve") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _StubHandler(ADCPHandler[Any]): + async def get_products(self, params: dict[str, Any], ctx: ToolContext) -> dict[str, Any]: + return {"products": []} + + +# --------------------------------------------------------------------------- +# ServeConfig basic construction +# --------------------------------------------------------------------------- + + +def test_serve_config_defaults() -> None: + cfg = ServeConfig() + assert cfg.name == "adcp-agent" + assert cfg.transport == "streamable-http" + assert cfg.port is None + assert cfg.host is None + assert cfg.advertise_all is False + assert cfg.streaming_responses is False + assert cfg.enable_debug_endpoints is False + assert cfg.middleware is None + assert cfg.validation is None + + +def test_serve_config_frozen() -> None: + cfg = ServeConfig(name="my-agent") + with pytest.raises((dataclasses.FrozenInstanceError, TypeError)): + cfg.name = "other" # type: ignore[misc] + + +def test_serve_config_replace() -> None: + base = ServeConfig(name="base", transport="a2a") + updated = dataclasses.replace(base, name="updated") + assert updated.name == "updated" + assert updated.transport == "a2a" + + +def test_serve_config_exportable_from_adcp_server() -> None: + """ServeConfig must be importable from the public adcp.server namespace.""" + import adcp.server as _server + + assert _server.ServeConfig is ServeConfig + + +# --------------------------------------------------------------------------- +# ServeConfig transport-field warnings +# --------------------------------------------------------------------------- + + +def test_serve_config_warns_a2a_only_on_mcp_transport() -> None: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + ServeConfig(transport="streamable-http", task_store=MagicMock()) + messages = [str(w.message) for w in caught if issubclass(w.category, UserWarning)] + assert any("A2A-only" in m for m in messages), messages + + +def test_serve_config_warns_mcp_only_on_a2a_transport() -> None: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + ServeConfig(transport="a2a", instructions="hello") + messages = [str(w.message) for w in caught if issubclass(w.category, UserWarning)] + assert any("MCP-only" in m for m in messages), messages + + +def test_serve_config_no_warning_on_both_transport() -> None: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + ServeConfig(transport="both", task_store=MagicMock(), instructions="hi") + user_warnings = [w for w in caught if issubclass(w.category, UserWarning)] + assert not user_warnings, "No warning expected for transport='both'" + + +def test_serve_config_no_warning_clean_config() -> None: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + ServeConfig(name="my-agent", transport="a2a") + user_warnings = [w for w in caught if issubclass(w.category, UserWarning)] + assert not user_warnings + + +# --------------------------------------------------------------------------- +# serve() respects config= over default kwargs +# --------------------------------------------------------------------------- + + +def test_serve_config_name_propagates() -> None: + handler = _StubHandler() + cfg = ServeConfig(name="from-config", transport="streamable-http", port=9999) + + with patch.object(_serve_mod, "_serve_mcp") as mock_mcp: + _serve_mod.serve(handler, config=cfg) + + mock_mcp.assert_called_once() + _, kwargs = mock_mcp.call_args + assert kwargs.get("name") == "from-config" + + +def test_serve_config_kwargs_ignored_when_config_provided() -> None: + """When config= is supplied, individual kwargs must be ignored.""" + handler = _StubHandler() + cfg = ServeConfig(name="from-config", transport="streamable-http", port=9999) + + with patch.object(_serve_mod, "_serve_mcp") as mock_mcp: + # Pass a contradicting name kwarg — config should win + _serve_mod.serve(handler, config=cfg, name="ignored-name") + + mock_mcp.assert_called_once() + _, kwargs = mock_mcp.call_args + assert kwargs.get("name") == "from-config", ( + "config.name should override the per-kwarg name when config= is provided" + ) + + +def test_serve_without_config_uses_kwargs() -> None: + """Without config=, individual kwargs must still reach the transport.""" + handler = _StubHandler() + + with patch.object(_serve_mod, "_serve_mcp") as mock_mcp: + _serve_mod.serve(handler, name="kwarg-name", transport="streamable-http") + + mock_mcp.assert_called_once() + _, kwargs = mock_mcp.call_args + assert kwargs.get("name") == "kwarg-name" + + +def test_serve_config_advertise_all_propagates() -> None: + handler = _StubHandler() + cfg = ServeConfig(transport="streamable-http", advertise_all=True) + + with patch.object(_serve_mod, "_serve_mcp") as mock_mcp: + _serve_mod.serve(handler, config=cfg) + + mock_mcp.assert_called_once() + _, kwargs = mock_mcp.call_args + assert kwargs.get("advertise_all") is True