From c474afd59bf019e5f61c2fca963cb0bad0f061c3 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Tue, 27 Jan 2026 12:01:34 -0600 Subject: [PATCH 01/13] Updating authenticators from latest in Tiled --- bluesky_httpserver/app.py | 9 +- bluesky_httpserver/authentication.py | 17 +- bluesky_httpserver/authentication/__init__.py | 11 + .../authentication/authenticator_base.py | 39 ++ bluesky_httpserver/authenticators.py | 462 ++++++++++++------ 5 files changed, 365 insertions(+), 173 deletions(-) create mode 100644 bluesky_httpserver/authentication/__init__.py create mode 100644 bluesky_httpserver/authentication/authenticator_base.py diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py index f09acb3..9a8420a 100644 --- a/bluesky_httpserver/app.py +++ b/bluesky_httpserver/app.py @@ -15,7 +15,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi -from .authentication import Mode +from .authentication import ExternalAuthenticator, InternalAuthenticator from .console_output import CollectPublishedConsoleOutput, ConsoleOutputStream, SystemInfoStream from .core import PatchedStreamingResponse from .database.core import purge_expired @@ -179,12 +179,11 @@ def build_app(authentication=None, api_access=None, resource_access=None, server for spec in authentication["providers"]: provider = spec["provider"] authenticator = spec["authenticator"] - mode = authenticator.mode - if mode == Mode.password: + if isinstance(authenticator, InternalAuthenticator): authentication_router.post(f"/provider/{provider}/token")( build_handle_credentials_route(authenticator, provider) ) - elif mode == Mode.external: + elif isinstance(authenticator, ExternalAuthenticator): authentication_router.get(f"/provider/{provider}/code")( build_auth_code_route(authenticator, provider) ) @@ -192,7 +191,7 @@ def build_app(authentication=None, api_access=None, resource_access=None, server build_auth_code_route(authenticator, provider) ) else: - raise ValueError(f"unknown authentication mode {mode}") + raise ValueError(f"unknown authenticator type {type(authenticator)}") for custom_router in getattr(authenticator, "include_routers", []): authentication_router.include_router(custom_router, prefix=f"/provider/{provider}") diff --git a/bluesky_httpserver/authentication.py b/bluesky_httpserver/authentication.py index 9772974..a30db6a 100644 --- a/bluesky_httpserver/authentication.py +++ b/bluesky_httpserver/authentication.py @@ -31,6 +31,11 @@ from pydantic_settings import BaseSettings from . import schemas +from .authentication.authenticator_base import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME from .core import json_or_msgpack from .database import orm @@ -54,12 +59,6 @@ def utcnow(): "UTC now with second resolution" return datetime.utcnow().replace(microsecond=0) - -class Mode(enum.Enum): - password = "password" - external = "external" - - class Token(BaseModel): access_token: str token_type: str @@ -455,7 +454,8 @@ async def auth_code( api_access_manager=Depends(get_api_access_manager), ): request.state.endpoint = "auth" - username = await authenticator.authenticate(request) + user_session_state = await authenticator.authenticate(request) + username = user_session_state.user_name if user_session_state else None if username and api_access_manager.is_user_known(username): scopes = api_access_manager.get_user_scopes(username) @@ -484,7 +484,8 @@ async def handle_credentials( api_access_manager=Depends(get_api_access_manager), ): request.state.endpoint = "auth" - username = await authenticator.authenticate(username=form_data.username, password=form_data.password) + user_session_state = await authenticator.authenticate(username=form_data.username, password=form_data.password) + username = user_session_state.user_name if user_session_state else None err_msg = None if not username: diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py new file mode 100644 index 0000000..58c758f --- /dev/null +++ b/bluesky_httpserver/authentication/__init__.py @@ -0,0 +1,11 @@ +from .authenticator_base import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) + +__all__ = [ + "ExternalAuthenticator", + "InternalAuthenticator", + "UserSessionState", +] diff --git a/bluesky_httpserver/authentication/authenticator_base.py b/bluesky_httpserver/authentication/authenticator_base.py new file mode 100644 index 0000000..7a2cff3 --- /dev/null +++ b/bluesky_httpserver/authentication/authenticator_base.py @@ -0,0 +1,39 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Optional + +from fastapi import Request + + +@dataclass +class UserSessionState: + """Data transfer class to communicate custom session state information.""" + + user_name: str + state: dict = None + + +class InternalAuthenticator(ABC): + """ + Base class for authenticators that use username/password credentials. + + Subclasses must implement the authenticate method which takes a username + and password and returns a UserSessionState on success or None on failure. + """ + + async def authenticate( + self, username: str, password: str + ) -> Optional[UserSessionState]: + raise NotImplementedError + + +class ExternalAuthenticator(ABC): + """ + Base class for authenticators that use external identity providers. + + Subclasses must implement the authenticate method which takes a FastAPI + Request object and returns a UserSessionState on success or None on failure. + """ + + async def authenticate(self, request: Request) -> Optional[UserSessionState]: + raise NotImplementedError diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index 61c2da4..3b439f4 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -1,21 +1,32 @@ import asyncio +import base64 import functools import logging import re import secrets from collections.abc import Iterable +from datetime import timedelta +from typing import Any, List, Mapping, Optional, cast +import httpx +from cachetools import TTLCache, cached from fastapi import APIRouter, Request -from jose import JWTError, jwk, jwt +from fastapi.security import OAuth2, OAuth2AuthorizationCodeBearer +from jose import JWTError, jwt +from pydantic import Secret from starlette.responses import RedirectResponse -from .authentication import Mode -from .utils import modules_available +from .authentication import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) +from .utils import get_root_url, modules_available logger = logging.getLogger(__name__) -class DummyAuthenticator: +class DummyAuthenticator(InternalAuthenticator): """ For test and demo purposes only! @@ -23,26 +34,20 @@ class DummyAuthenticator: """ - mode = Mode.password + def __init__(self, confirmation_message: str = ""): + self.confirmation_message = confirmation_message - async def authenticate(self, username: str, password: str): - return username + async def authenticate(self, username: str, password: str) -> UserSessionState: + return UserSessionState(username, {}) -class DictionaryAuthenticator: +class DictionaryAuthenticator(InternalAuthenticator): """ For test and demo purposes only! Check passwords from a dictionary of usernames mapped to passwords. - - Parameters - ---------- - - users_to_passwords: dict(str, str) - Mapping of usernames to passwords. """ - mode = Mode.password configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -50,25 +55,32 @@ class DictionaryAuthenticator: properties: users_to_password: type: object - description: | - Mapping usernames to password. Environment variable expansion should be - used to avoid placing passwords directly in configuration. + description: | + Mapping usernames to password. Environment variable expansion should be + used to avoid placing passwords directly in configuration. + confirmation_message: + type: string + description: May be displayed by client after successful login. """ - def __init__(self, users_to_passwords): + def __init__( + self, users_to_passwords: Mapping[str, str], confirmation_message: str = "" + ): self._users_to_passwords = users_to_passwords + self.confirmation_message = confirmation_message - async def authenticate(self, username: str, password: str): + async def authenticate( + self, username: str, password: str + ) -> Optional[UserSessionState]: true_password = self._users_to_passwords.get(username) if not true_password: # Username is not valid. - return + return None if secrets.compare_digest(true_password, password): - return username + return UserSessionState(username, {}) -class PAMAuthenticator: - mode = Mode.password +class PAMAuthenticator(InternalAuthenticator): configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -77,90 +89,149 @@ class PAMAuthenticator: service: type: string description: PAM service. Default is 'login'. + confirmation_message: + type: string + description: May be displayed by client after successful login. """ - def __init__(self, service="login"): + def __init__(self, service: str = "login", confirmation_message: str = ""): if not modules_available("pamela"): - raise ModuleNotFoundError("This PAMAuthenticator requires the module 'pamela' to be installed.") + raise ModuleNotFoundError( + "This PAMAuthenticator requires the module 'pamela' to be installed." + ) self.service = service + self.confirmation_message = confirmation_message # TODO Try to open a PAM session. - async def authenticate(self, username: str, password: str): + async def authenticate( + self, username: str, password: str + ) -> Optional[UserSessionState]: import pamela try: pamela.authenticate(username, password, service=self.service) + return UserSessionState(username, {}) except pamela.PAMError: # Authentication failed. - return - else: - return username + return None -class OIDCAuthenticator: - mode = Mode.external +class OIDCAuthenticator(ExternalAuthenticator): configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object additionalProperties: false properties: + audience: + type: string client_id: type: string client_secret: type: string - redirect_uri: + well_known_uri: type: string - token_uri: + confirmation_message: type: string - authorization_endpoint: + redirect_on_success: + type: string + redirect_on_failure: type: string - public_keys: - type: array - item: - type: object - properties: - - alg: - type: string - - e - type: string - - kid - type: string - - kty - type: string - - n - type: string - - use - type: string - required: - - alg - - e - - kid - - kty - - n - - use """ def __init__( self, - client_id, - client_secret, - redirect_uri, - public_keys, - token_uri, - authorization_endpoint, - confirmation_message, + audience: str, + client_id: str, + client_secret: str, + well_known_uri: str, + confirmation_message: str = "", + redirect_on_success: Optional[str] = None, + redirect_on_failure: Optional[str] = None, ): - self.client_id = client_id - self.client_secret = client_secret + self._audience = audience + self._client_id = client_id + self._client_secret = Secret(client_secret) + self._well_known_url = well_known_uri self.confirmation_message = confirmation_message - self.redirect_uri = redirect_uri - self.public_keys = public_keys - self.token_uri = token_uri - self.authorization_endpoint = authorization_endpoint.format(client_id=client_id, redirect_uri=redirect_uri) - - async def authenticate(self, request): - code = request.query_params["code"] - response = await exchange_code(self.token_uri, code, self.client_id, self.client_secret, self.redirect_uri) + self.redirect_on_success = redirect_on_success + self.redirect_on_failure = redirect_on_failure + + @functools.cached_property + def _config_from_oidc_url(self) -> dict[str, Any]: + response: httpx.Response = httpx.get(self._well_known_url) + response.raise_for_status() + return response.json() + + @functools.cached_property + def client_id(self) -> str: + return self._client_id + + @functools.cached_property + def id_token_signing_alg_values_supported(self) -> list[str]: + return cast( + list[str], + self._config_from_oidc_url.get("id_token_signing_alg_values_supported"), + ) + + @functools.cached_property + def issuer(self) -> str: + return cast(str, self._config_from_oidc_url.get("issuer")) + + @functools.cached_property + def jwks_uri(self) -> str: + return cast(str, self._config_from_oidc_url.get("jwks_uri")) + + @functools.cached_property + def token_endpoint(self) -> str: + return cast(str, self._config_from_oidc_url.get("token_endpoint")) + + @functools.cached_property + def authorization_endpoint(self) -> httpx.URL: + return httpx.URL( + cast(str, self._config_from_oidc_url.get("authorization_endpoint")) + ) + + @functools.cached_property + def device_authorization_endpoint(self) -> str: + return cast( + str, self._config_from_oidc_url.get("device_authorization_endpoint") + ) + + @functools.cached_property + def end_session_endpoint(self) -> str: + return cast(str, self._config_from_oidc_url.get("end_session_endpoint")) + + @cached(TTLCache(maxsize=1, ttl=timedelta(days=7).total_seconds())) + def keys(self) -> List[str]: + return httpx.get(self.jwks_uri).raise_for_status().json().get("keys", []) + + def decode_token(self, token: str) -> dict[str, Any]: + return jwt.decode( + token, + key=self.keys(), + algorithms=self.id_token_signing_alg_values_supported, + audience=self._audience, + issuer=self.issuer, + ) + + async def authenticate(self, request: Request) -> Optional[UserSessionState]: + code = request.query_params.get("code") + if not code: + logger.warning( + "Authentication failed: No authorization code parameter provided." + ) + return None + # A proxy in the middle may make the request into something like + # 'http://localhost:8000/...' so we fix the first part but keep + # the original URI path. + redirect_uri = f"{get_root_url(request)}{request.url.path}" + response = await exchange_code( + self.token_endpoint, + code, + self._client_id, + self._client_secret.get_secret_value(), + redirect_uri, + ) response_body = response.json() if response.is_error: logger.error("Authentication error: %r", response_body) @@ -168,63 +239,84 @@ async def authenticate(self, request): response_body = response.json() id_token = response_body["id_token"] access_token = response_body["access_token"] - # Match the kid in id_token to a key in the list of public_keys. - key = find_key(id_token, self.public_keys) try: - verified_body = jwt.decode(id_token, key, access_token=access_token, audience=self.client_id) + verified_body = self.decode_token(access_token) except JWTError: logger.exception( "Authentication error. Unverified token: %r", jwt.get_unverified_claims(id_token), ) return None - return verified_body["sub"] + return UserSessionState(verified_body["sub"], {}) -class KeyNotFoundError(Exception): - pass - - -def find_key(token, keys): - """ - Find a key from the configured keys based on the kid claim of the token - - Parameters - ---------- - token : token to search for the kid from - keys: list of keys - - Raises - ------ - KeyNotFoundError: - returned if the token does not have a kid claim - - Returns - ------ - key: found key object - """ +class ProxiedOIDCAuthenticator(OIDCAuthenticator): + configuration_schema = """ +$schema": http://json-schema.org/draft-07/schema# +type: object +additionalProperties: false +properties: + audience: + type: string + client_id: + type: string + well_known_uri: + type: string + scopes: + type: array + items: + type: string + description: | + Optional list of OAuth2 scopes to request. If provided, authorization + should be enforced by an external policy agent (for example ExternalPolicyDecisionPoint) + rather than by this authenticator. + device_flow_client_id: + type: string + confirmation_message: + type: string +""" - unverified = jwt.get_unverified_header(token) - kid = unverified.get("kid") - if not kid: - raise KeyNotFoundError("No 'kid' in token") + def __init__( + self, + audience: str, + client_id: str, + well_known_uri: str, + device_flow_client_id: str, + scopes: Optional[List[str]] = None, + confirmation_message: str = "", + ): + super().__init__( + audience=audience, + client_id=client_id, + client_secret="", + well_known_uri=well_known_uri, + confirmation_message=confirmation_message, + ) + self.scopes = scopes + self.device_flow_client_id = device_flow_client_id + self._oidc_bearer = OAuth2AuthorizationCodeBearer( + authorizationUrl=str(self.authorization_endpoint), + tokenUrl=self.token_endpoint, + ) - for key in keys: - if key["kid"] == kid: - return jwk.construct(key) - return KeyNotFoundError(f"Token specifies {kid} but we have {[k['kid'] for k in keys]}") + @property + def oauth2_schema(self) -> OAuth2: + return self._oidc_bearer -async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect_uri): +async def exchange_code( + token_uri: str, + auth_code: str, + client_id: str, + client_secret: str, + redirect_uri: str, +) -> httpx.Response: """Method that talks to an IdP to exchange a code for an access_token and/or id_token Args: token_url ([type]): [description] auth_code ([type]): [description] """ - if not modules_available("httpx"): - raise ModuleNotFoundError("This authenticator requires 'httpx'. (pip install httpx)") - import httpx - + auth_value = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() response = httpx.post( url=token_uri, data={ @@ -234,18 +326,18 @@ async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect "code": auth_code, "client_secret": client_secret, }, + headers={"Authorization": f"Basic {auth_value}"}, ) return response -class SAMLAuthenticator: - mode = Mode.external +class SAMLAuthenticator(ExternalAuthenticator): def __init__( self, saml_settings, # See EXAMPLE_SAML_SETTINGS below. - attribute_name, # which SAML attribute to use as 'id' for Idenity - confirmation_message=None, + attribute_name: str, # which SAML attribute to use as 'id' for Identity + confirmation_message: str = "", ): self.saml_settings = saml_settings self.attribute_name = attribute_name @@ -258,30 +350,26 @@ def __init__( # The PyPI package name is 'python3-saml' # but it imports as 'onelogin'. # https://github.com/onelogin/python3-saml - raise ModuleNotFoundError("This SAMLAuthenticator requires 'python3-saml' to be installed.") + raise ModuleNotFoundError( + "This SAMLAuthenticator requires 'python3-saml' to be installed." + ) from onelogin.saml2.auth import OneLogin_Saml2_Auth @router.get("/login") - async def saml_login(request: Request): + async def saml_login(request: Request) -> RedirectResponse: req = await prepare_saml_from_fastapi_request(request) auth = OneLogin_Saml2_Auth(req, self.saml_settings) - # saml_settings = auth.get_settings() - # metadata = saml_settings.get_sp_metadata() - # errors = saml_settings.validate_metadata(metadata) - # if len(errors) == 0: - # print(metadata) - # else: - # print("Error found on Metadata: %s" % (', '.join(errors))) callback_url = auth.login() - response = RedirectResponse(url=callback_url) - return response + return RedirectResponse(url=callback_url) self.include_routers = [router] - async def authenticate(self, request): + async def authenticate(self, request: Request) -> Optional[UserSessionState]: if not modules_available("onelogin"): - raise ModuleNotFoundError("This SAMLAuthenticator requires the module 'oneline' to be installed.") + raise ModuleNotFoundError( + "This SAMLAuthenticator requires the module 'oneline' to be installed." + ) from onelogin.saml2.auth import OneLogin_Saml2_Auth req = await prepare_saml_from_fastapi_request(request, True) @@ -290,26 +378,27 @@ async def authenticate(self, request): errors = auth.get_errors() # This method receives an array with the errors if errors: raise Exception( - "Error when processing SAML Response: %s %s" % (", ".join(errors), auth.get_last_error_reason()) + "Error when processing SAML Response: %s %s" + % (", ".join(errors), auth.get_last_error_reason()) ) if auth.is_authenticated(): # Return a string that the Identity can use as id. attribute_as_list = auth.get_attributes()[self.attribute_name] # Confused in what situation this would have more than one item.... assert len(attribute_as_list) == 1 - return attribute_as_list[0] + return UserSessionState(attribute_as_list[0], {}) else: return None -async def prepare_saml_from_fastapi_request(request, debug=False): +async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, str]: form_data = await request.form() rv = { "http_host": request.client.host, "server_port": request.url.port, "script_name": request.url.path, "post_data": {}, - "get_data": {}, + "get_data": {} # Advanced request options # "https": "", # "request_uri": "", @@ -328,7 +417,7 @@ async def prepare_saml_from_fastapi_request(request, debug=False): return rv -class LDAPAuthenticator: +class LDAPAuthenticator(InternalAuthenticator): """ LDAP authenticator. The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator @@ -472,6 +561,8 @@ class LDAPAuthenticator: This can be useful in an heterogeneous environment, when supplying a UNIX username to authenticate against AD. + confirmation_message: str + May be displayed by client after successful login. Examples -------- @@ -510,8 +601,6 @@ class LDAPAuthenticator: id: user02 """ - mode = Mode.password - def __init__( self, server_address, @@ -536,6 +625,7 @@ def __init__( attributes=None, auth_state_attributes=None, use_lookup_dn_username=True, + confirmation_message="", ): self.use_ssl = use_ssl self.use_tls = use_tls @@ -554,7 +644,9 @@ def __init__( self.escape_userdn = escape_userdn self.search_filter = search_filter self.attributes = attributes if attributes else [] - self.auth_state_attributes = auth_state_attributes if auth_state_attributes else [] + self.auth_state_attributes = ( + auth_state_attributes if auth_state_attributes else [] + ) self.use_lookup_dn_username = use_lookup_dn_username if isinstance(server_address, str): @@ -567,10 +659,15 @@ def __init__( f"type(server_address)={type(server_address)}" ) if not server_address_list: - raise ValueError("No servers are specified: 'server_address' is an empty list") + raise ValueError( + "No servers are specified: 'server_address' is an empty list" + ) self.server_address_list = server_address_list - self.server_port = server_port if server_port is not None else self._server_port_default() + self.server_port = ( + server_port if server_port is not None else self._server_port_default() + ) + self.confirmation_message = confirmation_message def _server_port_default(self): if self.use_ssl: @@ -623,8 +720,15 @@ async def resolve_username(self, username_supplied_by_user): response = conn.response if len(response) == 0 or "attributes" not in response[0].keys(): - msg = "No entry found for user '{username}' " "when looking up attribute '{attribute}'" - logger.warning(msg.format(username=username_supplied_by_user, attribute=self.user_attribute)) + msg = ( + "No entry found for user '{username}' " + "when looking up attribute '{attribute}'" + ) + logger.warning( + msg.format( + username=username_supplied_by_user, attribute=self.user_attribute + ) + ) return (None, None) user_dn = response[0]["attributes"][self.lookup_dn_user_dn_attribute] @@ -655,7 +759,7 @@ async def resolve_username(self, username_supplied_by_user): def get_connection(self, userdn, password): import ldap3 - # NOTE: setting 'acitve=False' essentially disables exclusion of inactive servers from the pool. + # NOTE: setting 'active=False' essentially disables exclusion of inactive servers from the pool. # It probably does not matter if the pool contains only one server, but it could have implications # when there are multiple servers in the pool. It is not clear what those implications are. # But using the default 'activate=True' results in the thread being blocked indefinitely @@ -675,14 +779,23 @@ def get_connection(self, userdn, password): server_port = self.server_port server = ldap3.Server( - server_addr, port=server_port, use_ssl=self.use_ssl, connect_timeout=self.connect_timeout + server_addr, + port=server_port, + use_ssl=self.use_ssl, + connect_timeout=self.connect_timeout, ) server_pool.add(server) - auto_bind_no_ssl = ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS + auto_bind_no_ssl = ( + ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS + ) auto_bind = ldap3.AUTO_BIND_NO_TLS if self.use_ssl else auto_bind_no_ssl conn = ldap3.Connection( - server_pool, user=userdn, password=password, auto_bind=auto_bind, receive_timeout=self.receive_timeout + server_pool, + user=userdn, + password=password, + auto_bind=auto_bind, + receive_timeout=self.receive_timeout, ) return conn @@ -690,14 +803,19 @@ async def get_user_attributes(self, conn, userdn): attrs = {} if self.auth_state_attributes: search_func = functools.partial( - conn.search, userdn, "(objectClass=*)", attributes=self.auth_state_attributes + conn.search, + userdn, + "(objectClass=*)", + attributes=self.auth_state_attributes, ) found = await asyncio.get_running_loop().run_in_executor(None, search_func) if found: attrs = conn.entries[0].entry_attributes_as_dict return attrs - async def authenticate(self, username: str, password: str): + async def authenticate( + self, username: str, password: str + ) -> Optional[UserSessionState]: import ldap3 username_saved = username # Save the user name passed as a parameter @@ -723,7 +841,9 @@ async def authenticate(self, username: str, password: str): # sanity check if not self.lookup_dn and not bind_dn_template: - logger.warning("Login not allowed, please configure 'lookup_dn' or 'bind_dn_template'.") + logger.warning( + "Login not allowed, please configure 'lookup_dn' or 'bind_dn_template'." + ) return None if self.lookup_dn: @@ -761,7 +881,9 @@ async def authenticate(self, username: str, password: str): if conn.bound: is_bound = True else: - is_bound = await asyncio.get_running_loop().run_in_executor(None, conn.bind) + is_bound = await asyncio.get_running_loop().run_in_executor( + None, conn.bind + ) msg = msg.format(username=username, userdn=userdn, is_bound=is_bound) logger.debug(msg) @@ -774,7 +896,9 @@ async def authenticate(self, username: str, password: str): return None if self.search_filter: - search_filter = self.search_filter.format(userattr=self.user_attribute, username=username) + search_filter = self.search_filter.format( + userattr=self.user_attribute, username=username + ) search_func = functools.partial( conn.search, @@ -788,18 +912,33 @@ async def authenticate(self, username: str, password: str): n_users = len(conn.response) if n_users == 0: msg = "User with '{userattr}={username}' not found in directory" - logger.warning(msg.format(userattr=self.user_attribute, username=username)) + logger.warning( + msg.format(userattr=self.user_attribute, username=username) + ) return None if n_users > 1: - msg = "Duplicate users found! " "{n_users} users found with '{userattr}={username}'" - logger.warning(msg.format(userattr=self.user_attribute, username=username, n_users=n_users)) + msg = ( + "Duplicate users found! " + "{n_users} users found with '{userattr}={username}'" + ) + logger.warning( + msg.format( + userattr=self.user_attribute, username=username, n_users=n_users + ) + ) return None if self.allowed_groups: logger.debug("username:%s Using dn %s", username, userdn) found = False for group in self.allowed_groups: - group_filter = "(|" "(member={userdn})" "(uniqueMember={userdn})" "(memberUid={uid})" ")" + group_filter = ( + "(|" + "(member={userdn})" + "(uniqueMember={userdn})" + "(memberUid={uid})" + ")" + ) group_filter = group_filter.format(userdn=userdn, uid=username) group_attributes = ["member", "uniqueMember", "memberUid"] @@ -810,7 +949,9 @@ async def authenticate(self, username: str, password: str): search_filter=group_filter, attributes=group_attributes, ) - found = await asyncio.get_running_loop().run_in_executor(None, search_func) + found = await asyncio.get_running_loop().run_in_executor( + None, search_func + ) if found: break @@ -826,5 +967,6 @@ async def authenticate(self, username: str, password: str): user_info = await self.get_user_attributes(conn, userdn) if user_info: logger.debug("username:%s attributes:%s", username, user_info) - return {"name": username, "auth_state": user_info} - return username + # this path might never have been worked out...is it ever hit? + return UserSessionState(username, user_info) + return UserSessionState(username, {}) From f54ef363ba1b9a04dad06b1fe7cb76ab184a5078 Mon Sep 17 00:00:00 2001 From: Dmitri Gavrilov Date: Thu, 5 Feb 2026 11:20:34 -0500 Subject: [PATCH 02/13] TST: fix unit tests --- .../{authentication.py => _authentication.py} | 11 ++++------- bluesky_httpserver/authentication/__init__.py | 14 ++++++++++++++ bluesky_httpserver/tests/test_authenticators.py | 6 +++--- requirements.txt | 1 + 4 files changed, 22 insertions(+), 10 deletions(-) rename bluesky_httpserver/{authentication.py => _authentication.py} (99%) diff --git a/bluesky_httpserver/authentication.py b/bluesky_httpserver/_authentication.py similarity index 99% rename from bluesky_httpserver/authentication.py rename to bluesky_httpserver/_authentication.py index a30db6a..0375794 100644 --- a/bluesky_httpserver/authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -1,5 +1,4 @@ import asyncio -import enum import hashlib import secrets import uuid as uuid_module @@ -31,11 +30,6 @@ from pydantic_settings import BaseSettings from . import schemas -from .authentication.authenticator_base import ( - ExternalAuthenticator, - InternalAuthenticator, - UserSessionState, -) from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME from .core import json_or_msgpack from .database import orm @@ -59,6 +53,7 @@ def utcnow(): "UTC now with second resolution" return datetime.utcnow().replace(microsecond=0) + class Token(BaseModel): access_token: str token_type: str @@ -484,7 +479,9 @@ async def handle_credentials( api_access_manager=Depends(get_api_access_manager), ): request.state.endpoint = "auth" - user_session_state = await authenticator.authenticate(username=form_data.username, password=form_data.password) + user_session_state = await authenticator.authenticate( + username=form_data.username, password=form_data.password + ) username = user_session_state.user_name if user_session_state else None err_msg = None diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py index 58c758f..fc35cdd 100644 --- a/bluesky_httpserver/authentication/__init__.py +++ b/bluesky_httpserver/authentication/__init__.py @@ -1,3 +1,11 @@ +from .._authentication import ( + base_authentication_router, + build_auth_code_route, + build_handle_credentials_route, + get_current_principal, + get_current_principal_websocket, + oauth2_scheme, +) from .authenticator_base import ( ExternalAuthenticator, InternalAuthenticator, @@ -8,4 +16,10 @@ "ExternalAuthenticator", "InternalAuthenticator", "UserSessionState", + "get_current_principal", + "get_current_principal_websocket", + "base_authentication_router", + "build_auth_code_route", + "build_handle_credentials_route", + "oauth2_scheme", ] diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py index cc2984c..183ce75 100644 --- a/bluesky_httpserver/tests/test_authenticators.py +++ b/bluesky_httpserver/tests/test_authenticators.py @@ -3,7 +3,7 @@ import pytest # fmt: off -from ..authenticators import LDAPAuthenticator +from ..authenticators import LDAPAuthenticator, UserSessionState @pytest.mark.parametrize("ldap_server_address, ldap_server_port", [ @@ -35,8 +35,8 @@ def test_LDAPAuthenticator_01(use_tls, use_ssl, ldap_server_address, ldap_server ) async def testing(): - assert await authenticator.authenticate("user01", "password1") == "user01" - assert await authenticator.authenticate("user02", "password2") == "user02" + assert await authenticator.authenticate("user01", "password1") == UserSessionState("user01", {}) + assert await authenticator.authenticate("user02", "password2") == UserSessionState("user02", {}) assert await authenticator.authenticate("user02a", "password2") is None assert await authenticator.authenticate("user02", "password2a") is None diff --git a/requirements.txt b/requirements.txt index f465abd..818362f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ alembic bluesky-queueserver bluesky-queueserver-api +cachetools fastapi ldap3 orjson From be73eda1c78b7a76e7784a979de38783325aa32e Mon Sep 17 00:00:00 2001 From: Dmitri Gavrilov Date: Thu, 5 Feb 2026 11:26:58 -0500 Subject: [PATCH 03/13] STY: reformat with black --- .../authentication/authenticator_base.py | 4 +- bluesky_httpserver/authenticators.py | 111 ++++-------------- 2 files changed, 27 insertions(+), 88 deletions(-) diff --git a/bluesky_httpserver/authentication/authenticator_base.py b/bluesky_httpserver/authentication/authenticator_base.py index 7a2cff3..af103c5 100644 --- a/bluesky_httpserver/authentication/authenticator_base.py +++ b/bluesky_httpserver/authentication/authenticator_base.py @@ -21,9 +21,7 @@ class InternalAuthenticator(ABC): and password and returns a UserSessionState on success or None on failure. """ - async def authenticate( - self, username: str, password: str - ) -> Optional[UserSessionState]: + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: raise NotImplementedError diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index 3b439f4..78b6cf1 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -63,15 +63,11 @@ class DictionaryAuthenticator(InternalAuthenticator): description: May be displayed by client after successful login. """ - def __init__( - self, users_to_passwords: Mapping[str, str], confirmation_message: str = "" - ): + def __init__(self, users_to_passwords: Mapping[str, str], confirmation_message: str = ""): self._users_to_passwords = users_to_passwords self.confirmation_message = confirmation_message - async def authenticate( - self, username: str, password: str - ) -> Optional[UserSessionState]: + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: true_password = self._users_to_passwords.get(username) if not true_password: # Username is not valid. @@ -96,16 +92,12 @@ class PAMAuthenticator(InternalAuthenticator): def __init__(self, service: str = "login", confirmation_message: str = ""): if not modules_available("pamela"): - raise ModuleNotFoundError( - "This PAMAuthenticator requires the module 'pamela' to be installed." - ) + raise ModuleNotFoundError("This PAMAuthenticator requires the module 'pamela' to be installed.") self.service = service self.confirmation_message = confirmation_message # TODO Try to open a PAM session. - async def authenticate( - self, username: str, password: str - ) -> Optional[UserSessionState]: + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: import pamela try: @@ -187,15 +179,11 @@ def token_endpoint(self) -> str: @functools.cached_property def authorization_endpoint(self) -> httpx.URL: - return httpx.URL( - cast(str, self._config_from_oidc_url.get("authorization_endpoint")) - ) + return httpx.URL(cast(str, self._config_from_oidc_url.get("authorization_endpoint"))) @functools.cached_property def device_authorization_endpoint(self) -> str: - return cast( - str, self._config_from_oidc_url.get("device_authorization_endpoint") - ) + return cast(str, self._config_from_oidc_url.get("device_authorization_endpoint")) @functools.cached_property def end_session_endpoint(self) -> str: @@ -217,9 +205,7 @@ def decode_token(self, token: str) -> dict[str, Any]: async def authenticate(self, request: Request) -> Optional[UserSessionState]: code = request.query_params.get("code") if not code: - logger.warning( - "Authentication failed: No authorization code parameter provided." - ) + logger.warning("Authentication failed: No authorization code parameter provided.") return None # A proxy in the middle may make the request into something like # 'http://localhost:8000/...' so we fix the first part but keep @@ -350,9 +336,7 @@ def __init__( # The PyPI package name is 'python3-saml' # but it imports as 'onelogin'. # https://github.com/onelogin/python3-saml - raise ModuleNotFoundError( - "This SAMLAuthenticator requires 'python3-saml' to be installed." - ) + raise ModuleNotFoundError("This SAMLAuthenticator requires 'python3-saml' to be installed.") from onelogin.saml2.auth import OneLogin_Saml2_Auth @@ -367,9 +351,7 @@ async def saml_login(request: Request) -> RedirectResponse: async def authenticate(self, request: Request) -> Optional[UserSessionState]: if not modules_available("onelogin"): - raise ModuleNotFoundError( - "This SAMLAuthenticator requires the module 'oneline' to be installed." - ) + raise ModuleNotFoundError("This SAMLAuthenticator requires the module 'oneline' to be installed.") from onelogin.saml2.auth import OneLogin_Saml2_Auth req = await prepare_saml_from_fastapi_request(request, True) @@ -378,8 +360,7 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]: errors = auth.get_errors() # This method receives an array with the errors if errors: raise Exception( - "Error when processing SAML Response: %s %s" - % (", ".join(errors), auth.get_last_error_reason()) + "Error when processing SAML Response: %s %s" % (", ".join(errors), auth.get_last_error_reason()) ) if auth.is_authenticated(): # Return a string that the Identity can use as id. @@ -398,7 +379,7 @@ async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, st "server_port": request.url.port, "script_name": request.url.path, "post_data": {}, - "get_data": {} + "get_data": {}, # Advanced request options # "https": "", # "request_uri": "", @@ -644,9 +625,7 @@ def __init__( self.escape_userdn = escape_userdn self.search_filter = search_filter self.attributes = attributes if attributes else [] - self.auth_state_attributes = ( - auth_state_attributes if auth_state_attributes else [] - ) + self.auth_state_attributes = auth_state_attributes if auth_state_attributes else [] self.use_lookup_dn_username = use_lookup_dn_username if isinstance(server_address, str): @@ -659,14 +638,10 @@ def __init__( f"type(server_address)={type(server_address)}" ) if not server_address_list: - raise ValueError( - "No servers are specified: 'server_address' is an empty list" - ) + raise ValueError("No servers are specified: 'server_address' is an empty list") self.server_address_list = server_address_list - self.server_port = ( - server_port if server_port is not None else self._server_port_default() - ) + self.server_port = server_port if server_port is not None else self._server_port_default() self.confirmation_message = confirmation_message def _server_port_default(self): @@ -720,15 +695,8 @@ async def resolve_username(self, username_supplied_by_user): response = conn.response if len(response) == 0 or "attributes" not in response[0].keys(): - msg = ( - "No entry found for user '{username}' " - "when looking up attribute '{attribute}'" - ) - logger.warning( - msg.format( - username=username_supplied_by_user, attribute=self.user_attribute - ) - ) + msg = "No entry found for user '{username}' " "when looking up attribute '{attribute}'" + logger.warning(msg.format(username=username_supplied_by_user, attribute=self.user_attribute)) return (None, None) user_dn = response[0]["attributes"][self.lookup_dn_user_dn_attribute] @@ -786,9 +754,7 @@ def get_connection(self, userdn, password): ) server_pool.add(server) - auto_bind_no_ssl = ( - ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS - ) + auto_bind_no_ssl = ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS auto_bind = ldap3.AUTO_BIND_NO_TLS if self.use_ssl else auto_bind_no_ssl conn = ldap3.Connection( server_pool, @@ -813,9 +779,7 @@ async def get_user_attributes(self, conn, userdn): attrs = conn.entries[0].entry_attributes_as_dict return attrs - async def authenticate( - self, username: str, password: str - ) -> Optional[UserSessionState]: + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: import ldap3 username_saved = username # Save the user name passed as a parameter @@ -841,9 +805,7 @@ async def authenticate( # sanity check if not self.lookup_dn and not bind_dn_template: - logger.warning( - "Login not allowed, please configure 'lookup_dn' or 'bind_dn_template'." - ) + logger.warning("Login not allowed, please configure 'lookup_dn' or 'bind_dn_template'.") return None if self.lookup_dn: @@ -881,9 +843,7 @@ async def authenticate( if conn.bound: is_bound = True else: - is_bound = await asyncio.get_running_loop().run_in_executor( - None, conn.bind - ) + is_bound = await asyncio.get_running_loop().run_in_executor(None, conn.bind) msg = msg.format(username=username, userdn=userdn, is_bound=is_bound) logger.debug(msg) @@ -896,9 +856,7 @@ async def authenticate( return None if self.search_filter: - search_filter = self.search_filter.format( - userattr=self.user_attribute, username=username - ) + search_filter = self.search_filter.format(userattr=self.user_attribute, username=username) search_func = functools.partial( conn.search, @@ -912,33 +870,18 @@ async def authenticate( n_users = len(conn.response) if n_users == 0: msg = "User with '{userattr}={username}' not found in directory" - logger.warning( - msg.format(userattr=self.user_attribute, username=username) - ) + logger.warning(msg.format(userattr=self.user_attribute, username=username)) return None if n_users > 1: - msg = ( - "Duplicate users found! " - "{n_users} users found with '{userattr}={username}'" - ) - logger.warning( - msg.format( - userattr=self.user_attribute, username=username, n_users=n_users - ) - ) + msg = "Duplicate users found! " "{n_users} users found with '{userattr}={username}'" + logger.warning(msg.format(userattr=self.user_attribute, username=username, n_users=n_users)) return None if self.allowed_groups: logger.debug("username:%s Using dn %s", username, userdn) found = False for group in self.allowed_groups: - group_filter = ( - "(|" - "(member={userdn})" - "(uniqueMember={userdn})" - "(memberUid={uid})" - ")" - ) + group_filter = "(|" "(member={userdn})" "(uniqueMember={userdn})" "(memberUid={uid})" ")" group_filter = group_filter.format(userdn=userdn, uid=username) group_attributes = ["member", "uniqueMember", "memberUid"] @@ -949,9 +892,7 @@ async def authenticate( search_filter=group_filter, attributes=group_attributes, ) - found = await asyncio.get_running_loop().run_in_executor( - None, search_func - ) + found = await asyncio.get_running_loop().run_in_executor(None, search_func) if found: break From 122aa17434a1a35a2e0dab9ebc402a04b0722873 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Tue, 17 Feb 2026 15:04:02 -0600 Subject: [PATCH 04/13] Working version for logging in with Entra This is working okay, although it doens't really work smoothly for the API based login and the http command based login isn't great, as it requires the user to copy and past token around. Compared to ldap which just logs the user in. So still some work to do here to smooth out the user experience. --- bluesky_httpserver/_authentication.py | 363 +++++++++++++++++- bluesky_httpserver/app.py | 27 ++ bluesky_httpserver/authentication/__init__.py | 10 + bluesky_httpserver/authenticators.py | 27 +- .../config_schemas/examples/oidc_config.yml | 78 ++++ .../config_schemas/service_configuration.yml | 32 +- bluesky_httpserver/database/core.py | 40 +- bluesky_httpserver/database/orm.py | 21 + bluesky_httpserver/schemas.py | 17 + bluesky_httpserver/tests/conftest.py | 25 ++ .../tests/test_oidc_authenticators.py | 224 +++++++++++ requirements-dev.txt | 3 + requirements.txt | 1 + 13 files changed, 843 insertions(+), 25 deletions(-) create mode 100644 bluesky_httpserver/config_schemas/examples/oidc_config.yml create mode 100644 bluesky_httpserver/tests/test_oidc_authenticators.py diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py index 0375794..a0d28b1 100644 --- a/bluesky_httpserver/_authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -6,12 +6,13 @@ from datetime import datetime, timedelta from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Request, Response, Security, WebSocket +from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response, Security, WebSocket from fastapi.openapi.models import APIKey, APIKeyIn -from fastapi.responses import JSONResponse +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyQuery from fastapi.security.utils import get_authorization_scheme_param +from sqlalchemy.exc import IntegrityError # To hide third-party warning # .../jose/backends/cryptography_backend.py:18: CryptographyDeprecationWarning: @@ -33,7 +34,14 @@ from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME from .core import json_or_msgpack from .database import orm -from .database.core import create_user, latest_principal_activity, lookup_valid_api_key, lookup_valid_session +from .database.core import ( + create_user, + latest_principal_activity, + lookup_valid_api_key, + lookup_valid_pending_session_by_device_code, + lookup_valid_pending_session_by_user_code, + lookup_valid_session, +) from .settings import get_sessionmaker, get_settings from .utils import ( API_KEY_COOKIE_NAME, @@ -48,6 +56,10 @@ ALGORITHM = "HS256" UNIT_SECOND = timedelta(seconds=1) +# Device code flow constants +DEVICE_CODE_MAX_AGE = timedelta(minutes=10) +DEVICE_CODE_POLLING_INTERVAL = 5 # seconds + def utcnow(): "UTC now with second resolution" @@ -505,6 +517,351 @@ async def handle_credentials( return handle_credentials +def create_pending_session(db): + """ + Create a pending session for device code flow. + + Returns a dict with 'user_code' (user-facing code) and 'device_code' (for polling). + """ + device_code = secrets.token_bytes(32) + hashed_device_code = hashlib.sha256(device_code).digest() + for _ in range(3): + user_code = secrets.token_hex(4).upper() # 8 digit code + pending_session = orm.PendingSession( + user_code=user_code, + hashed_device_code=hashed_device_code, + expiration_time=utcnow() + DEVICE_CODE_MAX_AGE, + ) + db.add(pending_session) + try: + db.commit() + except IntegrityError: + # Since the user_code is short, we cannot completely dismiss the + # possibility of a collision. Retry. + db.rollback() + continue + break + formatted_user_code = f"{user_code[:4]}-{user_code[4:]}" + return { + "user_code": formatted_user_code, + "device_code": device_code.hex(), + } + + +def build_authorize_route(authenticator, provider): + """Build a GET route that redirects the browser to the OIDC provider for authentication.""" + + async def authorize_redirect( + request: Request, + state: Optional[str] = Query(None), + ): + """Redirect browser to OAuth provider for authentication.""" + redirect_uri = f"{get_base_url(request)}/auth/provider/{provider}/code" + + params = { + "client_id": authenticator.client_id, + "response_type": "code", + "scope": "openid profile email", + "redirect_uri": redirect_uri, + } + if state: + params["state"] = state + + auth_url = authenticator.authorization_endpoint.copy_with(params=params) + return RedirectResponse(url=str(auth_url)) + + return authorize_redirect + + +def build_device_code_authorize_route(authenticator, provider): + """Build a POST route that initiates the device code flow for CLI/headless clients.""" + + async def device_code_authorize( + request: Request, + settings: BaseSettings = Depends(get_settings), + ): + """ + Initiate device code flow. + + Returns authorization_uri for the user to visit in browser, + and device_code + user_code for the CLI client to poll. + """ + request.state.endpoint = "auth" + with get_sessionmaker(settings.database_settings)() as db: + pending_session = create_pending_session(db) + + verification_uri = f"{get_base_url(request)}/auth/provider/{provider}/token" + authorization_uri = authenticator.authorization_endpoint.copy_with( + params={ + "client_id": authenticator.client_id, + "response_type": "code", + "scope": "openid profile email", + "redirect_uri": f"{get_base_url(request)}/auth/provider/{provider}/device_code", + } + ) + return { + "authorization_uri": str(authorization_uri), # URL that user should visit in browser + "verification_uri": str(verification_uri), # URL that terminal client will poll + "interval": DEVICE_CODE_POLLING_INTERVAL, # suggested polling interval + "device_code": pending_session["device_code"], + "expires_in": int(DEVICE_CODE_MAX_AGE.total_seconds()), # seconds + "user_code": pending_session["user_code"], + } + + return device_code_authorize + + +def build_device_code_form_route(authenticator, provider): + """Build a GET route that shows the user code entry form.""" + + async def device_code_form( + request: Request, + code: str, + ): + """Show form for user to enter user code after browser auth.""" + action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" + html_content = f""" + + + + Authorize Session + + + +

