From 0395cec56b088cc95c1dcea2625379d9bc506315 Mon Sep 17 00:00:00 2001 From: arbaz Date: Tue, 12 May 2026 01:10:47 +0530 Subject: [PATCH] Add durable control-plane state storage --- src/api/dependencies.py | 45 +--- src/api/routes/admin.py | 39 ++- src/api/routes/api_keys.py | 46 ++-- src/api/routes/auth.py | 130 ++++------ src/database/api_key_store.py | 19 ++ src/database/control_plane_store.py | 300 ++++++++++++++++++++++ tests/api/test_dependencies_and_routes.py | 13 +- tests/unit/test_control_plane_store.py | 54 ++++ tests/unit/test_database_stores.py | 4 + 9 files changed, 505 insertions(+), 145 deletions(-) create mode 100644 src/database/control_plane_store.py create mode 100644 tests/unit/test_control_plane_store.py diff --git a/src/api/dependencies.py b/src/api/dependencies.py index 65c3f92..40186a3 100644 --- a/src/api/dependencies.py +++ b/src/api/dependencies.py @@ -10,8 +10,6 @@ import hashlib import hmac import logging -import time -from collections import defaultdict from typing import Optional from fastapi import Depends, HTTPException, Request, status @@ -19,6 +17,7 @@ from jose import JWTError, jwt from src.config import settings +from src.database.control_plane_store import control_plane_store from src.database.api_key_store import APIKeyStore from src.database.user_store import UserStore from src.pipelines.ingest import IngestPipeline @@ -175,7 +174,13 @@ async def require_api_key( return user # 2. Check MongoDB for user-generated API keys - key_doc = _api_key_store.validate_api_key(token) + try: + key_doc = _api_key_store.validate_api_key(token) + except RuntimeError as exc: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=str(exc), + ) if key_doc: user_id = key_doc.get("user_id") if user_id: @@ -276,40 +281,14 @@ async def require_user(current_user: Optional[dict] = Depends(get_current_user)) return current_user -# ═══════════════════════════════════════════════════════════════════════════ -# Sliding-window rate limiter (in-process, per-key) -# ═══════════════════════════════════════════════════════════════════════════ - -class _SlidingWindowRateLimiter: - """Thread-safe sliding-window counter keyed by API identity.""" - - def __init__(self, max_requests: int, window_seconds: int = 60): - self.max_requests = max_requests - self.window = window_seconds - self._hits: dict[str, list[float]] = defaultdict(list) - self._lock = asyncio.Lock() +class _ControlPlaneRateLimiter: + """Rate limiter backed by shared control-plane storage.""" async def check(self, key: str) -> tuple[bool, int]: - """Return (allowed, remaining) for *key*.""" - now = time.monotonic() - cutoff = now - self.window - - async with self._lock: - timestamps = self._hits[key] - self._hits[key] = [t for t in timestamps if t > cutoff] - - if len(self._hits[key]) >= self.max_requests: - return False, 0 - - self._hits[key].append(now) - remaining = self.max_requests - len(self._hits[key]) - return True, remaining + return control_plane_store.check_rate_limit(key, settings.rate_limit, 60) -_rate_limiter = _SlidingWindowRateLimiter( - max_requests=settings.rate_limit, - window_seconds=60, -) +_rate_limiter = _ControlPlaneRateLimiter() async def enforce_rate_limit( diff --git a/src/api/routes/admin.py b/src/api/routes/admin.py index 2207af3..7101587 100644 --- a/src/api/routes/admin.py +++ b/src/api/routes/admin.py @@ -34,6 +34,7 @@ from src.config import settings from src.config.analytics import analytics +from src.database.control_plane_store import control_plane_store logger = logging.getLogger("xmem.api.admin") @@ -45,7 +46,6 @@ # ═══════════════════════════════════════════════════════════════════════════ _admin_collection = None -_admin_sessions: Dict[str, Dict[str, Any]] = {} # token → {user, expires} def _get_admin_collection(): @@ -88,15 +88,18 @@ def _verify_admin_token(request: Request) -> Dict[str, Any]: if auth.startswith("Bearer "): token = auth[7:] - if not token or token not in _admin_sessions: + if not token: raise HTTPException(status_code=401, detail="Not authenticated") - session = _admin_sessions[token] - if datetime.now(timezone.utc) > session["expires"]: - del _admin_sessions[token] + try: + session_user = control_plane_store.get_admin_session(token) + except RuntimeError: + raise HTTPException(status_code=503, detail="Admin session storage unavailable") + + if not session_user: raise HTTPException(status_code=401, detail="Session expired") - return session["user"] + return session_user # ═══════════════════════════════════════════════════════════════════════════ @@ -115,11 +118,10 @@ async def admin_login(req: AdminLoginRequest): raise HTTPException(status_code=401, detail="Invalid credentials") # Generate session token - token = hashlib.sha256(f"{req.username}{time.time()}".encode()).hexdigest() - _admin_sessions[token] = { - "user": {"username": user["username"], "role": user.get("role", "admin")}, - "expires": datetime.now(timezone.utc) + timedelta(hours=24), - } + token = control_plane_store.create_admin_session( + {"username": user["username"], "role": user.get("role", "admin")}, + ttl_hours=24, + ) response = JSONResponse({"status": "ok", "token": token, "username": user["username"]}) response.set_cookie( @@ -135,8 +137,11 @@ async def admin_login(req: AdminLoginRequest): @router.post("/api/logout") async def admin_logout(request: Request): token = request.cookies.get("xmem_admin_token") - if token and token in _admin_sessions: - del _admin_sessions[token] + if token: + try: + control_plane_store.delete_admin_session(token) + except RuntimeError: + pass response = JSONResponse({"status": "ok"}) response.delete_cookie("xmem_admin_token") return response @@ -220,7 +225,13 @@ async def ws_live_logs(websocket: WebSocket): # Validate auth token from query param token = websocket.query_params.get("token", "") - if token not in _admin_sessions: + try: + session_user = control_plane_store.get_admin_session(token) + except RuntimeError: + await websocket.close(code=1011, reason="Admin session storage unavailable") + return + + if not session_user: await websocket.close(code=4001, reason="Not authenticated") return diff --git a/src/api/routes/api_keys.py b/src/api/routes/api_keys.py index 8d27681..5fccffb 100644 --- a/src/api/routes/api_keys.py +++ b/src/api/routes/api_keys.py @@ -87,10 +87,13 @@ async def list_api_keys( Returns metadata about each key but NOT the actual key values. """ - keys = api_key_store.get_user_api_keys( - user_id=current_user["id"], - include_inactive=include_inactive - ) + try: + keys = api_key_store.get_user_api_keys( + user_id=current_user["id"], + include_inactive=include_inactive + ) + except RuntimeError as exc: + raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(exc)) # Convert to response model key_responses = [ @@ -118,10 +121,13 @@ async def create_api_key( WARNING: The full API key is only returned once in this response. Make sure to save it securely - it cannot be retrieved again. """ - result = api_key_store.create_api_key( - user_id=current_user["id"], - name=request.name, - ) + try: + result = api_key_store.create_api_key( + user_id=current_user["id"], + name=request.name, + ) + except RuntimeError as exc: + raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(exc)) return APIKeyCreateResponse( key=result["key"], @@ -138,11 +144,14 @@ async def update_api_key( current_user: dict = Depends(require_auth), ): """Update an API key's name.""" - success = api_key_store.update_api_key_name( - user_id=current_user["id"], - key_id=key_id, - new_name=request.name, - ) + try: + success = api_key_store.update_api_key_name( + user_id=current_user["id"], + key_id=key_id, + new_name=request.name, + ) + except RuntimeError as exc: + raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(exc)) if not success: raise HTTPException( @@ -180,10 +189,13 @@ async def revoke_api_key( Once revoked, the key cannot be used for authentication. This action is reversible - you can reactivate a key if needed (not implemented yet). """ - success = api_key_store.revoke_api_key( - user_id=current_user["id"], - key_id=key_id, - ) + try: + success = api_key_store.revoke_api_key( + user_id=current_user["id"], + key_id=key_id, + ) + except RuntimeError as exc: + raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(exc)) if not success: raise HTTPException( diff --git a/src/api/routes/auth.py b/src/api/routes/auth.py index a4a70b1..e5d5b37 100644 --- a/src/api/routes/auth.py +++ b/src/api/routes/auth.py @@ -1,7 +1,5 @@ """Authentication routes for Google OAuth and JWT management.""" -import secrets -import string from datetime import datetime, timedelta from typing import Optional, Dict, Any @@ -14,6 +12,7 @@ from src.api.dependencies import get_current_user, require_api_key, require_user from src.config import settings +from src.database.control_plane_store import control_plane_store from src.database.user_store import UserStore from src.database.api_key_store import APIKeyStore @@ -23,87 +22,32 @@ user_store = UserStore() api_key_store = APIKeyStore() -# ═══════════════════════════════════════════════════════════════════════════ -# MCP OAuth Temp Token Store (in-memory with TTL) -# ═══════════════════════════════════════════════════════════════════════════ -_mcp_temp_tokens: Dict[str, Dict[str, Any]] = {} TEMP_TOKEN_PREFIX = "xm-temp-" TEMP_TOKEN_TTL_MINUTES = 10 -TEMP_TOKEN_LENGTH = 32 - - -def _generate_mcp_temp_token() -> str: - """Generate a temporary token for MCP OAuth flow.""" - alphabet = string.ascii_letters + string.digits - random_part = "".join(secrets.choice(alphabet) for _ in range(TEMP_TOKEN_LENGTH)) - return f"{TEMP_TOKEN_PREFIX}{random_part}" def _create_mcp_temp_token(user_id: str) -> str: """Create and store a temporary token for the user.""" - token = _generate_mcp_temp_token() - expires_at = datetime.utcnow() + timedelta(minutes=TEMP_TOKEN_TTL_MINUTES) - - _mcp_temp_tokens[token] = { - "user_id": user_id, - "created_at": datetime.utcnow(), - "expires_at": expires_at, - "exchanged": False, - } - + token, _ = control_plane_store.create_temp_token( + user_id=user_id, + ttl_minutes=TEMP_TOKEN_TTL_MINUTES, + prefix=TEMP_TOKEN_PREFIX, + ) return token def _get_and_invalidate_mcp_token(token: str) -> Optional[str]: """Validate temp token and return user_id if valid, None otherwise.""" - if token not in _mcp_temp_tokens: - return None - - token_data = _mcp_temp_tokens[token] - - # Check expiry - if datetime.utcnow() > token_data["expires_at"]: - del _mcp_temp_tokens[token] - return None + return control_plane_store.consume_temp_token(token) - # Check if already exchanged - if token_data["exchanged"]: - return None - - # Mark as exchanged and return user_id - user_id = token_data["user_id"] - del _mcp_temp_tokens[token] # Single-use token - return user_id - - -# ═══════════════════════════════════════════════════════════════════════════ -# Standard OAuth 2.0 Store (for ChatGPT UI) -# ═══════════════════════════════════════════════════════════════════════════ -_oauth_auth_codes: Dict[str, Dict[str, Any]] = {} def _generate_auth_code(user_id: str) -> str: """Generate a standard OAuth 2.0 authorization code.""" - alphabet = string.ascii_letters + string.digits - code = "".join(secrets.choice(alphabet) for _ in range(32)) - - _oauth_auth_codes[code] = { - "user_id": user_id, - "expires_at": datetime.utcnow() + timedelta(minutes=10) - } - return code + return control_plane_store.create_auth_code(user_id=user_id, ttl_minutes=10) def _get_and_invalidate_auth_code(code: str) -> Optional[str]: """Validate auth code and return user_id if valid.""" - if code not in _oauth_auth_codes: - return None - - data = _oauth_auth_codes[code] - del _oauth_auth_codes[code] # Single-use - - if datetime.utcnow() > data["expires_at"]: - return None - - return data["user_id"] + return control_plane_store.consume_auth_code(code) # ═══════════════════════════════════════════════════════════════════════════ @@ -460,12 +404,18 @@ async def generate_mcp_temp_token(current_user: dict = Depends(require_user)): ) user_id = str(current_user.get("id")) - temp_token = _create_mcp_temp_token(user_id) + try: + temp_token = _create_mcp_temp_token(user_id) + except RuntimeError: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Authentication storage unavailable", + ) return MCPTempTokenResponse( temp_token=temp_token, expires_in=TEMP_TOKEN_TTL_MINUTES * 60, - expires_at=_mcp_temp_tokens[temp_token]["expires_at"] + expires_at=datetime.utcnow() + timedelta(minutes=TEMP_TOKEN_TTL_MINUTES) ) @@ -480,7 +430,13 @@ async def exchange_mcp_token(request: MCPExchangeRequest): The temp token is single-use and invalidated after exchange. """ # Validate and consume the temp token - user_id = _get_and_invalidate_mcp_token(request.temp_token) + try: + user_id = _get_and_invalidate_mcp_token(request.temp_token) + except RuntimeError: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Authentication storage unavailable", + ) if not user_id: raise HTTPException( @@ -497,10 +453,16 @@ async def exchange_mcp_token(request: MCPExchangeRequest): ) # Create a new API key for this user - key_result = api_key_store.create_api_key( - user_id=user_id, - name=f"MCP Client - {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" - ) + try: + key_result = api_key_store.create_api_key( + user_id=user_id, + name=f"MCP Client - {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" + ) + except RuntimeError: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="API key storage unavailable", + ) # Prepare user response user_response = { @@ -531,7 +493,13 @@ async def oauth_approve(request: OAuthApproveRequest, current_user: dict = Depen raise HTTPException(status_code=401, detail="Authentication required") user_id = str(current_user.get("id")) - code = _generate_auth_code(user_id) + try: + code = _generate_auth_code(user_id) + except RuntimeError: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Authentication storage unavailable", + ) return OAuthApproveResponse(code=code) @@ -555,15 +523,21 @@ async def oauth_token( if not code: return JSONResponse(status_code=400, content={"error": "invalid_request", "error_description": "code is required"}) - user_id = _get_and_invalidate_auth_code(code) + try: + user_id = _get_and_invalidate_auth_code(code) + except RuntimeError: + return JSONResponse(status_code=503, content={"error": "server_error", "error_description": "Authentication storage unavailable"}) if not user_id: return JSONResponse(status_code=400, content={"error": "invalid_grant", "error_description": "Invalid or expired authorization code"}) # Generate a permanent API key acting as the access token - key_result = api_key_store.create_api_key( - user_id=user_id, - name=f"OAuth Client ({client_id or 'Unknown'}) - {datetime.utcnow().strftime('%Y-%m-%d')}" - ) + try: + key_result = api_key_store.create_api_key( + user_id=user_id, + name=f"OAuth Client ({client_id or 'Unknown'}) - {datetime.utcnow().strftime('%Y-%m-%d')}" + ) + except RuntimeError: + return JSONResponse(status_code=503, content={"error": "server_error", "error_description": "API key storage unavailable"}) return { "access_token": key_result["key"], diff --git a/src/database/api_key_store.py b/src/database/api_key_store.py index 9f017e4..75dd4c6 100644 --- a/src/database/api_key_store.py +++ b/src/database/api_key_store.py @@ -57,6 +57,10 @@ def _try_connect(self) -> None: self._in_memory = True self.api_keys = None + def _require_durable_storage(self) -> None: + if self._in_memory and settings.environment.lower() == "production": + raise RuntimeError("MongoDB is required for API key storage in production") + def _ensure_indexes(self) -> None: """Create necessary indexes for the api_keys collection.""" if not self._connected: @@ -85,6 +89,7 @@ def create_api_key( name: str = "Default", ) -> Dict[str, Any]: """Create a new API key for a user.""" + self._require_durable_storage() key = self._generate_api_key() key_hash = self._hash_key(key) key_prefix = key[:8] @@ -131,11 +136,14 @@ def create_api_key( } except Exception as e: logger.error(f"Database error creating API key: {e}") + if settings.environment.lower() == "production": + raise RuntimeError("Failed to create API key") from e self._in_memory = True return self.create_api_key(user_id, name) def get_user_api_keys(self, user_id: str, include_inactive: bool = False) -> List[Dict[str, Any]]: """Get all API keys for a user.""" + self._require_durable_storage() if self._in_memory: keys = [ {**k, "id": str(k["_id"])} @@ -159,10 +167,13 @@ def get_user_api_keys(self, user_id: str, include_inactive: bool = False) -> Lis return keys except Exception as e: logger.error(f"Database error getting API keys: {e}") + if settings.environment.lower() == "production": + raise RuntimeError("Failed to get API keys") from e return [] def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]: """Validate an API key and return associated user info.""" + self._require_durable_storage() key_hash = self._hash_key(key) if self._in_memory: @@ -191,10 +202,13 @@ def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]: return key_doc except Exception as e: logger.error(f"Database error validating API key: {e}") + if settings.environment.lower() == "production": + raise RuntimeError("Failed to validate API key") from e return None def revoke_api_key(self, user_id: str, key_id: str) -> bool: """Revoke (deactivate) an API key.""" + self._require_durable_storage() if self._in_memory: if key_id in _in_memory_api_keys: if _in_memory_api_keys[key_id].get("user_id") == user_id: @@ -211,6 +225,8 @@ def revoke_api_key(self, user_id: str, key_id: str) -> bool: return result.modified_count > 0 except Exception as e: logger.error(f"Failed to revoke API key {key_id}: {e}") + if settings.environment.lower() == "production": + raise RuntimeError("Failed to revoke API key") from e return False def update_api_key_name( @@ -220,6 +236,7 @@ def update_api_key_name( new_name: str, ) -> bool: """Update the name of an API key.""" + self._require_durable_storage() if self._in_memory: if key_id in _in_memory_api_keys: if _in_memory_api_keys[key_id].get("user_id") == user_id: @@ -236,6 +253,8 @@ def update_api_key_name( return result.modified_count > 0 except Exception as e: logger.error(f"Failed to update API key name {key_id}: {e}") + if settings.environment.lower() == "production": + raise RuntimeError("Failed to update API key") from e return False def close(self) -> None: diff --git a/src/database/control_plane_store.py b/src/database/control_plane_store.py new file mode 100644 index 0000000..f038dfb --- /dev/null +++ b/src/database/control_plane_store.py @@ -0,0 +1,300 @@ +"""Shared durable storage for auth codes, temp tokens, admin sessions, and rate limits.""" + +from __future__ import annotations + +import hashlib +import logging +import secrets +import string +import time +from datetime import datetime, timedelta +from typing import Any, Dict, Optional, Tuple + +from src.config import settings + +logger = logging.getLogger("xmem.database.control_plane_store") + +_in_memory_temp_tokens: Dict[str, Dict[str, Any]] = {} +_in_memory_auth_codes: Dict[str, Dict[str, Any]] = {} +_in_memory_admin_sessions: Dict[str, Dict[str, Any]] = {} +_in_memory_rate_limits: Dict[str, Dict[str, Any]] = {} + + +def _hash_secret(value: str) -> str: + return hashlib.sha256(value.encode()).hexdigest() + + +def _random_token(prefix: str, length: int) -> str: + alphabet = string.ascii_letters + string.digits + random_part = "".join(secrets.choice(alphabet) for _ in range(length)) + return f"{prefix}{random_part}" + + +class ControlPlaneStore: + """MongoDB-backed storage for cross-instance auth/session state. + + The store keeps an in-memory fallback for development and tests, but + production mode refuses to serve from memory if MongoDB is unavailable. + """ + + def __init__(self, uri: str = None, database: str = None) -> None: + self._uri = uri or settings.mongodb_uri + self._database = database or settings.mongodb_database + self._client = None + self._db = None + self._connected = False + self._in_memory = False + + self.temp_tokens = None + self.auth_codes = None + self.admin_sessions = None + self.rate_limits = None + + self._try_connect() + + def _try_connect(self) -> None: + try: + from pymongo import MongoClient + + self._client = MongoClient(self._uri, serverSelectionTimeoutMS=5000) + self._client.admin.command("ping") + self._db = self._client[self._database] + self.temp_tokens = self._db["mcp_temp_tokens"] + self.auth_codes = self._db["oauth_auth_codes"] + self.admin_sessions = self._db["admin_sessions"] + self.rate_limits = self._db["rate_limits"] + self._connected = True + self._in_memory = False + self._ensure_indexes() + logger.info("Connected to MongoDB for control-plane storage") + except Exception as exc: + self._connected = False + self._in_memory = True + logger.warning("MongoDB connection failed, using in-memory control-plane storage: %s", exc) + + def _ensure_indexes(self) -> None: + if not self._connected: + return + try: + from pymongo import ASCENDING + + self.temp_tokens.create_index([("token_hash", ASCENDING)], unique=True) + self.temp_tokens.create_index([("expires_at", ASCENDING)], expireAfterSeconds=0) + + self.auth_codes.create_index([("code_hash", ASCENDING)], unique=True) + self.auth_codes.create_index([("expires_at", ASCENDING)], expireAfterSeconds=0) + + self.admin_sessions.create_index([("session_hash", ASCENDING)], unique=True) + self.admin_sessions.create_index([("expires_at", ASCENDING)], expireAfterSeconds=0) + + self.rate_limits.create_index([("identity", ASCENDING), ("window_key", ASCENDING)], unique=True) + self.rate_limits.create_index([("window_expires_at", ASCENDING)], expireAfterSeconds=0) + except Exception as exc: + logger.warning("Failed to create control-plane indexes: %s", exc) + + def _require_durable_storage(self) -> None: + if self._in_memory and settings.environment.lower() == "production": + raise RuntimeError("MongoDB is required for control-plane state in production") + + def create_temp_token(self, user_id: str, ttl_minutes: int, prefix: str = "xm-temp-") -> Tuple[str, datetime]: + token = _random_token(prefix, 32) + expires_at = datetime.utcnow() + timedelta(minutes=ttl_minutes) + token_hash = _hash_secret(token) + doc = { + "token_hash": token_hash, + "user_id": user_id, + "created_at": datetime.utcnow(), + "expires_at": expires_at, + } + + if self._in_memory: + self._require_durable_storage() + _in_memory_temp_tokens[token_hash] = doc + return token, expires_at + + try: + self.temp_tokens.insert_one(doc) + return token, expires_at + except Exception as exc: + logger.error("Failed to create temp token: %s", exc) + raise RuntimeError("Failed to create temp token") from exc + + def consume_temp_token(self, token: str) -> Optional[str]: + token_hash = _hash_secret(token) + + if self._in_memory: + self._require_durable_storage() + token_doc = _in_memory_temp_tokens.get(token_hash) + if not token_doc: + return None + if datetime.utcnow() > token_doc["expires_at"]: + _in_memory_temp_tokens.pop(token_hash, None) + return None + _in_memory_temp_tokens.pop(token_hash, None) + return token_doc["user_id"] + + try: + now = datetime.utcnow() + token_doc = self.temp_tokens.find_one_and_delete( + {"token_hash": token_hash, "expires_at": {"$gt": now}}, + ) + if not token_doc: + return None + return token_doc["user_id"] + except Exception as exc: + logger.error("Failed to consume temp token: %s", exc) + raise RuntimeError("Failed to consume temp token") from exc + + def create_auth_code(self, user_id: str, ttl_minutes: int = 10) -> str: + code = _random_token("", 32) + code_hash = _hash_secret(code) + expires_at = datetime.utcnow() + timedelta(minutes=ttl_minutes) + doc = { + "code_hash": code_hash, + "user_id": user_id, + "created_at": datetime.utcnow(), + "expires_at": expires_at, + } + + if self._in_memory: + self._require_durable_storage() + _in_memory_auth_codes[code_hash] = doc + return code + + try: + self.auth_codes.insert_one(doc) + return code + except Exception as exc: + logger.error("Failed to create auth code: %s", exc) + raise RuntimeError("Failed to create auth code") from exc + + def consume_auth_code(self, code: str) -> Optional[str]: + code_hash = _hash_secret(code) + + if self._in_memory: + self._require_durable_storage() + code_doc = _in_memory_auth_codes.get(code_hash) + if not code_doc: + return None + if datetime.utcnow() > code_doc["expires_at"]: + _in_memory_auth_codes.pop(code_hash, None) + return None + _in_memory_auth_codes.pop(code_hash, None) + return code_doc["user_id"] + + try: + now = datetime.utcnow() + code_doc = self.auth_codes.find_one_and_delete( + {"code_hash": code_hash, "expires_at": {"$gt": now}}, + ) + if not code_doc: + return None + return code_doc["user_id"] + except Exception as exc: + logger.error("Failed to consume auth code: %s", exc) + raise RuntimeError("Failed to consume auth code") from exc + + def create_admin_session(self, user: Dict[str, Any], ttl_hours: int = 24) -> str: + token = _hash_secret(f"{user.get('username', 'admin')}:{secrets.token_hex(32)}:{time.time()}") + expires_at = datetime.utcnow() + timedelta(hours=ttl_hours) + doc = { + "session_hash": token, + "user": user, + "created_at": datetime.utcnow(), + "expires_at": expires_at, + } + + if self._in_memory: + self._require_durable_storage() + _in_memory_admin_sessions[token] = doc + return token + + try: + self.admin_sessions.insert_one(doc) + return token + except Exception as exc: + logger.error("Failed to create admin session: %s", exc) + raise RuntimeError("Failed to create admin session") from exc + + def get_admin_session(self, token: str) -> Optional[Dict[str, Any]]: + session_hash = token + + if self._in_memory: + self._require_durable_storage() + session_doc = _in_memory_admin_sessions.get(session_hash) + if not session_doc: + return None + if datetime.utcnow() > session_doc["expires_at"]: + _in_memory_admin_sessions.pop(session_hash, None) + return None + return session_doc["user"] + + try: + now = datetime.utcnow() + session_doc = self.admin_sessions.find_one( + {"session_hash": session_hash, "expires_at": {"$gt": now}}, + ) + if not session_doc: + return None + return session_doc["user"] + except Exception as exc: + logger.error("Failed to fetch admin session: %s", exc) + raise RuntimeError("Failed to fetch admin session") from exc + + def delete_admin_session(self, token: str) -> None: + session_hash = token + + if self._in_memory: + self._require_durable_storage() + _in_memory_admin_sessions.pop(session_hash, None) + return + + try: + self.admin_sessions.delete_one({"session_hash": session_hash}) + except Exception as exc: + logger.error("Failed to delete admin session: %s", exc) + raise RuntimeError("Failed to delete admin session") from exc + + def check_rate_limit(self, identity: str, max_requests: int, window_seconds: int = 60) -> tuple[bool, int]: + window_key = int(time.time() // window_seconds) + window_expires_at = datetime.utcnow() + timedelta(seconds=window_seconds) + + if self._in_memory: + self._require_durable_storage() + bucket = _in_memory_rate_limits.get(identity) + if not bucket or bucket.get("window_key") != window_key: + bucket = {"window_key": window_key, "count": 0} + if bucket["count"] >= max_requests: + return False, 0 + bucket["count"] += 1 + _in_memory_rate_limits[identity] = bucket + return True, max_requests - bucket["count"] + + try: + from pymongo import ReturnDocument + + doc = self.rate_limits.find_one_and_update( + {"identity": identity, "window_key": window_key}, + { + "$setOnInsert": { + "identity": identity, + "window_key": window_key, + "count": 0, + "window_started_at": datetime.utcnow(), + "window_expires_at": window_expires_at, + }, + "$inc": {"count": 1}, + }, + upsert=True, + return_document=ReturnDocument.AFTER, + ) + count = int(doc.get("count", 0)) + if count > max_requests: + return False, 0 + return True, max_requests - count + except Exception as exc: + logger.error("Failed to check rate limit: %s", exc) + raise RuntimeError("Failed to check rate limit") from exc + + +control_plane_store = ControlPlaneStore() diff --git a/tests/api/test_dependencies_and_routes.py b/tests/api/test_dependencies_and_routes.py index 6fbdd9b..174331f 100644 --- a/tests/api/test_dependencies_and_routes.py +++ b/tests/api/test_dependencies_and_routes.py @@ -10,6 +10,7 @@ from src.api import dependencies as deps from src.api.middleware import RequestContextMiddleware, SecurityHeadersMiddleware from src.api.routes.health import router as health_router +from src.database.control_plane_store import _in_memory_rate_limits from src.schemas.retrieval import RetrievalResult @@ -35,6 +36,7 @@ def close(self): @pytest.fixture def dependency_app(monkeypatch): + monkeypatch.setattr(deps.settings, "environment", "dev", raising=False) monkeypatch.setattr(deps.settings, "api_keys", ["test-static-key"], raising=False) deps._init_error = None deps._pipelines_ready.set() @@ -87,6 +89,11 @@ def test_dependency_injection_returns_configured_pipeline(dependency_app): @pytest.mark.asyncio async def test_rate_limiter_blocks_after_limit(monkeypatch): - limiter = deps._SlidingWindowRateLimiter(max_requests=1, window_seconds=60) - assert await limiter.check("user-1") == (True, 0) - assert await limiter.check("user-1") == (False, 0) + monkeypatch.setattr(deps.settings, "environment", "dev", raising=False) + _in_memory_rate_limits.clear() + assert await deps._rate_limiter.check("user-1") == (True, deps.settings.rate_limit - 1) + limiter = deps._rate_limiter + for _ in range(deps.settings.rate_limit - 1): + await limiter.check("user-2") + assert await limiter.check("user-2") == (True, 0) + assert await limiter.check("user-2") == (False, 0) diff --git a/tests/unit/test_control_plane_store.py b/tests/unit/test_control_plane_store.py new file mode 100644 index 0000000..2156477 --- /dev/null +++ b/tests/unit/test_control_plane_store.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from datetime import datetime + +from src.config import settings +from src.database.control_plane_store import ( + ControlPlaneStore, + _in_memory_admin_sessions, + _in_memory_auth_codes, + _in_memory_rate_limits, + _in_memory_temp_tokens, +) + + +def _force_control_memory(self): + self._connected = False + self._in_memory = True + self.temp_tokens = None + self.auth_codes = None + self.admin_sessions = None + self.rate_limits = None + + +def test_control_plane_store_handles_tokens_sessions_and_rate_limits_in_memory(monkeypatch): + monkeypatch.setattr(settings, "environment", "dev", raising=False) + monkeypatch.setattr(ControlPlaneStore, "_try_connect", _force_control_memory) + _in_memory_temp_tokens.clear() + _in_memory_auth_codes.clear() + _in_memory_admin_sessions.clear() + _in_memory_rate_limits.clear() + + store = ControlPlaneStore() + + token, expires_at = store.create_temp_token("user-1", ttl_minutes=10) + assert token.startswith("xm-temp-") + assert expires_at > datetime.utcnow() + assert store.consume_temp_token(token) == "user-1" + assert store.consume_temp_token(token) is None + + code = store.create_auth_code("user-2", ttl_minutes=10) + assert store.consume_auth_code(code) == "user-2" + assert store.consume_auth_code(code) is None + + session_token = store.create_admin_session({"username": "admin", "role": "superadmin"}, ttl_hours=24) + assert store.get_admin_session(session_token)["username"] == "admin" + store.delete_admin_session(session_token) + assert store.get_admin_session(session_token) is None + + allowed, remaining = store.check_rate_limit("user-3", max_requests=2, window_seconds=60) + assert allowed and remaining == 1 + allowed, remaining = store.check_rate_limit("user-3", max_requests=2, window_seconds=60) + assert allowed and remaining == 0 + allowed, remaining = store.check_rate_limit("user-3", max_requests=2, window_seconds=60) + assert not allowed and remaining == 0 diff --git a/tests/unit/test_database_stores.py b/tests/unit/test_database_stores.py index 400dd78..cbbf45f 100644 --- a/tests/unit/test_database_stores.py +++ b/tests/unit/test_database_stores.py @@ -1,6 +1,7 @@ from __future__ import annotations from src.database.api_key_store import APIKeyStore, _in_memory_api_keys +from src.config import settings from src.database.project_store import ProjectStore from src.database.models import TeamRole from src.database.user_store import UserStore, _in_memory_users @@ -25,6 +26,7 @@ def _force_project_memory(self): def test_api_key_store_creates_validates_updates_and_revokes_in_memory(monkeypatch): _in_memory_api_keys.clear() + monkeypatch.setattr(settings, "environment", "dev", raising=False) monkeypatch.setattr(APIKeyStore, "_try_connect", _force_api_key_memory) store = APIKeyStore() @@ -45,6 +47,7 @@ def test_api_key_store_creates_validates_updates_and_revokes_in_memory(monkeypat def test_user_store_get_or_create_and_username_helpers_in_memory(monkeypatch): _in_memory_users.clear() + monkeypatch.setattr(settings, "environment", "dev", raising=False) monkeypatch.setattr(UserStore, "_try_connect", _force_user_memory) store = UserStore() @@ -63,6 +66,7 @@ def test_user_store_get_or_create_and_username_helpers_in_memory(monkeypatch): def test_project_store_team_permissions_in_memory(monkeypatch): + monkeypatch.setattr(settings, "environment", "dev", raising=False) monkeypatch.setattr(ProjectStore, "_try_connect", _force_project_memory) store = ProjectStore()