From 0b491b3643eeb99196194717a66afdb265f8a15e Mon Sep 17 00:00:00 2001 From: Roberto Musso Date: Mon, 23 Mar 2026 00:23:59 +0100 Subject: [PATCH] fix: langfuse v4 SDK compatibility and pass user message as trace input --- .env.example | 6 +- README.md | 2 +- app/core/llm.py | 12 +- requirements.txt | 1 + services/chat/app/deep_agent.py | 16 ++- services/chat/app/main.py | 5 + services/chat/app/redis_consumer.py | 138 +++++++++---------- services/chat/app/tracing.py | 207 ++++++++++++++-------------- services/chat/requirements.txt | 2 +- services/ws-gateway/app/main.py | 7 + tests/test_e2e_flow.py | 124 +++++++++++++++++ 11 files changed, 330 insertions(+), 190 deletions(-) create mode 100644 tests/test_e2e_flow.py diff --git a/.env.example b/.env.example index 2c54566..8038cb9 100644 --- a/.env.example +++ b/.env.example @@ -25,7 +25,6 @@ OPENAI_API_KEY= ANTHROPIC_API_KEY= GOOGLE_API_KEY= LLM_MODEL=gpt-4o -LLM_ROUTER_MODEL=gpt-4o-mini # ── Stripe (leave empty to stub billing) ────────────────────────────────────── STRIPE_SECRET_KEY= @@ -50,3 +49,8 @@ QDRANT_API_KEY= # ── CORS ────────────────────────────────────────────────────────────────────── # Comma-separated list parsed by Settings (override default if needed) # CORS_ORIGINS=["app://.","http://localhost:3000"] + +# ── Langfuse (observability) ───────────────────────────────────────────────── +LANGFUSE_SECRET_KEY=sk-lf-... +LANGFUSE_PUBLIC_KEY=pk-lf-... +LANGFUSE_HOST=https://cloud.langfuse.com # or self-hosted URL \ No newline at end of file diff --git a/README.md b/README.md index 19da6ea..ca039c9 100644 --- a/README.md +++ b/README.md @@ -739,7 +739,7 @@ adiuva-api/ │ │ │ ├── core/ # Orchestration engine │ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry -│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm) +│ │ ├── llm.py # LiteLLM factory (get_llm) │ │ ├── orchestrator.py # Intent classification & routing │ │ └── execution_plan.py # Plan builder, templates, cache │ │ diff --git a/app/core/llm.py b/app/core/llm.py index 3415921..1787ce9 100644 --- a/app/core/llm.py +++ b/app/core/llm.py @@ -1,6 +1,6 @@ """LLM factory — centralised model instantiation via LiteLLM. -Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()`` +Every agent and the orchestrator call ``get_llm()`` instead of directly constructing a provider-specific class. The model string follows the `LiteLLM model naming convention `_: @@ -11,7 +11,7 @@ follows the `LiteLLM model naming convention * Ollama: ``ollama/llama3`` * Bedrock: ``bedrock/anthropic.claude-v2`` -Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env`` +Switch providers by changing **LLM_MODEL** in ``.env`` — no code changes required. """ @@ -95,14 +95,6 @@ def get_llm( ) -def get_router_llm( - *, - temperature: float = 0, -) -> ChatOpenAI | ChatLiteLLM: - """Return the lighter model used for intent classification / routing.""" - return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature) - - async def embed(text: str) -> list[float]: """Return an embedding vector for *text*. diff --git a/requirements.txt b/requirements.txt index bd15886..e707b0d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,4 +33,5 @@ google-auth-httplib2>=0.2.0 msal>=1.28.0 cryptography>=42.0.0 redis>=5.0.0 +langfuse>=3.0.0 ruff>=0.8.0 diff --git a/services/chat/app/deep_agent.py b/services/chat/app/deep_agent.py index 1472b15..486bbb7 100644 --- a/services/chat/app/deep_agent.py +++ b/services/chat/app/deep_agent.py @@ -528,7 +528,9 @@ def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> return {"type": "task", "id": None, "section": None} -async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]: +async def _infer_floating_domain( + message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None, +) -> dict[str, str | None]: resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None @@ -538,10 +540,14 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[ } try: - llm = get_llm() + classifier_prompt = _get_system_prompt( + "floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_SYSTEM, + ) + callbacks = _build_callbacks(langfuse_handler) + llm = get_llm(callbacks=callbacks) response = await llm.ainvoke( [ - SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_SYSTEM), + SystemMessage(content=classifier_prompt), HumanMessage( content=( f"Message:\n{message}\n\n" @@ -784,7 +790,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any], *, langf async def run_floating(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> tuple[str, dict[str, str | None]]: prepared_context = await _prepare_context(message, context) - domain = await _infer_floating_domain(message, prepared_context) + domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler) system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM) response = await _run_single_agent( user_id=user_id, @@ -835,7 +841,7 @@ async def run_floating_stream( langfuse_handler: Any | None = None, ) -> AsyncGenerator[tuple[str, Any], None]: prepared_context = await _prepare_context(message, context) - domain = await _infer_floating_domain(message, prepared_context) + domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler) yield "floating_domain", domain system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM) diff --git a/services/chat/app/main.py b/services/chat/app/main.py index 1ac3bad..1a41daf 100644 --- a/services/chat/app/main.py +++ b/services/chat/app/main.py @@ -31,6 +31,11 @@ logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING) @asynccontextmanager async def lifespan(app: FastAPI): + # Initialise Langfuse tracing (no-op if keys are missing) + from app.tracing import init_langfuse + + init_langfuse() + # Start Redis consumer in background from app.redis_consumer import start_consumer diff --git a/services/chat/app/redis_consumer.py b/services/chat/app/redis_consumer.py index 70689d2..b663b0a 100644 --- a/services/chat/app/redis_consumer.py +++ b/services/chat/app/redis_consumer.py @@ -85,52 +85,51 @@ async def _handle_home_request(user_id: str, frame: dict) -> None: user_id, request_id, message[:200], ) - # Create Langfuse trace - trace = tracing.create_trace( + response_chunks: list[str] = [] + + with tracing.trace_span( name="home_request", user_id=user_id, session_id=session_id, trace_id=request_id, + input=message, metadata={"message_preview": message[:200]}, tags=["home"], - ) - langfuse_handler = tracing.get_langfuse_callback( - trace=trace, span_name="home_agent", - ) + ) as span: + langfuse_handler = tracing.get_langfuse_callback() - # Enrich with memory context - async with async_session() as db: - memory = MemoryMiddleware(db) - memory_context = await memory.enrich_context( - user_id, message, - trace_id=request_id, session_id=session_id, - ) + # Enrich with memory context + async with async_session() as db: + memory = MemoryMiddleware(db) + memory_context = await memory.enrich_context( + user_id, message, + trace_id=request_id, session_id=session_id, + ) - context: dict = { - "conversation_history": frame.get("conversation_history", []), - "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, - **memory_context, - } + context: dict = { + "conversation_history": frame.get("conversation_history", []), + "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, + **memory_context, + } - set_current_user(user_id) - response_chunks: list[str] = [] - try: - event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler) - formatter = StreamFormatter(request_id=request_id) - async for ws_frame in formatter.format(event_stream): - await _publish_frame(user_id, ws_frame.model_dump_json()) - if hasattr(ws_frame, "chunk"): - response_chunks.append(ws_frame.chunk) - except Exception as exc: - logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc) - finally: - clear_current_user() + set_current_user(user_id) + try: + event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler) + formatter = StreamFormatter(request_id=request_id) + async for ws_frame in formatter.format(event_stream): + await _publish_frame(user_id, ws_frame.model_dump_json()) + if hasattr(ws_frame, "chunk"): + response_chunks.append(ws_frame.chunk) + except Exception as exc: + logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc) + finally: + clear_current_user() - # Link prompt and flush trace - if trace is not None: - tracing.link_prompt_to_trace(trace, "home_system") + # Link prompt and attach output preview + tracing.link_prompt_to_trace(span, "home_system") response_text = "".join(response_chunks) - trace.update(output=response_text[:500] if response_text else None) + span.update(output=response_text[:500] if response_text else None) + tracing.flush() # Store episode @@ -154,52 +153,51 @@ async def _handle_floating_request(user_id: str, frame: dict) -> None: user_id, request_id, json.dumps(scope)[:200], message[:200], ) - # Create Langfuse trace - trace = tracing.create_trace( + response_chunks: list[str] = [] + + with tracing.trace_span( name="floating_request", user_id=user_id, session_id=session_id, trace_id=request_id, + input=message, metadata={"message_preview": message[:200], "scope": scope}, tags=["floating"], - ) - langfuse_handler = tracing.get_langfuse_callback( - trace=trace, span_name="floating_agent", - ) + ) as span: + langfuse_handler = tracing.get_langfuse_callback() - # Enrich with memory context - async with async_session() as db: - memory = MemoryMiddleware(db) - memory_context = await memory.enrich_context( - user_id, message, - trace_id=request_id, session_id=session_id, - ) + # Enrich with memory context + async with async_session() as db: + memory = MemoryMiddleware(db) + memory_context = await memory.enrich_context( + user_id, message, + trace_id=request_id, session_id=session_id, + ) - context: dict = { - "scope": scope, - "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, - **memory_context, - } + context: dict = { + "scope": scope, + "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, + **memory_context, + } - set_current_user(user_id) - response_chunks: list[str] = [] - try: - event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler) - formatter = StreamFormatter(request_id=request_id) - async for ws_frame in formatter.format(event_stream): - await _publish_frame(user_id, ws_frame.model_dump_json()) - if hasattr(ws_frame, "chunk"): - response_chunks.append(ws_frame.chunk) - except Exception as exc: - logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc) - finally: - clear_current_user() + set_current_user(user_id) + try: + event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler) + formatter = StreamFormatter(request_id=request_id) + async for ws_frame in formatter.format(event_stream): + await _publish_frame(user_id, ws_frame.model_dump_json()) + if hasattr(ws_frame, "chunk"): + response_chunks.append(ws_frame.chunk) + except Exception as exc: + logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc) + finally: + clear_current_user() - # Link prompt and flush trace - if trace is not None: - tracing.link_prompt_to_trace(trace, "floating_system") + # Link prompt and attach output preview + tracing.link_prompt_to_trace(span, "floating_system") response_text = "".join(response_chunks) - trace.update(output=response_text[:500] if response_text else None) + span.update(output=response_text[:500] if response_text else None) + tracing.flush() # Store episode diff --git a/services/chat/app/tracing.py b/services/chat/app/tracing.py index aa95d28..d115d5b 100644 --- a/services/chat/app/tracing.py +++ b/services/chat/app/tracing.py @@ -1,137 +1,156 @@ -"""Langfuse tracing & prompt management for the Chat Service. +"""Langfuse tracing & prompt management for the Chat Service (v4 SDK). Provides: -- ``langfuse`` — singleton Langfuse client (lazy, no-op when keys are missing) -- ``create_trace()`` — start a new trace for a chat request -- ``get_langfuse_callback()`` — LangChain callback handler for a trace/span +- ``init_langfuse()`` — initialise the singleton client at startup +- ``trace_span()`` — context manager that creates a trace + span +- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace) - ``get_prompt()`` — fetch a managed prompt from Langfuse by name -- ``flush()`` — ensure all events are sent before shutdown +- ``flush()`` / ``shutdown()`` — lifecycle management All functions gracefully degrade to no-ops when Langfuse is not configured, so the service works identically with or without observability keys. + +Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK). """ from __future__ import annotations import logging +from contextlib import contextmanager from typing import Any from shared.config import settings logger = logging.getLogger(__name__) -# ── Lazy singleton ─────────────────────────────────────────────────────── +# ── State ──────────────────────────────────────────────────────────────── -_langfuse_client: Any | None = None -_langfuse_disabled: bool = False +_initialised: bool = False +_disabled: bool = False def _is_configured() -> bool: return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY) -def _get_langfuse() -> Any | None: - """Return the Langfuse client singleton, or None if not configured.""" - global _langfuse_client, _langfuse_disabled +def init_langfuse() -> None: + """Initialise the Langfuse singleton. Call once at startup.""" + global _initialised, _disabled - if _langfuse_disabled: - return None - - if _langfuse_client is not None: - return _langfuse_client + if _initialised or _disabled: + return if not _is_configured(): - _langfuse_disabled = True + _disabled = True logger.info("tracing: Langfuse keys not set — tracing disabled") - return None + return try: from langfuse import Langfuse - _langfuse_client = Langfuse( + Langfuse( secret_key=settings.LANGFUSE_SECRET_KEY, public_key=settings.LANGFUSE_PUBLIC_KEY, host=settings.LANGFUSE_HOST, ) + _initialised = True logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST) - return _langfuse_client except Exception as exc: - _langfuse_disabled = True + _disabled = True logger.warning("tracing: failed to initialise Langfuse: %s", exc) + + +def _get_client() -> Any | None: + """Return the singleton Langfuse client, or *None* if disabled.""" + if _disabled: + return None + if not _initialised: + init_langfuse() + if _disabled: + return None + try: + from langfuse import get_client + return get_client() + except Exception: return None -# ── Trace lifecycle ────────────────────────────────────────────────────── +# ── Null span (no-op when Langfuse is disabled) ───────────────────────── -def create_trace( +class _NullSpan: + """Drop-in replacement when Langfuse is disabled.""" + + def update(self, **_: Any) -> None: ... + def set_trace_io(self, **_: Any) -> None: ... + def score_trace(self, **_: Any) -> None: ... + + +# ── Trace context manager ─────────────────────────────────────────────── + + +@contextmanager +def trace_span( *, name: str, user_id: str, session_id: str | None = None, trace_id: str | None = None, + input: Any = None, metadata: dict[str, Any] | None = None, tags: list[str] | None = None, -) -> Any | None: - """Create a Langfuse trace. Returns the trace object, or None if disabled.""" - lf = _get_langfuse() +): + """Context manager that creates a Langfuse trace/span. + + Yields the span object (or a ``_NullSpan`` if Langfuse is disabled). + A ``CallbackHandler`` created inside this block auto-inherits the trace + context, so there is no need to pass trace IDs manually. + """ + lf = _get_client() if lf is None: - return None + yield _NullSpan() + return try: - return lf.trace( - id=trace_id, + from langfuse import Langfuse, propagate_attributes + + trace_ctx: dict[str, str] = {} + if trace_id is not None: + trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id) + + with lf.start_as_current_observation( + as_type="span", name=name, - user_id=user_id, - session_id=session_id, + input=input, metadata=metadata or {}, - tags=tags or [], - ) + **({"trace_context": trace_ctx} if trace_ctx else {}), + ) as span: + with propagate_attributes( + user_id=user_id, + session_id=session_id, + tags=tags or [], + ): + yield span except Exception as exc: - logger.warning("tracing: create_trace failed: %s", exc) - return None + logger.warning("tracing: trace_span(%s) failed: %s", name, exc) + yield _NullSpan() # ── LangChain callback handler ────────────────────────────────────────── -def get_langfuse_callback( - *, - trace_id: str | None = None, - trace: Any | None = None, - span_name: str | None = None, - update_parent: bool = True, -) -> Any | None: - """Return a ``CallbackHandler`` wired to an existing trace. +def get_langfuse_callback() -> Any | None: + """Return a LangChain ``CallbackHandler`` that auto-inherits the current trace. - This handler is passed to LangChain's ``ainvoke`` / ``astream`` as a - callback so every LLM generation and tool call is automatically - captured as a nested span inside the trace. - - If both *trace* and *trace_id* are given, *trace* takes precedence. - Returns None when Langfuse is disabled. + Must be called inside a ``trace_span()`` block for proper linking. + Returns *None* when Langfuse is disabled. """ - lf = _get_langfuse() - if lf is None: + if _disabled and not _initialised: return None try: - from langfuse.callback import CallbackHandler - - kwargs: dict[str, Any] = { - "secret_key": settings.LANGFUSE_SECRET_KEY, - "public_key": settings.LANGFUSE_PUBLIC_KEY, - "host": settings.LANGFUSE_HOST, - "update_parent": update_parent, - } - if trace is not None: - kwargs["trace_id"] = trace.id - elif trace_id is not None: - kwargs["trace_id"] = trace_id - if span_name: - kwargs["root_span"] = span_name - - return CallbackHandler(**kwargs) + from langfuse.langchain import CallbackHandler + return CallbackHandler() except Exception as exc: logger.warning("tracing: get_langfuse_callback failed: %s", exc) return None @@ -152,21 +171,8 @@ def get_prompt( Returns the compiled prompt string, or *fallback* if the prompt is not found or Langfuse is disabled. - - Parameters - ---------- - name : str - Prompt name as registered in Langfuse. - version : int, optional - Pin to a specific version; omit for the latest production version. - label : str, optional - Fetch by label (e.g. ``"production"``, ``"staging"``). - fallback : str, optional - Value returned when the prompt cannot be fetched. - cache_ttl_seconds : int - How long to cache the prompt locally (default 5 min). """ - lf = _get_langfuse() + lf = _get_client() if lf is None: return fallback @@ -187,20 +193,15 @@ def get_prompt( def link_prompt_to_trace( - trace: Any, + span: Any, prompt_name: str, *, version: int | None = None, label: str | None = None, ) -> None: - """Attach a Langfuse prompt reference to a trace/generation. - - Call this *after* creating a generation on the trace to associate the - prompt that was used. The prompt object is fetched and linked so - Langfuse can display prompt→trace associations in the dashboard. - """ - lf = _get_langfuse() - if lf is None or trace is None: + """Attach prompt metadata to a span/trace.""" + lf = _get_client() + if lf is None or isinstance(span, _NullSpan): return try: @@ -210,7 +211,7 @@ def link_prompt_to_trace( if label is not None: kwargs["label"] = label prompt = lf.get_prompt(**kwargs) - trace.update(metadata={"prompt": {"name": prompt_name, "version": prompt.version}}) + span.update(metadata={"prompt": {"name": prompt_name, "version": prompt.version}}) except Exception as exc: logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc) @@ -226,12 +227,12 @@ def score_trace( comment: str | None = None, ) -> None: """Post a score to a trace (e.g. user feedback, latency, quality).""" - lf = _get_langfuse() + lf = _get_client() if lf is None: return try: - lf.score(trace_id=trace_id, name=name, value=value, comment=comment) + lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment) except Exception as exc: logger.warning("tracing: score_trace failed: %s", exc) @@ -240,22 +241,24 @@ def score_trace( def flush() -> None: - """Flush pending Langfuse events. Call this on service shutdown.""" - if _langfuse_client is not None: + """Flush pending Langfuse events.""" + lf = _get_client() + if lf is not None: try: - _langfuse_client.flush() + lf.flush() except Exception as exc: logger.warning("tracing: flush failed: %s", exc) def shutdown() -> None: """Flush and close the Langfuse client.""" - global _langfuse_client, _langfuse_disabled - if _langfuse_client is not None: + global _initialised, _disabled + lf = _get_client() + if lf is not None: try: - _langfuse_client.flush() - _langfuse_client.shutdown() + lf.flush() + lf.shutdown() except Exception as exc: logger.warning("tracing: shutdown failed: %s", exc) - _langfuse_client = None - _langfuse_disabled = False + _initialised = False + _disabled = False diff --git a/services/chat/requirements.txt b/services/chat/requirements.txt index 21d884c..b57bb10 100644 --- a/services/chat/requirements.txt +++ b/services/chat/requirements.txt @@ -14,4 +14,4 @@ langchain-litellm>=0.3.0 litellm>=1.50.0 openai>=1.50.0 httpx>=0.27.0 -langfuse>=2.0.0 +langfuse>=3.0.0 diff --git a/services/ws-gateway/app/main.py b/services/ws-gateway/app/main.py index 49ff4d9..00ced31 100644 --- a/services/ws-gateway/app/main.py +++ b/services/ws-gateway/app/main.py @@ -6,8 +6,15 @@ and routes frames between Electron and downstream services via Redis pub/sub. This service has NO business logic — it only routes JSON frames. """ +import sys from contextlib import asynccontextmanager import logging +from pathlib import Path + +# Ensure the repo root is on sys.path so "shared" is importable in local dev. +_repo_root = str(Path(__file__).resolve().parents[3]) +if _repo_root not in sys.path: + sys.path.insert(0, _repo_root) from fastapi import FastAPI from shared.config import settings diff --git a/tests/test_e2e_flow.py b/tests/test_e2e_flow.py new file mode 100644 index 0000000..d961d3b --- /dev/null +++ b/tests/test_e2e_flow.py @@ -0,0 +1,124 @@ +"""End-to-end test: Auth → WS Gateway → Chat Service round-trip. + +Usage (from repo root, with venv activated): + python test_e2e_flow.py + +Requires: Auth (8001), WS Gateway (8002), Chat (8003) all running. +""" + +import asyncio +import json +import uuid + +import httpx +import websockets + +AUTH_URL = "http://127.0.0.1:8001/api/v1/auth" +WS_URL = "ws://127.0.0.1:8002/api/v1/ws/device" + +# ── 1. Authenticate ───────────────────────────────────────────────── + + +async def get_token() -> str: + async with httpx.AsyncClient() as client: + # Try login first, register if user doesn't exist + resp = await client.post( + f"{AUTH_URL}/login", + json={"email": "e2e@test.com", "password": "Test1234!"}, + ) + if resp.status_code == 200: + print("[1/4] Logged in as e2e@test.com") + return resp.json()["access_token"] + + resp = await client.post( + f"{AUTH_URL}/register", + json={ + "email": "e2e@test.com", + "password": "Test1234!", + "name": "E2E", + "surname": "Test", + }, + ) + resp.raise_for_status() + print("[1/4] Registered + logged in as e2e@test.com") + return resp.json()["access_token"] + + +# ── 2. WebSocket flow ─────────────────────────────────────────────── + + +async def run_e2e(): + token = await get_token() + + uri = f"{WS_URL}?token={token}" + async with websockets.connect(uri) as ws: + # Send device_hello + await ws.send(json.dumps({ + "type": "device_hello", + "device_id": str(uuid.uuid4()), + "agent_ids": ["task", "note", "project", "timeline"], + })) + print("[2/4] Device registered with WS Gateway") + + # Send a home_request (simple greeting — unlikely to need tools) + await ws.send(json.dumps({ + "type": "home_request", + "message": "Hello! How are you doing today?", + "context": {}, + })) + print("[3/4] Sent home_request → waiting for Chat Service response...") + + # Listen for response frames (text_chunk, tool_call, final) + full_response = [] + try: + while True: + raw = await asyncio.wait_for(ws.recv(), timeout=60) + frame = json.loads(raw) + ftype = frame.get("type") + + if ftype == "text_chunk": + chunk = frame.get("chunk", frame.get("text", "")) + full_response.append(chunk) + print(f" ← text_chunk: {chunk[:80]}") + + elif ftype == "tool_call": + # Respond with a mock tool_result so the agent doesn't hang + call_id = frame.get("id") + action = frame.get("action") + table = frame.get("table", "") + print(f" ← tool_call: {action} {table} (id={call_id})") + + mock_result = {"rows": [], "row": None} + await ws.send(json.dumps({ + "type": "tool_result", + "id": call_id, + **mock_result, + })) + print(f" → tool_result (mock) for {call_id}") + + elif ftype == "final": + text = frame.get("text", "") + if text: + full_response.append(text) + print(f" ← final") + break + + elif ftype == "ping": + # Ignore heartbeats + continue + + else: + print(f" ← {ftype}: {json.dumps(frame)[:120]}") + + except asyncio.TimeoutError: + print(" ⚠ Timed out waiting for response (60s)") + + print() + if full_response: + print(f"[4/4] Full response: {''.join(full_response)}") + else: + print("[4/4] No text response received (check Chat Service logs)") + + +if __name__ == "__main__": + asyncio.run(run_e2e())