Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
6077054
wip
abbiemery Feb 24, 2026
cdf023a
client wip
abbiemery Feb 24, 2026
a176322
use normal iter
abbiemery Feb 24, 2026
477d81a
close ws
abbiemery Feb 24, 2026
afc87a8
add some trys
abbiemery Feb 24, 2026
219b58a
unpipe
abbiemery Feb 24, 2026
89a8c77
Move websocket handling into BlueapiRestClient
tpoliaw Mar 3, 2026
0643d84
Send all events through websocket
tpoliaw Mar 3, 2026
999784e
Split pipe subscribe handles
tpoliaw Mar 4, 2026
01e5ae6
Re-use run subcommand for websockets
tpoliaw Mar 4, 2026
2a119e3
Raise for connection closing pre plan completed
abbiemery Mar 4, 2026
3b1d1f6
Remove run blocking from cli
abbiemery Mar 4, 2026
98495ca
Catch plan key error in run_plan
abbiemery Mar 4, 2026
4b7817e
Refactor event pipe handling into context manager and iterable
tpoliaw Mar 6, 2026
6c9b021
Testing auth tokens
tpoliaw Mar 12, 2026
e89a46b
Re-use existing auth dependency for websocket endpoint
tpoliaw Mar 12, 2026
154d8ae
Add user auth token in websocket client
tpoliaw Mar 12, 2026
8bd4b65
Read authorization from cookie as well as header
tpoliaw Mar 12, 2026
4408ecd
Add user agent to websocket request
tpoliaw Mar 12, 2026
3049944
Add user agent to all requests
tpoliaw Apr 7, 2026
ff8d916
Use new fedid dependency for user name
tpoliaw Jun 29, 2026
12d9623
Test auth from cookie
tpoliaw Jun 29, 2026
4df7628
Fix CLI event handler test
tpoliaw Jun 29, 2026
9e2658e
Reinstate _valid_return check
tpoliaw Jun 29, 2026
23ed486
Use versioned api for websockets
tpoliaw Jun 29, 2026
431332f
Add type annotation to unpipe
tpoliaw Jun 30, 2026
a25cd56
Move ws endpoint to v2 api
tpoliaw Jun 30, 2026
793182d
Use Depends for header and cookie
tpoliaw Mar 12, 2026
f8db53d
Add sub-protocol to ws communication
tpoliaw Apr 17, 2026
de2fea2
Used configured host for websockets
tpoliaw Apr 17, 2026
d3db28a
Add debug logging of all websocket traffic
tpoliaw Apr 20, 2026
f22ea82
Include connection info in logging
tpoliaw Apr 20, 2026
0fccf70
Split receive logging by type
tpoliaw Apr 20, 2026
4bfad5e
Correct typing in rest run_blocking
tpoliaw Apr 29, 2026
b8bb5f9
Use rstrip instead of removesuffix to remove multiple trailing slashes
tpoliaw Jul 1, 2026
b73fc3e
Redact auth tokens in websocket logging
tpoliaw Jul 1, 2026
8413977
Use send_text instead of send json
tpoliaw Jul 1, 2026
11b43f8
Improve error handling
tpoliaw Jul 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def on_event(
@controller.command(name="run")
@click.argument("name", type=str)
@click.argument("parameters", type=ParametersType(), default={}, required=False)
@click.option("--ws", type=bool, is_flag=True, default=False)
@click.option(
"--foreground/--background", "--fg/--bg", type=bool, is_flag=True, default=True
)
Expand Down Expand Up @@ -348,6 +349,7 @@ def run_plan(
name: str,
timeout: float | None,
foreground: bool,
ws: bool,
instrument_session: str,
parameters: TaskParameters,
) -> None:
Expand All @@ -374,7 +376,13 @@ def on_event(event: AnyEvent) -> None:
elif isinstance(event, DataEvent):
callback(event.name, event.doc)

resp = client.run_task(task, on_event=on_event)
client.add_callback(on_event)

if ws:
resp = client.run_blocking(task)
else:
resp = client.run_task(task)

match resp.result:
case TaskResult(result=None, type="NoneType"):
print("Plan succeeded")
Expand Down
20 changes: 20 additions & 0 deletions src/blueapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,26 @@ def get_active_task(self) -> WorkerTask:

return self.active_task

@start_as_current_span(TRACER, "request")
def run_blocking(
self, request: TaskRequest, on_event: OnAnyEvent | None = None
) -> TaskStatus:
for event in self._rest.run_blocking(request):
if on_event is not None:
on_event(event)
for cb in self._callbacks.values():
try:
cb(event)
except Exception as e:
log.error(f"Callback ({cb}) failed for event: {event}", exc_info=e)
if isinstance(event, WorkerEvent) and event.is_complete():
if event.task_status is None:
raise BlueskyRemoteControlError(
"Server completed without task status"
)
return event.task_status
raise BlueskyRemoteControlError("Connection closed before plan completed.")

@start_as_current_span(TRACER, "task", "timeout")
def run_task(
self,
Expand Down
73 changes: 67 additions & 6 deletions src/blueapi/client/rest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from collections.abc import Callable, Mapping
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Literal, TypeVar

import requests
Expand All @@ -10,12 +10,15 @@
get_tracer,
start_as_current_span,
)
from pydantic import BaseModel, TypeAdapter, ValidationError
from pydantic import BaseModel, TypeAdapter, ValidationError, WebsocketUrl
from pydantic_core import PydanticSerializationError
from websockets.exceptions import InvalidStatus
from websockets.sync.client import connect

from blueapi import __version__
from blueapi.client import client
from blueapi.config import RestConfig
from blueapi.core.bluesky_types import DataEvent
from blueapi.service.authentication import JWTAuth, SessionManager
from blueapi.service.model import (
DeviceModel,
Expand All @@ -31,14 +34,24 @@
TasksListResponse,
WorkerTask,
)
from blueapi.service.protocol import (
ControlResponse,
InvalidArgs,
PlanNotFound,
Submit,
Update,
)
from blueapi.worker import TrackableTask, WorkerState
from blueapi.worker.event import ProgressEvent, WorkerEvent

T = TypeVar("T")

TRACER = get_tracer("rest")

LOGGER = logging.getLogger(__name__)

USER_AGENT = f"blueapi cli {__version__}"


class BlueskyRequestError(Exception):
"""An error response from the blueapi server."""
Expand Down Expand Up @@ -86,8 +99,8 @@ def __init__(self, target_type: type) -> None:

class ParameterError(BaseModel):
loc: list[str | int]
msg: str
type: str
msg: str | None
type: str | None
input: Any

def field(self):
Expand Down Expand Up @@ -307,14 +320,15 @@ def _request_and_deserialize(
) -> T:
url = self._config.url.unicode_string().removesuffix("/") + suffix
# Get the trace context to propagate to the REST API
carr = get_context_propagator()
headers = get_context_propagator()
headers["User-Agent"] = USER_AGENT
try:
response = self._pool.request(
method,
url,
json=data,
params=params,
headers=carr,
headers=headers,
auth=JWTAuth(self._session_manager),
)
except requests.exceptions.ConnectionError as ce:
Expand All @@ -340,6 +354,53 @@ def _request_and_deserialize(
)
return deserialized

def run_blocking(
self, req: TaskRequest
) -> Iterable[DataEvent | WorkerEvent | ProgressEvent]:
url = self._ws_address().unicode_string().rstrip("/") + "/api/v2/run_plan"
headers = get_context_propagator()
if self._session_manager:
auth = self._session_manager.get_valid_access_token()
headers["Authorization"] = f"Bearer {auth}"
try:
with connect(
url,
additional_headers=headers,
user_agent_header=USER_AGENT,
) as ws:
ws.send(Submit(task=req).model_dump_json())
for message in ws:
event = ControlResponse.validate_json(message)
match event:
case Update(data=data):
yield data
case InvalidArgs(errors=errors):
raise InvalidParametersError(
[
ParameterError(
loc=e.loc, msg=e.msg, type=e.type, input=e.input
)
for e in errors
]
)
case PlanNotFound(plan_name=name):
raise UnknownPlanError(message=name)
except InvalidStatus as istat:
match istat.response.status_code:
case 401 | 403:
raise UnauthorisedAccessError() from None
print(vars(istat))
return

def _ws_address(self) -> WebsocketUrl:
api = self._config.url
if api.host is None:
raise ValueError("No host configured")
scheme = "ws" if api.scheme == "http" else "wss"
return WebsocketUrl.build(
scheme=scheme, host=api.host, port=api.port, path=api.path
)


# https://github.com/DiamondLightSource/blueapi/issues/1256 - remove before 2.0
def __getattr__(name: str):
Expand Down
15 changes: 9 additions & 6 deletions src/blueapi/service/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import httpx
import jwt
import requests
from fastapi import Depends, HTTPException, Request
from fastapi import Cookie, Depends, Header, HTTPException
from fastapi.requests import HTTPConnection
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import TypeAdapter
from requests.auth import AuthBase
Expand Down Expand Up @@ -278,14 +279,16 @@ def sync_auth_flow(self, request):
yield request


def unchecked_bearer_token(req: Request) -> str | None:
def unchecked_bearer_token(
auth_header: str | None = Header(alias="Authorization", default=None),
auth_cookie: str | None = Cookie(alias="Authorization", default=None),
) -> str | None:
"""Get bearer token value from authorization header"""
# This is an abridged version of the same feature of
# OAuth2AuthorizationCodeBearer from fastapi. Replicating here prevents
# passing unused configuration and means the schema does not include auth
# details for servers that do not support it.
auth = req.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(auth)
scheme, param = get_authorization_scheme_param(auth_header or auth_cookie)
if scheme.casefold() != "bearer":
return None
return param.strip()
Expand All @@ -303,7 +306,7 @@ def build_access_token_check(config: OIDCConfig):
"""
jwkclient = jwt.PyJWKClient(config.jwks_uri)

def validate_bearer_token(request: Request, token: UncheckedBearerToken):
def validate_bearer_token(request: HTTPConnection, token: UncheckedBearerToken):
"""Check that a bearer token is valid and inject into request state"""
if not token:
raise HTTPException(
Expand All @@ -326,7 +329,7 @@ def validate_bearer_token(request: Request, token: UncheckedBearerToken):
return validate_bearer_token


def access_token(request: Request) -> Mapping[str, Any] | None:
def access_token(request: HTTPConnection) -> Mapping[str, Any] | None:
"""Get the decoded and verified access token of the user making the request"""
return getattr(request.state, "decoded_access_token", None)

Expand Down
42 changes: 40 additions & 2 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
from collections.abc import Mapping
from dataclasses import dataclass
from functools import cache
from multiprocessing.connection import Connection
from typing import Any

from bluesky.callbacks.tiled_writer import TiledWriter
Expand All @@ -9,6 +12,7 @@

from blueapi.cli.scratch import get_python_environment
from blueapi.config import ApplicationConfig, OIDCConfig, ServiceAccount, StompConfig
from blueapi.core.bluesky_types import DataEvent
from blueapi.core.context import BlueskyContext
from blueapi.core.event import EventStream
from blueapi.log import set_up_logging
Expand All @@ -22,14 +26,14 @@
WorkerTask,
)
from blueapi.utils.serialization import access_blob
from blueapi.worker.event import TaskStatusEnum, WorkerEvent, WorkerState
from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent, WorkerState
from blueapi.worker.task import Task
from blueapi.worker.task_worker import TaskWorker, TrackableTask

"""This module provides interface between web application and underlying Bluesky
context and worker"""


LOGGER = logging.getLogger(__name__)
_CONFIG: ApplicationConfig = ApplicationConfig()


Expand Down Expand Up @@ -286,3 +290,37 @@ def get_python_env(
"""Retrieve information about the Python environment"""
scratch = config().scratch
return get_python_environment(config=scratch, name=name, source=source)


@dataclass
class SubHandles:
worker: int
progress: int
data: int


def pipe_events(tx: Connection) -> SubHandles:
tw = worker()

def handler(
worker_event: WorkerEvent | DataEvent | ProgressEvent,
_cor_id: str | None,
) -> None:

try:
tx.send(worker_event)
except BrokenPipeError:
LOGGER.warning("Sending event to broken pipe")
pass
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
Comment thread
tpoliaw marked this conversation as resolved.
Dismissed

w = tw.worker_events.subscribe(handler)
d = tw.data_events.subscribe(handler)
p = tw.progress_events.subscribe(handler)
return SubHandles(worker=w, data=d, progress=p)


def unpipe_events(hnd: SubHandles) -> None:
tw = worker()
tw.worker_events.unsubscribe(hnd.worker)
tw.data_events.unsubscribe(hnd.data)
tw.progress_events.unsubscribe(hnd.progress)
Loading
Loading