Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import logging
import time
from collections import defaultdict
from typing import Optional
from typing import TYPE_CHECKING, Optional

from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand All @@ -24,6 +24,9 @@
from src.pipelines.ingest import IngestPipeline
from src.pipelines.retrieval import RetrievalPipeline

if TYPE_CHECKING:
from src.pipelines.code_retrieval import CodeRetrievalPipeline

logger = logging.getLogger("xmem.api.deps")

# Initialize stores
Expand Down Expand Up @@ -181,7 +184,15 @@ async def require_api_key(
if user_id:
user = _user_store.get_user_by_id(user_id)
if user:
user = dict(user)
user["id"] = str(user.pop("_id"))
user["api_key"] = {
"id": key_doc.get("id"),
"scopes": key_doc.get("scopes", ["*"]),
"org_id": key_doc.get("org_id"),
"project_id": key_doc.get("project_id"),
"expires_at": key_doc.get("expires_at"),
}
request.state.user = user
return user

Expand Down
28 changes: 28 additions & 0 deletions src/api/routes/api_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
class APIKeyCreateRequest(BaseModel):
"""Request model for creating a new API key."""
name: str = Field(default="Default", description="Name for this API key")
scopes: List[str] = Field(default_factory=lambda: ["*"], description="Allowed API scopes")
expires_at: Optional[datetime] = Field(default=None, description="Optional key expiration time")
org_id: Optional[str] = Field(default=None, description="Optional organization binding")
project_id: Optional[str] = Field(default=None, description="Optional project binding")


class APIKeyCreateResponse(BaseModel):
Expand All @@ -32,6 +36,10 @@ class APIKeyCreateResponse(BaseModel):
key: str = Field(..., description="The full API key (only shown once)")
key_id: str = Field(..., description="ID of the API key for reference")
name: str = Field(..., description="Name of the API key")
scopes: List[str] = Field(default_factory=list, description="Allowed API scopes")
expires_at: Optional[datetime] = Field(None, description="Optional key expiration time")
org_id: Optional[str] = Field(None, description="Optional organization binding")
project_id: Optional[str] = Field(None, description="Optional project binding")
created_at: datetime = Field(..., description="Creation timestamp")


Expand All @@ -40,6 +48,10 @@ class APIKeyResponse(BaseModel):
id: str = Field(..., description="API key ID")
key_prefix: str = Field(..., description="First 8 characters of the key")
name: str = Field(..., description="Name of the API key")
scopes: List[str] = Field(default_factory=list, description="Allowed API scopes")
expires_at: Optional[datetime] = Field(None, description="Optional key expiration time")
org_id: Optional[str] = Field(None, description="Optional organization binding")
project_id: Optional[str] = Field(None, description="Optional project binding")
created_at: datetime = Field(..., description="Creation timestamp")
last_used: Optional[datetime] = Field(None, description="Last usage timestamp")
is_active: bool = Field(..., description="Whether the key is active")
Expand Down Expand Up @@ -98,6 +110,10 @@ async def list_api_keys(
id=key["id"],
key_prefix=key.get("key_prefix", "xxxx-xxxx"),
name=key["name"],
scopes=key.get("scopes", ["*"]),
expires_at=key.get("expires_at"),
org_id=key.get("org_id"),
project_id=key.get("project_id"),
created_at=key["created_at"],
last_used=key.get("last_used"),
is_active=key.get("is_active", True),
Expand All @@ -121,12 +137,20 @@ async def create_api_key(
result = api_key_store.create_api_key(
user_id=current_user["id"],
name=request.name,
scopes=request.scopes,
expires_at=request.expires_at,
org_id=request.org_id,
project_id=request.project_id,
)

return APIKeyCreateResponse(
key=result["key"],
key_id=result["key_id"],
name=result["name"],
scopes=result["scopes"],
expires_at=result.get("expires_at"),
org_id=result.get("org_id"),
project_id=result.get("project_id"),
created_at=result["created_at"],
)

Expand Down Expand Up @@ -164,6 +188,10 @@ async def update_api_key(
id=updated_key["id"],
key_prefix=updated_key.get("key_prefix", "xxxx-xxxx"),
name=updated_key["name"],
scopes=updated_key.get("scopes", ["*"]),
expires_at=updated_key.get("expires_at"),
org_id=updated_key.get("org_id"),
project_id=updated_key.get("project_id"),
created_at=updated_key["created_at"],
last_used=updated_key.get("last_used"),
is_active=updated_key.get("is_active", True),
Expand Down
2 changes: 1 addition & 1 deletion src/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class Settings(BaseSettings):
# Monitoring & Observability
# =============================================================================
environment: str = Field(
default="production",
default="development",
description="Deployment environment: dev, staging, production"
)
sentry_dsn: Optional[str] = Field(
Expand Down
134 changes: 123 additions & 11 deletions src/database/api_key_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import secrets
import string
from datetime import datetime
from datetime import datetime, timezone
from typing import List, Optional, Dict, Any

from src.config import settings
Expand Down Expand Up @@ -38,6 +38,24 @@ def __init__(
# Try to connect
self._try_connect()

def _requires_durable_storage(self) -> bool:
"""Return True when in-memory API key fallback must not be used."""
return settings.environment.lower() in {"production", "prod"}

def _enable_in_memory_fallback(self, error: Exception) -> None:
"""Switch to in-memory storage unless the environment forbids it."""
message = f"MongoDB connection failed for API key storage: {error}"
if self._requires_durable_storage():
logger.error("%s; refusing in-memory fallback in production", message)
raise RuntimeError(
"MongoDB is required for API key storage when ENVIRONMENT=production"
) from error

logger.warning("%s; using in-memory storage", message)
self._connected = False
self._in_memory = True
self.api_keys = None

def _try_connect(self) -> None:
"""Attempt to connect to MongoDB, fall back to in-memory if unavailable."""
try:
Expand All @@ -52,10 +70,7 @@ def _try_connect(self) -> None:
logger.info("Connected to MongoDB for API key storage")
self._ensure_indexes()
except Exception as e:
logger.warning(f"MongoDB connection failed, using in-memory storage: {e}")
self._connected = False
self._in_memory = True
self.api_keys = None
self._enable_in_memory_fallback(e)

def _ensure_indexes(self) -> None:
"""Create necessary indexes for the api_keys collection."""
Expand All @@ -79,16 +94,70 @@ def _hash_key(self, key: str) -> str:
"""Create SHA-256 hash of the API key."""
return hashlib.sha256(key.encode()).hexdigest()

def _normalize_scopes(self, scopes: Optional[List[str]]) -> List[str]:
"""Return a stable non-empty scope list for storage."""
normalized = sorted({scope.strip() for scope in (scopes or ["*"]) if scope and scope.strip()})
return normalized or ["*"]

def _utc_now(self) -> datetime:
return datetime.now(timezone.utc)

def _as_utc(self, value: datetime) -> datetime:
if value.tzinfo is None or value.tzinfo.utcoffset(value) is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)

def _is_expired(self, key_doc: Dict[str, Any]) -> bool:
expires_at = key_doc.get("expires_at")
return bool(expires_at and self._utc_now() > self._as_utc(expires_at))

def _scope_allowed(self, key_doc: Dict[str, Any], required_scope: Optional[str]) -> bool:
if not required_scope:
return True
scopes = key_doc.get("scopes") or ["*"]
return "*" in scopes or required_scope in scopes

def _binding_allowed(
self,
key_doc: Dict[str, Any],
org_id: Optional[str],
project_id: Optional[str],
) -> bool:
bound_org = key_doc.get("org_id")
bound_project = key_doc.get("project_id")
if bound_org is not None and bound_org != org_id:
return False
if bound_project is not None and bound_project != project_id:
return False
return True

def _deactivate_expired_key(self, key_doc: Dict[str, Any]) -> None:
key_doc["is_active"] = False
if self._in_memory:
return
try:
self.api_keys.update_one(
{"_id": key_doc["_id"]},
{"$set": {"is_active": False}},
)
except Exception as e:
logger.warning(f"Failed to deactivate expired API key: {e}")

def create_api_key(
self,
user_id: str,
name: str = "Default",
scopes: Optional[List[str]] = None,
expires_at: Optional[datetime] = None,
org_id: Optional[str] = None,
project_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Create a new API key for a user."""
key = self._generate_api_key()
key_hash = self._hash_key(key)
key_prefix = key[:8]
now = datetime.utcnow()
now = self._utc_now()
normalized_scopes = self._normalize_scopes(scopes)

if self._in_memory:
key_id = f"mem_{len(_in_memory_api_keys)}"
Expand All @@ -98,6 +167,10 @@ def create_api_key(
"key_hash": key_hash,
"key_prefix": key_prefix,
"name": name,
"scopes": normalized_scopes,
"expires_at": expires_at,
"org_id": org_id,
"project_id": project_id,
"created_at": now,
"last_used": None,
"is_active": True,
Expand All @@ -108,6 +181,10 @@ def create_api_key(
"key": key,
"key_id": key_id,
"name": name,
"scopes": normalized_scopes,
"expires_at": expires_at,
"org_id": org_id,
"project_id": project_id,
"created_at": now,
}

Expand All @@ -117,6 +194,10 @@ def create_api_key(
"key_hash": key_hash,
"key_prefix": key_prefix,
"name": name,
"scopes": normalized_scopes,
"expires_at": expires_at,
"org_id": org_id,
"project_id": project_id,
"created_at": now,
"last_used": None,
"is_active": True,
Expand All @@ -127,12 +208,23 @@ def create_api_key(
"key": key,
"key_id": str(result.inserted_id),
"name": name,
"scopes": normalized_scopes,
"expires_at": expires_at,
"org_id": org_id,
"project_id": project_id,
"created_at": now,
}
except Exception as e:
logger.error(f"Database error creating API key: {e}")
self._in_memory = True
return self.create_api_key(user_id, name)
self._enable_in_memory_fallback(e)
return self.create_api_key(
user_id=user_id,
name=name,
scopes=scopes,
expires_at=expires_at,
org_id=org_id,
project_id=project_id,
)

def get_user_api_keys(self, user_id: str, include_inactive: bool = False) -> List[Dict[str, Any]]:
"""Get all API keys for a user."""
Expand Down Expand Up @@ -161,14 +253,27 @@ def get_user_api_keys(self, user_id: str, include_inactive: bool = False) -> Lis
logger.error(f"Database error getting API keys: {e}")
return []

def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]:
def validate_api_key(
self,
key: str,
required_scope: Optional[str] = None,
org_id: Optional[str] = None,
project_id: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
"""Validate an API key and return associated user info."""
key_hash = self._hash_key(key)

if self._in_memory:
for key_doc in _in_memory_api_keys.values():
if key_doc.get("key_hash") == key_hash and key_doc.get("is_active", True):
key_doc["last_used"] = datetime.utcnow()
if self._is_expired(key_doc):
self._deactivate_expired_key(key_doc)
return None
if not self._scope_allowed(key_doc, required_scope):
return None
if not self._binding_allowed(key_doc, org_id, project_id):
return None
key_doc["last_used"] = self._utc_now()
result = {**key_doc, "id": str(key_doc["_id"])}
result.pop("key_hash", None)
return result
Expand All @@ -180,7 +285,14 @@ def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]:
"is_active": True,
})
if key_doc:
now = datetime.utcnow()
if self._is_expired(key_doc):
self._deactivate_expired_key(key_doc)
return None
if not self._scope_allowed(key_doc, required_scope):
return None
if not self._binding_allowed(key_doc, org_id, project_id):
return None
now = self._utc_now()
self.api_keys.update_one(
{"_id": key_doc["_id"]},
{"$set": {"last_used": now}}
Expand Down
Loading
Loading