79 lines
2.4 KiB
Python
79 lines
2.4 KiB
Python
"""Chat routes: POST /chat and WebSocket /chat/stream."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
|
|
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
|
from fastapi.responses import JSONResponse
|
|
from jose import JWTError, jwt
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.config.settings import settings
|
|
from app.core.orchestrator import orchestrate, orchestrate_stream
|
|
from app.schemas import ChatRequest, UserProfile
|
|
|
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
|
|
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
|
|
|
|
|
@router.post("")
|
|
async def chat(
|
|
body: ChatRequest,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
) -> JSONResponse:
|
|
"""Route a chat message through the orchestrator.
|
|
|
|
Returns ``ChatResponse`` for ``execution_mode='direct'``,
|
|
or ``ExecutionPlan`` for ``execution_mode='plan'``.
|
|
"""
|
|
result = await orchestrate(body)
|
|
return JSONResponse(content=result.model_dump())
|
|
|
|
|
|
@router.websocket("/stream")
|
|
async def chat_stream(websocket: WebSocket) -> None:
|
|
"""Streaming chat via WebSocket.
|
|
|
|
Auth: ``?token=<jwt>`` query param (Bearer not possible during WS handshake).
|
|
|
|
Protocol:
|
|
1. Client sends ``ChatRequest`` as the first JSON text frame.
|
|
2. Server streams response text chunks.
|
|
3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``.
|
|
4. Server pings every 30 s to keep the connection alive.
|
|
"""
|
|
# Authenticate before accepting the connection
|
|
token = websocket.query_params.get("token", "")
|
|
try:
|
|
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
|
user_id: str | None = payload.get("sub")
|
|
if not user_id:
|
|
raise JWTError("missing sub")
|
|
except JWTError:
|
|
await websocket.close(code=1008) # 1008 = Policy Violation
|
|
return
|
|
|
|
await websocket.accept()
|
|
|
|
try:
|
|
raw = await websocket.receive_text()
|
|
body = ChatRequest.model_validate_json(raw)
|
|
|
|
async def _heartbeat() -> None:
|
|
while True:
|
|
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
|
await websocket.send_text(json.dumps({"ping": True}))
|
|
|
|
heartbeat_task = asyncio.create_task(_heartbeat())
|
|
try:
|
|
async for chunk in orchestrate_stream(body):
|
|
await websocket.send_text(chunk)
|
|
finally:
|
|
heartbeat_task.cancel()
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|