From 60ff58a35f1c0d663ed2d5f14ecc519af633b18f Mon Sep 17 00:00:00 2001 From: webhop123 Date: Mon, 11 May 2026 20:40:18 +0100 Subject: [PATCH 1/2] Add API key production storage guards --- src/api/dependencies.py | 13 +++- src/api/routes/api_keys.py | 28 +++++++ src/config/settings.py | 2 +- src/database/api_key_store.py | 120 +++++++++++++++++++++++++++-- src/database/user_store.py | 30 +++++--- tests/unit/test_database_stores.py | 52 +++++++++++++ 6 files changed, 227 insertions(+), 18 deletions(-) diff --git a/src/api/dependencies.py b/src/api/dependencies.py index 65c3f92..9cf6e40 100644 --- a/src/api/dependencies.py +++ b/src/api/dependencies.py @@ -12,7 +12,7 @@ import logging import time from collections import defaultdict -from typing import Optional +from typing import TYPE_CHECKING, Optional from fastapi import Depends, HTTPException, Request, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -24,6 +24,9 @@ from src.pipelines.ingest import IngestPipeline from src.pipelines.retrieval import RetrievalPipeline +if TYPE_CHECKING: + from src.pipelines.code_retrieval import CodeRetrievalPipeline + logger = logging.getLogger("xmem.api.deps") # Initialize stores @@ -181,7 +184,15 @@ async def require_api_key( if user_id: user = _user_store.get_user_by_id(user_id) if user: + user = dict(user) user["id"] = str(user.pop("_id")) + user["api_key"] = { + "id": key_doc.get("id"), + "scopes": key_doc.get("scopes", ["*"]), + "org_id": key_doc.get("org_id"), + "project_id": key_doc.get("project_id"), + "expires_at": key_doc.get("expires_at"), + } request.state.user = user return user diff --git a/src/api/routes/api_keys.py b/src/api/routes/api_keys.py index 8d27681..0ff0b99 100644 --- a/src/api/routes/api_keys.py +++ b/src/api/routes/api_keys.py @@ -22,6 +22,10 @@ class APIKeyCreateRequest(BaseModel): """Request model for creating a new API key.""" name: str = Field(default="Default", description="Name for this API key") + scopes: List[str] = Field(default_factory=lambda: ["*"], description="Allowed API scopes") + expires_at: Optional[datetime] = Field(default=None, description="Optional key expiration time") + org_id: Optional[str] = Field(default=None, description="Optional organization binding") + project_id: Optional[str] = Field(default=None, description="Optional project binding") class APIKeyCreateResponse(BaseModel): @@ -32,6 +36,10 @@ class APIKeyCreateResponse(BaseModel): key: str = Field(..., description="The full API key (only shown once)") key_id: str = Field(..., description="ID of the API key for reference") name: str = Field(..., description="Name of the API key") + scopes: List[str] = Field(default_factory=list, description="Allowed API scopes") + expires_at: Optional[datetime] = Field(None, description="Optional key expiration time") + org_id: Optional[str] = Field(None, description="Optional organization binding") + project_id: Optional[str] = Field(None, description="Optional project binding") created_at: datetime = Field(..., description="Creation timestamp") @@ -40,6 +48,10 @@ class APIKeyResponse(BaseModel): id: str = Field(..., description="API key ID") key_prefix: str = Field(..., description="First 8 characters of the key") name: str = Field(..., description="Name of the API key") + scopes: List[str] = Field(default_factory=list, description="Allowed API scopes") + expires_at: Optional[datetime] = Field(None, description="Optional key expiration time") + org_id: Optional[str] = Field(None, description="Optional organization binding") + project_id: Optional[str] = Field(None, description="Optional project binding") created_at: datetime = Field(..., description="Creation timestamp") last_used: Optional[datetime] = Field(None, description="Last usage timestamp") is_active: bool = Field(..., description="Whether the key is active") @@ -98,6 +110,10 @@ async def list_api_keys( id=key["id"], key_prefix=key.get("key_prefix", "xxxx-xxxx"), name=key["name"], + scopes=key.get("scopes", ["*"]), + expires_at=key.get("expires_at"), + org_id=key.get("org_id"), + project_id=key.get("project_id"), created_at=key["created_at"], last_used=key.get("last_used"), is_active=key.get("is_active", True), @@ -121,12 +137,20 @@ async def create_api_key( result = api_key_store.create_api_key( user_id=current_user["id"], name=request.name, + scopes=request.scopes, + expires_at=request.expires_at, + org_id=request.org_id, + project_id=request.project_id, ) return APIKeyCreateResponse( key=result["key"], key_id=result["key_id"], name=result["name"], + scopes=result["scopes"], + expires_at=result.get("expires_at"), + org_id=result.get("org_id"), + project_id=result.get("project_id"), created_at=result["created_at"], ) @@ -164,6 +188,10 @@ async def update_api_key( id=updated_key["id"], key_prefix=updated_key.get("key_prefix", "xxxx-xxxx"), name=updated_key["name"], + scopes=updated_key.get("scopes", ["*"]), + expires_at=updated_key.get("expires_at"), + org_id=updated_key.get("org_id"), + project_id=updated_key.get("project_id"), created_at=updated_key["created_at"], last_used=updated_key.get("last_used"), is_active=updated_key.get("is_active", True), diff --git a/src/config/settings.py b/src/config/settings.py index d06ce0e..a1343da 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -264,7 +264,7 @@ class Settings(BaseSettings): # Monitoring & Observability # ============================================================================= environment: str = Field( - default="production", + default="development", description="Deployment environment: dev, staging, production" ) sentry_dsn: Optional[str] = Field( diff --git a/src/database/api_key_store.py b/src/database/api_key_store.py index 9f017e4..4f92f0f 100644 --- a/src/database/api_key_store.py +++ b/src/database/api_key_store.py @@ -38,6 +38,24 @@ def __init__( # Try to connect self._try_connect() + def _requires_durable_storage(self) -> bool: + """Return True when in-memory API key fallback must not be used.""" + return settings.environment.lower() in {"production", "prod"} + + def _enable_in_memory_fallback(self, error: Exception) -> None: + """Switch to in-memory storage unless the environment forbids it.""" + message = f"MongoDB connection failed for API key storage: {error}" + if self._requires_durable_storage(): + logger.error("%s; refusing in-memory fallback in production", message) + raise RuntimeError( + "MongoDB is required for API key storage when ENVIRONMENT=production" + ) from error + + logger.warning("%s; using in-memory storage", message) + self._connected = False + self._in_memory = True + self.api_keys = None + def _try_connect(self) -> None: """Attempt to connect to MongoDB, fall back to in-memory if unavailable.""" try: @@ -52,10 +70,7 @@ def _try_connect(self) -> None: logger.info("Connected to MongoDB for API key storage") self._ensure_indexes() except Exception as e: - logger.warning(f"MongoDB connection failed, using in-memory storage: {e}") - self._connected = False - self._in_memory = True - self.api_keys = None + self._enable_in_memory_fallback(e) def _ensure_indexes(self) -> None: """Create necessary indexes for the api_keys collection.""" @@ -66,6 +81,8 @@ def _ensure_indexes(self) -> None: self.api_keys.create_index([("user_id", ASCENDING)]) self.api_keys.create_index([("key_hash", ASCENDING)], unique=True) self.api_keys.create_index([("is_active", ASCENDING)]) + self.api_keys.create_index([("expires_at", ASCENDING)]) + self.api_keys.create_index([("org_id", ASCENDING), ("project_id", ASCENDING)]) except Exception as e: logger.warning(f"Failed to create indexes: {e}") @@ -79,16 +96,62 @@ def _hash_key(self, key: str) -> str: """Create SHA-256 hash of the API key.""" return hashlib.sha256(key.encode()).hexdigest() + def _normalize_scopes(self, scopes: Optional[List[str]]) -> List[str]: + """Return a stable non-empty scope list for storage.""" + normalized = sorted({scope.strip() for scope in (scopes or ["*"]) if scope and scope.strip()}) + return normalized or ["*"] + + def _is_expired(self, key_doc: Dict[str, Any]) -> bool: + expires_at = key_doc.get("expires_at") + return bool(expires_at and datetime.utcnow() > expires_at) + + def _scope_allowed(self, key_doc: Dict[str, Any], required_scope: Optional[str]) -> bool: + if not required_scope: + return True + scopes = key_doc.get("scopes") or ["*"] + return "*" in scopes or required_scope in scopes + + def _binding_allowed( + self, + key_doc: Dict[str, Any], + org_id: Optional[str], + project_id: Optional[str], + ) -> bool: + bound_org = key_doc.get("org_id") + bound_project = key_doc.get("project_id") + if org_id is not None and bound_org not in (None, org_id): + return False + if project_id is not None and bound_project not in (None, project_id): + return False + return True + + def _deactivate_expired_key(self, key_doc: Dict[str, Any]) -> None: + key_doc["is_active"] = False + if self._in_memory: + return + try: + self.api_keys.update_one( + {"_id": key_doc["_id"]}, + {"$set": {"is_active": False}}, + ) + except Exception as e: + logger.warning(f"Failed to deactivate expired API key: {e}") + def create_api_key( self, user_id: str, name: str = "Default", + scopes: Optional[List[str]] = None, + expires_at: Optional[datetime] = None, + org_id: Optional[str] = None, + project_id: Optional[str] = None, ) -> Dict[str, Any]: """Create a new API key for a user.""" key = self._generate_api_key() key_hash = self._hash_key(key) key_prefix = key[:8] now = datetime.utcnow() + normalized_scopes = self._normalize_scopes(scopes) if self._in_memory: key_id = f"mem_{len(_in_memory_api_keys)}" @@ -98,6 +161,10 @@ def create_api_key( "key_hash": key_hash, "key_prefix": key_prefix, "name": name, + "scopes": normalized_scopes, + "expires_at": expires_at, + "org_id": org_id, + "project_id": project_id, "created_at": now, "last_used": None, "is_active": True, @@ -108,6 +175,10 @@ def create_api_key( "key": key, "key_id": key_id, "name": name, + "scopes": normalized_scopes, + "expires_at": expires_at, + "org_id": org_id, + "project_id": project_id, "created_at": now, } @@ -117,6 +188,10 @@ def create_api_key( "key_hash": key_hash, "key_prefix": key_prefix, "name": name, + "scopes": normalized_scopes, + "expires_at": expires_at, + "org_id": org_id, + "project_id": project_id, "created_at": now, "last_used": None, "is_active": True, @@ -127,12 +202,23 @@ def create_api_key( "key": key, "key_id": str(result.inserted_id), "name": name, + "scopes": normalized_scopes, + "expires_at": expires_at, + "org_id": org_id, + "project_id": project_id, "created_at": now, } except Exception as e: logger.error(f"Database error creating API key: {e}") - self._in_memory = True - return self.create_api_key(user_id, name) + self._enable_in_memory_fallback(e) + return self.create_api_key( + user_id=user_id, + name=name, + scopes=scopes, + expires_at=expires_at, + org_id=org_id, + project_id=project_id, + ) def get_user_api_keys(self, user_id: str, include_inactive: bool = False) -> List[Dict[str, Any]]: """Get all API keys for a user.""" @@ -161,13 +247,26 @@ def get_user_api_keys(self, user_id: str, include_inactive: bool = False) -> Lis logger.error(f"Database error getting API keys: {e}") return [] - def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]: + def validate_api_key( + self, + key: str, + required_scope: Optional[str] = None, + org_id: Optional[str] = None, + project_id: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: """Validate an API key and return associated user info.""" key_hash = self._hash_key(key) if self._in_memory: for key_doc in _in_memory_api_keys.values(): if key_doc.get("key_hash") == key_hash and key_doc.get("is_active", True): + if self._is_expired(key_doc): + self._deactivate_expired_key(key_doc) + return None + if not self._scope_allowed(key_doc, required_scope): + return None + if not self._binding_allowed(key_doc, org_id, project_id): + return None key_doc["last_used"] = datetime.utcnow() result = {**key_doc, "id": str(key_doc["_id"])} result.pop("key_hash", None) @@ -180,6 +279,13 @@ def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]: "is_active": True, }) if key_doc: + if self._is_expired(key_doc): + self._deactivate_expired_key(key_doc) + return None + if not self._scope_allowed(key_doc, required_scope): + return None + if not self._binding_allowed(key_doc, org_id, project_id): + return None now = datetime.utcnow() self.api_keys.update_one( {"_id": key_doc["_id"]}, diff --git a/src/database/user_store.py b/src/database/user_store.py index 7b059a6..5cd2891 100644 --- a/src/database/user_store.py +++ b/src/database/user_store.py @@ -31,11 +31,28 @@ def __init__( # Try to connect, but don't fail if MongoDB is unavailable self._try_connect() + def _requires_durable_storage(self) -> bool: + """Return True when in-memory user fallback must not be used.""" + return settings.environment.lower() in {"production", "prod"} + + def _enable_in_memory_fallback(self, error: Exception) -> None: + """Switch to in-memory storage unless the environment forbids it.""" + message = f"MongoDB connection failed for user storage: {error}" + if self._requires_durable_storage(): + logger.error("%s; refusing in-memory fallback in production", message) + raise RuntimeError( + "MongoDB is required for user storage when ENVIRONMENT=production" + ) from error + + logger.warning("%s; using in-memory storage", message) + self._connected = False + self._in_memory = True + self.users = None + def _try_connect(self) -> None: """Attempt to connect to MongoDB, fall back to in-memory if unavailable.""" try: - from pymongo import MongoClient, ASCENDING - from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError + from pymongo import MongoClient self._client = MongoClient(self._uri, serverSelectionTimeoutMS=5000) # Test connection @@ -47,10 +64,7 @@ def _try_connect(self) -> None: logger.info("Connected to MongoDB for user storage") self._ensure_indexes() except Exception as e: - logger.warning(f"MongoDB connection failed, using in-memory storage: {e}") - self._connected = False - self._in_memory = True - self.users = None + self._enable_in_memory_fallback(e) def _ensure_indexes(self) -> None: """Create necessary indexes for the users collection.""" @@ -101,8 +115,6 @@ def get_or_create_user( # MongoDB path try: - from pymongo.errors import DuplicateKeyError - existing = self.users.find_one({"google_id": google_id}) if existing: @@ -130,7 +142,7 @@ def get_or_create_user( except Exception as e: logger.error(f"Database error in get_or_create_user: {e}") # Fall back to in-memory - self._in_memory = True + self._enable_in_memory_fallback(e) return self.get_or_create_user(google_id, email, name, picture) def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: diff --git a/tests/unit/test_database_stores.py b/tests/unit/test_database_stores.py index 400dd78..abe4da2 100644 --- a/tests/unit/test_database_stores.py +++ b/tests/unit/test_database_stores.py @@ -1,5 +1,9 @@ from __future__ import annotations +from datetime import datetime, timedelta + +import pytest + from src.database.api_key_store import APIKeyStore, _in_memory_api_keys from src.database.project_store import ProjectStore from src.database.models import TeamRole @@ -43,6 +47,54 @@ def test_api_key_store_creates_validates_updates_and_revokes_in_memory(monkeypat assert store.validate_api_key(created["key"]) is None +def test_api_key_store_enforces_scopes_expiry_and_bindings_in_memory(monkeypatch): + _in_memory_api_keys.clear() + monkeypatch.setattr(APIKeyStore, "_try_connect", _force_api_key_memory) + store = APIKeyStore() + + created = store.create_api_key( + "user-1", + name="Project key", + scopes=["memory:read", "scanner:write"], + expires_at=datetime.utcnow() + timedelta(minutes=5), + org_id="org-1", + project_id="project-1", + ) + + validated = store.validate_api_key( + created["key"], + required_scope="memory:read", + org_id="org-1", + project_id="project-1", + ) + assert validated["scopes"] == ["memory:read", "scanner:write"] + assert validated["org_id"] == "org-1" + assert validated["project_id"] == "project-1" + + assert store.validate_api_key(created["key"], required_scope="admin:write") is None + assert store.validate_api_key(created["key"], org_id="other-org") is None + assert store.validate_api_key(created["key"], project_id="other-project") is None + + expired = store.create_api_key( + "user-1", + expires_at=datetime.utcnow() - timedelta(seconds=1), + ) + assert store.validate_api_key(expired["key"]) is None + assert _in_memory_api_keys[expired["key_id"]]["is_active"] is False + + +def test_api_key_store_refuses_memory_fallback_in_production(monkeypatch): + store = object.__new__(APIKeyStore) + store._connected = False + store._in_memory = False + store.api_keys = None + + monkeypatch.setattr("src.database.api_key_store.settings.environment", "production") + + with pytest.raises(RuntimeError, match="MongoDB is required"): + store._enable_in_memory_fallback(ConnectionError("offline")) + + def test_user_store_get_or_create_and_username_helpers_in_memory(monkeypatch): _in_memory_users.clear() monkeypatch.setattr(UserStore, "_try_connect", _force_user_memory) From 0be9a6d76764ad826825efa26eecb4c75211c62a Mon Sep 17 00:00:00 2001 From: webhop123 Date: Tue, 12 May 2026 11:12:33 +0100 Subject: [PATCH 2/2] Address API key validation feedback --- src/database/api_key_store.py | 24 +++++++++++++++--------- tests/unit/test_database_stores.py | 8 +++++--- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/database/api_key_store.py b/src/database/api_key_store.py index 4f92f0f..2a01f8a 100644 --- a/src/database/api_key_store.py +++ b/src/database/api_key_store.py @@ -4,7 +4,7 @@ import logging import secrets import string -from datetime import datetime +from datetime import datetime, timezone from typing import List, Optional, Dict, Any from src.config import settings @@ -81,8 +81,6 @@ def _ensure_indexes(self) -> None: self.api_keys.create_index([("user_id", ASCENDING)]) self.api_keys.create_index([("key_hash", ASCENDING)], unique=True) self.api_keys.create_index([("is_active", ASCENDING)]) - self.api_keys.create_index([("expires_at", ASCENDING)]) - self.api_keys.create_index([("org_id", ASCENDING), ("project_id", ASCENDING)]) except Exception as e: logger.warning(f"Failed to create indexes: {e}") @@ -101,9 +99,17 @@ def _normalize_scopes(self, scopes: Optional[List[str]]) -> List[str]: normalized = sorted({scope.strip() for scope in (scopes or ["*"]) if scope and scope.strip()}) return normalized or ["*"] + def _utc_now(self) -> datetime: + return datetime.now(timezone.utc) + + def _as_utc(self, value: datetime) -> datetime: + if value.tzinfo is None or value.tzinfo.utcoffset(value) is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + def _is_expired(self, key_doc: Dict[str, Any]) -> bool: expires_at = key_doc.get("expires_at") - return bool(expires_at and datetime.utcnow() > expires_at) + return bool(expires_at and self._utc_now() > self._as_utc(expires_at)) def _scope_allowed(self, key_doc: Dict[str, Any], required_scope: Optional[str]) -> bool: if not required_scope: @@ -119,9 +125,9 @@ def _binding_allowed( ) -> bool: bound_org = key_doc.get("org_id") bound_project = key_doc.get("project_id") - if org_id is not None and bound_org not in (None, org_id): + if bound_org is not None and bound_org != org_id: return False - if project_id is not None and bound_project not in (None, project_id): + if bound_project is not None and bound_project != project_id: return False return True @@ -150,7 +156,7 @@ def create_api_key( key = self._generate_api_key() key_hash = self._hash_key(key) key_prefix = key[:8] - now = datetime.utcnow() + now = self._utc_now() normalized_scopes = self._normalize_scopes(scopes) if self._in_memory: @@ -267,7 +273,7 @@ def validate_api_key( return None if not self._binding_allowed(key_doc, org_id, project_id): return None - key_doc["last_used"] = datetime.utcnow() + key_doc["last_used"] = self._utc_now() result = {**key_doc, "id": str(key_doc["_id"])} result.pop("key_hash", None) return result @@ -286,7 +292,7 @@ def validate_api_key( return None if not self._binding_allowed(key_doc, org_id, project_id): return None - now = datetime.utcnow() + now = self._utc_now() self.api_keys.update_one( {"_id": key_doc["_id"]}, {"$set": {"last_used": now}} diff --git a/tests/unit/test_database_stores.py b/tests/unit/test_database_stores.py index abe4da2..3d1c1b2 100644 --- a/tests/unit/test_database_stores.py +++ b/tests/unit/test_database_stores.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import pytest @@ -56,7 +56,7 @@ def test_api_key_store_enforces_scopes_expiry_and_bindings_in_memory(monkeypatch "user-1", name="Project key", scopes=["memory:read", "scanner:write"], - expires_at=datetime.utcnow() + timedelta(minutes=5), + expires_at=datetime.now(timezone.utc) + timedelta(minutes=5), org_id="org-1", project_id="project-1", ) @@ -71,13 +71,15 @@ def test_api_key_store_enforces_scopes_expiry_and_bindings_in_memory(monkeypatch assert validated["org_id"] == "org-1" assert validated["project_id"] == "project-1" + assert store.validate_api_key(created["key"], required_scope="memory:read") is None + assert store.validate_api_key(created["key"], required_scope="memory:read", org_id="org-1") is None assert store.validate_api_key(created["key"], required_scope="admin:write") is None assert store.validate_api_key(created["key"], org_id="other-org") is None assert store.validate_api_key(created["key"], project_id="other-project") is None expired = store.create_api_key( "user-1", - expires_at=datetime.utcnow() - timedelta(seconds=1), + expires_at=datetime.now(timezone.utc) - timedelta(seconds=1), ) assert store.validate_api_key(expired["key"]) is None assert _in_memory_api_keys[expired["key_id"]]["is_active"] is False