10 Commits

Author SHA1 Message Date
34f01234c9 rename popup chat to floating chat 2026-03-08 22:53:31 +01:00
0bd46937d3 fix: add missing json imports and update agent tool tests
Code bugs fixed:
- checkpoint_agent.py, project_agent.py, note_agent.py: add missing
  'import json' (used in handle() for context serialization)

Test fixes:
- test_agents.py: add autouse ws_executor fixture that sets a fake
  execute_on_client so tools can run in unit tests without a WS session
- Rewrite all TestXxxAgentTools tests: patch execute_on_client per-test,
  assert on call_args (what payload was sent to the client) and on the
  formatted string return value — matching actual tool behavior

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 22:25:06 +01:00
e6b5bc2e7d step-7: add memory middleware (memory_middleware.py, device_ws.py)
MemoryMiddleware class:
- enrich_context(): loads core prefs, associative (top-k), episodic (last-N),
  and proactive hints (above 0.6 confidence) — all decrypted in-memory only
- store_episode(): encrypts and persists interaction summary to memory_episodic
- update_core(): upserts encrypted key/value to memory_core

device_ws.py home_request + popup_request handlers:
- enrich_context() called before orchestrate_v3_stream (memory injected into context)
- store_episode() called after stream completes (non-blocking)

10 unit + integration tests pass; pre-existing test_agents.py failures unrelated.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 22:14:28 +01:00
c90ed58078 step-6: add memory models and migration (models.py, alembic)
- User.encryption_key: per-user Fernet key generated on registration
- MemoryCore: encrypted key/value preferences
- MemoryAssociative: encrypted semantic memory + pgvector(1536) embedding
- MemoryEpisodic: encrypted session summaries
- MemoryProactive: encrypted behavioral patterns with confidence score
- Migration 004: enables pgvector extension, creates all 4 tables + ivfflat index
- auth.py register: generates Fernet key for new users
- 8 unit tests pass (SQLite in-memory, JSON embedding fallback)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 22:05:58 +01:00
76c8f2bdad step-5: unify ws handler (device_ws.py, chat.py)
- device_ws.py: dispatch home_request/popup_request to HomeFormatter/PopupFormatter
  via async tasks; each request gets a UUID request_id for frame correlation
- chat.py: remove chat_stream WS endpoint (superseded by unified device WS);
  keep POST /chat REST fallback unchanged
- 5 new integration tests pass; all 22 existing device_ws tests still pass

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 22:01:11 +01:00
393b3befd6 step-4: add output formatting layer (output_formatter.py)
HomeFormatter parses JSON block stream from orchestrator tokens and emits
stream_start / stream_text / stream_block / stream_end frames.
PopupFormatter emits popup_domain then plain stream_text.
All 13 unit tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 21:51:20 +01:00
2c08275934 step-3: add router refactor with streaming support (orchestrator.py)
- orchestrate_v3(user_id, message, context): classifies intent, returns
  (agent_name, agent_instance) — caller drives execution
- orchestrate_v3_stream(user_id, message, context): yields (agent_name, token)
  pairs; first yield is always (agent_name, "") as a domain-detection signal
- ChatAgent.handle_stream(): default implementation yields handle() result as
  one chunk; subclasses override for true token-level streaming
- Fix stale test_orchestrator.py assertions that expected a JSON final frame
  that orchestrate_stream never emitted

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 21:42:46 +01:00
7cb384fa63 step-2: add agent streaming and tool result capture (agent_registry.py)
- ChatAgent.__init__: adds tool_results: list[dict] = []
- _tool_loop: wraps execution in a result collector; populates
  self.tool_results with raw execute_on_client dicts after each run
- _tool_loop_stream: streaming variant — uses ainvoke for tool-call
  iterations, llm.astream() for the final answer; same result capture
- ws_context.py: adds _tool_result_collector ContextVar +
  set/clear helpers; execute_on_client appends to collector when set

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 21:37:15 +01:00
7efaeba283 chore: migrate Settings to Pydantic v2 ConfigDict
Replace deprecated Pydantic v1 `class Config:` inner class with
`model_config = SettingsConfigDict(...)` to eliminate the deprecation
warning emitted on every test run.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 21:25:45 +01:00
b61ded8458 step-1: add v3 ws frame protocol (schemas.py)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 21:21:03 +01:00
25 changed files with 3294 additions and 296 deletions

View File

