diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 8c746fe92..6368da0ac 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -321,6 +321,7 @@ def on_event( @controller.command(name="run") @click.argument("name", type=str) @click.argument("parameters", type=ParametersType(), default={}, required=False) +@click.option("--ws", type=bool, is_flag=True, default=False) @click.option( "--foreground/--background", "--fg/--bg", type=bool, is_flag=True, default=True ) @@ -348,6 +349,7 @@ def run_plan( name: str, timeout: float | None, foreground: bool, + ws: bool, instrument_session: str, parameters: TaskParameters, ) -> None: @@ -374,7 +376,13 @@ def on_event(event: AnyEvent) -> None: elif isinstance(event, DataEvent): callback(event.name, event.doc) - resp = client.run_task(task, on_event=on_event) + client.add_callback(on_event) + + if ws: + resp = client.run_blocking(task) + else: + resp = client.run_task(task) + match resp.result: case TaskResult(result=None, type="NoneType"): print("Plan succeeded") diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 20de15892..3d6e3bf6a 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -459,6 +459,26 @@ def get_active_task(self) -> WorkerTask: return self.active_task + @start_as_current_span(TRACER, "request") + def run_blocking( + self, request: TaskRequest, on_event: OnAnyEvent | None = None + ) -> TaskStatus: + for event in self._rest.run_blocking(request): + if on_event is not None: + on_event(event) + for cb in self._callbacks.values(): + try: + cb(event) + except Exception as e: + log.error(f"Callback ({cb}) failed for event: {event}", exc_info=e) + if isinstance(event, WorkerEvent) and event.is_complete(): + if event.task_status is None: + raise BlueskyRemoteControlError( + "Server completed without task status" + ) + return event.task_status + raise BlueskyRemoteControlError("Connection closed before plan completed.") + @start_as_current_span(TRACER, "task", "timeout") def run_task( self, diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 0bddb5c87..7cb9cd830 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -1,6 +1,6 @@ import json import logging -from collections.abc import Callable, Mapping +from collections.abc import Callable, Iterable, Mapping from typing import Any, Literal, TypeVar import requests @@ -10,12 +10,15 @@ get_tracer, start_as_current_span, ) -from pydantic import BaseModel, TypeAdapter, ValidationError +from pydantic import BaseModel, TypeAdapter, ValidationError, WebsocketUrl from pydantic_core import PydanticSerializationError +from websockets.exceptions import InvalidStatus +from websockets.sync.client import connect from blueapi import __version__ from blueapi.client import client from blueapi.config import RestConfig +from blueapi.core.bluesky_types import DataEvent from blueapi.service.authentication import JWTAuth, SessionManager from blueapi.service.model import ( DeviceModel, @@ -31,7 +34,15 @@ TasksListResponse, WorkerTask, ) +from blueapi.service.protocol import ( + ControlResponse, + InvalidArgs, + PlanNotFound, + Submit, + Update, +) from blueapi.worker import TrackableTask, WorkerState +from blueapi.worker.event import ProgressEvent, WorkerEvent T = TypeVar("T") @@ -39,6 +50,8 @@ LOGGER = logging.getLogger(__name__) +USER_AGENT = f"blueapi cli {__version__}" + class BlueskyRequestError(Exception): """An error response from the blueapi server.""" @@ -86,8 +99,8 @@ def __init__(self, target_type: type) -> None: class ParameterError(BaseModel): loc: list[str | int] - msg: str - type: str + msg: str | None + type: str | None input: Any def field(self): @@ -307,14 +320,15 @@ def _request_and_deserialize( ) -> T: url = self._config.url.unicode_string().removesuffix("/") + suffix # Get the trace context to propagate to the REST API - carr = get_context_propagator() + headers = get_context_propagator() + headers["User-Agent"] = USER_AGENT try: response = self._pool.request( method, url, json=data, params=params, - headers=carr, + headers=headers, auth=JWTAuth(self._session_manager), ) except requests.exceptions.ConnectionError as ce: @@ -340,6 +354,53 @@ def _request_and_deserialize( ) return deserialized + def run_blocking( + self, req: TaskRequest + ) -> Iterable[DataEvent | WorkerEvent | ProgressEvent]: + url = self._ws_address().unicode_string().rstrip("/") + "/api/v2/run_plan" + headers = get_context_propagator() + if self._session_manager: + auth = self._session_manager.get_valid_access_token() + headers["Authorization"] = f"Bearer {auth}" + try: + with connect( + url, + additional_headers=headers, + user_agent_header=USER_AGENT, + ) as ws: + ws.send(Submit(task=req).model_dump_json()) + for message in ws: + event = ControlResponse.validate_json(message) + match event: + case Update(data=data): + yield data + case InvalidArgs(errors=errors): + raise InvalidParametersError( + [ + ParameterError( + loc=e.loc, msg=e.msg, type=e.type, input=e.input + ) + for e in errors + ] + ) + case PlanNotFound(plan_name=name): + raise UnknownPlanError(message=name) + except InvalidStatus as istat: + match istat.response.status_code: + case 401 | 403: + raise UnauthorisedAccessError() from None + print(vars(istat)) + return + + def _ws_address(self) -> WebsocketUrl: + api = self._config.url + if api.host is None: + raise ValueError("No host configured") + scheme = "ws" if api.scheme == "http" else "wss" + return WebsocketUrl.build( + scheme=scheme, host=api.host, port=api.port, path=api.path + ) + # https://github.com/DiamondLightSource/blueapi/issues/1256 - remove before 2.0 def __getattr__(name: str): diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index 6761256de..614f42b7c 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -15,7 +15,8 @@ import httpx import jwt import requests -from fastapi import Depends, HTTPException, Request +from fastapi import Cookie, Depends, Header, HTTPException +from fastapi.requests import HTTPConnection from fastapi.security.utils import get_authorization_scheme_param from pydantic import TypeAdapter from requests.auth import AuthBase @@ -278,14 +279,16 @@ def sync_auth_flow(self, request): yield request -def unchecked_bearer_token(req: Request) -> str | None: +def unchecked_bearer_token( + auth_header: str | None = Header(alias="Authorization", default=None), + auth_cookie: str | None = Cookie(alias="Authorization", default=None), +) -> str | None: """Get bearer token value from authorization header""" # This is an abridged version of the same feature of # OAuth2AuthorizationCodeBearer from fastapi. Replicating here prevents # passing unused configuration and means the schema does not include auth # details for servers that do not support it. - auth = req.headers.get("Authorization") - scheme, param = get_authorization_scheme_param(auth) + scheme, param = get_authorization_scheme_param(auth_header or auth_cookie) if scheme.casefold() != "bearer": return None return param.strip() @@ -303,7 +306,7 @@ def build_access_token_check(config: OIDCConfig): """ jwkclient = jwt.PyJWKClient(config.jwks_uri) - def validate_bearer_token(request: Request, token: UncheckedBearerToken): + def validate_bearer_token(request: HTTPConnection, token: UncheckedBearerToken): """Check that a bearer token is valid and inject into request state""" if not token: raise HTTPException( @@ -326,7 +329,7 @@ def validate_bearer_token(request: Request, token: UncheckedBearerToken): return validate_bearer_token -def access_token(request: Request) -> Mapping[str, Any] | None: +def access_token(request: HTTPConnection) -> Mapping[str, Any] | None: """Get the decoded and verified access token of the user making the request""" return getattr(request.state, "decoded_access_token", None) diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 335d00477..65c59242b 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -1,5 +1,8 @@ +import logging from collections.abc import Mapping +from dataclasses import dataclass from functools import cache +from multiprocessing.connection import Connection from typing import Any from bluesky.callbacks.tiled_writer import TiledWriter @@ -9,6 +12,7 @@ from blueapi.cli.scratch import get_python_environment from blueapi.config import ApplicationConfig, OIDCConfig, ServiceAccount, StompConfig +from blueapi.core.bluesky_types import DataEvent from blueapi.core.context import BlueskyContext from blueapi.core.event import EventStream from blueapi.log import set_up_logging @@ -22,14 +26,14 @@ WorkerTask, ) from blueapi.utils.serialization import access_blob -from blueapi.worker.event import TaskStatusEnum, WorkerEvent, WorkerState +from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TaskWorker, TrackableTask """This module provides interface between web application and underlying Bluesky context and worker""" - +LOGGER = logging.getLogger(__name__) _CONFIG: ApplicationConfig = ApplicationConfig() @@ -286,3 +290,37 @@ def get_python_env( """Retrieve information about the Python environment""" scratch = config().scratch return get_python_environment(config=scratch, name=name, source=source) + + +@dataclass +class SubHandles: + worker: int + progress: int + data: int + + +def pipe_events(tx: Connection) -> SubHandles: + tw = worker() + + def handler( + worker_event: WorkerEvent | DataEvent | ProgressEvent, + _cor_id: str | None, + ) -> None: + + try: + tx.send(worker_event) + except BrokenPipeError: + LOGGER.warning("Sending event to broken pipe") + pass + + w = tw.worker_events.subscribe(handler) + d = tw.data_events.subscribe(handler) + p = tw.progress_events.subscribe(handler) + return SubHandles(worker=w, data=d, progress=p) + + +def unpipe_events(hnd: SubHandles) -> None: + tw = worker() + tw.worker_events.unsubscribe(hnd.worker) + tw.data_events.unsubscribe(hnd.data) + tw.progress_events.unsubscribe(hnd.progress) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index f3c343564..a01ddbcfd 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -14,6 +14,8 @@ HTTPException, Request, Response, + WebSocket, + WebSocketDisconnect, status, ) from fastapi.datastructures import Address @@ -32,14 +34,27 @@ from super_state_machine.errors import TransitionError from blueapi.config import ApplicationConfig, OIDCConfig, Tag +from blueapi.core.bluesky_types import DataEvent from blueapi.service import interface -from blueapi.service.authentication import Fedid, build_access_token_check +from blueapi.service.authentication import ( + Fedid, + build_access_token_check, +) from blueapi.service.middleware import ( ObservabilityContextPropagator, VersionHeaders, + WebsocketTracing, +) +from blueapi.service.protocol import ( + ControlRequest, + InvalidArgs, + PlanNotFound, + ServerBusy, + Update, ) from blueapi.worker import TrackableTask, WorkerState -from blueapi.worker.event import TaskStatusEnum +from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent +from blueapi.worker.worker_errors import WorkerBusyError from .authorization import OpaClient, validate_tiled_config from .model import ( @@ -66,6 +81,9 @@ TRACER = get_tracer("interface") +AnyEvent = WorkerEvent | DataEvent | ProgressEvent + + def _runner() -> WorkerDispatcher: """Intended to be used only with FastAPI Depends""" if RUNNER is None: @@ -109,6 +127,7 @@ async def inner(app: FastAPI): open_router = APIRouter() secure_router = APIRouter(deprecated=True) secure_router_v1 = APIRouter(prefix="/api/v1") +secure_router_v2 = APIRouter(prefix="/api/v2") def get_app(config: ApplicationConfig): @@ -130,12 +149,14 @@ def get_app(config: ApplicationConfig): } app.include_router(open_router) app.include_router(secure_router_v1, dependencies=dependencies) + app.include_router(secure_router_v2, dependencies=dependencies) app.include_router(secure_router, dependencies=dependencies) app.add_exception_handler(KeyError, on_key_error_404) app.add_exception_handler(jwt.PyJWTError, on_token_error_401) app.add_middleware(ObservabilityContextPropagator) app.add_middleware(VersionHeaders) + app.add_middleware(WebsocketTracing) app.middleware("http")(log_request_details) if config.api.cors: app.add_middleware( @@ -564,6 +585,61 @@ def logout(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> Response: ) +@secure_router_v2.websocket("/run_plan") +async def run_plan( + ws: WebSocket, runner: Annotated[WorkerDispatcher, Depends(_runner)], user: Fedid +): + LOGGER.info("Starting WS plan as %s", user) + await ws.accept() + rq = await ws.receive_text() + try: + task_request = ControlRequest.validate_json(rq) + except ValidationError: + LOGGER.error("Failed to deserialize request", exc_info=True) + await ws.close(code=1007, reason="Invalid Request") + return + LOGGER.info("Plan request: %s", task_request) + + try: + task_id: str = runner.run( + interface.submit_task, task_request.task, {"user": user} + ) + LOGGER.info("Task ID: %s", task_id) + except ValidationError as ve: + LOGGER.info("Plan args not valid: %s - %s", task_request, ve) + await ws.send_text(InvalidArgs.from_validation_error(ve).model_dump_json()) + await ws.close(code=4002, reason="Invalid Args") + return + except KeyError as ke: + LOGGER.error("Plan %r not found", ke.args[0]) + await ws.send_text(PlanNotFound(plan_name=ke.args[0]).model_dump_json()) + await ws.close(code=4001, reason="unknown plan") + return + + try: + with runner.event_pipe() as events: + LOGGER.info("Created event pipe") + runner.run(interface.begin_task, task=WorkerTask(task_id=task_id)) + async for evt in events: + LOGGER.debug("Event: %s", evt) + await ws.send_text(Update(data=evt).model_dump_json()) + if isinstance(evt, WorkerEvent) and evt.is_complete(): + LOGGER.info("End of stream") + break + except WorkerBusyError: + LOGGER.error("Worker was busy") + await ws.send_text(ServerBusy().model_dump_json()) + await ws.close(code=1013, reason="Worker busy") + except WebSocketDisconnect: + LOGGER.info("Client disconnected") + runner.run( + interface.cancel_active_task, failure=True, reason="Client disconnected" + ) + else: + LOGGER.info("Plan complete") + await ws.close() + + @start_as_current_span(TRACER, "config") def start(config: ApplicationConfig): import uvicorn diff --git a/src/blueapi/service/middleware.py b/src/blueapi/service/middleware.py index b31fe0fb9..cb94747c5 100644 --- a/src/blueapi/service/middleware.py +++ b/src/blueapi/service/middleware.py @@ -1,4 +1,6 @@ import logging +import uuid +from collections.abc import Iterable from opentelemetry.context import attach from opentelemetry.propagate import get_global_textmap @@ -8,6 +10,7 @@ from blueapi.config import ApplicationConfig OBS_LOGGER = logging.getLogger("blueapi.service.middleware.observability") +WS_LOGGER = logging.getLogger("blueapi.service.middleware.websocket") CONTEXT_HEADER = ApplicationConfig.CONTEXT_HEADER.encode() VENDOR_CONTEXT_HEADER = ApplicationConfig.VENDOR_CONTEXT_HEADER.encode() @@ -56,3 +59,79 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): attach(get_global_textmap().extract(carrier)) return await self.app(scope, receive, send) + + +Header = tuple[bytes, bytes] + + +def _redact_headers(headers: list[Header] | None) -> Iterable[Header]: + for key, value in headers or []: + if key == b"authorization": + if (space := value.find(b" ")) >= 0: + value = value[:space] + b" [REDACTED]" + yield (key, value) + + +class WebsocketTracing: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + active = WS_LOGGER.isEnabledFor(logging.DEBUG) + + if scope.get("type") != "websocket" or not active: + return await self.app(scope, receive, send) + + conn_id = uuid.uuid4() + client: tuple[str, int] = scope.get("client", ("unknown", 0)) + extra = {"conn": conn_id, "client": client} + + WS_LOGGER.debug( + "New Connection from %r", + {**scope, "headers": list(_redact_headers(scope.get("headers")))}, + extra=extra, + ) + + async def local_send(msg: Message): + match msg.get("type"): + case "websocket.send": + WS_LOGGER.debug("Sending: %r", msg.get("text"), extra=extra) + case "websocket.accept": + WS_LOGGER.debug( + "Accepting websocket - sending headers: %r", + msg.get("headers"), + extra=extra, + ) + case "websocket.close": + WS_LOGGER.debug( + "Closing with code: %r, reason: %r", + msg.get("code"), + msg.get("reason"), + extra=extra, + ) + case "websocket.http.response.start": + WS_LOGGER.debug( + "HTTP Response: status=%r, headers=%r", + msg.get("status"), + msg.get("headers"), + extra=extra, + ) + case "websocket.http.response.body": + WS_LOGGER.debug( + "HTTP Response Content: %r", msg.get("body"), extra=extra + ) + case _: + WS_LOGGER.debug("Sending other: %r", msg, extra=extra) + + await send(msg) + + async def local_receive() -> Message: + message = await receive() + match message.get("type"): + case "websocket.receive": + WS_LOGGER.debug("Received: %r", message.get("text")) + case "websocket.connect": + WS_LOGGER.debug("New connection from %s:%d", *client) + return message + + return await self.app(scope, local_receive, local_send) diff --git a/src/blueapi/service/protocol.py b/src/blueapi/service/protocol.py new file mode 100644 index 000000000..67924318a --- /dev/null +++ b/src/blueapi/service/protocol.py @@ -0,0 +1,94 @@ +""" +The application level sub-protocol used to communicate between the server and +client when running plans via websockets +""" + +# Client to server +# * Submit task +# * Pause +# * Resume +# * Abort +# +# Server to client +# * Plan not found +# * Args not valid +# * Server busy +# * Event update + +from typing import Annotated, Any, Literal, Self + +from pydantic import BaseModel, Field, TypeAdapter, ValidationError + +from blueapi.core.bluesky_types import DataEvent +from blueapi.service.model import TaskRequest +from blueapi.worker.event import ProgressEvent, WorkerEvent + + +class ArgumentError(BaseModel): + loc: list[str | int] + msg: str | None + type: str | None + input: Any + + +class Submit(BaseModel): + kind: Literal["submit"] = "submit" + task: TaskRequest + + +class Pause(BaseModel): + kind: Literal["pause"] = "pause" + + +class Resume(BaseModel): + kind: Literal["resume"] = "resume" + + +class Abort(BaseModel): + kind: Literal["abort"] = "abort" + reason: str | None = None + + +ControlRequest = TypeAdapter( + Annotated[Submit | Pause | Resume | Abort, Field(discriminator="kind")] +) + + +class PlanNotFound(BaseModel): + kind: Literal["plan_not_found"] = "plan_not_found" + plan_name: str + + +class InvalidArgs(BaseModel): + kind: Literal["invalid_args"] = "invalid_args" + errors: list[ArgumentError] + + @classmethod + def from_validation_error(cls, e: ValidationError) -> Self: + errors = [ + ArgumentError( + loc=["body", "params", *err.get("loc", [])], + msg=err.get("msg", None), + type=err.get("type", None), + # Input is not listed as required but is useful to have if available + input=err.get("input", None), + ) + for err in e.errors() + ] + return cls(errors=errors) + + +class ServerBusy(BaseModel): + kind: Literal["busy"] = "busy" + + +class Update(BaseModel): + kind: Literal["update"] = "update" + data: WorkerEvent | DataEvent | ProgressEvent + + +ControlResponse = TypeAdapter( + Annotated[ + PlanNotFound | InvalidArgs | ServerBusy | Update, Field(discriminator="kind") + ] +) diff --git a/src/blueapi/service/runner.py b/src/blueapi/service/runner.py index 2b5a5f37f..83153ed5f 100644 --- a/src/blueapi/service/runner.py +++ b/src/blueapi/service/runner.py @@ -1,10 +1,12 @@ +import asyncio import inspect import logging import signal import uuid -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable from importlib import import_module from multiprocessing import Pool, set_start_method +from multiprocessing.connection import Connection, Pipe from multiprocessing.pool import Pool as PoolClass from typing import Any, ParamSpec, TypeVar @@ -18,8 +20,11 @@ from pydantic import TypeAdapter from blueapi.config import ApplicationConfig -from blueapi.service.interface import setup, teardown +from blueapi.core.bluesky_types import DataEvent +from blueapi.service import interface +from blueapi.service.interface import SubHandles, setup, teardown from blueapi.service.model import EnvironmentResponse +from blueapi.worker.event import ProgressEvent, WorkerEvent # The default multiprocessing start method is fork set_start_method("spawn", force=True) @@ -145,11 +150,57 @@ def run( kwargs, ) + def event_pipe(self): + return EventPipe(self) + @property def state(self) -> EnvironmentResponse: return self._state +class EventStream: + def __init__(self, rx: Connection): + self._rx = rx + + def __aiter__(self) -> AsyncIterator[WorkerEvent | DataEvent | ProgressEvent]: + return self + + async def __anext__(self) -> WorkerEvent | DataEvent | ProgressEvent: + data_available = asyncio.Event() + asyncio.get_event_loop().add_reader(self._rx.fileno(), data_available.set) + try: + while not self._rx.poll(): + await data_available.wait() + data_available.clear() + return self._rx.recv() + except BrokenPipeError: + raise StopAsyncIteration() from None + finally: + asyncio.get_event_loop().remove_reader(self._rx.fileno()) + + +class EventPipe: + runner: WorkerDispatcher + handles: list[tuple[SubHandles, Connection]] + + def __init__(self, runner: WorkerDispatcher): + self.runner = runner + self.handles = [] + + def __enter__(self) -> EventStream: + tx, rx = Pipe() + hnd = self.runner.run(interface.pipe_events, tx) + LOGGER.debug("Subscribing new event pipe: %s", hnd) + self.handles.append((hnd, tx)) + return EventStream(rx) + + def __exit__(self, *exc): + hnd, conn = self.handles.pop() + LOGGER.debug("Unsubscribing event pipe: %s", hnd) + conn.close() + self.runner.run(interface.unpipe_events, hnd) + + class InvalidRunnerStateError(Exception): def __init__(self, message): super().__init__(message) diff --git a/tests/unit_tests/cli/test_cli.py b/tests/unit_tests/cli/test_cli.py index 3e27cfc98..fe71825c1 100644 --- a/tests/unit_tests/cli/test_cli.py +++ b/tests/unit_tests/cli/test_cli.py @@ -7,7 +7,6 @@ from pathlib import Path from textwrap import dedent from typing import Any, TypeVar -from unittest import mock from unittest.mock import Mock, patch import pytest @@ -385,9 +384,9 @@ def test_run_plan_feedback( main, ["controller", "run", "-i", "cm12345-1", "name"], ) + bc.add_callback.assert_called_once() bc.run_task.assert_called_once_with( TaskRequest(name="name", params={}, instrument_session="cm12345-1"), - on_event=mock.ANY, ) assert res.exit_code == 0 assert res.stdout == message diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py index 01bc426e2..e3ee86f9d 100644 --- a/tests/unit_tests/service/test_authentication.py +++ b/tests/unit_tests/service/test_authentication.py @@ -189,18 +189,25 @@ def test_tiled_auth_sync_auth_flow(): @pytest.mark.parametrize( - "header,token", + "header,cookie,token", [ - (None, None), - ("ApiKey foobar", None), - ("Bearer foobar", "foobar"), - ("Bearer with_whitespace ", "with_whitespace"), - ("Bearerfoobar", None), + (None, None, None), + ("", None, None), + ("ApiKey foobar", None, None), + ("Bearer foobar", None, "foobar"), + ("Bearer with_whitespace ", None, "with_whitespace"), + ("Bearerfoobar", None, None), + (None, "Bearer foobar", "foobar"), + ("", "Bearer foo", "foo"), + ("Bearer foo", "bearer bar", "foo"), ], ) -def test_unchecked_bearer_token(header: str | None, token: str | None): +def test_unchecked_bearer_token( + header: str | None, cookie: str | None, token: str | None +): req = Mock() req.headers.get.side_effect = lambda key: header if key == "Authorization" else None + req.cookies.get.side_effect = lambda key: cookie if key == "Authorization" else None assert unchecked_bearer_token(req) == token diff --git a/tests/unit_tests/service/test_protocol.py b/tests/unit_tests/service/test_protocol.py new file mode 100644 index 000000000..ef95bb8ce --- /dev/null +++ b/tests/unit_tests/service/test_protocol.py @@ -0,0 +1,70 @@ +from typing import Any + +import pytest + +from blueapi.service.model import TaskRequest +from blueapi.service.protocol import ( + Abort, + ArgumentError, + ControlRequest, + ControlResponse, + InvalidArgs, + Pause, + Resume, + Submit, +) + + +@pytest.mark.parametrize( + "src,res", + [ + ( + """{ + "kind": "submit", + "task": { + "name": "foo", + "instrument_session": "cm12345-1" + } + }""", + Submit( + task=TaskRequest(name="foo", params={}, instrument_session="cm12345-1") + ), + ), + ('{"kind": "pause"}', Pause()), + ('{"kind": "resume"}', Resume()), + ('{"kind": "abort"}', Abort()), + ], +) +def test_request_deserialization(src: str, res: Any): + req = ControlRequest.validate_json(src) + assert req == res + + +@pytest.mark.parametrize( + "src,res", + [ + ( + """{ + "kind": "invalid_args", + "errors":[{ + "loc":["body","params","spec"], + "msg":"error_message", + "type":"error_type", + "input":"original input" + }]}""", + InvalidArgs( + errors=[ + ArgumentError( + loc=["body", "params", "spec"], + msg="error_message", + type="error_type", + input="original input", + ) + ] + ), + ), + ], +) +def test_response_deserialization(src: str, res: Any): + req = ControlResponse.validate_json(src) + assert req == res