Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions src/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import asyncio
import json
import os
from dataclasses import asdict
from typing import Dict

Expand All @@ -12,15 +13,29 @@
from src.core.dataclasses.config import ClientConfig


USER_ID_FILE = os.path.expanduser("~/.huri_user_id")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mettre dans un .env peut etre ? ou pas en vrai



def load_user_id() -> str | None:
if os.path.exists(USER_ID_FILE):
with open(USER_ID_FILE) as f:
return f.read().strip()
return None


def save_user_id(user_id: str):
with open(USER_ID_FILE, "w") as f:
f.write(user_id)

def load_client_config(path: str) -> ClientConfig:
with open(path) as f:
dict_config = OmegaConf.load(f)
raw_resolved = OmegaConf.to_container(dict_config, resolve=True)
raw_resolved = OmegaConf.to_container(dict_config, resolve=True)

if not isinstance(raw_resolved, Dict):
raise RuntimeError("error yaml does not output a dict")
if not isinstance(raw_resolved, Dict):
raise RuntimeError("error yaml does not output a dict")

return ClientConfig.from_dict(raw_resolved)
return ClientConfig.from_dict(raw_resolved)


async def stream_audio():
Expand All @@ -38,7 +53,19 @@ async def stream_audio():
async with websockets.connect(config.huri_url) as ws:
print("Connected to server")

await ws.send(json.dumps(asdict(config)))
payload = asdict(config)
user_id = load_user_id()
if user_id:
payload["user_id"] = user_id
print(f"Reconnecting with user_id: {user_id}")

await ws.send(json.dumps(payload))

init_msg = json.loads(await ws.recv())
if init_msg.get("type") == "session_init":
user_id = init_msg["user_id"]
save_user_id(user_id)
print(f"Session started with user_id: {user_id}")

async def receive(ws: websockets.ClientConnection):
while True:
Expand Down
12 changes: 8 additions & 4 deletions src/core/huri.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Je pense qu'il faut initialiser la Session avec le user_id pour eviter le if "rag"
Quitte a ajouter un ModuleWithId qui s'initialise avec un user id, dans le module Factory

Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,17 @@ def __init__(
@app.websocket("/session")
async def run_session(self, ws: WebSocket):
await ws.accept()

client_config_raw: Dict = await ws.receive_json()

client_config = ClientConfig.from_dict(client_config_raw)

user_id = client_config_raw.get("user_id") or str(uuid.uuid4())
await ws.send_json({"type": "session_init", "user_id": user_id})

if "rag" in client_config.modules:
if client_config.modules["rag"].args is None:
client_config.modules["rag"].args = {}
client_config.modules["rag"].args["user_id"] = user_id

senders: List[Module] = [
Sender(ws, topic) for topic in client_config.topic_list
]
Expand All @@ -43,9 +49,7 @@ async def run_session(self, ws: WebSocket):
)

session_id = str(uuid.uuid4())

self.clients[session_id] = Session(modules)

print("Client registered successfully with config:", client_config)

async def receive_loop(session: Session, ws: WebSocket):
Expand Down
3 changes: 2 additions & 1 deletion src/modules/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from src.modules.speech_to_text.record_speech import MIC
from src.modules.speech_to_text.speech_to_text import STT
from src.modules.speech_to_text.text_aggregator import TAG
from src.modules.rag.rag import RAG

from .factory import Module


def get_modules() -> Dict[str, Type[Module]]:
return {"mic": MIC, "stt": STT, "tag": TAG}
return {"mic": MIC, "stt": STT, "tag": TAG, "rag": RAG}
72 changes: 72 additions & 0 deletions src/modules/rag/ingestion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# ingestion.py
import argparse
import os
import uuid

from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
from sentence_transformers import SentenceTransformer

USER_ID_FILE = os.path.expanduser("~/.huri_user_id")


def get_user_id(provided_id: str = None) -> str:
"""Use provided ID, or load from file, or generate new one."""
if provided_id:
return provided_id
if os.path.exists(USER_ID_FILE):
with open(USER_ID_FILE) as f:
return f.read().strip()
new_id = str(uuid.uuid4())
with open(USER_ID_FILE, "w") as f:
f.write(new_id)
return new_id


def main():
parser = argparse.ArgumentParser(description="Ingest documents into Qdrant")
parser.add_argument("--user-id", type=str, default=None, help="User ID (reads from ~/.huri_user_id if not provided)")
parser.add_argument("--collection", type=str, default="documents")
parser.add_argument("--qdrant-url", type=str, default="http://localhost:6333")
args = parser.parse_args()

user_id = get_user_id(args.user_id)
print(f"Ingesting for user_id: {user_id}")

client = QdrantClient(url=args.qdrant_url)
model = SentenceTransformer("BAAI/bge-large-en-v1.5")

collections = [c.name for c in client.get_collections().collections]
if args.collection not in collections:
client.create_collection(
collection_name=args.collection,
vectors_config=VectorParams(size=1024, distance=Distance.COSINE),
)
print(f"Created collection: {args.collection}")

docs = [
{"text": "The company budget for 2026 is 2 million euros.", "source": "budget.pdf"},
{"text": "The project deadline is June 15th 2026.", "source": "planning.pdf"},
{"text": "The team consists of 5 developers and 2 designers.", "source": "team.pdf"},
{"text": "The main office is located in Paris, France.", "source": "info.pdf"},
]

points = []
for doc in docs:
vector = model.encode(doc["text"], normalize_embeddings=True).tolist()
points.append(PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"text": doc["text"],
"source": doc["source"],
"user_id": user_id,
},
))

client.upsert(collection_name=args.collection, points=points)
print(f"Ingested {len(points)} documents for user {user_id}")


if __name__ == "__main__":
main()
Loading