Authorize Bluesky HTTP Server Session

+
+ + + +
+ +
+ + +""" + return HTMLResponse(content=html_content) + + return device_code_form + + +def build_device_code_submit_route(authenticator, provider): + """Build a POST route that handles user code submission after browser auth.""" + + async def device_code_submit( + request: Request, + code: str = Form(), + user_code: str = Form(), + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """Handle user code submission and link to authenticated session.""" + request.state.endpoint = "auth" + action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" + normalized_user_code = user_code.upper().replace("-", "").strip() + + with get_sessionmaker(settings.database_settings)() as db: + pending_session = lookup_valid_pending_session_by_user_code(db, normalized_user_code) + if pending_session is None: + error_html = f""" + + +Error + + + +

Authorization Failed

+
Invalid user code. It may have been mistyped, or the pending request may have expired.
+
Try again + + +""" + return HTMLResponse(content=error_html, status_code=401) + + # Authenticate with the OIDC provider using the authorization code + user_session_state = await authenticator.authenticate(request) + if not user_session_state: + error_html = """ + + +Authentication Failed + + + +

Authentication Failed

+
User code was correct but authentication with the identity provider failed. Please contact the administrator.
+ + +""" + return HTMLResponse(content=error_html, status_code=401) + + username = user_session_state.user_name + if not api_access_manager.is_user_known(username): + error_html = f""" + + +Authorization Failed + + + +

