From 81826b7b4701b1ba0f91fdc0700e369713191db0 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Thu, 7 May 2026 12:06:18 +0100 Subject: [PATCH 1/4] wip(rag): V1 of a rag working with an ollama llm working with the current pipeline with a RAG + Embedding + LLM (ollama). Should work with vLLM but not tested --- src/modules/modules.py | 3 +- src/modules/rag/rag.py | 310 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 312 insertions(+), 1 deletion(-) create mode 100644 src/modules/rag/rag.py diff --git a/src/modules/modules.py b/src/modules/modules.py index 69ebb45..7283983 100644 --- a/src/modules/modules.py +++ b/src/modules/modules.py @@ -3,9 +3,10 @@ from src.modules.speech_to_text.record_speech import MIC from src.modules.speech_to_text.speech_to_text import STT from src.modules.speech_to_text.text_aggregator import TAG +from src.modules.rag.rag import RAG from .factory import Module def get_modules() -> Dict[str, Type[Module]]: - return {"mic": MIC, "stt": STT, "tag": TAG} + return {"mic": MIC, "stt": STT, "tag": TAG, "rag": RAG} diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py new file mode 100644 index 0000000..0884c76 --- /dev/null +++ b/src/modules/rag/rag.py @@ -0,0 +1,310 @@ +from typing import Any, Optional +from dataclasses import dataclass, field + +from ray import data, serve +from ray.serve import handle +from src.core.module import ModuleWithHandle +from qdrant_client.models import Filter, FieldCondition, MatchValue +from sentence_transformers import SentenceTransformer +from qdrant_client import QdrantClient + + +import httpx + + +@dataclass +class RAGQuery: + """What flows from RAG module to RAGHandle.""" + user_id: str + question: str + preferences: dict = field(default_factory=dict) + # preferences can include: language, tone, response_format, max_length, system_prompt, extra_instructions, etc. + + +@dataclass +class RAGResult: + """What RAGHandle returns.""" + answer: str + sources: list[dict] = field(default_factory=list) + + +@serve.deployment( + num_replicas=2, + ray_actor_options={"num_cpus": 1}, +) +class RAGHandle: + """ + Stateless RAG processor. Knows nothing about sessions. + Receives a user_id + question, uses user_id to find the right + collection/data in the vector DB, runs embed -> search -> LLM. + """ + + def __init__( + self, + qdrant_url: str = "http://localhost:6333", + default_collection: str = "documents", + embedding_model: str = "BAAI/bge-large-en-v1.5", + llm_provider: str = "ollama", # "vllm", "ollama", "api" + llm_url: str = "http://localhost:11434", + llm_model: str = "mistral:7b", + llm_api_key: str = "", + top_k: int = 5, + score_threshold: float = 0.5, + ): + self.embed_model = SentenceTransformer(embedding_model) + self.qdrant = QdrantClient(url=qdrant_url) + self.default_collection = default_collection + self.top_k = top_k + self.score_threshold = score_threshold + + self.llm_provider = llm_provider + self.llm_url = llm_url + self.llm_model = llm_model + self.llm_api_key = llm_api_key + + def _resolve_user_context(self, user_id: str) -> tuple[str, dict | None]: + """ + Given a user_id, decide which collection to search + and which filters to apply. + + Options (pick what fits your data model): + A) One collection per user: collection = f"user_{user_id}" + B) Shared collection, filter by user_id in payload + C) Lookup in a DB to find the user's config + """ + + # Option A: separate collection per user + # collection = f"user_{user_id}" + # filters = None + + # Option B: shared collection with user_id filter (recommended) + collection = self.default_collection + filters = {"user_id": user_id} + + return collection, filters + + + def _embed(self, text) -> list[float]: + return self.embed_model.encode(str(text), normalize_embeddings=True).tolist() + + + + def _search( + self, + query_vector: list[float], + collection: str, + filters: dict | None = None, + ) -> list[dict]: + + # Build qdrant filter from user context + qdrant_filter = None + if filters: + conditions = [ + FieldCondition(key=k, match=MatchValue(value=v)) + for k, v in filters.items() + ] + qdrant_filter = Filter(must=conditions) + + results = self.qdrant.query_points( + collection_name=collection, + query=query_vector, + query_filter=qdrant_filter, + limit=self.top_k, + score_threshold=self.score_threshold, + ).points + + return [ + { + "text": point.payload.get("text", ""), + "score": point.score, + "metadata": {k: v for k, v in point.payload.items() if k != "text"}, + } + for point in results + ] + + + def _build_prompt( + self, + question: str, + chunks: list[dict], + preferences: dict, + ) -> tuple[str, str]: + + parts = [ + "You are a helpful assistant. Answer based on the provided context.", + "If the context is insufficient, say so clearly.", + ] + if preferences.get("language"): + parts.append(f"Always respond in {preferences['language']}.") + if preferences.get("tone"): + parts.append(f"Use a {preferences['tone']} tone.") + if preferences.get("response_format") == "bullet_points": + parts.append("Format your answer as bullet points.") + elif preferences.get("response_format") == "short": + parts.append("Keep your answer to 2-3 sentences maximum.") + if preferences.get("extra_instructions"): + parts.append(preferences["extra_instructions"]) + system_prompt = " ".join(parts) + + if not chunks: + user_prompt = ( + "No relevant context was found.\n\n" + f"Question: {question}\n\n" + "Answer based on general knowledge and mention no documents were found." + ) + else: + context_parts = [] + for i, chunk in enumerate(chunks, 1): + source = chunk["metadata"].get("source", "unknown") + context_parts.append( + f"[{i}] (source: {source}, score: {chunk['score']:.2f})\n{chunk['text']}" + ) + context_block = "\n\n".join(context_parts) + user_prompt = ( + f"Context:\n{context_block}\n\n" + f"Question: {question}\n\n" + "Answer based on the context above. Cite sources by number." + ) + + return system_prompt, user_prompt + + + async def _llm_generate( + self, + system_prompt: str, + user_prompt: str, + preferences: dict, + ) -> str: + max_tokens = preferences.get("max_length", 1024) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + if self.llm_provider == "vllm": + return await self._call_openai_compatible( + f"{self.llm_url}/v1/chat/completions", messages, max_tokens + ) + elif self.llm_provider == "ollama": + return await self._call_ollama(messages, max_tokens) + elif self.llm_provider == "api": + return await self._call_openai_compatible( + f"{self.llm_url}/v1/chat/completions", messages, max_tokens, self.llm_api_key + ) + else: + raise ValueError(f"Unknown llm_provider: {self.llm_provider}") + + + async def _call_openai_compatible( + self, url: str, messages: list, max_tokens: int, api_key: str = "" + ) -> str: + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(url, headers=headers, json={ + "model": self.llm_model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": 0.1, + }) + resp.raise_for_status() + return resp.json()["choices"][0]["message"]["content"] + + + async def _call_ollama(self, messages: list, max_tokens: int) -> str: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(f"{self.llm_url}/api/chat", json={ + "model": self.llm_model, + "messages": messages, + "stream": False, + "options": {"num_predict": max_tokens, "temperature": 0.1}, + }) + resp.raise_for_status() + return resp.json()["message"]["content"] + + + async def process(self, query: RAGQuery) -> RAGResult: + """ + Main entry point. Called by the RAG module. + Uses user_id to determine which collection / filters to use. + """ + + print(f"[RAG] Question: {query.question}") + collection, filters = self._resolve_user_context(query.user_id) + query_vector = self._embed(query.question) + chunks = self._search(query_vector, collection, filters) + + + print(f"[RAG] Found {len(chunks)} chunks") + for c in chunks: + print(f" - score: {c['score']:.2f} | {c['text'][:100]}...") + + system_prompt, user_prompt = self._build_prompt( + query.question, chunks, query.preferences + ) + print(f"[RAG] System prompt: {system_prompt[:200]}...") + answer = await self._llm_generate(system_prompt, user_prompt, query.preferences) + print(f"[RAG] Answer: {answer}") + + return RAGResult( + answer=answer, + sources=[ + {"text": c["text"], "score": c["score"], "metadata": c["metadata"]} + for c in chunks + ], + ) + + +class RAG(ModuleWithHandle): + """ + Session-bound module. HuRI instantiates this when a client connects, + passing the user_id from the WebSocket config. + + Listens to "question" events. + Forwards question + user_id to the detached RAGHandle. + Emits "rag_response" event with the answer. + """ + _handle_cls = RAGHandle + input_type = "question" + output_type = "rag_response" + + def __init__( + self, + handle: handle.DeploymentHandle[RAGHandle], + user_id: str = "", + language: str = "en", + tone: str = "formal", + response_format: str = "paragraph", + max_length: int = 1024, + extra_instructions: str = "", + ): + super().__init__(handle) + self.user_id = user_id + self.preferences = { + "language": language, + "tone": tone, + "response_format": response_format, + "max_length": max_length, + "extra_instructions": extra_instructions, + } + + async def process(self, data) -> Optional[Any]: + """ + Called when a "question" event arrives through the event bus. + Packages user_id + question, sends to the stateless RAGHandle. + """ + question_text = data.text if hasattr(data, 'text') else str(data) + + query = RAGQuery( + user_id=self.user_id if self.user_id else "anonymous", + question=question_text, + preferences=self.preferences, + ) + + result: RAGResult = await self.handle.process.remote(query) + return result + + def update_preferences(self, new_preferences: dict): + """Client can update preferences mid-session via the event bus.""" + self.preferences.update(new_preferences) From 87588fbd4f3093afe2556b4abc26fee2fffb1ef2 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Thu, 7 May 2026 12:17:18 +0100 Subject: [PATCH 2/4] wip(rag): set the filter at None to be able to restrieve collections without a user_id --- src/modules/rag/rag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 0884c76..07a3244 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -79,7 +79,7 @@ def _resolve_user_context(self, user_id: str) -> tuple[str, dict | None]: # Option B: shared collection with user_id filter (recommended) collection = self.default_collection - filters = {"user_id": user_id} + filters = None #{"user_id": user_id} return collection, filters From 9b7dadea6f372792d867071451a63807d9513bb6 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Mon, 11 May 2026 15:39:53 +0100 Subject: [PATCH 3/4] feat(id): add ids to make it work with the rag system + an ingestion system --- src/client.py | 37 +++++++++++++++--- src/core/huri.py | 12 ++++-- src/modules/rag/ingestion.py | 75 ++++++++++++++++++++++++++++++++++++ src/modules/rag/rag.py | 8 ++-- 4 files changed, 119 insertions(+), 13 deletions(-) create mode 100644 src/modules/rag/ingestion.py diff --git a/src/client.py b/src/client.py index ca29146..2240709 100644 --- a/src/client.py +++ b/src/client.py @@ -1,6 +1,7 @@ import argparse import asyncio import json +import os from dataclasses import asdict from typing import Dict @@ -12,15 +13,29 @@ from src.core.dataclasses.config import ClientConfig +USER_ID_FILE = os.path.expanduser("~/.huri_user_id") + + +def load_user_id() -> str | None: + if os.path.exists(USER_ID_FILE): + with open(USER_ID_FILE) as f: + return f.read().strip() + return None + + +def save_user_id(user_id: str): + with open(USER_ID_FILE, "w") as f: + f.write(user_id) + def load_client_config(path: str) -> ClientConfig: with open(path) as f: dict_config = OmegaConf.load(f) - raw_resolved = OmegaConf.to_container(dict_config, resolve=True) + raw_resolved = OmegaConf.to_container(dict_config, resolve=True) - if not isinstance(raw_resolved, Dict): - raise RuntimeError("error yaml does not output a dict") + if not isinstance(raw_resolved, Dict): + raise RuntimeError("error yaml does not output a dict") - return ClientConfig.from_dict(raw_resolved) + return ClientConfig.from_dict(raw_resolved) async def stream_audio(): @@ -38,7 +53,19 @@ async def stream_audio(): async with websockets.connect(config.huri_url) as ws: print("Connected to server") - await ws.send(json.dumps(asdict(config))) + payload = asdict(config) + user_id = load_user_id() + if user_id: + payload["user_id"] = user_id + print(f"Reconnecting with user_id: {user_id}") + + await ws.send(json.dumps(payload)) + + init_msg = json.loads(await ws.recv()) + if init_msg.get("type") == "session_init": + user_id = init_msg["user_id"] + save_user_id(user_id) + print(f"Session started with user_id: {user_id}") async def receive(ws: websockets.ClientConnection): while True: diff --git a/src/core/huri.py b/src/core/huri.py index 6d6d747..a07c4fe 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -30,11 +30,17 @@ def __init__( @app.websocket("/session") async def run_session(self, ws: WebSocket): await ws.accept() - client_config_raw: Dict = await ws.receive_json() - client_config = ClientConfig.from_dict(client_config_raw) + user_id = client_config_raw.get("user_id") or str(uuid.uuid4()) + await ws.send_json({"type": "session_init", "user_id": user_id}) + + if "rag" in client_config.modules: + if client_config.modules["rag"].args is None: + client_config.modules["rag"].args = {} + client_config.modules["rag"].args["user_id"] = user_id + senders: List[Module] = [ Sender(ws, topic) for topic in client_config.topic_list ] @@ -43,9 +49,7 @@ async def run_session(self, ws: WebSocket): ) session_id = str(uuid.uuid4()) - self.clients[session_id] = Session(modules) - print("Client registered successfully with config:", client_config) async def receive_loop(session: Session, ws: WebSocket): diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py new file mode 100644 index 0000000..d517c67 --- /dev/null +++ b/src/modules/rag/ingestion.py @@ -0,0 +1,75 @@ +# ingestion.py +import argparse +import os +import uuid + +from qdrant_client import QdrantClient +from qdrant_client.models import VectorParams, Distance, PointStruct +from sentence_transformers import SentenceTransformer + +USER_ID_FILE = os.path.expanduser("~/.huri_user_id") + + +def get_user_id(provided_id: str = None) -> str: + """Use provided ID, or load from file, or generate new one.""" + if provided_id: + return provided_id + if os.path.exists(USER_ID_FILE): + with open(USER_ID_FILE) as f: + return f.read().strip() + new_id = str(uuid.uuid4()) + with open(USER_ID_FILE, "w") as f: + f.write(new_id) + return new_id + + +def main(): + parser = argparse.ArgumentParser(description="Ingest documents into Qdrant") + parser.add_argument("--user-id", type=str, default=None, help="User ID (reads from ~/.huri_user_id if not provided)") + parser.add_argument("--collection", type=str, default="documents") + parser.add_argument("--qdrant-url", type=str, default="http://localhost:6333") + args = parser.parse_args() + + user_id = get_user_id(args.user_id) + print(f"Ingesting for user_id: {user_id}") + + client = QdrantClient(url=args.qdrant_url) + model = SentenceTransformer("BAAI/bge-large-en-v1.5") + + # Create collection if it doesn't exist + collections = [c.name for c in client.get_collections().collections] + if args.collection not in collections: + client.create_collection( + collection_name=args.collection, + vectors_config=VectorParams(size=1024, distance=Distance.COSINE), + ) + print(f"Created collection: {args.collection}") + + # Sample documents + docs = [ + {"text": "The company budget for 2026 is 2 million euros.", "source": "budget.pdf"}, + {"text": "The project deadline is June 15th 2026.", "source": "planning.pdf"}, + {"text": "The team consists of 5 developers and 2 designers.", "source": "team.pdf"}, + {"text": "The main office is located in Paris, France.", "source": "info.pdf"}, + ] + + # Embed and insert with user_id + points = [] + for doc in docs: + vector = model.encode(doc["text"], normalize_embeddings=True).tolist() + points.append(PointStruct( + id=str(uuid.uuid4()), + vector=vector, + payload={ + "text": doc["text"], + "source": doc["source"], + "user_id": user_id, # ← scoped to this user + }, + )) + + client.upsert(collection_name=args.collection, points=points) + print(f"Ingested {len(points)} documents for user {user_id}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 07a3244..e37db92 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -79,7 +79,7 @@ def _resolve_user_context(self, user_id: str) -> tuple[str, dict | None]: # Option B: shared collection with user_id filter (recommended) collection = self.default_collection - filters = None #{"user_id": user_id} + filters = {"user_id": user_id} return collection, filters @@ -131,7 +131,7 @@ def _build_prompt( ) -> tuple[str, str]: parts = [ - "You are a helpful assistant. Answer based on the provided context.", + "You are a robot speaking to a user. Answer based on the provided context.", "If the context is insufficient, say so clearly.", ] if preferences.get("language"): @@ -150,7 +150,7 @@ def _build_prompt( user_prompt = ( "No relevant context was found.\n\n" f"Question: {question}\n\n" - "Answer based on general knowledge and mention no documents were found." + "Answer based on general knowledge." ) else: context_parts = [] @@ -163,7 +163,7 @@ def _build_prompt( user_prompt = ( f"Context:\n{context_block}\n\n" f"Question: {question}\n\n" - "Answer based on the context above. Cite sources by number." + "Answer based on the context above. Don't speak about the sources, just use them to answer the question." ) return system_prompt, user_prompt From 52b2cc567f51ce3861bcfa8f6256c724a3f28a9f Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Mon, 11 May 2026 15:43:54 +0100 Subject: [PATCH 4/4] clean(id): clean code --- src/modules/rag/ingestion.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py index d517c67..5c458d1 100644 --- a/src/modules/rag/ingestion.py +++ b/src/modules/rag/ingestion.py @@ -36,7 +36,6 @@ def main(): client = QdrantClient(url=args.qdrant_url) model = SentenceTransformer("BAAI/bge-large-en-v1.5") - # Create collection if it doesn't exist collections = [c.name for c in client.get_collections().collections] if args.collection not in collections: client.create_collection( @@ -45,7 +44,6 @@ def main(): ) print(f"Created collection: {args.collection}") - # Sample documents docs = [ {"text": "The company budget for 2026 is 2 million euros.", "source": "budget.pdf"}, {"text": "The project deadline is June 15th 2026.", "source": "planning.pdf"}, @@ -53,7 +51,6 @@ def main(): {"text": "The main office is located in Paris, France.", "source": "info.pdf"}, ] - # Embed and insert with user_id points = [] for doc in docs: vector = model.encode(doc["text"], normalize_embeddings=True).tolist() @@ -63,7 +60,7 @@ def main(): payload={ "text": doc["text"], "source": doc["source"], - "user_id": user_id, # ← scoped to this user + "user_id": user_id, }, ))