"""Chat routes: POST /chat and WebSocket /chat/stream.""" from __future__ import annotations import asyncio import json from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect from fastapi.responses import JSONResponse from jose import JWTError, jwt from app.api.deps import get_current_user from app.config.settings import settings from app.core.orchestrator import orchestrate, orchestrate_stream from app.schemas import ChatRequest, UserProfile router = APIRouter(prefix="/chat", tags=["chat"]) _HEARTBEAT_INTERVAL = 30 # seconds @router.post("") async def chat( body: ChatRequest, current_user: UserProfile = Depends(get_current_user), ) -> JSONResponse: """Route a chat message through the orchestrator. Returns ``ChatResponse`` for ``execution_mode='direct'``, or ``ExecutionPlan`` for ``execution_mode='plan'``. """ result = await orchestrate(body) return JSONResponse(content=result.model_dump()) @router.websocket("/stream") async def chat_stream(websocket: WebSocket) -> None: """Streaming chat via WebSocket. Auth: ``?token=`` query param (Bearer not possible during WS handshake). Protocol: 1. Client sends ``ChatRequest`` as the first JSON text frame. 2. Server streams response text chunks. 3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``. 4. Server pings every 30 s to keep the connection alive. """ # Authenticate before accepting the connection token = websocket.query_params.get("token", "") try: payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]) user_id: str | None = payload.get("sub") if not user_id: raise JWTError("missing sub") except JWTError: await websocket.close(code=1008) # 1008 = Policy Violation return await websocket.accept() try: raw = await websocket.receive_text() body = ChatRequest.model_validate_json(raw) async def _heartbeat() -> None: while True: await asyncio.sleep(_HEARTBEAT_INTERVAL) await websocket.send_text(json.dumps({"ping": True})) heartbeat_task = asyncio.create_task(_heartbeat()) try: async for chunk in orchestrate_stream(body): await websocket.send_text(chunk) finally: heartbeat_task.cancel() except WebSocketDisconnect: pass