Authorization Failed

+
User '{username}' is not authorized to access this server.
+ + +""" + return HTMLResponse(content=error_html, status_code=403) + + scopes = api_access_manager.get_user_scopes(username) + + # Create the session + session = await asyncio.get_running_loop().run_in_executor( + None, _create_session_orm, settings, provider, username, db + ) + + # Link the pending session to the real session + pending_session.session_id = session.id + db.add(pending_session) + db.commit() + + success_html = f""" + + +Success + + + +

Success!

+
You have been authenticated. Return to your terminal application - within {DEVICE_CODE_POLLING_INTERVAL} seconds it should be successfully logged in.
+ + +""" + return HTMLResponse(content=success_html) + + return device_code_submit + + +def _create_session_orm(settings, identity_provider, id, db): + """ + Create a session and return the ORM object (for device code flow). + + Unlike create_session(), this returns the ORM object so we can link it + to the pending session. + """ + # Have we seen this Identity before? + identity = ( + db.query(orm.Identity) + .filter(orm.Identity.id == id) + .filter(orm.Identity.provider == identity_provider) + .first() + ) + now = utcnow() + if identity is None: + # We have not. Make a new Principal and link this new Identity to it. + principal = create_user(db, identity_provider, id) + (new_identity,) = principal.identities + new_identity.latest_login = now + else: + identity.latest_login = now + principal = identity.principal + + session = orm.Session( + principal_id=principal.id, + expiration_time=utcnow() + settings.session_max_age, + ) + db.add(session) + db.commit() + db.refresh(session) + return session + + +def build_device_code_token_route(authenticator, provider): + """Build a POST route for the CLI client to poll for tokens.""" + + async def device_code_token( + request: Request, + body: schemas.DeviceCode, + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """ + Poll for tokens after device code flow authentication. + + Returns tokens if the user has authenticated, or 400 with + 'authorization_pending' error if still waiting. + """ + request.state.endpoint = "auth" + device_code_hex = body.device_code + try: + device_code = bytes.fromhex(device_code_hex) + except Exception: + # Not valid hex, therefore not a valid device_code + raise HTTPException(status_code=401, detail="Invalid device code") + + with get_sessionmaker(settings.database_settings)() as db: + pending_session = lookup_valid_pending_session_by_device_code(db, device_code) + if pending_session is None: + raise HTTPException( + status_code=404, + detail="No such device_code. The pending request may have expired.", + ) + if pending_session.session_id is None: + raise HTTPException(status_code=400, detail={"error": "authorization_pending"}) + + session = pending_session.session + principal = session.principal + + # Get scopes for the user + # Find an identity to get the username + identity = db.query(orm.Identity).filter(orm.Identity.principal_id == principal.id).first() + if identity and api_access_manager.is_user_known(identity.id): + scopes = api_access_manager.get_user_scopes(identity.id) + else: + scopes = set() + + # The pending session can only be used once + db.delete(pending_session) + db.commit() + + # Generate tokens + data = { + "sub": principal.uuid.hex, + "sub_typ": principal.type.value, + "scp": list(scopes), + "ids": [{"id": ident.id, "idp": ident.provider} for ident in principal.identities], + } + access_token = create_access_token( + data=data, + expires_delta=settings.access_token_max_age, + secret_key=settings.secret_keys[0], + ) + refresh_token = create_refresh_token( + session_id=session.uuid.hex, + expires_delta=settings.refresh_token_max_age, + secret_key=settings.secret_keys[0], + ) + + return { + "access_token": access_token, + "expires_in": int(settings.access_token_max_age / UNIT_SECOND), + "refresh_token": refresh_token, + "refresh_token_expires_in": int(settings.refresh_token_max_age / UNIT_SECOND), + "token_type": "bearer", + } + + return device_code_token + + def generate_apikey(db, principal, apikey_params, request, allowed_scopes, source_api_key_scopes): # Use API key scopes if API key is generated based on existing API key, otherwise used allowed scopes if (source_api_key_scopes is not None) and ("inherit" not in source_api_key_scopes): diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py index 9a8420a..0d96667 100644 --- a/bluesky_httpserver/app.py +++ b/bluesky_httpserver/app.py @@ -160,6 +160,11 @@ def build_app(authentication=None, api_access=None, resource_access=None, server from .authentication import ( base_authentication_router, build_auth_code_route, + build_authorize_route, + build_device_code_authorize_route, + build_device_code_form_route, + build_device_code_submit_route, + build_device_code_token_route, build_handle_credentials_route, oauth2_scheme, ) @@ -184,12 +189,34 @@ def build_app(authentication=None, api_access=None, resource_access=None, server build_handle_credentials_route(authenticator, provider) ) elif isinstance(authenticator, ExternalAuthenticator): + # Standard OAuth callback route (authorization code flow) authentication_router.get(f"/provider/{provider}/code")( build_auth_code_route(authenticator, provider) ) authentication_router.post(f"/provider/{provider}/code")( build_auth_code_route(authenticator, provider) ) + # Device code flow routes for CLI/headless clients + # GET /authorize - redirects browser to OIDC provider + authentication_router.get(f"/provider/{provider}/authorize")( + build_authorize_route(authenticator, provider) + ) + # POST /authorize - initiates device code flow (returns device_code, user_code, etc.) + authentication_router.post(f"/provider/{provider}/authorize")( + build_device_code_authorize_route(authenticator, provider) + ) + # GET /device_code - shows user code entry form + authentication_router.get(f"/provider/{provider}/device_code")( + build_device_code_form_route(authenticator, provider) + ) + # POST /device_code - handles user code submission after browser auth + authentication_router.post(f"/provider/{provider}/device_code")( + build_device_code_submit_route(authenticator, provider) + ) + # POST /token - CLI client polls this for tokens + authentication_router.post(f"/provider/{provider}/token")( + build_device_code_token_route(authenticator, provider) + ) else: raise ValueError(f"unknown authenticator type {type(authenticator)}") for custom_router in getattr(authenticator, "include_routers", []): diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py index fc35cdd..85d835e 100644 --- a/bluesky_httpserver/authentication/__init__.py +++ b/bluesky_httpserver/authentication/__init__.py @@ -1,6 +1,11 @@ from .._authentication import ( base_authentication_router, build_auth_code_route, + build_authorize_route, + build_device_code_authorize_route, + build_device_code_form_route, + build_device_code_submit_route, + build_device_code_token_route, build_handle_credentials_route, get_current_principal, get_current_principal_websocket, @@ -20,6 +25,11 @@ "get_current_principal_websocket", "base_authentication_router", "build_auth_code_route", + "build_authorize_route", + "build_device_code_authorize_route", + "build_device_code_form_route", + "build_device_code_submit_route", + "build_device_code_token_route", "build_handle_credentials_route", "oauth2_scheme", ] diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index 78b6cf1..e8d108d 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -224,16 +224,37 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]: return None response_body = response.json() id_token = response_body["id_token"] - access_token = response_body["access_token"] + # NOTE: We decode the id_token, not access_token, because: + # 1. The id_token is the OIDC identity assertion meant for the client + # 2. Some providers (like Microsoft Entra) return opaque access_tokens + # that cannot be decoded with the JWKS keys when the resource is + # a first-party Microsoft API (e.g., Graph API with User.Read scope) try: - verified_body = self.decode_token(access_token) + verified_body = self.decode_token(id_token) except JWTError: logger.exception( "Authentication error. Unverified token: %r", jwt.get_unverified_claims(id_token), ) return None - return UserSessionState(verified_body["sub"], {}) + # Use preferred_username as the user identifier, extracting just the username + # part if it's in email format (user@domain.com -> user) + preferred_username = verified_body.get("preferred_username") + if preferred_username and "@" in preferred_username: + user_id = preferred_username.split("@")[0] + elif preferred_username: + user_id = preferred_username + else: + user_id = verified_body["sub"] + logger.info( + "OIDC authentication successful. user_id=%r (sub=%r, preferred_username=%r, email=%r, name=%r)", + user_id, + verified_body.get("sub"), + verified_body.get("preferred_username"), + verified_body.get("email"), + verified_body.get("name"), + ) + return UserSessionState(user_id, {}) class ProxiedOIDCAuthenticator(OIDCAuthenticator): diff --git a/bluesky_httpserver/config_schemas/examples/oidc_config.yml b/bluesky_httpserver/config_schemas/examples/oidc_config.yml new file mode 100644 index 0000000..c2f8d24 --- /dev/null +++ b/bluesky_httpserver/config_schemas/examples/oidc_config.yml @@ -0,0 +1,78 @@ +# Example OIDC Configuration for Bluesky HTTP Server +# +# This example shows how to configure OIDC (OpenID Connect) authentication. +# OIDC is used by providers like Google, Microsoft Entra (Azure AD), Okta, Keycloak, etc. +# +# Required environment variables: +# - OIDC_CLIENT_ID: The client ID from your OIDC provider +# - OIDC_CLIENT_SECRET: The client secret from your OIDC provider +# - OIDC_WELL_KNOWN_URI: The .well-known/openid-configuration URL +# +# Example for Google: +# OIDC_WELL_KNOWN_URI=https://accounts.google.com/.well-known/openid-configuration +# +# Example for Microsoft Entra (Azure AD): +# OIDC_WELL_KNOWN_URI=https://login.microsoftonline.com/{tenant-id}/v2.0/.well-known/openid-configuration +# +# Example for Keycloak: +# OIDC_WELL_KNOWN_URI=https://your-keycloak-server/realms/{realm}/.well-known/openid-configuration + +authentication: + providers: + - provider: oidc + authenticator: bluesky_httpserver.authenticators:OIDCAuthenticator + args: + # The audience should match the client_id or be a value expected by your OIDC provider + audience: ${OIDC_CLIENT_ID} + client_id: ${OIDC_CLIENT_ID} + client_secret: ${OIDC_CLIENT_SECRET} + well_known_uri: ${OIDC_WELL_KNOWN_URI} + confirmation_message: "You have successfully logged in via OIDC as {id}." + # Optional: redirect URLs after authentication + # redirect_on_success: https://your-app.example.com/success + # redirect_on_failure: https://your-app.example.com/login-failed + + # Secret keys used to sign secure tokens (generate with: openssl rand -hex 32) + secret_keys: + - ${SECRET_KEY} + + # Allow unauthenticated access to public endpoints + allow_anonymous_access: false + + # Token lifetimes (in seconds) + access_token_max_age: 900 # 15 minutes + refresh_token_max_age: 604800 # 7 days + +# Database for storing sessions and API keys +database: + uri: ${DATABASE_URI} + pool_size: 5 + pool_pre_ping: true + +# API access control - configure which users have access +api_access: + policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl + args: + users: + # Add users identified by their OIDC subject ID (sub claim) + # The ID typically looks like an email or UUID depending on your OIDC provider + user@example.com: + roles: + - admin + - user + +# Resource access control +resource_access: + policy: bluesky_httpserver.authorization:DefaultResourceAccessControl + args: + default_group: root + +# Queue Server connection +qserver_zmq_configuration: + control_address: tcp://localhost:60615 + info_address: tcp://localhost:60625 + +# HTTP Server configuration +uvicorn: + host: 0.0.0.0 + port: 8000 diff --git a/bluesky_httpserver/config_schemas/service_configuration.yml b/bluesky_httpserver/config_schemas/service_configuration.yml index 57343f7..12f01a3 100644 --- a/bluesky_httpserver/config_schemas/service_configuration.yml +++ b/bluesky_httpserver/config_schemas/service_configuration.yml @@ -47,14 +47,14 @@ properties: properties: custom_routers: type: array - item: + items: type: string description: | The list of Python modules with custom routers. Overrides the list of modules set using QSERVER_HTTP_CUSTOM_ROUTERS environment variable. custom_modules: type: array - item: + items: type: string description: | THE FUNCTIONALITY WILL BE DEPRECATED IN FAVOR OF CUSTOM ROUTERS. Overrides the list of modules @@ -65,7 +65,7 @@ properties: properties: providers: type: array - item: + items: type: object additionalProperties: false required: @@ -83,7 +83,7 @@ properties: description: | Type of Authenticator to use. - These are typically from the tiled.authenticators module, + These are typically from the bluesky_httpserver.authenticators module, though user-defined ones may be used as well. This is given as an import path. In an import path, packages/modules @@ -92,21 +92,21 @@ properties: Example: ```yaml - authenticator: bluesky_httpserver.examples.DummyAuthenticator + authenticator: bluesky_httpserver.authenticators:DummyAuthenticator ``` - args: - type: [object, "null"] - description: | - Named arguments to pass to Authenticator. If there are none, - `args` may be omitted or empty. + args: + type: object + description: | + Named arguments to pass to Authenticator. If there are none, + `args` may be omitted or empty. - Example: + Example: - ```yaml - authenticator: bluesky_httpserver.examples.PAMAuthenticator - args: - service: "custom_service" - ``` + ```yaml + authenticator: bluesky_httpserver.authenticators:PAMAuthenticator + args: + service: "custom_service" + ``` # qserver_admins: # type: array # items: diff --git a/bluesky_httpserver/database/core.py b/bluesky_httpserver/database/core.py index 163fac3..f096edc 100644 --- a/bluesky_httpserver/database/core.py +++ b/bluesky_httpserver/database/core.py @@ -1,6 +1,7 @@ import hashlib import uuid as uuid_module from datetime import datetime +from typing import Optional from alembic import command from alembic.config import Config @@ -10,13 +11,13 @@ from .alembic_utils import temp_alembic_ini from .base import Base -from .orm import APIKey, Identity, Principal, Session # , Role +from .orm import APIKey, Identity, PendingSession, Principal, Session # , Role # This is the alembic revision ID of the database revision # required by this version of Tiled. -REQUIRED_REVISION = "722ff4e4fcc7" +REQUIRED_REVISION = "a1b2c3d4e5f6" # This is list of all valid revisions (from current to oldest). -ALL_REVISIONS = ["722ff4e4fcc7", "481830dd6c11"] +ALL_REVISIONS = ["a1b2c3d4e5f6", "722ff4e4fcc7", "481830dd6c11"] # def create_default_roles(engine): @@ -294,3 +295,36 @@ def latest_principal_activity(db, principal): if all([t is None for t in all_activity]): return None return max(t for t in all_activity if t is not None) + + +def lookup_valid_pending_session_by_device_code(db, device_code: bytes) -> Optional[PendingSession]: + """ + Look up a pending session by its device code. + + Returns None if the pending session is not found or has expired. + """ + hashed_device_code = hashlib.sha256(device_code).digest() + pending_session = db.query(PendingSession).filter(PendingSession.hashed_device_code == hashed_device_code).first() + if pending_session is None: + return None + if pending_session.expiration_time is not None and pending_session.expiration_time < datetime.utcnow(): + db.delete(pending_session) + db.commit() + return None + return pending_session + + +def lookup_valid_pending_session_by_user_code(db, user_code: str) -> Optional[PendingSession]: + """ + Look up a pending session by its user code. + + Returns None if the pending session is not found or has expired. + """ + pending_session = db.query(PendingSession).filter(PendingSession.user_code == user_code).first() + if pending_session is None: + return None + if pending_session.expiration_time is not None and pending_session.expiration_time < datetime.utcnow(): + db.delete(pending_session) + db.commit() + return None + return pending_session diff --git a/bluesky_httpserver/database/orm.py b/bluesky_httpserver/database/orm.py index 17d7c82..7611824 100644 --- a/bluesky_httpserver/database/orm.py +++ b/bluesky_httpserver/database/orm.py @@ -181,3 +181,24 @@ class Session(Timestamped, Base): revoked = Column(Boolean, default=False, nullable=False) principal = relationship("Principal", back_populates="sessions") + pending_sessions = relationship("PendingSession", back_populates="session") + + +class PendingSession(Timestamped, Base): + """ + This is used only in Device Code Flow for OIDC authentication. + + When a CLI client initiates the device code flow, a pending session is created + with a device_code (for the client to poll) and a user_code (for the user to + enter in the browser). Once the user authenticates, the pending session is + linked to a real session, which the polling client then receives. + """ + + __tablename__ = "pending_sessions" + + hashed_device_code = Column(LargeBinary(32), primary_key=True, index=True, nullable=False) + user_code = Column(Unicode(8), index=True, nullable=False) + expiration_time = Column(DateTime(timezone=False), nullable=False) + session_id = Column(Integer, ForeignKey("sessions.id"), nullable=True) + + session = relationship("Session", back_populates="pending_sessions") diff --git a/bluesky_httpserver/schemas.py b/bluesky_httpserver/schemas.py index c52d8f2..f1d9fcb 100644 --- a/bluesky_httpserver/schemas.py +++ b/bluesky_httpserver/schemas.py @@ -163,6 +163,23 @@ class RefreshToken(pydantic.BaseModel): refresh_token: str +class DeviceCode(pydantic.BaseModel): + """Schema for device code token polling request.""" + + device_code: str + + +class DeviceCodeResponse(pydantic.BaseModel): + """Schema for device code flow initiation response.""" + + authorization_uri: str + verification_uri: str + device_code: str + user_code: str + expires_in: int + interval: int + + class AuthenticationMode(str, enum.Enum): password = "password" external = "external" diff --git a/bluesky_httpserver/tests/conftest.py b/bluesky_httpserver/tests/conftest.py index ec69415..3c43529 100644 --- a/bluesky_httpserver/tests/conftest.py +++ b/bluesky_httpserver/tests/conftest.py @@ -195,3 +195,28 @@ def wait_for_ip_kernel_idle(timeout, polling_period=0.2, api_key=API_KEY_FOR_TES return True return False + + +# ============================================================================ +# OIDC Test Fixtures +# ============================================================================ + +@pytest.fixture +def oidc_base_url() -> str: + """Base URL for mock OIDC provider.""" + return "https://example.com/realms/example/" + + +@pytest.fixture +def well_known_response(oidc_base_url: str) -> dict: + """Mock OIDC well-known configuration response.""" + return { + "id_token_signing_alg_values_supported": ["RS256"], + "issuer": oidc_base_url.rstrip("/"), + "jwks_uri": f"{oidc_base_url}protocol/openid-connect/certs", + "authorization_endpoint": f"{oidc_base_url}protocol/openid-connect/auth", + "token_endpoint": f"{oidc_base_url}protocol/openid-connect/token", + "device_authorization_endpoint": f"{oidc_base_url}protocol/openid-connect/auth/device", + "end_session_endpoint": f"{oidc_base_url}protocol/openid-connect/logout", + } + diff --git a/bluesky_httpserver/tests/test_oidc_authenticators.py b/bluesky_httpserver/tests/test_oidc_authenticators.py new file mode 100644 index 0000000..30303e4 --- /dev/null +++ b/bluesky_httpserver/tests/test_oidc_authenticators.py @@ -0,0 +1,224 @@ +"""Tests for OIDC Authenticator functionality.""" + +import time +from typing import Any, Tuple + +import httpx +import pytest +from cryptography.hazmat.primitives.asymmetric import rsa +from jose import ExpiredSignatureError, jwt +from jose.backends import RSAKey +from respx import MockRouter + +from bluesky_httpserver.authenticators import OIDCAuthenticator, ProxiedOIDCAuthenticator + + +@pytest.fixture +def oidc_well_known_url(oidc_base_url: str) -> str: + return f"{oidc_base_url}.well-known/openid-configuration" + + +@pytest.fixture +def keys() -> Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: + """Generate RSA key pair for testing.""" + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + public_key = private_key.public_key() + return (private_key, public_key) + + +@pytest.fixture +def json_web_keyset(keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]) -> list[dict[str, Any]]: + """Create a JSON Web Key Set from the test keys.""" + _, public_key = keys + return [RSAKey(key=public_key, algorithm="RS256").to_dict()] + + +@pytest.fixture +def mock_oidc_server( + respx_mock: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + json_web_keyset: list[dict[str, Any]], +) -> MockRouter: + """Set up mock OIDC server endpoints.""" + respx_mock.get(oidc_well_known_url).mock( + return_value=httpx.Response(httpx.codes.OK, json=well_known_response) + ) + respx_mock.get(well_known_response["jwks_uri"]).mock( + return_value=httpx.Response(httpx.codes.OK, json={"keys": json_web_keyset}) + ) + return respx_mock + + +def create_token(issued: bool, expired: bool) -> dict[str, Any]: + """Create a test JWT token.""" + now = time.time() + return { + "aud": "test_client", + "exp": (now - 1500) if expired else (now + 1500), + "iat": (now - 1500) if issued else (now + 1500), + "iss": "https://example.com/realms/example", + "sub": "test_user", + } + + +def encrypt_token(token: dict[str, Any], private_key: rsa.RSAPrivateKey) -> str: + """Encrypt a token with the test private key.""" + return jwt.encode( + token, + key=private_key, + algorithm="RS256", + headers={"kid": "test_key"}, + ) + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +class TestOIDCAuthenticator: + """Tests for OIDCAuthenticator class.""" + + def test_oidc_authenticator_caching( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + json_web_keyset: list[dict[str, Any]], + ): + """Test that OIDC configuration is cached after first fetch.""" + authenticator = OIDCAuthenticator( + audience="test_client", + client_id="test_client", + client_secret="secret", + well_known_uri=oidc_well_known_url, + ) + + # Access multiple properties to ensure caching works + assert authenticator.client_id == "test_client" + assert authenticator.authorization_endpoint == well_known_response["authorization_endpoint"] + assert ( + authenticator.id_token_signing_alg_values_supported + == well_known_response["id_token_signing_alg_values_supported"] + ) + assert authenticator.issuer == well_known_response["issuer"] + assert authenticator.jwks_uri == well_known_response["jwks_uri"] + assert authenticator.token_endpoint == well_known_response["token_endpoint"] + assert ( + authenticator.device_authorization_endpoint + == well_known_response["device_authorization_endpoint"] + ) + assert authenticator.end_session_endpoint == well_known_response["end_session_endpoint"] + + # Should only call well-known endpoint once due to caching + assert len(mock_oidc_server.calls) == 1 + call_request = mock_oidc_server.calls[0].request + assert call_request.method == "GET" + assert call_request.url == oidc_well_known_url + + # Keys should also be cached + assert authenticator.keys() == json_web_keyset + assert len(mock_oidc_server.calls) == 2 # Now also fetched JWKS + + # Multiple calls should still be cached + for _ in range(5): + assert authenticator.keys() == json_web_keyset + assert len(mock_oidc_server.calls) == 2 # No new calls + + @pytest.mark.parametrize("issued", [True, False]) + @pytest.mark.parametrize("expired", [True, False]) + def test_oidc_token_decoding( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + issued: bool, + expired: bool, + keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], + ): + """Test token decoding with various validity scenarios.""" + private_key, _ = keys + authenticator = OIDCAuthenticator( + audience="test_client", + client_id="test_client", + client_secret="secret", + well_known_uri=oidc_well_known_url, + ) + + token = create_token(issued, expired) + encrypted = encrypt_token(token, private_key) + + if not expired: + # Non-expired tokens should decode successfully + decoded = authenticator.decode_token(encrypted) + assert decoded["sub"] == "test_user" + assert decoded["aud"] == "test_client" + else: + # Expired tokens should raise an error + with pytest.raises(ExpiredSignatureError): + authenticator.decode_token(encrypted) + + def test_oidc_authenticator_properties( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + ): + """Test that all authenticator properties are correctly set.""" + authenticator = OIDCAuthenticator( + audience="my_audience", + client_id="my_client_id", + client_secret="my_secret", + well_known_uri=oidc_well_known_url, + confirmation_message="Logged in as {id}", + redirect_on_success="https://app.example.com/success", + redirect_on_failure="https://app.example.com/failure", + ) + + assert authenticator.client_id == "my_client_id" + assert authenticator.confirmation_message == "Logged in as {id}" + assert authenticator.redirect_on_success == "https://app.example.com/success" + assert authenticator.redirect_on_failure == "https://app.example.com/failure" + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +class TestProxiedOIDCAuthenticator: + """Tests for ProxiedOIDCAuthenticator class.""" + + @pytest.mark.asyncio + async def test_proxied_oidc_oauth2_schema( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + ): + """Test that ProxiedOIDCAuthenticator extracts bearer token correctly.""" + authenticator = ProxiedOIDCAuthenticator( + audience="test_client", + client_id="test_client", + well_known_uri=oidc_well_known_url, + device_flow_client_id="test_cli_client", + ) + + # Create a mock request with Authorization header + test_request = httpx.Request( + "GET", + "http://example.com/api/test", + headers={"Authorization": "Bearer TEST_TOKEN"}, + ) + + # The oauth2_schema should extract the bearer token + token = await authenticator.oauth2_schema(test_request) + assert token == "TEST_TOKEN" + + def test_proxied_oidc_with_scopes( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + ): + """Test ProxiedOIDCAuthenticator with custom scopes.""" + authenticator = ProxiedOIDCAuthenticator( + audience="test_client", + client_id="test_client", + well_known_uri=oidc_well_known_url, + device_flow_client_id="test_cli_client", + scopes=["openid", "profile", "email"], + ) + + assert authenticator.scopes == ["openid", "profile", "email"] + assert authenticator.device_flow_client_id == "test_cli_client" diff --git a/requirements-dev.txt b/requirements-dev.txt index dd7212a..e47dd72 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,13 +3,16 @@ black codecov coverage +cryptography fastapi[all] flake8 isort pre-commit pytest +pytest-asyncio pytest-xprocess py +respx sphinx ipython numpydoc diff --git a/requirements.txt b/requirements.txt index 818362f..1377ef0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ bluesky-queueserver bluesky-queueserver-api cachetools fastapi +httpx ldap3 orjson pamela From d90ad0cad8720d09786dfd0f02a5250c996f7165 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Tue, 17 Feb 2026 15:14:06 -0600 Subject: [PATCH 05/13] Removing some unnecessary code. --- bluesky_httpserver/_authentication.py | 2 -- bluesky_httpserver/authenticators.py | 1 - 2 files changed, 3 deletions(-) diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py index a0d28b1..c745dff 100644 --- a/bluesky_httpserver/_authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -721,8 +721,6 @@ async def device_code_submit( """ return HTMLResponse(content=error_html, status_code=403) - scopes = api_access_manager.get_user_scopes(username) - # Create the session session = await asyncio.get_running_loop().run_in_executor( None, _create_session_orm, settings, provider, username, db diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index e8d108d..a58fedf 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -222,7 +222,6 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]: if response.is_error: logger.error("Authentication error: %r", response_body) return None - response_body = response.json() id_token = response_body["id_token"] # NOTE: We decode the id_token, not access_token, because: # 1. The id_token is the OIDC identity assertion meant for the client From 24857905f573f805940d6e3a0c4cefd409b87bf9 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Mon, 23 Feb 2026 10:06:43 -0600 Subject: [PATCH 06/13] Working example that does not require device-codes This solves the problem that what was implemented was actually authenticating the application and not the user like expected. It worked but it required that the user input a code. This solves that problem so that when you click the login link, if you are already logged in with you SSO provider you'll just automatically log in to the HTTP Server. Likewise if you use the bluesky queueserver api, when you call RM.Login you'll automatically be logged in, no user interaction required. --- bluesky_httpserver/_authentication.py | 193 +++++++++++------- .../config_schemas/examples/oidc_config.yml | 78 ------- .../config_schemas/service_configuration.yml | 43 ++-- 3 files changed, 128 insertions(+), 186 deletions(-) delete mode 100644 bluesky_httpserver/config_schemas/examples/oidc_config.yml diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py index c745dff..0cb046f 100644 --- a/bluesky_httpserver/_authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -597,6 +597,7 @@ async def device_code_authorize( "response_type": "code", "scope": "openid profile email", "redirect_uri": f"{get_base_url(request)}/auth/provider/{provider}/device_code", + "state": pending_session["user_code"].replace("-", ""), } ) return { @@ -611,66 +612,23 @@ async def device_code_authorize( return device_code_authorize -def build_device_code_form_route(authenticator, provider): - """Build a GET route that shows the user code entry form.""" - - async def device_code_form( - request: Request, - code: str, - ): - """Show form for user to enter user code after browser auth.""" - action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" - html_content = f""" - - - - Authorize Session - - - -