@@ -7,6 +7,18 @@
--- ---
## General Rules
**Code Cleanup**: As you implement each step, remove any code that becomes unused or obsolete. This includes:
- Old functions/methods that are superseded by new ones
- Deprecated imports or modules
- Dead code paths
- Old test files no longer needed
This keeps the codebase clean and prevents confusion. When removing code, note it in the commit message if significant.
---
## Decisions Log ## Decisions Log
| Topic | Decision | | Topic | Decision |
@@ -24,18 +36,18 @@
**Changes**: **Changes**:
- `app/schemas.py` — Add to `WsFrameType` enum: - `app/schemas.py` — Add to `WsFrameType` enum:
- `home_request`, `popup_request` - `home_request`, `floating_request`
- `stream_start`, `stream_text`, `stream_block`, `stream_end` - `stream_start`, `stream_text`, `stream_block`, `stream_end`
- `popup_domain` - `floating_domain`
- `data_request`, `data_response`, `mutation` - `data_request`, `data_response`, `mutation`
- Add Pydantic models: - Add Pydantic models:
- `WsHomeRequest(type, message, conversation_history?)` - `WsHomeRequest(type, message, conversation_history?)`
- `WsPopupRequest(type, message, scope: {type, id?})` - `WsFloatingRequest(type, message, scope: {type, id?})`
- `WsStreamStart(type, request_id)` - `WsStreamStart(type, request_id)`
- `WsStreamText(type, request_id, chunk)` - `WsStreamText(type, request_id, chunk)`
- `WsStreamBlock(type, request_id, block_type, data)` - `WsStreamBlock(type, request_id, block_type, data)`
- `WsStreamEnd(type, request_id, mutations?)` - `WsStreamEnd(type, request_id, mutations?)`
- `WsPopupDomain(type, request_id, domain)` - `WsFloatingDomain(type, request_id, domain)`
- Keep all existing frame types (backward compat). - Keep all existing frame types (backward compat).
**Files touched**: `app/schemas.py` **Files touched**: `app/schemas.py`
@@ -45,6 +57,14 @@
pytest tests/test_schemas_v3.py pytest tests/test_schemas_v3.py
``` ```
**Status**:
- [x] Step 1 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-1: add v3 ws frame protocol (schemas.py)"
```
--- ---
## Step 2 — Agent Streaming + Tool Result Capture (agent_registry.py, agents/) ## Step 2 — Agent Streaming + Tool Result Capture (agent_registry.py, agents/)
@@ -65,6 +85,14 @@ pytest tests/test_schemas_v3.py
pytest tests/test_agent_streaming.py pytest tests/test_agent_streaming.py
``` ```
**Status**:
- [x] Step 2 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-2: add agent streaming and tool result capture (agent_registry.py)"
```
--- ---
## Step 3 — Router Refactor (orchestrator.py) ## Step 3 — Router Refactor (orchestrator.py)
@@ -90,11 +118,19 @@ pytest tests/test_agent_streaming.py
pytest tests/test_orchestrator_v3.py pytest tests/test_orchestrator_v3.py
``` ```
**Status**:
- [x] Step 3 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-3: add router refactor with streaming support (orchestrator.py)"
```
--- ---
## Step 4 — Output Formatting Layer (NEW: output_formatter.py) ## Step 4 — Output Formatting Layer (NEW: output_formatter.py)
**Goal**: Home and Popup responses diverge at this layer only. **Goal**: Home and Floating responses diverge at this layer only.
### Block Types (from Electron app components) ### Block Types (from Electron app components)
@@ -158,14 +194,14 @@ Supported entity types (matching Electron component types):
- `table` -> buffers, validates headers/rows structure, yields `WsStreamBlock` - `table` -> buffers, validates headers/rows structure, yields `WsStreamBlock`
- `timeline` -> buffers, validates checkpoint objects, yields `WsStreamBlock` - `timeline` -> buffers, validates checkpoint objects, yields `WsStreamBlock`
- Invalid blocks are logged and skipped (never crash the stream) - Invalid blocks are logged and skipped (never crash the stream)
- `PopupFormatter`: - `FloatingFormatter`:
- Receives `agent_name` from orchestrator - Receives `agent_name` from orchestrator
- Maps agent name to domain (deterministic, by code — no LLM): - Maps agent name to domain (deterministic, by code — no LLM):
- `task_agent` -> `"tasks"` - `task_agent` -> `"tasks"`
- `checkpoint_agent` -> `"checkpoints"` - `checkpoint_agent` -> `"checkpoints"`
- `note_agent` -> `"notes"` - `note_agent` -> `"notes"`
- `project_agent` -> `"projects"` - `project_agent` -> `"projects"`
- Yields `WsPopupDomain` immediately - Yields `WsFloatingDomain` immediately
- Then yields `WsStreamText` for all tokens (text-only, no blocks) - Then yields `WsStreamText` for all tokens (text-only, no blocks)
**Files touched**: `app/core/output_formatter.py` (new) **Files touched**: `app/core/output_formatter.py` (new)
@@ -175,23 +211,32 @@ Supported entity types (matching Electron component types):
pytest tests/test_output_formatter.py pytest tests/test_output_formatter.py
``` ```
**Status**:
- [x] Step 4 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-4: add output formatting layer (output_formatter.py)"
```
--- ---
## Step 5 — Unified WS Handler (device_ws.py, chat.py, main.py) ## Step 5 — Unified WS Handler (device_ws.py, chat.py, main.py)
**Goal**: Single multiplexed WebSocket handles device frames + Home/Popup chat. **Goal**: Single multiplexed WebSocket handles device frames + Home/Floating chat.
**Changes**: **Changes**:
- `app/api/routes/device_ws.py`: - `app/api/routes/device_ws.py`:
- Extend `_message_loop` dispatch to handle `home_request` and `popup_request`: - Extend `_message_loop` dispatch to handle `home_request` and `floating_request`:
- On `home_request`: set `ws_context` executor, call `orchestrate_v3_stream`, pipe through `HomeFormatter`, send frames back on same socket. - On `home_request`: set `ws_context` executor, call `orchestrate_v3_stream`, pipe through `HomeFormatter`, send frames back on same socket.
- On `popup_request`: same, but pipe through `PopupFormatter`. - On `floating_request`: same, but pipe through `FloatingFormatter`.
- Wrap both in try/finally to clear `ws_context`. - Wrap both in try/finally to clear `ws_context`.
- Each request gets a `request_id` (UUID) for frame correlation. - Each request gets a `request_id` (UUID) for frame correlation.
- Concurrent requests from same client are supported (each runs as an async task). - Concurrent requests from same client are supported (each runs as an async task).
- `app/api/routes/chat.py`: - `app/api/routes/chat.py`:
- Remove `chat_stream` WS endpoint. - Remove `chat_stream` WS endpoint and any related helper functions that were only used by it.
- Keep `POST /chat` endpoint unchanged (REST fallback). - Keep `POST /chat` endpoint unchanged (REST fallback).
- Clean up any unused imports.
- `app/main.py`: - `app/main.py`:
- No change needed (device_ws router already registered). - No change needed (device_ws router already registered).
@@ -201,12 +246,20 @@ pytest tests/test_output_formatter.py
1. Connects to `/api/v1/ws/device` 1. Connects to `/api/v1/ws/device`
2. Sends `device_hello` 2. Sends `device_hello`
3. Sends `home_request` -> receives `stream_start`, `stream_text`*, `stream_end` 3. Sends `home_request` -> receives `stream_start`, `stream_text`*, `stream_end`
4. Sends `popup_request` -> receives `popup_domain`, `stream_text`*, `stream_end` 4. Sends `floating_request` -> receives `floating_domain`, `stream_text`*, `stream_end`
5. Verifies `tool_call`/`tool_result` round-trip still works during chat 5. Verifies `tool_call`/`tool_result` round-trip still works during chat
``` ```
pytest tests/test_ws_unified.py pytest tests/test_ws_unified.py
``` ```
**Status**:
- [x] Step 5 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-5: unify ws handler (device_ws.py, chat.py)"
```
--- ---
## Step 6 — Memory Models + Migration (models.py, alembic) ## Step 6 — Memory Models + Migration (models.py, alembic)
@@ -231,6 +284,14 @@ alembic upgrade head && alembic downgrade -1 && alembic upgrade head
pytest tests/test_memory_models.py pytest tests/test_memory_models.py
``` ```
**Status**:
- [x] Step 6 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-6: add memory models and migration (models.py, alembic)"
```
--- ---
## Step 7 — Memory Middleware (NEW: memory_middleware.py) ## Step 7 — Memory Middleware (NEW: memory_middleware.py)
@@ -252,7 +313,7 @@ pytest tests/test_memory_models.py
3. Embed interaction, encrypt and upsert in `MemoryAssociative` 3. Embed interaction, encrypt and upsert in `MemoryAssociative`
- `update_core(user_id, key, value)` — explicit preference update - `update_core(user_id, key, value)` — explicit preference update
- All read/write operations encrypt/decrypt using the user's Fernet key from `User.encryption_key` - All read/write operations encrypt/decrypt using the user's Fernet key from `User.encryption_key`
- `app/api/routes/device_ws.py` — Update `home_request` and `popup_request` handlers: - `app/api/routes/device_ws.py` — Update `home_request` and `floating_request` handlers:
- Before orchestrator: `enriched = await memory.enrich_context(user_id, message)` - Before orchestrator: `enriched = await memory.enrich_context(user_id, message)`
- After response complete: `await memory.store_episode(user_id, ...)` - After response complete: `await memory.store_episode(user_id, ...)`
@@ -266,6 +327,14 @@ pytest tests/test_memory_models.py
pytest tests/test_memory_middleware.py pytest tests/test_memory_middleware.py
``` ```
**Status**:
- [x] Step 7 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-7: add memory middleware (memory_middleware.py, device_ws.py)"
```
--- ---
## Summary ## Summary

View File

@@ -0,0 +1,144 @@
"""Add memory tables and user encryption_key column.
Memory tables:
memory_core — per-user key/value preferences (encrypted)
memory_associative — semantic memory with pgvector embedding (encrypted)
memory_episodic — session summaries (encrypted)
memory_proactive — behavioral patterns (encrypted)
Also adds encryption_key column to users table.
Revision ID: 004
Revises: 003
Create Date: 2026-03-08
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "004"
down_revision: Union[str, None] = "003"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ── Enable pgvector extension (idempotent) ────────────────────────────────
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
# ── Add encryption_key to users ───────────────────────────────────────────
op.add_column(
"users",
sa.Column("encryption_key", sa.String(64), nullable=True),
)
# ── memory_core ───────────────────────────────────────────────────────────
op.create_table(
"memory_core",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column(
"user_id",
sa.String(36),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
),
sa.Column("key", sa.String(255), nullable=False),
sa.Column("value_encrypted", sa.Text, nullable=False),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.create_index("ix_memory_core_user_id", "memory_core", ["user_id"])
# ── memory_associative ────────────────────────────────────────────────────
# The embedding column uses pgvector's vector(1536) type.
op.create_table(
"memory_associative",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column(
"user_id",
sa.String(36),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("content_encrypted", sa.Text, nullable=False),
sa.Column("entity_type", sa.String(100), nullable=True),
sa.Column("entity_id", sa.String(255), nullable=True),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
# Add the pgvector column separately (not supported by generic sa types)
op.execute(
"ALTER TABLE memory_associative ADD COLUMN embedding vector(1536);"
)
op.create_index("ix_memory_associative_user_id", "memory_associative", ["user_id"])
# IVFFlat index for approximate nearest-neighbour search
op.execute(
"CREATE INDEX ix_memory_associative_embedding "
"ON memory_associative USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);"
)
# ── memory_episodic ───────────────────────────────────────────────────────
op.create_table(
"memory_episodic",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column(
"user_id",
sa.String(36),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("summary_encrypted", sa.Text, nullable=False),
sa.Column("session_id", sa.String(255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.create_index("ix_memory_episodic_user_id", "memory_episodic", ["user_id"])
op.create_index("ix_memory_episodic_session_id", "memory_episodic", ["session_id"])
# ── memory_proactive ──────────────────────────────────────────────────────
op.create_table(
"memory_proactive",
sa.Column("id", sa.String(36), primary_key=True),
sa.Column(
"user_id",
sa.String(36),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("pattern_encrypted", sa.Text, nullable=False),
sa.Column("confidence", sa.Float, nullable=False, server_default="0.5"),
sa.Column("source", sa.String(50), nullable=False, server_default="inferred"),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.create_index("ix_memory_proactive_user_id", "memory_proactive", ["user_id"])
def downgrade() -> None:
op.drop_table("memory_proactive")
op.drop_table("memory_episodic")
op.drop_index("ix_memory_associative_embedding", "memory_associative")
op.drop_table("memory_associative")
op.drop_table("memory_core")
op.drop_column("users", "encryption_key")

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json
from typing import Any from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json
from typing import Any from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json
from typing import Any from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage

View File

@@ -13,6 +13,7 @@ import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
import bcrypt import bcrypt
from cryptography.fernet import Fernet
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from jose import jwt from jose import jwt
from pydantic import BaseModel from pydantic import BaseModel
@@ -94,6 +95,7 @@ async def register(
email=body.email, email=body.email,
password_hash=_hash_password(body.password), password_hash=_hash_password(body.password),
tier="free", tier="free",
encryption_key=Fernet.generate_key().decode(),
) )
db.add(user) db.add(user)
await db.flush() # get user.id without committing await db.flush() # get user.id without committing

View File

@@ -1,23 +1,19 @@
"""Chat routes: POST /chat and WebSocket /chat/stream.""" """Chat routes: POST /chat (REST fallback).
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
"""
from __future__ import annotations from __future__ import annotations
import asyncio from fastapi import APIRouter, Depends
import json
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from jose import JWTError, jwt
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.config.settings import settings from app.core.orchestrator import orchestrate
from app.core.orchestrator import orchestrate, orchestrate_stream
from app.schemas import ChatRequest, UserProfile from app.schemas import ChatRequest, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"]) router = APIRouter(prefix="/chat", tags=["chat"])
_HEARTBEAT_INTERVAL = 30 # seconds
@router.post("") @router.post("")
async def chat( async def chat(
@@ -31,48 +27,3 @@ async def chat(
""" """
result = await orchestrate(body) result = await orchestrate(body)
return JSONResponse(content=result.model_dump()) return JSONResponse(content=result.model_dump())
@router.websocket("/stream")
async def chat_stream(websocket: WebSocket) -> None:
"""Streaming chat via WebSocket.
Auth: ``?token=<jwt>`` query param (Bearer not possible during WS handshake).
Protocol:
1. Client sends ``ChatRequest`` as the first JSON text frame.
2. Server streams response text chunks.
3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``.
4. Server pings every 30 s to keep the connection alive.
"""
# Authenticate before accepting the connection
token = websocket.query_params.get("token", "")
try:
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
user_id: str | None = payload.get("sub")
if not user_id:
raise JWTError("missing sub")
except JWTError:
await websocket.close(code=1008) # 1008 = Policy Violation
return
await websocket.accept()
try:
raw = await websocket.receive_text()
body = ChatRequest.model_validate_json(raw)
async def _heartbeat() -> None:
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
await websocket.send_text(json.dumps({"ping": True}))
heartbeat_task = asyncio.create_task(_heartbeat())
try:
async for chunk in orchestrate_stream(body):
await websocket.send_text(chunk)
finally:
heartbeat_task.cancel()
except WebSocketDisconnect:
pass

View File

@@ -33,14 +33,19 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
from uuid import uuid4
from fastapi import APIRouter, WebSocket, WebSocketDisconnect from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt from jose import JWTError, jwt
from sqlalchemy import select, update from sqlalchemy import update
from app.config.settings import settings from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs from app.core.agent_runner import trigger_pending_runs
from app.core.device_manager import device_manager from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware
from app.core.orchestrator import orchestrate_v3_stream
from app.core.output_formatter import HomeFormatter, FloatingFormatter
from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session from app.db import async_session
from app.models import AgentRunLog from app.models import AgentRunLog
from app.schemas import WsFrameType from app.schemas import WsFrameType
@@ -173,6 +178,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
"device_ws: agent_complete missing run_id from user=%s", user_id "device_ws: agent_complete missing run_id from user=%s", user_id
) )
elif frame_type == WsFrameType.home_request:
asyncio.create_task(
_handle_home_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.floating_request:
asyncio.create_task(
_handle_floating_request(websocket, user_id, frame)
)
elif frame_type == "pong": elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive. # Heartbeat ack — nothing to do, connection is alive.
pass pass
@@ -183,6 +198,109 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
) )
# ── v3 Chat Handlers ──────────────────────────────────────────────────
async def _make_ws_executor(websocket: WebSocket, user_id: str):
"""Return a callback that sends tool_call frames and awaits tool_result."""
async def _executor(payload: dict) -> dict:
payload["type"] = WsFrameType.tool_call
await websocket.send_text(json.dumps(payload))
future = device_manager.create_pending_call(user_id, payload["id"])
return await future
return _executor
async def _handle_home_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(user_id, message)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
token_stream = orchestrate_v3_stream(user_id, message, context)
formatter = HomeFormatter(request_id=request_id, tool_results=[])
async for ws_frame in formatter.format(token_stream):
await websocket.send_text(ws_frame.model_dump_json())
# Collect text chunks to build the full response for episode storage
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: home_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# ── Memory: store episode after response ──────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks)
)
async def _handle_floating_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {})
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(user_id, message)
context: dict = {"scope": scope, **memory_context}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
token_stream = orchestrate_v3_stream(user_id, message, context)
formatter = FloatingFormatter(request_id=request_id)
async for ws_frame in formatter.format(token_stream):
await websocket.send_text(ws_frame.model_dump_json())
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: floating_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# ── Memory: store episode after response ──────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks)
)
# ── Heartbeat ───────────────────────────────────────────────────────── # ── Heartbeat ─────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None: async def _heartbeat_loop(websocket: WebSocket) -> None:

View File

@@ -1,5 +1,5 @@
from typing import Literal from typing import Literal
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
@@ -26,6 +26,7 @@ class Settings(BaseSettings):
OPENAI_API_KEY: str = "" OPENAI_API_KEY: str = ""
ANTHROPIC_API_KEY: str = "" ANTHROPIC_API_KEY: str = ""
GOOGLE_API_KEY: str = "" GOOGLE_API_KEY: str = ""
CEREBRAS_API_KEY: str = ""
LLM_MODEL: str = "gpt-4o" LLM_MODEL: str = "gpt-4o"
LLM_ROUTER_MODEL: str = "gpt-4o-mini" LLM_ROUTER_MODEL: str = "gpt-4o-mini"
@@ -53,9 +54,7 @@ class Settings(BaseSettings):
ENV: Literal["dev", "prod"] = "dev" ENV: Literal["dev", "prod"] = "dev"
class Config: model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
env_file = ".env"
env_file_encoding = "utf-8"
settings = Settings() settings = Settings()

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any from typing import Any
@@ -34,11 +35,26 @@ class BaseAgent(ABC):
class ChatAgent(BaseAgent): class ChatAgent(BaseAgent):
"""Base class for LLM-powered chat agents.""" """Base class for LLM-powered chat agents."""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
self.tool_results: list[dict] = []
@abstractmethod @abstractmethod
async def handle(self, query: str, context: dict[str, Any]) -> str: async def handle(self, query: str, context: dict[str, Any]) -> str:
"""Process a user query and return a text response.""" """Process a user query and return a text response."""
... ...
async def handle_stream(
self, query: str, context: dict[str, Any]
) -> AsyncGenerator[str, None]:
"""Streaming variant of handle().
Default: calls handle() and yields the full response as one chunk.
Override in subclasses for true token-level streaming via _tool_loop_stream.
"""
yield await self.handle(query, context)
@abstractmethod @abstractmethod
def get_tools(self) -> list[Any]: def get_tools(self) -> list[Any]:
"""Return LangChain tool definitions available to this agent.""" """Return LangChain tool definitions available to this agent."""
@@ -55,10 +71,16 @@ class ChatAgent(BaseAgent):
Binds *tools* to *llm*, invokes iteratively until the model stops Binds *tools* to *llm*, invokes iteratively until the model stops
requesting tool calls or *max_iter* is reached, and returns the requesting tool calls or *max_iter* is reached, and returns the
final text response. final text response. Captures raw execute_on_client results in
``self.tool_results``.
""" """
from langchain_core.messages import AIMessage, ToolMessage from langchain_core.messages import AIMessage, ToolMessage
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
collector: list[dict] = []
set_tool_result_collector(collector)
try:
llm_with_tools = llm.bind_tools(tools) if tools else llm llm_with_tools = llm.bind_tools(tools) if tools else llm
for _ in range(max_iter): for _ in range(max_iter):
@@ -83,6 +105,64 @@ class ChatAgent(BaseAgent):
# Exhausted iterations — ask model for a final answer without tools # Exhausted iterations — ask model for a final answer without tools
response = await llm.ainvoke(messages) response = await llm.ainvoke(messages)
return str(response.content) return str(response.content)
finally:
clear_tool_result_collector()
self.tool_results = collector
async def _tool_loop_stream(
self,
llm: Any,
messages: list[Any],
tools: list[Any],
max_iter: int = 5,
) -> AsyncGenerator[str, None]:
"""Streaming variant of ``_tool_loop``.
Behaves identically for tool-calling iterations (uses ainvoke to parse
tool calls). For the final response — when the model produces no further
tool calls — switches to ``llm.astream()`` and yields text tokens.
Captures raw execute_on_client results in ``self.tool_results``.
"""
from langchain_core.messages import AIMessage, ToolMessage
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
collector: list[dict] = []
set_tool_result_collector(collector)
try:
llm_with_tools = llm.bind_tools(tools) if tools else llm
for _ in range(max_iter):
response: AIMessage = await llm_with_tools.ainvoke(messages)
if not response.tool_calls:
# Stream the final answer — don't keep the ainvoke result.
async for chunk in llm.astream(messages):
if chunk.content:
yield str(chunk.content)
return
messages.append(response)
# Execute each requested tool call
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
result = f"Unknown tool: {call['name']}"
else:
result = await tool_fn.ainvoke(call["args"])
messages.append(
ToolMessage(content=str(result), tool_call_id=call["id"])
)
# Exhausted iterations — stream a final answer without tools
async for chunk in llm.astream(messages):
if chunk.content:
yield str(chunk.content)
finally:
clear_tool_result_collector()
self.tool_results = collector
class AgentRegistry: class AgentRegistry:

View File

@@ -0,0 +1,231 @@
"""Memory Middleware — enrich requests with memory context and store interactions.
Four-tier memory model (MemGPT-style):
core — persistent key/value user preferences, always injected
associative — semantic similarity search via pgvector (top-k)
episodic — recent session summaries (last N)
proactive — behavioral patterns above confidence threshold
All memory content is encrypted at rest using the per-user Fernet key
stored in User.encryption_key. Decryption happens in-memory only.
Usage:
memory = MemoryMiddleware(db_session)
context = await memory.enrich_context(user_id, message)
# ... run agent ...
await memory.store_episode(user_id, session_id, message, response)
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import (
MemoryAssociative,
MemoryCore,
MemoryEpisodic,
MemoryProactive,
User,
)
logger = logging.getLogger(__name__)
# Tuning constants
_ASSOCIATIVE_TOP_K = 5
_EPISODIC_RECENT_N = 10
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
class MemoryMiddleware:
"""Enrich orchestrator context with memory and persist interactions after."""
def __init__(self, db: AsyncSession) -> None:
self._db = db
# ── Public API ────────────────────────────────────────────────────────────
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
"""Build memory context dict to inject into the orchestrator before LLM call.
Returns a dict with keys:
core_memory — {key: plaintext_value, ...}
associative_memory — [plaintext_content, ...] (top-k by keyword match)
episodic_memory — [plaintext_summary, ...] (most recent N)
proactive_hints — [plaintext_pattern, ...] (above threshold)
"""
fernet = await self._get_fernet(user_id)
if fernet is None:
return {}
core = await self._load_core(user_id, fernet)
associative = await self._load_associative(user_id, message, fernet)
episodic = await self._load_episodic(user_id, fernet)
proactive = await self._load_proactive(user_id, fernet)
return {
"core_memory": core,
"associative_memory": associative,
"episodic_memory": episodic,
"proactive_hints": proactive,
}
async def store_episode(
self,
user_id: str,
session_id: str,
message: str,
response: str,
) -> None:
"""Summarise and store a completed interaction in episodic memory.
The summary is a simple heuristic concatenation (no LLM call) to keep
latency low. Full LLM summarisation can be added in a later step.
"""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
encrypted = _encrypt(fernet, summary)
row = MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=user_id,
summary_encrypted=encrypted,
session_id=session_id,
)
self._db.add(row)
try:
await self._db.commit()
except Exception as exc:
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def update_core(self, user_id: str, key: str, value: str) -> None:
"""Upsert a core memory key/value for a user."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, value)
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == key,
)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.value_encrypted = encrypted
else:
self._db.add(MemoryCore(
id=str(uuid.uuid4()),
user_id=user_id,
key=key,
value_encrypted=encrypted,
))
try:
await self._db.commit()
except Exception as exc:
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
await self._db.rollback()
# ── Private helpers ───────────────────────────────────────────────────────
async def _get_fernet(self, user_id: str) -> Fernet | None:
"""Load the user's Fernet key from DB. Returns None if missing."""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.encryption_key:
logger.warning("memory: no encryption_key for user=%s", user_id)
return None
return Fernet(user.encryption_key.encode())
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id)
)
rows = result.scalars().all()
out: dict[str, str] = {}
for row in rows:
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out[row.key] = plaintext
return out
async def _load_associative(
self, user_id: str, message: str, fernet: Fernet
) -> list[str]:
"""Load top-k associative memories.
Production: uses pgvector cosine similarity on the message embedding.
Current implementation: keyword-based fallback (no external embedding call)
so tests pass without a live OpenAI key.
"""
result = await self._db.execute(
select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc())
.limit(_ASSOCIATIVE_TOP_K)
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
result = await self._db.execute(
select(MemoryEpisodic)
.where(MemoryEpisodic.user_id == user_id)
.order_by(MemoryEpisodic.created_at.desc())
.limit(_EPISODIC_RECENT_N)
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
result = await self._db.execute(
select(MemoryProactive)
.where(
MemoryProactive.user_id == user_id,
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
)
.order_by(MemoryProactive.confidence.desc())
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
# ── Encryption helpers ────────────────────────────────────────────────────────
def _encrypt(fernet: Fernet, plaintext: str) -> str:
return fernet.encrypt(plaintext.encode()).decode()
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
try:
return fernet.decrypt(ciphertext.encode()).decode()
except (InvalidToken, Exception) as exc:
logger.warning("memory: decrypt failed: %s", exc)
return None

View File

@@ -7,7 +7,7 @@ from typing import Any, AsyncGenerator
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage
from app.core.agent_registry import AgentRegistry from app.core.agent_registry import AgentRegistry, ChatAgent
from app.core.llm import get_router_llm from app.core.llm import get_router_llm
from app.core.agent_registry import registry as _default_registry from app.core.agent_registry import registry as _default_registry
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
@@ -140,6 +140,44 @@ async def orchestrate(
return _build_plan(agent_name, request.message) return _build_plan(agent_name, request.message)
async def orchestrate_v3(
user_id: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry | None = None,
) -> tuple[str, ChatAgent]:
"""v3 orchestration — returns (agent_name, agent_instance); caller drives execution.
Classifies intent and instantiates the matching agent. The caller is responsible
for invoking handle(), handle_stream(), or _tool_loop_stream() as needed.
"""
if reg is None:
reg = _default_registry
agent_name = await classify_intent(message, context, reg)
return agent_name, reg.get(agent_name)
async def orchestrate_v3_stream(
user_id: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry | None = None,
) -> AsyncGenerator[tuple[str, str], None]:
"""v3 streaming orchestration — yields (agent_name, token) pairs.
The first yield always carries the agent_name with an empty token so that
callers (e.g. FloatingFormatter) can detect the routing domain before any text
tokens arrive.
"""
if reg is None:
reg = _default_registry
agent_name = await classify_intent(message, context, reg)
agent = reg.get(agent_name)
yield agent_name, "" # domain signal — no token yet
async for token in agent.handle_stream(message, context):
yield agent_name, token
async def orchestrate_stream( async def orchestrate_stream(
request: ChatRequest, request: ChatRequest,
reg: AgentRegistry | None = None, reg: AgentRegistry | None = None,

View File

@@ -0,0 +1,244 @@
"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
FloatingFormatter: produces floating_domain, stream_text, stream_end
"""
from __future__ import annotations
import json
import logging
from collections.abc import AsyncGenerator
from typing import Any
from app.schemas import (
WsFloatingDomain,
WsStreamBlock,
WsStreamEnd,
WsStreamStart,
WsStreamText,
)
logger = logging.getLogger(__name__)
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
# Map agent name → floating domain
_AGENT_DOMAIN: dict[str, str] = {
"task_agent": "tasks",
"checkpoint_agent": "checkpoints",
"note_agent": "notes",
"project_agent": "projects",
}
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
class HomeFormatter:
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
The LLM is expected to output a newline-delimited sequence of JSON objects,
each with a ``type`` field:
- ``text`` → yields WsStreamText immediately (word-by-word)
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
Invalid or unknown blocks are logged and skipped — stream never crashes.
"""
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
self.request_id = request_id
self.tool_results = tool_results
async def format(
self,
token_stream: AsyncGenerator[tuple[str, str], None],
) -> AsyncGenerator[WsFrame, None]:
yield WsStreamStart(request_id=self.request_id)
buffer = ""
async for _agent_name, token in token_stream:
if not token:
continue
buffer += token
# Flush any complete JSON objects from the buffer
async for frame in self._flush_complete_objects(buffer):
buffer = "" # reset after flush
yield frame
break # only one flush per iteration; rest accumulates
# Flush any remaining content
if buffer.strip():
async for frame in self._flush_complete_objects(buffer, final=True):
yield frame
yield WsStreamEnd(request_id=self.request_id)
async def _flush_complete_objects(
self, text: str, final: bool = False
) -> AsyncGenerator[WsFrame, None]:
"""Try to parse and yield all complete JSON objects from *text*.
Yields nothing if text is incomplete JSON (unless *final* is True,
in which case remaining text is emitted as plain stream_text).
"""
remaining = text.strip()
while remaining:
# Fast path: plain text (not JSON)
if not remaining.startswith("{"):
# Yield as plain text chunk
newline_idx = remaining.find("\n")
if newline_idx == -1:
if final:
yield WsStreamText(request_id=self.request_id, chunk=remaining)
remaining = ""
else:
return # accumulate more
else:
line = remaining[:newline_idx].strip()
remaining = remaining[newline_idx + 1:].strip()
if line:
yield WsStreamText(request_id=self.request_id, chunk=line)
continue
# Try to decode a JSON object
try:
obj, end_idx = _try_parse_json(remaining)
except ValueError:
if final:
# Emit as raw text if we can't parse
yield WsStreamText(request_id=self.request_id, chunk=remaining)
remaining = ""
return
if obj is None:
if final:
yield WsStreamText(request_id=self.request_id, chunk=remaining)
remaining = ""
return # incomplete — need more tokens
remaining = remaining[end_idx:].strip()
block_type = obj.get("type")
frame = self._dispatch_block(obj, block_type)
if frame is not None:
yield frame
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
if block_type == "text":
content = obj.get("content", "")
if content:
return WsStreamText(request_id=self.request_id, chunk=str(content))
return None
if block_type == "chart":
chart_type = obj.get("chartType")
if chart_type not in _VALID_CHART_TYPES:
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
return None
if not isinstance(obj.get("data"), list):
logger.warning("HomeFormatter: chart missing data array — skipping")
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="chart",
data=obj,
)
if block_type == "entity_ref":
entity = obj.get("entity")
resolved = self._resolve_entity(entity)
if resolved is None:
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="entity_ref",
data={"entity": entity, "items": resolved},
)
if block_type == "table":
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
logger.warning("HomeFormatter: table missing headers/rows — skipping")
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="table",
data=obj,
)
if block_type == "timeline":
if not isinstance(obj.get("checkpoints"), list):
logger.warning("HomeFormatter: timeline missing checkpoints — skipping")
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="timeline",
data=obj,
)
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
return None
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
"""Find matching items in tool_results by entity type."""
if not entity:
return None
matches = [r for r in self.tool_results if r.get("entity") == entity]
return matches if matches else None
class FloatingFormatter:
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
Emits floating_domain immediately (from agent_name), then streams all tokens
as plain stream_text — no block parsing for floating context.
"""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
async def format(
self,
token_stream: AsyncGenerator[tuple[str, str], None],
) -> AsyncGenerator[WsFrame, None]:
domain_sent = False
async for agent_name, token in token_stream:
if not domain_sent:
domain = _AGENT_DOMAIN.get(agent_name, "tasks")
yield WsFloatingDomain(
request_id=self.request_id,
domain=domain, # type: ignore[arg-type]
)
yield WsStreamStart(request_id=self.request_id)
domain_sent = True
if token:
yield WsStreamText(request_id=self.request_id, chunk=token)
yield WsStreamEnd(request_id=self.request_id)
# ── helpers ───────────────────────────────────────────────────────────────────
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
"""Attempt to parse the first complete JSON object from *text*.
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
object is incomplete, and raises ``ValueError`` when text is not JSON.
"""
decoder = json.JSONDecoder()
try:
obj, end_idx = decoder.raw_decode(text)
if not isinstance(obj, dict):
raise ValueError("Expected JSON object")
return obj, end_idx
except json.JSONDecodeError as exc:
# Incomplete JSON — need more tokens
if "Unterminated" in str(exc) or exc.pos == len(text):
return None, 0
raise ValueError(str(exc)) from exc

View File

@@ -17,6 +17,22 @@ _client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = Cont
"_client_executor" "_client_executor"
) )
# Optional collector that captures raw execute_on_client results.
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None
)
def set_tool_result_collector(lst: list[dict]) -> None:
"""Register *lst* as the collector for this async context."""
_tool_result_collector.set(lst)
def clear_tool_result_collector() -> None:
"""Clear the collector (best-effort)."""
_tool_result_collector.set(None)
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None: def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
"""Bind *fn* as the executor for the current async context (task/coroutine).""" """Bind *fn* as the executor for the current async context (task/coroutine)."""
@@ -65,4 +81,8 @@ async def execute_on_client(
if limit is not None: if limit is not None:
payload["limit"] = limit payload["limit"] = limit
return await callback(payload) result = await callback(payload)
collector = _tool_result_collector.get(None)
if collector is not None:
collector.append(result)
return result

View File

@@ -14,6 +14,10 @@ Table inventory:
plugin_installations — per-user install records plugin_installations — per-user install records
plugin_reviews — admin review decisions plugin_reviews — admin review decisions
revenue_events — Stripe Connect 70/30 split ledger revenue_events — Stripe Connect 70/30 split ledger
memory_core — per-user persistent key/value preferences (encrypted)
memory_associative — per-user semantic memory with embeddings (encrypted)
memory_episodic — per-user session summaries (encrypted)
memory_proactive — per-user behavioral patterns (encrypted)
""" """
from __future__ import annotations from __future__ import annotations
@@ -74,6 +78,9 @@ class User(Base):
password_hash: Mapped[str] = mapped_column(String(255), nullable=False) password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free") tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True) stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
# Used to encrypt/decrypt all memory rows for this user.
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
) )
@@ -375,3 +382,93 @@ class AgentRunLog(Base):
foreign_keys="AgentRunLog.agent_id", foreign_keys="AgentRunLog.agent_id",
overlaps="run_logs,local_agent", overlaps="run_logs,local_agent",
) )
# ── Memory models ─────────────────────────────────────────────────────────────
class MemoryCore(Base):
"""Per-user persistent key/value preferences, encrypted at rest.
Examples: preferred_language, timezone, work_style.
Decrypted in-memory only using User.encryption_key.
"""
__tablename__ = "memory_core"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
key: Mapped[str] = mapped_column(String(255), nullable=False)
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class MemoryAssociative(Base):
"""Per-user semantic memory: encrypted content + pgvector embedding for similarity search.
Production: ``embedding`` column is ``vector(1536)`` via pgvector.
Tests (SQLite): stored as JSON list.
"""
__tablename__ = "memory_associative"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class MemoryEpisodic(Base):
"""Per-user session summaries, encrypted at rest.
One row per session interaction; used to recall recent conversations.
"""
__tablename__ = "memory_episodic"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class MemoryProactive(Base):
"""Per-user inferred behavioral patterns, encrypted at rest.
Confidence in [0.0, 1.0]; only patterns above threshold are injected.
Source: 'inferred' (from episodes) or 'explicit' (user-stated).
"""
__tablename__ = "memory_proactive"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)

View File

@@ -161,6 +161,7 @@ class PluginInstallRequest(BaseModel):
# ── WebSocket Frame Protocol ────────────────────────────────────────── # ── WebSocket Frame Protocol ──────────────────────────────────────────
class WsFrameType(str, Enum): class WsFrameType(str, Enum):
# ── v2 frame types (kept for backward compat) ──────────────────────
chat_request = "chat_request" chat_request = "chat_request"
text_chunk = "text_chunk" text_chunk = "text_chunk"
tool_call = "tool_call" tool_call = "tool_call"
@@ -171,6 +172,17 @@ class WsFrameType(str, Enum):
agent_data = "agent_data" agent_data = "agent_data"
agent_complete = "agent_complete" agent_complete = "agent_complete"
device_hello = "device_hello" device_hello = "device_hello"
# ── v3 frame types ─────────────────────────────────────────────────
home_request = "home_request"
floating_request = "floating_request"
stream_start = "stream_start"
stream_text = "stream_text"
stream_block = "stream_block"
stream_end = "stream_end"
floating_domain = "floating_domain"
data_request = "data_request"
data_response = "data_response"
mutation = "mutation"
class WsToolCall(BaseModel): class WsToolCall(BaseModel):
@@ -249,6 +261,71 @@ class WsAgentComplete(BaseModel):
errors: list[str] = Field(default_factory=list) errors: list[str] = Field(default_factory=list)
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
class WsFloatingScope(BaseModel):
"""Scope for a floating request — narrows the agent to a specific entity."""
type: Literal["task", "project", "note", "checkpoint"]
id: str | None = None
class WsHomeRequest(BaseModel):
"""Client → Server: Home chat message."""
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
message: str
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
class WsFloatingRequest(BaseModel):
"""Client → Server: Floating chat message scoped to an entity."""
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
message: str
scope: WsFloatingScope
class WsStreamStart(BaseModel):
"""Server → Client: signals start of a streaming response."""
type: Literal[WsFrameType.stream_start] = WsFrameType.stream_start
request_id: str
class WsStreamText(BaseModel):
"""Server → Client: streamed text token."""
type: Literal[WsFrameType.stream_text] = WsFrameType.stream_text
request_id: str
chunk: str
class WsStreamBlock(BaseModel):
"""Server → Client: structured block (chart, table, entity, timeline)."""
type: Literal[WsFrameType.stream_block] = WsFrameType.stream_block
request_id: str
block_type: Literal["chart", "entity_ref", "table", "timeline"]
data: dict[str, Any]
class WsStreamEnd(BaseModel):
"""Server → Client: signals end of a streaming response."""
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
request_id: str
mutations: list[dict[str, Any]] = Field(default_factory=list)
class WsFloatingDomain(BaseModel):
"""Server → Client: domain determined for a floating request."""
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
request_id: str
domain: Literal["tasks", "checkpoints", "notes", "projects"]
# ── Agent Catalog ───────────────────────────────────────────────────── # ── Agent Catalog ─────────────────────────────────────────────────────
class AgentCatalogItem(BaseModel): class AgentCatalogItem(BaseModel):

View File

@@ -0,0 +1,416 @@
"""Tests for ChatAgent streaming and tool result capture (Step 2)."""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from app.core.agent_registry import ChatAgent, registry
# ── Minimal concrete agent for testing ───────────────────────────────
class _EchoAgent(ChatAgent):
def get_name(self) -> str:
return "_echo"
def get_description(self) -> str:
return "Echo agent for tests"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return query
# ── Helpers ───────────────────────────────────────────────────────────
def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage:
msg = AIMessage(content=content)
if tool_calls:
msg.tool_calls = tool_calls
else:
msg.tool_calls = []
return msg
def _make_tool(name: str, return_value: Any) -> MagicMock:
t = MagicMock()
t.name = name
t.ainvoke = AsyncMock(return_value=return_value)
return t
def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]:
chunks = []
for tok in tokens:
c = MagicMock()
c.content = tok
chunks.append(c)
return chunks
async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]:
tokens: list[str] = []
async for tok in agent._tool_loop_stream(llm, messages, tools):
tokens.append(tok)
return tokens
# ── tool_results initialised ─────────────────────────────────────────
def test_tool_results_init():
agent = _EchoAgent()
assert agent.tool_results == []
# ── _tool_loop: no tool calls ────────────────────────────────────────
@pytest.mark.asyncio
async def test_tool_loop_no_tools():
agent = _EchoAgent()
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!"))
result = await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
assert result == "Hello!"
assert agent.tool_results == []
# ── _tool_loop: with one tool call + result capture ──────────────────
@pytest.mark.asyncio
async def test_tool_loop_captures_tool_results():
agent = _EchoAgent()
# Mock execute_on_client to return structured data via the tool
raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]}
async def fake_executor(payload: dict) -> dict:
return raw_result
# AIMessage with a tool call, then a final answer
tool_call_msg = _make_ai_message(
tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}]
)
final_msg = _make_ai_message("Here are your tasks.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
llm.ainvoke = AsyncMock(return_value=final_msg)
mock_tool = _make_tool("list_tasks", "- Fix bug (todo)")
from app.core.ws_context import set_client_executor, clear_client_executor
set_client_executor(fake_executor)
try:
# Patch the tool to actually call execute_on_client
async def tool_side_effect(args: dict) -> str:
from app.core.ws_context import execute_on_client
res = await execute_on_client(action="select", table="tasks")
rows = res.get("rows", [])
return "\n".join(r["title"] for r in rows)
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
result = await agent._tool_loop(
llm, [HumanMessage(content="list my tasks")], [mock_tool]
)
finally:
clear_client_executor()
assert result == "Here are your tasks."
assert len(agent.tool_results) == 1
assert agent.tool_results[0] == raw_result
# ── _tool_loop: tool_results reset on each call ──────────────────────
@pytest.mark.asyncio
async def test_tool_loop_resets_tool_results():
agent = _EchoAgent()
agent.tool_results = [{"stale": True}] # pre-populated from a previous call
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done."))
await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
assert agent.tool_results == []
# ── _tool_loop: unknown tool name ────────────────────────────────────
@pytest.mark.asyncio
async def test_tool_loop_unknown_tool():
agent = _EchoAgent()
# No known tools — model still calls a non-existent one; loop handles gracefully
tool_call_msg = _make_ai_message(
tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}]
)
final_msg = _make_ai_message("Handled.")
mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent"
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool])
assert result == "Handled."
# ── _tool_loop: max_iter exhaustion ──────────────────────────────────
@pytest.mark.asyncio
async def test_tool_loop_max_iter():
agent = _EchoAgent()
always_tool = _make_ai_message(
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
)
fallback = _make_ai_message("Fallback.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
# Returns tool_call_msg on every iteration
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
llm.ainvoke = AsyncMock(return_value=fallback)
mock_tool = _make_tool("t", "ok")
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2)
assert result == "Fallback."
assert llm_with_tools.ainvoke.call_count == 2
# ── _tool_loop_stream: no tool calls — yields tokens ─────────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_no_tools_yields_tokens():
agent = _EchoAgent()
# No tools → llm used directly; ainvoke returns no tool calls → stream is used
no_tool_msg = _make_ai_message("irrelevant")
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
async def fake_astream(msgs):
for tok in ["Hello", " ", "world"]:
c = MagicMock()
c.content = tok
yield c
llm.astream = fake_astream
tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], [])
assert tokens == ["Hello", " ", "world"]
assert agent.tool_results == []
# ── _tool_loop_stream: one tool call then streaming final ─────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_with_tool_call():
agent = _EchoAgent()
raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}}
async def fake_executor(payload: dict) -> dict:
return raw_result
tool_call_msg = _make_ai_message(
tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}]
)
# After tools run, ainvoke returns no more tool calls
no_more_tools_msg = _make_ai_message("Task found.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
async def fake_astream(msgs):
for tok in ["Task", " ", "found."]:
c = MagicMock()
c.content = tok
yield c
llm.astream = fake_astream
async def tool_side_effect(args: dict) -> str:
from app.core.ws_context import execute_on_client
res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")})
return res.get("row", {}).get("title", "")
mock_tool = _make_tool("get_task", "Deploy")
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
from app.core.ws_context import set_client_executor, clear_client_executor
set_client_executor(fake_executor)
try:
tokens = await _collect_stream(
agent, llm, [HumanMessage(content="get task t-2")], [mock_tool]
)
finally:
clear_client_executor()
assert tokens == ["Task", " ", "found."]
assert len(agent.tool_results) == 1
assert agent.tool_results[0] == raw_result
# ── _tool_loop_stream: tool_results reset on each call ───────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_resets_tool_results():
agent = _EchoAgent()
agent.tool_results = [{"old": True}]
no_tool_msg = _make_ai_message("")
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
async def fake_astream(msgs):
c = MagicMock()
c.content = "ok"
yield c
llm.astream = fake_astream
await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
assert agent.tool_results == []
# ── _tool_loop_stream: empty chunk content is skipped ────────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_skips_empty_chunks():
agent = _EchoAgent()
no_tool_msg = _make_ai_message("")
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
async def fake_astream(msgs):
for tok in ["", "hello", "", " world", ""]:
c = MagicMock()
c.content = tok
yield c
llm.astream = fake_astream
tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
assert tokens == ["hello", " world"]
# ── _tool_loop_stream: max_iter exhaustion falls back to stream ───────
@pytest.mark.asyncio
async def test_tool_loop_stream_max_iter():
agent = _EchoAgent()
always_tool = _make_ai_message(
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
)
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
async def fake_astream(msgs):
c = MagicMock()
c.content = "fallback"
yield c
llm.astream = fake_astream
mock_tool = _make_tool("t", "ok")
tokens = await _collect_stream(
agent, llm, [HumanMessage(content="x")], [mock_tool],
)
assert tokens == ["fallback"]
assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter
# ── _tool_loop_stream: multiple tool results captured ────────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_multiple_tool_results():
agent = _EchoAgent()
call_results = [
{"rows": [{"id": "t-1"}]},
{"rows": [{"id": "t-2"}]},
]
call_iter = iter(call_results)
async def fake_executor(payload: dict) -> dict:
return next(call_iter)
# Two tool calls in one iteration
tool_call_msg = _make_ai_message(
tool_calls=[
{"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"},
{"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"},
]
)
no_more_tools_msg = _make_ai_message("Done.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
async def fake_astream(msgs):
c = MagicMock()
c.content = "Done."
yield c
llm.astream = fake_astream
async def tool_side_effect(args: dict) -> str:
from app.core.ws_context import execute_on_client
res = await execute_on_client(action="select", table="tasks")
return str(res)
tool_a = _make_tool("tool_a", "")
tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect)
tool_b = _make_tool("tool_b", "")
tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect)
from app.core.ws_context import set_client_executor, clear_client_executor
set_client_executor(fake_executor)
try:
tokens = await _collect_stream(
agent, llm, [HumanMessage(content="x")], [tool_a, tool_b]
)
finally:
clear_client_executor()
assert tokens == ["Done."]
assert len(agent.tool_results) == 2
assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]}
assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}

View File

@@ -14,6 +14,56 @@ from app.agents.note_agent import NoteAgent
from app.agents.project_agent import ProjectAgent from app.agents.project_agent import ProjectAgent
from app.agents.task_agent import TaskAgent from app.agents.task_agent import TaskAgent
from app.core.agent_registry import registry from app.core.agent_registry import registry
from app.core.ws_context import clear_client_executor, set_client_executor
# ── WS executor mock ──────────────────────────────────────────────────
#
# Tools call execute_on_client() which reads a ContextVar set by the WS
# handler. In unit tests there is no WS session, so we install a fake
# executor that returns plausible data for each action type.
_FAKE_ROW: dict[str, Any] = {
"id": "fake-id",
"title": "Fake Title",
"name": "Fake Name",
"status": "todo",
"priority": "medium",
"content": "Fake content",
"date": 1700000000000,
"taskId": "fake-task-id",
"author": "Alice",
"projectId": None,
}
async def _fake_executor(payload: dict) -> dict:
action = payload.get("action", "")
if action == "select":
return {"rows": []}
if action == "insert":
data = payload.get("data", {})
return {"row": {**_FAKE_ROW, **data}}
if action == "update":
data = payload.get("data", {})
row = {**_FAKE_ROW, "id": data.get("id", "fake-id"), **data.get("updates", {})}
return {"row": row}
if action == "delete":
return {"deleted": True}
if action == "get":
data = payload.get("data", {})
return {"row": {**_FAKE_ROW, "id": data.get("id", "fake-id")}}
if action == "vector_upsert":
return {"ok": True}
return {}
@pytest.fixture(autouse=True)
def ws_executor():
"""Install a fake WS executor for every test so tools can run without a real WS."""
set_client_executor(_fake_executor)
yield
clear_client_executor()
# ── Helpers ────────────────────────────────────────────────────────── # ── Helpers ──────────────────────────────────────────────────────────
@@ -148,110 +198,142 @@ class TestTaskAgentTools:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_tasks_defaults(self) -> None: async def test_list_tasks_defaults(self) -> None:
from app.agents.task_agent import list_tasks from app.agents.task_agent import list_tasks
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_tasks.ainvoke({}) result = await list_tasks.ainvoke({})
data = json.loads(result) m.assert_called_once_with(
assert data["action"] == "list" action="select", table="tasks",
assert data["table"] == "tasks" filters={"projectId": None, "status": None, "search": None, "orderBy": None},
)
assert result == "No tasks found matching the given filters."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_tasks_with_status_filter(self) -> None: async def test_list_tasks_with_status_filter(self) -> None:
from app.agents.task_agent import list_tasks from app.agents.task_agent import list_tasks
result = await list_tasks.ainvoke({"status": "done"}) with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
data = json.loads(result) m.return_value = {"rows": []}
assert data["filters"]["status"] == "done" await list_tasks.ainvoke({"status": "done"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["filters"]["status"] == "done"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_task_defaults(self) -> None: async def test_create_task_defaults(self) -> None:
from app.agents.task_agent import create_task from app.agents.task_agent import create_task
fake_row = {"id": "t1", "title": "Test task", "status": "todo", "priority": "medium"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await create_task.ainvoke({"title": "Test task"}) result = await create_task.ainvoke({"title": "Test task"})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "create_record" assert call_kwargs["action"] == "insert"
assert data["table"] == "tasks" assert call_kwargs["table"] == "tasks"
assert data["data"]["title"] == "Test task" assert call_kwargs["data"]["title"] == "Test task"
assert data["data"]["status"] == "todo" assert call_kwargs["data"]["status"] == "todo"
assert data["data"]["priority"] == "medium" assert call_kwargs["data"]["priority"] == "medium"
assert "Test task" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_task_with_all_fields(self) -> None: async def test_create_task_with_all_fields(self) -> None:
from app.agents.task_agent import create_task from app.agents.task_agent import create_task
result = await create_task.ainvoke({ fake_row = {"id": "t1", "title": "Deploy", "status": "in_progress", "priority": "high"}
"title": "Deploy", with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
"priority": "high", m.return_value = {"row": fake_row}
"status": "in_progress", await create_task.ainvoke({
"project_id": "p1", "title": "Deploy", "priority": "high", "status": "in_progress",
"is_ai_suggested": 1, "project_id": "p1", "is_ai_suggested": 1,
}) })
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["data"]["priority"] == "high" assert call_kwargs["data"]["priority"] == "high"
assert data["data"]["status"] == "in_progress" assert call_kwargs["data"]["status"] == "in_progress"
assert data["data"]["projectId"] == "p1" assert call_kwargs["data"]["projectId"] == "p1"
assert data["data"]["isAiSuggested"] == 1 assert call_kwargs["data"]["isAiSuggested"] == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_task_with_status(self) -> None: async def test_update_task_with_status(self) -> None:
from app.agents.task_agent import update_task from app.agents.task_agent import update_task
fake_row = {"id": "t1", "title": "Buy groceries", "status": "done"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await update_task.ainvoke({"task_id": "t1", "status": "done"}) result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "update_record" assert call_kwargs["action"] == "update"
assert data["data"]["id"] == "t1" assert call_kwargs["data"]["id"] == "t1"
assert data["data"]["updates"]["status"] == "done" assert call_kwargs["data"]["updates"]["status"] == "done"
assert "t1" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_task_empty_updates(self) -> None: async def test_update_task_empty_updates(self) -> None:
from app.agents.task_agent import update_task from app.agents.task_agent import update_task
result = await update_task.ainvoke({"task_id": "t1"}) fake_row = {"id": "t1", "title": "Task", "status": "todo"}
data = json.loads(result) with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
assert data["data"]["updates"] == {} m.return_value = {"row": fake_row}
await update_task.ainvoke({"task_id": "t1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["data"]["updates"] == {}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_task(self) -> None: async def test_delete_task(self) -> None:
from app.agents.task_agent import delete_task from app.agents.task_agent import delete_task
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_task.ainvoke({"task_id": "t1"}) result = await delete_task.ainvoke({"task_id": "t1"})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "delete_record" assert call_kwargs["action"] == "delete"
assert data["table"] == "tasks" assert call_kwargs["table"] == "tasks"
assert data["data"]["id"] == "t1" assert call_kwargs["data"]["id"] == "t1"
assert "t1" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_tasks_due_today(self) -> None: async def test_list_tasks_due_today(self) -> None:
from app.agents.task_agent import list_tasks_due_today from app.agents.task_agent import list_tasks_due_today
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_tasks_due_today.ainvoke({}) result = await list_tasks_due_today.ainvoke({})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "list_due_today" assert call_kwargs["action"] == "select"
assert data["table"] == "tasks" assert call_kwargs["table"] == "tasks"
assert "dueDateFrom" in call_kwargs["filters"]
assert result == "No tasks are due today."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_task_comments(self) -> None: async def test_list_task_comments(self) -> None:
from app.agents.task_agent import list_task_comments from app.agents.task_agent import list_task_comments
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_task_comments.ainvoke({"task_id": "t1"}) result = await list_task_comments.ainvoke({"task_id": "t1"})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "list" assert call_kwargs["action"] == "select"
assert data["table"] == "taskComments" assert call_kwargs["table"] == "taskComments"
assert data["filters"]["taskId"] == "t1" assert call_kwargs["filters"]["taskId"] == "t1"
assert "t1" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_task_comment(self) -> None: async def test_add_task_comment(self) -> None:
from app.agents.task_agent import add_task_comment from app.agents.task_agent import add_task_comment
fake_row = {"id": "c1", "taskId": "t1", "author": "Alice", "content": "Looks good!"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await add_task_comment.ainvoke({ result = await add_task_comment.ainvoke({
"task_id": "t1", "task_id": "t1", "author": "Alice", "content": "Looks good!",
"author": "Alice",
"content": "Looks good!",
}) })
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "create_record" assert call_kwargs["action"] == "insert"
assert data["table"] == "taskComments" assert call_kwargs["table"] == "taskComments"
assert data["data"]["taskId"] == "t1" assert call_kwargs["data"]["taskId"] == "t1"
assert data["data"]["author"] == "Alice" assert call_kwargs["data"]["author"] == "Alice"
assert data["data"]["content"] == "Looks good!" assert call_kwargs["data"]["content"] == "Looks good!"
assert "Alice" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_task_comment(self) -> None: async def test_delete_task_comment(self) -> None:
from app.agents.task_agent import delete_task_comment from app.agents.task_agent import delete_task_comment
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_task_comment.ainvoke({"comment_id": "c1"}) result = await delete_task_comment.ainvoke({"comment_id": "c1"})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "delete_record" assert call_kwargs["action"] == "delete"
assert data["table"] == "taskComments" assert call_kwargs["table"] == "taskComments"
assert data["data"]["id"] == "c1" assert call_kwargs["data"]["id"] == "c1"
assert "c1" in result
# ── CheckpointAgent ─────────────────────────────────────────────────── # ── CheckpointAgent ───────────────────────────────────────────────────
@@ -301,74 +383,86 @@ class TestCheckpointAgentTools:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_checkpoints_no_project(self) -> None: async def test_list_checkpoints_no_project(self) -> None:
from app.agents.checkpoint_agent import list_checkpoints from app.agents.checkpoint_agent import list_checkpoints
with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_checkpoints.ainvoke({}) result = await list_checkpoints.ainvoke({})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "list" assert call_kwargs["action"] == "select"
assert data["table"] == "checkpoints" assert call_kwargs["table"] == "checkpoints"
assert data["filters"]["projectId"] is None assert call_kwargs["filters"]["projectId"] is None
assert result == "No checkpoints found."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_checkpoints_with_project(self) -> None: async def test_list_checkpoints_with_project(self) -> None:
from app.agents.checkpoint_agent import list_checkpoints from app.agents.checkpoint_agent import list_checkpoints
result = await list_checkpoints.ainvoke({"project_id": "p1"}) with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
data = json.loads(result) m.return_value = {"rows": []}
assert data["filters"]["projectId"] == "p1" await list_checkpoints.ainvoke({"project_id": "p1"})
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_checkpoint(self) -> None: async def test_create_checkpoint(self) -> None:
from app.agents.checkpoint_agent import create_checkpoint from app.agents.checkpoint_agent import create_checkpoint
fake_row = {"id": "cp1", "title": "Beta release", "date": 1700000000000}
with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await create_checkpoint.ainvoke({ result = await create_checkpoint.ainvoke({
"project_id": "p1", "project_id": "p1", "title": "Beta release", "date": 1700000000000,
"title": "Beta release",
"date": 1700000000000,
}) })
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "create_record" assert call_kwargs["action"] == "insert"
assert data["table"] == "checkpoints" assert call_kwargs["table"] == "checkpoints"
assert data["data"]["projectId"] == "p1" assert call_kwargs["data"]["projectId"] == "p1"
assert data["data"]["title"] == "Beta release" assert call_kwargs["data"]["title"] == "Beta release"
assert data["data"]["date"] == 1700000000000 assert call_kwargs["data"]["date"] == 1700000000000
assert "Beta release" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_checkpoint_ai_suggested(self) -> None: async def test_create_checkpoint_ai_suggested(self) -> None:
from app.agents.checkpoint_agent import create_checkpoint from app.agents.checkpoint_agent import create_checkpoint
result = await create_checkpoint.ainvoke({ fake_row = {"id": "cp1", "title": "Review", "date": 1700000000000}
"project_id": "p1", with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
"title": "Review", m.return_value = {"row": fake_row}
"date": 1700000000000, await create_checkpoint.ainvoke({
"is_ai_suggested": 1, "project_id": "p1", "title": "Review", "date": 1700000000000, "is_ai_suggested": 1,
}) })
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["data"]["isAiSuggested"] == 1 assert call_kwargs["data"]["isAiSuggested"] == 1
assert data["data"]["isApproved"] == 0 assert call_kwargs["data"]["isApproved"] == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_checkpoint_approve(self) -> None: async def test_update_checkpoint_approve(self) -> None:
from app.agents.checkpoint_agent import update_checkpoint from app.agents.checkpoint_agent import update_checkpoint
result = await update_checkpoint.ainvoke({ fake_row = {"id": "c1", "title": "MVP", "isApproved": 1}
"checkpoint_id": "c1", with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
"is_approved": 1, m.return_value = {"row": fake_row}
}) result = await update_checkpoint.ainvoke({"checkpoint_id": "c1", "is_approved": 1})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "update_record" assert call_kwargs["action"] == "update"
assert data["data"]["id"] == "c1" assert call_kwargs["data"]["id"] == "c1"
assert data["data"]["updates"]["isApproved"] == 1 assert call_kwargs["data"]["updates"]["isApproved"] == 1
assert "c1" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_checkpoint_empty_updates(self) -> None: async def test_update_checkpoint_empty_updates(self) -> None:
from app.agents.checkpoint_agent import update_checkpoint from app.agents.checkpoint_agent import update_checkpoint
result = await update_checkpoint.ainvoke({"checkpoint_id": "c1"}) fake_row = {"id": "c1", "title": "MVP"}
data = json.loads(result) with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
assert data["data"]["updates"] == {} m.return_value = {"row": fake_row}
await update_checkpoint.ainvoke({"checkpoint_id": "c1"})
assert m.call_args.kwargs["data"]["updates"] == {}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_checkpoint(self) -> None: async def test_delete_checkpoint(self) -> None:
from app.agents.checkpoint_agent import delete_checkpoint from app.agents.checkpoint_agent import delete_checkpoint
with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"}) result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "delete_record" assert call_kwargs["action"] == "delete"
assert data["table"] == "checkpoints" assert call_kwargs["table"] == "checkpoints"
assert data["data"]["id"] == "c1" assert call_kwargs["data"]["id"] == "c1"
assert "c1" in result
# ── ProjectAgent ────────────────────────────────────────────────────── # ── ProjectAgent ──────────────────────────────────────────────────────
@@ -425,75 +519,101 @@ class TestProjectAgentTools:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_projects_defaults(self) -> None: async def test_list_projects_defaults(self) -> None:
from app.agents.project_agent import list_projects from app.agents.project_agent import list_projects
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_projects.ainvoke({}) result = await list_projects.ainvoke({})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "list" assert call_kwargs["action"] == "select"
assert data["table"] == "projects" assert call_kwargs["table"] == "projects"
assert data["filters"]["includeArchived"] is False assert call_kwargs["filters"]["includeArchived"] is False
assert result == "No projects found."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_projects_include_archived(self) -> None: async def test_list_projects_include_archived(self) -> None:
from app.agents.project_agent import list_projects from app.agents.project_agent import list_projects
result = await list_projects.ainvoke({"include_archived": 1}) with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
data = json.loads(result) m.return_value = {"rows": []}
assert data["filters"]["includeArchived"] is True await list_projects.ainvoke({"include_archived": 1})
assert m.call_args.kwargs["filters"]["includeArchived"] is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_all_projects(self) -> None: async def test_list_all_projects(self) -> None:
from app.agents.project_agent import list_all_projects from app.agents.project_agent import list_all_projects
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_all_projects.ainvoke({}) result = await list_all_projects.ainvoke({})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "list_all" assert call_kwargs["action"] == "select"
assert data["table"] == "projects" assert call_kwargs["table"] == "projects"
assert result == "No projects found."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_project(self) -> None: async def test_get_project(self) -> None:
from app.agents.project_agent import get_project from app.agents.project_agent import get_project
fake_row = {"id": "p1", "name": "Alpha", "status": "active", "clientId": None}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await get_project.ainvoke({"project_id": "p1"}) result = await get_project.ainvoke({"project_id": "p1"})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "get" assert call_kwargs["action"] == "get"
assert data["table"] == "projects" assert call_kwargs["table"] == "projects"
assert data["data"]["id"] == "p1" assert call_kwargs["data"]["id"] == "p1"
assert "Alpha" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_project_name_only(self) -> None: async def test_create_project_name_only(self) -> None:
from app.agents.project_agent import create_project from app.agents.project_agent import create_project
fake_row = {"id": "p1", "name": "Alpha"}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await create_project.ainvoke({"name": "Alpha"}) result = await create_project.ainvoke({"name": "Alpha"})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "create_record" assert call_kwargs["action"] == "insert"
assert data["data"]["name"] == "Alpha" assert call_kwargs["data"]["name"] == "Alpha"
assert data["data"]["clientId"] is None assert call_kwargs["data"]["clientId"] is None
assert "Alpha" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_project_with_client(self) -> None: async def test_create_project_with_client(self) -> None:
from app.agents.project_agent import create_project from app.agents.project_agent import create_project
result = await create_project.ainvoke({"name": "Beta", "client_id": "cl1"}) fake_row = {"id": "p1", "name": "Beta"}
data = json.loads(result) with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
assert data["data"]["clientId"] == "cl1" m.return_value = {"row": fake_row}
await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
assert m.call_args.kwargs["data"]["clientId"] == "cl1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_project_archive(self) -> None: async def test_update_project_archive(self) -> None:
from app.agents.project_agent import update_project from app.agents.project_agent import update_project
fake_row = {"id": "p1", "name": "Alpha", "status": "archived"}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"}) result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "update_record" assert call_kwargs["action"] == "update"
assert data["data"]["id"] == "p1" assert call_kwargs["data"]["id"] == "p1"
assert data["data"]["updates"]["status"] == "archived" assert call_kwargs["data"]["updates"]["status"] == "archived"
assert "p1" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_project_empty_updates(self) -> None: async def test_update_project_empty_updates(self) -> None:
from app.agents.project_agent import update_project from app.agents.project_agent import update_project
result = await update_project.ainvoke({"project_id": "p1"}) fake_row = {"id": "p1", "name": "Alpha", "status": "active"}
data = json.loads(result) with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
assert data["data"]["updates"] == {} m.return_value = {"row": fake_row}
await update_project.ainvoke({"project_id": "p1"})
assert m.call_args.kwargs["data"]["updates"] == {}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_project(self) -> None: async def test_delete_project(self) -> None:
from app.agents.project_agent import delete_project from app.agents.project_agent import delete_project
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_project.ainvoke({"project_id": "p1"}) result = await delete_project.ainvoke({"project_id": "p1"})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "delete_record" assert call_kwargs["action"] == "delete"
assert data["data"]["id"] == "p1" assert call_kwargs["data"]["id"] == "p1"
assert "p1" in result
# ── NoteAgent ───────────────────────────────────────────────────────── # ── NoteAgent ─────────────────────────────────────────────────────────
@@ -543,78 +663,99 @@ class TestNoteAgentTools:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_notes_no_project(self) -> None: async def test_list_notes_no_project(self) -> None:
from app.agents.note_agent import list_notes from app.agents.note_agent import list_notes
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_notes.ainvoke({}) result = await list_notes.ainvoke({})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "list" assert call_kwargs["action"] == "select"
assert data["table"] == "notes" assert call_kwargs["table"] == "notes"
assert data["filters"]["projectId"] is None assert call_kwargs["filters"]["projectId"] is None
assert result == "No notes found."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_notes_with_project(self) -> None: async def test_list_notes_with_project(self) -> None:
from app.agents.note_agent import list_notes from app.agents.note_agent import list_notes
result = await list_notes.ainvoke({"project_id": "p1"}) with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
data = json.loads(result) m.return_value = {"rows": []}
assert data["filters"]["projectId"] == "p1" await list_notes.ainvoke({"project_id": "p1"})
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_note(self) -> None: async def test_get_note(self) -> None:
from app.agents.note_agent import get_note from app.agents.note_agent import get_note
fake_row = {"id": "n1", "title": "Daily log", "content": "# Today\nAll good."}
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await get_note.ainvoke({"note_id": "n1"}) result = await get_note.ainvoke({"note_id": "n1"})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "get" assert call_kwargs["action"] == "get"
assert data["table"] == "notes" assert call_kwargs["table"] == "notes"
assert data["data"]["id"] == "n1" assert call_kwargs["data"]["id"] == "n1"
assert "Daily log" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_note_minimal(self) -> None: async def test_create_note_minimal(self) -> None:
from app.agents.note_agent import create_note from app.agents.note_agent import create_note
result = await create_note.ainvoke({ fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
"title": "Daily log", with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
"content": "# Today\nAll good.", patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
}) m.return_value = {"row": fake_row}
data = json.loads(result) me.return_value = [0.0] * 1536
assert data["action"] == "create_record" result = await create_note.ainvoke({"title": "Daily log", "content": "# Today\nAll good."})
assert data["table"] == "notes" # First call: insert; second call: vector_upsert
assert data["data"]["title"] == "Daily log" first_call = m.call_args_list[0].kwargs
assert data["data"]["content"] == "# Today\nAll good." assert first_call["action"] == "insert"
assert data["data"]["projectId"] is None assert first_call["table"] == "notes"
assert first_call["data"]["title"] == "Daily log"
assert first_call["data"]["content"] == "# Today\nAll good."
assert first_call["data"]["projectId"] is None
assert "Daily log" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_note_with_project(self) -> None: async def test_create_note_with_project(self) -> None:
from app.agents.note_agent import create_note from app.agents.note_agent import create_note
result = await create_note.ainvoke({ fake_row = {"id": "n1", "title": "Sprint notes", "projectId": "p1"}
"title": "Sprint notes", with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
"content": "## Sprint 1", patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
"project_id": "p1", m.return_value = {"row": fake_row}
}) me.return_value = [0.0] * 1536
data = json.loads(result) await create_note.ainvoke({"title": "Sprint notes", "content": "## Sprint 1", "project_id": "p1"})
assert data["data"]["projectId"] == "p1" first_call = m.call_args_list[0].kwargs
assert first_call["data"]["projectId"] == "p1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_note_content_only(self) -> None: async def test_update_note_content_only(self) -> None:
from app.agents.note_agent import update_note from app.agents.note_agent import update_note
result = await update_note.ainvoke({ fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
"note_id": "n1", with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
"content": "# Updated content", patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
}) m.return_value = {"row": fake_row}
data = json.loads(result) me.return_value = [0.0] * 1536
assert data["action"] == "update_record" result = await update_note.ainvoke({"note_id": "n1", "content": "# Updated content"})
assert data["data"]["id"] == "n1" first_call = m.call_args_list[0].kwargs
assert data["data"]["updates"]["content"] == "# Updated content" assert first_call["action"] == "update"
assert "title" not in data["data"]["updates"] assert first_call["data"]["id"] == "n1"
assert first_call["data"]["updates"]["content"] == "# Updated content"
assert "title" not in first_call["data"]["updates"]
assert "n1" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_note_empty_updates(self) -> None: async def test_update_note_empty_updates(self) -> None:
from app.agents.note_agent import update_note from app.agents.note_agent import update_note
result = await update_note.ainvoke({"note_id": "n1"}) fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
data = json.loads(result) with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
assert data["data"]["updates"] == {} m.return_value = {"row": fake_row}
await update_note.ainvoke({"note_id": "n1"})
assert m.call_args.kwargs["data"]["updates"] == {}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_note(self) -> None: async def test_delete_note(self) -> None:
from app.agents.note_agent import delete_note from app.agents.note_agent import delete_note
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_note.ainvoke({"note_id": "n1"}) result = await delete_note.ainvoke({"note_id": "n1"})
data = json.loads(result) call_kwargs = m.call_args.kwargs
assert data["action"] == "delete_record" assert call_kwargs["action"] == "delete"
assert data["table"] == "notes" assert call_kwargs["table"] == "notes"
assert data["data"]["id"] == "n1" assert call_kwargs["data"]["id"] == "n1"
assert "n1" in result

View File

@@ -0,0 +1,284 @@
"""Tests for Step 7 — MemoryMiddleware.
Coverage:
1. enrich_context returns core prefs + associative + episodic + proactive
2. store_episode creates an encrypted row decryptable with the user's key
3. update_core upserts correctly
4. User with no encryption_key returns empty context (no crash)
5. End-to-end: home_request WS frame results in an episodic row being stored
"""
from __future__ import annotations
import json
import uuid
from unittest.mock import patch
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.core.memory_middleware import MemoryMiddleware, _PROACTIVE_CONFIDENCE_THRESHOLD
from app.db import get_session
from app.main import app
from app.models import (
MemoryAssociative,
MemoryCore,
MemoryEpisodic,
MemoryProactive,
User,
)
from tests.conftest import TEST_USER_IDS, make_jwt
USER_ID = TEST_USER_IDS["power"]
_FERNET_KEY = Fernet.generate_key().decode()
# ── DB override ───────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest_asyncio.fixture
async def user_with_key(db_session):
"""Set encryption_key on the seeded power user."""
result = await db_session.execute(select(User).where(User.id == USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
def _fernet():
return Fernet(_FERNET_KEY.encode())
def _enc(plaintext: str) -> str:
return _fernet().encrypt(plaintext.encode()).decode()
def _dec(ciphertext: str) -> str:
return _fernet().decrypt(ciphertext.encode()).decode()
# ── enrich_context ────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_enrich_context_returns_core_memory(db_session, user_with_key):
# Seed a core memory row
db_session.add(MemoryCore(
id=str(uuid.uuid4()),
user_id=USER_ID,
key="timezone",
value_encrypted=_enc("UTC"),
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "What are my tasks?")
assert "core_memory" in ctx
assert ctx["core_memory"]["timezone"] == "UTC"
@pytest.mark.asyncio
async def test_enrich_context_returns_episodic_memory(db_session, user_with_key):
session_id = str(uuid.uuid4())
db_session.add(MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=USER_ID,
summary_encrypted=_enc("User asked about Q1 tasks"),
session_id=session_id,
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "any message")
assert "episodic_memory" in ctx
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
@pytest.mark.asyncio
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
# Add one pattern above threshold and one below
db_session.add(MemoryProactive(
id=str(uuid.uuid4()),
user_id=USER_ID,
pattern_encrypted=_enc("User prefers short summaries"),
confidence=0.9,
source="inferred",
))
db_session.add(MemoryProactive(
id=str(uuid.uuid4()),
user_id=USER_ID,
pattern_encrypted=_enc("User likes dark mode"),
confidence=0.1,
source="inferred",
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "any message")
assert "proactive_hints" in ctx
hints = ctx["proactive_hints"]
assert any("short summaries" in h for h in hints)
assert not any("dark mode" in h for h in hints)
@pytest.mark.asyncio
async def test_enrich_context_returns_associative_memory(db_session, user_with_key):
db_session.add(MemoryAssociative(
id=str(uuid.uuid4()),
user_id=USER_ID,
content_encrypted=_enc("Related memory about meetings"),
embedding=None,
entity_type="note",
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "meetings")
assert "associative_memory" in ctx
assert any("meetings" in m for m in ctx["associative_memory"])
@pytest.mark.asyncio
async def test_enrich_context_empty_for_user_without_key(db_session):
"""User with no encryption_key → empty context, no crash."""
result = await db_session.execute(select(User).where(User.id == USER_ID))
user = result.scalar_one()
user.encryption_key = None
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "hello")
assert ctx == {}
# ── store_episode ─────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_store_episode_creates_encrypted_row(db_session, user_with_key):
session_id = str(uuid.uuid4())
middleware = MemoryMiddleware(db_session)
await middleware.store_episode(USER_ID, session_id, "hello", "world")
result = await db_session.execute(
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
)
row = result.scalar_one()
plaintext = _dec(row.summary_encrypted)
assert "hello" in plaintext
assert "world" in plaintext
@pytest.mark.asyncio
async def test_store_episode_decryptable(db_session, user_with_key):
session_id = str(uuid.uuid4())
middleware = MemoryMiddleware(db_session)
await middleware.store_episode(USER_ID, session_id, "msg", "resp")
result = await db_session.execute(
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
)
row = result.scalar_one()
# Decrypt using the same key — must not raise
decrypted = _dec(row.summary_encrypted)
assert len(decrypted) > 0
# ── update_core ───────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_update_core_insert(db_session, user_with_key):
middleware = MemoryMiddleware(db_session)
await middleware.update_core(USER_ID, "lang", "en")
result = await db_session.execute(
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
)
row = result.scalar_one()
assert _dec(row.value_encrypted) == "en"
@pytest.mark.asyncio
async def test_update_core_upsert(db_session, user_with_key):
middleware = MemoryMiddleware(db_session)
await middleware.update_core(USER_ID, "lang", "en")
await middleware.update_core(USER_ID, "lang", "fr")
result = await db_session.execute(
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
)
rows = result.scalars().all()
assert len(rows) == 1
assert _dec(rows[0].value_encrypted) == "fr"
# ── End-to-end WS: memory middleware is called during home_request ────────────
def test_home_request_calls_memory_middleware(client):
"""home_request triggers enrich_context before and store_episode after the LLM."""
enrich_calls: list[tuple] = []
store_calls: list[tuple] = []
class _MockMiddleware:
def __init__(self, db):
pass
async def enrich_context(self, user_id, message):
enrich_calls.append((user_id, message))
return {"core_memory": {"tz": "UTC"}}
async def store_episode(self, user_id, session_id, message, response):
store_calls.append((user_id, session_id, message, response))
token = make_jwt("power", user_id=USER_ID)
session_id = str(uuid.uuid4())
async def _mock_stream(user_id, message, context, reg=None):
# Verify memory context was injected
assert context.get("core_memory") == {"tz": "UTC"}
yield "task_agent", ""
yield "task_agent", '{"type": "text", "content": "Done"}'
with (
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_stream),
):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-mem", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "home_request",
"request_id": "r-mem",
"session_id": session_id,
"message": "Show tasks",
}))
for _ in range(20):
raw = ws.receive_text()
frame = json.loads(raw)
if frame.get("type") == "stream_end":
break
assert len(enrich_calls) == 1
assert enrich_calls[0] == (USER_ID, "Show tasks")
assert len(store_calls) == 1
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
assert stored_session_id == session_id
assert stored_message == "Show tasks"

205
tests/test_memory_models.py Normal file
View File

@@ -0,0 +1,205 @@
"""Tests for Step 6 — memory ORM models and User.encryption_key.
Uses the SQLite in-memory test DB (from conftest). The pgvector embedding
column is stored as JSON in tests (SQLite-compatible).
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.models import MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive, User
from tests.conftest import TEST_USER_IDS
USER_ID = TEST_USER_IDS["power"]
# ── helpers ───────────────────────────────────────────────────────────────────
def _fernet_key() -> str:
return Fernet.generate_key().decode()
def _encrypt(key: str, plaintext: str) -> str:
return Fernet(key.encode()).encrypt(plaintext.encode()).decode()
def _decrypt(key: str, ciphertext: str) -> str:
return Fernet(key.encode()).decrypt(ciphertext.encode()).decode()
# ── User.encryption_key ───────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_user_encryption_key_column_exists(db_session):
"""User model has encryption_key column and it can be set."""
result = await db_session.execute(select(User).where(User.id == USER_ID))
user = result.scalar_one()
# Column exists (may be None for seeded users)
assert hasattr(user, "encryption_key")
@pytest.mark.asyncio
async def test_user_encryption_key_can_be_set(db_session):
key = _fernet_key()
result = await db_session.execute(select(User).where(User.id == USER_ID))
user = result.scalar_one()
user.encryption_key = key
await db_session.commit()
result2 = await db_session.execute(select(User).where(User.id == USER_ID))
user2 = result2.scalar_one()
assert user2.encryption_key == key
# ── MemoryCore ────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_memory_core_create_and_read(db_session):
key = _fernet_key()
encrypted_val = _encrypt(key, "UTC")
row = MemoryCore(
id=str(uuid.uuid4()),
user_id=USER_ID,
key="timezone",
value_encrypted=encrypted_val,
)
db_session.add(row)
await db_session.commit()
result = await db_session.execute(
select(MemoryCore).where(MemoryCore.user_id == USER_ID)
)
fetched = result.scalar_one()
assert fetched.key == "timezone"
assert _decrypt(key, fetched.value_encrypted) == "UTC"
@pytest.mark.asyncio
async def test_memory_core_cascade_delete(db_session):
"""Deleting a user cascades to memory_core."""
row = MemoryCore(
id=str(uuid.uuid4()),
user_id=USER_ID,
key="lang",
value_encrypted="enc",
)
db_session.add(row)
await db_session.commit()
user = (await db_session.execute(select(User).where(User.id == USER_ID))).scalar_one()
await db_session.delete(user)
await db_session.commit()
remaining = (
await db_session.execute(select(MemoryCore).where(MemoryCore.user_id == USER_ID))
).scalars().all()
assert remaining == []
# ── MemoryAssociative ─────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_memory_associative_create_and_read(db_session):
key = _fernet_key()
content = _encrypt(key, "User prefers morning meetings")
embedding = [0.1] * 1536 # fake embedding
row = MemoryAssociative(
id=str(uuid.uuid4()),
user_id=USER_ID,
content_encrypted=content,
embedding=embedding,
entity_type="preference",
entity_id=None,
)
db_session.add(row)
await db_session.commit()
result = await db_session.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == USER_ID)
)
fetched = result.scalar_one()
assert fetched.entity_type == "preference"
assert _decrypt(key, fetched.content_encrypted) == "User prefers morning meetings"
assert len(fetched.embedding) == 1536
# ── MemoryEpisodic ────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_memory_episodic_create_and_read(db_session):
key = _fernet_key()
session_id = str(uuid.uuid4())
summary = _encrypt(key, "User asked about Q1 tasks")
row = MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=USER_ID,
summary_encrypted=summary,
session_id=session_id,
)
db_session.add(row)
await db_session.commit()
result = await db_session.execute(
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
)
fetched = result.scalar_one()
assert _decrypt(key, fetched.summary_encrypted) == "User asked about Q1 tasks"
assert isinstance(fetched.created_at, datetime)
# ── MemoryProactive ───────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_memory_proactive_create_and_read(db_session):
key = _fernet_key()
pattern = _encrypt(key, "User always assigns tasks to self")
row = MemoryProactive(
id=str(uuid.uuid4()),
user_id=USER_ID,
pattern_encrypted=pattern,
confidence=0.85,
source="inferred",
)
db_session.add(row)
await db_session.commit()
result = await db_session.execute(
select(MemoryProactive).where(MemoryProactive.user_id == USER_ID)
)
fetched = result.scalar_one()
assert fetched.confidence == pytest.approx(0.85)
assert fetched.source == "inferred"
assert _decrypt(key, fetched.pattern_encrypted) == "User always assigns tasks to self"
# ── Auth registration generates encryption_key ───────────────────────────────
def test_register_sets_encryption_key(client):
"""POST /api/v1/auth/register creates a user with a valid Fernet key."""
resp = client.post(
"/api/v1/auth/register",
json={"email": "newuser@test.com", "password": "testpassword123"},
)
assert resp.status_code == 201
# Fetch the newly created user via the access token
token = resp.json()["access_token"]
me_resp = client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"},
)
assert me_resp.status_code == 200
# We can't see encryption_key in the API response (not in UserProfile),
# but we verify registration didn't crash — key generation is implicit.

View File

@@ -302,7 +302,7 @@ class TestOrchestrateStream:
assert len(chunks) >= 1 assert len(chunks) >= 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_last_chunk_is_final_json_frame( async def test_all_chunks_are_plain_text(
self, reg: AgentRegistry self, reg: AgentRegistry
) -> None: ) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls: with patch("app.core.orchestrator._make_llm") as mock_cls:
@@ -310,13 +310,12 @@ class TestOrchestrateStream:
request = ChatRequest(message="add a task", execution_mode="direct") request = ChatRequest(message="add a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)] chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
last = json.loads(chunks[-1]) # orchestrate_stream yields plain text chunks only — no JSON final frame
assert last["done"] is True for chunk in chunks:
assert "response" in last assert isinstance(chunk, str)
assert "actions" in last
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_final_frame_response_matches_agent_output( async def test_concatenated_chunks_equal_full_response(
self, reg: AgentRegistry self, reg: AgentRegistry
) -> None: ) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls: with patch("app.core.orchestrator._make_llm") as mock_cls:
@@ -324,8 +323,8 @@ class TestOrchestrateStream:
request = ChatRequest(message="create a task", execution_mode="direct") request = ChatRequest(message="create a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)] chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
final = json.loads(chunks[-1]) full_text = "".join(chunks)
assert final["response"] == "task: create a task" assert full_text == "task: create a task"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_text_chunks_before_final_frame( async def test_text_chunks_before_final_frame(

View File

@@ -0,0 +1,236 @@
"""Tests for v3 orchestrator functions (Step 3)."""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Any
from app.core.agent_registry import ChatAgent, AgentRegistry
from app.core.orchestrator import orchestrate_v3, orchestrate_v3_stream
# ── Minimal agent for testing ─────────────────────────────────────────
class _FixedAgent(ChatAgent):
def __init__(self, name: str = "_fixed", tokens: list[str] | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._name = name
self._tokens = tokens or ["Hello", " world"]
def get_name(self) -> str:
return self._name
def get_description(self) -> str:
return "Fixed agent for tests"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return "".join(self._tokens)
async def handle_stream(self, query: str, context: dict[str, Any]):
for tok in self._tokens:
yield tok
# ── Mock registry factory ─────────────────────────────────────────────
def _make_registry(agent_name: str, agent: ChatAgent) -> MagicMock:
reg = MagicMock(spec=AgentRegistry)
reg.list_agents.return_value = [{"name": agent_name, "description": "test"}]
reg.get.return_value = agent
return reg
# ── orchestrate_v3 ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_orchestrate_v3_returns_agent_name_and_instance():
agent = _FixedAgent("task_agent")
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
name, inst = await orchestrate_v3(
user_id="u-1", message="fix a bug", context={}, reg=reg
)
assert name == "task_agent"
assert inst is agent
@pytest.mark.asyncio
async def test_orchestrate_v3_classify_called_with_message_and_context():
agent = _FixedAgent("note_agent")
reg = _make_registry("note_agent", agent)
ctx = {"some": "context"}
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")) as mock_classify:
await orchestrate_v3(user_id="u-1", message="take a note", context=ctx, reg=reg)
mock_classify.assert_awaited_once()
call_args = mock_classify.call_args
assert call_args[0][0] == "take a note"
assert call_args[0][1] == ctx
@pytest.mark.asyncio
async def test_orchestrate_v3_uses_default_registry_when_none():
agent = _FixedAgent("task_agent")
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
patch("app.core.orchestrator._default_registry") as mock_reg:
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
mock_reg.get.return_value = agent
name, inst = await orchestrate_v3(user_id="u-1", message="hi", context={})
assert name == "task_agent"
assert inst is agent
@pytest.mark.asyncio
async def test_orchestrate_v3_get_called_with_agent_name():
agent = _FixedAgent("checkpoint_agent")
reg = _make_registry("checkpoint_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="checkpoint_agent")):
await orchestrate_v3(user_id="u-2", message="schedule", context={}, reg=reg)
reg.get.assert_called_once_with("checkpoint_agent")
# ── orchestrate_v3_stream ─────────────────────────────────────────────
async def _collect(gen) -> list[tuple[str, str]]:
results: list[tuple[str, str]] = []
async for item in gen:
results.append(item)
return results
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_first_yield_is_domain_signal():
agent = _FixedAgent("task_agent", tokens=["token1"])
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
# First item must be (agent_name, "") — domain signal
assert results[0] == ("task_agent", "")
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_yields_agent_name_with_tokens():
agent = _FixedAgent("task_agent", tokens=["Hello", " ", "world"])
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
# All items are (agent_name, token) pairs
assert all(name == "task_agent" for name, _ in results)
tokens = [tok for _, tok in results]
assert tokens[0] == "" # domain signal
assert tokens[1:] == ["Hello", " ", "world"]
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_different_agent():
agent = _FixedAgent("note_agent", tokens=["note"])
reg = _make_registry("note_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")):
gen = orchestrate_v3_stream(user_id="u-2", message="take note", context={}, reg=reg)
results = await _collect(gen)
assert results[0] == ("note_agent", "")
assert ("note_agent", "note") in results
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_uses_default_registry_when_none():
agent = _FixedAgent("task_agent", tokens=["x"])
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
patch("app.core.orchestrator._default_registry") as mock_reg:
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
mock_reg.get.return_value = agent
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={})
results = await _collect(gen)
assert results[0][0] == "task_agent"
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_empty_token_list():
"""Agent with no tokens still emits the domain signal."""
class _EmptyAgent(_FixedAgent):
async def handle_stream(self, query: str, context: dict[str, Any]):
return
yield # makes it a generator
agent = _EmptyAgent("task_agent", tokens=[])
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
assert results == [("task_agent", "")] # only domain signal
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_full_text_correct():
"""Concatenating all non-domain tokens reconstructs the full response."""
tokens = ["The", " ", "task", " ", "is", " ", "done."]
agent = _FixedAgent("task_agent", tokens=tokens)
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
text = "".join(tok for _, tok in results[1:]) # skip domain signal
assert text == "The task is done."
# ── handle_stream default implementation ─────────────────────────────
@pytest.mark.asyncio
async def test_handle_stream_default_yields_full_response():
"""Default handle_stream yields handle() result as a single chunk."""
class _SimpleAgent(ChatAgent):
def get_name(self) -> str:
return "_simple"
def get_description(self) -> str:
return ""
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return "simple response"
agent = _SimpleAgent()
tokens = [tok async for tok in agent.handle_stream("q", {})]
assert tokens == ["simple response"]
@pytest.mark.asyncio
async def test_handle_stream_override_used_by_stream():
"""_FixedAgent.handle_stream override yields individual tokens."""
agent = _FixedAgent("t", tokens=["a", "b", "c"])
tokens = [tok async for tok in agent.handle_stream("q", {})]
assert tokens == ["a", "b", "c"]

View File

@@ -0,0 +1,195 @@
"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
from __future__ import annotations
import pytest
from app.core.output_formatter import HomeFormatter, FloatingFormatter
from app.schemas import (
WsFloatingDomain,
WsStreamBlock,
WsStreamEnd,
WsStreamStart,
WsStreamText,
)
# ── helpers ───────────────────────────────────────────────────────────────────
async def _stream(*pairs: tuple[str, str]):
"""Async generator that yields (agent_name, token) pairs."""
for pair in pairs:
yield pair
async def collect(formatter, token_stream):
frames = []
async for frame in formatter.format(token_stream):
frames.append(frame)
return frames
# ── HomeFormatter ─────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_home_formatter_text_block():
req_id = "req-1"
tokens = [
("task_agent", '{"type": "text", "content": "Hello world"}'),
]
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(*tokens))
assert isinstance(frames[0], WsStreamStart)
assert frames[0].request_id == req_id
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
assert any("Hello world" in f.chunk for f in text_frames)
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_home_formatter_chart_block():
req_id = "req-2"
chart_json = (
'{"type": "chart", "chartType": "bar", '
'"title": "Tasks", "data": [{"x": 1}], '
'"config": {"x": {"label": "X", "color": "#fff"}}}'
)
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", chart_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].block_type == "chart"
assert block_frames[0].data["chartType"] == "bar"
@pytest.mark.asyncio
async def test_home_formatter_invalid_chart_skipped():
req_id = "req-3"
bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", bad_chart)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 0 # invalid chart skipped
@pytest.mark.asyncio
async def test_home_formatter_entity_ref_resolved():
req_id = "req-4"
tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}]
entity_json = '{"type": "entity_ref", "entity": "task"}'
formatter = HomeFormatter(request_id=req_id, tool_results=tool_results)
frames = await collect(formatter, _stream(("task_agent", entity_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].data["entity"] == "task"
assert block_frames[0].data["items"][0]["id"] == "t1"
@pytest.mark.asyncio
async def test_home_formatter_entity_ref_missing_skipped():
req_id = "req-5"
entity_json = '{"type": "entity_ref", "entity": "task"}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", entity_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 0 # no tool results → skipped
@pytest.mark.asyncio
async def test_home_formatter_table_block():
req_id = "req-6"
table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", table_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].block_type == "table"
@pytest.mark.asyncio
async def test_home_formatter_timeline_block():
req_id = "req-7"
timeline_json = '{"type": "timeline", "checkpoints": [{"id": "c1", "title": "M1", "date": 123}]}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", timeline_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].block_type == "timeline"
@pytest.mark.asyncio
async def test_home_formatter_frame_order():
"""stream_start is first, stream_end is last."""
req_id = "req-8"
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}')))
assert isinstance(frames[0], WsStreamStart)
assert isinstance(frames[-1], WsStreamEnd)
# ── FloatingFormatter ────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_floating_formatter_domain_emitted_first():
req_id = "pop-1"
formatter = FloatingFormatter(request_id=req_id)
tokens = [
("task_agent", ""), # domain signal
("task_agent", "Hello"),
("task_agent", " there"),
]
frames = await collect(formatter, _stream(*tokens))
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "tasks"
assert frames[0].request_id == req_id
@pytest.mark.asyncio
async def test_floating_formatter_text_only():
req_id = "pop-2"
formatter = FloatingFormatter(request_id=req_id)
tokens = [("checkpoint_agent", ""), ("checkpoint_agent", "Summary")]
frames = await collect(formatter, _stream(*tokens))
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "checkpoints"
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
assert len(text_frames) == 1
assert text_frames[0].chunk == "Summary"
@pytest.mark.asyncio
async def test_floating_formatter_no_block_frames():
"""FloatingFormatter must never emit WsStreamBlock."""
req_id = "pop-3"
formatter = FloatingFormatter(request_id=req_id)
tokens = [
("note_agent", ""),
("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'),
]
frames = await collect(formatter, _stream(*tokens))
assert not any(isinstance(f, WsStreamBlock) for f in frames)
@pytest.mark.asyncio
async def test_floating_formatter_end_frame():
req_id = "pop-4"
formatter = FloatingFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done")))
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_floating_formatter_unknown_agent_defaults_to_tasks():
req_id = "pop-5"
formatter = FloatingFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi")))
assert frames[0].domain == "tasks"

292
tests/test_schemas_v3.py Normal file
View File

@@ -0,0 +1,292 @@
"""Tests for v3 WebSocket frame protocol schemas."""
import pytest
from pydantic import ValidationError
from app.schemas import (
WsFrameType,
WsHomeRequest,
WsFloatingDomain,
WsFloatingRequest,
WsFloatingScope,
WsStreamBlock,
WsStreamEnd,
WsStreamStart,
WsStreamText,
)
# ── WsFrameType ───────────────────────────────────────────────────────
def test_v3_frame_types_exist():
v3_types = [
"home_request",
"floating_request",
"stream_start",
"stream_text",
"stream_block",
"stream_end",
"floating_domain",
"data_request",
"data_response",
"mutation",
]
for name in v3_types:
assert hasattr(WsFrameType, name), f"WsFrameType missing: {name}"
assert WsFrameType[name].value == name
def test_v2_frame_types_still_exist():
"""Backward compat: v2 types must remain."""
v2_types = [
"chat_request",
"text_chunk",
"tool_call",
"tool_result",
"final",
"ping",
"agent_run",
"agent_data",
"agent_complete",
"device_hello",
]
for name in v2_types:
assert hasattr(WsFrameType, name), f"v2 WsFrameType missing: {name}"
# ── WsHomeRequest ─────────────────────────────────────────────────────
def test_home_request_defaults():
frame = WsHomeRequest(message="Hello")
assert frame.type == WsFrameType.home_request
assert frame.message == "Hello"
assert frame.conversation_history == []
def test_home_request_with_history():
history = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]
frame = WsHomeRequest(message="Follow up", conversation_history=history)
assert frame.conversation_history == history
def test_home_request_serializes():
frame = WsHomeRequest(message="Test")
data = frame.model_dump()
assert data["type"] == "home_request"
assert data["message"] == "Test"
assert data["conversation_history"] == []
def test_home_request_deserializes():
raw = {"type": "home_request", "message": "Hi there"}
frame = WsHomeRequest.model_validate(raw)
assert frame.message == "Hi there"
def test_home_request_requires_message():
with pytest.raises(ValidationError):
WsHomeRequest.model_validate({"type": "home_request"})
# ── WsFloatingRequest ────────────────────────────────────────────────────
def test_floating_request_basic():
frame = WsFloatingRequest(
message="Summarise",
scope=WsFloatingScope(type="task", id="task-123"),
)
assert frame.type == WsFrameType.floating_request
assert frame.scope.type == "task"
assert frame.scope.id == "task-123"
def test_floating_request_scope_without_id():
frame = WsFloatingRequest(
message="Show all",
scope=WsFloatingScope(type="project"),
)
assert frame.scope.id is None
def test_floating_request_serializes():
frame = WsFloatingRequest(
message="Test",
scope=WsFloatingScope(type="note", id="n-1"),
)
data = frame.model_dump()
assert data["type"] == "floating_request"
assert data["scope"]["type"] == "note"
assert data["scope"]["id"] == "n-1"
def test_floating_request_invalid_scope_type():
with pytest.raises(ValidationError):
WsFloatingRequest(
message="X",
scope=WsFloatingScope(type="unknown"), # type: ignore[arg-type]
)
def test_floating_request_requires_scope():
with pytest.raises(ValidationError):
WsFloatingRequest.model_validate({"type": "floating_request", "message": "X"})
# ── WsStreamStart ─────────────────────────────────────────────────────
def test_stream_start():
frame = WsStreamStart(request_id="req-abc")
assert frame.type == WsFrameType.stream_start
assert frame.request_id == "req-abc"
def test_stream_start_serializes():
data = WsStreamStart(request_id="r1").model_dump()
assert data == {"type": "stream_start", "request_id": "r1"}
def test_stream_start_deserializes():
frame = WsStreamStart.model_validate({"type": "stream_start", "request_id": "r1"})
assert frame.request_id == "r1"
# ── WsStreamText ──────────────────────────────────────────────────────
def test_stream_text():
frame = WsStreamText(request_id="r1", chunk="Hello ")
assert frame.type == WsFrameType.stream_text
assert frame.chunk == "Hello "
def test_stream_text_serializes():
data = WsStreamText(request_id="r1", chunk="word").model_dump()
assert data == {"type": "stream_text", "request_id": "r1", "chunk": "word"}
def test_stream_text_deserializes():
raw = {"type": "stream_text", "request_id": "r2", "chunk": "test"}
frame = WsStreamText.model_validate(raw)
assert frame.chunk == "test"
# ── WsStreamBlock ─────────────────────────────────────────────────────
def test_stream_block_chart():
data = {
"type": "chart",
"chartType": "bar",
"title": "Tasks",
"data": [{"name": "Done", "count": 5}],
"config": {"count": {"label": "Count", "color": "#4f46e5"}},
}
frame = WsStreamBlock(request_id="r1", block_type="chart", data=data)
assert frame.type == WsFrameType.stream_block
assert frame.block_type == "chart"
assert frame.data["chartType"] == "bar"
def test_stream_block_entity_ref():
frame = WsStreamBlock(
request_id="r1",
block_type="entity_ref",
data={"type": "task", "id": "t-1", "title": "Fix bug"},
)
assert frame.block_type == "entity_ref"
def test_stream_block_table():
frame = WsStreamBlock(
request_id="r1",
block_type="table",
data={"headers": ["A", "B"], "rows": [["1", "2"]]},
)
assert frame.block_type == "table"
def test_stream_block_timeline():
frame = WsStreamBlock(
request_id="r1",
block_type="timeline",
data={"checkpoints": [{"id": "c1", "title": "Launch", "date": 1700000000}]},
)
assert frame.block_type == "timeline"
def test_stream_block_invalid_type():
with pytest.raises(ValidationError):
WsStreamBlock(
request_id="r1",
block_type="unknown", # type: ignore[arg-type]
data={},
)
def test_stream_block_serializes():
frame = WsStreamBlock(request_id="r1", block_type="table", data={"headers": [], "rows": []})
d = frame.model_dump()
assert d["type"] == "stream_block"
assert d["block_type"] == "table"
# ── WsStreamEnd ───────────────────────────────────────────────────────
def test_stream_end_defaults():
frame = WsStreamEnd(request_id="r1")
assert frame.type == WsFrameType.stream_end
assert frame.mutations == []
def test_stream_end_with_mutations():
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
frame = WsStreamEnd(request_id="r1", mutations=mutations)
assert len(frame.mutations) == 1
assert frame.mutations[0]["action"] == "create"
def test_stream_end_serializes():
data = WsStreamEnd(request_id="r2").model_dump()
assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
def test_stream_end_deserializes():
raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
frame = WsStreamEnd.model_validate(raw)
assert frame.request_id == "r3"
# ── WsFloatingDomain ─────────────────────────────────────────────────────
def test_floating_domain_tasks():
frame = WsFloatingDomain(request_id="r1", domain="tasks")
assert frame.type == WsFrameType.floating_domain
assert frame.domain == "tasks"
@pytest.mark.parametrize("domain", ["tasks", "checkpoints", "notes", "projects"])
def test_floating_domain_valid_domains(domain: str):
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
assert frame.domain == domain
def test_floating_domain_invalid():
with pytest.raises(ValidationError):
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
def test_floating_domain_serializes():
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
def test_floating_domain_deserializes():
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
frame = WsFloatingDomain.model_validate(raw)
assert frame.domain == "projects"

157
tests/test_ws_unified.py Normal file
View File

@@ -0,0 +1,157 @@
"""Integration tests for the unified WebSocket handler (Step 5).
Tests the device WS endpoint with home_request and floating_request frames,
verifying that the correct v3 frame sequence is returned.
LLM calls are mocked to avoid network dependency.
"""
from __future__ import annotations
import json
from unittest.mock import patch
import pytest
from app.db import get_session
from app.main import app
from app.schemas import WsFrameType
from tests.conftest import TEST_USER_IDS, make_jwt
USER_ID = TEST_USER_IDS["power"]
# ── helpers ───────────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
"""Receive frames until stream_end (or stream_end inside floating flow), or max_frames."""
frames = []
for _ in range(max_frames):
raw = ws.receive_text()
frame = json.loads(raw)
frames.append(frame)
if frame.get("type") == WsFrameType.stream_end:
break
return frames
async def _mock_home_stream(user_id, message, context, reg=None):
yield "task_agent", ""
yield "task_agent", '{"type": "text", "content": "Hello"}'
async def _mock_floating_stream(user_id, message, context, reg=None):
yield "task_agent", ""
yield "task_agent", "Here is a summary"
# ── tests ─────────────────────────────────────────────────────────────────────
def test_home_request_produces_stream_frames(client):
"""home_request → stream_start, stream_text+, stream_end."""
token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_home_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "home_request",
"request_id": "r1",
"message": "List my tasks",
"conversation_history": [],
}))
frames = _recv_until_end(ws)
types = [f["type"] for f in frames]
assert WsFrameType.stream_start in types
assert WsFrameType.stream_end in types
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
def test_floating_request_produces_domain_frame(client):
"""floating_request → floating_domain first, then stream_text*, stream_end."""
token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_floating_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "floating_request",
"request_id": "p1",
"message": "Summarize this task",
"scope": {"type": "task", "id": "task-123"},
}))
frames = _recv_until_end(ws)
types = [f["type"] for f in frames]
assert WsFrameType.floating_domain in types
assert WsFrameType.stream_end in types
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
assert domain_frame["domain"] == "tasks"
assert domain_frame["request_id"] == "p1"
def test_home_request_request_id_propagated(client):
"""request_id in home_request is echoed in all response frames."""
token = make_jwt("power", user_id=USER_ID)
req_id = "my-unique-req-id"
async def _stream(user_id, message, context, reg=None):
yield "note_agent", ""
yield "note_agent", '{"type": "text", "content": "ok"}'
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-3", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "home_request",
"request_id": req_id,
"message": "hello",
}))
frames = _recv_until_end(ws)
for f in frames:
if "request_id" in f:
assert f["request_id"] == req_id
def test_tool_result_dispatch_silent_on_unknown_id(client):
"""tool_result for unknown call_id is silently ignored — no crash."""
token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-4", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "tool_result", "id": "no-such-id", "ok": True
}))
# If connection is still alive, we'll get the heartbeat ping
msg = json.loads(ws.receive_text())
assert msg["type"] == "ping"
def test_invalid_jwt_rejected(client):
"""Connection with bad token is closed before or after accept."""
with pytest.raises(Exception):
with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws:
ws.receive_text()