From bb08f0752527af3e4ac9e21165ddf2eb9f9cad8f Mon Sep 17 00:00:00 2001 From: Michael Schvarcz Date: Mon, 11 May 2026 11:37:50 -0700 Subject: [PATCH 1/3] Add raw memory search fast path --- src/api/routes/memory.py | 80 +++++- src/api/schemas.py | 28 +- src/pipelines/retrieval.py | 279 +++++++++++++++++-- tests/integration/test_retrieval_pipeline.py | 100 +++++++ 4 files changed, 450 insertions(+), 37 deletions(-) diff --git a/src/api/routes/memory.py b/src/api/routes/memory.py index b4be36d..53ae411 100644 --- a/src/api/routes/memory.py +++ b/src/api/routes/memory.py @@ -16,6 +16,7 @@ from src.api.dependencies import ( enforce_rate_limit, + get_code_pipeline, get_ingest_pipeline, get_retrieval_pipeline, require_api_key, @@ -689,16 +690,77 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen user_id = user.get("username") or user.get("name") or user["id"] try: - all_results: List[SourceRecord] = [] - - if "profile" in req.domains: - all_results.extend(_search_profile(pipeline, user_id)) - if "temporal" in req.domains: - all_results.extend(_search_temporal(pipeline, req.query, user_id, req.top_k)) - if "summary" in req.domains: - all_results.extend(await _search_summary(pipeline, req.query, user_id, req.top_k)) + if "code" in req.domains and not req.org_id: + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _error(request, "org_id is required when searching the code domain.", 400, elapsed) + + memory_domains = [domain for domain in req.domains if domain != "code"] + result = await pipeline.raw_search( + query=req.query, + user_id=user_id, + domains=memory_domains, + top_k=req.top_k, + include_answer=False, + ) + records = list(result.sources) + + if "code" in req.domains: + code_pipeline = get_code_pipeline(org_id=req.org_id or "", repo=req.repo) + code_results = await asyncio.gather( + code_pipeline._execute_tool( + tool_name="search_symbols", + tool_args={"query": req.query, "repo": req.repo}, + repo=req.repo, + top_k=req.top_k, + user_id=user_id, + ), + code_pipeline._execute_tool( + tool_name="search_files", + tool_args={"query": req.query, "repo": req.repo}, + repo=req.repo, + top_k=req.top_k, + user_id=user_id, + ), + return_exceptions=True, + ) + for code_records in code_results: + if isinstance(code_records, Exception): + logger.warning("Code search subquery failed: %s", code_records) + continue + records.extend(code_records) + + records = sorted(records, key=lambda s: s.score or 0.0, reverse=True) + + answer = "" + if req.answer: + answer = await pipeline.answer_from_sources(query=req.query, sources=records) + pipeline._record_latency( + "raw_search_answer", + (time.perf_counter() - start) * 1000, + ) + elif "code" in req.domains: + pipeline._record_latency( + "raw_search_code", + (time.perf_counter() - start) * 1000, + ) - data = SearchResponse(results=all_results, total=len(all_results)) + confidence = pipeline.confidence_from_sources(records) + data = SearchResponse( + results=[ + SourceRecord( + domain=s.domain, + content=s.content, + score=round(s.score, 3) if s.score is not None else 0.0, + metadata=s.metadata, + ) + for s in records + ], + total=len(records), + answer=answer, + model=_model_name(pipeline.model) if req.answer else "", + confidence=confidence, + latency=pipeline.latency_snapshot(), + ) elapsed = round((time.perf_counter() - start) * 1000, 2) return _wrap(request, data, elapsed) diff --git a/src/api/schemas.py b/src/api/schemas.py index b7ee122..0794f6d 100644 --- a/src/api/schemas.py +++ b/src/api/schemas.py @@ -159,24 +159,48 @@ class SearchRequest(BaseModel): ..., min_length=1, max_length=256, pattern=r"^[\w.\-@]+$", ) domains: List[str] = Field( - default=["profile", "temporal", "summary"], + default=["profile", "temporal", "summary", "snippet"], description="Which memory domains to search", ) top_k: int = Field(default=10, ge=1, le=100) + answer: bool = Field( + default=False, + description="When true, synthesize an LLM answer from the raw hits.", + ) + org_id: Optional[str] = Field( + default=None, + min_length=1, + max_length=256, + description="Required when including the code domain.", + ) + repo: str = Field( + default="", + max_length=256, + description="Optional repository scope for code search.", + ) @field_validator("domains") @classmethod def validate_domains(cls, v: List[str]) -> List[str]: - allowed = {"profile", "temporal", "summary"} + allowed = {"profile", "temporal", "summary", "snippet", "code"} for d in v: if d not in allowed: raise ValueError(f"Invalid domain '{d}'. Allowed: {allowed}") return v + @field_validator("query") + @classmethod + def strip_search_query(cls, v: str) -> str: + return v.strip() + class SearchResponse(BaseModel): results: List[SourceRecord] = Field(default_factory=list) total: int = 0 + answer: str = "" + model: str = "" + confidence: float = 0.0 + latency: Dict[str, Dict[str, float]] = Field(default_factory=dict) # ── Scrape (extract from shared chat links) ──────────────────────────────── diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index 3516561..e498017 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -22,7 +22,9 @@ import asyncio import logging -from typing import Any, Callable, Dict, List, Optional +import time +from collections import OrderedDict, defaultdict, deque +from typing import Any, Callable, Deque, Dict, List, Optional, Tuple from dotenv import load_dotenv from langchain_core.language_models import BaseChatModel @@ -40,6 +42,45 @@ logger = logging.getLogger("xmem.pipelines.retrieval") +CONFIDENCE_PER_SOURCE = 0.2 +PROFILE_CATALOG_CACHE_MAX = 512 +RETRIEVAL_PLAN_CACHE_MAX = 1024 + + +class RetrievalLatencyTracker: + """Bounded in-memory latency samples for retrieval modes.""" + + def __init__(self, max_samples: int = 512) -> None: + self._samples: Dict[str, Deque[float]] = defaultdict( + lambda: deque(maxlen=max_samples) + ) + + def record(self, mode: str, elapsed_ms: float) -> None: + self._samples[mode].append(elapsed_ms) + + def snapshot(self) -> Dict[str, Dict[str, float]]: + return { + mode: self._percentiles(samples) + for mode, samples in self._samples.items() + if samples + } + + @staticmethod + def _percentiles(samples: Deque[float]) -> Dict[str, float]: + ordered = sorted(samples) + count = len(ordered) + + def percentile(pct: float) -> float: + index = min(count - 1, max(0, int(round((pct / 100) * (count - 1))))) + return round(ordered[index], 2) + + return { + "count": count, + "p50_ms": percentile(50), + "p95_ms": percentile(95), + "p99_ms": percentile(99), + } + # ═══════════════════════════════════════════════════════════════════════════ # Tool schemas — These are the "function signatures" exposed to the LLM @@ -132,6 +173,14 @@ def __init__( self.embed_fn = embed_fn self._snippet_stores: Dict[str, PineconeVectorStore] = {} + self._profile_catalog_cache: OrderedDict[ + str, Tuple[float, List[Dict[str, str]], List[Any]] + ] = OrderedDict() + self._retrieval_plan_cache: OrderedDict[ + str, List[Dict[str, Any]] + ] = OrderedDict() + self._cache_ttl_s = 300.0 + self._latency = RetrievalLatencyTracker() logger.info("RetrievalPipeline initialized") @@ -146,6 +195,7 @@ async def run( top_k: int = 5, ) -> RetrievalResult: """Run the two-step retrieval pipeline.""" + mode_start = time.perf_counter() logger.info("=" * 60) logger.info("RETRIEVAL PIPELINE START") @@ -154,28 +204,42 @@ async def run( logger.info("=" * 60) # ── Step 0: Fetch available profile catalog for this user ───── - profile_catalog, profile_records = self._fetch_profile_catalog(user_id) + profile_catalog, _ = self._fetch_profile_catalog(user_id) catalog_text = self._format_catalog(profile_catalog) logger.info("Available profiles: %s", catalog_text) - # Store profile records so _search_profile can reuse them - self._cached_profile_records = profile_records - # ── Step 1: Ask LLM what to fetch (tool calls) ──────────────── system_prompt = build_system_prompt(profile_catalog=catalog_text) - messages = [ - SystemMessage(content=system_prompt), - HumanMessage(content=query), - ] + plan_key = self._plan_cache_key(user_id, query, top_k, catalog_text) + cached_tool_calls = self._get_cached_retrieval_plan(plan_key) + ai_response: Optional[AIMessage] = None - ai_response: AIMessage = await self.model_with_tools.ainvoke(messages) - logger.info("LLM response received (tool_calls=%d)", len(ai_response.tool_calls or [])) + if cached_tool_calls is not None: + tool_calls = cached_tool_calls + logger.info("Using cached retrieval plan (tool_calls=%d)", len(tool_calls)) + else: + messages = [ + SystemMessage(content=system_prompt), + HumanMessage(content=query), + ] + ai_response = await self.model_with_tools.ainvoke(messages) + tool_calls = list(ai_response.tool_calls or []) + logger.info("LLM response received (tool_calls=%d)", len(tool_calls)) + if tool_calls: + self._cache_retrieval_plan(plan_key, [ + { + "name": tc["name"], + "args": dict(tc.get("args") or {}), + "id": f"cached-{idx}", + } + for idx, tc in enumerate(tool_calls) + ]) # ── Step 2: Execute tool calls ──────────────────────────────── sources: List[SourceRecord] = [] tool_messages: List[ToolMessage] = [] - if ai_response.tool_calls: + if tool_calls: called_tools = set() async def _process_tool_call(tc): @@ -188,7 +252,7 @@ async def _process_tool_call(tc): ) return tool_name, tool_args, tool_id, records - tool_results = await asyncio.gather(*[_process_tool_call(tc) for tc in ai_response.tool_calls]) + tool_results = await asyncio.gather(*[_process_tool_call(tc) for tc in tool_calls]) for tool_name, tool_args, tool_id, records in tool_results: sources.extend(records) @@ -214,7 +278,7 @@ async def _process_tool_call(tc): tool_messages.append( ToolMessage( content=f"[Auto-fetched summary context]\n{extra_text}", - tool_call_id=ai_response.tool_calls[-1]["id"], + tool_call_id=tool_calls[-1]["id"], ) ) @@ -233,7 +297,7 @@ async def _process_tool_call(tc): answer = final_response.content else: # No tool calls — LLM answered directly (shouldn't happen often) - answer = ai_response.content + answer = ai_response.content if ai_response is not None else "" logger.info("LLM answered without tool calls") if isinstance(answer, list): @@ -247,7 +311,7 @@ async def _process_tool_call(tc): parts.append(str(c)) answer = "\n".join(parts) - confidence = min(1.0, len(sources) * 0.2) if sources else 0.1 + confidence = self.confidence_from_sources(sources) logger.info("=" * 60) logger.info("RETRIEVAL PIPELINE COMPLETE") @@ -255,12 +319,141 @@ async def _process_tool_call(tc): logger.info(" answer: %s", answer[:100] + "..." if len(answer) > 100 else answer) logger.info("=" * 60) - return RetrievalResult( + result = RetrievalResult( query=query, answer=answer, sources=sources, confidence=confidence, ) + self._record_latency("agentic", (time.perf_counter() - mode_start) * 1000) + return result + + async def raw_search( + self, + query: str, + user_id: str, + domains: Optional[List[str]] = None, + top_k: int = 10, + include_answer: bool = False, + ) -> RetrievalResult: + """Fast retrieval path that returns ranked hits without LLM planning.""" + mode_start = time.perf_counter() + requested = domains if domains is not None else [ + "profile", "temporal", "summary", "snippet", + ] + sources: List[SourceRecord] = [] + tasks = [] + + if "profile" in requested: + sources.extend(self._search_profile_catalog(user_id=user_id, top_k=top_k)) + if "temporal" in requested: + tasks.append(self._search_temporal(query=query, user_id=user_id, top_k=top_k)) + if "summary" in requested: + tasks.append(self._search_summary(query=query, user_id=user_id, top_k=top_k)) + if "snippet" in requested: + tasks.append(self._search_snippet(query=query, user_id=user_id, top_k=top_k)) + + if tasks: + for records in await asyncio.gather(*tasks): + sources.extend(records) + + sources = self._rank_sources(sources) + answer = "" + if include_answer: + answer = await self.answer_from_sources(query=query, sources=sources) + + confidence = self.confidence_from_sources(sources) + result = RetrievalResult( + query=query, + answer=answer, + sources=sources, + confidence=confidence, + ) + + mode = "raw_search_answer" if include_answer else "raw_search" + self._record_latency(mode, (time.perf_counter() - mode_start) * 1000) + return result + + def latency_snapshot(self) -> Dict[str, Dict[str, float]]: + """Return p50/p95/p99 latency snapshots by retrieval mode.""" + return self._latency.snapshot() + + def _plan_cache_key( + self, + user_id: str, + query: str, + top_k: int, + catalog_text: str, + ) -> str: + return "\0".join([user_id, str(top_k), query, catalog_text]) + + def _get_cached_retrieval_plan( + self, + key: str, + ) -> Optional[List[Dict[str, Any]]]: + plan = self._retrieval_plan_cache.get(key) + if plan is not None: + self._retrieval_plan_cache.move_to_end(key) + return plan + + def _cache_retrieval_plan( + self, + key: str, + plan: List[Dict[str, Any]], + ) -> None: + self._retrieval_plan_cache[key] = plan + self._retrieval_plan_cache.move_to_end(key) + if len(self._retrieval_plan_cache) > RETRIEVAL_PLAN_CACHE_MAX: + self._retrieval_plan_cache.popitem(last=False) + + def _record_latency(self, mode: str, elapsed_ms: float) -> None: + self._latency.record(mode, elapsed_ms) + try: + from src.config.metrics import METRICS + METRICS.pipeline_stage_duration.labels( + pipeline="retrieval", + stage=mode, + ).observe(elapsed_ms / 1000) + except Exception: + pass + + def _rank_sources(self, sources: List[SourceRecord]) -> List[SourceRecord]: + return sorted( + sources, + key=lambda source: source.score if source.score is not None else 0.0, + reverse=True, + ) + + def confidence_from_sources(self, sources: List[SourceRecord]) -> float: + return min(1.0, len(sources) * CONFIDENCE_PER_SOURCE) if sources else 0.1 + + async def answer_from_sources( + self, + query: str, + sources: List[SourceRecord], + ) -> str: + if not sources: + return "I could not find matching memory records for that query." + + answer_prompt = ANSWER_PROMPT.format( + context=self._format_tool_results(sources), + query=query, + ) + final_response = await self.model.ainvoke([HumanMessage(content=answer_prompt)]) + return self._coerce_answer_text(final_response.content) + + def _coerce_answer_text(self, answer: Any) -> str: + if isinstance(answer, list): + parts = [] + for c in answer: + if isinstance(c, dict) and "text" in c: + parts.append(c["text"]) + elif isinstance(c, str): + parts.append(c) + else: + parts.append(str(c)) + return "\n".join(parts) + return str(answer) # ------------------------------------------------------------------ # Tool execution @@ -320,9 +513,10 @@ def _search_profile( 'work_company', 'work_title', etc.). """ topic_prefix = topic.strip().lower().replace(" ", "_") + _, profile_records = self._fetch_profile_catalog(user_id) records = [] - for r in getattr(self, "_cached_profile_records", []): + for r in profile_records: main_content = r.metadata.get("main_content", "") if not main_content.startswith(topic_prefix): continue @@ -346,6 +540,25 @@ def _search_profile( logger.info(" → Profile [%s]: %d results", topic, len(records)) return records + def _search_profile_catalog( + self, + user_id: str, + top_k: int, + ) -> List[SourceRecord]: + """Return cached profile records for raw search.""" + _, profile_records = self._fetch_profile_catalog(user_id) + records = [ + SourceRecord( + domain="profile", + content=r.content, + score=r.score, + metadata={"id": r.id, **r.metadata}, + ) + for r in profile_records[:top_k] + ] + logger.info(" → Profile catalog: %d results", len(records)) + return records + # -- Temporal: Neo4j semantic search ─────────────────────────────── async def _search_temporal( @@ -459,14 +672,18 @@ async def _search_snippet( top_k: int = 5, ) -> List[SourceRecord]: """Semantic search over user code snippets (sandboxed namespace).""" - store = self._get_snippet_store(user_id) - - # In the sandboxed namespace, we can just search. We pass domain filter just in case. - results = await store.search_by_text( - query_text=query, - top_k=top_k, - filters={"domain": "snippet"}, - ) + try: + store = self._get_snippet_store(user_id) + + # In the sandboxed namespace, we can just search. We pass domain filter just in case. + results = await store.search_by_text( + query_text=query, + top_k=top_k, + filters={"domain": "snippet"}, + ) + except Exception as exc: + logger.warning("Snippet search error: %s", exc) + return [] records = [] for r in results: @@ -499,6 +716,12 @@ def _fetch_profile_catalog(self, user_id: str): catalog — list of {topic, sub_topic} for the prompt raw_results — the full SearchResult list, cached for _search_profile """ + cached = self._profile_catalog_cache.get(user_id) + now = time.monotonic() + if cached and now - cached[0] < self._cache_ttl_s: + self._profile_catalog_cache.move_to_end(user_id) + return cached[1], cached[2] + try: results = self.vector_store.search_by_metadata( filters={"user_id": user_id, "domain": "profile"}, @@ -529,6 +752,10 @@ def _fetch_profile_catalog(self, user_id: str): "sub_topic": "", }) + self._profile_catalog_cache[user_id] = (now, catalog, results) + self._profile_catalog_cache.move_to_end(user_id) + if len(self._profile_catalog_cache) > PROFILE_CATALOG_CACHE_MAX: + self._profile_catalog_cache.popitem(last=False) return catalog, results def _format_catalog(self, catalog: List[Dict[str, str]]) -> str: diff --git a/tests/integration/test_retrieval_pipeline.py b/tests/integration/test_retrieval_pipeline.py index 076b26d..f280d6d 100644 --- a/tests/integration/test_retrieval_pipeline.py +++ b/tests/integration/test_retrieval_pipeline.py @@ -64,3 +64,103 @@ async def search_by_text(self, **kwargs): snippets = await pipeline._execute_tool("SearchSnippet", {"query": "binary search"}, "user-1", 5) assert snippets[0].domain == "snippet" assert "def bs" in snippets[0].content + + +@pytest.mark.asyncio +async def test_raw_search_returns_ranked_hits_without_llm_planning(vector_store, neo4j_client): + vector_store.seed( + "profile-1", + "work / company = XMem", + {"user_id": "alice", "domain": "profile", "main_content": "work_company"}, + score=0.7, + ) + vector_store.seed( + "summary-1", + "Alice is tuning low-latency retrieval.", + {"user_id": "alice", "domain": "summary"}, + score=0.9, + ) + neo4j_client.seed_event( + user_id="alice", + date="05-11", + event_name="Launch", + desc="Raw search launch", + year="2026", + similarity_score=0.8, + ) + + class SnippetStore: + async def search_by_text(self, **kwargs): + return [type("R", (), { + "id": "snip-1", + "content": "Search helper", + "score": 0.95, + "metadata": {"code_snippet": "def search(): pass", "language": "python"}, + })()] + + model = FakeChatModel(responses=["unused"]) + pipeline = RetrievalPipeline(model=model, vector_store=vector_store, neo4j_client=neo4j_client) + pipeline._snippet_stores["alice"] = SnippetStore() + + result = await pipeline.raw_search( + query="search launch", + user_id="alice", + domains=["profile", "summary", "temporal", "snippet"], + top_k=5, + ) + + assert [source.domain for source in result.sources] == [ + "snippet", "summary", "temporal", "profile", + ] + assert result.answer == "" + assert model.calls == [] + assert pipeline.model_with_tools.calls == [] + assert "raw_search" in pipeline.latency_snapshot() + + +@pytest.mark.asyncio +async def test_raw_search_can_synthesize_answer_after_hits(vector_store, neo4j_client): + vector_store.seed( + "summary-1", + "Alice is building raw search.", + {"user_id": "alice", "domain": "summary"}, + ) + model = FakeChatModel(responses=["Alice is building raw search."]) + pipeline = RetrievalPipeline(model=model, vector_store=vector_store, neo4j_client=neo4j_client) + + result = await pipeline.raw_search( + query="What is Alice building?", + user_id="alice", + domains=["summary"], + include_answer=True, + ) + + assert result.answer == "Alice is building raw search." + assert len(model.calls) == 1 + assert pipeline.model_with_tools.calls == [] + assert "raw_search_answer" in pipeline.latency_snapshot() + + +@pytest.mark.asyncio +async def test_retrieval_pipeline_reuses_cached_tool_plan(vector_store, neo4j_client): + vector_store.seed( + "profile-1", + "work / company = XMem", + {"user_id": "alice", "domain": "profile", "main_content": "work_company"}, + ) + model = FakeChatModel( + tool_responses=[ + FakeLLMResponse("", tool_calls=[ + {"name": "search_profile", "args": {"topic": "work"}, "id": "call-profile"}, + ]) + ], + responses=["first answer", "second answer"], + ) + pipeline = RetrievalPipeline(model=model, vector_store=vector_store, neo4j_client=neo4j_client) + + first = await pipeline.run("Where does Alice work?", "alice") + second = await pipeline.run("Where does Alice work?", "alice") + + assert first.answer == "first answer" + assert second.answer == "second answer" + assert len(pipeline.model_with_tools.calls) == 1 From d2df8052e9316c4521edcfd873be6012df35627e Mon Sep 17 00:00:00 2001 From: Michael Schvarcz Date: Tue, 12 May 2026 06:44:45 -0700 Subject: [PATCH 2/3] Harden raw search score handling --- src/api/routes/memory.py | 8 ++++++-- src/pipelines/retrieval.py | 5 ++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/api/routes/memory.py b/src/api/routes/memory.py index 53ae411..be2d2a5 100644 --- a/src/api/routes/memory.py +++ b/src/api/routes/memory.py @@ -70,6 +70,10 @@ def _model_name(model: Any) -> str: return getattr(model, "model", getattr(model, "model_name", "unknown")) +def _score_value(score: float | None) -> float: + return round(score, 3) if score is not None else 0.0 + + def _build_domain_result(judge: Any, weaver: Any) -> DomainResult | None: if not judge or not getattr(judge, "operations", None): return None @@ -661,7 +665,7 @@ async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = D sources=[ SourceRecord( domain=s.domain, content=s.content, - score=round(s.score, 3), metadata=s.metadata, + score=_score_value(s.score), metadata=s.metadata, ) for s in result.sources ], @@ -750,7 +754,7 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen SourceRecord( domain=s.domain, content=s.content, - score=round(s.score, 3) if s.score is not None else 0.0, + score=_score_value(s.score), metadata=s.metadata, ) for s in records diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index e498017..e65c36a 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -354,7 +354,10 @@ async def raw_search( tasks.append(self._search_snippet(query=query, user_id=user_id, top_k=top_k)) if tasks: - for records in await asyncio.gather(*tasks): + for records in await asyncio.gather(*tasks, return_exceptions=True): + if isinstance(records, Exception): + logger.warning("Raw memory search subquery failed: %s", records) + continue sources.extend(records) sources = self._rank_sources(sources) From 89fe3e4076e4e0059cba9d1569dc2224843b4af7 Mon Sep 17 00:00:00 2001 From: Michael Schvarcz Date: Wed, 13 May 2026 06:10:31 -0700 Subject: [PATCH 3/3] Handle missing source scores in answer formatting --- src/pipelines/retrieval.py | 3 ++- tests/integration/test_retrieval_pipeline.py | 23 ++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index e65c36a..6d16319 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -784,7 +784,8 @@ def _format_tool_results(self, records: List[SourceRecord]) -> str: lines = [] for i, rec in enumerate(records, 1): - score_str = f" (score: {rec.score:.2f})" if rec.score > 0 else "" + score = rec.score if rec.score is not None else 0.0 + score_str = f" (score: {score:.2f})" if score > 0 else "" lines.append(f"{i}. [{rec.domain}]{score_str} {rec.content}") return "\n".join(lines) diff --git a/tests/integration/test_retrieval_pipeline.py b/tests/integration/test_retrieval_pipeline.py index f280d6d..9a02d8c 100644 --- a/tests/integration/test_retrieval_pipeline.py +++ b/tests/integration/test_retrieval_pipeline.py @@ -3,6 +3,7 @@ import pytest from src.pipelines.retrieval import RetrievalPipeline +from src.schemas.retrieval import SourceRecord from tests.conftest import FakeChatModel, FakeLLMResponse @@ -141,6 +142,28 @@ async def test_raw_search_can_synthesize_answer_after_hits(vector_store, neo4j_c assert "raw_search_answer" in pipeline.latency_snapshot() +@pytest.mark.asyncio +async def test_answer_from_sources_handles_missing_source_scores(vector_store, neo4j_client): + model = FakeChatModel(responses=["Answer synthesized from a missing-score source."]) + pipeline = RetrievalPipeline(model=model, vector_store=vector_store, neo4j_client=neo4j_client) + + answer = await pipeline.answer_from_sources( + query="What is Alice building?", + sources=[ + SourceRecord( + domain="summary", + content="Alice is building raw search.", + score=None, + ), + ], + ) + + assert answer == "Answer synthesized from a missing-score source." + assert len(model.calls) == 1 + prompt = model.calls[0][0].content + assert "[summary] Alice is building raw search." in prompt + + @pytest.mark.asyncio async def test_retrieval_pipeline_reuses_cached_tool_plan(vector_store, neo4j_client): vector_store.seed(