Authorize Bluesky HTTP Server Session

-
- - - -
- -
- - -""" - return HTMLResponse(content=html_content) - - return device_code_form - - -def build_device_code_submit_route(authenticator, provider): - """Build a POST route that handles user code submission after browser auth.""" - - async def device_code_submit( - request: Request, - code: str = Form(), - user_code: str = Form(), - settings: BaseSettings = Depends(get_settings), - api_access_manager=Depends(get_api_access_manager), - ): - """Handle user code submission and link to authenticated session.""" - request.state.endpoint = "auth" - action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" - normalized_user_code = user_code.upper().replace("-", "").strip() +async def _complete_device_code_authorization( + request: Request, + authenticator, + provider: str, + code: str, + user_code: str, + settings: BaseSettings, + api_access_manager, +): + request.state.endpoint = "auth" + action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" + normalized_user_code = user_code.upper().replace("-", "").strip() - with get_sessionmaker(settings.database_settings)() as db: - pending_session = lookup_valid_pending_session_by_user_code(db, normalized_user_code) - if pending_session is None: - error_html = f""" + with get_sessionmaker(settings.database_settings)() as db: + pending_session = lookup_valid_pending_session_by_user_code(db, normalized_user_code) + if pending_session is None: + error_html = f""" Error @@ -684,12 +642,12 @@ async def device_code_submit( """ - return HTMLResponse(content=error_html, status_code=401) + return HTMLResponse(content=error_html, status_code=401) - # Authenticate with the OIDC provider using the authorization code - user_session_state = await authenticator.authenticate(request) - if not user_session_state: - error_html = """ + # Authenticate with the OIDC provider using the authorization code + user_session_state = await authenticator.authenticate(request) + if not user_session_state: + error_html = """ Authentication Failed @@ -702,11 +660,11 @@ async def device_code_submit( """ - return HTMLResponse(content=error_html, status_code=401) + return HTMLResponse(content=error_html, status_code=401) - username = user_session_state.user_name - if not api_access_manager.is_user_known(username): - error_html = f""" + username = user_session_state.user_name + if not api_access_manager.is_user_known(username): + error_html = f""" Authorization Failed @@ -719,19 +677,19 @@ async def device_code_submit( """ - return HTMLResponse(content=error_html, status_code=403) + return HTMLResponse(content=error_html, status_code=403) - # Create the session - session = await asyncio.get_running_loop().run_in_executor( - None, _create_session_orm, settings, provider, username, db - ) + # Create the session + session = await asyncio.get_running_loop().run_in_executor( + None, _create_session_orm, settings, provider, username, db + ) - # Link the pending session to the real session - pending_session.session_id = session.id - db.add(pending_session) - db.commit() + # Link the pending session to the real session + pending_session.session_id = session.id + db.add(pending_session) + db.commit() - success_html = f""" + success_html = f""" Success @@ -744,7 +702,84 @@ async def device_code_submit( """ - return HTMLResponse(content=success_html) + return HTMLResponse(content=success_html) + + +def build_device_code_form_route(authenticator, provider): + """Build a GET route that shows the user code entry form.""" + + async def device_code_form( + request: Request, + code: str, + state: Optional[str] = Query(None), + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """Show form for user to enter user code after browser auth.""" + if state: + return await _complete_device_code_authorization( + request=request, + authenticator=authenticator, + provider=provider, + code=code, + user_code=state, + settings=settings, + api_access_manager=api_access_manager, + ) + + action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" + html_content = f""" + + + + Authorize Session + + + +

Authorize Bluesky HTTP Server Session

