Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 76 additions & 10 deletions src/api/routes/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from src.api.dependencies import (
enforce_rate_limit,
get_code_pipeline,
get_ingest_pipeline,
get_retrieval_pipeline,
require_api_key,
Expand Down Expand Up @@ -69,6 +70,10 @@ def _model_name(model: Any) -> str:
return getattr(model, "model", getattr(model, "model_name", "unknown"))


def _score_value(score: float | None) -> float:
return round(score, 3) if score is not None else 0.0


def _build_domain_result(judge: Any, weaver: Any) -> DomainResult | None:
if not judge or not getattr(judge, "operations", None):
return None
Expand Down Expand Up @@ -660,7 +665,7 @@ async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = D
sources=[
SourceRecord(
domain=s.domain, content=s.content,
score=round(s.score, 3), metadata=s.metadata,
score=_score_value(s.score), metadata=s.metadata,
)
for s in result.sources
],
Expand Down Expand Up @@ -689,16 +694,77 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen
user_id = user.get("username") or user.get("name") or user["id"]

try:
all_results: List[SourceRecord] = []

if "profile" in req.domains:
all_results.extend(_search_profile(pipeline, user_id))
if "temporal" in req.domains:
all_results.extend(_search_temporal(pipeline, req.query, user_id, req.top_k))
if "summary" in req.domains:
all_results.extend(await _search_summary(pipeline, req.query, user_id, req.top_k))
if "code" in req.domains and not req.org_id:
elapsed = round((time.perf_counter() - start) * 1000, 2)
return _error(request, "org_id is required when searching the code domain.", 400, elapsed)

memory_domains = [domain for domain in req.domains if domain != "code"]
result = await pipeline.raw_search(
query=req.query,
user_id=user_id,
domains=memory_domains,
top_k=req.top_k,
include_answer=False,
)
records = list(result.sources)

if "code" in req.domains:
code_pipeline = get_code_pipeline(org_id=req.org_id or "", repo=req.repo)
code_results = await asyncio.gather(
code_pipeline._execute_tool(
tool_name="search_symbols",
tool_args={"query": req.query, "repo": req.repo},
repo=req.repo,
top_k=req.top_k,
user_id=user_id,
),
code_pipeline._execute_tool(
tool_name="search_files",
tool_args={"query": req.query, "repo": req.repo},
repo=req.repo,
top_k=req.top_k,
user_id=user_id,
),
return_exceptions=True,
)
Comment on lines +713 to +729
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If any of the code retrieval tool calls fail, the entire search request will return a 500 error. For a more robust search experience, consider using asyncio.gather(..., return_exceptions=True) and filtering out failed results, allowing the user to see memory hits even if the code domain search is temporarily unavailable.

for code_records in code_results:
if isinstance(code_records, Exception):
logger.warning("Code search subquery failed: %s", code_records)
continue
records.extend(code_records)

records = sorted(records, key=lambda s: s.score or 0.0, reverse=True)

answer = ""
if req.answer:
answer = await pipeline.answer_from_sources(query=req.query, sources=records)
pipeline._record_latency(
"raw_search_answer",
(time.perf_counter() - start) * 1000,
)
elif "code" in req.domains:
pipeline._record_latency(
"raw_search_code",
(time.perf_counter() - start) * 1000,
)

data = SearchResponse(results=all_results, total=len(all_results))
confidence = pipeline.confidence_from_sources(records)
data = SearchResponse(
results=[
SourceRecord(
domain=s.domain,
content=s.content,
score=_score_value(s.score),
metadata=s.metadata,
)
for s in records
],
total=len(records),
answer=answer,
model=_model_name(pipeline.model) if req.answer else "",
confidence=confidence,
latency=pipeline.latency_snapshot(),
)
elapsed = round((time.perf_counter() - start) * 1000, 2)
return _wrap(request, data, elapsed)

Expand Down
28 changes: 26 additions & 2 deletions src/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,24 +159,48 @@ class SearchRequest(BaseModel):
..., min_length=1, max_length=256, pattern=r"^[\w.\-@]+$",
)
domains: List[str] = Field(
default=["profile", "temporal", "summary"],
default=["profile", "temporal", "summary", "snippet"],
description="Which memory domains to search",
)
top_k: int = Field(default=10, ge=1, le=100)
answer: bool = Field(
default=False,
description="When true, synthesize an LLM answer from the raw hits.",
)
org_id: Optional[str] = Field(
default=None,
min_length=1,
max_length=256,
description="Required when including the code domain.",
)
repo: str = Field(
default="",
max_length=256,
description="Optional repository scope for code search.",
)

@field_validator("domains")
@classmethod
def validate_domains(cls, v: List[str]) -> List[str]:
allowed = {"profile", "temporal", "summary"}
allowed = {"profile", "temporal", "summary", "snippet", "code"}
for d in v:
if d not in allowed:
raise ValueError(f"Invalid domain '{d}'. Allowed: {allowed}")
return v

@field_validator("query")
@classmethod
def strip_search_query(cls, v: str) -> str:
return v.strip()


class SearchResponse(BaseModel):
results: List[SourceRecord] = Field(default_factory=list)
total: int = 0
answer: str = ""
model: str = ""
confidence: float = 0.0
latency: Dict[str, Dict[str, float]] = Field(default_factory=dict)


# ── Scrape (extract from shared chat links) ────────────────────────────────
Expand Down
Loading
Loading