Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/adcp/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ async def get_products(params, context=None):
update_media_buy_response,
)
from adcp.server.serve import (
ASGIMiddlewareEntry,
ContextFactory,
RequestMetadata,
SkillMiddleware,
Expand Down Expand Up @@ -186,6 +187,7 @@ async def get_products(params, context=None):
# A2A integration
"ADCPAgentExecutor",
"MessageParser",
"ASGIMiddlewareEntry",
"SkillMiddleware",
"create_a2a_server",
# Bearer-token auth middleware (seller-facing recipe)
Expand Down
92 changes: 66 additions & 26 deletions src/adcp/server/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,19 @@ def build_context(meta: RequestMetadata) -> ToolContext:
mcp = create_mcp_server(MyAgent(), context_factory=build_context)
"""

ASGIMiddlewareEntry = tuple[Callable[..., Any], dict[str, Any]] | Callable[..., Any]
"""A single ASGI middleware entry for :func:`serve`'s ``asgi_middleware`` param.

Each entry is either:

- A ``(callable, kwargs)`` tuple — invoked as ``callable(app, **kwargs)``.
Both plain class constructors and :func:`functools.partial` instances work
as the first element.
- A bare callable factory ``f(app) -> app`` — invoked as ``factory(app)``.

Both forms can be mixed in the same list.
"""


def serve(
handler: ADCPHandler[Any] | Any,
Expand All @@ -420,7 +433,7 @@ def serve(
task_store: TaskStore | None = None,
push_config_store: PushNotificationConfigStore | None = None,
middleware: Sequence[SkillMiddleware] | None = None,
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None = None,
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None = None,
message_parser: MessageParser | None = None,
advertise_all: bool = False,
max_request_size: int | None = None,
Expand Down Expand Up @@ -472,23 +485,40 @@ def serve(
rate limiting, tracing. Composes outermost-first. See
:data:`SkillMiddleware` for the signature and composition
semantics.
asgi_middleware: Optional sequence of ``(MiddlewareClass, kwargs)``
tuples — Starlette-shape ASGI middleware applied to the
outer HTTP app before uvicorn binds. Use for cross-cutting
HTTP concerns the SDK does not own: tenant resolution
(:class:`adcp.server.SubdomainTenantMiddleware`), CORS,
request-id propagation, IP allowlists, custom auth.
Composes outermost-first — the first entry sees every
request before later entries. Each class is invoked as
``cls(app, **kwargs)``. Applied on every HTTP transport
(``streamable-http``, ``a2a``, ``both``); ignored on
``stdio``.
asgi_middleware: Optional sequence of ASGI middleware entries
applied to the outer HTTP app before uvicorn binds. Use for
cross-cutting HTTP concerns the SDK does not own: tenant
resolution (:class:`adcp.server.SubdomainTenantMiddleware`),
CORS, request-id propagation, IP allowlists, custom auth.
Composes outermost-first — the first entry sees every request
before later entries. Applied on every HTTP transport
(``streamable-http``, ``sse``, ``a2a``, ``both``); ignored
on ``stdio``.

Each entry is either a ``(MiddlewareClass, kwargs)`` tuple
invoked as ``cls(app, **kwargs)``, or a callable factory
``f(app) -> app``. Both forms can appear in the same list.

Middleware sees ``lifespan`` and ``websocket`` scopes in
addition to ``http`` — guard non-HTTP scopes by passing
them through unchanged (``if scope['type'] != 'http':
await self.app(scope, receive, send); return``) so the
framework's lifespan composition still runs.

Example (tuple form)::

from starlette.middleware.cors import CORSMiddleware
serve(handler, asgi_middleware=[
(CORSMiddleware, {"allow_origins": ["*"]}),
])

Example (callable factory form, e.g. with ``functools.partial``)::

import functools
from starlette.middleware.cors import CORSMiddleware
serve(handler, asgi_middleware=[
functools.partial(CORSMiddleware, allow_origins=["*"]),
])
message_parser: Optional
:data:`~adcp.server.a2a_server.MessageParser` callable for
alternative A2A wire shapes (A2A transport only). The
Expand Down Expand Up @@ -690,11 +720,11 @@ async def force_account_status(self, account_id, status):


def _prepend_debug_endpoint(
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None,
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None,
*,
enable_debug_endpoints: bool,
debug_traffic_source: Callable[[], dict[str, int]] | None,
) -> Sequence[tuple[type, dict[str, Any]]] | None:
) -> Sequence[ASGIMiddlewareEntry] | None:
"""Prepend :class:`DebugTrafficMiddleware` to the asgi_middleware
sequence when debug endpoints are enabled.

Expand Down Expand Up @@ -728,21 +758,27 @@ def _prepend_debug_endpoint(

def _apply_asgi_middleware(
app: Any,
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None,
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None,
) -> Any:
"""Wrap ``app`` with operator-supplied Starlette-style ASGI middleware.

Each entry is ``(MiddlewareClass, kwargs)`` and is invoked as
``cls(app, **kwargs)``. Composition is outermost-first — the first
entry sees every request before later entries — so we wrap in
reverse, matching :meth:`Starlette.add_middleware` semantics.
Each entry is either ``(MiddlewareClass, kwargs)`` invoked as
``cls(app, **kwargs)``, or a callable factory ``f(app) -> app`` invoked
as ``factory(app)``. Both forms can appear in the same list. Composition
is outermost-first — the first entry sees every request before later
entries — so we wrap in reverse, matching :meth:`Starlette.add_middleware`
semantics.

No-op when the sequence is empty or ``None``.
"""
if not asgi_middleware:
return app
for cls, kwargs in reversed(list(asgi_middleware)):
app = cls(app, **kwargs)
for entry in reversed(list(asgi_middleware)):
if isinstance(entry, tuple):
cls, kwargs = entry
app = cls(app, **kwargs)
else:
app = entry(app)
return app


Expand Down Expand Up @@ -952,7 +988,7 @@ def _serve_mcp(
test_controller: TestControllerStore | None,
context_factory: ContextFactory | None = None,
middleware: Sequence[SkillMiddleware] | None = None,
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None = None,
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None = None,
advertise_all: bool = False,
max_request_size: int | None = None,
streaming_responses: bool = False,
Expand Down Expand Up @@ -985,24 +1021,28 @@ def _serve_mcp(
_run_mcp_http(
mcp,
transport=transport,
max_request_size=max_request_size,
asgi_middleware=asgi_middleware,
max_request_size=max_request_size,
discovery_name=name,
discovery_base_url=base_url,
discovery_specialisms=specialisms,
discovery_description=description,
)
else:
# stdio — no listening socket, nothing to configure.
if asgi_middleware:
logger.warning(
"asgi_middleware is ignored on transport='stdio'; " "ASGI middleware will not run"
)
mcp.run(transport=transport)


def _run_mcp_http(
mcp: Any,
*,
transport: str,
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None = None,
max_request_size: int | None = None,
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None = None,
discovery_name: str = "adcp-agent",
discovery_base_url: str | None = None,
discovery_specialisms: list[str] | None = None,
Expand Down Expand Up @@ -1080,7 +1120,7 @@ def _serve_a2a(
task_store: TaskStore | None = None,
push_config_store: PushNotificationConfigStore | None = None,
middleware: Sequence[SkillMiddleware] | None = None,
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None = None,
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None = None,
message_parser: MessageParser | None = None,
advertise_all: bool = False,
max_request_size: int | None = None,
Expand Down Expand Up @@ -1287,7 +1327,7 @@ def _serve_mcp_and_a2a(
task_store: TaskStore | None = None,
push_config_store: PushNotificationConfigStore | None = None,
middleware: Sequence[SkillMiddleware] | None = None,
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None = None,
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None = None,
message_parser: MessageParser | None = None,
advertise_all: bool = False,
max_request_size: int | None = None,
Expand Down
58 changes: 56 additions & 2 deletions tests/test_serve_asgi_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
Operators wiring tenant routing, CORS, request-id propagation, and
custom auth use this kwarg to layer Starlette-style ASGI middleware
on the outer HTTP app before uvicorn binds. The kwarg accepts a
sequence of ``(MiddlewareClass, kwargs)`` tuples and composes
outermost-first.
sequence of ``(MiddlewareClass, kwargs)`` tuples, callable factories,
or a mix of both, and composes outermost-first.
"""

from __future__ import annotations

import functools

from adcp.server.serve import _apply_asgi_middleware


Expand Down Expand Up @@ -65,3 +67,55 @@ def test_apply_asgi_middleware_passes_kwargs_through():
assert isinstance(wrapped, _TaggingMiddleware)
assert wrapped.name == "audit"
assert wrapped.app is app


def test_apply_asgi_middleware_callable_factory():
"""Callable factory form ``f(app) -> app`` is accepted."""
app = _NoOpAsgi()

def cors_factory(inner):
return _TaggingMiddleware(inner, name="cors")

wrapped = _apply_asgi_middleware(app, [cors_factory])
assert isinstance(wrapped, _TaggingMiddleware)
assert wrapped.name == "cors"
assert wrapped.app is app


def test_apply_asgi_middleware_callable_factory_with_partial():
"""``functools.partial`` is a valid callable factory."""
app = _NoOpAsgi()
factory = functools.partial(_TaggingMiddleware, name="partial-cors")
wrapped = _apply_asgi_middleware(app, [factory])
assert isinstance(wrapped, _TaggingMiddleware)
assert wrapped.name == "partial-cors"
assert wrapped.app is app


def test_apply_asgi_middleware_mixed_tuple_and_callable_preserves_order():
"""Mixed list composes outermost-first regardless of entry type.

Given ``[tuple_entry("outer"), callable("middle"), tuple_entry("inner")]``,
the result must be outer → middle → inner → app, verified by walking
the ``.app`` chain.
"""
app = _NoOpAsgi()

def middle_factory(inner):
return _TaggingMiddleware(inner, name="middle")

wrapped = _apply_asgi_middleware(
app,
[
(_TaggingMiddleware, {"name": "outer"}),
middle_factory,
(_TaggingMiddleware, {"name": "inner"}),
],
)
assert isinstance(wrapped, _TaggingMiddleware)
assert wrapped.name == "outer"
assert isinstance(wrapped.app, _TaggingMiddleware)
assert wrapped.app.name == "middle"
assert isinstance(wrapped.app.app, _TaggingMiddleware)
assert wrapped.app.app.name == "inner"
assert wrapped.app.app.app is app
Loading