+
+ + + +
+ +
+ + +""" + return HTMLResponse(content=html_content) + + return device_code_form + + +def build_device_code_submit_route(authenticator, provider): + """Build a POST route that handles user code submission after browser auth.""" + + async def device_code_submit( + request: Request, + code: str = Form(), + user_code: str = Form(), + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """Handle user code submission and link to authenticated session.""" + return await _complete_device_code_authorization( + request=request, + authenticator=authenticator, + provider=provider, + code=code, + user_code=user_code, + settings=settings, + api_access_manager=api_access_manager, + ) return device_code_submit diff --git a/bluesky_httpserver/config_schemas/examples/oidc_config.yml b/bluesky_httpserver/config_schemas/examples/oidc_config.yml deleted file mode 100644 index c2f8d24..0000000 --- a/bluesky_httpserver/config_schemas/examples/oidc_config.yml +++ /dev/null @@ -1,78 +0,0 @@ -# Example OIDC Configuration for Bluesky HTTP Server -# -# This example shows how to configure OIDC (OpenID Connect) authentication. -# OIDC is used by providers like Google, Microsoft Entra (Azure AD), Okta, Keycloak, etc. -# -# Required environment variables: -# - OIDC_CLIENT_ID: The client ID from your OIDC provider -# - OIDC_CLIENT_SECRET: The client secret from your OIDC provider -# - OIDC_WELL_KNOWN_URI: The .well-known/openid-configuration URL -# -# Example for Google: -# OIDC_WELL_KNOWN_URI=https://accounts.google.com/.well-known/openid-configuration -# -# Example for Microsoft Entra (Azure AD): -# OIDC_WELL_KNOWN_URI=https://login.microsoftonline.com/{tenant-id}/v2.0/.well-known/openid-configuration -# -# Example for Keycloak: -# OIDC_WELL_KNOWN_URI=https://your-keycloak-server/realms/{realm}/.well-known/openid-configuration - -authentication: - providers: - - provider: oidc - authenticator: bluesky_httpserver.authenticators:OIDCAuthenticator - args: - # The audience should match the client_id or be a value expected by your OIDC provider - audience: ${OIDC_CLIENT_ID} - client_id: ${OIDC_CLIENT_ID} - client_secret: ${OIDC_CLIENT_SECRET} - well_known_uri: ${OIDC_WELL_KNOWN_URI} - confirmation_message: "You have successfully logged in via OIDC as {id}." - # Optional: redirect URLs after authentication - # redirect_on_success: https://your-app.example.com/success - # redirect_on_failure: https://your-app.example.com/login-failed - - # Secret keys used to sign secure tokens (generate with: openssl rand -hex 32) - secret_keys: - - ${SECRET_KEY} - - # Allow unauthenticated access to public endpoints - allow_anonymous_access: false - - # Token lifetimes (in seconds) - access_token_max_age: 900 # 15 minutes - refresh_token_max_age: 604800 # 7 days - -# Database for storing sessions and API keys -database: - uri: ${DATABASE_URI} - pool_size: 5 - pool_pre_ping: true - -# API access control - configure which users have access -api_access: - policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl - args: - users: - # Add users identified by their OIDC subject ID (sub claim) - # The ID typically looks like an email or UUID depending on your OIDC provider - user@example.com: - roles: - - admin - - user - -# Resource access control -resource_access: - policy: bluesky_httpserver.authorization:DefaultResourceAccessControl - args: - default_group: root - -# Queue Server connection -qserver_zmq_configuration: - control_address: tcp://localhost:60615 - info_address: tcp://localhost:60625 - -# HTTP Server configuration -uvicorn: - host: 0.0.0.0 - port: 8000 diff --git a/bluesky_httpserver/config_schemas/service_configuration.yml b/bluesky_httpserver/config_schemas/service_configuration.yml index 12f01a3..a76e4d3 100644 --- a/bluesky_httpserver/config_schemas/service_configuration.yml +++ b/bluesky_httpserver/config_schemas/service_configuration.yml @@ -47,14 +47,14 @@ properties: properties: custom_routers: type: array - items: + item: type: string description: | The list of Python modules with custom routers. Overrides the list of modules set using QSERVER_HTTP_CUSTOM_ROUTERS environment variable. custom_modules: type: array - items: + item: type: string description: | THE FUNCTIONALITY WILL BE DEPRECATED IN FAVOR OF CUSTOM ROUTERS. Overrides the list of modules @@ -65,7 +65,7 @@ properties: properties: providers: type: array - items: + item: type: object additionalProperties: false required: @@ -94,34 +94,19 @@ properties: ```yaml authenticator: bluesky_httpserver.authenticators:DummyAuthenticator ``` - args: - type: object - description: | - Named arguments to pass to Authenticator. If there are none, - `args` may be omitted or empty. + args: + type: object + description: | + Named arguments to pass to Authenticator. If there are none, + `args` may be omitted or empty. - Example: + Example: - ```yaml - authenticator: bluesky_httpserver.authenticators:PAMAuthenticator - args: - service: "custom_service" - ``` - # qserver_admins: - # type: array - # items: - # type: object - # additionalProperties: false - # required: - # - provider - # - id - # properties: - # provider: - # type: string - # id: - # type: string - # description: | - # Give users with these identities 'admin' Role. + ```yaml + authenticator: bluesky_httpserver.authenticators:PAMAuthenticator + args: + service: "custom_service" + ``` secret_keys: type: array items: From 96cd9db5b4f23104d737cbc90f00bc9bd67d07c4 Mon Sep 17 00:00:00 2001 From: David Pastl Date: Mon, 23 Feb 2026 11:00:17 -0600 Subject: [PATCH 07/13] Fixes from running black --- bluesky_httpserver/database/core.py | 4 +++- bluesky_httpserver/tests/conftest.py | 2 +- bluesky_httpserver/tests/test_oidc_authenticators.py | 9 ++------- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/bluesky_httpserver/database/core.py b/bluesky_httpserver/database/core.py index f096edc..52d102f 100644 --- a/bluesky_httpserver/database/core.py +++ b/bluesky_httpserver/database/core.py @@ -304,7 +304,9 @@ def lookup_valid_pending_session_by_device_code(db, device_code: bytes) -> Optio Returns None if the pending session is not found or has expired. """ hashed_device_code = hashlib.sha256(device_code).digest() - pending_session = db.query(PendingSession).filter(PendingSession.hashed_device_code == hashed_device_code).first() + pending_session = ( + db.query(PendingSession).filter(PendingSession.hashed_device_code == hashed_device_code).first() + ) if pending_session is None: return None if pending_session.expiration_time is not None and pending_session.expiration_time < datetime.utcnow(): diff --git a/bluesky_httpserver/tests/conftest.py b/bluesky_httpserver/tests/conftest.py index 3c43529..8851e71 100644 --- a/bluesky_httpserver/tests/conftest.py +++ b/bluesky_httpserver/tests/conftest.py @@ -201,6 +201,7 @@ def wait_for_ip_kernel_idle(timeout, polling_period=0.2, api_key=API_KEY_FOR_TES # OIDC Test Fixtures # ============================================================================ + @pytest.fixture def oidc_base_url() -> str: """Base URL for mock OIDC provider.""" @@ -219,4 +220,3 @@ def well_known_response(oidc_base_url: str) -> dict: "device_authorization_endpoint": f"{oidc_base_url}protocol/openid-connect/auth/device", "end_session_endpoint": f"{oidc_base_url}protocol/openid-connect/logout", } - diff --git a/bluesky_httpserver/tests/test_oidc_authenticators.py b/bluesky_httpserver/tests/test_oidc_authenticators.py index 30303e4..f3249cd 100644 --- a/bluesky_httpserver/tests/test_oidc_authenticators.py +++ b/bluesky_httpserver/tests/test_oidc_authenticators.py @@ -41,9 +41,7 @@ def mock_oidc_server( json_web_keyset: list[dict[str, Any]], ) -> MockRouter: """Set up mock OIDC server endpoints.""" - respx_mock.get(oidc_well_known_url).mock( - return_value=httpx.Response(httpx.codes.OK, json=well_known_response) - ) + respx_mock.get(oidc_well_known_url).mock(return_value=httpx.Response(httpx.codes.OK, json=well_known_response)) respx_mock.get(well_known_response["jwks_uri"]).mock( return_value=httpx.Response(httpx.codes.OK, json={"keys": json_web_keyset}) ) @@ -101,10 +99,7 @@ def test_oidc_authenticator_caching( assert authenticator.issuer == well_known_response["issuer"] assert authenticator.jwks_uri == well_known_response["jwks_uri"] assert authenticator.token_endpoint == well_known_response["token_endpoint"] - assert ( - authenticator.device_authorization_endpoint - == well_known_response["device_authorization_endpoint"] - ) + assert authenticator.device_authorization_endpoint == well_known_response["device_authorization_endpoint"] assert authenticator.end_session_endpoint == well_known_response["end_session_endpoint"] # Should only call well-known endpoint once due to caching From 967fcbab3de634ce66f79ba9451173435476edbe Mon Sep 17 00:00:00 2001 From: David Pastl Date: Mon, 23 Feb 2026 13:07:07 -0600 Subject: [PATCH 08/13] Adding documentation on how to use OIDC --- docs/source/configuration.rst | 79 +++++++++++++++++++++++++++++++++++ docs/source/usage.rst | 43 +++++++++++++++++++ 2 files changed, 122 insertions(+) diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst index eb31efa..8852a31 100644 --- a/docs/source/configuration.rst +++ b/docs/source/configuration.rst @@ -294,6 +294,85 @@ See the documentation on ``LDAPAuthenticator`` for more details. authenticators.LDAPAuthenticator +OIDC Authenticator +++++++++++++++++++ + +``OIDCAuthenticator`` integrates the server with third-party OpenID Connect providers +such as Google, Microsoft Entra ID, ORCID and others. The server does not process user +passwords directly: authentication is delegated to the provider and the server validates +the returned OIDC token. + +General setup steps: + +#. Register an application with the OIDC provider. +#. Configure redirect URIs for the provider application. For provider name ``entra`` and + host ``https://your-server.example`` the redirect URIs are: + + - ``https://your-server.example/api/auth/provider/entra/code`` + - ``https://your-server.example/api/auth/provider/entra/device_code`` + +#. Store the client secret in environment variable and reference it in config. +#. Use provider's ``.well-known/openid-configuration`` URL. + +Typical ``well_known_uri`` values: + +- Google: ``https://accounts.google.com/.well-known/openid-configuration`` +- Microsoft Entra ID: ``https://login.microsoftonline.com//v2.0/.well-known/openid-configuration`` +- ORCID: ``https://orcid.org/.well-known/openid-configuration`` + +Example configuration (Microsoft Entra ID):: + + authentication: + providers: + - provider: entra + authenticator: bluesky_httpserver.authenticators:OIDCAuthenticator + args: + audience: 00000000-0000-0000-0000-000000000000 + client_id: 00000000-0000-0000-0000-000000000000 + client_secret: ${BSKY_ENTRA_SECRET} + well_known_uri: https://login.microsoftonline.com//v2.0/.well-known/openid-configuration + confirmation_message: "You have logged in successfully." + api_access: + policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl + args: + users: + : + roles: + - admin + - expert + +Example configuration (Google):: + + authentication: + providers: + - provider: google + authenticator: bluesky_httpserver.authenticators:OIDCAuthenticator + args: + audience: + client_id: + client_secret: ${BSKY_GOOGLE_SECRET} + well_known_uri: https://accounts.google.com/.well-known/openid-configuration + api_access: + policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl + args: + users: + : + roles: user + +.. note:: + + The name used in ``api_access/args/users`` must match the identity string produced by + the authenticator for your provider configuration. Verify with ``/api/auth/whoami`` after + successful login. + +See the documentation on ``OIDCAuthenticator`` for parameter details. + +.. autosummary:: + :nosignatures: + :toctree: generated + + authenticators.OIDCAuthenticator + Expiration Time for Tokens and Sessions +++++++++++++++++++++++++++++++++++++++ diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 5e1e9b3..d6c3a10 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -154,6 +154,49 @@ Then users ``bob``, ``alice`` and ``tom`` can log into the server as :: If authentication is successful, then the server returns access and refresh tokens. +Logging in with OIDC Providers (Google, Entra, ORCID, ...) +----------------------------------------------------------- + +For providers configured with ``OIDCAuthenticator``, use provider-specific endpoints +under ``/api/auth/provider//...``. + +Browser-first flow +++++++++++++++++++ + +If you are already in a browser context, open: + +``/api/auth/provider//authorize`` + +This redirects to the OIDC provider login page and then back to the server callback. + +CLI/device flow ++++++++++++++++ + +For terminal clients, start with ``POST /api/auth/provider//authorize``. +The response includes: + +- ``authorization_uri``: open this URL in a browser +- ``verification_uri``: polling endpoint for the terminal client +- ``device_code`` and ``interval``: values for polling + +Example using ``httpie`` (provider ``entra``):: + + http POST http://localhost:60610/api/auth/provider/entra/authorize + +After opening ``authorization_uri`` in a browser and completing provider login, +poll ``verification_uri`` using ``device_code`` until tokens are issued:: + + http POST http://localhost:60610/api/auth/provider/entra/token \ + device_code='' + +When authorization is still pending, the endpoint returns ``authorization_pending``. +When complete, it returns access and refresh tokens. + +.. note:: + + In common same-device flows the callback can complete automatically without manually + typing the user code. Manual code entry remains available as a fallback path. + Generating API Keys ------------------- From 5906b28b889224527c937da912944460513723ac Mon Sep 17 00:00:00 2001 From: David Pastl Date: Tue, 24 Feb 2026 16:17:24 -0600 Subject: [PATCH 09/13] Fixes for unit tests, moving start LDAP These should correct some of the problems in the last CI workflow. I moved the LDAP and docker image into the continuous_integration folder so it matches tiled. --- .github/workflows/testing.yml | 2 +- bluesky_httpserver/_authentication.py | 79 +++++- bluesky_httpserver/tests/conftest.py | 19 +- .../tests/test_authenticators.py | 245 +++++++++++++++++- .../docker-configs/ldap-docker-compose.yml | 6 +- continuous_integration/scripts/start_LDAP.sh | 7 + docs/source/usage.rst | 4 +- start_LDAP.sh | 8 - 8 files changed, 340 insertions(+), 30 deletions(-) rename {.github/workflows => continuous_integration}/docker-configs/ldap-docker-compose.yml (74%) create mode 100755 continuous_integration/scripts/start_LDAP.sh delete mode 100644 start_LDAP.sh diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index b7d9d54..5355c05 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -43,7 +43,7 @@ jobs: popd # Start LDAP - source start_LDAP.sh + source continuous_integration/scripts/start_LDAP.sh # These packages are installed in the base environment but may be older # versions. Explicitly upgrade them because they often create diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py index 0cb046f..c1144f5 100644 --- a/bluesky_httpserver/_authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -632,12 +632,22 @@ async def _complete_device_code_authorization( Error - +

Authorization Failed

-
Invalid user code. It may have been mistyped, or the pending request may have expired.
+
+ Invalid user code. It may have been mistyped, or the pending request may have expired. +

Try again @@ -651,12 +661,23 @@ async def _complete_device_code_authorization( Authentication Failed - +

Authentication Failed

-
User code was correct but authentication with the identity provider failed. Please contact the administrator.
+
+ User code was correct but authentication with the identity provider failed. + Please contact the administrator. +
""" @@ -668,8 +689,16 @@ async def _complete_device_code_authorization( Authorization Failed - +

Authorization Failed

@@ -693,12 +722,23 @@ async def _complete_device_code_authorization( Success - +

Success!

-
You have been authenticated. Return to your terminal application - within {DEVICE_CODE_POLLING_INTERVAL} seconds it should be successfully logged in.
+
+ You have been authenticated. Return to your terminal application - + within {DEVICE_CODE_POLLING_INTERVAL} seconds it should be successfully logged in. +
""" @@ -738,8 +778,21 @@ async def device_code_form( h1 {{ color: #333; }} form {{ margin-top: 20px; }} label {{ display: block; margin-bottom: 10px; }} - input[type="text"] {{ padding: 10px; font-size: 16px; width: 200px; text-transform: uppercase; }} - input[type="submit"] {{ padding: 10px 20px; font-size: 16px; background-color: #007bff; color: white; border: none; cursor: pointer; margin-top: 10px; }} + input[type="text"] {{ + padding: 10px; + font-size: 16px; + width: 200px; + text-transform: uppercase; + }} + input[type="submit"] {{ + padding: 10px 20px; + font-size: 16px; + background-color: #007bff; + color: white; + border: none; + cursor: pointer; + margin-top: 10px; + }} input[type="submit"]:hover {{ background-color: #0056b3; }} diff --git a/bluesky_httpserver/tests/conftest.py b/bluesky_httpserver/tests/conftest.py index 8851e71..d5cafdb 100644 --- a/bluesky_httpserver/tests/conftest.py +++ b/bluesky_httpserver/tests/conftest.py @@ -18,6 +18,22 @@ _user_group = "primary" +def _wait_for_http_server_ready(*, timeout=10, request_prefix="/api"): + """Wait until HTTP server accepts connections and responds to /status.""" + t_stop = ttime.time() + timeout + url = f"http://{SERVER_ADDRESS}:{SERVER_PORT}{request_prefix}/status" + while ttime.time() < t_stop: + try: + response = requests.get(url, timeout=0.5) + # Any HTTP response means the server is up (auth may still reject request). + if response.status_code: + return + except requests.RequestException: + pass + ttime.sleep(0.1) + raise TimeoutError(f"HTTP server is not ready after {timeout} s: {url}") + + @pytest.fixture(scope="module") def fastapi_server(xprocess): class Starter(ProcessStarter): @@ -29,6 +45,7 @@ class Starter(ProcessStarter): # args = f"start-bluesky-httpserver --host={SERVER_ADDRESS} --port {SERVER_PORT}".split() xprocess.ensure("fastapi_server", Starter) + _wait_for_http_server_ready() yield @@ -55,7 +72,7 @@ class Starter(ProcessStarter): args = f"uvicorn --host={http_server_host} --port {http_server_port} {bqss.__name__}:app".split() xprocess.ensure("fastapi_server", Starter) - ttime.sleep(1) + _wait_for_http_server_ready() yield start diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py index 183ce75..28e2601 100644 --- a/bluesky_httpserver/tests/test_authenticators.py +++ b/bluesky_httpserver/tests/test_authenticators.py @@ -1,9 +1,17 @@ import asyncio +import time +from typing import Any, Tuple +import httpx import pytest +from cryptography.hazmat.primitives.asymmetric import rsa +from jose import ExpiredSignatureError, jwt +from jose.backends import RSAKey +from respx import MockRouter +from starlette.datastructures import QueryParams, URL # fmt: off -from ..authenticators import LDAPAuthenticator, UserSessionState +from ..authenticators import LDAPAuthenticator, OIDCAuthenticator, ProxiedOIDCAuthenticator, UserSessionState @pytest.mark.parametrize("ldap_server_address, ldap_server_port", [ @@ -41,3 +49,238 @@ async def testing(): assert await authenticator.authenticate("user02", "password2a") is None asyncio.run(testing()) + + +@pytest.fixture +def oidc_well_known_url(oidc_base_url: str) -> str: + return f"{oidc_base_url}.well-known/openid-configuration" + + +@pytest.fixture +def keys() -> Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + public_key = private_key.public_key() + return (private_key, public_key) + + +@pytest.fixture +def json_web_keyset(keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]) -> list[dict[str, Any]]: + _, public_key = keys + return [RSAKey(key=public_key, algorithm="RS256").to_dict()] + + +@pytest.fixture +def mock_oidc_server( + respx_mock: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + json_web_keyset: list[dict[str, Any]], +) -> MockRouter: + respx_mock.get(oidc_well_known_url).mock(return_value=httpx.Response(httpx.codes.OK, json=well_known_response)) + respx_mock.get(well_known_response["jwks_uri"]).mock( + return_value=httpx.Response(httpx.codes.OK, json={"keys": json_web_keyset}) + ) + return respx_mock + + +def token(issued: bool, expired: bool) -> dict[str, str]: + now = time.time() + return { + "aud": "tiled", + "exp": (now - 1500) if expired else (now + 1500), + "iat": (now - 1500) if issued else (now + 1500), + "iss": "https://example.com/realms/example", + "sub": "Jane Doe", + } + + +def encrypted_token(token_data: dict[str, str], private_key: rsa.RSAPrivateKey) -> str: + return jwt.encode( + token_data, + key=private_key, + algorithm="RS256", + headers={"kid": "secret"}, + ) + + +def test_oidc_authenticator_caching( + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + json_web_keyset: list[dict[str, Any]], +): + authenticator = OIDCAuthenticator("tiled", "tiled", "secret", well_known_uri=oidc_well_known_url) + assert authenticator.client_id == "tiled" + assert authenticator.authorization_endpoint == well_known_response["authorization_endpoint"] + assert authenticator.id_token_signing_alg_values_supported == well_known_response[ + "id_token_signing_alg_values_supported" + ] + assert authenticator.issuer == well_known_response["issuer"] + assert authenticator.jwks_uri == well_known_response["jwks_uri"] + assert authenticator.token_endpoint == well_known_response["token_endpoint"] + assert authenticator.device_authorization_endpoint == well_known_response["device_authorization_endpoint"] + assert authenticator.end_session_endpoint == well_known_response["end_session_endpoint"] + + assert len(mock_oidc_server.calls) == 1 + call_request = mock_oidc_server.calls[0].request + assert call_request.method == "GET" + assert call_request.url == oidc_well_known_url + + assert authenticator.keys() == json_web_keyset + assert len(mock_oidc_server.calls) == 2 + keys_request = mock_oidc_server.calls[1].request + assert keys_request.method == "GET" + assert keys_request.url == well_known_response["jwks_uri"] + + for _ in range(10): + assert authenticator.keys() == json_web_keyset + + assert len(mock_oidc_server.calls) == 2 + + +@pytest.mark.parametrize("issued", [True, False]) +@pytest.mark.parametrize("expired", [True, False]) +def test_oidc_decoding( + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + issued: bool, + expired: bool, + keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], +): + private_key, _ = keys + authenticator = OIDCAuthenticator("tiled", "tiled", "secret", well_known_uri=oidc_well_known_url) + access_token = token(issued, expired) + encrypted_access_token = encrypted_token(access_token, private_key) + + if not expired: + assert authenticator.decode_token(encrypted_access_token) == access_token + else: + with pytest.raises(ExpiredSignatureError): + authenticator.decode_token(encrypted_access_token) + + +@pytest.mark.asyncio +async def test_proxied_oidc_token_retrieval(oidc_well_known_url: str, mock_oidc_server: MockRouter): + authenticator = ProxiedOIDCAuthenticator("tiled", "tiled", oidc_well_known_url, + device_flow_client_id="tiled-cli") + test_request = httpx.Request("GET", "http://example.com", headers={"Authorization": "bearer FOO"}) + + assert "FOO" == await authenticator.oauth2_schema(test_request) + + +def create_mock_oidc_request(query_params=None): + if query_params is None: + query_params = {} + + class MockRequest: + def __init__(self, request_query_params): + self.query_params = QueryParams(request_query_params) + self.scope = { + "type": "http", + "scheme": "http", + "server": ("localhost", 8000), + "path": "/api/v1/auth/provider/orcid/code", + "headers": [], + } + self.headers = {"host": "localhost:8000"} + self.url = URL("http://localhost:8000/api/v1/auth/provider/orcid/code") + + return MockRequest(query_params) + + +@pytest.mark.asyncio +async def test_OIDCAuthenticator_mock( + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + monkeypatch, +): + mock_jwt_payload = { + "sub": "0009-0008-8698-7745", + "aud": "APP-TEST-CLIENT-ID", + "iss": well_known_response["issuer"], + "exp": 9999999999, + "iat": 1000000000, + "given_name": "Test User", + } + + mock_oidc_server.post(well_known_response["token_endpoint"]).mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "mock-access-token", + "id_token": "mock-id-token", + "token_type": "bearer", + }, + ) + ) + + authenticator = OIDCAuthenticator( + audience="APP-TEST-CLIENT-ID", + client_id="APP-TEST-CLIENT-ID", + client_secret="test-secret", + well_known_uri=oidc_well_known_url, + ) + + mock_request = create_mock_oidc_request({"code": "test-auth-code"}) + + def mock_jwt_decode(*args, **kwargs): + return mock_jwt_payload + + def mock_jwk_construct(*args, **kwargs): + class MockJWK: + pass + + return MockJWK() + + monkeypatch.setattr("jose.jwt.decode", mock_jwt_decode) + monkeypatch.setattr("jose.jwk.construct", mock_jwk_construct) + + user_session = await authenticator.authenticate(mock_request) + + assert user_session is not None + assert user_session.user_name == "0009-0008-8698-7745" + + +@pytest.mark.asyncio +async def test_OIDCAuthenticator_missing_code_parameter(oidc_well_known_url: str): + authenticator = OIDCAuthenticator( + audience="APP-TEST-CLIENT-ID", + client_id="APP-TEST-CLIENT-ID", + client_secret="test-secret", + well_known_uri=oidc_well_known_url, + ) + + mock_request = create_mock_oidc_request({}) + + result = await authenticator.authenticate(mock_request) + assert result is None + + +@pytest.mark.asyncio +async def test_OIDCAuthenticator_token_exchange_failure( + oidc_well_known_url: str, + mock_oidc_server, + well_known_response, +): + mock_oidc_server.post(well_known_response["token_endpoint"]).mock( + return_value=httpx.Response( + 400, + json={ + "error": "invalid_client", + "error_description": "Client not found: APP-TEST-CLIENT-ID", + }, + ) + ) + + authenticator = OIDCAuthenticator( + audience="APP-TEST-CLIENT-ID", + client_id="APP-TEST-CLIENT-ID", + client_secret="test-secret", + well_known_uri=oidc_well_known_url, + ) + + mock_request = create_mock_oidc_request({"code": "invalid-code"}) + + result = await authenticator.authenticate(mock_request) + assert result is None diff --git a/.github/workflows/docker-configs/ldap-docker-compose.yml b/continuous_integration/docker-configs/ldap-docker-compose.yml similarity index 74% rename from .github/workflows/docker-configs/ldap-docker-compose.yml rename to continuous_integration/docker-configs/ldap-docker-compose.yml index 5cf12a8..2b2c45a 100644 --- a/.github/workflows/docker-configs/ldap-docker-compose.yml +++ b/continuous_integration/docker-configs/ldap-docker-compose.yml @@ -1,8 +1,6 @@ -version: '2' - services: openldap: - image: docker.io/bitnami/openldap:latest + image: osixia/openldap:latest ports: - '1389:1389' - '1636:1636' @@ -12,7 +10,7 @@ services: - LDAP_USERS=user01,user02 - LDAP_PASSWORDS=password1,password2 volumes: - - 'openldap_data:/bitnami/openldap' + - 'openldap_data:/var/lib/ldap' volumes: openldap_data: diff --git a/continuous_integration/scripts/start_LDAP.sh b/continuous_integration/scripts/start_LDAP.sh new file mode 100755 index 0000000..c6a5fbc --- /dev/null +++ b/continuous_integration/scripts/start_LDAP.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -e + +# Start LDAP server in docker container +docker pull osixia/openldap:latest +docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml up -d +docker ps \ No newline at end of file diff --git a/docs/source/usage.rst b/docs/source/usage.rst index d6c3a10..6cd168c 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -161,7 +161,7 @@ For providers configured with ``OIDCAuthenticator``, use provider-specific endpo under ``/api/auth/provider//...``. Browser-first flow -++++++++++++++++++ +~~~~~~~~~~~~~~~~~ If you are already in a browser context, open: @@ -170,7 +170,7 @@ If you are already in a browser context, open: This redirects to the OIDC provider login page and then back to the server callback. CLI/device flow -+++++++++++++++ +~~~~~~~~~~~~~~~ For terminal clients, start with ``POST /api/auth/provider//authorize``. The response includes: diff --git a/start_LDAP.sh b/start_LDAP.sh deleted file mode 100644 index 8b612de..0000000 --- a/start_LDAP.sh +++ /dev/null @@ -1,8 +0,0 @@ - -#!/bin/bash -set -e - -# Start LDAP server in docker container -# sudo docker pull osixia/openldap:latest -sudo docker compose -f .github/workflows/docker-configs/ldap-docker-compose.yml up -d -sudo docker ps From 28483f97c8b90ca330ba289cec09e68235a5c83e Mon Sep 17 00:00:00 2001 From: David Pastl Date: Tue, 24 Feb 2026 16:19:37 -0600 Subject: [PATCH 10/13] fixing pre-commit issues --- bluesky_httpserver/tests/test_authenticators.py | 4 ++-- continuous_integration/scripts/start_LDAP.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py index 28e2601..53c6bbe 100644 --- a/bluesky_httpserver/tests/test_authenticators.py +++ b/bluesky_httpserver/tests/test_authenticators.py @@ -8,7 +8,7 @@ from jose import ExpiredSignatureError, jwt from jose.backends import RSAKey from respx import MockRouter -from starlette.datastructures import QueryParams, URL +from starlette.datastructures import URL, QueryParams # fmt: off from ..authenticators import LDAPAuthenticator, OIDCAuthenticator, ProxiedOIDCAuthenticator, UserSessionState @@ -161,7 +161,7 @@ def test_oidc_decoding( @pytest.mark.asyncio async def test_proxied_oidc_token_retrieval(oidc_well_known_url: str, mock_oidc_server: MockRouter): - authenticator = ProxiedOIDCAuthenticator("tiled", "tiled", oidc_well_known_url, + authenticator = ProxiedOIDCAuthenticator("tiled", "tiled", oidc_well_known_url, device_flow_client_id="tiled-cli") test_request = httpx.Request("GET", "http://example.com", headers={"Authorization": "bearer FOO"}) diff --git a/continuous_integration/scripts/start_LDAP.sh b/continuous_integration/scripts/start_LDAP.sh index c6a5fbc..ecfa1cf 100755 --- a/continuous_integration/scripts/start_LDAP.sh +++ b/continuous_integration/scripts/start_LDAP.sh @@ -4,4 +4,4 @@ set -e # Start LDAP server in docker container docker pull osixia/openldap:latest docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml up -d -docker ps \ No newline at end of file +docker ps From a4551e34e70394a8563cc88de4e00c89c43a1c8b Mon Sep 17 00:00:00 2001 From: David Pastl Date: Wed, 25 Feb 2026 08:54:52 -0600 Subject: [PATCH 11/13] fixing documentation issues This addresses documentation problems, the levels were incorrect as I did not understand what the next level should have been in the docs. I've also updated the usage documentation a little to be more useful. --- docs/source/usage.rst | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 6cd168c..299bdcb 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -161,7 +161,7 @@ For providers configured with ``OIDCAuthenticator``, use provider-specific endpo under ``/api/auth/provider//...``. Browser-first flow -~~~~~~~~~~~~~~~~~ +****************** If you are already in a browser context, open: @@ -169,10 +169,22 @@ If you are already in a browser context, open: This redirects to the OIDC provider login page and then back to the server callback. +This can similarly be acheived using ``httpie`` by opening the URL in a browser after getting +the authorization URI from the server:: + + http POST http://localhost:60610/api/auth/provider/entra/authorize + +Which will return a token back to the bluesky http server after the user logs in to the provider +in their browser (or automatically if already logged in). The user then gets a token +for the bluesky HTTP server to use for subsequent API requests. This flow can be used +even when using the bluesky queueserver api in a terminal so long as that session can +spawn a browser for the user to log in to the provider. + CLI/device flow -~~~~~~~~~~~~~~~ +*************** -For terminal clients, start with ``POST /api/auth/provider//authorize``. +For terminal clients (i.e. no browser possible), start with +``POST /api/auth/provider//authorize``. The response includes: - ``authorization_uri``: open this URL in a browser From 8fa89ab6509052f584c99b46545bee510ee0419c Mon Sep 17 00:00:00 2001 From: David Pastl Date: Fri, 13 Mar 2026 12:02:04 -0600 Subject: [PATCH 12/13] Adding in helper scripts for testing These allow for running the unit tests in a containerized system just like how they are done in the ci pipeline, but locally and in a way that can maximize processor usage and minimize runtime. --- docker/test.Dockerfile | 29 ++ scripts/docker/run_shard_in_container.sh | 80 ++++ scripts/run_ci_docker_parallel.sh | 480 +++++++++++++++++++++++ 3 files changed, 589 insertions(+) create mode 100644 docker/test.Dockerfile create mode 100755 scripts/docker/run_shard_in_container.sh create mode 100755 scripts/run_ci_docker_parallel.sh diff --git a/docker/test.Dockerfile b/docker/test.Dockerfile new file mode 100644 index 0000000..2e994cf --- /dev/null +++ b/docker/test.Dockerfile @@ -0,0 +1,29 @@ +ARG PYTHON_VERSION=3.13 +FROM python:${PYTHON_VERSION}-slim + +ENV PYTHONUNBUFFERED=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + bash \ + build-essential \ + git \ + redis-server \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /workspace + +COPY requirements.txt requirements-dev.txt ./ +COPY pyproject.toml setup.py setup.cfg MANIFEST.in versioneer.py README.rst AUTHORS.rst LICENSE ./ +COPY bluesky_httpserver ./bluesky_httpserver + +RUN python -m pip install --upgrade pip setuptools wheel numpy && \ + python -m pip install git+https://github.com/bluesky/bluesky-queueserver.git && \ + python -m pip install git+https://github.com/bluesky/bluesky-queueserver-api.git && \ + python -m pip install -r requirements-dev.txt && \ + python -m pip install . + +COPY scripts/docker/run_shard_in_container.sh /usr/local/bin/run_shard_in_container.sh +RUN chmod +x /usr/local/bin/run_shard_in_container.sh + +ENTRYPOINT ["/usr/local/bin/run_shard_in_container.sh"] diff --git a/scripts/docker/run_shard_in_container.sh b/scripts/docker/run_shard_in_container.sh new file mode 100755 index 0000000..7fc23a7 --- /dev/null +++ b/scripts/docker/run_shard_in_container.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +set -euo pipefail + +SHARD_GROUP="${SHARD_GROUP:-1}" +SHARD_COUNT="${SHARD_COUNT:-3}" +ARTIFACTS_DIR="${ARTIFACTS_DIR:-/artifacts}" +PYTEST_EXTRA_ARGS="${PYTEST_EXTRA_ARGS:-}" + +mkdir -p "$ARTIFACTS_DIR" + +if [[ "$SHARD_GROUP" -lt 1 || "$SHARD_COUNT" -lt 1 || "$SHARD_GROUP" -gt "$SHARD_COUNT" ]]; then + echo "Invalid shard settings: SHARD_GROUP=$SHARD_GROUP SHARD_COUNT=$SHARD_COUNT" >&2 + exit 2 +fi + +export COVERAGE_FILE="$ARTIFACTS_DIR/.coverage.${SHARD_GROUP}" + +redis-server --save "" --appendonly no --daemonize yes +for _ in $(seq 1 50); do + if redis-cli ping >/dev/null 2>&1; then + break + fi + sleep 0.2 +done + +if ! redis-cli ping >/dev/null 2>&1; then + echo "Failed to start redis-server inside container" >&2 + exit 2 +fi + +mapfile -t shard_tests < <( + python - <<'PY' "$SHARD_GROUP" "$SHARD_COUNT" +import glob +import sys + +group = int(sys.argv[1]) +count = int(sys.argv[2]) + +tests = sorted(glob.glob("bluesky_httpserver/tests/test_*.py")) +selected = [path for idx, path in enumerate(tests) if idx % count == (group - 1)] + +for path in selected: + print(path) +PY +) + +if [[ "${#shard_tests[@]}" -eq 0 ]]; then + echo "No tests selected for shard ${SHARD_GROUP}/${SHARD_COUNT}; treating as success." + exit 0 +fi + +pytest_cmd=( + coverage + run + -m + pytest + --junitxml="$ARTIFACTS_DIR/junit.${SHARD_GROUP}.xml" + -vv +) + +if [[ -n "$PYTEST_EXTRA_ARGS" ]]; then + read -r -a extra_args <<< "$PYTEST_EXTRA_ARGS" + pytest_cmd+=("${extra_args[@]}") +fi + +pytest_cmd+=("${shard_tests[@]}") + +set +e +"${pytest_cmd[@]}" +test_status=$? +set -e + +if [[ "$test_status" -eq 5 ]]; then + echo "Pytest collected no tests for shard ${SHARD_GROUP}/${SHARD_COUNT}; treating as success." + test_status=0 +fi + +redis-cli shutdown nosave >/dev/null 2>&1 || true + +exit "$test_status" diff --git a/scripts/run_ci_docker_parallel.sh b/scripts/run_ci_docker_parallel.sh new file mode 100755 index 0000000..c9caee7 --- /dev/null +++ b/scripts/run_ci_docker_parallel.sh @@ -0,0 +1,480 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +IMAGE_TAG_BASE="bluesky-httpserver-test:local" +WORKER_COUNT="3" +CHUNK_COUNT="" +PYTHON_VERSIONS="latest" +PYTEST_EXTRA_ARGS="" +ARTIFACTS_DIR="$ROOT_DIR/.docker-test-artifacts" +DOCKER_NETWORK_NAME="bhs-ci-net" +LDAP_CONTAINER_NAME="bhs-ci-ldap" + +SUMMARY_TSV="" +SUMMARY_FAIL_LOGS="" +SUMMARY_TXT="" +SUMMARY_JSON="" +TESTS_START_EPOCH="" +TESTS_START_HUMAN="" + +SUPPORTED_PYTHON_VERSIONS=("3.10" "3.11" "3.12" "3.13") + +usage() { + cat <<'EOF' +Run bluesky-httpserver unit tests in Docker with dynamic chunk dispatch and optional Python-version matrix. + +Usage: + scripts/run_ci_docker_parallel.sh [options] + +Options: + --workers N, --worker-count N + Number of concurrent chunk workers (default: 3). + + --chunks N, --chunk-count N + Number of total chunks/splits to execute per Python version. + Default: workers * 3. + + --python-versions VALUE + Python version selection: latest | all | comma-separated list. + Examples: latest, all, 3.12, 3.11,3.13 + Default: latest (currently 3.13). + + --pytest-args "ARGS" + Extra arguments passed to pytest in each chunk. + Example: --pytest-args "-k oidc --maxfail=1" + + --artifacts-dir PATH + Output directory for all artifacts. + Default: .docker-test-artifacts under repository root. + + --image-tag TAG + Base docker image tag. Per-version tags will append -py. + Default: bluesky-httpserver-test:local + + -h, --help + Show this help message. + +Examples: + scripts/run_ci_docker_parallel.sh + scripts/run_ci_docker_parallel.sh --workers 8 --chunks 24 + scripts/run_ci_docker_parallel.sh --python-versions all --workers 8 --chunks 24 + scripts/run_ci_docker_parallel.sh --python-versions 3.11,3.13 --pytest-args "-k test_access_control" +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --workers|--worker-count) + WORKER_COUNT="$2" + shift 2 + ;; + --chunks|--chunk-count) + CHUNK_COUNT="$2" + shift 2 + ;; + --python-versions) + PYTHON_VERSIONS="$2" + shift 2 + ;; + --pytest-args) + PYTEST_EXTRA_ARGS="$2" + shift 2 + ;; + --artifacts-dir) + ARTIFACTS_DIR="$2" + shift 2 + ;; + --image-tag) + IMAGE_TAG_BASE="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown option: $1" >&2 + usage + exit 2 + ;; + esac +done + +if [[ "$WORKER_COUNT" -lt 1 ]]; then + echo "WORKER_COUNT must be >= 1" >&2 + exit 2 +fi + +if [[ -z "$CHUNK_COUNT" ]]; then + CHUNK_COUNT=$(( WORKER_COUNT * 3 )) +fi + +if [[ "$CHUNK_COUNT" -lt 1 ]]; then + echo "CHUNK_COUNT must be >= 1" >&2 + exit 2 +fi + +if ! command -v docker >/dev/null 2>&1; then + echo "docker is required but not found in PATH" >&2 + exit 2 +fi + +if ! docker info >/dev/null 2>&1; then + echo "docker daemon is not available" >&2 + exit 2 +fi + +normalize_python_versions() { + local selection="$1" + local raw + local normalized=() + + if [[ "$selection" == "latest" ]]; then + normalized=("3.13") + elif [[ "$selection" == "all" ]]; then + normalized=("${SUPPORTED_PYTHON_VERSIONS[@]}") + else + raw="${selection//,/ }" + read -r -a normalized <<< "$raw" + fi + + if [[ "${#normalized[@]}" -eq 0 ]]; then + echo "PYTHON_VERSIONS selection produced no versions" >&2 + exit 2 + fi + + for version in "${normalized[@]}"; do + if [[ ! " ${SUPPORTED_PYTHON_VERSIONS[*]} " =~ " ${version} " ]]; then + echo "Unsupported Python version '${version}'. Supported: ${SUPPORTED_PYTHON_VERSIONS[*]}" >&2 + exit 2 + fi + done + + echo "${normalized[@]}" +} + +ensure_ldap_image() { + local image_ref="bitnami/openldap:latest" + if docker image inspect "$image_ref" >/dev/null 2>&1; then + return + fi + + echo "LDAP image $image_ref not found locally; trying docker pull..." + if docker pull "$image_ref"; then + return + fi + + echo "docker pull failed; building bitnami/openldap:latest from source (CI fallback)." + local workdir="$ROOT_DIR/.docker-test-artifacts/bitnami-containers" + rm -rf "$workdir" + git clone --depth 1 https://github.com/bitnami/containers.git "$workdir" + (cd "$workdir/bitnami/openldap/2.6/debian-12" && docker build -t "$image_ref" .) +} + +start_services() { + ensure_ldap_image + + docker network rm "$DOCKER_NETWORK_NAME" >/dev/null 2>&1 || true + docker network create "$DOCKER_NETWORK_NAME" >/dev/null + + docker rm -f "$LDAP_CONTAINER_NAME" >/dev/null 2>&1 || true + docker run -d --rm \ + --name "$LDAP_CONTAINER_NAME" \ + --network "$DOCKER_NETWORK_NAME" \ + -e LDAP_ADMIN_USERNAME=admin \ + -e LDAP_ADMIN_PASSWORD=adminpassword \ + -e LDAP_USERS=user01,user02 \ + -e LDAP_PASSWORDS=password1,password2 \ + bitnami/openldap:latest >/dev/null + + sleep 2 +} + +stop_services() { + docker rm -f "$LDAP_CONTAINER_NAME" >/dev/null 2>&1 || true + docker network rm "$DOCKER_NETWORK_NAME" >/dev/null 2>&1 || true +} + +cleanup() { + stop_services +} + +collect_junit_totals() { + local artifacts_dir="$1" + + python - "$artifacts_dir" <<'PY' +import glob +import os +import sys +import xml.etree.ElementTree as ET + +artifacts_dir = sys.argv[1] +tests = failures = errors = files = 0 + +for path in sorted(glob.glob(os.path.join(artifacts_dir, "junit.*.xml"))): + files += 1 + try: + root = ET.parse(path).getroot() + except Exception: + continue + + if root.tag == "testsuite": + suites = [root] + elif root.tag == "testsuites": + suites = root.findall("testsuite") + else: + suites = [] + + for suite in suites: + tests += int(suite.attrib.get("tests", 0) or 0) + failures += int(suite.attrib.get("failures", 0) or 0) + errors += int(suite.attrib.get("errors", 0) or 0) + +print(f"{tests} {failures} {errors} {files}") +PY +} + +append_summary_row() { + local py_version="$1" + local chunks_total="$2" + local junit_files="$3" + local tests="$4" + local failures="$5" + local errors="$6" + local status="$7" + + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "$py_version" "$chunks_total" "$junit_files" "$tests" "$failures" "$errors" "$status" >> "$SUMMARY_TSV" +} + +write_summary_files() { + local end_epoch end_human elapsed_sec + + if [[ -z "$SUMMARY_TSV" || -z "$SUMMARY_TXT" || -z "$SUMMARY_JSON" ]]; then + return + fi + + if [[ ! -f "$SUMMARY_TSV" ]]; then + return + fi + + end_epoch="$(date +%s)" + end_human="$(date -u +"%Y-%m-%dT%H:%M:%SZ")" + + if [[ -n "$TESTS_START_EPOCH" ]]; then + elapsed_sec=$(( end_epoch - TESTS_START_EPOCH )) + else + elapsed_sec=0 + fi + + { + echo "Test Run Summary" + echo "Start (UTC): ${TESTS_START_HUMAN:-N/A}" + echo "End (UTC): $end_human" + echo "Elapsed: ${elapsed_sec}s" + echo + printf "%-8s %-8s %-7s %-8s %-10s %-8s %-6s\n" \ + "Python" "Status" "Chunks" "JUnit" "Tests" "Failures" "Errors" + printf "%-8s %-8s %-7s %-8s %-10s %-8s %-6s\n" \ + "------" "------" "------" "-----" "-----" "--------" "------" + + if [[ -s "$SUMMARY_TSV" ]]; then + while IFS=$'\t' read -r py_version chunks_total junit_files tests failures errors status; do + printf "%-8s %-8s %-7s %-8s %-10s %-8s %-6s\n" \ + "$py_version" "$status" "$chunks_total" "$junit_files" "$tests" "$failures" "$errors" + done < "$SUMMARY_TSV" + else + echo "No per-version summary rows were recorded." + fi + + if [[ -s "$SUMMARY_FAIL_LOGS" ]]; then + echo + echo "Failed Chunk Logs" + cat "$SUMMARY_FAIL_LOGS" + fi + } > "$SUMMARY_TXT" + + python - "$SUMMARY_TSV" "$SUMMARY_FAIL_LOGS" "$SUMMARY_JSON" "${TESTS_START_HUMAN:-N/A}" "$end_human" "$elapsed_sec" <<'PY' +import json +import sys + +summary_tsv, fail_logs_path, output_path, start_utc, end_utc, elapsed_sec = sys.argv[1:] + +rows = [] +with open(summary_tsv) as f: + for line in f: + parts = line.rstrip("\n").split("\t") + if len(parts) != 7: + continue + py_version, chunks_total, junit_files, tests, failures, errors, status = parts + rows.append( + { + "python_version": py_version, + "status": status, + "chunks_total": int(chunks_total), + "junit_files": int(junit_files), + "tests": int(tests), + "failures": int(failures), + "errors": int(errors), + } + ) + +failed_logs = [] +with open(fail_logs_path) as f: + failed_logs = [line.strip() for line in f if line.strip()] + +payload = { + "start_utc": start_utc, + "end_utc": end_utc, + "elapsed_seconds": int(elapsed_sec), + "python_versions": rows, + "failed_chunk_logs": failed_logs, +} + +with open(output_path, "w") as f: + json.dump(payload, f, indent=2) + f.write("\n") +PY + + echo "==> Test run end time (UTC): $end_human" + echo "==> Test run elapsed: ${elapsed_sec}s" + echo "==> Summary written: $SUMMARY_TXT" + echo "==> Summary JSON: $SUMMARY_JSON" +} + +on_exit() { + local exit_code=$? + write_summary_files || true + cleanup + trap - EXIT + exit "$exit_code" +} + +trap on_exit EXIT + +read -r -a SELECTED_PYTHON_VERSIONS <<< "$(normalize_python_versions "$PYTHON_VERSIONS")" + +echo "==> Preparing artifacts directory: $ARTIFACTS_DIR" +rm -rf "$ARTIFACTS_DIR" +mkdir -p "$ARTIFACTS_DIR" + +SUMMARY_TSV="$ARTIFACTS_DIR/.summary_rows.tsv" +SUMMARY_FAIL_LOGS="$ARTIFACTS_DIR/.summary_fail_logs.txt" +SUMMARY_TXT="$ARTIFACTS_DIR/summary.txt" +SUMMARY_JSON="$ARTIFACTS_DIR/summary.json" + +: > "$SUMMARY_TSV" +: > "$SUMMARY_FAIL_LOGS" + +echo "==> Starting shared services (LDAP)" +start_services + +TESTS_START_EPOCH="$(date +%s)" +TESTS_START_HUMAN="$(date -u +"%Y-%m-%dT%H:%M:%SZ")" +echo "==> Test run start time (UTC): $TESTS_START_HUMAN" +echo "==> Python versions selected: ${SELECTED_PYTHON_VERSIONS[*]}" + +run_chunk() { + local group="$1" + local log_file="$CURRENT_ARTIFACTS_DIR/shard.${group}.log" + + if docker run --rm \ + --network "$DOCKER_NETWORK_NAME" \ + -e SHARD_GROUP="$group" \ + -e SHARD_COUNT="$CHUNK_COUNT" \ + -e ARTIFACTS_DIR="/artifacts" \ + -e PYTEST_EXTRA_ARGS="$PYTEST_EXTRA_ARGS" \ + -e QSERVER_TEST_LDAP_HOST="$LDAP_CONTAINER_NAME" \ + -e QSERVER_TEST_LDAP_PORT="1389" \ + -e QSERVER_TEST_REDIS_ADDR="localhost" \ + -e QSERVER_HTTP_TEST_BIND_HOST="127.0.0.1" \ + -e QSERVER_HTTP_TEST_HOST="127.0.0.1" \ + -v "$CURRENT_ARTIFACTS_DIR:/artifacts" \ + "$CURRENT_IMAGE_TAG" >"$log_file" 2>&1; then + : > "$CURRENT_ARTIFACTS_DIR/.status.${group}.ok" + else + : > "$CURRENT_ARTIFACTS_DIR/.status.${group}.fail" + exit 1 + fi +} + +export -f run_chunk +export CHUNK_COUNT PYTEST_EXTRA_ARGS DOCKER_NETWORK_NAME LDAP_CONTAINER_NAME + +for PYTHON_VERSION in "${SELECTED_PYTHON_VERSIONS[@]}"; do + CURRENT_IMAGE_TAG="${IMAGE_TAG_BASE}-py${PYTHON_VERSION}" + CURRENT_ARTIFACTS_DIR="$ARTIFACTS_DIR/py${PYTHON_VERSION}" + export CURRENT_IMAGE_TAG CURRENT_ARTIFACTS_DIR + + echo "==> Building test image: $CURRENT_IMAGE_TAG (Python $PYTHON_VERSION)" + docker build \ + --build-arg PYTHON_VERSION="$PYTHON_VERSION" \ + -f "$ROOT_DIR/docker/test.Dockerfile" \ + -t "$CURRENT_IMAGE_TAG" \ + "$ROOT_DIR" + + mkdir -p "$CURRENT_ARTIFACTS_DIR" + + echo "==> [Python $PYTHON_VERSION] Starting dynamic dispatch: $WORKER_COUNT workers over $CHUNK_COUNT chunks" + if ! seq 1 "$CHUNK_COUNT" | xargs -P "$WORKER_COUNT" -I {} bash -lc 'run_chunk "$1"' _ {}; then + echo "One or more chunks failed for Python $PYTHON_VERSION." >&2 + read -r TOTAL_TESTS TOTAL_FAILURES TOTAL_ERRORS TOTAL_JUNIT_FILES < <(collect_junit_totals "$CURRENT_ARTIFACTS_DIR") + for group in $(seq 1 "$CHUNK_COUNT"); do + if [[ -f "$CURRENT_ARTIFACTS_DIR/.status.${group}.fail" ]]; then + echo "Chunk $group failed. Log: $CURRENT_ARTIFACTS_DIR/shard.${group}.log" >&2 + echo "$CURRENT_ARTIFACTS_DIR/shard.${group}.log" >> "$SUMMARY_FAIL_LOGS" + fi + done + append_summary_row "py${PYTHON_VERSION}" "$CHUNK_COUNT" "$TOTAL_JUNIT_FILES" \ + "$TOTAL_TESTS" "$TOTAL_FAILURES" "$TOTAL_ERRORS" "FAIL" + exit 1 + fi + + for group in $(seq 1 "$CHUNK_COUNT"); do + if [[ -f "$CURRENT_ARTIFACTS_DIR/.status.${group}.ok" ]]; then + echo "[Python $PYTHON_VERSION] Chunk $group completed successfully" + fi + done + + rm -f "$CURRENT_ARTIFACTS_DIR"/.status.*.ok "$CURRENT_ARTIFACTS_DIR"/.status.*.fail + + echo "==> [Python $PYTHON_VERSION] Merging coverage artifacts" + docker run --rm \ + --entrypoint bash \ + -v "$CURRENT_ARTIFACTS_DIR:/artifacts" \ + "$CURRENT_IMAGE_TAG" \ + -lc "set -euo pipefail; \ + python -m coverage combine /artifacts/.coverage.* && \ + python -m coverage xml -o /artifacts/coverage.xml && \ + python -m coverage report -m > /artifacts/coverage.txt" + + if [[ "${#SELECTED_PYTHON_VERSIONS[@]}" -eq 1 ]]; then + cp "$CURRENT_ARTIFACTS_DIR/coverage.xml" "$ROOT_DIR/coverage.xml" + else + cp "$CURRENT_ARTIFACTS_DIR/coverage.xml" "$ROOT_DIR/coverage.py${PYTHON_VERSION}.xml" + fi + + read -r TOTAL_TESTS TOTAL_FAILURES TOTAL_ERRORS TOTAL_JUNIT_FILES < <(collect_junit_totals "$CURRENT_ARTIFACTS_DIR") + echo "==> [Python $PYTHON_VERSION] JUnit summary: tests=$TOTAL_TESTS failures=$TOTAL_FAILURES errors=$TOTAL_ERRORS files=$TOTAL_JUNIT_FILES" + + VERSION_STATUS="PASS" + if [[ "$TOTAL_FAILURES" -gt 0 || "$TOTAL_ERRORS" -gt 0 ]]; then + VERSION_STATUS="FAIL" + fi + + append_summary_row "py${PYTHON_VERSION}" "$CHUNK_COUNT" "$TOTAL_JUNIT_FILES" \ + "$TOTAL_TESTS" "$TOTAL_FAILURES" "$TOTAL_ERRORS" "$VERSION_STATUS" +done + +echo "==> Completed. Artifacts:" +echo " versioned logs : $ARTIFACTS_DIR/py/shard..log" +echo " versioned junit : $ARTIFACTS_DIR/py/junit..xml" +echo " versioned coverage : $ARTIFACTS_DIR/py/{coverage.txt,coverage.xml}" +echo " run summary : $ARTIFACTS_DIR/{summary.txt,summary.json}" + +if [[ "${#SELECTED_PYTHON_VERSIONS[@]}" -eq 1 ]]; then + echo " root coverage xml : $ROOT_DIR/coverage.xml" +else + echo " root coverage xmls : $ROOT_DIR/coverage.py.xml" +fi From 8298a7b29acd1d461eb7dc50391609a41453e738 Mon Sep 17 00:00:00 2001 From: davidpcls Date: Fri, 20 Mar 2026 11:24:56 -0600 Subject: [PATCH 13/13] Fixing unit tests (#3) This is a set of test changes intended to improve the reliability of unit testing, as the current unit tests are randomly failing due to test design. Primarily this appears to be centered around LDAP. So this work was to: * Fix for ldap errors * Hardening unit tests so they fail less frequency * Try to handle console output more reliably --- .github/workflows/testing.yml | 23 +- bluesky_httpserver/tests/conftest.py | 23 +- .../tests/test_authenticators.py | 29 ++- .../tests/test_console_output.py | 81 +++++--- .../tests/test_core_api_main.py | 44 +++- bluesky_httpserver/tests/test_server.py | 26 ++- .../docker-configs/ldap-docker-compose.yml | 11 +- .../dockerfiles}/test.Dockerfile | 0 continuous_integration/scripts/start_LDAP.sh | 196 +++++++++++++++++- docs/source/usage.rst | 4 +- scripts/run_ci_docker_parallel.sh | 56 ++--- 11 files changed, 371 insertions(+), 122 deletions(-) rename {docker => continuous_integration/dockerfiles}/test.Dockerfile (100%) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 5355c05..adef4fc 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -26,7 +26,7 @@ jobs: - name: Fetch tags run: git fetch --tags --prune --unshallow - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - uses: shogo82148/actions-setup-redis@v1 @@ -36,14 +36,8 @@ jobs: run: | # sudo apt install redis - pushd .. - git clone https://github.com/bitnami/containers.git - cd containers/bitnami/openldap/2.6/debian-12 - docker build -t bitnami/openldap:latest . - popd - # Start LDAP - source continuous_integration/scripts/start_LDAP.sh + bash continuous_integration/scripts/start_LDAP.sh # These packages are installed in the base environment but may be older # versions. Explicitly upgrade them because they often create @@ -70,6 +64,19 @@ jobs: pip list - name: Test with pytest + env: + PYTEST_ADDOPTS: "--durations=20" run: | coverage run -m pytest -vv coverage report -m + - name: Dump LDAP diagnostics on failure + if: failure() + run: | + docker ps + docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml ps + LDAP_CONTAINER_ID=$(docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml ps -q openldap | tr -d '[:space:]') + if [ -n "$LDAP_CONTAINER_ID" ]; then + docker logs --tail 200 "$LDAP_CONTAINER_ID" + else + docker compose -f continuous_integration/docker-configs/ldap-docker-compose.yml logs --tail 200 openldap + fi diff --git a/bluesky_httpserver/tests/conftest.py b/bluesky_httpserver/tests/conftest.py index d5cafdb..8a81df9 100644 --- a/bluesky_httpserver/tests/conftest.py +++ b/bluesky_httpserver/tests/conftest.py @@ -4,6 +4,7 @@ import pytest import requests from bluesky_queueserver.manager.comms import zmq_single_request +from bluesky_queueserver.manager.tests.common import re_manager_cmd # noqa: F401 from bluesky_queueserver.manager.tests.common import set_qserver_zmq_encoding # noqa: F401 from xprocess import ProcessStarter @@ -60,7 +61,11 @@ def fastapi_server_fs(xprocess): to perform additional steps (such as setting environmental variables) before the server is started. """ - def start(http_server_host=SERVER_ADDRESS, http_server_port=SERVER_PORT, api_key=API_KEY_FOR_TESTS): + def start( + http_server_host=SERVER_ADDRESS, + http_server_port=SERVER_PORT, + api_key=API_KEY_FOR_TESTS, + ): class Starter(ProcessStarter): max_read_lines = 53 @@ -112,7 +117,12 @@ def add_plans_to_queue(): user_group = _user_group user = "HTTP unit test setup" - plan1 = {"name": "count", "args": [["det1", "det2"]], "kwargs": {"num": 10, "delay": 1}, "item_type": "plan"} + plan1 = { + "name": "count", + "args": [["det1", "det2"]], + "kwargs": {"num": 10, "delay": 1}, + "item_type": "plan", + } plan2 = {"name": "count", "args": [["det1", "det2"]], "item_type": "plan"} for plan in (plan1, plan2, plan2): resp2, _ = zmq_single_request("queue_item_add", {"item": plan, "user": user, "user_group": user_group}) @@ -120,7 +130,14 @@ def add_plans_to_queue(): def request_to_json( - request_type, path, *, request_prefix="/api", api_key=API_KEY_FOR_TESTS, token=None, login=None, **kwargs + request_type, + path, + *, + request_prefix="/api", + api_key=API_KEY_FOR_TESTS, + token=None, + login=None, + **kwargs, ): if login: auth = None diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py index 53c6bbe..7b7dd4b 100644 --- a/bluesky_httpserver/tests/test_authenticators.py +++ b/bluesky_httpserver/tests/test_authenticators.py @@ -1,4 +1,5 @@ import asyncio +import os import time from typing import Any, Tuple @@ -10,20 +11,28 @@ from respx import MockRouter from starlette.datastructures import URL, QueryParams -# fmt: off from ..authenticators import LDAPAuthenticator, OIDCAuthenticator, ProxiedOIDCAuthenticator, UserSessionState +LDAP_TEST_HOST = os.environ.get("QSERVER_TEST_LDAP_HOST", "localhost") +LDAP_TEST_PORT = int(os.environ.get("QSERVER_TEST_LDAP_PORT", "1389")) +LDAP_TEST_ALT_HOST = os.environ.get("QSERVER_TEST_LDAP_ALT_HOST") +if not LDAP_TEST_ALT_HOST: + LDAP_TEST_ALT_HOST = "127.0.0.1" if LDAP_TEST_HOST == "localhost" else LDAP_TEST_HOST + + +# fmt: off + @pytest.mark.parametrize("ldap_server_address, ldap_server_port", [ - ("localhost", 1389), - ("localhost:1389", 904), # Random port, ignored - ("localhost:1389", None), - ("127.0.0.1", 1389), - ("127.0.0.1:1389", 904), - (["localhost"], 1389), - (["localhost", "127.0.0.1"], 1389), - (["localhost", "127.0.0.1:1389"], 1389), - (["localhost:1389", "127.0.0.1:1389"], None), + (LDAP_TEST_HOST, LDAP_TEST_PORT), + (f"{LDAP_TEST_HOST}:{LDAP_TEST_PORT}", 904), # Random port, ignored + (f"{LDAP_TEST_HOST}:{LDAP_TEST_PORT}", None), + (LDAP_TEST_ALT_HOST, LDAP_TEST_PORT), + (f"{LDAP_TEST_ALT_HOST}:{LDAP_TEST_PORT}", 904), + ([LDAP_TEST_HOST], LDAP_TEST_PORT), + ([LDAP_TEST_HOST, LDAP_TEST_ALT_HOST], LDAP_TEST_PORT), + ([LDAP_TEST_HOST, f"{LDAP_TEST_ALT_HOST}:{LDAP_TEST_PORT}"], LDAP_TEST_PORT), + ([f"{LDAP_TEST_HOST}:{LDAP_TEST_PORT}", f"{LDAP_TEST_ALT_HOST}:{LDAP_TEST_PORT}"], None), ]) # fmt: on @pytest.mark.parametrize("use_tls,use_ssl", [(False, False)]) diff --git a/bluesky_httpserver/tests/test_console_output.py b/bluesky_httpserver/tests/test_console_output.py index 1f089ec..6193db0 100644 --- a/bluesky_httpserver/tests/test_console_output.py +++ b/bluesky_httpserver/tests/test_console_output.py @@ -3,17 +3,16 @@ import re import threading import time as ttime +from typing import Any import pytest import requests -from bluesky_queueserver.manager.tests.common import re_manager_cmd # noqa F401 from websockets.sync.client import connect from bluesky_httpserver.tests.conftest import ( # noqa F401 API_KEY_FOR_TESTS, SERVER_ADDRESS, SERVER_PORT, - fastapi_server_fs, request_to_json, set_qserver_zmq_encoding, wait_for_environment_to_be_closed, @@ -36,37 +35,42 @@ def __init__(self, api_key=API_KEY_FOR_TESTS, **kwargs): self._api_key = api_key def run(self): - kwargs = {"stream": True} + kwargs: dict[str, Any] = {"stream": True} if self._api_key: - auth = None headers = {"Authorization": f"ApiKey {self._api_key}"} - kwargs.update({"auth": auth, "headers": headers}) + kwargs.update({"headers": headers}) + + kwargs["timeout"] = (5, 1) - with requests.get(f"http://{SERVER_ADDRESS}:{SERVER_PORT}/api/stream_console_output", **kwargs) as r: - r.encoding = "utf-8" + while not self._exit: + try: + with requests.get( + f"http://{SERVER_ADDRESS}:{SERVER_PORT}/api/stream_console_output", + **kwargs, + ) as r: + r.encoding = "utf-8" - characters = [] - n_brackets = 0 + characters = [] + n_brackets = 0 - for ch in r.iter_content(decode_unicode=True): - # Note, that some output must be received from the server before the loop exits - if self._exit: - break + for ch in r.iter_content(decode_unicode=True): + if self._exit: + return - characters.append(ch) - if ch == "{": - n_brackets += 1 - elif ch == "}": - n_brackets -= 1 + characters.append(ch) + if ch == "{": + n_brackets += 1 + elif ch == "}": + n_brackets -= 1 - # If the received buffer ('characters') is not empty and the message contains - # equal number of opening and closing brackets then consider the message complete. - if characters and not n_brackets: - line = "".join(characters) - characters = [] + if characters and not n_brackets: + line = "".join(characters) + characters = [] - print(f"{line}") - self.received_data_buffer.append(json.loads(line)) + print(f"{line}") + self.received_data_buffer.append(json.loads(line)) + except requests.exceptions.ReadTimeout: + continue def stop(self): """ @@ -81,7 +85,10 @@ def __del__(self): @pytest.mark.parametrize("zmq_port", (None, 60619)) def test_http_server_stream_console_output_1( - monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port # noqa F811 + monkeypatch, + re_manager_cmd, + fastapi_server_fs, + zmq_port, # noqa F811 ): """ Test for ``stream_console_output`` API @@ -122,7 +129,8 @@ def test_http_server_stream_console_output_1( assert resp2["items"][0] == resp1["item"] assert resp2["running_item"] == {} - rsc.join() + rsc.join(timeout=10) + assert not rsc.is_alive(), "Timed out waiting for stream_console_output thread to terminate" assert len(rsc.received_data_buffer) >= 2, pprint.pformat(rsc.received_data_buffer) @@ -160,7 +168,11 @@ def test_http_server_stream_console_output_1( @pytest.mark.parametrize("zmq_encoding", (None, "json", "msgpack")) @pytest.mark.parametrize("zmq_port", (None, 60619)) def test_http_server_console_output_1( - monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port, zmq_encoding # noqa F811 + monkeypatch, + re_manager_cmd, + fastapi_server_fs, + zmq_port, + zmq_encoding, # noqa F811 ): """ Test for ``console_output`` API (not a streaming version). @@ -238,7 +250,10 @@ def test_http_server_console_output_1( @pytest.mark.parametrize("zmq_port", (None, 60619)) def test_http_server_console_output_update_1( - monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port # noqa F811 + monkeypatch, + re_manager_cmd, + fastapi_server_fs, + zmq_port, # noqa F811 ): """ Test for ``console_output`` API (not a streaming version). @@ -379,7 +394,10 @@ def __del__(self): @pytest.mark.parametrize("zmq_port", (None, 60619)) def test_http_server_console_output_socket_1( - monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port # noqa F811 + monkeypatch, + re_manager_cmd, + fastapi_server_fs, + zmq_port, # noqa F811 ): """ Test for ``/console_output/ws`` websocket @@ -421,7 +439,8 @@ def test_http_server_console_output_socket_1( assert resp2["items"][0] == resp1["item"] assert resp2["running_item"] == {} - rsc.join() + rsc.join(timeout=10) + assert not rsc.is_alive(), "Timed out waiting for console_output websocket thread to terminate" assert len(rsc.received_data_buffer) >= 2, pprint.pformat(rsc.received_data_buffer) diff --git a/bluesky_httpserver/tests/test_core_api_main.py b/bluesky_httpserver/tests/test_core_api_main.py index b2b5140..0c471bd 100644 --- a/bluesky_httpserver/tests/test_core_api_main.py +++ b/bluesky_httpserver/tests/test_core_api_main.py @@ -30,8 +30,17 @@ # Plans used in most of the tests: '_plan1' and '_plan2' are quickly executed '_plan3' runs for 5 seconds. _plan1 = {"name": "count", "args": [["det1", "det2"]], "item_type": "plan"} -_plan2 = {"name": "scan", "args": [["det1", "det2"], "motor", -1, 1, 10], "item_type": "plan"} -_plan3 = {"name": "count", "args": [["det1", "det2"]], "kwargs": {"num": 5, "delay": 1}, "item_type": "plan"} +_plan2 = { + "name": "scan", + "args": [["det1", "det2"], "motor", -1, 1, 10], + "item_type": "plan", +} +_plan3 = { + "name": "count", + "args": [["det1", "det2"]], + "kwargs": {"num": 5, "delay": 1}, + "item_type": "plan", +} _instruction_stop = {"name": "queue_stop", "item_type": "instruction"} @@ -515,8 +524,10 @@ def test_http_server_queue_item_update_2_fail(re_manager, fastapi_server, replac resp2 = request_to_json("post", "/queue/item/update", json=params) assert resp2["success"] is False - assert resp2["msg"] == "Failed to add an item: Failed to replace item: " \ - "Item with UID 'incorrect_uid' is not in the queue" + assert ( + resp2["msg"] == "Failed to add an item: Failed to replace item: " + "Item with UID 'incorrect_uid' is not in the queue" + ) resp3 = request_to_json("get", "/queue/get") assert resp3["items"] != [] @@ -1286,16 +1297,33 @@ def test_http_server_history_clear(re_manager, fastapi_server, clear_params, exp def test_http_server_manager_kill(re_manager, fastapi_server): # noqa F811 + timeout_variants = ( + "Request timeout: ZMQ communication error: timeout occurred", + "Request timeout: ZMQ communication error: Resource temporarily unavailable", + ) + request_to_json("post", "/environment/open") assert wait_for_environment_to_be_created(10), "Timeout" resp = request_to_json("post", "/test/manager/kill") assert "success" not in resp - assert "Request timeout: ZMQ communication error: timeout occurred" in resp["detail"] - - ttime.sleep(10) + assert any(_ in resp["detail"] for _ in timeout_variants) + + deadline = ttime.time() + 20 + last_status = None + while ttime.time() < deadline: + ttime.sleep(0.2) + last_status = request_to_json("get", "/status") + if ( + isinstance(last_status, dict) + and last_status.get("manager_state") == "idle" + and last_status.get("worker_environment_exists") is True + ): + break + else: + assert False, f"Timeout while waiting for manager recovery after kill. Last status: {last_status!r}" - resp = request_to_json("get", "/status") + resp = last_status assert resp["msg"].startswith("RE Manager") assert resp["manager_state"] == "idle" assert resp["items_in_queue"] == 0 diff --git a/bluesky_httpserver/tests/test_server.py b/bluesky_httpserver/tests/test_server.py index 117f4df..33b82a2 100644 --- a/bluesky_httpserver/tests/test_server.py +++ b/bluesky_httpserver/tests/test_server.py @@ -27,8 +27,17 @@ # Plans used in most of the tests: '_plan1' and '_plan2' are quickly executed '_plan3' runs for 5 seconds. _plan1 = {"name": "count", "args": [["det1", "det2"]], "item_type": "plan"} -_plan2 = {"name": "scan", "args": [["det1", "det2"], "motor", -1, 1, 10], "item_type": "plan"} -_plan3 = {"name": "count", "args": [["det1", "det2"]], "kwargs": {"num": 5, "delay": 1}, "item_type": "plan"} +_plan2 = { + "name": "scan", + "args": [["det1", "det2"], "motor", -1, 1, 10], + "item_type": "plan", +} +_plan3 = { + "name": "count", + "args": [["det1", "det2"]], + "kwargs": {"num": 5, "delay": 1}, + "item_type": "plan", +} _config_public_key = """ @@ -122,7 +131,7 @@ def test_http_server_secure_1(monkeypatch, tmpdir, re_manager_cmd, fastapi_serve @pytest.mark.parametrize("option", ["ev", "cfg_file", "both"]) # fmt: on def test_http_server_set_zmq_address_1( - monkeypatch, tmpdir, re_manager_cmd, fastapi_server_fs, option # noqa: F811 + monkeypatch, tmpdir, re_manager_cmd, fastapi_server_fs, free_tcp_port_factory, option # noqa: F811 ): """ Test if ZMQ address of RE Manager is passed to the HTTP server using 'QSERVER_ZMQ_ADDRESS_CONTROL' @@ -130,11 +139,12 @@ def test_http_server_set_zmq_address_1( channel different from default address, add and execute a plan. """ - # Change ZMQ address to use port 60616 instead of the default port 60615. - zmq_control_address_server = "tcp://*:60616" - zmq_info_address_server = "tcp://*:60617" - zmq_control_address = "tcp://localhost:60616" - zmq_info_address = "tcp://localhost:60617" + zmq_control_port = free_tcp_port_factory() + zmq_info_port = free_tcp_port_factory() + zmq_control_address_server = f"tcp://*:{zmq_control_port}" + zmq_info_address_server = f"tcp://*:{zmq_info_port}" + zmq_control_address = f"tcp://localhost:{zmq_control_port}" + zmq_info_address = f"tcp://localhost:{zmq_info_port}" if option == "ev": monkeypatch.setenv("QSERVER_ZMQ_CONTROL_ADDRESS", zmq_control_address) monkeypatch.setenv("QSERVER_ZMQ_INFO_ADDRESS", zmq_info_address) diff --git a/continuous_integration/docker-configs/ldap-docker-compose.yml b/continuous_integration/docker-configs/ldap-docker-compose.yml index 2b2c45a..5fbfc53 100644 --- a/continuous_integration/docker-configs/ldap-docker-compose.yml +++ b/continuous_integration/docker-configs/ldap-docker-compose.yml @@ -1,14 +1,13 @@ services: openldap: - image: osixia/openldap:latest + image: osixia/openldap:1.5.0 ports: - - '1389:1389' - - '1636:1636' + - '1389:389' + - '1636:636' environment: - - LDAP_ADMIN_USERNAME=admin + - LDAP_ORGANISATION=Example Inc. + - LDAP_DOMAIN=example.org - LDAP_ADMIN_PASSWORD=adminpassword - - LDAP_USERS=user01,user02 - - LDAP_PASSWORDS=password1,password2 volumes: - 'openldap_data:/var/lib/ldap' diff --git a/docker/test.Dockerfile b/continuous_integration/dockerfiles/test.Dockerfile similarity index 100% rename from docker/test.Dockerfile rename to continuous_integration/dockerfiles/test.Dockerfile diff --git a/continuous_integration/scripts/start_LDAP.sh b/continuous_integration/scripts/start_LDAP.sh index ecfa1cf..d2bd48d 100755 --- a/continuous_integration/scripts/start_LDAP.sh +++ b/continuous_integration/scripts/start_LDAP.sh @@ -1,7 +1,195 @@ -#!/bin/bash -set -e +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +COMPOSE_FILE="${LDAP_COMPOSE_FILE:-$ROOT_DIR/continuous_integration/docker-configs/ldap-docker-compose.yml}" +COMPOSE_PROJECT="${LDAP_COMPOSE_PROJECT:-}" +LDAP_HOST="${LDAP_HOST:-127.0.0.1}" +LDAP_PORT="${LDAP_PORT:-1389}" +LDAP_ADMIN_DN="cn=admin,dc=example,dc=org" +LDAP_ADMIN_PASSWORD="adminpassword" +LDAP_BASE_DN="dc=example,dc=org" + +compose_cmd() { + if [[ -n "$COMPOSE_PROJECT" ]]; then + docker compose -p "$COMPOSE_PROJECT" -f "$COMPOSE_FILE" "$@" + else + docker compose -f "$COMPOSE_FILE" "$@" + fi +} + +get_openldap_container_id() { + compose_cmd ps -q openldap | tr -d '[:space:]' +} + +wait_for_ldap() { + local timeout_seconds="${1:-60}" + local deadline=$((SECONDS + timeout_seconds)) + + while (( SECONDS < deadline )); do + if python - </dev/null 2>&1 +import socket + +with socket.create_connection(("${LDAP_HOST}", ${LDAP_PORT}), timeout=1): + pass +PY + then + return 0 + fi + sleep 1 + done + + return 1 +} + +wait_for_ldap_bind() { + local container_id="$1" + local timeout_seconds="${2:-60}" + local deadline=$((SECONDS + timeout_seconds)) + local rc=0 + + while (( SECONDS < deadline )); do + rc=0 + docker exec "$container_id" ldapsearch \ + -x \ + -H "ldap://127.0.0.1:389" \ + -D "$LDAP_ADMIN_DN" \ + -w "$LDAP_ADMIN_PASSWORD" \ + -b "$LDAP_BASE_DN" \ + -s base \ + "(objectclass=*)" dn >/dev/null 2>&1 || rc=$? + if [[ "$rc" -eq 0 ]]; then + return 0 + fi + sleep 1 + done + + return 1 +} + +wait_for_ldap_test_user_bind() { + local container_id="$1" + local timeout_seconds="${2:-60}" + local deadline=$((SECONDS + timeout_seconds)) + local rc=0 + + while (( SECONDS < deadline )); do + rc=0 + docker exec "$container_id" ldapwhoami \ + -x \ + -H "ldap://127.0.0.1:389" \ + -D "cn=user01,ou=users,$LDAP_BASE_DN" \ + -w "password1" >/dev/null 2>&1 || rc=$? + if [[ "$rc" -eq 0 ]]; then + return 0 + fi + sleep 1 + done + + return 1 +} + +print_ldap_diagnostics() { + local container_id="${1:-}" + + echo "LDAP startup diagnostics:" >&2 + compose_cmd ps >&2 || true + + if [[ -z "$container_id" ]]; then + container_id="$(get_openldap_container_id)" + fi + + if [[ -n "$container_id" ]]; then + docker logs --tail 200 "$container_id" >&2 || true + else + compose_cmd logs --tail 200 openldap >&2 || true + fi +} + +ldap_entry_exists() { + local container_id="$1" + local dn="$2" + + docker exec "$container_id" ldapsearch \ + -x \ + -H "ldap://127.0.0.1:389" \ + -D "$LDAP_ADMIN_DN" \ + -w "$LDAP_ADMIN_PASSWORD" \ + -b "$dn" \ + -s base \ + "(objectclass=*)" dn >/dev/null 2>&1 +} + +ldap_add_if_missing() { + local container_id="$1" + local dn="$2" + local ldif="$3" + + if ldap_entry_exists "$container_id" "$dn"; then + return 0 + fi + + docker exec -i "$container_id" ldapadd \ + -x \ + -H "ldap://127.0.0.1:389" \ + -D "$LDAP_ADMIN_DN" \ + -w "$LDAP_ADMIN_PASSWORD" >/dev/null <&2 + print_ldap_diagnostics + exit 1 +fi + +if ! wait_for_ldap 120; then + echo "LDAP port ${LDAP_HOST}:${LDAP_PORT} did not become reachable in time." >&2 + print_ldap_diagnostics "$CONTAINER_ID" + exit 1 +fi + +echo "LDAP port ${LDAP_HOST}:${LDAP_PORT} is reachable. Waiting for slapd initialization..." +sleep 3 + +if ! wait_for_ldap_bind "$CONTAINER_ID" 120; then + echo "LDAP admin bind did not become ready in time." >&2 + print_ldap_diagnostics "$CONTAINER_ID" + exit 1 +fi + +seed_ldap_test_users "$CONTAINER_ID" + +if ! wait_for_ldap_test_user_bind "$CONTAINER_ID" 60; then + echo "LDAP test-user bind did not become ready in time." >&2 + print_ldap_diagnostics "$CONTAINER_ID" + exit 1 +fi + docker ps diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 299bdcb..bcae133 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -169,7 +169,7 @@ If you are already in a browser context, open: This redirects to the OIDC provider login page and then back to the server callback. -This can similarly be acheived using ``httpie`` by opening the URL in a browser after getting +This can similarly be acheived using ``httpie`` by opening the URL in a browser after getting the authorization URI from the server:: http POST http://localhost:60610/api/auth/provider/entra/authorize @@ -183,7 +183,7 @@ spawn a browser for the user to log in to the provider. CLI/device flow *************** -For terminal clients (i.e. no browser possible), start with +For terminal clients (i.e. no browser possible), start with ``POST /api/auth/provider//authorize``. The response includes: diff --git a/scripts/run_ci_docker_parallel.sh b/scripts/run_ci_docker_parallel.sh index c9caee7..efb6594 100755 --- a/scripts/run_ci_docker_parallel.sh +++ b/scripts/run_ci_docker_parallel.sh @@ -8,8 +8,10 @@ CHUNK_COUNT="" PYTHON_VERSIONS="latest" PYTEST_EXTRA_ARGS="" ARTIFACTS_DIR="$ROOT_DIR/.docker-test-artifacts" -DOCKER_NETWORK_NAME="bhs-ci-net" -LDAP_CONTAINER_NAME="bhs-ci-ldap" +LDAP_COMPOSE_FILE="$ROOT_DIR/continuous_integration/docker-configs/ldap-docker-compose.yml" +LDAP_COMPOSE_PROJECT="bhs-ci-ldap-parallel-$$" +LDAP_SERVICE_NAME="openldap" +DOCKER_NETWORK_NAME="${LDAP_COMPOSE_PROJECT}_default" SUMMARY_TSV="" SUMMARY_FAIL_LOGS="" @@ -154,46 +156,16 @@ normalize_python_versions() { echo "${normalized[@]}" } -ensure_ldap_image() { - local image_ref="bitnami/openldap:latest" - if docker image inspect "$image_ref" >/dev/null 2>&1; then - return - fi - - echo "LDAP image $image_ref not found locally; trying docker pull..." - if docker pull "$image_ref"; then - return - fi - - echo "docker pull failed; building bitnami/openldap:latest from source (CI fallback)." - local workdir="$ROOT_DIR/.docker-test-artifacts/bitnami-containers" - rm -rf "$workdir" - git clone --depth 1 https://github.com/bitnami/containers.git "$workdir" - (cd "$workdir/bitnami/openldap/2.6/debian-12" && docker build -t "$image_ref" .) -} - start_services() { - ensure_ldap_image - - docker network rm "$DOCKER_NETWORK_NAME" >/dev/null 2>&1 || true - docker network create "$DOCKER_NETWORK_NAME" >/dev/null - - docker rm -f "$LDAP_CONTAINER_NAME" >/dev/null 2>&1 || true - docker run -d --rm \ - --name "$LDAP_CONTAINER_NAME" \ - --network "$DOCKER_NETWORK_NAME" \ - -e LDAP_ADMIN_USERNAME=admin \ - -e LDAP_ADMIN_PASSWORD=adminpassword \ - -e LDAP_USERS=user01,user02 \ - -e LDAP_PASSWORDS=password1,password2 \ - bitnami/openldap:latest >/dev/null - - sleep 2 + LDAP_COMPOSE_FILE="$LDAP_COMPOSE_FILE" \ + LDAP_COMPOSE_PROJECT="$LDAP_COMPOSE_PROJECT" \ + LDAP_HOST="127.0.0.1" \ + LDAP_PORT="1389" \ + bash "$ROOT_DIR/continuous_integration/scripts/start_LDAP.sh" >/dev/null } stop_services() { - docker rm -f "$LDAP_CONTAINER_NAME" >/dev/null 2>&1 || true - docker network rm "$DOCKER_NETWORK_NAME" >/dev/null 2>&1 || true + docker compose -p "$LDAP_COMPOSE_PROJECT" -f "$LDAP_COMPOSE_FILE" down -v >/dev/null 2>&1 || true } cleanup() { @@ -385,8 +357,8 @@ run_chunk() { -e SHARD_COUNT="$CHUNK_COUNT" \ -e ARTIFACTS_DIR="/artifacts" \ -e PYTEST_EXTRA_ARGS="$PYTEST_EXTRA_ARGS" \ - -e QSERVER_TEST_LDAP_HOST="$LDAP_CONTAINER_NAME" \ - -e QSERVER_TEST_LDAP_PORT="1389" \ + -e QSERVER_TEST_LDAP_HOST="$LDAP_SERVICE_NAME" \ + -e QSERVER_TEST_LDAP_PORT="389" \ -e QSERVER_TEST_REDIS_ADDR="localhost" \ -e QSERVER_HTTP_TEST_BIND_HOST="127.0.0.1" \ -e QSERVER_HTTP_TEST_HOST="127.0.0.1" \ @@ -400,7 +372,7 @@ run_chunk() { } export -f run_chunk -export CHUNK_COUNT PYTEST_EXTRA_ARGS DOCKER_NETWORK_NAME LDAP_CONTAINER_NAME +export CHUNK_COUNT PYTEST_EXTRA_ARGS DOCKER_NETWORK_NAME LDAP_SERVICE_NAME for PYTHON_VERSION in "${SELECTED_PYTHON_VERSIONS[@]}"; do CURRENT_IMAGE_TAG="${IMAGE_TAG_BASE}-py${PYTHON_VERSION}" @@ -410,7 +382,7 @@ for PYTHON_VERSION in "${SELECTED_PYTHON_VERSIONS[@]}"; do echo "==> Building test image: $CURRENT_IMAGE_TAG (Python $PYTHON_VERSION)" docker build \ --build-arg PYTHON_VERSION="$PYTHON_VERSION" \ - -f "$ROOT_DIR/docker/test.Dockerfile" \ + -f "$ROOT_DIR/continuous_integration/dockerfiles/test.Dockerfile" \ -t "$CURRENT_IMAGE_TAG" \ "$ROOT_DIR"