From e0f8b9f11ab17e009aa7811511c1c78f4ce8e070 Mon Sep 17 00:00:00 2001 From: strongkeep-debug Date: Mon, 11 May 2026 11:19:55 -0700 Subject: [PATCH 1/5] Add low-latency memory search path --- src/api/app.py | 2 + src/api/routes/__init__.py | 3 +- src/api/routes/memory.py | 48 ++++-- src/api/schemas.py | 13 +- src/pipelines/retrieval.py | 168 +++++++++++++++++-- src/schemas/retrieval.py | 4 +- tests/api/test_memory_search_routes.py | 104 ++++++++++++ tests/integration/test_retrieval_pipeline.py | 79 +++++++++ 8 files changed, 393 insertions(+), 28 deletions(-) create mode 100644 tests/api/test_memory_search_routes.py diff --git a/src/api/app.py b/src/api/app.py index 7b9dc79..d072398 100644 --- a/src/api/app.py +++ b/src/api/app.py @@ -33,6 +33,7 @@ from src.api.routes.enterprise import router as enterprise_router from src.api.routes.health import router as health_router from src.api.routes.memory import router as memory_router +from src.api.routes.memory import search_router as memory_search_router from src.api.routes.memory import scrape_router as memory_scrape_router from src.api.routes.memory_graph import router as memory_graph_router from src.api.routes.scanner import router as scanner_router @@ -155,6 +156,7 @@ async def lifespan(app: FastAPI): # ── Routes ──────────────────────────────────────────────────────── app.include_router(health_router) app.include_router(memory_scrape_router) + app.include_router(memory_search_router) app.include_router(memory_router) app.include_router(memory_graph_router) app.include_router(code_router) diff --git a/src/api/routes/__init__.py b/src/api/routes/__init__.py index 17d0c0e..768617e 100644 --- a/src/api/routes/__init__.py +++ b/src/api/routes/__init__.py @@ -1,4 +1,5 @@ from .health import router as health_router from .memory import router as memory_router +from .memory import search_router as memory_search_router -__all__ = ["health_router", "memory_router"] +__all__ = ["health_router", "memory_router", "memory_search_router"] diff --git a/src/api/routes/memory.py b/src/api/routes/memory.py index b4be36d..88ba771 100644 --- a/src/api/routes/memory.py +++ b/src/api/routes/memory.py @@ -63,6 +63,11 @@ dependencies=[Depends(enforce_rate_limit)], ) +search_router = APIRouter( + tags=["memory"], + dependencies=[Depends(require_ready), Depends(enforce_rate_limit)], +) + # Helpers def _model_name(model: Any) -> str: @@ -667,6 +672,7 @@ async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = D confidence=result.confidence, ) elapsed = round((time.perf_counter() - start) * 1000, 2) + pipeline.record_latency("agentic", elapsed) return _wrap(request, data, elapsed) except Exception as exc: @@ -676,10 +682,15 @@ async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = D # POST /v1/memory/search +@search_router.post( + "/search", + response_model=APIResponse, + summary="Raw semantic search across memory domains with optional answer synthesis", +) @router.post( "/search", response_model=APIResponse, - summary="Raw semantic search across memory domains (no LLM answer)", + summary="Raw semantic search across memory domains with optional answer synthesis", ) async def search_memory(req: SearchRequest, request: Request, user: dict = Depends(require_api_key)): start = time.perf_counter() @@ -689,17 +700,34 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen user_id = user.get("username") or user.get("name") or user["id"] try: - all_results: List[SourceRecord] = [] - - if "profile" in req.domains: - all_results.extend(_search_profile(pipeline, user_id)) - if "temporal" in req.domains: - all_results.extend(_search_temporal(pipeline, req.query, user_id, req.top_k)) - if "summary" in req.domains: - all_results.extend(await _search_summary(pipeline, req.query, user_id, req.top_k)) + all_results = await pipeline.search_raw( + query=req.query, + user_id=user_id, + domains=req.domains, + top_k=req.top_k, + ) + answer = "" + if req.answer: + answer = await pipeline.answer_from_sources(req.query, all_results) - data = SearchResponse(results=all_results, total=len(all_results)) elapsed = round((time.perf_counter() - start) * 1000, 2) + pipeline.record_latency("answer" if req.answer else "raw", elapsed) + data = SearchResponse( + results=[ + SourceRecord( + domain=s.domain, + content=s.content, + score=round(s.score, 3), + metadata=s.metadata, + ) + for s in all_results + ], + total=len(all_results), + answer=answer, + model=_model_name(pipeline.model) if req.answer else "", + confidence=min(1.0, len(all_results) * 0.2) if answer else 0.0, + latency=pipeline.get_latency_snapshot(), + ) return _wrap(request, data, elapsed) except Exception as exc: diff --git a/src/api/schemas.py b/src/api/schemas.py index b7ee122..c701616 100644 --- a/src/api/schemas.py +++ b/src/api/schemas.py @@ -7,7 +7,6 @@ from __future__ import annotations -from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional @@ -159,15 +158,19 @@ class SearchRequest(BaseModel): ..., min_length=1, max_length=256, pattern=r"^[\w.\-@]+$", ) domains: List[str] = Field( - default=["profile", "temporal", "summary"], + default=["profile", "temporal", "summary", "snippet"], description="Which memory domains to search", ) top_k: int = Field(default=10, ge=1, le=100) + answer: bool = Field( + default=False, + description="When true, synthesize an answer from the raw hits without agentic tool selection.", + ) @field_validator("domains") @classmethod def validate_domains(cls, v: List[str]) -> List[str]: - allowed = {"profile", "temporal", "summary"} + allowed = {"profile", "temporal", "summary", "snippet"} for d in v: if d not in allowed: raise ValueError(f"Invalid domain '{d}'. Allowed: {allowed}") @@ -177,6 +180,10 @@ def validate_domains(cls, v: List[str]) -> List[str]: class SearchResponse(BaseModel): results: List[SourceRecord] = Field(default_factory=list) total: int = 0 + answer: str = "" + model: str = "" + confidence: float = 0.0 + latency: Dict[str, Dict[str, float | int]] = Field(default_factory=dict) # ── Scrape (extract from shared chat links) ──────────────────────────────── diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index 3516561..3fc7acd 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -22,6 +22,7 @@ import asyncio import logging +import time from typing import Any, Callable, Dict, List, Optional from dotenv import load_dotenv @@ -41,6 +42,10 @@ logger = logging.getLogger("xmem.pipelines.retrieval") +_CACHE_TTL_SECONDS = 60.0 +_LATENCY_SAMPLE_LIMIT = 200 + + # ═══════════════════════════════════════════════════════════════════════════ # Tool schemas — These are the "function signatures" exposed to the LLM # ═══════════════════════════════════════════════════════════════════════════ @@ -132,6 +137,9 @@ def __init__( self.embed_fn = embed_fn self._snippet_stores: Dict[str, PineconeVectorStore] = {} + self._profile_catalog_cache: Dict[str, tuple[float, List[Dict[str, str]], List[Any]]] = {} + self._retrieval_plan_cache: Dict[tuple[str, str, int, str], tuple[float, AIMessage]] = {} + self._latency_samples: Dict[str, List[float]] = {} logger.info("RetrievalPipeline initialized") @@ -154,7 +162,7 @@ async def run( logger.info("=" * 60) # ── Step 0: Fetch available profile catalog for this user ───── - profile_catalog, profile_records = self._fetch_profile_catalog(user_id) + profile_catalog, profile_records = self._get_profile_catalog(user_id) catalog_text = self._format_catalog(profile_catalog) logger.info("Available profiles: %s", catalog_text) @@ -168,7 +176,11 @@ async def run( HumanMessage(content=query), ] - ai_response: AIMessage = await self.model_with_tools.ainvoke(messages) + plan_key = (user_id, query.strip(), top_k, catalog_text) + ai_response = self._get_cached_retrieval_plan(plan_key) + if ai_response is None: + ai_response = await self.model_with_tools.ainvoke(messages) + self._cache_retrieval_plan(plan_key, ai_response) logger.info("LLM response received (tool_calls=%d)", len(ai_response.tool_calls or [])) # ── Step 2: Execute tool calls ──────────────────────────────── @@ -236,16 +248,7 @@ async def _process_tool_call(tc): answer = ai_response.content logger.info("LLM answered without tool calls") - if isinstance(answer, list): - parts = [] - for c in answer: - if isinstance(c, dict) and "text" in c: - parts.append(c["text"]) - elif isinstance(c, str): - parts.append(c) - else: - parts.append(str(c)) - answer = "\n".join(parts) + answer = self._coerce_answer(answer) confidence = min(1.0, len(sources) * 0.2) if sources else 0.1 @@ -262,6 +265,52 @@ async def _process_tool_call(tc): confidence=confidence, ) + async def search_raw( + self, + query: str, + user_id: str, + domains: List[str], + top_k: int = 10, + ) -> List[SourceRecord]: + """Return ranked memory hits without asking the LLM for a retrieval plan.""" + + domain_set = set(domains) + results: List[SourceRecord] = [] + + if "profile" in domain_set: + results.extend(await self._search_profile_raw(query, user_id, top_k)) + if "temporal" in domain_set: + results.extend(await self._search_temporal(query, user_id, top_k)) + if "summary" in domain_set: + results.extend(await self._search_summary(query, user_id, top_k)) + if "snippet" in domain_set: + results.extend(await self._search_snippet(query, user_id, top_k)) + + return sorted(results, key=lambda record: record.score, reverse=True) + + async def answer_from_sources(self, query: str, sources: List[SourceRecord]) -> str: + """Generate an answer from already-fetched sources without tool selection.""" + + context_text = self._format_tool_results(sources) + answer_prompt = ANSWER_PROMPT.format(context=context_text, query=query) + final_response = await self.model.ainvoke([HumanMessage(content=answer_prompt)]) + return self._coerce_answer(final_response.content) + + def record_latency(self, mode: str, elapsed_ms: float) -> None: + """Track bounded latency samples for raw, answer, and agentic modes.""" + + samples = self._latency_samples.setdefault(mode, []) + samples.append(float(elapsed_ms)) + if len(samples) > _LATENCY_SAMPLE_LIMIT: + del samples[0 : len(samples) - _LATENCY_SAMPLE_LIMIT] + + def get_latency_snapshot(self) -> Dict[str, Dict[str, float | int]]: + return { + mode: self._percentiles(samples) + for mode, samples in self._latency_samples.items() + if samples + } + # ------------------------------------------------------------------ # Tool execution # ------------------------------------------------------------------ @@ -348,6 +397,35 @@ def _search_profile( # -- Temporal: Neo4j semantic search ─────────────────────────────── + async def _search_profile_raw( + self, + query: str, + user_id: str, + top_k: int = 10, + ) -> List[SourceRecord]: + """Semantic profile search for the low-latency raw search endpoint.""" + + try: + results = await self.vector_store.search_by_text( + query_text=query, + top_k=top_k, + filters={"user_id": user_id, "domain": "profile"}, + ) + except Exception as exc: + logger.warning("Profile raw search failed, using cached catalog: %s", exc) + _, results = self._get_profile_catalog(user_id) + + records = [] + for r in results[:top_k]: + records.append(SourceRecord( + domain="profile", + content=r.content, + score=r.score, + metadata={"id": r.id, **r.metadata}, + )) + logger.info("Profile raw [%s]: %d results", query, len(records)) + return records + async def _search_temporal( self, query: str, @@ -491,6 +569,20 @@ async def _search_snippet( # Profile catalog (tells the LLM what profile keys exist) # ------------------------------------------------------------------ + def _get_profile_catalog(self, user_id: str): + cached = self._profile_catalog_cache.get(user_id) + now = time.monotonic() + if cached and cached[0] > now: + return cached[1], cached[2] + + catalog, results = self._fetch_profile_catalog(user_id) + self._profile_catalog_cache[user_id] = ( + now + _CACHE_TTL_SECONDS, + catalog, + results, + ) + return catalog, results + def _fetch_profile_catalog(self, user_id: str): """Fetch all profile entries for a user. @@ -543,6 +635,58 @@ def _format_catalog(self, catalog: List[Dict[str, str]]) -> str: lines.append(f" - {t} / {st}") return "\n".join(lines) + def _get_cached_retrieval_plan( + self, + key: tuple[str, str, int, str], + ) -> AIMessage | None: + cached = self._retrieval_plan_cache.get(key) + if not cached: + return None + expires_at, response = cached + if expires_at <= time.monotonic(): + self._retrieval_plan_cache.pop(key, None) + return None + return response + + def _cache_retrieval_plan( + self, + key: tuple[str, str, int, str], + response: AIMessage, + ) -> None: + self._retrieval_plan_cache[key] = ( + time.monotonic() + _CACHE_TTL_SECONDS, + response, + ) + + def _coerce_answer(self, answer: Any) -> str: + if isinstance(answer, list): + parts = [] + for c in answer: + if isinstance(c, dict) and "text" in c: + parts.append(c["text"]) + elif isinstance(c, str): + parts.append(c) + else: + parts.append(str(c)) + return "\n".join(parts) + return str(answer) + + def _percentiles(self, samples: List[float]) -> Dict[str, float | int]: + ordered = sorted(samples) + + def pick(percentile: float) -> float: + if not ordered: + return 0.0 + index = min(len(ordered) - 1, round((len(ordered) - 1) * percentile)) + return round(ordered[index], 2) + + return { + "count": len(ordered), + "p50": pick(0.50), + "p95": pick(0.95), + "p99": pick(0.99), + } + # ------------------------------------------------------------------ # Formatting helpers # ------------------------------------------------------------------ diff --git a/src/schemas/retrieval.py b/src/schemas/retrieval.py index 8896726..3b0bb5e 100644 --- a/src/schemas/retrieval.py +++ b/src/schemas/retrieval.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List @dataclass @@ -26,7 +26,7 @@ def source_count(self) -> int: class SourceRecord: """A single piece of evidence fetched from a data store.""" - domain: str # "profile", "temporal", "summary" + domain: str # "profile", "temporal", "summary", "snippet" content: str # the actual text score: float = 0.0 # similarity score (if applicable) metadata: Dict[str, Any] = field(default_factory=dict) diff --git a/tests/api/test_memory_search_routes.py b/tests/api/test_memory_search_routes.py new file mode 100644 index 0000000..3c7221c --- /dev/null +++ b/tests/api/test_memory_search_routes.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from src.api import dependencies as deps +from src.api.middleware import RequestContextMiddleware +from src.api.routes.memory import router as memory_router +from src.api.routes.memory import search_router as memory_search_router +from src.schemas.retrieval import SourceRecord + + +class FakeSearchPipeline: + model = SimpleNamespace(model="fake-retrieval") + + def __init__(self) -> None: + self.answer_calls = 0 + self.latencies: dict[str, list[float]] = {} + + async def search_raw(self, query: str, user_id: str, domains: list[str], top_k: int): + assert query == "latency" + assert user_id == "Static Key User" + assert domains == ["profile", "summary"] + assert top_k == 3 + return [ + SourceRecord(domain="summary", content="Low-latency summary", score=0.9), + SourceRecord(domain="profile", content="work / company = XMem", score=0.7), + ] + + async def answer_from_sources(self, query: str, sources: list[SourceRecord]) -> str: + self.answer_calls += 1 + return "Alice is working on low-latency retrieval." + + def record_latency(self, mode: str, elapsed_ms: float) -> None: + self.latencies.setdefault(mode, []).append(elapsed_ms) + + def get_latency_snapshot(self): + return { + mode: {"count": len(samples), "p50": samples[-1], "p95": samples[-1], "p99": samples[-1]} + for mode, samples in self.latencies.items() + } + + +@pytest.fixture +def memory_search_app(monkeypatch): + pipeline = FakeSearchPipeline() + monkeypatch.setattr(deps.settings, "api_keys", ["test-static-key"], raising=False) + deps._init_error = None + deps._pipelines_ready.set() + deps.set_pipelines(SimpleNamespace(), pipeline) + + app = FastAPI() + app.add_middleware(RequestContextMiddleware) + app.include_router(memory_search_router) + app.include_router(memory_router) + return app, pipeline + + +def test_memory_search_route_returns_raw_hits_without_answer(memory_search_app): + app, pipeline = memory_search_app + response = TestClient(app).post( + "/v1/memory/search", + headers={"Authorization": "Bearer test-static-key"}, + json={ + "query": "latency", + "user_id": "ignored-by-auth", + "domains": ["profile", "summary"], + "top_k": 3, + }, + ) + + payload = response.json() + + assert response.status_code == 200 + assert payload["data"]["total"] == 2 + assert payload["data"]["answer"] == "" + assert payload["data"]["latency"]["raw"]["count"] == 1 + assert pipeline.answer_calls == 0 + + +def test_root_search_alias_can_synthesize_answer(memory_search_app): + app, pipeline = memory_search_app + response = TestClient(app).post( + "/search", + headers={"Authorization": "Bearer test-static-key"}, + json={ + "query": "latency", + "user_id": "ignored-by-auth", + "domains": ["profile", "summary"], + "top_k": 3, + "answer": True, + }, + ) + + payload = response.json() + + assert response.status_code == 200 + assert payload["data"]["answer"] == "Alice is working on low-latency retrieval." + assert payload["data"]["model"] == "fake-retrieval" + assert payload["data"]["latency"]["answer"]["count"] == 1 + assert pipeline.answer_calls == 1 diff --git a/tests/integration/test_retrieval_pipeline.py b/tests/integration/test_retrieval_pipeline.py index 076b26d..dcfc805 100644 --- a/tests/integration/test_retrieval_pipeline.py +++ b/tests/integration/test_retrieval_pipeline.py @@ -44,6 +44,85 @@ async def test_retrieval_pipeline_executes_tool_calls_and_generates_answer(vecto assert result.confidence > 0.1 +@pytest.mark.asyncio +async def test_retrieval_pipeline_caches_catalog_and_retrieval_plan(vector_store, neo4j_client): + vector_store.seed( + "profile-1", + "work / company = XMem", + {"user_id": "alice", "domain": "profile", "main_content": "work_company"}, + ) + model = FakeChatModel( + tool_responses=[ + FakeLLMResponse("", tool_calls=[ + {"name": "search_profile", "args": {"topic": "work"}, "id": "call-profile"}, + ]) + ], + responses=["Alice works at XMem.", "Alice still works at XMem."], + ) + pipeline = RetrievalPipeline(model=model, vector_store=vector_store, neo4j_client=neo4j_client) + + first = await pipeline.run("Where does Alice work?", "alice") + second = await pipeline.run("Where does Alice work?", "alice") + + assert "XMem" in first.answer + assert "XMem" in second.answer + assert len(pipeline.model_with_tools.calls) == 1 + + +@pytest.mark.asyncio +async def test_raw_search_returns_ranked_hits_without_tool_selection(vector_store, neo4j_client): + vector_store.seed( + "profile-1", + "work / company = XMem", + {"user_id": "alice", "domain": "profile", "main_content": "work_company"}, + score=0.7, + ) + vector_store.seed( + "summary-1", + "Alice is tuning low-latency retrieval.", + {"user_id": "alice", "domain": "summary"}, + score=0.9, + ) + neo4j_client.seed_event( + user_id="alice", + date="05-11", + event_name="Latency review", + desc="Measured raw search latency", + year="2026", + similarity_score=0.8, + ) + model = FakeChatModel() + pipeline = RetrievalPipeline(model=model, vector_store=vector_store, neo4j_client=neo4j_client) + + results = await pipeline.search_raw( + "latency", + "alice", + ["profile", "temporal", "summary"], + top_k=5, + ) + + assert [record.score for record in results] == sorted( + [record.score for record in results], + reverse=True, + ) + assert {record.domain for record in results} == {"profile", "temporal", "summary"} + assert not pipeline.model_with_tools.calls + + +@pytest.mark.asyncio +async def test_answer_from_sources_skips_tool_selection(vector_store, neo4j_client): + model = FakeChatModel(responses=["Alice works at XMem."]) + pipeline = RetrievalPipeline(model=model, vector_store=vector_store, neo4j_client=neo4j_client) + + answer = await pipeline.answer_from_sources( + "Where does Alice work?", + [], + ) + + assert answer == "Alice works at XMem." + assert not pipeline.model_with_tools.calls + + @pytest.mark.asyncio async def test_retrieval_tool_dispatch_handles_unknown_and_snippet(vector_store, neo4j_client): model = FakeChatModel() From 5dfca957ab49777e80c25c3759b0541835e60769 Mon Sep 17 00:00:00 2001 From: strongkeep-debug Date: Mon, 11 May 2026 11:51:44 -0700 Subject: [PATCH 2/5] Tighten low-latency retrieval path --- src/pipelines/retrieval.py | 197 ++++++++++++------- tests/integration/test_retrieval_pipeline.py | 196 +++++++++++++++--- 2 files changed, 301 insertions(+), 92 deletions(-) diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index 3fc7acd..08ac00e 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -21,8 +21,10 @@ from __future__ import annotations import asyncio +import hashlib import logging import time +from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional from dotenv import load_dotenv @@ -43,6 +45,8 @@ _CACHE_TTL_SECONDS = 60.0 +_PROFILE_CATALOG_CACHE_LIMIT = 256 +_RETRIEVAL_PLAN_CACHE_LIMIT = 512 _LATENCY_SAMPLE_LIMIT = 200 @@ -50,6 +54,7 @@ # Tool schemas — These are the "function signatures" exposed to the LLM # ═══════════════════════════════════════════════════════════════════════════ + class SearchProfile(BaseModel): """Look up user profile facts by topic. Use when the question asks about a specific attribute like job, name, hobby, food preference, etc. @@ -63,21 +68,28 @@ class SearchTemporal(BaseModel): """Search for date-based events like appointments, birthdays, milestones. Use when the question involves 'when', dates, schedules, or events.""" - query: str = Field(description="Short search query describing the event, e.g. 'dentist appointment'") + query: str = Field( + description="Short search query describing the event, e.g. 'dentist appointment'" + ) class SearchSummary(BaseModel): """Search general conversation summaries for broad context. Use as a fallback for questions that don't fit profile or temporal domains.""" - query: str = Field(description="Short search query, e.g. 'what does the user enjoy'") + query: str = Field( + description="Short search query, e.g. 'what does the user enjoy'" + ) class SearchSnippet(BaseModel): """Search for personal code snippets previously saved by the user. - Use when the question asks about a specific piece of code, script, or technical configuration the user wrote.""" + Use when the question asks about a specific piece of code, script, or technical configuration the user wrote. + """ - query: str = Field(description="Short search query, e.g. 'python database connection script'") + query: str = Field( + description="Short search query, e.g. 'python database connection script'" + ) TOOLS = [SearchProfile, SearchTemporal, SearchSummary, SearchSnippet] @@ -87,8 +99,10 @@ class SearchSnippet(BaseModel): # Embedding helper (reuses the cached model from ingest) # ═══════════════════════════════════════════════════════════════════════════ + def _get_embed_fn() -> Callable[[str], List[float]]: from src.pipelines.ingest import embed_text + return embed_text @@ -96,6 +110,7 @@ def _get_embed_fn() -> Callable[[str], List[float]]: # RetrievalPipeline # ═══════════════════════════════════════════════════════════════════════════ + class RetrievalPipeline: """Two-step agentic retrieval: tool-call → fetch → answer.""" @@ -108,6 +123,7 @@ def __init__( # ── LLM ─────────────────────────────────────────────────────── if model is None: from src.models import get_model + override = settings.retrieval_model self.model = get_model(model_name=override) if override else get_model() else: @@ -137,8 +153,12 @@ def __init__( self.embed_fn = embed_fn self._snippet_stores: Dict[str, PineconeVectorStore] = {} - self._profile_catalog_cache: Dict[str, tuple[float, List[Dict[str, str]], List[Any]]] = {} - self._retrieval_plan_cache: Dict[tuple[str, str, int, str], tuple[float, AIMessage]] = {} + self._profile_catalog_cache: OrderedDict[ + str, tuple[float, List[Dict[str, str]], List[Any]] + ] = OrderedDict() + self._retrieval_plan_cache: OrderedDict[ + tuple[str, str, int, str], tuple[float, Any] + ] = OrderedDict() self._latency_samples: Dict[str, List[float]] = {} logger.info("RetrievalPipeline initialized") @@ -162,7 +182,7 @@ async def run( logger.info("=" * 60) # ── Step 0: Fetch available profile catalog for this user ───── - profile_catalog, profile_records = self._get_profile_catalog(user_id) + profile_catalog, profile_records = await self._get_profile_catalog(user_id) catalog_text = self._format_catalog(profile_catalog) logger.info("Available profiles: %s", catalog_text) @@ -176,12 +196,15 @@ async def run( HumanMessage(content=query), ] - plan_key = (user_id, query.strip(), top_k, catalog_text) + catalog_hash = hashlib.sha256(catalog_text.encode("utf-8")).hexdigest() + plan_key = (user_id, query.strip(), top_k, catalog_hash) ai_response = self._get_cached_retrieval_plan(plan_key) if ai_response is None: ai_response = await self.model_with_tools.ainvoke(messages) self._cache_retrieval_plan(plan_key, ai_response) - logger.info("LLM response received (tool_calls=%d)", len(ai_response.tool_calls or [])) + logger.info( + "LLM response received (tool_calls=%d)", len(ai_response.tool_calls or []) + ) # ── Step 2: Execute tool calls ──────────────────────────────── sources: List[SourceRecord] = [] @@ -196,11 +219,16 @@ async def _process_tool_call(tc): tool_id = tc["id"] logger.info(" Tool call: %s(%s)", tool_name, tool_args) records = await self._execute_tool( - tool_name, tool_args, user_id, top_k, + tool_name, + tool_args, + user_id, + top_k, ) return tool_name, tool_args, tool_id, records - tool_results = await asyncio.gather(*[_process_tool_call(tc) for tc in ai_response.tool_calls]) + tool_results = await asyncio.gather( + *[_process_tool_call(tc) for tc in ai_response.tool_calls] + ) for tool_name, tool_args, tool_id, records in tool_results: sources.extend(records) @@ -217,7 +245,9 @@ async def _process_tool_call(tc): if "searchsummary" not in called_tools: logger.info(" Auto-adding summary context (top_k=5)") extra = await self._search_summary( - query=query, user_id=user_id, top_k=20, + query=query, + user_id=user_id, + top_k=20, ) if extra: sources.extend(extra) @@ -255,7 +285,9 @@ async def _process_tool_call(tc): logger.info("=" * 60) logger.info("RETRIEVAL PIPELINE COMPLETE") logger.info(" sources: %d", len(sources)) - logger.info(" answer: %s", answer[:100] + "..." if len(answer) > 100 else answer) + logger.info( + " answer: %s", answer[:100] + "..." if len(answer) > 100 else answer + ) logger.info("=" * 60) return RetrievalResult( @@ -275,16 +307,23 @@ async def search_raw( """Return ranked memory hits without asking the LLM for a retrieval plan.""" domain_set = set(domains) - results: List[SourceRecord] = [] - + tasks = [] if "profile" in domain_set: - results.extend(await self._search_profile_raw(query, user_id, top_k)) + tasks.append(self._search_profile_raw(query, user_id, top_k)) if "temporal" in domain_set: - results.extend(await self._search_temporal(query, user_id, top_k)) + tasks.append(self._search_temporal(query, user_id, top_k)) if "summary" in domain_set: - results.extend(await self._search_summary(query, user_id, top_k)) + tasks.append(self._search_summary(query, user_id, top_k)) if "snippet" in domain_set: - results.extend(await self._search_snippet(query, user_id, top_k)) + tasks.append(self._search_snippet(query, user_id, top_k)) + + if not tasks: + return [] + + task_results = await asyncio.gather(*tasks) + results = [ + record for domain_results in task_results for record in domain_results + ] return sorted(results, key=lambda record: record.score, reverse=True) @@ -380,17 +419,19 @@ def _search_profile( parts = main_content.split("_", 1) sub_topic = parts[1] if len(parts) == 2 else "" - records.append(SourceRecord( - domain="profile", - content=r.content, - score=r.score, - metadata={ - "id": r.id, - "topic": topic, - "sub_topic": sub_topic, - **r.metadata, - }, - )) + records.append( + SourceRecord( + domain="profile", + content=r.content, + score=r.score, + metadata={ + "id": r.id, + "topic": topic, + "sub_topic": sub_topic, + **r.metadata, + }, + ) + ) logger.info(" → Profile [%s]: %d results", topic, len(records)) return records @@ -413,16 +454,18 @@ async def _search_profile_raw( ) except Exception as exc: logger.warning("Profile raw search failed, using cached catalog: %s", exc) - _, results = self._get_profile_catalog(user_id) + _, results = await self._get_profile_catalog(user_id) records = [] for r in results[:top_k]: - records.append(SourceRecord( - domain="profile", - content=r.content, - score=r.score, - metadata={"id": r.id, **r.metadata}, - )) + records.append( + SourceRecord( + domain="profile", + content=r.content, + score=r.score, + metadata={"id": r.id, **r.metadata}, + ) + ) logger.info("Profile raw [%s]: %d results", query, len(records)) return records @@ -445,7 +488,7 @@ async def _search_temporal( query_text=query, top_k=top_k, similarity_threshold=0.15, - ) + ), ) records = [] @@ -472,12 +515,14 @@ async def _search_temporal( content = " | ".join(content_parts) - records.append(SourceRecord( - domain="temporal", - content=content, - score=ev.get("similarity_score", 0.0), - metadata=ev, - )) + records.append( + SourceRecord( + domain="temporal", + content=content, + score=ev.get("similarity_score", 0.0), + metadata=ev, + ) + ) logger.info(" → Temporal [%s]: %d results", query, len(records)) return records @@ -503,12 +548,14 @@ async def _search_summary( records = [] for r in results: - records.append(SourceRecord( - domain="summary", - content=r.content, - score=r.score, - metadata={"id": r.id, **r.metadata}, - )) + records.append( + SourceRecord( + domain="summary", + content=r.content, + score=r.score, + metadata={"id": r.id, **r.metadata}, + ) + ) logger.info(" → Summary [%s]: %d results", query, len(records)) return records @@ -538,7 +585,7 @@ async def _search_snippet( ) -> List[SourceRecord]: """Semantic search over user code snippets (sandboxed namespace).""" store = self._get_snippet_store(user_id) - + # In the sandboxed namespace, we can just search. We pass domain filter just in case. results = await store.search_by_text( query_text=query, @@ -555,12 +602,14 @@ async def _search_snippet( lang = r.metadata.get("language", "") content += f"\n```{lang}\n{snippet}\n```" - records.append(SourceRecord( - domain="snippet", - content=content, - score=r.score, - metadata={"id": r.id, **r.metadata}, - )) + records.append( + SourceRecord( + domain="snippet", + content=content, + score=r.score, + metadata={"id": r.id, **r.metadata}, + ) + ) logger.info(" → Snippet [%s]: %d results", query, len(records)) return records @@ -569,18 +618,22 @@ async def _search_snippet( # Profile catalog (tells the LLM what profile keys exist) # ------------------------------------------------------------------ - def _get_profile_catalog(self, user_id: str): + async def _get_profile_catalog(self, user_id: str): cached = self._profile_catalog_cache.get(user_id) now = time.monotonic() if cached and cached[0] > now: + self._profile_catalog_cache.move_to_end(user_id) return cached[1], cached[2] + if cached: + self._profile_catalog_cache.pop(user_id, None) - catalog, results = self._fetch_profile_catalog(user_id) + catalog, results = await asyncio.to_thread(self._fetch_profile_catalog, user_id) self._profile_catalog_cache[user_id] = ( now + _CACHE_TTL_SECONDS, catalog, results, ) + self._trim_cache(self._profile_catalog_cache, _PROFILE_CATALOG_CACHE_LIMIT) return catalog, results def _fetch_profile_catalog(self, user_id: str): @@ -611,15 +664,19 @@ def _fetch_profile_catalog(self, user_id: str): parts = main_content.split("_", 1) if len(parts) == 2: - catalog.append({ - "topic": parts[0], - "sub_topic": parts[1], - }) + catalog.append( + { + "topic": parts[0], + "sub_topic": parts[1], + } + ) else: - catalog.append({ - "topic": main_content, - "sub_topic": "", - }) + catalog.append( + { + "topic": main_content, + "sub_topic": "", + } + ) return catalog, results @@ -646,6 +703,7 @@ def _get_cached_retrieval_plan( if expires_at <= time.monotonic(): self._retrieval_plan_cache.pop(key, None) return None + self._retrieval_plan_cache.move_to_end(key) return response def _cache_retrieval_plan( @@ -657,6 +715,11 @@ def _cache_retrieval_plan( time.monotonic() + _CACHE_TTL_SECONDS, response, ) + self._trim_cache(self._retrieval_plan_cache, _RETRIEVAL_PLAN_CACHE_LIMIT) + + def _trim_cache(self, cache: OrderedDict, limit: int) -> None: + while len(cache) > limit: + cache.popitem(last=False) def _coerce_answer(self, answer: Any) -> str: if isinstance(answer, list): diff --git a/tests/integration/test_retrieval_pipeline.py b/tests/integration/test_retrieval_pipeline.py index dcfc805..7671b79 100644 --- a/tests/integration/test_retrieval_pipeline.py +++ b/tests/integration/test_retrieval_pipeline.py @@ -1,13 +1,21 @@ from __future__ import annotations +import asyncio + import pytest -from src.pipelines.retrieval import RetrievalPipeline +from src.pipelines.retrieval import ( + _RETRIEVAL_PLAN_CACHE_LIMIT, + RetrievalPipeline, +) +from src.schemas.retrieval import SourceRecord from tests.conftest import FakeChatModel, FakeLLMResponse @pytest.mark.asyncio -async def test_retrieval_pipeline_executes_tool_calls_and_generates_answer(vector_store, neo4j_client): +async def test_retrieval_pipeline_executes_tool_calls_and_generates_answer( + vector_store, neo4j_client +): vector_store.seed( "profile-1", "work / company = XMem", @@ -28,24 +36,43 @@ async def test_retrieval_pipeline_executes_tool_calls_and_generates_answer(vecto ) model = FakeChatModel( tool_responses=[ - FakeLLMResponse("", tool_calls=[ - {"name": "search_profile", "args": {"topic": "work"}, "id": "call-profile"}, - {"name": "search_temporal", "args": {"query": "launch"}, "id": "call-event"}, - ]) + FakeLLMResponse( + "", + tool_calls=[ + { + "name": "search_profile", + "args": {"topic": "work"}, + "id": "call-profile", + }, + { + "name": "search_temporal", + "args": {"query": "launch"}, + "id": "call-event", + }, + ], + ) ], responses=["Alice works at XMem and has a launch on 05-11."], ) - pipeline = RetrievalPipeline(model=model, vector_store=vector_store, neo4j_client=neo4j_client) + pipeline = RetrievalPipeline( + model=model, vector_store=vector_store, neo4j_client=neo4j_client + ) result = await pipeline.run("Where does Alice work and what is upcoming?", "alice") assert "XMem" in result.answer - assert {source.domain for source in result.sources} == {"profile", "temporal", "summary"} + assert {source.domain for source in result.sources} == { + "profile", + "temporal", + "summary", + } assert result.confidence > 0.1 @pytest.mark.asyncio -async def test_retrieval_pipeline_caches_catalog_and_retrieval_plan(vector_store, neo4j_client): +async def test_retrieval_pipeline_caches_catalog_and_retrieval_plan( + vector_store, neo4j_client +): vector_store.seed( "profile-1", "work / company = XMem", @@ -53,13 +80,22 @@ async def test_retrieval_pipeline_caches_catalog_and_retrieval_plan(vector_store ) model = FakeChatModel( tool_responses=[ - FakeLLMResponse("", tool_calls=[ - {"name": "search_profile", "args": {"topic": "work"}, "id": "call-profile"}, - ]) + FakeLLMResponse( + "", + tool_calls=[ + { + "name": "search_profile", + "args": {"topic": "work"}, + "id": "call-profile", + }, + ], + ) ], responses=["Alice works at XMem.", "Alice still works at XMem."], ) - pipeline = RetrievalPipeline(model=model, vector_store=vector_store, neo4j_client=neo4j_client) + pipeline = RetrievalPipeline( + model=model, vector_store=vector_store, neo4j_client=neo4j_client + ) first = await pipeline.run("Where does Alice work?", "alice") second = await pipeline.run("Where does Alice work?", "alice") @@ -70,7 +106,9 @@ async def test_retrieval_pipeline_caches_catalog_and_retrieval_plan(vector_store @pytest.mark.asyncio -async def test_raw_search_returns_ranked_hits_without_tool_selection(vector_store, neo4j_client): +async def test_raw_search_returns_ranked_hits_without_tool_selection( + vector_store, neo4j_client +): vector_store.seed( "profile-1", "work / company = XMem", @@ -92,7 +130,9 @@ async def test_raw_search_returns_ranked_hits_without_tool_selection(vector_stor similarity_score=0.8, ) model = FakeChatModel() - pipeline = RetrievalPipeline(model=model, vector_store=vector_store, neo4j_client=neo4j_client) + pipeline = RetrievalPipeline( + model=model, vector_store=vector_store, neo4j_client=neo4j_client + ) results = await pipeline.search_raw( "latency", @@ -109,10 +149,100 @@ async def test_raw_search_returns_ranked_hits_without_tool_selection(vector_stor assert not pipeline.model_with_tools.calls +@pytest.mark.asyncio +async def test_raw_search_runs_requested_domains_concurrently( + vector_store, neo4j_client +): + model = FakeChatModel() + pipeline = RetrievalPipeline( + model=model, vector_store=vector_store, neo4j_client=neo4j_client + ) + started: list[str] = [] + release = asyncio.Event() + + async def fake_domain(name: str, score: float): + started.append(name) + if len(started) == 3: + release.set() + await asyncio.wait_for(release.wait(), timeout=0.1) + return [SourceRecord(domain=name, content=f"{name} hit", score=score)] + + pipeline._search_profile_raw = lambda *_args: fake_domain("profile", 0.7) + pipeline._search_temporal = lambda *_args: fake_domain("temporal", 0.8) + pipeline._search_summary = lambda *_args: fake_domain("summary", 0.9) + + results = await pipeline.search_raw( + "latency", + "alice", + ["profile", "temporal", "summary"], + top_k=5, + ) + + assert set(started) == {"profile", "temporal", "summary"} + assert [record.domain for record in results] == ["summary", "temporal", "profile"] + + +@pytest.mark.asyncio +async def test_profile_catalog_fetch_does_not_block_event_loop( + vector_store, neo4j_client +): + import threading + + class BlockingVectorStore: + def search_by_metadata(self, filters, top_k=10): + threading.Event().wait(0.05) + return vector_store.search_by_metadata(filters, top_k=top_k) + + async def search_by_text(self, *args, **kwargs): + return await vector_store.search_by_text(*args, **kwargs) + + vector_store.seed( + "profile-1", + "work / company = XMem", + {"user_id": "alice", "domain": "profile", "main_content": "work_company"}, + ) + model = FakeChatModel() + pipeline = RetrievalPipeline( + model=model, + vector_store=BlockingVectorStore(), + neo4j_client=neo4j_client, + ) + ticks: list[str] = [] + + async def ticker(): + await asyncio.sleep(0.01) + ticks.append("tick") + + tick_task = asyncio.create_task(ticker()) + catalog, records = await pipeline._get_profile_catalog("alice") + await tick_task + + assert ticks == ["tick"] + assert catalog == [{"topic": "work", "sub_topic": "company"}] + assert len(records) == 1 + + +def test_retrieval_plan_cache_evicts_oldest_entry(vector_store, neo4j_client): + model = FakeChatModel() + pipeline = RetrievalPipeline( + model=model, vector_store=vector_store, neo4j_client=neo4j_client + ) + first_key = ("alice", "query-0", 5, "catalog-0") + + for index in range(_RETRIEVAL_PLAN_CACHE_LIMIT + 1): + key = ("alice", f"query-{index}", 5, f"catalog-{index}") + pipeline._cache_retrieval_plan(key, FakeLLMResponse(f"response-{index}")) + + assert len(pipeline._retrieval_plan_cache) == _RETRIEVAL_PLAN_CACHE_LIMIT + assert pipeline._get_cached_retrieval_plan(first_key) is None + + @pytest.mark.asyncio async def test_answer_from_sources_skips_tool_selection(vector_store, neo4j_client): model = FakeChatModel(responses=["Alice works at XMem."]) - pipeline = RetrievalPipeline(model=model, vector_store=vector_store, neo4j_client=neo4j_client) + pipeline = RetrievalPipeline( + model=model, vector_store=vector_store, neo4j_client=neo4j_client + ) answer = await pipeline.answer_from_sources( "Where does Alice work?", @@ -124,22 +254,38 @@ async def test_answer_from_sources_skips_tool_selection(vector_store, neo4j_clie @pytest.mark.asyncio -async def test_retrieval_tool_dispatch_handles_unknown_and_snippet(vector_store, neo4j_client): +async def test_retrieval_tool_dispatch_handles_unknown_and_snippet( + vector_store, neo4j_client +): model = FakeChatModel() - pipeline = RetrievalPipeline(model=model, vector_store=vector_store, neo4j_client=neo4j_client) + pipeline = RetrievalPipeline( + model=model, vector_store=vector_store, neo4j_client=neo4j_client + ) + class SnippetStore: async def search_by_text(self, **kwargs): - return [type("R", (), { - "id": "snip-1", - "content": "Binary search helper", - "score": 0.9, - "metadata": {"code_snippet": "def bs(): pass", "language": "python"}, - })()] + return [ + type( + "R", + (), + { + "id": "snip-1", + "content": "Binary search helper", + "score": 0.9, + "metadata": { + "code_snippet": "def bs(): pass", + "language": "python", + }, + }, + )() + ] snippet_store = SnippetStore() pipeline._snippet_stores["user-1"] = snippet_store assert await pipeline._execute_tool("missing_tool", {}, "user-1", 5) == [] - snippets = await pipeline._execute_tool("SearchSnippet", {"query": "binary search"}, "user-1", 5) + snippets = await pipeline._execute_tool( + "SearchSnippet", {"query": "binary search"}, "user-1", 5 + ) assert snippets[0].domain == "snippet" assert "def bs" in snippets[0].content From c06c5fea2d5a02f4fa8e351005cf7799b52f084c Mon Sep 17 00:00:00 2001 From: strongkeep-debug Date: Mon, 11 May 2026 12:50:05 -0700 Subject: [PATCH 3/5] Include code annotations in raw search --- src/api/schemas.py | 4 +- src/pipelines/retrieval.py | 49 +++++++++++++++ src/schemas/retrieval.py | 2 +- tests/api/test_memory_search_routes.py | 63 +++++++++++++++++--- tests/integration/test_retrieval_pipeline.py | 26 +++++++- 5 files changed, 132 insertions(+), 12 deletions(-) diff --git a/src/api/schemas.py b/src/api/schemas.py index c701616..9379bd5 100644 --- a/src/api/schemas.py +++ b/src/api/schemas.py @@ -158,7 +158,7 @@ class SearchRequest(BaseModel): ..., min_length=1, max_length=256, pattern=r"^[\w.\-@]+$", ) domains: List[str] = Field( - default=["profile", "temporal", "summary", "snippet"], + default=["profile", "temporal", "summary", "snippet", "code"], description="Which memory domains to search", ) top_k: int = Field(default=10, ge=1, le=100) @@ -170,7 +170,7 @@ class SearchRequest(BaseModel): @field_validator("domains") @classmethod def validate_domains(cls, v: List[str]) -> List[str]: - allowed = {"profile", "temporal", "summary", "snippet"} + allowed = {"profile", "temporal", "summary", "snippet", "code"} for d in v: if d not in allowed: raise ValueError(f"Invalid domain '{d}'. Allowed: {allowed}") diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index 08ac00e..8327675 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -316,6 +316,8 @@ async def search_raw( tasks.append(self._search_summary(query, user_id, top_k)) if "snippet" in domain_set: tasks.append(self._search_snippet(query, user_id, top_k)) + if "code" in domain_set: + tasks.append(self._search_code(query, user_id, top_k)) if not tasks: return [] @@ -560,6 +562,53 @@ async def _search_summary( logger.info(" → Summary [%s]: %d results", query, len(records)) return records + # -- Code: Pinecone semantic search -------------------------------- + + async def _search_code( + self, + query: str, + user_id: str, + top_k: int = 10, + ) -> List[SourceRecord]: + """Semantic search over stored code annotations.""" + + results = await self.vector_store.search_by_text( + query_text=query, + top_k=top_k, + filters={ + "user_id": user_id, + "domain": "code", + }, + ) + + records = [] + for r in results: + metadata = dict(r.metadata) + detail_parts = [] + for label, key in ( + ("repo", "repo"), + ("file", "target_file"), + ("symbol", "target_symbol"), + ("type", "annotation_type"), + ("severity", "severity"), + ): + value = metadata.get(key) + if value: + detail_parts.append(f"{label}={value}") + + prefix = f"[{'; '.join(detail_parts)}] " if detail_parts else "" + records.append( + SourceRecord( + domain="code", + content=f"{prefix}{r.content}", + score=r.score, + metadata={"id": r.id, **metadata}, + ) + ) + + logger.info(" -> Code [%s]: %d results", query, len(records)) + return records + # -- Snippet: Pinecone semantic search ───────────────────────────── def _get_snippet_store(self, user_id: str) -> PineconeVectorStore: diff --git a/src/schemas/retrieval.py b/src/schemas/retrieval.py index 3b0bb5e..a9ce819 100644 --- a/src/schemas/retrieval.py +++ b/src/schemas/retrieval.py @@ -26,7 +26,7 @@ def source_count(self) -> int: class SourceRecord: """A single piece of evidence fetched from a data store.""" - domain: str # "profile", "temporal", "summary", "snippet" + domain: str # "profile", "temporal", "summary", "snippet", "code" content: str # the actual text score: float = 0.0 # similarity score (if applicable) metadata: Dict[str, Any] = field(default_factory=dict) diff --git a/tests/api/test_memory_search_routes.py b/tests/api/test_memory_search_routes.py index 3c7221c..d05500e 100644 --- a/tests/api/test_memory_search_routes.py +++ b/tests/api/test_memory_search_routes.py @@ -18,17 +18,38 @@ class FakeSearchPipeline: def __init__(self) -> None: self.answer_calls = 0 + self.search_calls: list[dict[str, object]] = [] self.latencies: dict[str, list[float]] = {} - async def search_raw(self, query: str, user_id: str, domains: list[str], top_k: int): + async def search_raw( + self, query: str, user_id: str, domains: list[str], top_k: int + ): assert query == "latency" assert user_id == "Static Key User" - assert domains == ["profile", "summary"] assert top_k == 3 - return [ - SourceRecord(domain="summary", content="Low-latency summary", score=0.9), - SourceRecord(domain="profile", content="work / company = XMem", score=0.7), - ] + self.search_calls.append( + {"query": query, "user_id": user_id, "domains": domains, "top_k": top_k} + ) + + fixtures = { + "summary": SourceRecord( + domain="summary", + content="Low-latency summary", + score=0.9, + ), + "profile": SourceRecord( + domain="profile", + content="work / company = XMem", + score=0.7, + ), + "code": SourceRecord( + domain="code", + content="[file=src/retry.py; symbol=RetryLoop] Timeout retry note", + score=0.8, + metadata={"target_file": "src/retry.py"}, + ), + } + return [fixtures[domain] for domain in domains if domain in fixtures] async def answer_from_sources(self, query: str, sources: list[SourceRecord]) -> str: self.answer_calls += 1 @@ -39,7 +60,12 @@ def record_latency(self, mode: str, elapsed_ms: float) -> None: def get_latency_snapshot(self): return { - mode: {"count": len(samples), "p50": samples[-1], "p95": samples[-1], "p99": samples[-1]} + mode: { + "count": len(samples), + "p50": samples[-1], + "p95": samples[-1], + "p99": samples[-1], + } for mode, samples in self.latencies.items() } @@ -78,6 +104,7 @@ def test_memory_search_route_returns_raw_hits_without_answer(memory_search_app): assert payload["data"]["total"] == 2 assert payload["data"]["answer"] == "" assert payload["data"]["latency"]["raw"]["count"] == 1 + assert pipeline.search_calls[0]["domains"] == ["profile", "summary"] assert pipeline.answer_calls == 0 @@ -102,3 +129,25 @@ def test_root_search_alias_can_synthesize_answer(memory_search_app): assert payload["data"]["model"] == "fake-retrieval" assert payload["data"]["latency"]["answer"]["count"] == 1 assert pipeline.answer_calls == 1 + + +def test_memory_search_route_accepts_code_domain(memory_search_app): + app, pipeline = memory_search_app + response = TestClient(app).post( + "/v1/memory/search", + headers={"Authorization": "Bearer test-static-key"}, + json={ + "query": "latency", + "user_id": "ignored-by-auth", + "domains": ["code"], + "top_k": 3, + }, + ) + + payload = response.json() + + assert response.status_code == 200 + assert payload["data"]["total"] == 1 + assert payload["data"]["results"][0]["domain"] == "code" + assert payload["data"]["results"][0]["metadata"]["target_file"] == "src/retry.py" + assert pipeline.search_calls[0]["domains"] == ["code"] diff --git a/tests/integration/test_retrieval_pipeline.py b/tests/integration/test_retrieval_pipeline.py index 7671b79..aff7fc4 100644 --- a/tests/integration/test_retrieval_pipeline.py +++ b/tests/integration/test_retrieval_pipeline.py @@ -121,6 +121,20 @@ async def test_raw_search_returns_ranked_hits_without_tool_selection( {"user_id": "alice", "domain": "summary"}, score=0.9, ) + vector_store.seed( + "code-1", + "RetryLoop can spin when the first retrieval attempt times out.", + { + "user_id": "alice", + "domain": "code", + "annotation_type": "bug_report", + "target_symbol": "RetryLoop", + "target_file": "src/retry.py", + "repo": "xmem", + "severity": "high", + }, + score=0.95, + ) neo4j_client.seed_event( user_id="alice", date="05-11", @@ -137,7 +151,7 @@ async def test_raw_search_returns_ranked_hits_without_tool_selection( results = await pipeline.search_raw( "latency", "alice", - ["profile", "temporal", "summary"], + ["profile", "temporal", "summary", "code"], top_k=5, ) @@ -145,7 +159,15 @@ async def test_raw_search_returns_ranked_hits_without_tool_selection( [record.score for record in results], reverse=True, ) - assert {record.domain for record in results} == {"profile", "temporal", "summary"} + assert {record.domain for record in results} == { + "profile", + "temporal", + "summary", + "code", + } + code_hit = next(record for record in results if record.domain == "code") + assert "file=src/retry.py" in code_hit.content + assert code_hit.metadata["target_symbol"] == "RetryLoop" assert not pipeline.model_with_tools.calls From 5548b63e63fa61e1bbf5d550dc28ecf825b152f1 Mon Sep 17 00:00:00 2001 From: strongkeep-debug Date: Tue, 12 May 2026 02:48:35 -0700 Subject: [PATCH 4/5] Harden raw search score handling --- src/api/routes/memory.py | 22 ++++++++---- src/pipelines/retrieval.py | 38 +++++++++++++++----- tests/api/test_memory_search_routes.py | 3 +- tests/integration/test_retrieval_pipeline.py | 29 +++++++++++++++ 4 files changed, 75 insertions(+), 17 deletions(-) diff --git a/src/api/routes/memory.py b/src/api/routes/memory.py index 88ba771..966a736 100644 --- a/src/api/routes/memory.py +++ b/src/api/routes/memory.py @@ -8,6 +8,8 @@ import asyncio import logging +import math +import threading import time from typing import Any, Dict, List @@ -113,6 +115,14 @@ def _error(request: Request, detail: str, code: int, elapsed_ms: float = 0) -> J return JSONResponse(content=body.model_dump(), status_code=code) +def _safe_score(score: Any) -> float: + try: + value = float(score) + except (TypeError, ValueError): + return 0.0 + return value if math.isfinite(value) else 0.0 + + def _detect_chat_provider(*urls: str) -> str: for url in urls: lowered = (url or "").lower() @@ -150,8 +160,6 @@ async def _render_chat_share(url: str) -> tuple[str, str]: # reuse it across scrape requests. The browser is thread-safe when each # request uses its own BrowserContext. -import threading - _browser_lock = threading.Lock() _pw_instance = None _browser_instance = None @@ -665,7 +673,7 @@ async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = D sources=[ SourceRecord( domain=s.domain, content=s.content, - score=round(s.score, 3), metadata=s.metadata, + score=round(_safe_score(s.score), 3), metadata=s.metadata, ) for s in result.sources ], @@ -717,7 +725,7 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen SourceRecord( domain=s.domain, content=s.content, - score=round(s.score, 3), + score=round(_safe_score(s.score), 3), metadata=s.metadata, ) for s in all_results @@ -741,7 +749,7 @@ def _search_profile(pipeline: RetrievalPipeline, user_id: str) -> List[SourceRec raw = pipeline.vector_store.search_by_metadata( filters={"user_id": user_id, "domain": "profile"}, top_k=100, ) - return [SourceRecord(domain="profile", content=r.content, score=r.score, metadata=r.metadata) for r in raw] + return [SourceRecord(domain="profile", content=r.content, score=_safe_score(r.score), metadata=r.metadata) for r in raw] except Exception as exc: logger.warning("Profile search error: %s", exc) return [] @@ -768,7 +776,7 @@ def _search_temporal(pipeline: RetrievalPipeline, query: str, user_id: str, top_ parts.append(f"Time: {ev['time']}") results.append(SourceRecord( domain="temporal", content=" | ".join(parts), - score=ev.get("similarity_score", 0.0), metadata=ev, + score=_safe_score(ev.get("similarity_score", 0.0)), metadata=ev, )) return results except Exception as exc: @@ -783,7 +791,7 @@ async def _search_summary(pipeline: RetrievalPipeline, query: str, user_id: str, filters={"user_id": user_id, "domain": "summary"}, ) return [ - SourceRecord(domain="summary", content=r.content, score=r.score, metadata={"id": r.id, **r.metadata}) + SourceRecord(domain="summary", content=r.content, score=_safe_score(r.score), metadata={"id": r.id, **r.metadata}) for r in raw ] except Exception as exc: diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index 8327675..0065402 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -23,6 +23,7 @@ import asyncio import hashlib import logging +import math import time from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional @@ -322,10 +323,15 @@ async def search_raw( if not tasks: return [] - task_results = await asyncio.gather(*tasks) + task_results = await asyncio.gather(*tasks, return_exceptions=True) results = [ - record for domain_results in task_results for record in domain_results + record + for domain_results in task_results + if not self._log_search_error(domain_results) + for record in domain_results ] + for record in results: + record.score = self._score_value(record.score) return sorted(results, key=lambda record: record.score, reverse=True) @@ -425,7 +431,7 @@ def _search_profile( SourceRecord( domain="profile", content=r.content, - score=r.score, + score=self._score_value(r.score), metadata={ "id": r.id, "topic": topic, @@ -464,7 +470,7 @@ async def _search_profile_raw( SourceRecord( domain="profile", content=r.content, - score=r.score, + score=self._score_value(r.score), metadata={"id": r.id, **r.metadata}, ) ) @@ -521,7 +527,7 @@ async def _search_temporal( SourceRecord( domain="temporal", content=content, - score=ev.get("similarity_score", 0.0), + score=self._score_value(ev.get("similarity_score", 0.0)), metadata=ev, ) ) @@ -554,7 +560,7 @@ async def _search_summary( SourceRecord( domain="summary", content=r.content, - score=r.score, + score=self._score_value(r.score), metadata={"id": r.id, **r.metadata}, ) ) @@ -601,7 +607,7 @@ async def _search_code( SourceRecord( domain="code", content=f"{prefix}{r.content}", - score=r.score, + score=self._score_value(r.score), metadata={"id": r.id, **metadata}, ) ) @@ -655,7 +661,7 @@ async def _search_snippet( SourceRecord( domain="snippet", content=content, - score=r.score, + score=self._score_value(r.score), metadata={"id": r.id, **r.metadata}, ) ) @@ -770,6 +776,19 @@ def _trim_cache(self, cache: OrderedDict, limit: int) -> None: while len(cache) > limit: cache.popitem(last=False) + def _log_search_error(self, domain_results: Any) -> bool: + if isinstance(domain_results, Exception): + logger.warning("Raw search domain failed: %s", domain_results) + return True + return False + + def _score_value(self, score: Any) -> float: + try: + value = float(score) + except (TypeError, ValueError): + return 0.0 + return value if math.isfinite(value) else 0.0 + def _coerce_answer(self, answer: Any) -> str: if isinstance(answer, list): parts = [] @@ -810,7 +829,8 @@ def _format_tool_results(self, records: List[SourceRecord]) -> str: lines = [] for i, rec in enumerate(records, 1): - score_str = f" (score: {rec.score:.2f})" if rec.score > 0 else "" + score = self._score_value(rec.score) + score_str = f" (score: {score:.2f})" if score > 0 else "" lines.append(f"{i}. [{rec.domain}]{score_str} {rec.content}") return "\n".join(lines) diff --git a/tests/api/test_memory_search_routes.py b/tests/api/test_memory_search_routes.py index d05500e..e674463 100644 --- a/tests/api/test_memory_search_routes.py +++ b/tests/api/test_memory_search_routes.py @@ -40,7 +40,7 @@ async def search_raw( "profile": SourceRecord( domain="profile", content="work / company = XMem", - score=0.7, + score=None, ), "code": SourceRecord( domain="code", @@ -102,6 +102,7 @@ def test_memory_search_route_returns_raw_hits_without_answer(memory_search_app): assert response.status_code == 200 assert payload["data"]["total"] == 2 + assert payload["data"]["results"][0]["score"] == 0.0 assert payload["data"]["answer"] == "" assert payload["data"]["latency"]["raw"]["count"] == 1 assert pipeline.search_calls[0]["domains"] == ["profile", "summary"] diff --git a/tests/integration/test_retrieval_pipeline.py b/tests/integration/test_retrieval_pipeline.py index aff7fc4..3f40d7d 100644 --- a/tests/integration/test_retrieval_pipeline.py +++ b/tests/integration/test_retrieval_pipeline.py @@ -204,6 +204,35 @@ async def fake_domain(name: str, score: float): assert [record.domain for record in results] == ["summary", "temporal", "profile"] +@pytest.mark.asyncio +async def test_raw_search_skips_failed_domains_and_normalizes_scores( + vector_store, neo4j_client +): + model = FakeChatModel() + pipeline = RetrievalPipeline( + model=model, vector_store=vector_store, neo4j_client=neo4j_client + ) + + async def profile_domain(*_args): + return [SourceRecord(domain="profile", content="No backend score", score=None)] + + async def summary_domain(*_args): + raise RuntimeError("summary backend offline") + + pipeline._search_profile_raw = profile_domain + pipeline._search_summary = summary_domain + + results = await pipeline.search_raw( + "latency", + "alice", + ["profile", "summary"], + top_k=5, + ) + + assert [(record.domain, record.score) for record in results] == [("profile", 0.0)] + assert pipeline._format_tool_results(results) == "1. [profile] No backend score" + + @pytest.mark.asyncio async def test_profile_catalog_fetch_does_not_block_event_loop( vector_store, neo4j_client From a3f6429912326b470597236acbb8b72973376878 Mon Sep 17 00:00:00 2001 From: strongkeep-debug Date: Wed, 13 May 2026 08:04:28 -0700 Subject: [PATCH 5/5] Add raw answer missing-score regression --- tests/integration/test_retrieval_pipeline.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/integration/test_retrieval_pipeline.py b/tests/integration/test_retrieval_pipeline.py index 3f40d7d..a5fc439 100644 --- a/tests/integration/test_retrieval_pipeline.py +++ b/tests/integration/test_retrieval_pipeline.py @@ -304,6 +304,25 @@ async def test_answer_from_sources_skips_tool_selection(vector_store, neo4j_clie assert not pipeline.model_with_tools.calls +@pytest.mark.asyncio +async def test_answer_from_sources_handles_missing_source_scores( + vector_store, neo4j_client +): + model = FakeChatModel(responses=["Alice works at XMem."]) + pipeline = RetrievalPipeline( + model=model, vector_store=vector_store, neo4j_client=neo4j_client + ) + + answer = await pipeline.answer_from_sources( + "Where does Alice work?", + [SourceRecord(domain="profile", content="work / company = XMem", score=None)], + ) + + assert answer == "Alice works at XMem." + assert "score:" not in model.calls[0][0].content + assert not pipeline.model_with_tools.calls + + @pytest.mark.asyncio async def test_retrieval_tool_dispatch_handles_unknown_and_snippet( vector_store, neo4j_client