Compare commits
10 Commits
ac71d99f9a
...
34f01234c9
| Author | SHA1 | Date | |
|---|---|---|---|
| 34f01234c9 | |||
| 0bd46937d3 | |||
| e6b5bc2e7d | |||
| c90ed58078 | |||
| 76c8f2bdad | |||
| 393b3befd6 | |||
| 2c08275934 | |||
| 7cb384fa63 | |||
| 7efaeba283 | |||
| b61ded8458 |
@@ -7,6 +7,18 @@
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## General Rules
|
||||||
|
|
||||||
|
**Code Cleanup**: As you implement each step, remove any code that becomes unused or obsolete. This includes:
|
||||||
|
- Old functions/methods that are superseded by new ones
|
||||||
|
- Deprecated imports or modules
|
||||||
|
- Dead code paths
|
||||||
|
- Old test files no longer needed
|
||||||
|
|
||||||
|
This keeps the codebase clean and prevents confusion. When removing code, note it in the commit message if significant.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Decisions Log
|
## Decisions Log
|
||||||
|
|
||||||
| Topic | Decision |
|
| Topic | Decision |
|
||||||
@@ -24,18 +36,18 @@
|
|||||||
|
|
||||||
**Changes**:
|
**Changes**:
|
||||||
- `app/schemas.py` — Add to `WsFrameType` enum:
|
- `app/schemas.py` — Add to `WsFrameType` enum:
|
||||||
- `home_request`, `popup_request`
|
- `home_request`, `floating_request`
|
||||||
- `stream_start`, `stream_text`, `stream_block`, `stream_end`
|
- `stream_start`, `stream_text`, `stream_block`, `stream_end`
|
||||||
- `popup_domain`
|
- `floating_domain`
|
||||||
- `data_request`, `data_response`, `mutation`
|
- `data_request`, `data_response`, `mutation`
|
||||||
- Add Pydantic models:
|
- Add Pydantic models:
|
||||||
- `WsHomeRequest(type, message, conversation_history?)`
|
- `WsHomeRequest(type, message, conversation_history?)`
|
||||||
- `WsPopupRequest(type, message, scope: {type, id?})`
|
- `WsFloatingRequest(type, message, scope: {type, id?})`
|
||||||
- `WsStreamStart(type, request_id)`
|
- `WsStreamStart(type, request_id)`
|
||||||
- `WsStreamText(type, request_id, chunk)`
|
- `WsStreamText(type, request_id, chunk)`
|
||||||
- `WsStreamBlock(type, request_id, block_type, data)`
|
- `WsStreamBlock(type, request_id, block_type, data)`
|
||||||
- `WsStreamEnd(type, request_id, mutations?)`
|
- `WsStreamEnd(type, request_id, mutations?)`
|
||||||
- `WsPopupDomain(type, request_id, domain)`
|
- `WsFloatingDomain(type, request_id, domain)`
|
||||||
- Keep all existing frame types (backward compat).
|
- Keep all existing frame types (backward compat).
|
||||||
|
|
||||||
**Files touched**: `app/schemas.py`
|
**Files touched**: `app/schemas.py`
|
||||||
@@ -45,6 +57,14 @@
|
|||||||
pytest tests/test_schemas_v3.py
|
pytest tests/test_schemas_v3.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Status**:
|
||||||
|
- [x] Step 1 complete
|
||||||
|
|
||||||
|
**Commit**: After tests pass, commit with:
|
||||||
|
```
|
||||||
|
git commit -m "step-1: add v3 ws frame protocol (schemas.py)"
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Step 2 — Agent Streaming + Tool Result Capture (agent_registry.py, agents/)
|
## Step 2 — Agent Streaming + Tool Result Capture (agent_registry.py, agents/)
|
||||||
@@ -65,6 +85,14 @@ pytest tests/test_schemas_v3.py
|
|||||||
pytest tests/test_agent_streaming.py
|
pytest tests/test_agent_streaming.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Status**:
|
||||||
|
- [x] Step 2 complete
|
||||||
|
|
||||||
|
**Commit**: After tests pass, commit with:
|
||||||
|
```
|
||||||
|
git commit -m "step-2: add agent streaming and tool result capture (agent_registry.py)"
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Step 3 — Router Refactor (orchestrator.py)
|
## Step 3 — Router Refactor (orchestrator.py)
|
||||||
@@ -90,11 +118,19 @@ pytest tests/test_agent_streaming.py
|
|||||||
pytest tests/test_orchestrator_v3.py
|
pytest tests/test_orchestrator_v3.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Status**:
|
||||||
|
- [x] Step 3 complete
|
||||||
|
|
||||||
|
**Commit**: After tests pass, commit with:
|
||||||
|
```
|
||||||
|
git commit -m "step-3: add router refactor with streaming support (orchestrator.py)"
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Step 4 — Output Formatting Layer (NEW: output_formatter.py)
|
## Step 4 — Output Formatting Layer (NEW: output_formatter.py)
|
||||||
|
|
||||||
**Goal**: Home and Popup responses diverge at this layer only.
|
**Goal**: Home and Floating responses diverge at this layer only.
|
||||||
|
|
||||||
### Block Types (from Electron app components)
|
### Block Types (from Electron app components)
|
||||||
|
|
||||||
@@ -158,14 +194,14 @@ Supported entity types (matching Electron component types):
|
|||||||
- `table` -> buffers, validates headers/rows structure, yields `WsStreamBlock`
|
- `table` -> buffers, validates headers/rows structure, yields `WsStreamBlock`
|
||||||
- `timeline` -> buffers, validates checkpoint objects, yields `WsStreamBlock`
|
- `timeline` -> buffers, validates checkpoint objects, yields `WsStreamBlock`
|
||||||
- Invalid blocks are logged and skipped (never crash the stream)
|
- Invalid blocks are logged and skipped (never crash the stream)
|
||||||
- `PopupFormatter`:
|
- `FloatingFormatter`:
|
||||||
- Receives `agent_name` from orchestrator
|
- Receives `agent_name` from orchestrator
|
||||||
- Maps agent name to domain (deterministic, by code — no LLM):
|
- Maps agent name to domain (deterministic, by code — no LLM):
|
||||||
- `task_agent` -> `"tasks"`
|
- `task_agent` -> `"tasks"`
|
||||||
- `checkpoint_agent` -> `"checkpoints"`
|
- `checkpoint_agent` -> `"checkpoints"`
|
||||||
- `note_agent` -> `"notes"`
|
- `note_agent` -> `"notes"`
|
||||||
- `project_agent` -> `"projects"`
|
- `project_agent` -> `"projects"`
|
||||||
- Yields `WsPopupDomain` immediately
|
- Yields `WsFloatingDomain` immediately
|
||||||
- Then yields `WsStreamText` for all tokens (text-only, no blocks)
|
- Then yields `WsStreamText` for all tokens (text-only, no blocks)
|
||||||
|
|
||||||
**Files touched**: `app/core/output_formatter.py` (new)
|
**Files touched**: `app/core/output_formatter.py` (new)
|
||||||
@@ -175,23 +211,32 @@ Supported entity types (matching Electron component types):
|
|||||||
pytest tests/test_output_formatter.py
|
pytest tests/test_output_formatter.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Status**:
|
||||||
|
- [x] Step 4 complete
|
||||||
|
|
||||||
|
**Commit**: After tests pass, commit with:
|
||||||
|
```
|
||||||
|
git commit -m "step-4: add output formatting layer (output_formatter.py)"
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Step 5 — Unified WS Handler (device_ws.py, chat.py, main.py)
|
## Step 5 — Unified WS Handler (device_ws.py, chat.py, main.py)
|
||||||
|
|
||||||
**Goal**: Single multiplexed WebSocket handles device frames + Home/Popup chat.
|
**Goal**: Single multiplexed WebSocket handles device frames + Home/Floating chat.
|
||||||
|
|
||||||
**Changes**:
|
**Changes**:
|
||||||
- `app/api/routes/device_ws.py`:
|
- `app/api/routes/device_ws.py`:
|
||||||
- Extend `_message_loop` dispatch to handle `home_request` and `popup_request`:
|
- Extend `_message_loop` dispatch to handle `home_request` and `floating_request`:
|
||||||
- On `home_request`: set `ws_context` executor, call `orchestrate_v3_stream`, pipe through `HomeFormatter`, send frames back on same socket.
|
- On `home_request`: set `ws_context` executor, call `orchestrate_v3_stream`, pipe through `HomeFormatter`, send frames back on same socket.
|
||||||
- On `popup_request`: same, but pipe through `PopupFormatter`.
|
- On `floating_request`: same, but pipe through `FloatingFormatter`.
|
||||||
- Wrap both in try/finally to clear `ws_context`.
|
- Wrap both in try/finally to clear `ws_context`.
|
||||||
- Each request gets a `request_id` (UUID) for frame correlation.
|
- Each request gets a `request_id` (UUID) for frame correlation.
|
||||||
- Concurrent requests from same client are supported (each runs as an async task).
|
- Concurrent requests from same client are supported (each runs as an async task).
|
||||||
- `app/api/routes/chat.py`:
|
- `app/api/routes/chat.py`:
|
||||||
- Remove `chat_stream` WS endpoint.
|
- Remove `chat_stream` WS endpoint and any related helper functions that were only used by it.
|
||||||
- Keep `POST /chat` endpoint unchanged (REST fallback).
|
- Keep `POST /chat` endpoint unchanged (REST fallback).
|
||||||
|
- Clean up any unused imports.
|
||||||
- `app/main.py`:
|
- `app/main.py`:
|
||||||
- No change needed (device_ws router already registered).
|
- No change needed (device_ws router already registered).
|
||||||
|
|
||||||
@@ -201,12 +246,20 @@ pytest tests/test_output_formatter.py
|
|||||||
1. Connects to `/api/v1/ws/device`
|
1. Connects to `/api/v1/ws/device`
|
||||||
2. Sends `device_hello`
|
2. Sends `device_hello`
|
||||||
3. Sends `home_request` -> receives `stream_start`, `stream_text`*, `stream_end`
|
3. Sends `home_request` -> receives `stream_start`, `stream_text`*, `stream_end`
|
||||||
4. Sends `popup_request` -> receives `popup_domain`, `stream_text`*, `stream_end`
|
4. Sends `floating_request` -> receives `floating_domain`, `stream_text`*, `stream_end`
|
||||||
5. Verifies `tool_call`/`tool_result` round-trip still works during chat
|
5. Verifies `tool_call`/`tool_result` round-trip still works during chat
|
||||||
```
|
```
|
||||||
pytest tests/test_ws_unified.py
|
pytest tests/test_ws_unified.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Status**:
|
||||||
|
- [x] Step 5 complete
|
||||||
|
|
||||||
|
**Commit**: After tests pass, commit with:
|
||||||
|
```
|
||||||
|
git commit -m "step-5: unify ws handler (device_ws.py, chat.py)"
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Step 6 — Memory Models + Migration (models.py, alembic)
|
## Step 6 — Memory Models + Migration (models.py, alembic)
|
||||||
@@ -231,6 +284,14 @@ alembic upgrade head && alembic downgrade -1 && alembic upgrade head
|
|||||||
pytest tests/test_memory_models.py
|
pytest tests/test_memory_models.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Status**:
|
||||||
|
- [x] Step 6 complete
|
||||||
|
|
||||||
|
**Commit**: After tests pass, commit with:
|
||||||
|
```
|
||||||
|
git commit -m "step-6: add memory models and migration (models.py, alembic)"
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Step 7 — Memory Middleware (NEW: memory_middleware.py)
|
## Step 7 — Memory Middleware (NEW: memory_middleware.py)
|
||||||
@@ -252,7 +313,7 @@ pytest tests/test_memory_models.py
|
|||||||
3. Embed interaction, encrypt and upsert in `MemoryAssociative`
|
3. Embed interaction, encrypt and upsert in `MemoryAssociative`
|
||||||
- `update_core(user_id, key, value)` — explicit preference update
|
- `update_core(user_id, key, value)` — explicit preference update
|
||||||
- All read/write operations encrypt/decrypt using the user's Fernet key from `User.encryption_key`
|
- All read/write operations encrypt/decrypt using the user's Fernet key from `User.encryption_key`
|
||||||
- `app/api/routes/device_ws.py` — Update `home_request` and `popup_request` handlers:
|
- `app/api/routes/device_ws.py` — Update `home_request` and `floating_request` handlers:
|
||||||
- Before orchestrator: `enriched = await memory.enrich_context(user_id, message)`
|
- Before orchestrator: `enriched = await memory.enrich_context(user_id, message)`
|
||||||
- After response complete: `await memory.store_episode(user_id, ...)`
|
- After response complete: `await memory.store_episode(user_id, ...)`
|
||||||
|
|
||||||
@@ -266,6 +327,14 @@ pytest tests/test_memory_models.py
|
|||||||
pytest tests/test_memory_middleware.py
|
pytest tests/test_memory_middleware.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Status**:
|
||||||
|
- [x] Step 7 complete
|
||||||
|
|
||||||
|
**Commit**: After tests pass, commit with:
|
||||||
|
```
|
||||||
|
git commit -m "step-7: add memory middleware (memory_middleware.py, device_ws.py)"
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Summary
|
## Summary
|
||||||
|
|||||||
144
alembic/versions/004_add_memory_tables.py
Normal file
144
alembic/versions/004_add_memory_tables.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""Add memory tables and user encryption_key column.
|
||||||
|
|
||||||
|
Memory tables:
|
||||||
|
memory_core — per-user key/value preferences (encrypted)
|
||||||
|
memory_associative — semantic memory with pgvector embedding (encrypted)
|
||||||
|
memory_episodic — session summaries (encrypted)
|
||||||
|
memory_proactive — behavioral patterns (encrypted)
|
||||||
|
|
||||||
|
Also adds encryption_key column to users table.
|
||||||
|
|
||||||
|
Revision ID: 004
|
||||||
|
Revises: 003
|
||||||
|
Create Date: 2026-03-08
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "004"
|
||||||
|
down_revision: Union[str, None] = "003"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Enable pgvector extension (idempotent) ────────────────────────────────
|
||||||
|
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||||
|
|
||||||
|
# ── Add encryption_key to users ───────────────────────────────────────────
|
||||||
|
op.add_column(
|
||||||
|
"users",
|
||||||
|
sa.Column("encryption_key", sa.String(64), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── memory_core ───────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"memory_core",
|
||||||
|
sa.Column("id", sa.String(36), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
sa.String(36),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column("key", sa.String(255), nullable=False),
|
||||||
|
sa.Column("value_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_core_user_id", "memory_core", ["user_id"])
|
||||||
|
|
||||||
|
# ── memory_associative ────────────────────────────────────────────────────
|
||||||
|
# The embedding column uses pgvector's vector(1536) type.
|
||||||
|
op.create_table(
|
||||||
|
"memory_associative",
|
||||||
|
sa.Column("id", sa.String(36), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
sa.String(36),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("content_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column("entity_type", sa.String(100), nullable=True),
|
||||||
|
sa.Column("entity_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Add the pgvector column separately (not supported by generic sa types)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE memory_associative ADD COLUMN embedding vector(1536);"
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_associative_user_id", "memory_associative", ["user_id"])
|
||||||
|
# IVFFlat index for approximate nearest-neighbour search
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX ix_memory_associative_embedding "
|
||||||
|
"ON memory_associative USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── memory_episodic ───────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"memory_episodic",
|
||||||
|
sa.Column("id", sa.String(36), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
sa.String(36),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("summary_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column("session_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_episodic_user_id", "memory_episodic", ["user_id"])
|
||||||
|
op.create_index("ix_memory_episodic_session_id", "memory_episodic", ["session_id"])
|
||||||
|
|
||||||
|
# ── memory_proactive ──────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"memory_proactive",
|
||||||
|
sa.Column("id", sa.String(36), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
sa.String(36),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("pattern_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column("confidence", sa.Float, nullable=False, server_default="0.5"),
|
||||||
|
sa.Column("source", sa.String(50), nullable=False, server_default="inferred"),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_proactive_user_id", "memory_proactive", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("memory_proactive")
|
||||||
|
op.drop_table("memory_episodic")
|
||||||
|
op.drop_index("ix_memory_associative_embedding", "memory_associative")
|
||||||
|
op.drop_table("memory_associative")
|
||||||
|
op.drop_table("memory_core")
|
||||||
|
op.drop_column("users", "encryption_key")
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import uuid
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -94,6 +95,7 @@ async def register(
|
|||||||
email=body.email,
|
email=body.email,
|
||||||
password_hash=_hash_password(body.password),
|
password_hash=_hash_password(body.password),
|
||||||
tier="free",
|
tier="free",
|
||||||
|
encryption_key=Fernet.generate_key().decode(),
|
||||||
)
|
)
|
||||||
db.add(user)
|
db.add(user)
|
||||||
await db.flush() # get user.id without committing
|
await db.flush() # get user.id without committing
|
||||||
|
|||||||
@@ -1,23 +1,19 @@
|
|||||||
"""Chat routes: POST /chat and WebSocket /chat/stream."""
|
"""Chat routes: POST /chat (REST fallback).
|
||||||
|
|
||||||
|
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
from fastapi import APIRouter, Depends
|
||||||
import json
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from jose import JWTError, jwt
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.config.settings import settings
|
from app.core.orchestrator import orchestrate
|
||||||
from app.core.orchestrator import orchestrate, orchestrate_stream
|
|
||||||
from app.schemas import ChatRequest, UserProfile
|
from app.schemas import ChatRequest, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def chat(
|
async def chat(
|
||||||
@@ -31,48 +27,3 @@ async def chat(
|
|||||||
"""
|
"""
|
||||||
result = await orchestrate(body)
|
result = await orchestrate(body)
|
||||||
return JSONResponse(content=result.model_dump())
|
return JSONResponse(content=result.model_dump())
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/stream")
|
|
||||||
async def chat_stream(websocket: WebSocket) -> None:
|
|
||||||
"""Streaming chat via WebSocket.
|
|
||||||
|
|
||||||
Auth: ``?token=<jwt>`` query param (Bearer not possible during WS handshake).
|
|
||||||
|
|
||||||
Protocol:
|
|
||||||
1. Client sends ``ChatRequest`` as the first JSON text frame.
|
|
||||||
2. Server streams response text chunks.
|
|
||||||
3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``.
|
|
||||||
4. Server pings every 30 s to keep the connection alive.
|
|
||||||
"""
|
|
||||||
# Authenticate before accepting the connection
|
|
||||||
token = websocket.query_params.get("token", "")
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
|
||||||
user_id: str | None = payload.get("sub")
|
|
||||||
if not user_id:
|
|
||||||
raise JWTError("missing sub")
|
|
||||||
except JWTError:
|
|
||||||
await websocket.close(code=1008) # 1008 = Policy Violation
|
|
||||||
return
|
|
||||||
|
|
||||||
await websocket.accept()
|
|
||||||
|
|
||||||
try:
|
|
||||||
raw = await websocket.receive_text()
|
|
||||||
body = ChatRequest.model_validate_json(raw)
|
|
||||||
|
|
||||||
async def _heartbeat() -> None:
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
|
||||||
await websocket.send_text(json.dumps({"ping": True}))
|
|
||||||
|
|
||||||
heartbeat_task = asyncio.create_task(_heartbeat())
|
|
||||||
try:
|
|
||||||
async for chunk in orchestrate_stream(body):
|
|
||||||
await websocket.send_text(chunk)
|
|
||||||
finally:
|
|
||||||
heartbeat_task.cancel()
|
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -33,14 +33,19 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from sqlalchemy import select, update
|
from sqlalchemy import update
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.core.orchestrator import orchestrate_v3_stream
|
||||||
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||||
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
from app.schemas import WsFrameType
|
from app.schemas import WsFrameType
|
||||||
@@ -173,6 +178,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
"device_ws: agent_complete missing run_id from user=%s", user_id
|
"device_ws: agent_complete missing run_id from user=%s", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.home_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_home_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.floating_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -183,6 +198,109 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||||
|
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||||
|
async def _executor(payload: dict) -> dict:
|
||||||
|
payload["type"] = WsFrameType.tool_call
|
||||||
|
await websocket.send_text(json.dumps(payload))
|
||||||
|
future = device_manager.create_pending_call(user_id, payload["id"])
|
||||||
|
return await future
|
||||||
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_home_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
|
||||||
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
try:
|
||||||
|
token_stream = orchestrate_v3_stream(user_id, message, context)
|
||||||
|
formatter = HomeFormatter(request_id=request_id, tool_results=[])
|
||||||
|
async for ws_frame in formatter.format(token_stream):
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
# Collect text chunks to build the full response for episode storage
|
||||||
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: home_request failed user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
# ── Memory: store episode after response ──────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.store_episode(
|
||||||
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_floating_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
scope: dict = frame.get("scope", {})
|
||||||
|
|
||||||
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
|
|
||||||
|
context: dict = {"scope": scope, **memory_context}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
try:
|
||||||
|
token_stream = orchestrate_v3_stream(user_id, message, context)
|
||||||
|
formatter = FloatingFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(token_stream):
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: floating_request failed user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
# ── Memory: store episode after response ──────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.store_episode(
|
||||||
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
@@ -26,6 +26,7 @@ class Settings(BaseSettings):
|
|||||||
OPENAI_API_KEY: str = ""
|
OPENAI_API_KEY: str = ""
|
||||||
ANTHROPIC_API_KEY: str = ""
|
ANTHROPIC_API_KEY: str = ""
|
||||||
GOOGLE_API_KEY: str = ""
|
GOOGLE_API_KEY: str = ""
|
||||||
|
CEREBRAS_API_KEY: str = ""
|
||||||
|
|
||||||
LLM_MODEL: str = "gpt-4o"
|
LLM_MODEL: str = "gpt-4o"
|
||||||
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
||||||
@@ -53,9 +54,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
class Config:
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
env_file = ".env"
|
|
||||||
env_file_encoding = "utf-8"
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
@@ -34,11 +35,26 @@ class BaseAgent(ABC):
|
|||||||
class ChatAgent(BaseAgent):
|
class ChatAgent(BaseAgent):
|
||||||
"""Base class for LLM-powered chat agents."""
|
"""Base class for LLM-powered chat agents."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
|
||||||
|
self.tool_results: list[dict] = []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
"""Process a user query and return a text response."""
|
"""Process a user query and return a text response."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def handle_stream(
|
||||||
|
self, query: str, context: dict[str, Any]
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Streaming variant of handle().
|
||||||
|
|
||||||
|
Default: calls handle() and yields the full response as one chunk.
|
||||||
|
Override in subclasses for true token-level streaming via _tool_loop_stream.
|
||||||
|
"""
|
||||||
|
yield await self.handle(query, context)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_tools(self) -> list[Any]:
|
def get_tools(self) -> list[Any]:
|
||||||
"""Return LangChain tool definitions available to this agent."""
|
"""Return LangChain tool definitions available to this agent."""
|
||||||
@@ -55,10 +71,16 @@ class ChatAgent(BaseAgent):
|
|||||||
|
|
||||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
Binds *tools* to *llm*, invokes iteratively until the model stops
|
||||||
requesting tool calls or *max_iter* is reached, and returns the
|
requesting tool calls or *max_iter* is reached, and returns the
|
||||||
final text response.
|
final text response. Captures raw execute_on_client results in
|
||||||
|
``self.tool_results``.
|
||||||
"""
|
"""
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
||||||
|
|
||||||
|
collector: list[dict] = []
|
||||||
|
set_tool_result_collector(collector)
|
||||||
|
try:
|
||||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||||
|
|
||||||
for _ in range(max_iter):
|
for _ in range(max_iter):
|
||||||
@@ -83,6 +105,64 @@ class ChatAgent(BaseAgent):
|
|||||||
# Exhausted iterations — ask model for a final answer without tools
|
# Exhausted iterations — ask model for a final answer without tools
|
||||||
response = await llm.ainvoke(messages)
|
response = await llm.ainvoke(messages)
|
||||||
return str(response.content)
|
return str(response.content)
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
self.tool_results = collector
|
||||||
|
|
||||||
|
async def _tool_loop_stream(
|
||||||
|
self,
|
||||||
|
llm: Any,
|
||||||
|
messages: list[Any],
|
||||||
|
tools: list[Any],
|
||||||
|
max_iter: int = 5,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Streaming variant of ``_tool_loop``.
|
||||||
|
|
||||||
|
Behaves identically for tool-calling iterations (uses ainvoke to parse
|
||||||
|
tool calls). For the final response — when the model produces no further
|
||||||
|
tool calls — switches to ``llm.astream()`` and yields text tokens.
|
||||||
|
Captures raw execute_on_client results in ``self.tool_results``.
|
||||||
|
"""
|
||||||
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
||||||
|
|
||||||
|
collector: list[dict] = []
|
||||||
|
set_tool_result_collector(collector)
|
||||||
|
try:
|
||||||
|
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||||
|
|
||||||
|
for _ in range(max_iter):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
# Stream the final answer — don't keep the ainvoke result.
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
if chunk.content:
|
||||||
|
yield str(chunk.content)
|
||||||
|
return
|
||||||
|
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
# Execute each requested tool call
|
||||||
|
tool_map = {t.name: t for t in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_fn = tool_map.get(call["name"])
|
||||||
|
if tool_fn is None:
|
||||||
|
result = f"Unknown tool: {call['name']}"
|
||||||
|
else:
|
||||||
|
result = await tool_fn.ainvoke(call["args"])
|
||||||
|
messages.append(
|
||||||
|
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exhausted iterations — stream a final answer without tools
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
if chunk.content:
|
||||||
|
yield str(chunk.content)
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
self.tool_results = collector
|
||||||
|
|
||||||
|
|
||||||
class AgentRegistry:
|
class AgentRegistry:
|
||||||
|
|||||||
231
app/core/memory_middleware.py
Normal file
231
app/core/memory_middleware.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
"""Memory Middleware — enrich requests with memory context and store interactions.
|
||||||
|
|
||||||
|
Four-tier memory model (MemGPT-style):
|
||||||
|
core — persistent key/value user preferences, always injected
|
||||||
|
associative — semantic similarity search via pgvector (top-k)
|
||||||
|
episodic — recent session summaries (last N)
|
||||||
|
proactive — behavioral patterns above confidence threshold
|
||||||
|
|
||||||
|
All memory content is encrypted at rest using the per-user Fernet key
|
||||||
|
stored in User.encryption_key. Decryption happens in-memory only.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
memory = MemoryMiddleware(db_session)
|
||||||
|
context = await memory.enrich_context(user_id, message)
|
||||||
|
# ... run agent ...
|
||||||
|
await memory.store_episode(user_id, session_id, message, response)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import (
|
||||||
|
MemoryAssociative,
|
||||||
|
MemoryCore,
|
||||||
|
MemoryEpisodic,
|
||||||
|
MemoryProactive,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Tuning constants
|
||||||
|
_ASSOCIATIVE_TOP_K = 5
|
||||||
|
_EPISODIC_RECENT_N = 10
|
||||||
|
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryMiddleware:
|
||||||
|
"""Enrich orchestrator context with memory and persist interactions after."""
|
||||||
|
|
||||||
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
|
self._db = db
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
||||||
|
"""Build memory context dict to inject into the orchestrator before LLM call.
|
||||||
|
|
||||||
|
Returns a dict with keys:
|
||||||
|
core_memory — {key: plaintext_value, ...}
|
||||||
|
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||||
|
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||||
|
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||||
|
"""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
core = await self._load_core(user_id, fernet)
|
||||||
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
|
episodic = await self._load_episodic(user_id, fernet)
|
||||||
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"core_memory": core,
|
||||||
|
"associative_memory": associative,
|
||||||
|
"episodic_memory": episodic,
|
||||||
|
"proactive_hints": proactive,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def store_episode(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
message: str,
|
||||||
|
response: str,
|
||||||
|
) -> None:
|
||||||
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
|
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||||
|
latency low. Full LLM summarisation can be added in a later step.
|
||||||
|
"""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||||
|
encrypted = _encrypt(fernet, summary)
|
||||||
|
|
||||||
|
row = MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
summary_encrypted=encrypted,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
||||||
|
"""Upsert a core memory key/value for a user."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, value)
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
if existing is not None:
|
||||||
|
existing.value_encrypted = encrypted
|
||||||
|
else:
|
||||||
|
self._db.add(MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
key=key,
|
||||||
|
value_encrypted=encrypted,
|
||||||
|
))
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
|
"""Load the user's Fernet key from DB. Returns None if missing."""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not user.encryption_key:
|
||||||
|
logger.warning("memory: no encryption_key for user=%s", user_id)
|
||||||
|
return None
|
||||||
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: dict[str, str] = {}
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out[row.key] = plaintext
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_associative(
|
||||||
|
self, user_id: str, message: str, fernet: Fernet
|
||||||
|
) -> list[str]:
|
||||||
|
"""Load top-k associative memories.
|
||||||
|
|
||||||
|
Production: uses pgvector cosine similarity on the message embedding.
|
||||||
|
Current implementation: keyword-based fallback (no external embedding call)
|
||||||
|
so tests pass without a live OpenAI key.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc())
|
||||||
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
|
.limit(_EPISODIC_RECENT_N)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryProactive)
|
||||||
|
.where(
|
||||||
|
MemoryProactive.user_id == user_id,
|
||||||
|
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
||||||
|
)
|
||||||
|
.order_by(MemoryProactive.confidence.desc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# ── Encryption helpers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
||||||
|
return fernet.encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
||||||
|
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
|
||||||
|
try:
|
||||||
|
return fernet.decrypt(ciphertext.encode()).decode()
|
||||||
|
except (InvalidToken, Exception) as exc:
|
||||||
|
logger.warning("memory: decrypt failed: %s", exc)
|
||||||
|
return None
|
||||||
@@ -7,7 +7,7 @@ from typing import Any, AsyncGenerator
|
|||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry
|
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||||
from app.core.llm import get_router_llm
|
from app.core.llm import get_router_llm
|
||||||
from app.core.agent_registry import registry as _default_registry
|
from app.core.agent_registry import registry as _default_registry
|
||||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
||||||
@@ -140,6 +140,44 @@ async def orchestrate(
|
|||||||
return _build_plan(agent_name, request.message)
|
return _build_plan(agent_name, request.message)
|
||||||
|
|
||||||
|
|
||||||
|
async def orchestrate_v3(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
reg: AgentRegistry | None = None,
|
||||||
|
) -> tuple[str, ChatAgent]:
|
||||||
|
"""v3 orchestration — returns (agent_name, agent_instance); caller drives execution.
|
||||||
|
|
||||||
|
Classifies intent and instantiates the matching agent. The caller is responsible
|
||||||
|
for invoking handle(), handle_stream(), or _tool_loop_stream() as needed.
|
||||||
|
"""
|
||||||
|
if reg is None:
|
||||||
|
reg = _default_registry
|
||||||
|
agent_name = await classify_intent(message, context, reg)
|
||||||
|
return agent_name, reg.get(agent_name)
|
||||||
|
|
||||||
|
|
||||||
|
async def orchestrate_v3_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
reg: AgentRegistry | None = None,
|
||||||
|
) -> AsyncGenerator[tuple[str, str], None]:
|
||||||
|
"""v3 streaming orchestration — yields (agent_name, token) pairs.
|
||||||
|
|
||||||
|
The first yield always carries the agent_name with an empty token so that
|
||||||
|
callers (e.g. FloatingFormatter) can detect the routing domain before any text
|
||||||
|
tokens arrive.
|
||||||
|
"""
|
||||||
|
if reg is None:
|
||||||
|
reg = _default_registry
|
||||||
|
agent_name = await classify_intent(message, context, reg)
|
||||||
|
agent = reg.get(agent_name)
|
||||||
|
yield agent_name, "" # domain signal — no token yet
|
||||||
|
async for token in agent.handle_stream(message, context):
|
||||||
|
yield agent_name, token
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_stream(
|
async def orchestrate_stream(
|
||||||
request: ChatRequest,
|
request: ChatRequest,
|
||||||
reg: AgentRegistry | None = None,
|
reg: AgentRegistry | None = None,
|
||||||
|
|||||||
244
app/core/output_formatter.py
Normal file
244
app/core/output_formatter.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
|
||||||
|
|
||||||
|
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
|
||||||
|
FloatingFormatter: produces floating_domain, stream_text, stream_end
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.schemas import (
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsStreamBlock,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
|
||||||
|
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
|
||||||
|
|
||||||
|
# Map agent name → floating domain
|
||||||
|
_AGENT_DOMAIN: dict[str, str] = {
|
||||||
|
"task_agent": "tasks",
|
||||||
|
"checkpoint_agent": "checkpoints",
|
||||||
|
"note_agent": "notes",
|
||||||
|
"project_agent": "projects",
|
||||||
|
}
|
||||||
|
|
||||||
|
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
|
class HomeFormatter:
|
||||||
|
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
||||||
|
|
||||||
|
The LLM is expected to output a newline-delimited sequence of JSON objects,
|
||||||
|
each with a ``type`` field:
|
||||||
|
- ``text`` → yields WsStreamText immediately (word-by-word)
|
||||||
|
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
|
||||||
|
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
|
||||||
|
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
|
||||||
|
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
|
||||||
|
|
||||||
|
Invalid or unknown blocks are logged and skipped — stream never crashes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self.tool_results = tool_results
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
token_stream: AsyncGenerator[tuple[str, str], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
async for _agent_name, token in token_stream:
|
||||||
|
if not token:
|
||||||
|
continue
|
||||||
|
buffer += token
|
||||||
|
# Flush any complete JSON objects from the buffer
|
||||||
|
async for frame in self._flush_complete_objects(buffer):
|
||||||
|
buffer = "" # reset after flush
|
||||||
|
yield frame
|
||||||
|
break # only one flush per iteration; rest accumulates
|
||||||
|
|
||||||
|
# Flush any remaining content
|
||||||
|
if buffer.strip():
|
||||||
|
async for frame in self._flush_complete_objects(buffer, final=True):
|
||||||
|
yield frame
|
||||||
|
|
||||||
|
yield WsStreamEnd(request_id=self.request_id)
|
||||||
|
|
||||||
|
async def _flush_complete_objects(
|
||||||
|
self, text: str, final: bool = False
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
"""Try to parse and yield all complete JSON objects from *text*.
|
||||||
|
|
||||||
|
Yields nothing if text is incomplete JSON (unless *final* is True,
|
||||||
|
in which case remaining text is emitted as plain stream_text).
|
||||||
|
"""
|
||||||
|
remaining = text.strip()
|
||||||
|
while remaining:
|
||||||
|
# Fast path: plain text (not JSON)
|
||||||
|
if not remaining.startswith("{"):
|
||||||
|
# Yield as plain text chunk
|
||||||
|
newline_idx = remaining.find("\n")
|
||||||
|
if newline_idx == -1:
|
||||||
|
if final:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||||
|
remaining = ""
|
||||||
|
else:
|
||||||
|
return # accumulate more
|
||||||
|
else:
|
||||||
|
line = remaining[:newline_idx].strip()
|
||||||
|
remaining = remaining[newline_idx + 1:].strip()
|
||||||
|
if line:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try to decode a JSON object
|
||||||
|
try:
|
||||||
|
obj, end_idx = _try_parse_json(remaining)
|
||||||
|
except ValueError:
|
||||||
|
if final:
|
||||||
|
# Emit as raw text if we can't parse
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||||
|
remaining = ""
|
||||||
|
return
|
||||||
|
|
||||||
|
if obj is None:
|
||||||
|
if final:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||||
|
remaining = ""
|
||||||
|
return # incomplete — need more tokens
|
||||||
|
|
||||||
|
remaining = remaining[end_idx:].strip()
|
||||||
|
block_type = obj.get("type")
|
||||||
|
|
||||||
|
frame = self._dispatch_block(obj, block_type)
|
||||||
|
if frame is not None:
|
||||||
|
yield frame
|
||||||
|
|
||||||
|
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
|
||||||
|
if block_type == "text":
|
||||||
|
content = obj.get("content", "")
|
||||||
|
if content:
|
||||||
|
return WsStreamText(request_id=self.request_id, chunk=str(content))
|
||||||
|
return None
|
||||||
|
|
||||||
|
if block_type == "chart":
|
||||||
|
chart_type = obj.get("chartType")
|
||||||
|
if chart_type not in _VALID_CHART_TYPES:
|
||||||
|
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
|
||||||
|
return None
|
||||||
|
if not isinstance(obj.get("data"), list):
|
||||||
|
logger.warning("HomeFormatter: chart missing data array — skipping")
|
||||||
|
return None
|
||||||
|
return WsStreamBlock(
|
||||||
|
request_id=self.request_id,
|
||||||
|
block_type="chart",
|
||||||
|
data=obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if block_type == "entity_ref":
|
||||||
|
entity = obj.get("entity")
|
||||||
|
resolved = self._resolve_entity(entity)
|
||||||
|
if resolved is None:
|
||||||
|
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
|
||||||
|
return None
|
||||||
|
return WsStreamBlock(
|
||||||
|
request_id=self.request_id,
|
||||||
|
block_type="entity_ref",
|
||||||
|
data={"entity": entity, "items": resolved},
|
||||||
|
)
|
||||||
|
|
||||||
|
if block_type == "table":
|
||||||
|
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
|
||||||
|
logger.warning("HomeFormatter: table missing headers/rows — skipping")
|
||||||
|
return None
|
||||||
|
return WsStreamBlock(
|
||||||
|
request_id=self.request_id,
|
||||||
|
block_type="table",
|
||||||
|
data=obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if block_type == "timeline":
|
||||||
|
if not isinstance(obj.get("checkpoints"), list):
|
||||||
|
logger.warning("HomeFormatter: timeline missing checkpoints — skipping")
|
||||||
|
return None
|
||||||
|
return WsStreamBlock(
|
||||||
|
request_id=self.request_id,
|
||||||
|
block_type="timeline",
|
||||||
|
data=obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
|
||||||
|
"""Find matching items in tool_results by entity type."""
|
||||||
|
if not entity:
|
||||||
|
return None
|
||||||
|
matches = [r for r in self.tool_results if r.get("entity") == entity]
|
||||||
|
return matches if matches else None
|
||||||
|
|
||||||
|
|
||||||
|
class FloatingFormatter:
|
||||||
|
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
||||||
|
|
||||||
|
Emits floating_domain immediately (from agent_name), then streams all tokens
|
||||||
|
as plain stream_text — no block parsing for floating context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
token_stream: AsyncGenerator[tuple[str, str], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
domain_sent = False
|
||||||
|
|
||||||
|
async for agent_name, token in token_stream:
|
||||||
|
if not domain_sent:
|
||||||
|
domain = _AGENT_DOMAIN.get(agent_name, "tasks")
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain=domain, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
domain_sent = True
|
||||||
|
|
||||||
|
if token:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=token)
|
||||||
|
|
||||||
|
yield WsStreamEnd(request_id=self.request_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
|
||||||
|
"""Attempt to parse the first complete JSON object from *text*.
|
||||||
|
|
||||||
|
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
|
||||||
|
object is incomplete, and raises ``ValueError`` when text is not JSON.
|
||||||
|
"""
|
||||||
|
decoder = json.JSONDecoder()
|
||||||
|
try:
|
||||||
|
obj, end_idx = decoder.raw_decode(text)
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
raise ValueError("Expected JSON object")
|
||||||
|
return obj, end_idx
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
# Incomplete JSON — need more tokens
|
||||||
|
if "Unterminated" in str(exc) or exc.pos == len(text):
|
||||||
|
return None, 0
|
||||||
|
raise ValueError(str(exc)) from exc
|
||||||
@@ -17,6 +17,22 @@ _client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = Cont
|
|||||||
"_client_executor"
|
"_client_executor"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Optional collector that captures raw execute_on_client results.
|
||||||
|
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
||||||
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
|
"_tool_result_collector", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||||
|
"""Register *lst* as the collector for this async context."""
|
||||||
|
_tool_result_collector.set(lst)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_tool_result_collector() -> None:
|
||||||
|
"""Clear the collector (best-effort)."""
|
||||||
|
_tool_result_collector.set(None)
|
||||||
|
|
||||||
|
|
||||||
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
||||||
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
||||||
@@ -65,4 +81,8 @@ async def execute_on_client(
|
|||||||
if limit is not None:
|
if limit is not None:
|
||||||
payload["limit"] = limit
|
payload["limit"] = limit
|
||||||
|
|
||||||
return await callback(payload)
|
result = await callback(payload)
|
||||||
|
collector = _tool_result_collector.get(None)
|
||||||
|
if collector is not None:
|
||||||
|
collector.append(result)
|
||||||
|
return result
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ Table inventory:
|
|||||||
plugin_installations — per-user install records
|
plugin_installations — per-user install records
|
||||||
plugin_reviews — admin review decisions
|
plugin_reviews — admin review decisions
|
||||||
revenue_events — Stripe Connect 70/30 split ledger
|
revenue_events — Stripe Connect 70/30 split ledger
|
||||||
|
memory_core — per-user persistent key/value preferences (encrypted)
|
||||||
|
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||||
|
memory_episodic — per-user session summaries (encrypted)
|
||||||
|
memory_proactive — per-user behavioral patterns (encrypted)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -74,6 +78,9 @@ class User(Base):
|
|||||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||||
|
# Used to encrypt/decrypt all memory rows for this user.
|
||||||
|
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
)
|
)
|
||||||
@@ -375,3 +382,93 @@ class AgentRunLog(Base):
|
|||||||
foreign_keys="AgentRunLog.agent_id",
|
foreign_keys="AgentRunLog.agent_id",
|
||||||
overlaps="run_logs,local_agent",
|
overlaps="run_logs,local_agent",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Memory models ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryCore(Base):
|
||||||
|
"""Per-user persistent key/value preferences, encrypted at rest.
|
||||||
|
|
||||||
|
Examples: preferred_language, timezone, work_style.
|
||||||
|
Decrypted in-memory only using User.encryption_key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_core"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryAssociative(Base):
|
||||||
|
"""Per-user semantic memory: encrypted content + pgvector embedding for similarity search.
|
||||||
|
|
||||||
|
Production: ``embedding`` column is ``vector(1536)`` via pgvector.
|
||||||
|
Tests (SQLite): stored as JSON list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_associative"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
||||||
|
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||||
|
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEpisodic(Base):
|
||||||
|
"""Per-user session summaries, encrypted at rest.
|
||||||
|
|
||||||
|
One row per session interaction; used to recall recent conversations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_episodic"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryProactive(Base):
|
||||||
|
"""Per-user inferred behavioral patterns, encrypted at rest.
|
||||||
|
|
||||||
|
Confidence in [0.0, 1.0]; only patterns above threshold are injected.
|
||||||
|
Source: 'inferred' (from episodes) or 'explicit' (user-stated).
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_proactive"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
|
||||||
|
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|||||||
@@ -161,6 +161,7 @@ class PluginInstallRequest(BaseModel):
|
|||||||
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
||||||
|
|
||||||
class WsFrameType(str, Enum):
|
class WsFrameType(str, Enum):
|
||||||
|
# ── v2 frame types (kept for backward compat) ──────────────────────
|
||||||
chat_request = "chat_request"
|
chat_request = "chat_request"
|
||||||
text_chunk = "text_chunk"
|
text_chunk = "text_chunk"
|
||||||
tool_call = "tool_call"
|
tool_call = "tool_call"
|
||||||
@@ -171,6 +172,17 @@ class WsFrameType(str, Enum):
|
|||||||
agent_data = "agent_data"
|
agent_data = "agent_data"
|
||||||
agent_complete = "agent_complete"
|
agent_complete = "agent_complete"
|
||||||
device_hello = "device_hello"
|
device_hello = "device_hello"
|
||||||
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
|
home_request = "home_request"
|
||||||
|
floating_request = "floating_request"
|
||||||
|
stream_start = "stream_start"
|
||||||
|
stream_text = "stream_text"
|
||||||
|
stream_block = "stream_block"
|
||||||
|
stream_end = "stream_end"
|
||||||
|
floating_domain = "floating_domain"
|
||||||
|
data_request = "data_request"
|
||||||
|
data_response = "data_response"
|
||||||
|
mutation = "mutation"
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -249,6 +261,71 @@ class WsAgentComplete(BaseModel):
|
|||||||
errors: list[str] = Field(default_factory=list)
|
errors: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
|
class WsFloatingScope(BaseModel):
|
||||||
|
"""Scope for a floating request — narrows the agent to a specific entity."""
|
||||||
|
|
||||||
|
type: Literal["task", "project", "note", "checkpoint"]
|
||||||
|
id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WsHomeRequest(BaseModel):
|
||||||
|
"""Client → Server: Home chat message."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
|
||||||
|
message: str
|
||||||
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WsFloatingRequest(BaseModel):
|
||||||
|
"""Client → Server: Floating chat message scoped to an entity."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
|
||||||
|
message: str
|
||||||
|
scope: WsFloatingScope
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamStart(BaseModel):
|
||||||
|
"""Server → Client: signals start of a streaming response."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_start] = WsFrameType.stream_start
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamText(BaseModel):
|
||||||
|
"""Server → Client: streamed text token."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_text] = WsFrameType.stream_text
|
||||||
|
request_id: str
|
||||||
|
chunk: str
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamBlock(BaseModel):
|
||||||
|
"""Server → Client: structured block (chart, table, entity, timeline)."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_block] = WsFrameType.stream_block
|
||||||
|
request_id: str
|
||||||
|
block_type: Literal["chart", "entity_ref", "table", "timeline"]
|
||||||
|
data: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamEnd(BaseModel):
|
||||||
|
"""Server → Client: signals end of a streaming response."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
|
request_id: str
|
||||||
|
mutations: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WsFloatingDomain(BaseModel):
|
||||||
|
"""Server → Client: domain determined for a floating request."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
|
request_id: str
|
||||||
|
domain: Literal["tasks", "checkpoints", "notes", "projects"]
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
class AgentCatalogItem(BaseModel):
|
class AgentCatalogItem(BaseModel):
|
||||||
|
|||||||
416
tests/test_agent_streaming.py
Normal file
416
tests/test_agent_streaming.py
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
"""Tests for ChatAgent streaming and tool result capture (Step 2)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.core.agent_registry import ChatAgent, registry
|
||||||
|
|
||||||
|
|
||||||
|
# ── Minimal concrete agent for testing ───────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _EchoAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "_echo"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Echo agent for tests"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage:
|
||||||
|
msg = AIMessage(content=content)
|
||||||
|
if tool_calls:
|
||||||
|
msg.tool_calls = tool_calls
|
||||||
|
else:
|
||||||
|
msg.tool_calls = []
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tool(name: str, return_value: Any) -> MagicMock:
|
||||||
|
t = MagicMock()
|
||||||
|
t.name = name
|
||||||
|
t.ainvoke = AsyncMock(return_value=return_value)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]:
|
||||||
|
chunks = []
|
||||||
|
for tok in tokens:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = tok
|
||||||
|
chunks.append(c)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]:
|
||||||
|
tokens: list[str] = []
|
||||||
|
async for tok in agent._tool_loop_stream(llm, messages, tools):
|
||||||
|
tokens.append(tok)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
# ── tool_results initialised ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_results_init():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
assert agent.tool_results == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop: no tool calls ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_no_tools():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!"))
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
|
||||||
|
assert result == "Hello!"
|
||||||
|
assert agent.tool_results == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop: with one tool call + result capture ──────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_captures_tool_results():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
# Mock execute_on_client to return structured data via the tool
|
||||||
|
raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]}
|
||||||
|
|
||||||
|
async def fake_executor(payload: dict) -> dict:
|
||||||
|
return raw_result
|
||||||
|
|
||||||
|
# AIMessage with a tool call, then a final answer
|
||||||
|
tool_call_msg = _make_ai_message(
|
||||||
|
tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}]
|
||||||
|
)
|
||||||
|
final_msg = _make_ai_message("Here are your tasks.")
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm_with_tools = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||||
|
llm.ainvoke = AsyncMock(return_value=final_msg)
|
||||||
|
|
||||||
|
mock_tool = _make_tool("list_tasks", "- Fix bug (todo)")
|
||||||
|
|
||||||
|
from app.core.ws_context import set_client_executor, clear_client_executor
|
||||||
|
set_client_executor(fake_executor)
|
||||||
|
try:
|
||||||
|
# Patch the tool to actually call execute_on_client
|
||||||
|
async def tool_side_effect(args: dict) -> str:
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
res = await execute_on_client(action="select", table="tasks")
|
||||||
|
rows = res.get("rows", [])
|
||||||
|
return "\n".join(r["title"] for r in rows)
|
||||||
|
|
||||||
|
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
||||||
|
|
||||||
|
result = await agent._tool_loop(
|
||||||
|
llm, [HumanMessage(content="list my tasks")], [mock_tool]
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
assert result == "Here are your tasks."
|
||||||
|
assert len(agent.tool_results) == 1
|
||||||
|
assert agent.tool_results[0] == raw_result
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop: tool_results reset on each call ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_resets_tool_results():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
agent.tool_results = [{"stale": True}] # pre-populated from a previous call
|
||||||
|
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done."))
|
||||||
|
|
||||||
|
await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
|
||||||
|
assert agent.tool_results == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop: unknown tool name ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_unknown_tool():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
# No known tools — model still calls a non-existent one; loop handles gracefully
|
||||||
|
tool_call_msg = _make_ai_message(
|
||||||
|
tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}]
|
||||||
|
)
|
||||||
|
final_msg = _make_ai_message("Handled.")
|
||||||
|
|
||||||
|
mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent"
|
||||||
|
llm = MagicMock()
|
||||||
|
llm_with_tools = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool])
|
||||||
|
assert result == "Handled."
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop: max_iter exhaustion ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_max_iter():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
always_tool = _make_ai_message(
|
||||||
|
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
|
||||||
|
)
|
||||||
|
fallback = _make_ai_message("Fallback.")
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm_with_tools = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
# Returns tool_call_msg on every iteration
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
|
||||||
|
llm.ainvoke = AsyncMock(return_value=fallback)
|
||||||
|
|
||||||
|
mock_tool = _make_tool("t", "ok")
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2)
|
||||||
|
assert result == "Fallback."
|
||||||
|
assert llm_with_tools.ainvoke.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop_stream: no tool calls — yields tokens ─────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_stream_no_tools_yields_tokens():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
# No tools → llm used directly; ainvoke returns no tool calls → stream is used
|
||||||
|
no_tool_msg = _make_ai_message("irrelevant")
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
||||||
|
|
||||||
|
async def fake_astream(msgs):
|
||||||
|
for tok in ["Hello", " ", "world"]:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = tok
|
||||||
|
yield c
|
||||||
|
|
||||||
|
llm.astream = fake_astream
|
||||||
|
|
||||||
|
tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], [])
|
||||||
|
assert tokens == ["Hello", " ", "world"]
|
||||||
|
assert agent.tool_results == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop_stream: one tool call then streaming final ─────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_stream_with_tool_call():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}}
|
||||||
|
|
||||||
|
async def fake_executor(payload: dict) -> dict:
|
||||||
|
return raw_result
|
||||||
|
|
||||||
|
tool_call_msg = _make_ai_message(
|
||||||
|
tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}]
|
||||||
|
)
|
||||||
|
# After tools run, ainvoke returns no more tool calls
|
||||||
|
no_more_tools_msg = _make_ai_message("Task found.")
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm_with_tools = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
|
||||||
|
|
||||||
|
async def fake_astream(msgs):
|
||||||
|
for tok in ["Task", " ", "found."]:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = tok
|
||||||
|
yield c
|
||||||
|
|
||||||
|
llm.astream = fake_astream
|
||||||
|
|
||||||
|
async def tool_side_effect(args: dict) -> str:
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")})
|
||||||
|
return res.get("row", {}).get("title", "")
|
||||||
|
|
||||||
|
mock_tool = _make_tool("get_task", "Deploy")
|
||||||
|
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
||||||
|
|
||||||
|
from app.core.ws_context import set_client_executor, clear_client_executor
|
||||||
|
set_client_executor(fake_executor)
|
||||||
|
try:
|
||||||
|
tokens = await _collect_stream(
|
||||||
|
agent, llm, [HumanMessage(content="get task t-2")], [mock_tool]
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
assert tokens == ["Task", " ", "found."]
|
||||||
|
assert len(agent.tool_results) == 1
|
||||||
|
assert agent.tool_results[0] == raw_result
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop_stream: tool_results reset on each call ───────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_stream_resets_tool_results():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
agent.tool_results = [{"old": True}]
|
||||||
|
|
||||||
|
no_tool_msg = _make_ai_message("")
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
||||||
|
|
||||||
|
async def fake_astream(msgs):
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = "ok"
|
||||||
|
yield c
|
||||||
|
|
||||||
|
llm.astream = fake_astream
|
||||||
|
|
||||||
|
await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
|
||||||
|
assert agent.tool_results == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop_stream: empty chunk content is skipped ────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_stream_skips_empty_chunks():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
no_tool_msg = _make_ai_message("")
|
||||||
|
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
||||||
|
|
||||||
|
async def fake_astream(msgs):
|
||||||
|
for tok in ["", "hello", "", " world", ""]:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = tok
|
||||||
|
yield c
|
||||||
|
|
||||||
|
llm.astream = fake_astream
|
||||||
|
|
||||||
|
tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
|
||||||
|
assert tokens == ["hello", " world"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop_stream: max_iter exhaustion falls back to stream ───────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_stream_max_iter():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
always_tool = _make_ai_message(
|
||||||
|
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm_with_tools = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
|
||||||
|
|
||||||
|
async def fake_astream(msgs):
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = "fallback"
|
||||||
|
yield c
|
||||||
|
|
||||||
|
llm.astream = fake_astream
|
||||||
|
mock_tool = _make_tool("t", "ok")
|
||||||
|
|
||||||
|
tokens = await _collect_stream(
|
||||||
|
agent, llm, [HumanMessage(content="x")], [mock_tool],
|
||||||
|
)
|
||||||
|
assert tokens == ["fallback"]
|
||||||
|
assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop_stream: multiple tool results captured ────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_stream_multiple_tool_results():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
call_results = [
|
||||||
|
{"rows": [{"id": "t-1"}]},
|
||||||
|
{"rows": [{"id": "t-2"}]},
|
||||||
|
]
|
||||||
|
call_iter = iter(call_results)
|
||||||
|
|
||||||
|
async def fake_executor(payload: dict) -> dict:
|
||||||
|
return next(call_iter)
|
||||||
|
|
||||||
|
# Two tool calls in one iteration
|
||||||
|
tool_call_msg = _make_ai_message(
|
||||||
|
tool_calls=[
|
||||||
|
{"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"},
|
||||||
|
{"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
no_more_tools_msg = _make_ai_message("Done.")
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm_with_tools = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
|
||||||
|
|
||||||
|
async def fake_astream(msgs):
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = "Done."
|
||||||
|
yield c
|
||||||
|
|
||||||
|
llm.astream = fake_astream
|
||||||
|
|
||||||
|
async def tool_side_effect(args: dict) -> str:
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
res = await execute_on_client(action="select", table="tasks")
|
||||||
|
return str(res)
|
||||||
|
|
||||||
|
tool_a = _make_tool("tool_a", "")
|
||||||
|
tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
||||||
|
tool_b = _make_tool("tool_b", "")
|
||||||
|
tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
||||||
|
|
||||||
|
from app.core.ws_context import set_client_executor, clear_client_executor
|
||||||
|
set_client_executor(fake_executor)
|
||||||
|
try:
|
||||||
|
tokens = await _collect_stream(
|
||||||
|
agent, llm, [HumanMessage(content="x")], [tool_a, tool_b]
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
assert tokens == ["Done."]
|
||||||
|
assert len(agent.tool_results) == 2
|
||||||
|
assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]}
|
||||||
|
assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}
|
||||||
@@ -14,6 +14,56 @@ from app.agents.note_agent import NoteAgent
|
|||||||
from app.agents.project_agent import ProjectAgent
|
from app.agents.project_agent import ProjectAgent
|
||||||
from app.agents.task_agent import TaskAgent
|
from app.agents.task_agent import TaskAgent
|
||||||
from app.core.agent_registry import registry
|
from app.core.agent_registry import registry
|
||||||
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
|
|
||||||
|
|
||||||
|
# ── WS executor mock ──────────────────────────────────────────────────
|
||||||
|
#
|
||||||
|
# Tools call execute_on_client() which reads a ContextVar set by the WS
|
||||||
|
# handler. In unit tests there is no WS session, so we install a fake
|
||||||
|
# executor that returns plausible data for each action type.
|
||||||
|
|
||||||
|
_FAKE_ROW: dict[str, Any] = {
|
||||||
|
"id": "fake-id",
|
||||||
|
"title": "Fake Title",
|
||||||
|
"name": "Fake Name",
|
||||||
|
"status": "todo",
|
||||||
|
"priority": "medium",
|
||||||
|
"content": "Fake content",
|
||||||
|
"date": 1700000000000,
|
||||||
|
"taskId": "fake-task-id",
|
||||||
|
"author": "Alice",
|
||||||
|
"projectId": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _fake_executor(payload: dict) -> dict:
|
||||||
|
action = payload.get("action", "")
|
||||||
|
if action == "select":
|
||||||
|
return {"rows": []}
|
||||||
|
if action == "insert":
|
||||||
|
data = payload.get("data", {})
|
||||||
|
return {"row": {**_FAKE_ROW, **data}}
|
||||||
|
if action == "update":
|
||||||
|
data = payload.get("data", {})
|
||||||
|
row = {**_FAKE_ROW, "id": data.get("id", "fake-id"), **data.get("updates", {})}
|
||||||
|
return {"row": row}
|
||||||
|
if action == "delete":
|
||||||
|
return {"deleted": True}
|
||||||
|
if action == "get":
|
||||||
|
data = payload.get("data", {})
|
||||||
|
return {"row": {**_FAKE_ROW, "id": data.get("id", "fake-id")}}
|
||||||
|
if action == "vector_upsert":
|
||||||
|
return {"ok": True}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def ws_executor():
|
||||||
|
"""Install a fake WS executor for every test so tools can run without a real WS."""
|
||||||
|
set_client_executor(_fake_executor)
|
||||||
|
yield
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||||||
@@ -148,110 +198,142 @@ class TestTaskAgentTools:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_tasks_defaults(self) -> None:
|
async def test_list_tasks_defaults(self) -> None:
|
||||||
from app.agents.task_agent import list_tasks
|
from app.agents.task_agent import list_tasks
|
||||||
|
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"rows": []}
|
||||||
result = await list_tasks.ainvoke({})
|
result = await list_tasks.ainvoke({})
|
||||||
data = json.loads(result)
|
m.assert_called_once_with(
|
||||||
assert data["action"] == "list"
|
action="select", table="tasks",
|
||||||
assert data["table"] == "tasks"
|
filters={"projectId": None, "status": None, "search": None, "orderBy": None},
|
||||||
|
)
|
||||||
|
assert result == "No tasks found matching the given filters."
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_tasks_with_status_filter(self) -> None:
|
async def test_list_tasks_with_status_filter(self) -> None:
|
||||||
from app.agents.task_agent import list_tasks
|
from app.agents.task_agent import list_tasks
|
||||||
result = await list_tasks.ainvoke({"status": "done"})
|
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
data = json.loads(result)
|
m.return_value = {"rows": []}
|
||||||
assert data["filters"]["status"] == "done"
|
await list_tasks.ainvoke({"status": "done"})
|
||||||
|
call_kwargs = m.call_args.kwargs
|
||||||
|
assert call_kwargs["filters"]["status"] == "done"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_task_defaults(self) -> None:
|
async def test_create_task_defaults(self) -> None:
|
||||||
from app.agents.task_agent import create_task
|
from app.agents.task_agent import create_task
|
||||||
|
fake_row = {"id": "t1", "title": "Test task", "status": "todo", "priority": "medium"}
|
||||||
|
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"row": fake_row}
|
||||||
result = await create_task.ainvoke({"title": "Test task"})
|
result = await create_task.ainvoke({"title": "Test task"})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "create_record"
|
assert call_kwargs["action"] == "insert"
|
||||||
assert data["table"] == "tasks"
|
assert call_kwargs["table"] == "tasks"
|
||||||
assert data["data"]["title"] == "Test task"
|
assert call_kwargs["data"]["title"] == "Test task"
|
||||||
assert data["data"]["status"] == "todo"
|
assert call_kwargs["data"]["status"] == "todo"
|
||||||
assert data["data"]["priority"] == "medium"
|
assert call_kwargs["data"]["priority"] == "medium"
|
||||||
|
assert "Test task" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_task_with_all_fields(self) -> None:
|
async def test_create_task_with_all_fields(self) -> None:
|
||||||
from app.agents.task_agent import create_task
|
from app.agents.task_agent import create_task
|
||||||
result = await create_task.ainvoke({
|
fake_row = {"id": "t1", "title": "Deploy", "status": "in_progress", "priority": "high"}
|
||||||
"title": "Deploy",
|
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
"priority": "high",
|
m.return_value = {"row": fake_row}
|
||||||
"status": "in_progress",
|
await create_task.ainvoke({
|
||||||
"project_id": "p1",
|
"title": "Deploy", "priority": "high", "status": "in_progress",
|
||||||
"is_ai_suggested": 1,
|
"project_id": "p1", "is_ai_suggested": 1,
|
||||||
})
|
})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["data"]["priority"] == "high"
|
assert call_kwargs["data"]["priority"] == "high"
|
||||||
assert data["data"]["status"] == "in_progress"
|
assert call_kwargs["data"]["status"] == "in_progress"
|
||||||
assert data["data"]["projectId"] == "p1"
|
assert call_kwargs["data"]["projectId"] == "p1"
|
||||||
assert data["data"]["isAiSuggested"] == 1
|
assert call_kwargs["data"]["isAiSuggested"] == 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_task_with_status(self) -> None:
|
async def test_update_task_with_status(self) -> None:
|
||||||
from app.agents.task_agent import update_task
|
from app.agents.task_agent import update_task
|
||||||
|
fake_row = {"id": "t1", "title": "Buy groceries", "status": "done"}
|
||||||
|
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"row": fake_row}
|
||||||
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
|
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "update_record"
|
assert call_kwargs["action"] == "update"
|
||||||
assert data["data"]["id"] == "t1"
|
assert call_kwargs["data"]["id"] == "t1"
|
||||||
assert data["data"]["updates"]["status"] == "done"
|
assert call_kwargs["data"]["updates"]["status"] == "done"
|
||||||
|
assert "t1" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_task_empty_updates(self) -> None:
|
async def test_update_task_empty_updates(self) -> None:
|
||||||
from app.agents.task_agent import update_task
|
from app.agents.task_agent import update_task
|
||||||
result = await update_task.ainvoke({"task_id": "t1"})
|
fake_row = {"id": "t1", "title": "Task", "status": "todo"}
|
||||||
data = json.loads(result)
|
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
assert data["data"]["updates"] == {}
|
m.return_value = {"row": fake_row}
|
||||||
|
await update_task.ainvoke({"task_id": "t1"})
|
||||||
|
call_kwargs = m.call_args.kwargs
|
||||||
|
assert call_kwargs["data"]["updates"] == {}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_task(self) -> None:
|
async def test_delete_task(self) -> None:
|
||||||
from app.agents.task_agent import delete_task
|
from app.agents.task_agent import delete_task
|
||||||
|
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"deleted": True}
|
||||||
result = await delete_task.ainvoke({"task_id": "t1"})
|
result = await delete_task.ainvoke({"task_id": "t1"})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "delete_record"
|
assert call_kwargs["action"] == "delete"
|
||||||
assert data["table"] == "tasks"
|
assert call_kwargs["table"] == "tasks"
|
||||||
assert data["data"]["id"] == "t1"
|
assert call_kwargs["data"]["id"] == "t1"
|
||||||
|
assert "t1" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_tasks_due_today(self) -> None:
|
async def test_list_tasks_due_today(self) -> None:
|
||||||
from app.agents.task_agent import list_tasks_due_today
|
from app.agents.task_agent import list_tasks_due_today
|
||||||
|
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"rows": []}
|
||||||
result = await list_tasks_due_today.ainvoke({})
|
result = await list_tasks_due_today.ainvoke({})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "list_due_today"
|
assert call_kwargs["action"] == "select"
|
||||||
assert data["table"] == "tasks"
|
assert call_kwargs["table"] == "tasks"
|
||||||
|
assert "dueDateFrom" in call_kwargs["filters"]
|
||||||
|
assert result == "No tasks are due today."
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_task_comments(self) -> None:
|
async def test_list_task_comments(self) -> None:
|
||||||
from app.agents.task_agent import list_task_comments
|
from app.agents.task_agent import list_task_comments
|
||||||
|
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"rows": []}
|
||||||
result = await list_task_comments.ainvoke({"task_id": "t1"})
|
result = await list_task_comments.ainvoke({"task_id": "t1"})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "list"
|
assert call_kwargs["action"] == "select"
|
||||||
assert data["table"] == "taskComments"
|
assert call_kwargs["table"] == "taskComments"
|
||||||
assert data["filters"]["taskId"] == "t1"
|
assert call_kwargs["filters"]["taskId"] == "t1"
|
||||||
|
assert "t1" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_task_comment(self) -> None:
|
async def test_add_task_comment(self) -> None:
|
||||||
from app.agents.task_agent import add_task_comment
|
from app.agents.task_agent import add_task_comment
|
||||||
|
fake_row = {"id": "c1", "taskId": "t1", "author": "Alice", "content": "Looks good!"}
|
||||||
|
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"row": fake_row}
|
||||||
result = await add_task_comment.ainvoke({
|
result = await add_task_comment.ainvoke({
|
||||||
"task_id": "t1",
|
"task_id": "t1", "author": "Alice", "content": "Looks good!",
|
||||||
"author": "Alice",
|
|
||||||
"content": "Looks good!",
|
|
||||||
})
|
})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "create_record"
|
assert call_kwargs["action"] == "insert"
|
||||||
assert data["table"] == "taskComments"
|
assert call_kwargs["table"] == "taskComments"
|
||||||
assert data["data"]["taskId"] == "t1"
|
assert call_kwargs["data"]["taskId"] == "t1"
|
||||||
assert data["data"]["author"] == "Alice"
|
assert call_kwargs["data"]["author"] == "Alice"
|
||||||
assert data["data"]["content"] == "Looks good!"
|
assert call_kwargs["data"]["content"] == "Looks good!"
|
||||||
|
assert "Alice" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_task_comment(self) -> None:
|
async def test_delete_task_comment(self) -> None:
|
||||||
from app.agents.task_agent import delete_task_comment
|
from app.agents.task_agent import delete_task_comment
|
||||||
|
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"deleted": True}
|
||||||
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
|
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "delete_record"
|
assert call_kwargs["action"] == "delete"
|
||||||
assert data["table"] == "taskComments"
|
assert call_kwargs["table"] == "taskComments"
|
||||||
assert data["data"]["id"] == "c1"
|
assert call_kwargs["data"]["id"] == "c1"
|
||||||
|
assert "c1" in result
|
||||||
|
|
||||||
|
|
||||||
# ── CheckpointAgent ───────────────────────────────────────────────────
|
# ── CheckpointAgent ───────────────────────────────────────────────────
|
||||||
@@ -301,74 +383,86 @@ class TestCheckpointAgentTools:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_checkpoints_no_project(self) -> None:
|
async def test_list_checkpoints_no_project(self) -> None:
|
||||||
from app.agents.checkpoint_agent import list_checkpoints
|
from app.agents.checkpoint_agent import list_checkpoints
|
||||||
|
with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"rows": []}
|
||||||
result = await list_checkpoints.ainvoke({})
|
result = await list_checkpoints.ainvoke({})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "list"
|
assert call_kwargs["action"] == "select"
|
||||||
assert data["table"] == "checkpoints"
|
assert call_kwargs["table"] == "checkpoints"
|
||||||
assert data["filters"]["projectId"] is None
|
assert call_kwargs["filters"]["projectId"] is None
|
||||||
|
assert result == "No checkpoints found."
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_checkpoints_with_project(self) -> None:
|
async def test_list_checkpoints_with_project(self) -> None:
|
||||||
from app.agents.checkpoint_agent import list_checkpoints
|
from app.agents.checkpoint_agent import list_checkpoints
|
||||||
result = await list_checkpoints.ainvoke({"project_id": "p1"})
|
with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
data = json.loads(result)
|
m.return_value = {"rows": []}
|
||||||
assert data["filters"]["projectId"] == "p1"
|
await list_checkpoints.ainvoke({"project_id": "p1"})
|
||||||
|
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_checkpoint(self) -> None:
|
async def test_create_checkpoint(self) -> None:
|
||||||
from app.agents.checkpoint_agent import create_checkpoint
|
from app.agents.checkpoint_agent import create_checkpoint
|
||||||
|
fake_row = {"id": "cp1", "title": "Beta release", "date": 1700000000000}
|
||||||
|
with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"row": fake_row}
|
||||||
result = await create_checkpoint.ainvoke({
|
result = await create_checkpoint.ainvoke({
|
||||||
"project_id": "p1",
|
"project_id": "p1", "title": "Beta release", "date": 1700000000000,
|
||||||
"title": "Beta release",
|
|
||||||
"date": 1700000000000,
|
|
||||||
})
|
})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "create_record"
|
assert call_kwargs["action"] == "insert"
|
||||||
assert data["table"] == "checkpoints"
|
assert call_kwargs["table"] == "checkpoints"
|
||||||
assert data["data"]["projectId"] == "p1"
|
assert call_kwargs["data"]["projectId"] == "p1"
|
||||||
assert data["data"]["title"] == "Beta release"
|
assert call_kwargs["data"]["title"] == "Beta release"
|
||||||
assert data["data"]["date"] == 1700000000000
|
assert call_kwargs["data"]["date"] == 1700000000000
|
||||||
|
assert "Beta release" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_checkpoint_ai_suggested(self) -> None:
|
async def test_create_checkpoint_ai_suggested(self) -> None:
|
||||||
from app.agents.checkpoint_agent import create_checkpoint
|
from app.agents.checkpoint_agent import create_checkpoint
|
||||||
result = await create_checkpoint.ainvoke({
|
fake_row = {"id": "cp1", "title": "Review", "date": 1700000000000}
|
||||||
"project_id": "p1",
|
with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
"title": "Review",
|
m.return_value = {"row": fake_row}
|
||||||
"date": 1700000000000,
|
await create_checkpoint.ainvoke({
|
||||||
"is_ai_suggested": 1,
|
"project_id": "p1", "title": "Review", "date": 1700000000000, "is_ai_suggested": 1,
|
||||||
})
|
})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["data"]["isAiSuggested"] == 1
|
assert call_kwargs["data"]["isAiSuggested"] == 1
|
||||||
assert data["data"]["isApproved"] == 0
|
assert call_kwargs["data"]["isApproved"] == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_checkpoint_approve(self) -> None:
|
async def test_update_checkpoint_approve(self) -> None:
|
||||||
from app.agents.checkpoint_agent import update_checkpoint
|
from app.agents.checkpoint_agent import update_checkpoint
|
||||||
result = await update_checkpoint.ainvoke({
|
fake_row = {"id": "c1", "title": "MVP", "isApproved": 1}
|
||||||
"checkpoint_id": "c1",
|
with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
"is_approved": 1,
|
m.return_value = {"row": fake_row}
|
||||||
})
|
result = await update_checkpoint.ainvoke({"checkpoint_id": "c1", "is_approved": 1})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "update_record"
|
assert call_kwargs["action"] == "update"
|
||||||
assert data["data"]["id"] == "c1"
|
assert call_kwargs["data"]["id"] == "c1"
|
||||||
assert data["data"]["updates"]["isApproved"] == 1
|
assert call_kwargs["data"]["updates"]["isApproved"] == 1
|
||||||
|
assert "c1" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_checkpoint_empty_updates(self) -> None:
|
async def test_update_checkpoint_empty_updates(self) -> None:
|
||||||
from app.agents.checkpoint_agent import update_checkpoint
|
from app.agents.checkpoint_agent import update_checkpoint
|
||||||
result = await update_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
fake_row = {"id": "c1", "title": "MVP"}
|
||||||
data = json.loads(result)
|
with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
assert data["data"]["updates"] == {}
|
m.return_value = {"row": fake_row}
|
||||||
|
await update_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
||||||
|
assert m.call_args.kwargs["data"]["updates"] == {}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_checkpoint(self) -> None:
|
async def test_delete_checkpoint(self) -> None:
|
||||||
from app.agents.checkpoint_agent import delete_checkpoint
|
from app.agents.checkpoint_agent import delete_checkpoint
|
||||||
|
with patch("app.agents.checkpoint_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"deleted": True}
|
||||||
result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "delete_record"
|
assert call_kwargs["action"] == "delete"
|
||||||
assert data["table"] == "checkpoints"
|
assert call_kwargs["table"] == "checkpoints"
|
||||||
assert data["data"]["id"] == "c1"
|
assert call_kwargs["data"]["id"] == "c1"
|
||||||
|
assert "c1" in result
|
||||||
|
|
||||||
|
|
||||||
# ── ProjectAgent ──────────────────────────────────────────────────────
|
# ── ProjectAgent ──────────────────────────────────────────────────────
|
||||||
@@ -425,75 +519,101 @@ class TestProjectAgentTools:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_projects_defaults(self) -> None:
|
async def test_list_projects_defaults(self) -> None:
|
||||||
from app.agents.project_agent import list_projects
|
from app.agents.project_agent import list_projects
|
||||||
|
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"rows": []}
|
||||||
result = await list_projects.ainvoke({})
|
result = await list_projects.ainvoke({})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "list"
|
assert call_kwargs["action"] == "select"
|
||||||
assert data["table"] == "projects"
|
assert call_kwargs["table"] == "projects"
|
||||||
assert data["filters"]["includeArchived"] is False
|
assert call_kwargs["filters"]["includeArchived"] is False
|
||||||
|
assert result == "No projects found."
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_projects_include_archived(self) -> None:
|
async def test_list_projects_include_archived(self) -> None:
|
||||||
from app.agents.project_agent import list_projects
|
from app.agents.project_agent import list_projects
|
||||||
result = await list_projects.ainvoke({"include_archived": 1})
|
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
data = json.loads(result)
|
m.return_value = {"rows": []}
|
||||||
assert data["filters"]["includeArchived"] is True
|
await list_projects.ainvoke({"include_archived": 1})
|
||||||
|
assert m.call_args.kwargs["filters"]["includeArchived"] is True
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_all_projects(self) -> None:
|
async def test_list_all_projects(self) -> None:
|
||||||
from app.agents.project_agent import list_all_projects
|
from app.agents.project_agent import list_all_projects
|
||||||
|
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"rows": []}
|
||||||
result = await list_all_projects.ainvoke({})
|
result = await list_all_projects.ainvoke({})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "list_all"
|
assert call_kwargs["action"] == "select"
|
||||||
assert data["table"] == "projects"
|
assert call_kwargs["table"] == "projects"
|
||||||
|
assert result == "No projects found."
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_project(self) -> None:
|
async def test_get_project(self) -> None:
|
||||||
from app.agents.project_agent import get_project
|
from app.agents.project_agent import get_project
|
||||||
|
fake_row = {"id": "p1", "name": "Alpha", "status": "active", "clientId": None}
|
||||||
|
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"row": fake_row}
|
||||||
result = await get_project.ainvoke({"project_id": "p1"})
|
result = await get_project.ainvoke({"project_id": "p1"})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "get"
|
assert call_kwargs["action"] == "get"
|
||||||
assert data["table"] == "projects"
|
assert call_kwargs["table"] == "projects"
|
||||||
assert data["data"]["id"] == "p1"
|
assert call_kwargs["data"]["id"] == "p1"
|
||||||
|
assert "Alpha" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_project_name_only(self) -> None:
|
async def test_create_project_name_only(self) -> None:
|
||||||
from app.agents.project_agent import create_project
|
from app.agents.project_agent import create_project
|
||||||
|
fake_row = {"id": "p1", "name": "Alpha"}
|
||||||
|
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"row": fake_row}
|
||||||
result = await create_project.ainvoke({"name": "Alpha"})
|
result = await create_project.ainvoke({"name": "Alpha"})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "create_record"
|
assert call_kwargs["action"] == "insert"
|
||||||
assert data["data"]["name"] == "Alpha"
|
assert call_kwargs["data"]["name"] == "Alpha"
|
||||||
assert data["data"]["clientId"] is None
|
assert call_kwargs["data"]["clientId"] is None
|
||||||
|
assert "Alpha" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_project_with_client(self) -> None:
|
async def test_create_project_with_client(self) -> None:
|
||||||
from app.agents.project_agent import create_project
|
from app.agents.project_agent import create_project
|
||||||
result = await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
|
fake_row = {"id": "p1", "name": "Beta"}
|
||||||
data = json.loads(result)
|
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
assert data["data"]["clientId"] == "cl1"
|
m.return_value = {"row": fake_row}
|
||||||
|
await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
|
||||||
|
assert m.call_args.kwargs["data"]["clientId"] == "cl1"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_project_archive(self) -> None:
|
async def test_update_project_archive(self) -> None:
|
||||||
from app.agents.project_agent import update_project
|
from app.agents.project_agent import update_project
|
||||||
|
fake_row = {"id": "p1", "name": "Alpha", "status": "archived"}
|
||||||
|
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"row": fake_row}
|
||||||
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
|
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "update_record"
|
assert call_kwargs["action"] == "update"
|
||||||
assert data["data"]["id"] == "p1"
|
assert call_kwargs["data"]["id"] == "p1"
|
||||||
assert data["data"]["updates"]["status"] == "archived"
|
assert call_kwargs["data"]["updates"]["status"] == "archived"
|
||||||
|
assert "p1" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_project_empty_updates(self) -> None:
|
async def test_update_project_empty_updates(self) -> None:
|
||||||
from app.agents.project_agent import update_project
|
from app.agents.project_agent import update_project
|
||||||
result = await update_project.ainvoke({"project_id": "p1"})
|
fake_row = {"id": "p1", "name": "Alpha", "status": "active"}
|
||||||
data = json.loads(result)
|
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
assert data["data"]["updates"] == {}
|
m.return_value = {"row": fake_row}
|
||||||
|
await update_project.ainvoke({"project_id": "p1"})
|
||||||
|
assert m.call_args.kwargs["data"]["updates"] == {}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_project(self) -> None:
|
async def test_delete_project(self) -> None:
|
||||||
from app.agents.project_agent import delete_project
|
from app.agents.project_agent import delete_project
|
||||||
|
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"deleted": True}
|
||||||
result = await delete_project.ainvoke({"project_id": "p1"})
|
result = await delete_project.ainvoke({"project_id": "p1"})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "delete_record"
|
assert call_kwargs["action"] == "delete"
|
||||||
assert data["data"]["id"] == "p1"
|
assert call_kwargs["data"]["id"] == "p1"
|
||||||
|
assert "p1" in result
|
||||||
|
|
||||||
|
|
||||||
# ── NoteAgent ─────────────────────────────────────────────────────────
|
# ── NoteAgent ─────────────────────────────────────────────────────────
|
||||||
@@ -543,78 +663,99 @@ class TestNoteAgentTools:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_notes_no_project(self) -> None:
|
async def test_list_notes_no_project(self) -> None:
|
||||||
from app.agents.note_agent import list_notes
|
from app.agents.note_agent import list_notes
|
||||||
|
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"rows": []}
|
||||||
result = await list_notes.ainvoke({})
|
result = await list_notes.ainvoke({})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "list"
|
assert call_kwargs["action"] == "select"
|
||||||
assert data["table"] == "notes"
|
assert call_kwargs["table"] == "notes"
|
||||||
assert data["filters"]["projectId"] is None
|
assert call_kwargs["filters"]["projectId"] is None
|
||||||
|
assert result == "No notes found."
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_notes_with_project(self) -> None:
|
async def test_list_notes_with_project(self) -> None:
|
||||||
from app.agents.note_agent import list_notes
|
from app.agents.note_agent import list_notes
|
||||||
result = await list_notes.ainvoke({"project_id": "p1"})
|
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
data = json.loads(result)
|
m.return_value = {"rows": []}
|
||||||
assert data["filters"]["projectId"] == "p1"
|
await list_notes.ainvoke({"project_id": "p1"})
|
||||||
|
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_note(self) -> None:
|
async def test_get_note(self) -> None:
|
||||||
from app.agents.note_agent import get_note
|
from app.agents.note_agent import get_note
|
||||||
|
fake_row = {"id": "n1", "title": "Daily log", "content": "# Today\nAll good."}
|
||||||
|
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"row": fake_row}
|
||||||
result = await get_note.ainvoke({"note_id": "n1"})
|
result = await get_note.ainvoke({"note_id": "n1"})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "get"
|
assert call_kwargs["action"] == "get"
|
||||||
assert data["table"] == "notes"
|
assert call_kwargs["table"] == "notes"
|
||||||
assert data["data"]["id"] == "n1"
|
assert call_kwargs["data"]["id"] == "n1"
|
||||||
|
assert "Daily log" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_note_minimal(self) -> None:
|
async def test_create_note_minimal(self) -> None:
|
||||||
from app.agents.note_agent import create_note
|
from app.agents.note_agent import create_note
|
||||||
result = await create_note.ainvoke({
|
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
|
||||||
"title": "Daily log",
|
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
|
||||||
"content": "# Today\nAll good.",
|
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
|
||||||
})
|
m.return_value = {"row": fake_row}
|
||||||
data = json.loads(result)
|
me.return_value = [0.0] * 1536
|
||||||
assert data["action"] == "create_record"
|
result = await create_note.ainvoke({"title": "Daily log", "content": "# Today\nAll good."})
|
||||||
assert data["table"] == "notes"
|
# First call: insert; second call: vector_upsert
|
||||||
assert data["data"]["title"] == "Daily log"
|
first_call = m.call_args_list[0].kwargs
|
||||||
assert data["data"]["content"] == "# Today\nAll good."
|
assert first_call["action"] == "insert"
|
||||||
assert data["data"]["projectId"] is None
|
assert first_call["table"] == "notes"
|
||||||
|
assert first_call["data"]["title"] == "Daily log"
|
||||||
|
assert first_call["data"]["content"] == "# Today\nAll good."
|
||||||
|
assert first_call["data"]["projectId"] is None
|
||||||
|
assert "Daily log" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_note_with_project(self) -> None:
|
async def test_create_note_with_project(self) -> None:
|
||||||
from app.agents.note_agent import create_note
|
from app.agents.note_agent import create_note
|
||||||
result = await create_note.ainvoke({
|
fake_row = {"id": "n1", "title": "Sprint notes", "projectId": "p1"}
|
||||||
"title": "Sprint notes",
|
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
|
||||||
"content": "## Sprint 1",
|
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
|
||||||
"project_id": "p1",
|
m.return_value = {"row": fake_row}
|
||||||
})
|
me.return_value = [0.0] * 1536
|
||||||
data = json.loads(result)
|
await create_note.ainvoke({"title": "Sprint notes", "content": "## Sprint 1", "project_id": "p1"})
|
||||||
assert data["data"]["projectId"] == "p1"
|
first_call = m.call_args_list[0].kwargs
|
||||||
|
assert first_call["data"]["projectId"] == "p1"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_note_content_only(self) -> None:
|
async def test_update_note_content_only(self) -> None:
|
||||||
from app.agents.note_agent import update_note
|
from app.agents.note_agent import update_note
|
||||||
result = await update_note.ainvoke({
|
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
|
||||||
"note_id": "n1",
|
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
|
||||||
"content": "# Updated content",
|
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
|
||||||
})
|
m.return_value = {"row": fake_row}
|
||||||
data = json.loads(result)
|
me.return_value = [0.0] * 1536
|
||||||
assert data["action"] == "update_record"
|
result = await update_note.ainvoke({"note_id": "n1", "content": "# Updated content"})
|
||||||
assert data["data"]["id"] == "n1"
|
first_call = m.call_args_list[0].kwargs
|
||||||
assert data["data"]["updates"]["content"] == "# Updated content"
|
assert first_call["action"] == "update"
|
||||||
assert "title" not in data["data"]["updates"]
|
assert first_call["data"]["id"] == "n1"
|
||||||
|
assert first_call["data"]["updates"]["content"] == "# Updated content"
|
||||||
|
assert "title" not in first_call["data"]["updates"]
|
||||||
|
assert "n1" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_note_empty_updates(self) -> None:
|
async def test_update_note_empty_updates(self) -> None:
|
||||||
from app.agents.note_agent import update_note
|
from app.agents.note_agent import update_note
|
||||||
result = await update_note.ainvoke({"note_id": "n1"})
|
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
|
||||||
data = json.loads(result)
|
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
assert data["data"]["updates"] == {}
|
m.return_value = {"row": fake_row}
|
||||||
|
await update_note.ainvoke({"note_id": "n1"})
|
||||||
|
assert m.call_args.kwargs["data"]["updates"] == {}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_note(self) -> None:
|
async def test_delete_note(self) -> None:
|
||||||
from app.agents.note_agent import delete_note
|
from app.agents.note_agent import delete_note
|
||||||
|
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = {"deleted": True}
|
||||||
result = await delete_note.ainvoke({"note_id": "n1"})
|
result = await delete_note.ainvoke({"note_id": "n1"})
|
||||||
data = json.loads(result)
|
call_kwargs = m.call_args.kwargs
|
||||||
assert data["action"] == "delete_record"
|
assert call_kwargs["action"] == "delete"
|
||||||
assert data["table"] == "notes"
|
assert call_kwargs["table"] == "notes"
|
||||||
assert data["data"]["id"] == "n1"
|
assert call_kwargs["data"]["id"] == "n1"
|
||||||
|
assert "n1" in result
|
||||||
|
|||||||
284
tests/test_memory_middleware.py
Normal file
284
tests/test_memory_middleware.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
"""Tests for Step 7 — MemoryMiddleware.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
1. enrich_context returns core prefs + associative + episodic + proactive
|
||||||
|
2. store_episode creates an encrypted row decryptable with the user's key
|
||||||
|
3. update_core upserts correctly
|
||||||
|
4. User with no encryption_key returns empty context (no crash)
|
||||||
|
5. End-to-end: home_request WS frame results in an episodic row being stored
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware, _PROACTIVE_CONFIDENCE_THRESHOLD
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import (
|
||||||
|
MemoryAssociative,
|
||||||
|
MemoryCore,
|
||||||
|
MemoryEpisodic,
|
||||||
|
MemoryProactive,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
||||||
|
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["power"]
|
||||||
|
_FERNET_KEY = Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── DB override ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def user_with_key(db_session):
|
||||||
|
"""Set encryption_key on the seeded power user."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = _FERNET_KEY
|
||||||
|
await db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def _fernet():
|
||||||
|
return Fernet(_FERNET_KEY.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def _enc(plaintext: str) -> str:
|
||||||
|
return _fernet().encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _dec(ciphertext: str) -> str:
|
||||||
|
return _fernet().decrypt(ciphertext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── enrich_context ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_core_memory(db_session, user_with_key):
|
||||||
|
# Seed a core memory row
|
||||||
|
db_session.add(MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
key="timezone",
|
||||||
|
value_encrypted=_enc("UTC"),
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "What are my tasks?")
|
||||||
|
|
||||||
|
assert "core_memory" in ctx
|
||||||
|
assert ctx["core_memory"]["timezone"] == "UTC"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_episodic_memory(db_session, user_with_key):
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
db_session.add(MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=_enc("User asked about Q1 tasks"),
|
||||||
|
session_id=session_id,
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "any message")
|
||||||
|
|
||||||
|
assert "episodic_memory" in ctx
|
||||||
|
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
||||||
|
# Add one pattern above threshold and one below
|
||||||
|
db_session.add(MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
pattern_encrypted=_enc("User prefers short summaries"),
|
||||||
|
confidence=0.9,
|
||||||
|
source="inferred",
|
||||||
|
))
|
||||||
|
db_session.add(MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
pattern_encrypted=_enc("User likes dark mode"),
|
||||||
|
confidence=0.1,
|
||||||
|
source="inferred",
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "any message")
|
||||||
|
|
||||||
|
assert "proactive_hints" in ctx
|
||||||
|
hints = ctx["proactive_hints"]
|
||||||
|
assert any("short summaries" in h for h in hints)
|
||||||
|
assert not any("dark mode" in h for h in hints)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_associative_memory(db_session, user_with_key):
|
||||||
|
db_session.add(MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
content_encrypted=_enc("Related memory about meetings"),
|
||||||
|
embedding=None,
|
||||||
|
entity_type="note",
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "meetings")
|
||||||
|
|
||||||
|
assert "associative_memory" in ctx
|
||||||
|
assert any("meetings" in m for m in ctx["associative_memory"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_empty_for_user_without_key(db_session):
|
||||||
|
"""User with no encryption_key → empty context, no crash."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = None
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "hello")
|
||||||
|
assert ctx == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ── store_episode ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_episode_creates_encrypted_row(db_session, user_with_key):
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
await middleware.store_episode(USER_ID, session_id, "hello", "world")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
||||||
|
)
|
||||||
|
row = result.scalar_one()
|
||||||
|
plaintext = _dec(row.summary_encrypted)
|
||||||
|
assert "hello" in plaintext
|
||||||
|
assert "world" in plaintext
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_episode_decryptable(db_session, user_with_key):
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
await middleware.store_episode(USER_ID, session_id, "msg", "resp")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
||||||
|
)
|
||||||
|
row = result.scalar_one()
|
||||||
|
# Decrypt using the same key — must not raise
|
||||||
|
decrypted = _dec(row.summary_encrypted)
|
||||||
|
assert len(decrypted) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── update_core ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_core_insert(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
await middleware.update_core(USER_ID, "lang", "en")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
|
||||||
|
)
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert _dec(row.value_encrypted) == "en"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_core_upsert(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
await middleware.update_core(USER_ID, "lang", "en")
|
||||||
|
await middleware.update_core(USER_ID, "lang", "fr")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert _dec(rows[0].value_encrypted) == "fr"
|
||||||
|
|
||||||
|
|
||||||
|
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
||||||
|
|
||||||
|
def test_home_request_calls_memory_middleware(client):
|
||||||
|
"""home_request triggers enrich_context before and store_episode after the LLM."""
|
||||||
|
enrich_calls: list[tuple] = []
|
||||||
|
store_calls: list[tuple] = []
|
||||||
|
|
||||||
|
class _MockMiddleware:
|
||||||
|
def __init__(self, db):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def enrich_context(self, user_id, message):
|
||||||
|
enrich_calls.append((user_id, message))
|
||||||
|
return {"core_memory": {"tz": "UTC"}}
|
||||||
|
|
||||||
|
async def store_episode(self, user_id, session_id, message, response):
|
||||||
|
store_calls.append((user_id, session_id, message, response))
|
||||||
|
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async def _mock_stream(user_id, message, context, reg=None):
|
||||||
|
# Verify memory context was injected
|
||||||
|
assert context.get("core_memory") == {"tz": "UTC"}
|
||||||
|
yield "task_agent", ""
|
||||||
|
yield "task_agent", '{"type": "text", "content": "Done"}'
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
||||||
|
patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_stream),
|
||||||
|
):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-mem", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "home_request",
|
||||||
|
"request_id": "r-mem",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "Show tasks",
|
||||||
|
}))
|
||||||
|
for _ in range(20):
|
||||||
|
raw = ws.receive_text()
|
||||||
|
frame = json.loads(raw)
|
||||||
|
if frame.get("type") == "stream_end":
|
||||||
|
break
|
||||||
|
|
||||||
|
assert len(enrich_calls) == 1
|
||||||
|
assert enrich_calls[0] == (USER_ID, "Show tasks")
|
||||||
|
assert len(store_calls) == 1
|
||||||
|
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
||||||
|
assert stored_session_id == session_id
|
||||||
|
assert stored_message == "Show tasks"
|
||||||
205
tests/test_memory_models.py
Normal file
205
tests/test_memory_models.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Tests for Step 6 — memory ORM models and User.encryption_key.
|
||||||
|
|
||||||
|
Uses the SQLite in-memory test DB (from conftest). The pgvector embedding
|
||||||
|
column is stored as JSON in tests (SQLite-compatible).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.models import MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive, User
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _fernet_key() -> str:
|
||||||
|
return Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _encrypt(key: str, plaintext: str) -> str:
|
||||||
|
return Fernet(key.encode()).encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _decrypt(key: str, ciphertext: str) -> str:
|
||||||
|
return Fernet(key.encode()).decrypt(ciphertext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── User.encryption_key ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_encryption_key_column_exists(db_session):
|
||||||
|
"""User model has encryption_key column and it can be set."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
# Column exists (may be None for seeded users)
|
||||||
|
assert hasattr(user, "encryption_key")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_encryption_key_can_be_set(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = key
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result2 = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user2 = result2.scalar_one()
|
||||||
|
assert user2.encryption_key == key
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryCore ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_core_create_and_read(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
encrypted_val = _encrypt(key, "UTC")
|
||||||
|
|
||||||
|
row = MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
key="timezone",
|
||||||
|
value_encrypted=encrypted_val,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID)
|
||||||
|
)
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.key == "timezone"
|
||||||
|
assert _decrypt(key, fetched.value_encrypted) == "UTC"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_core_cascade_delete(db_session):
|
||||||
|
"""Deleting a user cascades to memory_core."""
|
||||||
|
row = MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
key="lang",
|
||||||
|
value_encrypted="enc",
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
user = (await db_session.execute(select(User).where(User.id == USER_ID))).scalar_one()
|
||||||
|
await db_session.delete(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
remaining = (
|
||||||
|
await db_session.execute(select(MemoryCore).where(MemoryCore.user_id == USER_ID))
|
||||||
|
).scalars().all()
|
||||||
|
assert remaining == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryAssociative ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_associative_create_and_read(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
content = _encrypt(key, "User prefers morning meetings")
|
||||||
|
embedding = [0.1] * 1536 # fake embedding
|
||||||
|
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
content_encrypted=content,
|
||||||
|
embedding=embedding,
|
||||||
|
entity_type="preference",
|
||||||
|
entity_id=None,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryAssociative).where(MemoryAssociative.user_id == USER_ID)
|
||||||
|
)
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.entity_type == "preference"
|
||||||
|
assert _decrypt(key, fetched.content_encrypted) == "User prefers morning meetings"
|
||||||
|
assert len(fetched.embedding) == 1536
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryEpisodic ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_episodic_create_and_read(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
summary = _encrypt(key, "User asked about Q1 tasks")
|
||||||
|
|
||||||
|
row = MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=summary,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
||||||
|
)
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert _decrypt(key, fetched.summary_encrypted) == "User asked about Q1 tasks"
|
||||||
|
assert isinstance(fetched.created_at, datetime)
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryProactive ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_proactive_create_and_read(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
pattern = _encrypt(key, "User always assigns tasks to self")
|
||||||
|
|
||||||
|
row = MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
pattern_encrypted=pattern,
|
||||||
|
confidence=0.85,
|
||||||
|
source="inferred",
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryProactive).where(MemoryProactive.user_id == USER_ID)
|
||||||
|
)
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.confidence == pytest.approx(0.85)
|
||||||
|
assert fetched.source == "inferred"
|
||||||
|
assert _decrypt(key, fetched.pattern_encrypted) == "User always assigns tasks to self"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Auth registration generates encryption_key ───────────────────────────────
|
||||||
|
|
||||||
|
def test_register_sets_encryption_key(client):
|
||||||
|
"""POST /api/v1/auth/register creates a user with a valid Fernet key."""
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "newuser@test.com", "password": "testpassword123"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
|
||||||
|
# Fetch the newly created user via the access token
|
||||||
|
token = resp.json()["access_token"]
|
||||||
|
me_resp = client.get(
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert me_resp.status_code == 200
|
||||||
|
# We can't see encryption_key in the API response (not in UserProfile),
|
||||||
|
# but we verify registration didn't crash — key generation is implicit.
|
||||||
@@ -302,7 +302,7 @@ class TestOrchestrateStream:
|
|||||||
assert len(chunks) >= 1
|
assert len(chunks) >= 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_last_chunk_is_final_json_frame(
|
async def test_all_chunks_are_plain_text(
|
||||||
self, reg: AgentRegistry
|
self, reg: AgentRegistry
|
||||||
) -> None:
|
) -> None:
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
@@ -310,13 +310,12 @@ class TestOrchestrateStream:
|
|||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||||
|
|
||||||
last = json.loads(chunks[-1])
|
# orchestrate_stream yields plain text chunks only — no JSON final frame
|
||||||
assert last["done"] is True
|
for chunk in chunks:
|
||||||
assert "response" in last
|
assert isinstance(chunk, str)
|
||||||
assert "actions" in last
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_final_frame_response_matches_agent_output(
|
async def test_concatenated_chunks_equal_full_response(
|
||||||
self, reg: AgentRegistry
|
self, reg: AgentRegistry
|
||||||
) -> None:
|
) -> None:
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
@@ -324,8 +323,8 @@ class TestOrchestrateStream:
|
|||||||
request = ChatRequest(message="create a task", execution_mode="direct")
|
request = ChatRequest(message="create a task", execution_mode="direct")
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||||
|
|
||||||
final = json.loads(chunks[-1])
|
full_text = "".join(chunks)
|
||||||
assert final["response"] == "task: create a task"
|
assert full_text == "task: create a task"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_text_chunks_before_final_frame(
|
async def test_text_chunks_before_final_frame(
|
||||||
|
|||||||
236
tests/test_orchestrator_v3.py
Normal file
236
tests/test_orchestrator_v3.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""Tests for v3 orchestrator functions (Step 3)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.agent_registry import ChatAgent, AgentRegistry
|
||||||
|
from app.core.orchestrator import orchestrate_v3, orchestrate_v3_stream
|
||||||
|
|
||||||
|
|
||||||
|
# ── Minimal agent for testing ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _FixedAgent(ChatAgent):
|
||||||
|
def __init__(self, name: str = "_fixed", tokens: list[str] | None = None, **kwargs: Any) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._name = name
|
||||||
|
self._tokens = tokens or ["Hello", " world"]
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Fixed agent for tests"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return "".join(self._tokens)
|
||||||
|
|
||||||
|
async def handle_stream(self, query: str, context: dict[str, Any]):
|
||||||
|
for tok in self._tokens:
|
||||||
|
yield tok
|
||||||
|
|
||||||
|
|
||||||
|
# ── Mock registry factory ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_registry(agent_name: str, agent: ChatAgent) -> MagicMock:
|
||||||
|
reg = MagicMock(spec=AgentRegistry)
|
||||||
|
reg.list_agents.return_value = [{"name": agent_name, "description": "test"}]
|
||||||
|
reg.get.return_value = agent
|
||||||
|
return reg
|
||||||
|
|
||||||
|
|
||||||
|
# ── orchestrate_v3 ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_returns_agent_name_and_instance():
|
||||||
|
agent = _FixedAgent("task_agent")
|
||||||
|
reg = _make_registry("task_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||||
|
name, inst = await orchestrate_v3(
|
||||||
|
user_id="u-1", message="fix a bug", context={}, reg=reg
|
||||||
|
)
|
||||||
|
|
||||||
|
assert name == "task_agent"
|
||||||
|
assert inst is agent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_classify_called_with_message_and_context():
|
||||||
|
agent = _FixedAgent("note_agent")
|
||||||
|
reg = _make_registry("note_agent", agent)
|
||||||
|
ctx = {"some": "context"}
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")) as mock_classify:
|
||||||
|
await orchestrate_v3(user_id="u-1", message="take a note", context=ctx, reg=reg)
|
||||||
|
|
||||||
|
mock_classify.assert_awaited_once()
|
||||||
|
call_args = mock_classify.call_args
|
||||||
|
assert call_args[0][0] == "take a note"
|
||||||
|
assert call_args[0][1] == ctx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_uses_default_registry_when_none():
|
||||||
|
agent = _FixedAgent("task_agent")
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
|
||||||
|
patch("app.core.orchestrator._default_registry") as mock_reg:
|
||||||
|
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
|
||||||
|
mock_reg.get.return_value = agent
|
||||||
|
name, inst = await orchestrate_v3(user_id="u-1", message="hi", context={})
|
||||||
|
|
||||||
|
assert name == "task_agent"
|
||||||
|
assert inst is agent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_get_called_with_agent_name():
|
||||||
|
agent = _FixedAgent("checkpoint_agent")
|
||||||
|
reg = _make_registry("checkpoint_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="checkpoint_agent")):
|
||||||
|
await orchestrate_v3(user_id="u-2", message="schedule", context={}, reg=reg)
|
||||||
|
|
||||||
|
reg.get.assert_called_once_with("checkpoint_agent")
|
||||||
|
|
||||||
|
|
||||||
|
# ── orchestrate_v3_stream ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _collect(gen) -> list[tuple[str, str]]:
|
||||||
|
results: list[tuple[str, str]] = []
|
||||||
|
async for item in gen:
|
||||||
|
results.append(item)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_stream_first_yield_is_domain_signal():
|
||||||
|
agent = _FixedAgent("task_agent", tokens=["token1"])
|
||||||
|
reg = _make_registry("task_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||||
|
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
||||||
|
results = await _collect(gen)
|
||||||
|
|
||||||
|
# First item must be (agent_name, "") — domain signal
|
||||||
|
assert results[0] == ("task_agent", "")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_stream_yields_agent_name_with_tokens():
|
||||||
|
agent = _FixedAgent("task_agent", tokens=["Hello", " ", "world"])
|
||||||
|
reg = _make_registry("task_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||||
|
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
||||||
|
results = await _collect(gen)
|
||||||
|
|
||||||
|
# All items are (agent_name, token) pairs
|
||||||
|
assert all(name == "task_agent" for name, _ in results)
|
||||||
|
tokens = [tok for _, tok in results]
|
||||||
|
assert tokens[0] == "" # domain signal
|
||||||
|
assert tokens[1:] == ["Hello", " ", "world"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_stream_different_agent():
|
||||||
|
agent = _FixedAgent("note_agent", tokens=["note"])
|
||||||
|
reg = _make_registry("note_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")):
|
||||||
|
gen = orchestrate_v3_stream(user_id="u-2", message="take note", context={}, reg=reg)
|
||||||
|
results = await _collect(gen)
|
||||||
|
|
||||||
|
assert results[0] == ("note_agent", "")
|
||||||
|
assert ("note_agent", "note") in results
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_stream_uses_default_registry_when_none():
|
||||||
|
agent = _FixedAgent("task_agent", tokens=["x"])
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
|
||||||
|
patch("app.core.orchestrator._default_registry") as mock_reg:
|
||||||
|
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
|
||||||
|
mock_reg.get.return_value = agent
|
||||||
|
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={})
|
||||||
|
results = await _collect(gen)
|
||||||
|
|
||||||
|
assert results[0][0] == "task_agent"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_stream_empty_token_list():
|
||||||
|
"""Agent with no tokens still emits the domain signal."""
|
||||||
|
|
||||||
|
class _EmptyAgent(_FixedAgent):
|
||||||
|
async def handle_stream(self, query: str, context: dict[str, Any]):
|
||||||
|
return
|
||||||
|
yield # makes it a generator
|
||||||
|
|
||||||
|
agent = _EmptyAgent("task_agent", tokens=[])
|
||||||
|
reg = _make_registry("task_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||||
|
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
||||||
|
results = await _collect(gen)
|
||||||
|
|
||||||
|
assert results == [("task_agent", "")] # only domain signal
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_stream_full_text_correct():
|
||||||
|
"""Concatenating all non-domain tokens reconstructs the full response."""
|
||||||
|
tokens = ["The", " ", "task", " ", "is", " ", "done."]
|
||||||
|
agent = _FixedAgent("task_agent", tokens=tokens)
|
||||||
|
reg = _make_registry("task_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||||
|
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
||||||
|
results = await _collect(gen)
|
||||||
|
|
||||||
|
text = "".join(tok for _, tok in results[1:]) # skip domain signal
|
||||||
|
assert text == "The task is done."
|
||||||
|
|
||||||
|
|
||||||
|
# ── handle_stream default implementation ─────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_stream_default_yields_full_response():
|
||||||
|
"""Default handle_stream yields handle() result as a single chunk."""
|
||||||
|
|
||||||
|
class _SimpleAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "_simple"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return "simple response"
|
||||||
|
|
||||||
|
agent = _SimpleAgent()
|
||||||
|
tokens = [tok async for tok in agent.handle_stream("q", {})]
|
||||||
|
assert tokens == ["simple response"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_stream_override_used_by_stream():
|
||||||
|
"""_FixedAgent.handle_stream override yields individual tokens."""
|
||||||
|
agent = _FixedAgent("t", tokens=["a", "b", "c"])
|
||||||
|
tokens = [tok async for tok in agent.handle_stream("q", {})]
|
||||||
|
assert tokens == ["a", "b", "c"]
|
||||||
195
tests/test_output_formatter.py
Normal file
195
tests/test_output_formatter.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||||
|
from app.schemas import (
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsStreamBlock,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _stream(*pairs: tuple[str, str]):
|
||||||
|
"""Async generator that yields (agent_name, token) pairs."""
|
||||||
|
for pair in pairs:
|
||||||
|
yield pair
|
||||||
|
|
||||||
|
|
||||||
|
async def collect(formatter, token_stream):
|
||||||
|
frames = []
|
||||||
|
async for frame in formatter.format(token_stream):
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_text_block():
|
||||||
|
req_id = "req-1"
|
||||||
|
tokens = [
|
||||||
|
("task_agent", '{"type": "text", "content": "Hello world"}'),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(*tokens))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
|
assert frames[0].request_id == req_id
|
||||||
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
|
assert any("Hello world" in f.chunk for f in text_frames)
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_chart_block():
|
||||||
|
req_id = "req-2"
|
||||||
|
chart_json = (
|
||||||
|
'{"type": "chart", "chartType": "bar", '
|
||||||
|
'"title": "Tasks", "data": [{"x": 1}], '
|
||||||
|
'"config": {"x": {"label": "X", "color": "#fff"}}}'
|
||||||
|
)
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", chart_json)))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 1
|
||||||
|
assert block_frames[0].block_type == "chart"
|
||||||
|
assert block_frames[0].data["chartType"] == "bar"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_invalid_chart_skipped():
|
||||||
|
req_id = "req-3"
|
||||||
|
bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}'
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", bad_chart)))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 0 # invalid chart skipped
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_entity_ref_resolved():
|
||||||
|
req_id = "req-4"
|
||||||
|
tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}]
|
||||||
|
entity_json = '{"type": "entity_ref", "entity": "task"}'
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=tool_results)
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", entity_json)))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 1
|
||||||
|
assert block_frames[0].data["entity"] == "task"
|
||||||
|
assert block_frames[0].data["items"][0]["id"] == "t1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_entity_ref_missing_skipped():
|
||||||
|
req_id = "req-5"
|
||||||
|
entity_json = '{"type": "entity_ref", "entity": "task"}'
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", entity_json)))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 0 # no tool results → skipped
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_table_block():
|
||||||
|
req_id = "req-6"
|
||||||
|
table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}'
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", table_json)))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 1
|
||||||
|
assert block_frames[0].block_type == "table"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_timeline_block():
|
||||||
|
req_id = "req-7"
|
||||||
|
timeline_json = '{"type": "timeline", "checkpoints": [{"id": "c1", "title": "M1", "date": 123}]}'
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", timeline_json)))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 1
|
||||||
|
assert block_frames[0].block_type == "timeline"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_frame_order():
|
||||||
|
"""stream_start is first, stream_end is last."""
|
||||||
|
req_id = "req-8"
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}')))
|
||||||
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
# ── FloatingFormatter ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_domain_emitted_first():
|
||||||
|
req_id = "pop-1"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
tokens = [
|
||||||
|
("task_agent", ""), # domain signal
|
||||||
|
("task_agent", "Hello"),
|
||||||
|
("task_agent", " there"),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*tokens))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
|
assert frames[0].domain == "tasks"
|
||||||
|
assert frames[0].request_id == req_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_text_only():
|
||||||
|
req_id = "pop-2"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
tokens = [("checkpoint_agent", ""), ("checkpoint_agent", "Summary")]
|
||||||
|
frames = await collect(formatter, _stream(*tokens))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
|
assert frames[0].domain == "checkpoints"
|
||||||
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
|
assert len(text_frames) == 1
|
||||||
|
assert text_frames[0].chunk == "Summary"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_no_block_frames():
|
||||||
|
"""FloatingFormatter must never emit WsStreamBlock."""
|
||||||
|
req_id = "pop-3"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
tokens = [
|
||||||
|
("note_agent", ""),
|
||||||
|
("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*tokens))
|
||||||
|
assert not any(isinstance(f, WsStreamBlock) for f in frames)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_end_frame():
|
||||||
|
req_id = "pop-4"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done")))
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_unknown_agent_defaults_to_tasks():
|
||||||
|
req_id = "pop-5"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi")))
|
||||||
|
assert frames[0].domain == "tasks"
|
||||||
292
tests/test_schemas_v3.py
Normal file
292
tests/test_schemas_v3.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
"""Tests for v3 WebSocket frame protocol schemas."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.schemas import (
|
||||||
|
WsFrameType,
|
||||||
|
WsHomeRequest,
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsFloatingRequest,
|
||||||
|
WsFloatingScope,
|
||||||
|
WsStreamBlock,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsFrameType ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_v3_frame_types_exist():
|
||||||
|
v3_types = [
|
||||||
|
"home_request",
|
||||||
|
"floating_request",
|
||||||
|
"stream_start",
|
||||||
|
"stream_text",
|
||||||
|
"stream_block",
|
||||||
|
"stream_end",
|
||||||
|
"floating_domain",
|
||||||
|
"data_request",
|
||||||
|
"data_response",
|
||||||
|
"mutation",
|
||||||
|
]
|
||||||
|
for name in v3_types:
|
||||||
|
assert hasattr(WsFrameType, name), f"WsFrameType missing: {name}"
|
||||||
|
assert WsFrameType[name].value == name
|
||||||
|
|
||||||
|
|
||||||
|
def test_v2_frame_types_still_exist():
|
||||||
|
"""Backward compat: v2 types must remain."""
|
||||||
|
v2_types = [
|
||||||
|
"chat_request",
|
||||||
|
"text_chunk",
|
||||||
|
"tool_call",
|
||||||
|
"tool_result",
|
||||||
|
"final",
|
||||||
|
"ping",
|
||||||
|
"agent_run",
|
||||||
|
"agent_data",
|
||||||
|
"agent_complete",
|
||||||
|
"device_hello",
|
||||||
|
]
|
||||||
|
for name in v2_types:
|
||||||
|
assert hasattr(WsFrameType, name), f"v2 WsFrameType missing: {name}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsHomeRequest ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_defaults():
|
||||||
|
frame = WsHomeRequest(message="Hello")
|
||||||
|
assert frame.type == WsFrameType.home_request
|
||||||
|
assert frame.message == "Hello"
|
||||||
|
assert frame.conversation_history == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_with_history():
|
||||||
|
history = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]
|
||||||
|
frame = WsHomeRequest(message="Follow up", conversation_history=history)
|
||||||
|
assert frame.conversation_history == history
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_serializes():
|
||||||
|
frame = WsHomeRequest(message="Test")
|
||||||
|
data = frame.model_dump()
|
||||||
|
assert data["type"] == "home_request"
|
||||||
|
assert data["message"] == "Test"
|
||||||
|
assert data["conversation_history"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_deserializes():
|
||||||
|
raw = {"type": "home_request", "message": "Hi there"}
|
||||||
|
frame = WsHomeRequest.model_validate(raw)
|
||||||
|
assert frame.message == "Hi there"
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_requires_message():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsHomeRequest.model_validate({"type": "home_request"})
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsFloatingRequest ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_basic():
|
||||||
|
frame = WsFloatingRequest(
|
||||||
|
message="Summarise",
|
||||||
|
scope=WsFloatingScope(type="task", id="task-123"),
|
||||||
|
)
|
||||||
|
assert frame.type == WsFrameType.floating_request
|
||||||
|
assert frame.scope.type == "task"
|
||||||
|
assert frame.scope.id == "task-123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_scope_without_id():
|
||||||
|
frame = WsFloatingRequest(
|
||||||
|
message="Show all",
|
||||||
|
scope=WsFloatingScope(type="project"),
|
||||||
|
)
|
||||||
|
assert frame.scope.id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_serializes():
|
||||||
|
frame = WsFloatingRequest(
|
||||||
|
message="Test",
|
||||||
|
scope=WsFloatingScope(type="note", id="n-1"),
|
||||||
|
)
|
||||||
|
data = frame.model_dump()
|
||||||
|
assert data["type"] == "floating_request"
|
||||||
|
assert data["scope"]["type"] == "note"
|
||||||
|
assert data["scope"]["id"] == "n-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_invalid_scope_type():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsFloatingRequest(
|
||||||
|
message="X",
|
||||||
|
scope=WsFloatingScope(type="unknown"), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_requires_scope():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsFloatingRequest.model_validate({"type": "floating_request", "message": "X"})
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsStreamStart ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_start():
|
||||||
|
frame = WsStreamStart(request_id="req-abc")
|
||||||
|
assert frame.type == WsFrameType.stream_start
|
||||||
|
assert frame.request_id == "req-abc"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_start_serializes():
|
||||||
|
data = WsStreamStart(request_id="r1").model_dump()
|
||||||
|
assert data == {"type": "stream_start", "request_id": "r1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_start_deserializes():
|
||||||
|
frame = WsStreamStart.model_validate({"type": "stream_start", "request_id": "r1"})
|
||||||
|
assert frame.request_id == "r1"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsStreamText ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_text():
|
||||||
|
frame = WsStreamText(request_id="r1", chunk="Hello ")
|
||||||
|
assert frame.type == WsFrameType.stream_text
|
||||||
|
assert frame.chunk == "Hello "
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_text_serializes():
|
||||||
|
data = WsStreamText(request_id="r1", chunk="word").model_dump()
|
||||||
|
assert data == {"type": "stream_text", "request_id": "r1", "chunk": "word"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_text_deserializes():
|
||||||
|
raw = {"type": "stream_text", "request_id": "r2", "chunk": "test"}
|
||||||
|
frame = WsStreamText.model_validate(raw)
|
||||||
|
assert frame.chunk == "test"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsStreamBlock ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_block_chart():
|
||||||
|
data = {
|
||||||
|
"type": "chart",
|
||||||
|
"chartType": "bar",
|
||||||
|
"title": "Tasks",
|
||||||
|
"data": [{"name": "Done", "count": 5}],
|
||||||
|
"config": {"count": {"label": "Count", "color": "#4f46e5"}},
|
||||||
|
}
|
||||||
|
frame = WsStreamBlock(request_id="r1", block_type="chart", data=data)
|
||||||
|
assert frame.type == WsFrameType.stream_block
|
||||||
|
assert frame.block_type == "chart"
|
||||||
|
assert frame.data["chartType"] == "bar"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_block_entity_ref():
|
||||||
|
frame = WsStreamBlock(
|
||||||
|
request_id="r1",
|
||||||
|
block_type="entity_ref",
|
||||||
|
data={"type": "task", "id": "t-1", "title": "Fix bug"},
|
||||||
|
)
|
||||||
|
assert frame.block_type == "entity_ref"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_block_table():
|
||||||
|
frame = WsStreamBlock(
|
||||||
|
request_id="r1",
|
||||||
|
block_type="table",
|
||||||
|
data={"headers": ["A", "B"], "rows": [["1", "2"]]},
|
||||||
|
)
|
||||||
|
assert frame.block_type == "table"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_block_timeline():
|
||||||
|
frame = WsStreamBlock(
|
||||||
|
request_id="r1",
|
||||||
|
block_type="timeline",
|
||||||
|
data={"checkpoints": [{"id": "c1", "title": "Launch", "date": 1700000000}]},
|
||||||
|
)
|
||||||
|
assert frame.block_type == "timeline"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_block_invalid_type():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsStreamBlock(
|
||||||
|
request_id="r1",
|
||||||
|
block_type="unknown", # type: ignore[arg-type]
|
||||||
|
data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_block_serializes():
|
||||||
|
frame = WsStreamBlock(request_id="r1", block_type="table", data={"headers": [], "rows": []})
|
||||||
|
d = frame.model_dump()
|
||||||
|
assert d["type"] == "stream_block"
|
||||||
|
assert d["block_type"] == "table"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsStreamEnd ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_defaults():
|
||||||
|
frame = WsStreamEnd(request_id="r1")
|
||||||
|
assert frame.type == WsFrameType.stream_end
|
||||||
|
assert frame.mutations == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_with_mutations():
|
||||||
|
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
|
||||||
|
frame = WsStreamEnd(request_id="r1", mutations=mutations)
|
||||||
|
assert len(frame.mutations) == 1
|
||||||
|
assert frame.mutations[0]["action"] == "create"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_serializes():
|
||||||
|
data = WsStreamEnd(request_id="r2").model_dump()
|
||||||
|
assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_deserializes():
|
||||||
|
raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
|
||||||
|
frame = WsStreamEnd.model_validate(raw)
|
||||||
|
assert frame.request_id == "r3"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsFloatingDomain ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_domain_tasks():
|
||||||
|
frame = WsFloatingDomain(request_id="r1", domain="tasks")
|
||||||
|
assert frame.type == WsFrameType.floating_domain
|
||||||
|
assert frame.domain == "tasks"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("domain", ["tasks", "checkpoints", "notes", "projects"])
|
||||||
|
def test_floating_domain_valid_domains(domain: str):
|
||||||
|
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
|
||||||
|
assert frame.domain == domain
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_domain_invalid():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_domain_serializes():
|
||||||
|
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
|
||||||
|
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_domain_deserializes():
|
||||||
|
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
|
||||||
|
frame = WsFloatingDomain.model_validate(raw)
|
||||||
|
assert frame.domain == "projects"
|
||||||
157
tests/test_ws_unified.py
Normal file
157
tests/test_ws_unified.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Integration tests for the unified WebSocket handler (Step 5).
|
||||||
|
|
||||||
|
Tests the device WS endpoint with home_request and floating_request frames,
|
||||||
|
verifying that the correct v3 frame sequence is returned.
|
||||||
|
|
||||||
|
LLM calls are mocked to avoid network dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.schemas import WsFrameType
|
||||||
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
||||||
|
"""Receive frames until stream_end (or stream_end inside floating flow), or max_frames."""
|
||||||
|
frames = []
|
||||||
|
for _ in range(max_frames):
|
||||||
|
raw = ws.receive_text()
|
||||||
|
frame = json.loads(raw)
|
||||||
|
frames.append(frame)
|
||||||
|
if frame.get("type") == WsFrameType.stream_end:
|
||||||
|
break
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_home_stream(user_id, message, context, reg=None):
|
||||||
|
yield "task_agent", ""
|
||||||
|
yield "task_agent", '{"type": "text", "content": "Hello"}'
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_floating_stream(user_id, message, context, reg=None):
|
||||||
|
yield "task_agent", ""
|
||||||
|
yield "task_agent", "Here is a summary"
|
||||||
|
|
||||||
|
|
||||||
|
# ── tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_home_request_produces_stream_frames(client):
|
||||||
|
"""home_request → stream_start, stream_text+, stream_end."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_home_stream):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "home_request",
|
||||||
|
"request_id": "r1",
|
||||||
|
"message": "List my tasks",
|
||||||
|
"conversation_history": [],
|
||||||
|
}))
|
||||||
|
frames = _recv_until_end(ws)
|
||||||
|
|
||||||
|
types = [f["type"] for f in frames]
|
||||||
|
assert WsFrameType.stream_start in types
|
||||||
|
assert WsFrameType.stream_end in types
|
||||||
|
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_produces_domain_frame(client):
|
||||||
|
"""floating_request → floating_domain first, then stream_text*, stream_end."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_floating_stream):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "floating_request",
|
||||||
|
"request_id": "p1",
|
||||||
|
"message": "Summarize this task",
|
||||||
|
"scope": {"type": "task", "id": "task-123"},
|
||||||
|
}))
|
||||||
|
frames = _recv_until_end(ws)
|
||||||
|
|
||||||
|
types = [f["type"] for f in frames]
|
||||||
|
assert WsFrameType.floating_domain in types
|
||||||
|
assert WsFrameType.stream_end in types
|
||||||
|
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
|
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
||||||
|
assert domain_frame["domain"] == "tasks"
|
||||||
|
assert domain_frame["request_id"] == "p1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_request_id_propagated(client):
|
||||||
|
"""request_id in home_request is echoed in all response frames."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
req_id = "my-unique-req-id"
|
||||||
|
|
||||||
|
async def _stream(user_id, message, context, reg=None):
|
||||||
|
yield "note_agent", ""
|
||||||
|
yield "note_agent", '{"type": "text", "content": "ok"}'
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_stream):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-3", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "home_request",
|
||||||
|
"request_id": req_id,
|
||||||
|
"message": "hello",
|
||||||
|
}))
|
||||||
|
frames = _recv_until_end(ws)
|
||||||
|
|
||||||
|
for f in frames:
|
||||||
|
if "request_id" in f:
|
||||||
|
assert f["request_id"] == req_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result_dispatch_silent_on_unknown_id(client):
|
||||||
|
"""tool_result for unknown call_id is silently ignored — no crash."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-4", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "tool_result", "id": "no-such-id", "ok": True
|
||||||
|
}))
|
||||||
|
# If connection is still alive, we'll get the heartbeat ping
|
||||||
|
msg = json.loads(ws.receive_text())
|
||||||
|
assert msg["type"] == "ping"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_jwt_rejected(client):
|
||||||
|
"""Connection with bad token is closed before or after accept."""
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws:
|
||||||
|
ws.receive_text()
|
||||||
Reference in New Issue
Block a user