Compare commits
50 Commits
feature/de
...
48036397f1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48036397f1 | ||
|
|
57b5648915 | ||
|
|
7e4374c69b | ||
|
|
fe0dd038ee | ||
|
|
d3f7099d93 | ||
|
|
63fa119543 | ||
|
|
d856dfd28c | ||
|
|
ccba54ac24 | ||
|
|
55500cc818 | ||
|
|
75a826c9d8 | ||
|
|
971f1dd84f | ||
|
|
333bba6fdd | ||
|
|
229e20d073 | ||
|
|
0b491b3643 | ||
|
|
0d5fa3e569 | ||
|
|
aff68a9051 | ||
|
|
5e9ef2809e | ||
|
|
90018af311 | ||
|
|
1e2e395676 | ||
|
|
59d3a53980 | ||
|
|
9feeaa79c8 | ||
|
|
aa219a4d08 | ||
|
|
552b8eb305 | ||
|
|
0d93b3960d | ||
|
|
f07580574b | ||
|
|
1a8bf11f90 | ||
|
|
e7cdce8287 | ||
|
|
58bc6efd4b | ||
|
|
6c450805cb | ||
|
|
f340d0fa3e | ||
|
|
edc53cb6eb | ||
|
|
725cece5c1 | ||
|
|
297e20ce8d | ||
|
|
5a03bd1cfb | ||
|
|
87b7a1c6c9 | ||
|
|
826f64d6bb | ||
| 5faa6b1d7c | |||
| 02a9684cd6 | |||
| fae9efee0d | |||
| 30b062dd4a | |||
| 2a0331d7ce | |||
| 13fd8677c1 | |||
| 9bd629cb59 | |||
| 9c97702daa | |||
| a1e364c9c0 | |||
| 5b55f1292a | |||
| 5bc9ea6cd6 | |||
| f7404b6f66 | |||
| d667e43c73 | |||
| fe085a7951 |
20
.env.example
20
.env.example
@@ -4,9 +4,17 @@ ENV=dev
|
|||||||
# ── Database ──────────────────────────────────────────────────────────────────
|
# ── Database ──────────────────────────────────────────────────────────────────
|
||||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
||||||
|
|
||||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
# ── Redis ─────────────────────────────────────────────────────────────────────
|
||||||
JWT_SECRET=replace-with-a-long-random-secret
|
REDIS_URL=redis://localhost:6379/0
|
||||||
JWT_ALGORITHM=HS256
|
|
||||||
|
# ── Auth (JWT RS256) ──────────────────────────────────────────────────────────
|
||||||
|
# Public key for optional local JWT verification (Traefik ForwardAuth handles
|
||||||
|
# this in production — services trust X-User-* headers from Traefik).
|
||||||
|
# Generate keypair:
|
||||||
|
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||||
|
# openssl rsa -in private.pem -pubout -out public.pem
|
||||||
|
# Paste PEM content with literal \n for newlines.
|
||||||
|
JWT_PUBLIC_KEY=
|
||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
|
||||||
@@ -17,7 +25,6 @@ OPENAI_API_KEY=
|
|||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
LLM_MODEL=gpt-4o
|
LLM_MODEL=gpt-4o
|
||||||
LLM_ROUTER_MODEL=gpt-4o-mini
|
|
||||||
|
|
||||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||||
STRIPE_SECRET_KEY=
|
STRIPE_SECRET_KEY=
|
||||||
@@ -42,3 +49,8 @@ QDRANT_API_KEY=
|
|||||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||||
# Comma-separated list parsed by Settings (override default if needed)
|
# Comma-separated list parsed by Settings (override default if needed)
|
||||||
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
||||||
|
|
||||||
|
# ── Langfuse (observability) ─────────────────────────────────────────────────
|
||||||
|
LANGFUSE_SECRET_KEY=sk-lf-...
|
||||||
|
LANGFUSE_PUBLIC_KEY=pk-lf-...
|
||||||
|
LANGFUSE_HOST=https://cloud.langfuse.com # or self-hosted URL
|
||||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -13,6 +13,9 @@ env/
|
|||||||
# Environment variables
|
# Environment variables
|
||||||
.env
|
.env
|
||||||
|
|
||||||
|
# Cryptographic keys
|
||||||
|
*.pem
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
@@ -32,3 +35,6 @@ Thumbs.db
|
|||||||
# Claude Code
|
# Claude Code
|
||||||
.claude/
|
.claude/
|
||||||
logs/
|
logs/
|
||||||
|
|
||||||
|
# Eval private test data
|
||||||
|
services/batch-agent/eval/fixtures/private_data/
|
||||||
|
|||||||
@@ -1,523 +0,0 @@
|
|||||||
# AI Refactor Plan — Adiuva Backend
|
|
||||||
|
|
||||||
> **Objective:** Transform backend tools from JSON-action-descriptor-returning functions into real bidirectional executors. Each tool sends structured CRUD operations to the Electron client via WebSocket, receives real data back, and returns meaningful results to the LLM. The LLM reasons about actual user data instead of serialized action payloads.
|
|
||||||
>
|
|
||||||
> **Electron app:** Lives at `../adiuva/`. See `../adiuva/AI_REFACTOR_PLAN.md`.
|
|
||||||
>
|
|
||||||
> **Protocol:** Execute steps sequentially. Each step is atomic and committable. Mark `[x]` when done.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Architecture — Before vs After
|
|
||||||
|
|
||||||
### Before (current)
|
|
||||||
```
|
|
||||||
LLM calls list_tasks(status="todo")
|
|
||||||
→ tool returns: '{"action":"list","table":"tasks","filters":{"status":"todo"}}'
|
|
||||||
→ _tool_loop feeds that JSON string as ToolMessage to LLM
|
|
||||||
→ LLM sees a descriptor, NOT real data — cannot reason about tasks
|
|
||||||
→ Final response: generic "Here are your tasks" (no actual task data)
|
|
||||||
→ Action descriptors sent in final WS frame for Electron to execute post-response
|
|
||||||
```
|
|
||||||
|
|
||||||
### After (target)
|
|
||||||
```
|
|
||||||
LLM calls list_tasks(status="todo")
|
|
||||||
→ tool calls execute_on_client(action="select", table="tasks", filters={status:"todo"})
|
|
||||||
→ WS frame sent to Electron: {type:"tool_call", id:"abc", action:"select", table:"tasks", filters:{status:"todo"}}
|
|
||||||
→ Electron runs: db.select().from(tasks).where(eq(tasks.status, "todo")).all()
|
|
||||||
→ WS frame back: {type:"tool_result", id:"abc", rows:[{id:"1",title:"Buy milk",...}, ...]}
|
|
||||||
→ tool returns: "Found 3 tasks: 1. Buy milk (high, due tomorrow) 2. ..."
|
|
||||||
→ _tool_loop feeds that as ToolMessage to LLM
|
|
||||||
→ LLM sees REAL data — can reason, count, compare, summarize
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## WS Protocol — Typed Frames
|
|
||||||
|
|
||||||
| Direction | `type` | Payload |
|
|
||||||
|---|---|---|
|
|
||||||
| Client → Server | `chat_request` | `{ message: str, context: ChatContext }` |
|
|
||||||
| Server → Client | `text_chunk` | `{ text: str }` |
|
|
||||||
| Server → Client | `tool_call` | `{ id: str, action: str, table?: str, data?: dict, filters?: dict, vector?: list[float], limit?: int }` |
|
|
||||||
| Client → Server | `tool_result` | `{ id: str, row?: dict, rows?: list[dict], results?: list[dict], deleted?: bool, ok?: bool, error?: str }` |
|
|
||||||
| Server → Client | `final` | `{ response: str }` |
|
|
||||||
| Server → Client | `ping` | `{}` |
|
|
||||||
|
|
||||||
**Actions:**
|
|
||||||
|
|
||||||
| `action` | What Electron does (Drizzle) | `tool_result` shape |
|
|
||||||
|---|---|---|
|
|
||||||
| `select` | `db.select().from(table).where(filters)` | `{ rows: [...] }` |
|
|
||||||
| `get` | `db.select().from(table).where(id=...).get()` | `{ row: {...} or null }` |
|
|
||||||
| `insert` | `db.insert(table).values({id: uuid(), ...data}).returning().get()` | `{ row: {...} }` |
|
|
||||||
| `update` | `db.update(table).set(updates).where(id=...).returning().get()` | `{ row: {...} }` |
|
|
||||||
| `delete` | `db.delete(table).where(id=...).run()` | `{ deleted: true }` |
|
|
||||||
| `vector_upsert` | LanceDB upsert with pre-computed vector | `{ ok: true }` |
|
|
||||||
| `vector_search` | LanceDB search by vector | `{ results: [{id, content, score}...] }` |
|
|
||||||
|
|
||||||
**Electron generates IDs + timestamps.** Backend tools never send `id` or `createdAt` in `insert` data — Electron adds `id: uuid()`, `createdAt: Date.now()`, `updatedAt: Date.now()`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## SQLite Schema Reference (Electron's local database)
|
|
||||||
|
|
||||||
Tools must use **camelCase** field names (Drizzle maps them to snake_case internally):
|
|
||||||
|
|
||||||
| Table | Columns |
|
|
||||||
|---|---|
|
|
||||||
| `tasks` | id, projectId, title, description, status (todo\|in_progress\|done), priority (high\|medium\|low), assignee (JSON array string), dueDate (ms), isAiSuggested (0\|1), isApproved (0\|1), createdAt (ms) |
|
|
||||||
| `projects` | id, clientId, name, status (active\|archived), aiSummary, createdAt (ms) |
|
|
||||||
| `timelines` | id, projectId (required), title, date (ms), isAiSuggested (0\|1), isApproved (0\|1), createdAt (ms) |
|
|
||||||
| `notes` | id, projectId, title, content (markdown), createdAt (ms), updatedAt (ms) |
|
|
||||||
| `taskComments` | id, taskId, author, content, createdAt (ms) |
|
|
||||||
| `clients` | id, parentId, name, industry, createdAt (ms) |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase B — Backend Changes
|
|
||||||
|
|
||||||
### Step B.1 — WS context + frame types
|
|
||||||
- [x] Create `app/core/ws_context.py` (~25 lines):
|
|
||||||
- `_client_executor: ContextVar[Callable]` — holds the async callback for the current WS session
|
|
||||||
- `async def execute_on_client(action, table=None, data=None, filters=None, vector=None, limit=None) -> dict`:
|
|
||||||
- Reads callback from ContextVar
|
|
||||||
- Builds `tool_call` payload: `{id: str(uuid4()), action, table, data, filters, vector, limit}` (omits None fields)
|
|
||||||
- Calls `await callback(payload)` — which sends the WS frame and waits for `tool_result`
|
|
||||||
- Returns the result dict
|
|
||||||
- `def set_client_executor(fn)` / `def clear_client_executor()` — ContextVar management
|
|
||||||
- [x] Add to `app/schemas.py`:
|
|
||||||
- `WsFrameType(str, Enum)`: `chat_request`, `text_chunk`, `tool_call`, `tool_result`, `final`, `ping`
|
|
||||||
- `WsToolCall(BaseModel)`: `type`, `id`, `action`, `table?`, `data?`, `filters?`, `vector?`, `limit?`
|
|
||||||
- `WsToolResult(BaseModel)`: `type`, `id`, `row?`, `rows?`, `results?`, `deleted?`, `ok?`, `error?`
|
|
||||||
- `WsTextChunk(BaseModel)`: `type`, `text`
|
|
||||||
- `WsFinal(BaseModel)`: `type`, `response`
|
|
||||||
- **Files:** `app/core/ws_context.py`, `app/schemas.py`
|
|
||||||
- **Outcome:** Any tool can `await execute_on_client(...)` to query/mutate the user's local DB.
|
|
||||||
|
|
||||||
### Step B.2 — Rewrite all 23 tools to use `execute_on_client()`
|
|
||||||
- [x] Each tool: same `@tool` decorator, same parameters, same docstring. Replace `return json.dumps({...})` body with:
|
|
||||||
1. Call `result = await execute_on_client(action=..., table=..., data/filters=...)`
|
|
||||||
2. Return human-readable string with confirmation + key data from `result`
|
|
||||||
|
|
||||||
- [x] **`app/agents/task_agent.py` (8 tools):**
|
|
||||||
- `list_tasks(project_id, status, search, order_by)`:
|
|
||||||
```python
|
|
||||||
result = await execute_on_client(action="select", table="tasks", filters={
|
|
||||||
"projectId": project_id or None,
|
|
||||||
"status": status or None,
|
|
||||||
"search": search or None,
|
|
||||||
"orderBy": order_by or None,
|
|
||||||
})
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No tasks found matching the given filters."
|
|
||||||
lines = [f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})" for r in rows]
|
|
||||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
|
||||||
```
|
|
||||||
- `create_task(title, ...)`:
|
|
||||||
```python
|
|
||||||
result = await execute_on_client(action="insert", table="tasks", data={
|
|
||||||
"title": title, "description": description or None, "status": status,
|
|
||||||
"priority": priority, "assignee": assignees, "dueDate": due_date or None,
|
|
||||||
"projectId": project_id or None, "isAiSuggested": is_ai_suggested, "isApproved": is_approved,
|
|
||||||
})
|
|
||||||
row = result["row"]
|
|
||||||
return f"Task created: '{row['title']}' (id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
|
||||||
```
|
|
||||||
- `update_task(task_id, ...)`: build updates dict (same logic as now) → `execute_on_client(action="update", table="tasks", data={"id": task_id, "updates": updates})` → return "Task updated: {title}"
|
|
||||||
- `delete_task(task_id)`: `execute_on_client(action="delete", table="tasks", data={"id": task_id})` → return "Task deleted"
|
|
||||||
- `list_tasks_due_today()`: calculate today's start/end ms → `execute_on_client(action="select", table="tasks", filters={"dueDateFrom": start, "dueDateTo": end})` → format + return
|
|
||||||
- `list_task_comments(task_id)`: `execute_on_client(action="select", table="taskComments", filters={"taskId": task_id})` → format + return
|
|
||||||
- `add_task_comment(task_id, author, content)`: `execute_on_client(action="insert", table="taskComments", data={...})` → return confirmation
|
|
||||||
- `delete_task_comment(comment_id)`: `execute_on_client(action="delete", table="taskComments", data={"id": comment_id})` → return confirmation
|
|
||||||
|
|
||||||
- [x] **`app/agents/project_agent.py` (6 tools):**
|
|
||||||
- `list_projects(client_id, include_archived)`: `execute_on_client(action="select", table="projects", filters={clientId, includeArchived})` → format + return
|
|
||||||
- `list_all_projects()`: `execute_on_client(action="select", table="projects")` → format + return
|
|
||||||
- `get_project(project_id)`: `execute_on_client(action="get", table="projects", data={"id": project_id})` → return project details or "not found"
|
|
||||||
- `create_project(name, client_id)`: `execute_on_client(action="insert", table="projects", data={name, clientId})` → return confirmation + id
|
|
||||||
- `update_project(project_id, ...)`: build updates → `execute_on_client(action="update", ...)` → return confirmation
|
|
||||||
- `delete_project(project_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
|
||||||
|
|
||||||
- [x] **`app/agents/timeline_agent.py` (4 tools):**
|
|
||||||
- `list_timelines(project_id)`: `execute_on_client(action="select", table="timelines", filters={projectId})` → format + return
|
|
||||||
- `create_timeline(project_id, title, date, ...)`: `execute_on_client(action="insert", table="timelines", data={...})` → return confirmation + id
|
|
||||||
- `update_timeline(timeline_id, ...)`: build updates → `execute_on_client(action="update", ...)` → return confirmation
|
|
||||||
- `delete_timeline(timeline_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
|
||||||
|
|
||||||
- [x] **`app/agents/note_agent.py` (5 tools):**
|
|
||||||
- `list_notes(project_id)`: `execute_on_client(action="select", table="notes", filters={projectId})` → format + return
|
|
||||||
- `get_note(note_id)`: `execute_on_client(action="get", table="notes", data={"id": note_id})` → return full content or "not found"
|
|
||||||
- `create_note(title, content, project_id)`: `execute_on_client(action="insert", table="notes", data={...})` → then `execute_on_client(action="vector_upsert", data={id, projectId, content}, vector=await embed(content))` → return confirmation
|
|
||||||
- `update_note(note_id, ...)`: build updates → `execute_on_client(action="update", ...)` → then vector_upsert for updated content → return confirmation
|
|
||||||
- `delete_note(note_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
|
||||||
|
|
||||||
- **Files:** `app/agents/task_agent.py`, `app/agents/project_agent.py`, `app/agents/timeline_agent.py`, `app/agents/note_agent.py`
|
|
||||||
- **Outcome:** All 23 tools query real user data via WS. LLM sees actual rows, not action descriptors.
|
|
||||||
|
|
||||||
### Step B.3 — Bidirectional WebSocket handler
|
|
||||||
- [x] Refactor `app/api/routes/chat.py` WS endpoint:
|
|
||||||
- After auth + accept + receive `chat_request`:
|
|
||||||
1. Create `execute_on_client` callback closure capturing the websocket:
|
|
||||||
```python
|
|
||||||
pending_calls: dict[str, asyncio.Future] = {}
|
|
||||||
|
|
||||||
async def on_client_result(frame: dict):
|
|
||||||
"""Called when a tool_result frame arrives from Electron."""
|
|
||||||
fut = pending_calls.pop(frame["id"], None)
|
|
||||||
if fut and not fut.done():
|
|
||||||
fut.set_result(frame)
|
|
||||||
|
|
||||||
async def execute_callback(payload: dict) -> dict:
|
|
||||||
"""Send tool_call to Electron, wait for tool_result."""
|
|
||||||
call_id = payload["id"]
|
|
||||||
fut = asyncio.get_event_loop().create_future()
|
|
||||||
pending_calls[call_id] = fut
|
|
||||||
await websocket.send_text(json.dumps({"type": "tool_call", **payload}))
|
|
||||||
return await asyncio.wait_for(fut, timeout=30.0)
|
|
||||||
```
|
|
||||||
2. Set `client_executor` ContextVar with `execute_callback`
|
|
||||||
3. Run orchestrator in a task — it calls agents, agents call tools, tools call `execute_on_client()` which goes through the callback
|
|
||||||
4. In parallel, run a message receive loop that dispatches incoming frames:
|
|
||||||
- `tool_result` → `on_client_result(frame)`
|
|
||||||
- `ping` → ignore
|
|
||||||
5. Orchestrator yields `text_chunk` frames → send to client
|
|
||||||
6. Send `final` frame when done
|
|
||||||
7. Clear ContextVar
|
|
||||||
- Keep heartbeat ping every 30s
|
|
||||||
- 30s timeout on `tool_result` — if Electron doesn't respond, future raises `TimeoutError`, tool returns error string to LLM
|
|
||||||
- **Files:** `app/api/routes/chat.py`
|
|
||||||
- **Outcome:** Full bidirectional WS. Tool calls and text streaming happen concurrently on the same connection.
|
|
||||||
|
|
||||||
### Step B.4 — `_tool_loop` — no changes needed
|
|
||||||
- [x] Verify `app/core/agent_registry.py` works unchanged:
|
|
||||||
- `_tool_loop` calls `tool_fn.ainvoke(args)` → tool awaits `execute_on_client()` (WS round-trip) → returns string → `ToolMessage(content=string)` → LLM sees real data
|
|
||||||
- The async WS round-trip happens inside each tool. `_tool_loop` just sees an awaited tool returning a string — same as before, different content.
|
|
||||||
- **No code changes.** Just verify + add a log line for tool execution times if desired.
|
|
||||||
|
|
||||||
### Step B.5 — Orchestrator cleanup
|
|
||||||
- [x] Update `app/core/orchestrator.py`:
|
|
||||||
- `orchestrate_stream()`: remove `"actions": []` from final frame. Final becomes: `{"done": true, "response": "..."}`
|
|
||||||
- No other changes — `classify_intent` → `call_agent` → chunk response → final frame
|
|
||||||
- **Files:** `app/core/orchestrator.py`
|
|
||||||
- **Outcome:** Clean final frame. No more action descriptors in the protocol.
|
|
||||||
|
|
||||||
### Step B.6 — Add `/vectors/embed` endpoint
|
|
||||||
- [x] Add to `app/api/routes/vectors.py`:
|
|
||||||
- `POST /api/v1/storage/vectors/embed`:
|
|
||||||
- Request: `{ text: str }`
|
|
||||||
- Response: `{ vector: list[float] }` (1536-dim from `text-embedding-3-small`)
|
|
||||||
- Auth required (JWT)
|
|
||||||
- Used by:
|
|
||||||
- Backend tools: `note_agent` calls this before `vector_upsert`
|
|
||||||
- Electron: `vectordb.ts` calls this for note embedding on create/update
|
|
||||||
- **Files:** `app/api/routes/vectors.py`
|
|
||||||
- **Outcome:** Single embedding endpoint. Both backend tools and Electron can generate vectors.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Verification
|
|
||||||
|
|
||||||
| What to test | How |
|
|
||||||
|---|---|
|
|
||||||
| **Read flow** | "List my tasks" → `list_tasks` → `tool_call{select, tasks}` → Electron returns rows → LLM describes real tasks |
|
|
||||||
| **Write flow** | "Create a task called Buy milk" → `create_task` → `tool_call{insert, tasks, data:{title:"Buy milk"}}` → Electron inserts + returns row → tool confirms with id |
|
|
||||||
| **Multi-tool** | "How many todo tasks do I have?" → `list_tasks(status=todo)` → LLM counts actual rows → "You have 3 todo tasks" |
|
|
||||||
| **Vector search** | "Find notes about deployment" → tool embeds → `tool_call{vector_search, vector:[...]}` → Electron searches LanceDB → returns matching notes |
|
|
||||||
| **Vector upsert** | "Create a note about..." → insert note → vector_upsert with embedding → both SQLite + LanceDB updated |
|
|
||||||
| **Tool timeout** | Disconnect Electron mid-conversation → 30s timeout → tool returns error → LLM handles gracefully |
|
|
||||||
| **Concurrent calls** | Agent calls 2 tools in sequence → each does WS round-trip → both succeed → LLM sees both results |
|
|
||||||
| **_tool_loop max iter** | Verify 5-iteration limit still works → after 5 tool calls, LLM forced to answer without tools |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Execution Notes
|
|
||||||
|
|
||||||
- **Phase 1 is the critical path.** Auth + backend client + drizzle executor + orchestrator refactor must land first.
|
|
||||||
- **Steps 1.1–1.4 are additive** — existing app keeps working until Step 1.5 swaps the orchestrator.
|
|
||||||
- **Step 2.1 is the point of no return** — after removing LangChain, there's no local AI fallback.
|
|
||||||
- **Phase B (backend changes) must land before Phase 1.3–1.5** — Electron needs the bidirectional WS to talk to.
|
|
||||||
- **Phase 3 and Phase 4 are independent** — can be parallelized after Phase 2.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 3 — Agent System: Config, Orchestration & Cloud Connectors
|
|
||||||
|
|
||||||
> **Objective:** Backend manages all agent configuration, scheduling, orchestration, and cloud data fetching. Two agent types: **Local Directory Agent** (backend triggers Electron to read files, then AI analyzes) and **Cloud Connector Agent** (backend fetches Gmail/Teams data directly, AI analyzes, pushes results to Electron via WS tool_call). All extracted items use existing WS tool infrastructure to insert into Electron's local DB with `is_ai_suggested=True`.
|
|
||||||
>
|
|
||||||
> **Electron Phase 3 plan:** `../adiuva/AI_REFACTOR_PLAN.md` Phase 3 section.
|
|
||||||
>
|
|
||||||
> **Electron UI status (2025):** Steps 3.6, 3.7, 3.8 of the Electron plan are ✅ complete. Agents are configured inside the Settings page (`/settings?section=agents`) — not a standalone route. The `JourneyDialog` (Step 3.8) is embedded inline in the Settings → Agents section. `LocalAgentConfigPanel` and `CloudAgentConfigPanel` (Step 3.7) are also inline. This affects the journey API contract (see Step 3.5 below).
|
|
||||||
|
|
||||||
### Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
Local Agent:
|
|
||||||
Scheduler/manual trigger ──► check device online ──► WS agent_run → Electron
|
|
||||||
──► Electron reads files ──► WS agent_data → Backend
|
|
||||||
──► Backend AI (prompt_template + file content) ──► WS tool_call(insert) → Electron
|
|
||||||
──► Electron persists with isAiSuggested=1
|
|
||||||
|
|
||||||
Cloud Agent:
|
|
||||||
Scheduler/manual trigger ──► Backend fetches Gmail/Teams (OAuth) ──► Backend AI analyzes
|
|
||||||
──► check device online ──► WS tool_call(insert) → Electron ──► Electron persists
|
|
||||||
```
|
|
||||||
|
|
||||||
**New WS frame types:**
|
|
||||||
|
|
||||||
| Direction | `type` | Payload |
|
|
||||||
|---|---|---|
|
|
||||||
| Server → Client | `agent_run` | `{ run_id, agent_id, config: { paths, file_extensions, prompt_template, data_types } }` |
|
|
||||||
| Client → Server | `agent_data` | `{ run_id, files: [{ path, name, content, metadata }] }` |
|
|
||||||
| Client → Server | `agent_complete` | `{ run_id, files_read, errors }` |
|
|
||||||
| Client → Server | `device_hello` | `{ device_id, agent_ids }` |
|
|
||||||
|
|
||||||
### Step 3.1 — Agent config tables
|
|
||||||
- [x] Add to `app/models.py`:
|
|
||||||
- **`LocalAgentConfig`**:
|
|
||||||
- `id` UUID PK
|
|
||||||
- `user_id` FK → users
|
|
||||||
- `device_id` str — identifies which Electron install this config belongs to
|
|
||||||
- `name` str
|
|
||||||
- `directory_paths` JSON — list of absolute paths on the device
|
|
||||||
- `data_types` JSON — which tables to extract to: `["tasks", "notes", "timelines", "projects"]`
|
|
||||||
- `prompt_template` text — user-configured via Chatbot Journey
|
|
||||||
- `file_extensions` JSON — e.g. `[".eml", ".txt", ".pdf", ".md"]`
|
|
||||||
- `schedule_cron` str — e.g. `"0 */6 * * *"` (every 6h)
|
|
||||||
- `enabled` bool (default True)
|
|
||||||
- `last_run_at` datetime nullable
|
|
||||||
- `created_at`, `updated_at` timestamps
|
|
||||||
- **`CloudAgentConfig`**:
|
|
||||||
- `id` UUID PK
|
|
||||||
- `user_id` FK → users
|
|
||||||
- `provider` str — enum: `gmail`, `teams`, `outlook`
|
|
||||||
- `name` str
|
|
||||||
- `data_types` JSON — same format as local
|
|
||||||
- `prompt_template` text
|
|
||||||
- `oauth_token_encrypted` text — Fernet-encrypted OAuth2 credentials
|
|
||||||
- `schedule_cron` str
|
|
||||||
- `enabled` bool (default True)
|
|
||||||
- `last_run_at` datetime nullable
|
|
||||||
- `filter_config` JSON — provider-specific: `{ labels: [], date_range: {from, to}, senders: [] }`
|
|
||||||
- `created_at`, `updated_at` timestamps
|
|
||||||
- **`AgentRunLog`**:
|
|
||||||
- `id` UUID PK
|
|
||||||
- `agent_id` str — references LocalAgentConfig.id or CloudAgentConfig.id
|
|
||||||
- `agent_type` str — `local` or `cloud`
|
|
||||||
- `user_id` FK → users
|
|
||||||
- `status` str — `running`, `success`, `error`, `partial`
|
|
||||||
- `items_processed` int (default 0)
|
|
||||||
- `items_created` int (default 0)
|
|
||||||
- `errors` JSON — list of error strings
|
|
||||||
- `started_at` datetime
|
|
||||||
- `completed_at` datetime nullable
|
|
||||||
- [x] Add Pydantic schemas to `app/schemas.py`:
|
|
||||||
- `LocalAgentConfigCreate`, `LocalAgentConfigUpdate`, `LocalAgentConfigResponse`
|
|
||||||
- `CloudAgentConfigCreate`, `CloudAgentConfigUpdate`, `CloudAgentConfigResponse`
|
|
||||||
- `AgentRunLogResponse`
|
|
||||||
- `AgentCatalogItem` — `{ type, name, description, config_schema }`
|
|
||||||
- `WsAgentRun`, `WsAgentData`, `WsAgentComplete`, `WsDeviceHello`
|
|
||||||
- [x] Generate Alembic migration
|
|
||||||
- **Files:** `app/models.py`, `app/schemas.py`, `alembic/versions/`
|
|
||||||
- **Outcome:** Agent config and run tracking tables in PostgreSQL.
|
|
||||||
|
|
||||||
### Step 3.2 — Agent CRUD API routes
|
|
||||||
- [x] Create `app/api/routes/agents.py`:
|
|
||||||
- `GET /api/v1/agents/catalog` — returns hardcoded agent type catalog:
|
|
||||||
- `local_directory`: "Watches local directories, extracts data from files using AI"
|
|
||||||
- `gmail`: "Scans Gmail inbox, extracts tasks/notes from emails"
|
|
||||||
- `teams`: "Monitors Teams messages, extracts action items"
|
|
||||||
- `outlook`: "Scans Outlook inbox, extracts tasks/notes"
|
|
||||||
- `GET /api/v1/agents/local` — list user's local agent configs
|
|
||||||
- `POST /api/v1/agents/local` — create local agent config
|
|
||||||
- Body: `{ name, device_id, directory_paths, data_types, prompt_template, file_extensions, schedule_cron }`
|
|
||||||
- Tier check: count enabled agents ≤ `batch_active` limit
|
|
||||||
- `PUT /api/v1/agents/local/{id}` — update config (ownership check)
|
|
||||||
- `DELETE /api/v1/agents/local/{id}` — delete config + associated run logs
|
|
||||||
- `GET /api/v1/agents/cloud` — list user's cloud agent configs
|
|
||||||
- `POST /api/v1/agents/cloud` — create cloud connector config
|
|
||||||
- Body: `{ provider, name, data_types, prompt_template, oauth_token_encrypted, schedule_cron, filter_config }`
|
|
||||||
- Tier check: same `batch_active` limit (local + cloud count together)
|
|
||||||
- `PUT /api/v1/agents/cloud/{id}` — update config
|
|
||||||
- `DELETE /api/v1/agents/cloud/{id}` — delete config + run logs
|
|
||||||
- `GET /api/v1/agents/runs` — query params: `agent_id`, `page`, `limit` → paginated run logs
|
|
||||||
- `POST /api/v1/agents/{id}/run` — manual trigger (dispatches to agent runner)
|
|
||||||
- All routes require JWT auth; ownership enforced on all mutations
|
|
||||||
- [x] Register router in `app/main.py`
|
|
||||||
- **Files:** `app/api/routes/agents.py`, `app/main.py`
|
|
||||||
- **Outcome:** Full CRUD for agent configs with tier-gated creation limits.
|
|
||||||
|
|
||||||
### Step 3.3 — Device WS endpoint
|
|
||||||
- [x] Create `app/api/routes/device_ws.py`:
|
|
||||||
- `WebSocket /api/v1/ws/device?token=<jwt>` — persistent connection from Electron
|
|
||||||
- On connect:
|
|
||||||
- Authenticate JWT
|
|
||||||
- Receive `device_hello` frame → extract `device_id`, `agent_ids`
|
|
||||||
- Store connection in `DeviceConnectionManager` (in-memory dict: `user_id → { ws, device_id }`)
|
|
||||||
- Check for overdue agent runs → trigger them immediately
|
|
||||||
- Message loop:
|
|
||||||
- `agent_data` → route to active agent run handler
|
|
||||||
- `agent_complete` → finalize agent run
|
|
||||||
- `tool_result` → route to pending tool call (same pattern as chat WS)
|
|
||||||
- `pong` → heartbeat ack
|
|
||||||
- On disconnect:
|
|
||||||
- Remove from `DeviceConnectionManager`
|
|
||||||
- Mark any in-progress agent runs as `error` with "device disconnected"
|
|
||||||
- Heartbeat: send `ping` every 30s, disconnect if no `pong` within 10s
|
|
||||||
- [x] Create `app/core/device_manager.py`:
|
|
||||||
- `DeviceConnectionManager` (singleton):
|
|
||||||
- `register(user_id, device_id, ws)` — stores active connection
|
|
||||||
- `unregister(user_id)` — removes connection
|
|
||||||
- `get_ws(user_id) -> WebSocket | None` — returns active WS if device is online
|
|
||||||
- `is_online(user_id, device_id=None) -> bool` — optionally checks specific device
|
|
||||||
- `send_frame(user_id, frame: dict)` — sends JSON frame to device
|
|
||||||
- **Files:** `app/api/routes/device_ws.py`, `app/core/device_manager.py`, `app/main.py`
|
|
||||||
- **Outcome:** Backend maintains persistent WS connections to Electron devices for agent triggers.
|
|
||||||
|
|
||||||
### Step 3.4 — Agent run orchestrator
|
|
||||||
- [x] Create `app/core/agent_runner.py`:
|
|
||||||
- `async run_local_agent(user_id, config: LocalAgentConfig, device_mgr: DeviceConnectionManager)`:
|
|
||||||
1. Check device is online with matching `device_id` → abort if offline
|
|
||||||
2. Create `AgentRunLog` with `status=running`
|
|
||||||
3. Send `WsAgentRun` frame to Electron with config (paths, extensions, prompt)
|
|
||||||
4. Await `WsAgentData` frames — collect file contents
|
|
||||||
5. Await `WsAgentComplete` frame — Electron signals done reading
|
|
||||||
6. For each file: call LLM with `prompt_template` + file content → extract structured items
|
|
||||||
7. For each extracted item: send `WsToolCall(insert, table, data)` to Electron → await `WsToolResult`
|
|
||||||
- All inserts include `is_ai_suggested=True, is_approved=False`
|
|
||||||
8. Update `AgentRunLog`: `status=success`, `items_processed`, `items_created`
|
|
||||||
- `async run_cloud_agent(user_id, config: CloudAgentConfig, device_mgr: DeviceConnectionManager)`:
|
|
||||||
1. Check device is online → abort if offline (results must push to Electron)
|
|
||||||
2. Create `AgentRunLog` with `status=running`
|
|
||||||
3. Decrypt OAuth credentials from `config.oauth_token_encrypted`
|
|
||||||
4. Fetch data from cloud provider (Step 3.6):
|
|
||||||
- Gmail: `google-api-python-client` + `filter_config` label/date filters
|
|
||||||
- Teams: `msgraph-sdk` + channel/date filters
|
|
||||||
- Outlook: `msgraph-sdk` + folder/date filters
|
|
||||||
5. For each item: call LLM with `prompt_template` + email/message content → extract structured items
|
|
||||||
6. For each extracted item: send `WsToolCall(insert)` to Electron → await `WsToolResult`
|
|
||||||
7. Update `AgentRunLog`
|
|
||||||
- `async trigger_pending_runs(user_id, device_id, device_mgr)`:
|
|
||||||
- Called when Electron connects (after `device_hello`)
|
|
||||||
- Queries all enabled agent configs where `last_run_at + schedule_interval < now()`
|
|
||||||
- For local agents: only triggers if `config.device_id == device_id`
|
|
||||||
- For cloud agents: triggers regardless of device (any connected device can receive results)
|
|
||||||
- Executes runs sequentially (one at a time to avoid overwhelming the WS)
|
|
||||||
- Error handling: on any failure, update `AgentRunLog` with `status=error` + error details
|
|
||||||
- [x] Wire `POST /agents/{id}/run` endpoint to dispatch background task via `asyncio.create_task()`
|
|
||||||
- [x] Replace `_trigger_pending_runs_stub` in `device_ws.py` with real `trigger_pending_runs` call
|
|
||||||
- [x] Add `croniter>=3.0.0` to `requirements.txt`
|
|
||||||
- [x] 23 unit + integration tests covering all code paths
|
|
||||||
- **Files:** `app/core/agent_runner.py`, `app/api/routes/agents.py`, `app/api/routes/device_ws.py`, `requirements.txt`, `tests/test_agent_runner.py`
|
|
||||||
- **Outcome:** Backend drives all agent execution — both local (via WS file request) and cloud (direct API calls — stub until Step 3.6).
|
|
||||||
|
|
||||||
### Step 3.5 — Chatbot Journey endpoint
|
|
||||||
- [x] Create `app/api/routes/agent_setup.py`:
|
|
||||||
- `POST /api/v1/agents/journey/start`:
|
|
||||||
- Body: `{ agent_type: "local"|"cloud", agent_id: str | None }`
|
|
||||||
- `agent_type`: which kind of agent this journey configures.
|
|
||||||
- `agent_id`: optional — if provided, the session is pre-seeded with the existing agent's `prompt_template` so the user can refine it. If absent, fresh journey.
|
|
||||||
- **No `data_types` field** — data types are determined through the conversation itself, not sent upfront.
|
|
||||||
- Creates a journey session (in-memory or Redis-backed)
|
|
||||||
- Returns first AI message: contextual question based on agent type
|
|
||||||
- Local: "What kind of files are in the directories you want to monitor? (emails, documents, logs, etc.)"
|
|
||||||
- Cloud: "What kind of emails/messages should I look for? (client communications, invoices, meeting notes, etc.)"
|
|
||||||
- Response: `{ session_id, message, done: false }`
|
|
||||||
- **Electron note:** `proxyPost` auto-converts camelCase keys to snake_case. Electron sends `{ agentType, agentId }` → backend receives `{ agent_type, agent_id }`.
|
|
||||||
- `POST /api/v1/agents/journey/message`:
|
|
||||||
- Body: `{ session_id, message }`
|
|
||||||
- AI processes user's answer, asks follow-up questions (max 5 turns)
|
|
||||||
- System prompt: "You are configuring a data extraction agent for a freelancer. Ask about file format, what data to extract (tasks, notes, timelines), naming conventions, priority rules, and any special mapping. After 3-5 questions, generate a detailed prompt_template."
|
|
||||||
- When AI determines enough context: `{ session_id, message: "Here's your configuration...", done: true, prompt_template: "..." }`
|
|
||||||
- The `prompt_template` is a structured instruction for the extraction LLM (e.g. "Extract tasks from email. Subject becomes task title. If body contains 'urgent' or 'ASAP', set priority to 'high'. Extract due dates if mentioned.")
|
|
||||||
- **Electron note:** `toCamelCase` converts the response → Electron reads `promptTemplate` from the final message and auto-fills the agent config panel. User clicks "Save & apply" which calls `agent.local.update` / `agent.cloud.update` tRPC mutation.
|
|
||||||
- **Files:** `app/api/routes/agent_setup.py`, `app/main.py`
|
|
||||||
- **Outcome:** Users configure AI prompts through guided conversation. Journey can refine an existing config when `agent_id` is provided. ✅
|
|
||||||
|
|
||||||
### Step 3.6 — Cloud provider integrations
|
|
||||||
- [x] Create `app/integrations/gmail.py`:
|
|
||||||
- `GmailClient`:
|
|
||||||
- `__init__(oauth_token)` — initializes Google API client
|
|
||||||
- `async fetch_messages(filter_config, since: datetime) -> list[EmailMessage]`
|
|
||||||
- `EmailMessage`: `{ id, subject, sender, body_text, date, labels }`
|
|
||||||
- Handles token refresh via Google OAuth2 refresh flow
|
|
||||||
- Respects `filter_config.labels`, `filter_config.date_range`, `filter_config.senders`
|
|
||||||
- [x] Create `app/integrations/ms_graph.py`:
|
|
||||||
- `MSGraphClient`:
|
|
||||||
- `__init__(oauth_token)` — initializes MS Graph client
|
|
||||||
- `async fetch_emails(filter_config, since: datetime) -> list[EmailMessage]` (Outlook)
|
|
||||||
- `async fetch_messages(filter_config, since: datetime) -> list[ChatMessage]` (Teams)
|
|
||||||
- `ChatMessage`: `{ id, content, sender, channel, date }`
|
|
||||||
- Handles token refresh via MSAL
|
|
||||||
- [x] Create `app/integrations/__init__.py` — factory: `get_provider(provider_name) -> GmailClient | MSGraphClient`
|
|
||||||
- **Dependencies:** `google-api-python-client`, `google-auth-oauthlib`, `msgraph-sdk`, `msal`
|
|
||||||
- **Files:** `app/integrations/gmail.py`, `app/integrations/ms_graph.py`, `app/integrations/__init__.py`
|
|
||||||
- **Outcome:** Backend can fetch emails/messages from Gmail, Outlook, and Teams.
|
|
||||||
|
|
||||||
### Step 3.7 — Agent scheduler
|
|
||||||
- [ ] Create `app/core/agent_scheduler.py`:
|
|
||||||
- Uses `APScheduler` (or simple asyncio loop) to check agent schedules
|
|
||||||
- Every 60s: query enabled agents where `last_run_at + cron_interval < now()`
|
|
||||||
- For each due agent:
|
|
||||||
- Check if user's device is online via `DeviceConnectionManager`
|
|
||||||
- If online: dispatch to `agent_runner`
|
|
||||||
- If offline: skip (will trigger on next `device_hello`)
|
|
||||||
- Locks: use PostgreSQL advisory locks to prevent duplicate runs in multi-instance deployments
|
|
||||||
- [ ] Integrate with FastAPI lifespan (start scheduler on app startup, shutdown gracefully)
|
|
||||||
- **Dependencies:** `apscheduler>=4.0`
|
|
||||||
- **Files:** `app/core/agent_scheduler.py`, `app/main.py`
|
|
||||||
- **Outcome:** Agents run automatically on their configured schedules.
|
|
||||||
|
|
||||||
### Step 3.8 — OAuth flow endpoints
|
|
||||||
- [ ] Create `app/api/routes/oauth.py`:
|
|
||||||
- `GET /api/v1/oauth/{provider}/authorize` — returns OAuth authorization URL
|
|
||||||
- Gmail: Google OAuth2 with `gmail.readonly` scope
|
|
||||||
- Outlook/Teams: MS identity platform with `Mail.Read`, `ChannelMessage.Read.All` scopes
|
|
||||||
- `GET /api/v1/oauth/{provider}/callback` — handles OAuth redirect
|
|
||||||
- Exchanges auth code for access + refresh tokens
|
|
||||||
- Encrypts tokens with Fernet (server-side key from settings)
|
|
||||||
- Returns encrypted token blob for storage in `CloudAgentConfig.oauth_token_encrypted`
|
|
||||||
- `POST /api/v1/oauth/{provider}/refresh` — refresh expired OAuth token
|
|
||||||
- **Files:** `app/api/routes/oauth.py`, `app/main.py`
|
|
||||||
- **Outcome:** Users can connect Gmail/Teams/Outlook accounts securely.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Phase 3 — Verification
|
|
||||||
|
|
||||||
| # | Scenario | Expected |
|
|
||||||
|---|---|---|
|
|
||||||
| 1 | **Agent CRUD** | Create/read/update/delete local and cloud configs; tier limits enforced (free=2, pro=10) |
|
|
||||||
| 2 | **WS device connect** | Electron connects → `device_hello` → backend stores connection → triggers overdue runs |
|
|
||||||
| 3 | **Local agent run** | Backend sends `agent_run` → Electron reads files → `agent_data` → backend AI extracts → `tool_call(insert)` → Electron persists with `isAiSuggested=1` |
|
|
||||||
| 4 | **Cloud agent run** | Backend fetches Gmail → AI extracts tasks → `tool_call(insert)` → Electron persists |
|
|
||||||
| 5 | **Device binding** | Local agent config with `device_id=A` only triggers when device A is connected |
|
|
||||||
| 6 | **Chatbot Journey** | Start journey → 3-5 Q&A turns → produces valid `prompt_template` |
|
|
||||||
| 7 | **Schedule** | Agent with `schedule_cron="0 */6 * * *"` runs every 6h when device is online |
|
|
||||||
| 8 | **Offline resilience** | Device offline → runs skipped → device reconnects → overdue runs trigger immediately |
|
|
||||||
| 9 | **OAuth flow** | Gmail authorize → callback → token encrypted → stored in config → fetch emails works |
|
|
||||||
|
|
||||||
### Phase 3 — New Dependencies
|
|
||||||
|
|
||||||
| Package | Purpose |
|
|
||||||
|---|---|
|
|
||||||
| `google-api-python-client` | Gmail API access |
|
|
||||||
| `google-auth-oauthlib` | Gmail OAuth2 flow |
|
|
||||||
| `msgraph-sdk` | Outlook + Teams API access |
|
|
||||||
| `msal` | MS identity platform auth |
|
|
||||||
| `apscheduler>=4.0` | Agent scheduling |
|
|
||||||
| `cryptography` (Fernet) | OAuth token encryption at rest |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ~~Phase 5 — Shared Memory~~ (SUPERSEDED)
|
|
||||||
|
|
||||||
> **This phase has been fully replaced by `V3_MIGRATION_PLAN.md`.**
|
|
||||||
>
|
|
||||||
> - Chat WS fix → V3 Step 5 (Unified WS Handler — single multiplexed socket)
|
|
||||||
> - Agent memory → V3 Steps 6–7 (Cloud-side MemGPT-style memory in PostgreSQL + pgvector, encrypted at rest with per-user Fernet key)
|
|
||||||
>
|
|
||||||
> The on-device KV approach (Electron SQLite `agent_memory` table) is no longer the target architecture.
|
|
||||||
> See `V3_MIGRATION_PLAN.md` for the current plan.
|
|
||||||
572
BACKEND_PLAN.md
572
BACKEND_PLAN.md
@@ -1,572 +0,0 @@
|
|||||||
# Backend Plan — Adiuva Cloud API
|
|
||||||
|
|
||||||
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
|
|
||||||
>
|
|
||||||
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, E2E backup blob storage, cloud storage (encrypted blobs), cloud vector store, and plugin marketplace.
|
|
||||||
> The backend NEVER persists user data in plaintext. Cloud storage blobs are E2E encrypted before upload — the backend only verifies integrity, never decrypts.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── app/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── main.py # FastAPI entry + CORS + lifespan + router includes
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── agent_registry.py # Base classes + singleton registry
|
|
||||||
│ │ ├── orchestrator.py # LLM-based intent router
|
|
||||||
│ │ ├── execution_plan.py # Plan builder + cache
|
|
||||||
│ │ └── plugin_loader.py # Dynamic agent loading
|
|
||||||
│ ├── agents/ # Chat agents (proprietary logic + prompts)
|
|
||||||
│ │ ├── __init__.py # Auto-registers all agents
|
|
||||||
│ │ ├── task_agent.py
|
|
||||||
│ │ ├── calendar_agent.py
|
|
||||||
│ │ ├── email_agent.py
|
|
||||||
│ │ └── analytics_agent.py
|
|
||||||
│ ├── api/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── routes/
|
|
||||||
│ │ │ ├── __init__.py
|
|
||||||
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
|
|
||||||
│ │ │ ├── plans.py # GET /plans/playbook
|
|
||||||
│ │ │ ├── storage.py # CRUD cloud storage (E2E encrypted blobs)
|
|
||||||
│ │ │ ├── vectors.py # Upsert/search cloud vector store
|
|
||||||
│ │ │ ├── backup.py # PUT/GET /backup
|
|
||||||
│ │ │ ├── plugins.py # Plugin marketplace
|
|
||||||
│ │ │ ├── auth.py # Register/login/refresh
|
|
||||||
│ │ │ └── billing.py # Checkout/webhook/subscription
|
|
||||||
│ │ └── middleware/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── auth.py # JWT validation
|
|
||||||
│ │ ├── rate_limit.py # Tier-aware rate limiting
|
|
||||||
│ │ └── sanitizer.py # Strip prompt metadata from responses
|
|
||||||
│ ├── storage/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── blob_store.py # S3 for E2E encrypted blobs
|
|
||||||
│ │ ├── vector_store.py # Cloud vector store (Pinecone/Qdrant)
|
|
||||||
│ │ └── encryption.py # Integrity verification only — NO decryption
|
|
||||||
│ ├── marketplace/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── plugin_registry.py # Plugin catalog (metadata, versions, ratings)
|
|
||||||
│ │ ├── plugin_review.py # Review queue + approval workflow
|
|
||||||
│ │ └── revenue_share.py # 70/30 split tracking with Stripe Connect
|
|
||||||
│ ├── billing/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── stripe_service.py # Stripe checkout + webhooks
|
|
||||||
│ │ └── tier_manager.py # Feature matrix per tier
|
|
||||||
│ └── config/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ └── settings.py # Pydantic BaseSettings (env-based)
|
|
||||||
├── tests/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── conftest.py # Fixtures: test client, mock agents, mock LLM
|
|
||||||
│ ├── test_orchestrator.py
|
|
||||||
│ ├── test_agents.py
|
|
||||||
│ ├── test_auth.py
|
|
||||||
│ ├── test_backup.py
|
|
||||||
│ ├── test_storage.py
|
|
||||||
│ └── test_plugins.py
|
|
||||||
├── alembic/ # DB migrations (auth/billing/marketplace tables only)
|
|
||||||
│ ├── alembic.ini
|
|
||||||
│ └── versions/
|
|
||||||
├── requirements.txt
|
|
||||||
├── Dockerfile
|
|
||||||
├── docker-compose.yml # App + PostgreSQL + Redis (dev)
|
|
||||||
├── .env.example
|
|
||||||
└── README.md
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step-by-Step Implementation
|
|
||||||
|
|
||||||
### Step 1 — Project scaffolding ✅
|
|
||||||
- [x] Initialize repo with the directory structure above
|
|
||||||
- [x] Write `requirements.txt`:
|
|
||||||
```
|
|
||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
langchain>=0.3.0
|
|
||||||
langchain-openai>=0.3.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
python-jose[cryptography]>=3.3.0
|
|
||||||
stripe>=11.0.0
|
|
||||||
boto3>=1.35.0
|
|
||||||
slowapi>=0.1.9
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
asyncpg>=0.30.0
|
|
||||||
alembic>=1.14.0
|
|
||||||
bcrypt>=4.2.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
httpx>=0.28.0
|
|
||||||
websockets>=14.0
|
|
||||||
pytest>=8.0.0
|
|
||||||
pytest-asyncio>=0.24.0
|
|
||||||
```
|
|
||||||
- [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
|
|
||||||
- [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod), `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
|
||||||
- [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
|
|
||||||
- [x] Write `docker-compose.yml`: app, postgres:16, optional redis
|
|
||||||
- [x] Write `.env.example`
|
|
||||||
- **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes).
|
|
||||||
|
|
||||||
### Step 2 — Pydantic schemas (API contracts) ✅
|
|
||||||
- [x] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
|
|
||||||
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
|
|
||||||
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
|
|
||||||
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
|
|
||||||
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification', 'call_agent']`, `table: str | None`, `data: dict | None`, `agent: str | None`
|
|
||||||
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
|
|
||||||
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
|
|
||||||
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
|
|
||||||
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
|
|
||||||
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
|
|
||||||
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
|
||||||
- `StorageRecord`: `id: str`, `user_id: str`, `table: str`, `blob: bytes`, `checksum: str`, `created_at: int`, `updated_at: int` — blob is always E2E encrypted by client
|
|
||||||
- `StorageRecordCreate`: `table: str`, `blob: bytes`, `checksum: str`
|
|
||||||
- `StorageRecordUpdate`: `blob: bytes`, `checksum: str`
|
|
||||||
- `VectorUpsertRequest`: `vectors: list[VectorItem]`
|
|
||||||
- `VectorItem`: `id: str`, `blob: bytes`, `checksum: str` — vector + metadata encrypted by client
|
|
||||||
- `VectorSearchRequest`: `query_blob: bytes`, `top_k: int = 10`
|
|
||||||
- `VectorSearchResponse`: `results: list[VectorSearchResult]`
|
|
||||||
- `VectorSearchResult`: `id: str`, `score: float`, `blob: bytes`
|
|
||||||
- `PluginManifest`: `id: str`, `name: str`, `description: str`, `version: str`, `author: str`, `permissions: list[str]`, `category: str`, `price_cents: int = 0`
|
|
||||||
- `PluginListResponse`: `plugins: list[PluginManifest]`, `total: int`, `page: int`
|
|
||||||
- `PluginInstallRequest`: `plugin_id: str`
|
|
||||||
- **Outcome:** All request/response models defined and validated.
|
|
||||||
|
|
||||||
### Step 3 — Agent Registry + base classes ✅
|
|
||||||
- [x] `app/core/agent_registry.py`:
|
|
||||||
- `BaseAgent(ABC)`:
|
|
||||||
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
|
|
||||||
- Abstract `get_name() -> str`, `get_description() -> str`
|
|
||||||
- `ChatAgent(BaseAgent)`:
|
|
||||||
- Abstract `async handle(query: str, context: dict) -> str`
|
|
||||||
- Abstract `get_tools() -> list` (LangChain tool definitions)
|
|
||||||
- Concrete `_tool_loop(llm, messages, tools, max_iter=5) -> str` — shared tool-calling loop
|
|
||||||
- `AgentRegistry` (singleton):
|
|
||||||
- `_agents: dict[str, ChatAgent]`
|
|
||||||
- `register(agent_class)` — decorator pattern
|
|
||||||
- `get(name) -> ChatAgent`
|
|
||||||
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
|
|
||||||
- `async call_agent(name, query, context) -> str` — for inter-agent calls
|
|
||||||
- [x] Unit tests: register, get, list, call_agent with mock
|
|
||||||
- **Outcome:** Pluggable agent framework.
|
|
||||||
|
|
||||||
### Step 4 — Orchestrator ✅
|
|
||||||
- [x] `app/core/orchestrator.py`:
|
|
||||||
- `async classify_intent(message, context, registry) -> str`:
|
|
||||||
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
|
|
||||||
- Uses gpt-4o-mini via LangChain for low latency
|
|
||||||
- Falls back to `task_agent` if no clear match
|
|
||||||
- `async route_single(agent_name, message, context) -> ChatResponse`:
|
|
||||||
- Instantiates agent from registry
|
|
||||||
- Calls `agent.handle(message, context)`
|
|
||||||
- Returns response + any actions the agent produced
|
|
||||||
- `async route_pipeline(agent_names, message, context) -> ChatResponse`:
|
|
||||||
- Executes agents in sequence
|
|
||||||
- Each agent receives `{...context, previous_results: [...]}`
|
|
||||||
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
|
|
||||||
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
|
|
||||||
- Main entry point
|
|
||||||
- Context is transparent to orchestrator — data may originate from local or cloud storage on the client side
|
|
||||||
- Classifies intent
|
|
||||||
- If `execution_mode == 'direct'`: route + return response
|
|
||||||
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
|
|
||||||
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
|
|
||||||
- Same as orchestrate but yields tokens for WebSocket streaming
|
|
||||||
- [x] Integration tests with mocked LLM and mocked agents
|
|
||||||
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
|
|
||||||
|
|
||||||
### Step 5 — Execution Plan generator ✅
|
|
||||||
- [x] `app/core/execution_plan.py`:
|
|
||||||
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
|
|
||||||
- `ExecutionPlanBuilder`:
|
|
||||||
- `add_step(action, params) -> self`
|
|
||||||
- `add_llm_step(template_id, variables) -> self`
|
|
||||||
- `add_data_step(action, data_from_step) -> self`
|
|
||||||
- `build() -> ExecutionPlan` — validates step references
|
|
||||||
- `PlanCache`:
|
|
||||||
- In-memory LRU (maxsize=1000)
|
|
||||||
- `cache_plan(key, plan)`, `get_plan(key)`, `get_all_playbooks() -> list[ExecutionPlan]`
|
|
||||||
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
|
|
||||||
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
|
|
||||||
|
|
||||||
### Step 6 — Chat Agents ✅
|
|
||||||
- [x] `app/agents/task_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
- Tools (8): `list_tasks(project_id, status, search, order_by)`, `create_task(title, description, status, priority, assignees, due_date, project_id, is_ai_suggested, is_approved)`, `update_task(task_id, ...)`, `delete_task(task_id)`, `list_tasks_due_today()`, `list_task_comments(task_id)`, `add_task_comment(task_id, author, content)`, `delete_task_comment(comment_id)`
|
|
||||||
- status: `todo|in_progress|done`; priority: `high|medium|low`; assignees: JSON-encoded string; due_date: ms timestamp
|
|
||||||
- Accepts flexible context; sentinel `-1` for optional integer update fields
|
|
||||||
- [x] `app/agents/timeline_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages project timelines (milestones): list, create, update, delete"
|
|
||||||
- Tools (4): `list_timelines(project_id)`, `create_timeline(project_id, title, date, is_ai_suggested, is_approved)`, `update_timeline(timeline_id, ...)`, `delete_timeline(timeline_id)`
|
|
||||||
- `project_id` is required for create; date is a ms timestamp; supports AI-suggestion + approval workflow
|
|
||||||
- [x] `app/agents/project_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
- Tools (6): `list_projects(client_id, include_archived)`, `list_all_projects()`, `get_project(project_id)`, `create_project(name, client_id)`, `update_project(project_id, ...)`, `delete_project(project_id)`
|
|
||||||
- status: `active|archived`; prefers archive over deletion (docstring guard on delete)
|
|
||||||
- [x] `app/agents/note_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages notes: list, get, create, update, delete"
|
|
||||||
- Tools (5): `list_notes(project_id)`, `get_note(note_id)`, `create_note(title, content, project_id)`, `update_note(note_id, ...)`, `delete_note(note_id)`
|
|
||||||
- content is Markdown; `get_note` should be called before update to preserve existing content
|
|
||||||
- [x] `app/agents/__init__.py`: imports all four agent modules to trigger `@registry.register` decorators
|
|
||||||
- [x] Unit tests per agent with mocked LLM (registration, names, tool counts, handle(), direct tool invocation)
|
|
||||||
- **Outcome:** Four domain-specific agents matching the UI data model (Tasks, Timelines, Projects, Notes), all registered and tested.
|
|
||||||
|
|
||||||
### Step 7 — Storage Layer ✅
|
|
||||||
- [x] `app/storage/blob_store.py`:
|
|
||||||
- `BlobStore`: `async upload`, `async download`, `async delete` (idempotent), `async list_keys`
|
|
||||||
- Keys: `{user_id}/{table}/{record_id}` — backend never inspects blob content
|
|
||||||
- boto3 S3 with SSE-S3 at-rest encryption; client checksum stored in S3 object metadata
|
|
||||||
- [x] `app/storage/vector_store.py`:
|
|
||||||
- `VectorStore`: `async upsert`, `async search`, `async delete`
|
|
||||||
- Pinecone (default, `namespace=user_id`) or Qdrant (`user_id` payload filter) — runtime-configurable
|
|
||||||
- 32-dim SHA-256-derived float vector; blob stored as base64 in metadata/payload
|
|
||||||
- ANN on encrypted data: known accuracy trade-off, documented
|
|
||||||
- [x] `app/storage/encryption.py`:
|
|
||||||
- `verify_checksum(blob, checksum) -> bool` — SHA-256 + `hmac.compare_digest` (constant-time)
|
|
||||||
- `reject_if_tampered(blob, checksum)` — raises `HTTP 400` on mismatch
|
|
||||||
- Backend NEVER holds decryption keys
|
|
||||||
- [x] `app/schemas.py`: added `StorageRecord*`, `VectorItem`, `VectorUpsertRequest`, `VectorSearch*`, `Plugin*` schemas
|
|
||||||
- [x] `app/config/settings.py`: added `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
|
||||||
- [x] `requirements.txt`: added `moto[s3]`, `pinecone`, `qdrant-client`
|
|
||||||
- [x] 37 unit tests covering encryption, BlobStore (moto), VectorStore Pinecone, VectorStore Qdrant
|
|
||||||
- **Outcome:** Cloud storage layer that handles E2E encrypted blobs without ever accessing plaintext.
|
|
||||||
|
|
||||||
### Step 8 — API Routes ✅
|
|
||||||
|
|
||||||
#### 8a — Chat endpoint
|
|
||||||
- [x] `app/api/routes/chat.py`:
|
|
||||||
- `POST /api/v1/chat`:
|
|
||||||
- Request: `ChatRequest`
|
|
||||||
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
|
|
||||||
- Response: `ChatResponse` or `ExecutionPlan`
|
|
||||||
- `WebSocket /api/v1/chat/stream`:
|
|
||||||
- Client sends `ChatRequest` as first JSON frame
|
|
||||||
- Server yields token strings via `orchestrate_stream()`
|
|
||||||
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
|
|
||||||
- Heartbeat ping every 30s to keep connection alive
|
|
||||||
|
|
||||||
#### 8b — Plans endpoint
|
|
||||||
- [x] `app/api/routes/plans.py`:
|
|
||||||
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
|
|
||||||
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
|
|
||||||
|
|
||||||
#### 8c — Storage endpoint (cloud records)
|
|
||||||
- [x] `app/api/routes/storage.py`:
|
|
||||||
- `POST /api/v1/storage/records`: Create encrypted record
|
|
||||||
- Request: `StorageRecordCreate`
|
|
||||||
- Verifies checksum, stores blob in S3, inserts metadata row in PostgreSQL
|
|
||||||
- Response: `{id: str, created_at: int}`
|
|
||||||
- `GET /api/v1/storage/records`: List record metadata (no blobs)
|
|
||||||
- Query params: `table: str`, `page: int`, `limit: int`
|
|
||||||
- Response: `list[{id, table, checksum, created_at, updated_at}]`
|
|
||||||
- `GET /api/v1/storage/records/{id}`: Download encrypted blob
|
|
||||||
- Response: blob bytes + `X-Checksum` header
|
|
||||||
- `PUT /api/v1/storage/records/{id}`: Update encrypted blob
|
|
||||||
- Request: `StorageRecordUpdate`
|
|
||||||
- `DELETE /api/v1/storage/records/{id}`: Delete record + S3 blob
|
|
||||||
- All routes enforce tier cloud_storage_gb quota via `TierManager.check_quota(user_id)`
|
|
||||||
|
|
||||||
#### 8d — Vectors endpoint (cloud vector store)
|
|
||||||
- [x] `app/api/routes/vectors.py`:
|
|
||||||
- `POST /api/v1/storage/vectors/upsert`:
|
|
||||||
- Request: `VectorUpsertRequest`
|
|
||||||
- Verifies checksums, delegates to `VectorStore.upsert()`
|
|
||||||
- Response: `{upserted: int}`
|
|
||||||
- `POST /api/v1/storage/vectors/search`:
|
|
||||||
- Request: `VectorSearchRequest`
|
|
||||||
- Delegates to `VectorStore.search()`
|
|
||||||
- Response: `VectorSearchResponse`
|
|
||||||
- `DELETE /api/v1/storage/vectors`:
|
|
||||||
- Request: `{ids: list[str]}`
|
|
||||||
|
|
||||||
#### 8e — Backup endpoint
|
|
||||||
- [x] `app/api/routes/backup.py`:
|
|
||||||
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
|
|
||||||
- Free: 0 (no backup)
|
|
||||||
- Pro: 5 GB
|
|
||||||
- Power: 25 GB
|
|
||||||
- Team: unlimited
|
|
||||||
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
|
|
||||||
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
|
|
||||||
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
|
|
||||||
|
|
||||||
#### 8f — Plugins endpoint
|
|
||||||
- [x] `app/api/routes/plugins.py`:
|
|
||||||
- `GET /api/v1/plugins`:
|
|
||||||
- Query params: `category: str | None`, `q: str | None`, `page: int`, `sort: Literal['rating', 'installs', 'newest']`
|
|
||||||
- Response: `PluginListResponse`
|
|
||||||
- Available from Power tier and above
|
|
||||||
- `GET /api/v1/plugins/{id}`:
|
|
||||||
- Response: `PluginManifest` + ratings + install count
|
|
||||||
- `POST /api/v1/plugins/{id}/install`:
|
|
||||||
- Request: `PluginInstallRequest`
|
|
||||||
- Records installation for the user (billing tracking, analytics)
|
|
||||||
- If plugin is paid: triggers Stripe Connect charge + revenue split (70% developer, 30% platform)
|
|
||||||
- Response: `{ok: true, download_url: str}` — signed S3 URL for plugin package
|
|
||||||
- `DELETE /api/v1/plugins/{id}/install`:
|
|
||||||
- Unregisters installation
|
|
||||||
|
|
||||||
#### 8g — Auth endpoint
|
|
||||||
- [x] `app/api/routes/auth.py`:
|
|
||||||
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
|
|
||||||
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
|
|
||||||
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
|
|
||||||
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
|
|
||||||
|
|
||||||
#### 8h — Billing endpoint
|
|
||||||
- [x] `app/api/routes/billing.py`:
|
|
||||||
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
|
|
||||||
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
|
|
||||||
- `GET /api/v1/billing/subscription`: Returns current subscription info
|
|
||||||
- `DELETE /api/v1/billing/subscription`: Cancels subscription
|
|
||||||
|
|
||||||
- **Outcome:** Complete REST + WebSocket API covering orchestration, storage, vectors, backup, marketplace.
|
|
||||||
|
|
||||||
### Step 9 — Middleware
|
|
||||||
|
|
||||||
#### 9a — Auth middleware
|
|
||||||
- [x] `app/api/middleware/auth.py`:
|
|
||||||
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
|
||||||
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
|
||||||
- Raises `401` on invalid/expired token
|
|
||||||
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
|
||||||
|
|
||||||
#### 9b — Rate limiter
|
|
||||||
- [x] `app/api/middleware/rate_limit.py`:
|
|
||||||
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
|
||||||
- Tier-based limits:
|
|
||||||
- Free: 20 req/min
|
|
||||||
- Pro: 60 req/min
|
|
||||||
- Power: 120 req/min
|
|
||||||
- Team: 200 req/seat/min
|
|
||||||
- Custom 429 response with `Retry-After` header
|
|
||||||
|
|
||||||
#### 9c — Sanitizer
|
|
||||||
- [x] `app/api/middleware/sanitizer.py`:
|
|
||||||
- Response middleware that scans response bodies
|
|
||||||
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
|
||||||
- Pattern-based detection + exact match against known prompt fingerprints
|
|
||||||
- Logs sanitization events for monitoring
|
|
||||||
|
|
||||||
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
|
||||||
|
|
||||||
### Step 10 — Plugin Marketplace ✅
|
|
||||||
- [x] `app/marketplace/plugin_registry.py`:
|
|
||||||
- `PluginRegistry`:
|
|
||||||
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
|
|
||||||
- `async get_plugin(plugin_id) -> PluginManifest | None`
|
|
||||||
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
|
|
||||||
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
|
|
||||||
- `async reject_plugin(plugin_id, reason: str) -> None`
|
|
||||||
- [x] `app/marketplace/plugin_review.py`:
|
|
||||||
- `ReviewQueue`:
|
|
||||||
- `async get_pending() -> list[dict]`
|
|
||||||
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
|
|
||||||
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
|
|
||||||
- [x] `app/marketplace/revenue_share.py`:
|
|
||||||
- `RevenueShare`:
|
|
||||||
- `async record_install(plugin_id, user_id, amount_cents) -> None`
|
|
||||||
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
|
|
||||||
- `async get_earnings(developer_id, period) -> dict`
|
|
||||||
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
|
||||||
|
|
||||||
### Step 11 — Billing & Tier management ✅
|
|
||||||
- [x] `app/billing/stripe_service.py`:
|
|
||||||
- `create_checkout_session(user_id, tier) -> str`
|
|
||||||
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
|
||||||
- `get_subscription(user_id) -> dict | None`
|
|
||||||
- `cancel_subscription(user_id) -> None`
|
|
||||||
- [x] `app/billing/tier_manager.py`:
|
|
||||||
- `TierManager`:
|
|
||||||
- Feature matrix:
|
|
||||||
```python
|
|
||||||
FEATURES = {
|
|
||||||
'free': {
|
|
||||||
'agents': 3,
|
|
||||||
'batch_active': 2,
|
|
||||||
'cloud_storage_gb': 0,
|
|
||||||
'backup_gb': 0,
|
|
||||||
'providers': 1,
|
|
||||||
'batch_builder': False,
|
|
||||||
'plugin_marketplace': False,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'pro': {
|
|
||||||
'agents': -1, # unlimited
|
|
||||||
'batch_active': 10,
|
|
||||||
'cloud_storage_gb': 5,
|
|
||||||
'backup_gb': 5,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': False,
|
|
||||||
'plugin_marketplace': False,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'power': {
|
|
||||||
'agents': -1,
|
|
||||||
'batch_active': -1, # unlimited
|
|
||||||
'cloud_storage_gb': 25,
|
|
||||||
'backup_gb': 25,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': True,
|
|
||||||
'plugin_marketplace': True,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'team': {
|
|
||||||
'agents': -1,
|
|
||||||
'batch_active': -1,
|
|
||||||
'cloud_storage_gb': -1,
|
|
||||||
'backup_gb': -1,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': True,
|
|
||||||
'plugin_marketplace': True,
|
|
||||||
'sso': True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
- `get_tier(user_id) -> BillingTier`
|
|
||||||
- `check_feature(user_id, feature) -> bool`
|
|
||||||
- `get_rate_limit(tier) -> int`
|
|
||||||
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
|
||||||
- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons
|
|
||||||
- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService`
|
|
||||||
- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota`
|
|
||||||
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
|
||||||
|
|
||||||
### Step 12 — Database (auth/billing/marketplace only)
|
|
||||||
- [x] PostgreSQL schema via Alembic:
|
|
||||||
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
|
|
||||||
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
|
|
||||||
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
|
|
||||||
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
|
|
||||||
- `storage_records`: `id UUID PK`, `user_id FK`, `table_name VARCHAR`, `s3_key`, `checksum`, `size_bytes`, `created_at`, `updated_at` — metadata only, no plaintext
|
|
||||||
- `plugins`: `id UUID PK`, `name`, `description`, `version`, `author_id FK`, `category`, `status` (pending_review/approved/rejected), `price_cents`, `s3_package_key`, `install_count`, `avg_rating`, `created_at`
|
|
||||||
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
|
|
||||||
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
|
|
||||||
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
|
|
||||||
- [x] Initial Alembic migration
|
|
||||||
- [x] SQLAlchemy models in `app/models.py`
|
|
||||||
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
|
|
||||||
|
|
||||||
### Step 13 — Testing & deployment ✅
|
|
||||||
- [x] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone
|
|
||||||
- [x] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
|
|
||||||
- [x] `tests/test_agents.py`: each agent with mocked tools
|
|
||||||
- [x] `tests/test_auth.py`: register → login → access protected → refresh → expired token
|
|
||||||
- [x] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
|
|
||||||
- [x] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement
|
|
||||||
- [x] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked)
|
|
||||||
- [x] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
|
|
||||||
- [x] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
|
|
||||||
- **Outcome:** Fully tested, deployable backend.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## API Contract Summary
|
|
||||||
|
|
||||||
| Method | Endpoint | Auth | Request | Response |
|
|
||||||
|--------|----------|------|---------|----------|
|
|
||||||
| POST | `/api/v1/auth/register` | No | `{email, password}` | `AuthTokens` |
|
|
||||||
| POST | `/api/v1/auth/login` | No | `{email, password}` | `AuthTokens` |
|
|
||||||
| POST | `/api/v1/auth/refresh` | No | `{refresh_token}` | `AuthTokens` |
|
|
||||||
| GET | `/api/v1/auth/me` | JWT | — | `UserProfile` |
|
|
||||||
| POST | `/api/v1/chat` | JWT | `ChatRequest` | `ChatResponse \| ExecutionPlan` |
|
|
||||||
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
|
|
||||||
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
|
|
||||||
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
|
|
||||||
| POST | `/api/v1/storage/records` | JWT | `StorageRecordCreate` | `{id, created_at}` |
|
|
||||||
| GET | `/api/v1/storage/records` | JWT | `?table&page&limit` | `RecordMeta[]` |
|
|
||||||
| GET | `/api/v1/storage/records/:id` | JWT | — | Binary blob |
|
|
||||||
| PUT | `/api/v1/storage/records/:id` | JWT | `StorageRecordUpdate` | `{ok: true}` |
|
|
||||||
| DELETE | `/api/v1/storage/records/:id` | JWT | — | `{ok: true}` |
|
|
||||||
| POST | `/api/v1/storage/vectors/upsert` | JWT | `VectorUpsertRequest` | `{upserted: int}` |
|
|
||||||
| POST | `/api/v1/storage/vectors/search` | JWT | `VectorSearchRequest` | `VectorSearchResponse` |
|
|
||||||
| DELETE | `/api/v1/storage/vectors` | JWT | `{ids: list[str]}` | `{ok: true}` |
|
|
||||||
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/backup` | JWT | — | Binary blob |
|
|
||||||
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
|
|
||||||
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/plugins` | JWT | `?category&q&page&sort` | `PluginListResponse` |
|
|
||||||
| GET | `/api/v1/plugins/:id` | JWT | — | `PluginManifest` + stats |
|
|
||||||
| POST | `/api/v1/plugins/:id/install` | JWT | `PluginInstallRequest` | `{ok, download_url}` |
|
|
||||||
| DELETE | `/api/v1/plugins/:id/install` | JWT | — | `{ok: true}` |
|
|
||||||
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
|
|
||||||
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
|
|
||||||
| DELETE | `/api/v1/billing/subscription` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/health` | No | — | `{status, version}` |
|
|
||||||
| GET | `/api/v1/agents/catalog` | JWT | — | `AgentCatalogItem[]` |
|
|
||||||
| GET | `/api/v1/agents/local` | JWT | — | `LocalAgentConfigResponse[]` |
|
|
||||||
| POST | `/api/v1/agents/local` | JWT | `LocalAgentConfigCreate` | `LocalAgentConfigResponse` |
|
|
||||||
| PUT | `/api/v1/agents/local/{id}` | JWT | `LocalAgentConfigUpdate` | `LocalAgentConfigResponse` |
|
|
||||||
| DELETE | `/api/v1/agents/local/{id}` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/agents/cloud` | JWT | — | `CloudAgentConfigResponse[]` |
|
|
||||||
| POST | `/api/v1/agents/cloud` | JWT | `CloudAgentConfigCreate` | `CloudAgentConfigResponse` |
|
|
||||||
| PUT | `/api/v1/agents/cloud/{id}` | JWT | `CloudAgentConfigUpdate` | `CloudAgentConfigResponse` |
|
|
||||||
| DELETE | `/api/v1/agents/cloud/{id}` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/agents/runs` | JWT | `?agent_id&page&limit` | `AgentRunLogResponse[]` |
|
|
||||||
| POST | `/api/v1/agents/{id}/run` | JWT | — | `{ok: true, run_id}` |
|
|
||||||
| POST | `/api/v1/agents/journey/start` | JWT | `{agent_type, data_types}` | `{session_id, message, done}` |
|
|
||||||
| POST | `/api/v1/agents/journey/message` | JWT | `{session_id, message}` | `{session_id, message, done, prompt_template?}` |
|
|
||||||
| GET | `/api/v1/oauth/{provider}/authorize` | JWT | — | `{authorization_url}` |
|
|
||||||
| GET | `/api/v1/oauth/{provider}/callback` | — | OAuth code | `{encrypted_token}` |
|
|
||||||
| WS | `/api/v1/ws/device` | JWT | `device_hello` (first frame) | Agent trigger + tool_call frames |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stack
|
|
||||||
|
|
||||||
| Layer | Technology |
|
|
||||||
|-------|-----------|
|
|
||||||
| Framework | FastAPI + Uvicorn |
|
|
||||||
| LLM | LangChain + langchain-openai |
|
|
||||||
| Auth | PyJWT + bcrypt + OAuth2 |
|
|
||||||
| Billing | stripe-python + Stripe Connect |
|
|
||||||
| Blob storage | boto3 (S3) |
|
|
||||||
| Vector store | Pinecone or Qdrant (configurable) |
|
|
||||||
| Database | PostgreSQL + SQLAlchemy + Alembic |
|
|
||||||
| Rate limiting | slowapi |
|
|
||||||
| Cloud integrations | google-api-python-client, msgraph-sdk, msal |
|
|
||||||
| Agent scheduling | APScheduler |
|
|
||||||
| Testing | pytest + pytest-asyncio + httpx + moto (S3 mock) |
|
|
||||||
| Deployment | Docker → fly.io / Railway / AWS ECS |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 3 — New Files
|
|
||||||
|
|
||||||
| File | Purpose |
|
|
||||||
|---|---|
|
|
||||||
| `app/models.py` | Add `LocalAgentConfig`, `CloudAgentConfig`, `AgentRunLog` models |
|
|
||||||
| `app/schemas.py` | Add agent config schemas + WS agent frame types |
|
|
||||||
| `app/api/routes/agents.py` | Agent CRUD endpoints (catalog, local, cloud, runs, manual trigger) |
|
|
||||||
| `app/api/routes/agent_setup.py` | Chatbot Journey endpoints (start + message) |
|
|
||||||
| `app/api/routes/device_ws.py` | Persistent device WS endpoint (`/api/v1/ws/device`) |
|
|
||||||
| `app/api/routes/oauth.py` | OAuth authorize/callback for Gmail, Teams, Outlook |
|
|
||||||
| `app/core/agent_runner.py` | Agent run orchestration — local (WS file request) + cloud (API fetch) |
|
|
||||||
| `app/core/device_manager.py` | `DeviceConnectionManager` — tracks active Electron WS connections |
|
|
||||||
| `app/core/agent_scheduler.py` | Periodic scheduler for agent cron triggers |
|
|
||||||
| `app/integrations/gmail.py` | Gmail API client (fetch messages with filters) |
|
|
||||||
| `app/integrations/ms_graph.py` | MS Graph client for Outlook emails + Teams messages |
|
|
||||||
| `app/integrations/__init__.py` | Provider factory |
|
|
||||||
|
|
||||||
> **Full Phase 3 step-by-step plan:** See `AI_REFACTOR_PLAN.md` Phase 3 section.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Development Rules
|
|
||||||
|
|
||||||
1. **NEVER persist user data in plaintext.** The DB stores only auth, billing, storage metadata, and marketplace data. User context arrives in requests and is discarded. Cloud blobs are E2E encrypted client-side — backend only stores opaque bytes.
|
|
||||||
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. In plan mode, `prompt_template` fields are reference IDs only.
|
|
||||||
3. **NEVER decrypt user blobs.** `app/storage/encryption.py` only verifies checksums. No decryption key ever reaches the backend.
|
|
||||||
4. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
|
|
||||||
5. **Type hints everywhere.** All functions have full type annotations.
|
|
||||||
6. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
|
|
||||||
7. **Structured logging.** JSON logs with request ID correlation.
|
|
||||||
8. **Tier gates are enforced server-side.** Never trust client-reported tier. Always fetch from DB via `TierManager.get_tier(user_id)`.
|
|
||||||
9. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.
|
|
||||||
@@ -739,7 +739,7 @@ adiuva-api/
|
|||||||
│ │
|
│ │
|
||||||
│ ├── core/ # Orchestration engine
|
│ ├── core/ # Orchestration engine
|
||||||
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
||||||
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
|
│ │ ├── llm.py # LiteLLM factory (get_llm)
|
||||||
│ │ ├── orchestrator.py # Intent classification & routing
|
│ │ ├── orchestrator.py # Intent classification & routing
|
||||||
│ │ └── execution_plan.py # Plan builder, templates, cache
|
│ │ └── execution_plan.py # Plan builder, templates, cache
|
||||||
│ │
|
│ │
|
||||||
|
|||||||
@@ -1,353 +0,0 @@
|
|||||||
# V3 Migration Plan — Multi-Agent AI Productivity App
|
|
||||||
|
|
||||||
> Incremental migration from current architecture to v3.
|
|
||||||
> Each step is self-contained, testable, and backwards-compatible.
|
|
||||||
> No BYOK — server manages all LLM keys.
|
|
||||||
> Memory encryption: server-side per-user Fernet key (Option A).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## General Rules
|
|
||||||
|
|
||||||
**Code Cleanup**: As you implement each step, remove any code that becomes unused or obsolete. This includes:
|
|
||||||
- Old functions/methods that are superseded by new ones
|
|
||||||
- Deprecated imports or modules
|
|
||||||
- Dead code paths
|
|
||||||
- Old test files no longer needed
|
|
||||||
|
|
||||||
This keeps the codebase clean and prevents confusion. When removing code, note it in the commit message if significant.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Decisions Log
|
|
||||||
|
|
||||||
| Topic | Decision |
|
|
||||||
|---|---|
|
|
||||||
| WS topology | Single multiplexed socket (merge chat into device WS) |
|
|
||||||
| LLM keys | Server-managed only, no user key passthrough |
|
|
||||||
| Memory encryption | Per-user server-generated Fernet key, encrypted at rest, decrypted in-memory |
|
|
||||||
| device_manager | Already multi-user correct (keyed by user_id), no structural change |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 1 — WS Frame Protocol (schemas.py)
|
|
||||||
|
|
||||||
**Goal**: Define the v3 frame vocabulary so all subsequent steps can import it.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/schemas.py` — Add to `WsFrameType` enum:
|
|
||||||
- `home_request`, `floating_request`
|
|
||||||
- `stream_start`, `stream_text`, `stream_block`, `stream_end`
|
|
||||||
- `floating_domain`
|
|
||||||
- `data_request`, `data_response`, `mutation`
|
|
||||||
- Add Pydantic models:
|
|
||||||
- `WsHomeRequest(type, message, conversation_history?)`
|
|
||||||
- `WsFloatingRequest(type, message, scope: {type, id?})`
|
|
||||||
- `WsStreamStart(type, request_id)`
|
|
||||||
- `WsStreamText(type, request_id, chunk)`
|
|
||||||
- `WsStreamBlock(type, request_id, block_type, data)`
|
|
||||||
- `WsStreamEnd(type, request_id, mutations?)`
|
|
||||||
- `WsFloatingDomain(type, request_id, domain)`
|
|
||||||
- Keep all existing frame types (backward compat).
|
|
||||||
|
|
||||||
**Files touched**: `app/schemas.py`
|
|
||||||
|
|
||||||
**Test**: Unit test that validates each new model serializes/deserializes correctly.
|
|
||||||
```
|
|
||||||
pytest tests/test_schemas_v3.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 1 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-1: add v3 ws frame protocol (schemas.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 2 — Agent Streaming + Tool Result Capture (agent_registry.py, agents/)
|
|
||||||
|
|
||||||
**Goal**: Agents can stream LLM tokens and expose structured tool results.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/core/agent_registry.py`:
|
|
||||||
- Add `_tool_loop_stream()` to `ChatAgent` — same logic as `_tool_loop()` but the **final** LLM call (when no more tool calls) uses `llm.astream()` and yields tokens.
|
|
||||||
- Add `self.tool_results: list[dict]` attribute to `ChatAgent.__init__()`.
|
|
||||||
- In both `_tool_loop` and `_tool_loop_stream`, capture raw `execute_on_client` results when tools run (store in `self.tool_results`).
|
|
||||||
- `app/agents/*.py` — Each agent's tools already return text summaries. No change to tools. The raw data capture happens at the `_tool_loop` level by intercepting `ToolMessage` content that comes from `execute_on_client`.
|
|
||||||
|
|
||||||
**Files touched**: `app/core/agent_registry.py`
|
|
||||||
|
|
||||||
**Test**: Unit test with mocked LLM that verifies `_tool_loop_stream()` yields tokens and `agent.tool_results` contains structured data after a tool call.
|
|
||||||
```
|
|
||||||
pytest tests/test_agent_streaming.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 2 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-2: add agent streaming and tool result capture (agent_registry.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 3 — Router Refactor (orchestrator.py)
|
|
||||||
|
|
||||||
**Goal**: Orchestrator returns agent name alongside execution, supports streaming.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/core/orchestrator.py`:
|
|
||||||
- Add `orchestrate_v3(user_id, message, context, mode)` that:
|
|
||||||
1. Calls `classify_intent()` (unchanged) -> `agent_name`
|
|
||||||
2. Instantiates agent via registry
|
|
||||||
3. Returns `(agent_name, agent_instance)` — caller drives execution
|
|
||||||
- Add `orchestrate_v3_stream(user_id, message, context)` -> `AsyncGenerator` that:
|
|
||||||
1. Calls `classify_intent()` -> `agent_name`
|
|
||||||
2. Calls `agent.handle_stream()` (uses `_tool_loop_stream`)
|
|
||||||
3. Yields `(agent_name, token)` tuples — first yield includes agent name for domain detection
|
|
||||||
- Keep `orchestrate()` and `orchestrate_stream()` unchanged (backward compat for POST /chat).
|
|
||||||
|
|
||||||
**Files touched**: `app/core/orchestrator.py`
|
|
||||||
|
|
||||||
**Test**: Unit test with mocked LLM and mocked registry that verifies `orchestrate_v3_stream` yields `(agent_name, token)` pairs.
|
|
||||||
```
|
|
||||||
pytest tests/test_orchestrator_v3.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 3 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-3: add router refactor with streaming support (orchestrator.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 4 — Output Formatting Layer (NEW: output_formatter.py)
|
|
||||||
|
|
||||||
**Goal**: Home and Floating responses diverge at this layer only.
|
|
||||||
|
|
||||||
### Block Types (from Electron app components)
|
|
||||||
|
|
||||||
The LLM outputs a JSON block stream. Each block has a `type` field that maps to
|
|
||||||
an Electron renderer component. The server validates and forwards these blocks.
|
|
||||||
|
|
||||||
**Text block** — streamed immediately, word-by-word:
|
|
||||||
```json
|
|
||||||
{ "type": "text", "content": "Here's your task summary..." }
|
|
||||||
```
|
|
||||||
|
|
||||||
**Chart blocks** — buffered until complete, validated, sent as `stream_block`.
|
|
||||||
Chart types match shadcn/ui Recharts wrappers used in the Electron app:
|
|
||||||
```json
|
|
||||||
{ "type": "chart", "chartType": "<type>", "title": "...", "data": [...], "config": {...} }
|
|
||||||
```
|
|
||||||
Supported `chartType` values:
|
|
||||||
- `area` — Area chart (shadcn AreaChart)
|
|
||||||
- `bar` — Bar chart (shadcn BarChart)
|
|
||||||
- `line` — Line chart (shadcn LineChart)
|
|
||||||
- `pie` — Pie chart (shadcn PieChart)
|
|
||||||
- `radar` — Radar chart (shadcn RadarChart)
|
|
||||||
- `radial` — Radial/gauge chart (shadcn RadialChart)
|
|
||||||
|
|
||||||
`data` is an array of objects with keys matching the chart's dataKey config.
|
|
||||||
`config` follows the shadcn ChartConfig format: `{ [dataKey]: { label, color } }`.
|
|
||||||
|
|
||||||
**Entity blocks** — server serializes from `agent.tool_results` (not LLM-generated data):
|
|
||||||
```json
|
|
||||||
{ "type": "entity_ref", "entity": "task" }
|
|
||||||
```
|
|
||||||
The server resolves this by looking up the structured data from the agent's
|
|
||||||
tool call results and emitting a `stream_block` with the full entity data.
|
|
||||||
|
|
||||||
Supported entity types (matching Electron component types):
|
|
||||||
- `task` — TaskRow component (`TaskItem`: id, title, status, priority, assignee, dueDate, projectId, ...)
|
|
||||||
- `project` — Project card (id, name, clientId, status)
|
|
||||||
- `note` — Note card (id, title, createdAt, projectId)
|
|
||||||
- `timeline` — Timeline card (GanttTimeline: id, title, date, projectId, isAiSuggested, isApproved)
|
|
||||||
|
|
||||||
**Table block** — buffered, validated:
|
|
||||||
```json
|
|
||||||
{ "type": "table", "headers": ["Col1", "Col2"], "rows": [["val1", "val2"]] }
|
|
||||||
```
|
|
||||||
|
|
||||||
**Timeline block** — buffered, validated (renders via GanttChart component):
|
|
||||||
```json
|
|
||||||
{ "type": "timeline", "timelines": [{ "id": "...", "title": "...", "date": 1234567890 }] }
|
|
||||||
```
|
|
||||||
|
|
||||||
### Changes
|
|
||||||
|
|
||||||
- `app/core/output_formatter.py` (new file):
|
|
||||||
- `HomeFormatter`:
|
|
||||||
- Receives token stream from orchestrator
|
|
||||||
- Accumulates tokens into a JSON-aware buffer
|
|
||||||
- Detects block boundaries by `type` field:
|
|
||||||
- `text` -> yields `WsStreamText` immediately (streams content word-by-word)
|
|
||||||
- `chart` -> buffers until JSON complete, validates `chartType` against allowed set, yields `WsStreamBlock`
|
|
||||||
- `entity_ref` -> looks up data from `agent.tool_results`, serializes full entity, yields `WsStreamBlock`
|
|
||||||
- `table` -> buffers, validates headers/rows structure, yields `WsStreamBlock`
|
|
||||||
- `timeline` -> buffers, validates timeline objects, yields `WsStreamBlock`
|
|
||||||
- Invalid blocks are logged and skipped (never crash the stream)
|
|
||||||
- `FloatingFormatter`:
|
|
||||||
- Receives `agent_name` from orchestrator
|
|
||||||
- Maps agent name to domain (deterministic, by code — no LLM):
|
|
||||||
- `task_agent` -> `"tasks"`
|
|
||||||
- `timeline_agent` -> `"timelines"`
|
|
||||||
- `note_agent` -> `"notes"`
|
|
||||||
- `project_agent` -> `"projects"`
|
|
||||||
- Yields `WsFloatingDomain` immediately
|
|
||||||
- Then yields `WsStreamText` for all tokens (text-only, no blocks)
|
|
||||||
|
|
||||||
**Files touched**: `app/core/output_formatter.py` (new)
|
|
||||||
|
|
||||||
**Test**: Unit test that feeds a mock token stream through each formatter and asserts correct frame output sequence.
|
|
||||||
```
|
|
||||||
pytest tests/test_output_formatter.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 4 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-4: add output formatting layer (output_formatter.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 5 — Unified WS Handler (device_ws.py, chat.py, main.py)
|
|
||||||
|
|
||||||
**Goal**: Single multiplexed WebSocket handles device frames + Home/Floating chat.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/api/routes/device_ws.py`:
|
|
||||||
- Extend `_message_loop` dispatch to handle `home_request` and `floating_request`:
|
|
||||||
- On `home_request`: set `ws_context` executor, call `orchestrate_v3_stream`, pipe through `HomeFormatter`, send frames back on same socket.
|
|
||||||
- On `floating_request`: same, but pipe through `FloatingFormatter`.
|
|
||||||
- Wrap both in try/finally to clear `ws_context`.
|
|
||||||
- Each request gets a `request_id` (UUID) for frame correlation.
|
|
||||||
- Concurrent requests from same client are supported (each runs as an async task).
|
|
||||||
- `app/api/routes/chat.py`:
|
|
||||||
- Remove `chat_stream` WS endpoint and any related helper functions that were only used by it.
|
|
||||||
- Keep `POST /chat` endpoint unchanged (REST fallback).
|
|
||||||
- Clean up any unused imports.
|
|
||||||
- `app/main.py`:
|
|
||||||
- No change needed (device_ws router already registered).
|
|
||||||
|
|
||||||
**Files touched**: `app/api/routes/device_ws.py`, `app/api/routes/chat.py`, `app/main.py`
|
|
||||||
|
|
||||||
**Test**: Integration test with a WebSocket test client that:
|
|
||||||
1. Connects to `/api/v1/ws/device`
|
|
||||||
2. Sends `device_hello`
|
|
||||||
3. Sends `home_request` -> receives `stream_start`, `stream_text`*, `stream_end`
|
|
||||||
4. Sends `floating_request` -> receives `floating_domain`, `stream_text`*, `stream_end`
|
|
||||||
5. Verifies `tool_call`/`tool_result` round-trip still works during chat
|
|
||||||
```
|
|
||||||
pytest tests/test_ws_unified.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 5 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-5: unify ws handler (device_ws.py, chat.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 6 — Memory Models + Migration (models.py, alembic)
|
|
||||||
|
|
||||||
**Goal**: Database tables for 4-tier memory, with per-user encryption key.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/models.py`:
|
|
||||||
- Add `encryption_key` column to `User` model (Fernet key, generated on registration).
|
|
||||||
- Add `MemoryCore` model: `id, user_id, key, value_encrypted, updated_at`
|
|
||||||
- Add `MemoryAssociative` model: `id, user_id, content_encrypted, embedding (Vector(1536)), entity_type, entity_id, updated_at`
|
|
||||||
- Add `MemoryEpisodic` model: `id, user_id, summary_encrypted, session_id, created_at`
|
|
||||||
- Add `MemoryProactive` model: `id, user_id, pattern_encrypted, confidence, source, created_at`
|
|
||||||
- `alembic/versions/` — New migration adding the 4 memory tables + user encryption_key column.
|
|
||||||
- `app/api/routes/auth.py` — On user registration, generate and store a Fernet key.
|
|
||||||
|
|
||||||
**Files touched**: `app/models.py`, `alembic/versions/xxx_add_memory_tables.py`, `app/api/routes/auth.py`
|
|
||||||
|
|
||||||
**Test**: Run migration up/down, verify tables exist with correct columns.
|
|
||||||
```
|
|
||||||
alembic upgrade head && alembic downgrade -1 && alembic upgrade head
|
|
||||||
pytest tests/test_memory_models.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 6 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-6: add memory models and migration (models.py, alembic)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 7 — Memory Middleware (NEW: memory_middleware.py)
|
|
||||||
|
|
||||||
**Goal**: Enrich every Router call with memory context, store interactions after.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/core/memory_middleware.py` (new file):
|
|
||||||
- `MemoryMiddleware` class with:
|
|
||||||
- `enrich_context(user_id, message) -> dict` (pre-LLM):
|
|
||||||
1. Load core memory (user prefs) — always injected
|
|
||||||
2. Embed `message`, search `MemoryAssociative` via pgvector — top-k relevant
|
|
||||||
3. Fetch recent `MemoryEpisodic` entries — last N sessions
|
|
||||||
4. Fetch active `MemoryProactive` patterns — above confidence threshold
|
|
||||||
5. Return merged context dict
|
|
||||||
- `store_episode(user_id, session_id, message, response)` (post-LLM):
|
|
||||||
1. Summarize interaction (short LLM call or heuristic)
|
|
||||||
2. Encrypt and store in `MemoryEpisodic`
|
|
||||||
3. Embed interaction, encrypt and upsert in `MemoryAssociative`
|
|
||||||
- `update_core(user_id, key, value)` — explicit preference update
|
|
||||||
- All read/write operations encrypt/decrypt using the user's Fernet key from `User.encryption_key`
|
|
||||||
- `app/api/routes/device_ws.py` — Update `home_request` and `floating_request` handlers:
|
|
||||||
- Before orchestrator: `enriched = await memory.enrich_context(user_id, message)`
|
|
||||||
- After response complete: `await memory.store_episode(user_id, ...)`
|
|
||||||
|
|
||||||
**Files touched**: `app/core/memory_middleware.py` (new), `app/api/routes/device_ws.py`
|
|
||||||
|
|
||||||
**Test**: Unit test with seeded memory rows that verifies:
|
|
||||||
1. `enrich_context` returns core prefs + associative matches + episodic summaries
|
|
||||||
2. `store_episode` creates encrypted rows that can be decrypted with the user's key
|
|
||||||
3. End-to-end WS test: send `home_request`, verify memory enrichment is passed to orchestrator
|
|
||||||
```
|
|
||||||
pytest tests/test_memory_middleware.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 7 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-7: add memory middleware (memory_middleware.py, device_ws.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
|
|
||||||
| Step | Component | Effort | Depends On |
|
|
||||||
|------|-----------|--------|------------|
|
|
||||||
| 1 | WS Frame Protocol | Low | — |
|
|
||||||
| 2 | Agent Streaming | Medium | Step 1 |
|
|
||||||
| 3 | Router Refactor | Medium | Step 2 |
|
|
||||||
| 4 | Output Formatter | High | Steps 1, 3 |
|
|
||||||
| 5 | Unified WS Handler | High | Steps 1–4 |
|
|
||||||
| 6 | Memory Models | Medium | — |
|
|
||||||
| 7 | Memory Middleware | High | Steps 5, 6 |
|
|
||||||
|
|
||||||
Steps 1–5 form the streaming pipeline. Steps 6–7 form the memory system.
|
|
||||||
Step 6 can run in parallel with Steps 2–4 (no dependencies).
|
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
"""Deprecate backend agent config tables.
|
||||||
|
|
||||||
|
The Electron client is now the source of truth for agent configuration
|
||||||
|
(directory, extract targets, batch interval, custom prompt). Backend keeps
|
||||||
|
billing checks and trigger/run logs only.
|
||||||
|
|
||||||
|
Revision ID: 9a1f2d0b6c7e
|
||||||
|
Revises: 818478c251dc
|
||||||
|
Create Date: 2026-03-16
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "9a1f2d0b6c7e"
|
||||||
|
down_revision: Union[str, None] = "818478c251dc"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = sa.inspect(bind)
|
||||||
|
existing = set(inspector.get_table_names())
|
||||||
|
|
||||||
|
if "cloud_agent_configs" in existing:
|
||||||
|
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||||
|
op.drop_table("cloud_agent_configs")
|
||||||
|
|
||||||
|
if "local_agent_configs" in existing:
|
||||||
|
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||||
|
op.drop_table("local_agent_configs")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"local_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("device_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"cloud_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"provider",
|
||||||
|
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||||
|
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Import all agent modules to trigger @registry.register decorators."""
|
"""Expose tool modules used by deep orchestrator-worker graphs."""
|
||||||
|
|
||||||
from app.agents import timeline_agent, note_agent, project_agent, task_agent
|
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
__all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]
|
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
|
|||||||
85
app/agents/filesystem_agent.py
Normal file
85
app/agents/filesystem_agent.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||||
|
|
||||||
|
These tools delegate to the Electron client via ``execute_on_client()`` using
|
||||||
|
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
|
||||||
|
handles actual disk I/O and responds with ``tool_result`` frames.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_directory(path: str) -> str:
|
||||||
|
"""List files and folders in a local directory on the user's device.
|
||||||
|
|
||||||
|
Returns a formatted listing of entries with name, type (file/directory),
|
||||||
|
and full path.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
return f"Directory '{path}' is empty or does not exist."
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
entry_type = entry.get("type", "unknown")
|
||||||
|
entry_name = entry.get("name", "")
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||||
|
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_file_content(path: str) -> str:
|
||||||
|
"""Read the text content of a local file on the user's device.
|
||||||
|
|
||||||
|
Returns the file content as a string. Large files may be truncated
|
||||||
|
by the Electron client.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return f"File '{path}' is empty or could not be read."
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_file_metadata(path: str) -> str:
|
||||||
|
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||||
|
|
||||||
|
Returns a formatted summary of the file's metadata.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_file_metadata",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
size = result.get("size", "unknown")
|
||||||
|
created = result.get("createdAt", "unknown")
|
||||||
|
modified = result.get("modifiedAt", "unknown")
|
||||||
|
extension = result.get("extension", "unknown")
|
||||||
|
name = result.get("name", path)
|
||||||
|
return (
|
||||||
|
f"File: {name}\n"
|
||||||
|
f" Extension: {extension}\n"
|
||||||
|
f" Size: {size} bytes\n"
|
||||||
|
f" Created: {created}\n"
|
||||||
|
f" Modified: {modified}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FILESYSTEM_TOOLS: list[Any] = [
|
||||||
|
list_directory,
|
||||||
|
read_file_content,
|
||||||
|
get_file_metadata,
|
||||||
|
]
|
||||||
@@ -2,17 +2,23 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
from app.core.llm import embed
|
||||||
from app.core.llm import embed, get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
NOTE_SYSTEM_PROMPT = (
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
"and delete Markdown notes in their workspace.\n\n"
|
||||||
"Rules:\n"
|
"Rules:\n"
|
||||||
@@ -22,6 +28,7 @@ _SYSTEM_PROMPT = (
|
|||||||
" before appending or replacing sections\n"
|
" before appending or replacing sections\n"
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||||
" when the user is working within a specific project\n"
|
" when the user is working within a specific project\n"
|
||||||
|
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||||
" is already in the note (retrieved via get_note)."
|
" is already in the note (retrieved via get_note)."
|
||||||
)
|
)
|
||||||
@@ -30,10 +37,11 @@ _SYSTEM_PROMPT = (
|
|||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="notes",
|
table="notes",
|
||||||
filters={"projectId": project_id or None},
|
filters={"projectId": normalized_project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -122,23 +130,10 @@ async def delete_note(note_id: str) -> str:
|
|||||||
return f"Note {note_id} deleted."
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
NOTE_TOOLS: list[Any] = [
|
||||||
class NoteAgent(ChatAgent):
|
list_notes,
|
||||||
def get_name(self) -> str:
|
get_note,
|
||||||
return "note_agent"
|
create_note,
|
||||||
|
update_note,
|
||||||
def get_description(self) -> str:
|
delete_note,
|
||||||
return "Manages notes: list, get, create, update, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [list_notes, get_note, create_note, update_note, delete_note]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -2,17 +2,13 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
PROJECT_SYSTEM_PROMPT = (
|
||||||
"You are a project management assistant. You help users create, find,\n"
|
"You are a project management assistant. You help users create, find,\n"
|
||||||
"update, and archive projects in their workspace.\n\n"
|
"update, and archive projects in their workspace.\n\n"
|
||||||
"Rules:\n"
|
"Rules:\n"
|
||||||
@@ -137,16 +133,7 @@ async def delete_project(project_id: str) -> str:
|
|||||||
return f"Project {project_id} permanently deleted."
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
PROJECT_TOOLS: list[Any] = [
|
||||||
class ProjectAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "project_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [
|
|
||||||
list_projects,
|
list_projects,
|
||||||
list_all_projects,
|
list_all_projects,
|
||||||
get_project,
|
get_project,
|
||||||
@@ -154,13 +141,3 @@ class ProjectAgent(ChatAgent):
|
|||||||
update_project,
|
update_project,
|
||||||
delete_project,
|
delete_project,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -2,18 +2,23 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
TASK_SYSTEM_PROMPT = (
|
||||||
"You are a task management assistant for a project workspace.\n"
|
"You are a task management assistant for a project workspace.\n"
|
||||||
"You create, update, list, and track tasks and their comments.\n\n"
|
"You create, update, list, and track tasks and their comments.\n\n"
|
||||||
"Rules:\n"
|
"Rules:\n"
|
||||||
@@ -24,7 +29,7 @@ _SYSTEM_PROMPT = (
|
|||||||
" - project_id is optional; link to a project when the user mentions one\n"
|
" - project_id is optional; link to a project when the user mentions one\n"
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
||||||
" did not explicitly request; 0 otherwise\n"
|
" did not explicitly request; 0 otherwise\n"
|
||||||
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
|
||||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
||||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||||
" - Always confirm the action in plain, user-friendly language."
|
" - Always confirm the action in plain, user-friendly language."
|
||||||
@@ -43,11 +48,12 @@ async def list_tasks(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
filters={
|
filters={
|
||||||
"projectId": project_id or None,
|
"projectId": normalized_project_id or None,
|
||||||
"status": status or None,
|
"status": status or None,
|
||||||
"search": search or None,
|
"search": search or None,
|
||||||
"orderBy": order_by or None,
|
"orderBy": order_by or None,
|
||||||
@@ -73,7 +79,6 @@ async def create_task(
|
|||||||
due_date: int = 0,
|
due_date: int = 0,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
is_approved: int = 0,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new task.
|
"""Create a new task.
|
||||||
title: task title (required)
|
title: task title (required)
|
||||||
@@ -84,7 +89,6 @@ async def create_task(
|
|||||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||||
project_id: optional UUID of the parent project
|
project_id: optional UUID of the parent project
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
is_approved: 0 until the user confirms; 1 when confirmed
|
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -98,7 +102,6 @@ async def create_task(
|
|||||||
"dueDate": due_date or None,
|
"dueDate": due_date or None,
|
||||||
"projectId": project_id or None,
|
"projectId": project_id or None,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
"isApproved": is_approved,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
@@ -118,12 +121,10 @@ async def update_task(
|
|||||||
assignees: str = "",
|
assignees: str = "",
|
||||||
due_date: int = -1,
|
due_date: int = -1,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
is_approved: int = -1,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update fields on an existing task. Only pass fields you want to change.
|
"""Update fields on an existing task. Only pass fields you want to change.
|
||||||
task_id: the task's UUID (required)
|
task_id: the task's UUID (required)
|
||||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||||
is_approved: -1 means unchanged; 0 or 1 sets the value
|
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
@@ -140,8 +141,6 @@ async def update_task(
|
|||||||
updates["dueDate"] = due_date or None
|
updates["dueDate"] = due_date or None
|
||||||
if project_id:
|
if project_id:
|
||||||
updates["projectId"] = project_id
|
updates["projectId"] = project_id
|
||||||
if is_approved != -1:
|
|
||||||
updates["isApproved"] = is_approved
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
@@ -209,8 +208,12 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|||||||
table="taskComments",
|
table="taskComments",
|
||||||
data={"taskId": task_id, "author": author, "content": content},
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result.get("row", {})
|
||||||
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
|
row_author = row.get("author", author)
|
||||||
|
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
|
||||||
|
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||||
|
row_comment_id = row.get("id", "unknown")
|
||||||
|
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -223,16 +226,7 @@ async def delete_task_comment(comment_id: str) -> str:
|
|||||||
# ── Agent ─────────────────────────────────────────────────────────────
|
# ── Agent ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
TASK_TOOLS: list[Any] = [
|
||||||
class TaskAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "task_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [
|
|
||||||
list_tasks,
|
list_tasks,
|
||||||
create_task,
|
create_task,
|
||||||
update_task,
|
update_task,
|
||||||
@@ -242,13 +236,3 @@ class TaskAgent(ChatAgent):
|
|||||||
add_task_comment,
|
add_task_comment,
|
||||||
delete_task_comment,
|
delete_task_comment,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -2,24 +2,30 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
TIMELINE_SYSTEM_PROMPT = (
|
||||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
"track progress on a project — they are not calendar events.\n\n"
|
||||||
"Rules:\n"
|
"Rules:\n"
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||||
|
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||||
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||||
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
||||||
" - Listing without a project_id returns all timelines across projects\n"
|
" - Listing without a project_id returns all timelines across projects\n"
|
||||||
" - Always echo the title and formatted date in your confirmation."
|
" - Always echo the title and formatted date in your confirmation."
|
||||||
@@ -29,10 +35,11 @@ _SYSTEM_PROMPT = (
|
|||||||
@tool
|
@tool
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
"""List timelines. Provide project_id to scope to a specific project."""
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
filters={"projectId": project_id or None},
|
filters={"projectId": normalized_project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -47,14 +54,12 @@ async def create_timeline(
|
|||||||
title: str,
|
title: str,
|
||||||
date: int,
|
date: int,
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
is_approved: int = 0,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a project timeline (milestone).
|
"""Create a project timeline (milestone).
|
||||||
project_id: REQUIRED UUID of the parent project
|
project_id: REQUIRED UUID of the parent project
|
||||||
title: descriptive name for the milestone
|
title: descriptive name for the milestone
|
||||||
date: Unix timestamp in milliseconds
|
date: Unix timestamp in milliseconds
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
is_approved: 0 until the user confirms
|
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -64,7 +69,6 @@ async def create_timeline(
|
|||||||
"title": title,
|
"title": title,
|
||||||
"date": date,
|
"date": date,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
"isApproved": is_approved,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
@@ -76,20 +80,16 @@ async def update_timeline(
|
|||||||
timeline_id: str,
|
timeline_id: str,
|
||||||
title: str = "",
|
title: str = "",
|
||||||
date: int = -1,
|
date: int = -1,
|
||||||
is_approved: int = -1,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update a timeline. Only pass fields that should change.
|
"""Update a timeline. Only pass fields that should change.
|
||||||
timeline_id: UUID of the timeline (required)
|
timeline_id: UUID of the timeline (required)
|
||||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
updates["title"] = title
|
updates["title"] = title
|
||||||
if date != -1:
|
if date != -1:
|
||||||
updates["date"] = date
|
updates["date"] = date
|
||||||
if is_approved != -1:
|
|
||||||
updates["isApproved"] = is_approved
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
@@ -106,23 +106,9 @@ async def delete_timeline(timeline_id: str) -> str:
|
|||||||
return f"Timeline {timeline_id} deleted."
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
TIMELINE_TOOLS: list[Any] = [
|
||||||
class TimelineAgent(ChatAgent):
|
list_timelines,
|
||||||
def get_name(self) -> str:
|
create_timeline,
|
||||||
return "timeline_agent"
|
update_timeline,
|
||||||
|
delete_timeline,
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages project timelines (milestones): list, create, update, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [list_timelines, create_timeline, update_timeline, delete_timeline]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -55,12 +55,15 @@ async def get_current_user(
|
|||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
|
|
||||||
# Live tier lookup — subscription row is the authoritative source.
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
|
# In dev, fall back to 'power' (unlimited) so quota limits don't
|
||||||
|
# block local development when no Stripe subscription exists.
|
||||||
from app.models import Subscription, User # noqa: PLC0415
|
from app.models import Subscription, User # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str = result.scalar_one_or_none() or "free"
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
# Fetch name/surname from user row.
|
# Fetch name/surname from user row.
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
|
|||||||
@@ -1,54 +1,40 @@
|
|||||||
"""Chatbot Journey endpoints — guided conversation to build an agent prompt_template.
|
"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
Endpoints:
|
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
||||||
POST /agents/journey/start — start a new journey session
|
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
||||||
POST /agents/journey/message — continue the conversation
|
frames to the functions exported here.
|
||||||
|
|
||||||
Sessions are stored in-memory with a 30-minute TTL. Stale entries are
|
|
||||||
cleaned up lazily on access. Upgrade to Redis for multi-instance deployments.
|
|
||||||
|
|
||||||
Journey flow:
|
Journey flow:
|
||||||
1. Client sends ``{ agent_type, agent_id? }`` to ``/start``.
|
1. FE sends ``journey_start`` frame with basic agent config (directory,
|
||||||
2. Server creates a session, calls the LLM with a contextual system prompt,
|
data_types, schedule).
|
||||||
and returns the first question.
|
2. Server creates an in-memory session, sets up a WS executor so the
|
||||||
3. Client sends follow-up messages to ``/message``.
|
setup LLM can use file-system tools, does a first directory scrape,
|
||||||
4. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` block
|
and sends back a ``journey_reply`` with the first question.
|
||||||
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
3. FE sends ``journey_message`` frames for each user reply.
|
||||||
5. Server parses the block, sets ``done=True``, and returns the template.
|
4. Server appends the user message, calls the LLM (which may read files
|
||||||
|
via tools), and sends back a ``journey_reply``.
|
||||||
The ``prompt_template`` from the final response is meant to be stored in
|
5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template``
|
||||||
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
|
block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
||||||
by the Electron client (via the agent CRUD endpoints).
|
6. Server parses the block, sends ``journey_reply`` with ``done=True``
|
||||||
|
and the template. FE stores it locally.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
from app.core.llm import get_llm
|
from app.core.llm import get_llm
|
||||||
from app.db import get_session
|
|
||||||
from app.models import CloudAgentConfig, LocalAgentConfig
|
|
||||||
from app.schemas import (
|
|
||||||
JourneyMessageRequest,
|
|
||||||
JourneyResponse,
|
|
||||||
JourneyStartRequest,
|
|
||||||
UserProfile,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/agents/journey", tags=["agents"])
|
|
||||||
|
|
||||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
@@ -57,18 +43,25 @@ _SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
|||||||
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||||
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||||
|
|
||||||
# Maximum number of conversation turns before the LLM is nudged to wrap up.
|
# Minimum turns before we consider nudging the LLM to wrap up.
|
||||||
_MAX_TURNS: int = 5
|
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||||
|
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
|
||||||
|
_MAX_TURNS: int = 15
|
||||||
|
# Max tool-calling steps per LLM invocation.
|
||||||
|
_MAX_TOOL_STEPS: int = 6
|
||||||
|
|
||||||
# ── In-memory session store ───────────────────────────────────────────────
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _JourneySession:
|
class JourneySession:
|
||||||
session_id: str
|
session_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
agent_type: str # "local" | "cloud"
|
agent_type: str # "local" | "cloud"
|
||||||
|
directory: str
|
||||||
|
data_types: list[str]
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
system_prompt: str = ""
|
||||||
created_at: float = field(default_factory=time.monotonic)
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
@@ -76,67 +69,84 @@ class _JourneySession:
|
|||||||
|
|
||||||
|
|
||||||
# session_id → session
|
# session_id → session
|
||||||
_sessions: dict[str, _JourneySession] = {}
|
_sessions: dict[str, JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
def _get_session(session_id: str, user_id: str) -> _JourneySession:
|
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||||
"""Retrieve session; raise 404 on missing, expired, or wrong owner."""
|
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||||
s = _sessions.get(session_id)
|
s = _sessions.get(session_id)
|
||||||
if s is None or s.is_expired():
|
if s is None or s.is_expired():
|
||||||
_sessions.pop(session_id, None)
|
_sessions.pop(session_id, None)
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
return None
|
||||||
if s.user_id != user_id:
|
if s.user_id != user_id:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
return None
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# ── System prompt builder ─────────────────────────────────────────────────
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
_LOCAL_PREAMBLE = """\
|
|
||||||
What kind of files are in the directories you want to monitor? \
|
|
||||||
(for example: emails saved as .eml, documents in .pdf or .txt, markdown notes, etc.)"""
|
|
||||||
|
|
||||||
_CLOUD_PREAMBLE = """\
|
|
||||||
What kind of emails or messages should I look for? \
|
|
||||||
(for example: client communications, invoices, meeting notes, project updates, etc.)"""
|
|
||||||
|
|
||||||
_SYSTEM_PROMPT_TEMPLATE = """\
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
Your job is to understand exactly what data the user wants to extract from their {source_description} \
|
Your job is to understand exactly what data the user wants to extract from their
|
||||||
and produce a detailed prompt_template that a separate AI will use as its instruction set.
|
local directory and produce a detailed prompt_template that a separate AI will use
|
||||||
|
as its instruction set.
|
||||||
|
|
||||||
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
|
The extraction agent already has this base behaviour built in:
|
||||||
1. The type and format of the source content.
|
- Reads each file using file-system tools.
|
||||||
2. Which data types to extract: tasks, notes, timelines, and/or projects.
|
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
|
||||||
3. How fields should be mapped (e.g. email subject → task title).
|
- Sets isAiSuggested=1 on every new record.
|
||||||
4. Priority or status rules (e.g. "urgent" keyword → high priority).
|
- Only extracts data explicitly present in the files — it never invents information.
|
||||||
5. Any special handling, date extraction, or exclusions.
|
The user's custom prompt is appended AFTER this base behaviour, so focus on
|
||||||
|
what to look for and how to map it — not on the general extraction mechanics.
|
||||||
|
|
||||||
After 3-5 questions (when you have enough information), output the final prompt_template between \
|
You have access to file-system tools to explore the user's directory:
|
||||||
these exact markers on their own lines:
|
- list_directory: to see folder structure
|
||||||
|
- read_file_content: to peek at file contents
|
||||||
|
- get_file_metadata: to check file info
|
||||||
|
|
||||||
|
The user's configured directory is: {directory}
|
||||||
|
Target data types: {data_types}
|
||||||
|
|
||||||
|
IMPORTANT — project assignment is handled automatically by the main agent runner
|
||||||
|
before the custom prompt is ever used. You MUST NOT ask the user about projects,
|
||||||
|
projectId, or how to link records to projects. Never include projectId logic or
|
||||||
|
project creation instructions in the generated prompt_template.
|
||||||
|
|
||||||
|
Start by exploring the directory to understand its structure. Then ask concise,
|
||||||
|
focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||||
|
1. The type and format of the source content (confirmed by your exploration).
|
||||||
|
2. How fields should be mapped (e.g. filename → task title).
|
||||||
|
3. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||||
|
4. Any special handling, date extraction, or exclusions.
|
||||||
|
|
||||||
|
Once you reach 90% confidence, output the final prompt_template between these exact
|
||||||
|
markers on their own lines:
|
||||||
|
|
||||||
{template_start}
|
{template_start}
|
||||||
<the complete extraction prompt here>
|
<the complete extraction prompt here>
|
||||||
{template_end}
|
{template_end}
|
||||||
|
|
||||||
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
|
The prompt_template must be a self-contained instruction for an AI that reads files
|
||||||
and must return a JSON array of records in this shape:
|
and must perform CRUD operations using tools to create records. It should specify:
|
||||||
[{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
|
- What entity types to create (tasks, notes, timelines) — never projects.
|
||||||
|
- How to map file content to record fields (camelCase: title, status, priority,
|
||||||
|
dueDate, content, etc.) — never include projectId.
|
||||||
|
- That isAiSuggested must be set to 1 on every new record.
|
||||||
|
- Concrete examples of mappings based on what you discovered in the directory.
|
||||||
|
|
||||||
Rules for the generated template:
|
|
||||||
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
|
|
||||||
- Include concrete examples of mappings.
|
|
||||||
- Mention that Electron adds id/createdAt/updatedAt automatically.
|
|
||||||
- Set isAiSuggested: true and isApproved: false on every record.
|
|
||||||
{existing_section}\
|
{existing_section}\
|
||||||
Do not ask more than {max_turns} questions total. Start with your first question now.\
|
Keep asking clarifying questions until you are at least 90% confident you have
|
||||||
|
enough information to generate an accurate prompt_template. Once you reach that
|
||||||
|
confidence level, stop asking and produce the final template immediately.
|
||||||
|
Begin by exploring the directory, then ask your first question.\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
def _build_system_prompt(
|
||||||
source_description = (
|
directory: str,
|
||||||
"files in local directories" if agent_type == "local" else "emails and messages from cloud providers"
|
data_types: list[str],
|
||||||
)
|
existing_template: str | None = None,
|
||||||
|
) -> str:
|
||||||
existing_section = (
|
existing_section = (
|
||||||
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
f"---\n{existing_template}\n---\n"
|
f"---\n{existing_template}\n---\n"
|
||||||
@@ -144,18 +154,14 @@ def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
|||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
return _SYSTEM_PROMPT_TEMPLATE.format(
|
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
source_description=source_description,
|
directory=directory,
|
||||||
|
data_types=", ".join(data_types),
|
||||||
template_start=_TEMPLATE_START,
|
template_start=_TEMPLATE_START,
|
||||||
template_end=_TEMPLATE_END,
|
template_end=_TEMPLATE_END,
|
||||||
existing_section=existing_section,
|
existing_section=existing_section,
|
||||||
max_turns=_MAX_TURNS,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _first_question(agent_type: str) -> str:
|
|
||||||
return _LOCAL_PREAMBLE if agent_type == "local" else _CLOUD_PREAMBLE
|
|
||||||
|
|
||||||
|
|
||||||
# ── Template extraction ───────────────────────────────────────────────────
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -168,11 +174,37 @@ def _extract_template(text: str) -> str | None:
|
|||||||
return text[start_idx:end_idx].strip() or None
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
# ── LLM call ─────────────────────────────────────────────────────────────
|
# ── LLM call with tool support ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
def _as_text(content: Any) -> str:
|
||||||
"""Build LangChain messages from history and invoke the LLM."""
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_llm_with_tools(
|
||||||
|
system_prompt: str,
|
||||||
|
history: list[dict[str, Any]],
|
||||||
|
tools: list[Any],
|
||||||
|
) -> str:
|
||||||
|
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||||
|
|
||||||
|
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||||
|
continue until a final text response is produced.
|
||||||
|
"""
|
||||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
for turn in history:
|
for turn in history:
|
||||||
if turn["role"] == "user":
|
if turn["role"] == "user":
|
||||||
@@ -181,137 +213,194 @@ async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
|||||||
messages.append(AIMessage(content=turn["content"]))
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
llm = get_llm(model=None, temperature=0.4)
|
llm = get_llm(model=None, temperature=0.4)
|
||||||
response = await llm.ainvoke(messages)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
return response.content # type: ignore[return-value]
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(_MAX_TOOL_STEPS):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"agent_setup: journey tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_setup: journey tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:800],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
# Fallback: exceeded max steps.
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
# ── Existing-config loader ────────────────────────────────────────────────
|
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _load_existing_template(
|
async def handle_journey_start(
|
||||||
agent_id: str,
|
|
||||||
user_id: str,
|
user_id: str,
|
||||||
db: AsyncSession,
|
frame: dict[str, Any],
|
||||||
) -> str | None:
|
) -> dict[str, Any]:
|
||||||
"""Return the prompt_template of an existing agent config, or None."""
|
"""Handle a ``journey_start`` WS frame.
|
||||||
# Try local first, then cloud.
|
|
||||||
local_result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.id == agent_id,
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
local = local_result.scalar_one_or_none()
|
|
||||||
if local is not None:
|
|
||||||
return local.prompt_template
|
|
||||||
|
|
||||||
cloud_result = await db.execute(
|
Creates a session, runs the setup LLM with directory exploration,
|
||||||
select(CloudAgentConfig).where(
|
and returns the ``journey_reply`` payload.
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cloud = cloud_result.scalar_one_or_none()
|
|
||||||
return cloud.prompt_template if cloud is not None else None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/start", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
|
||||||
async def start_journey(
|
|
||||||
body: JourneyStartRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> JourneyResponse:
|
|
||||||
"""Start a new Chatbot Journey session.
|
|
||||||
|
|
||||||
If ``agent_id`` is provided the session is pre-seeded with the existing
|
|
||||||
agent's ``prompt_template`` so the user can refine it.
|
|
||||||
"""
|
"""
|
||||||
# Load existing template (may be None).
|
agent_type = frame.get("agent_type", "local")
|
||||||
existing_template: str | None = None
|
directory = frame.get("directory", "")
|
||||||
if body.agent_id:
|
data_types = frame.get("data_types", [])
|
||||||
existing_template = await _load_existing_template(body.agent_id, current_user.id, db)
|
existing_template = frame.get("existing_template")
|
||||||
# If agent_id was given but not found, proceed without seeding (don't 404 —
|
|
||||||
# the user may be starting a fresh journey for a not-yet-persisted config).
|
|
||||||
|
|
||||||
system_prompt = _build_system_prompt(body.agent_type, existing_template)
|
# Use the session_id provided by the FE so the reply matches the
|
||||||
first_question = _first_question(body.agent_type)
|
# listener key; fall back to a generated one if absent.
|
||||||
|
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||||
|
system_prompt = _build_system_prompt(directory, data_types, existing_template)
|
||||||
|
|
||||||
session_id = str(uuid.uuid4())
|
session = JourneySession(
|
||||||
session = _JourneySession(
|
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=current_user.id,
|
user_id=user_id,
|
||||||
agent_type=body.agent_type,
|
agent_type=agent_type,
|
||||||
# Seed history with the AI's first question so it stays consistent.
|
directory=directory,
|
||||||
history=[{"role": "assistant", "content": first_question}],
|
data_types=data_types,
|
||||||
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
# Store the system prompt inside the session for reuse in /message.
|
|
||||||
session.__dict__["_system_prompt"] = system_prompt # type: ignore[index]
|
# The LLM will explore the directory using FILESYSTEM_TOOLS via the
|
||||||
|
# ws_context executor (already set by the WS handler before calling us).
|
||||||
|
# Seed with an initial user message — some providers (e.g. GitHub Copilot)
|
||||||
|
# require at least one user/input message to be present.
|
||||||
|
seed_history: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||||
|
]
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history=seed_history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.extend(seed_history)
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
_sessions[session_id] = session
|
_sessions[session_id] = session
|
||||||
|
|
||||||
logger.info("Journey session %s started for user %s (agent_type=%s)", session_id, current_user.id, body.agent_type)
|
logger.info(
|
||||||
return JourneyResponse(session_id=session_id, message=first_question, done=False)
|
"agent_setup: journey session %s started for user %s (directory=%s)",
|
||||||
|
session_id,
|
||||||
|
user_id,
|
||||||
|
directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the LLM produced the template on the first turn (unlikely but possible).
|
||||||
@router.post("/message", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
|
||||||
async def send_journey_message(
|
|
||||||
body: JourneyMessageRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> JourneyResponse:
|
|
||||||
"""Send a message in an existing Chatbot Journey session.
|
|
||||||
|
|
||||||
The server appends the user's message to the conversation history,
|
|
||||||
calls the LLM, and appends the AI reply. When the LLM wraps up with a
|
|
||||||
``prompt_template`` block the response includes ``done=True`` and the
|
|
||||||
extracted template.
|
|
||||||
"""
|
|
||||||
session = _get_session(body.session_id, current_user.id)
|
|
||||||
system_prompt: str = session.__dict__.get("_system_prompt", _build_system_prompt(session.agent_type, None)) # type: ignore[assignment]
|
|
||||||
|
|
||||||
# Append user turn to history.
|
|
||||||
session.history.append({"role": "user", "content": body.message})
|
|
||||||
|
|
||||||
# Call the LLM with the full conversation so far.
|
|
||||||
ai_reply = await _call_llm(system_prompt, session.history)
|
|
||||||
|
|
||||||
# Append AI turn.
|
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
|
||||||
|
|
||||||
# Check if the LLM produced the final template.
|
|
||||||
prompt_template = _extract_template(ai_reply)
|
prompt_template = _extract_template(ai_reply)
|
||||||
done = prompt_template is not None
|
done = prompt_template is not None
|
||||||
|
|
||||||
# Strip the sentinel markers from the message shown to the user.
|
|
||||||
display_message = ai_reply
|
display_message = ai_reply
|
||||||
if done:
|
if done:
|
||||||
display_message = (
|
display_message = (
|
||||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
or "Here is your agent configuration. You can save it or continue refining."
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
)
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
|
||||||
if done:
|
return {
|
||||||
logger.info("Journey session %s completed for user %s", body.session_id, current_user.id)
|
"type": "journey_reply",
|
||||||
# Clean up the session immediately on completion.
|
"session_id": session_id,
|
||||||
_sessions.pop(body.session_id, None)
|
"message": display_message,
|
||||||
else:
|
"done": done,
|
||||||
# Nudge the LLM to wrap up after max turns.
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_message(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_message`` WS frame.
|
||||||
|
|
||||||
|
Appends the user message, calls the LLM, and returns the
|
||||||
|
``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
message = frame.get("message", "")
|
||||||
|
|
||||||
|
session = get_journey_session(session_id, user_id)
|
||||||
|
if session is None:
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "Journey session not found or expired. Please start a new setup.",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Append user turn.
|
||||||
|
session.history.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
# Call the LLM with tools.
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
|
# Check if the LLM produced the final template.
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
# If the LLM didn't produce a template, nudge it once it has asked enough
|
||||||
|
# questions (>= _MIN_TURNS_BEFORE_NUDGE) or hits the hard safety cap.
|
||||||
|
if not done:
|
||||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
if turns >= _MAX_TURNS:
|
if turns >= _MAX_TURNS:
|
||||||
# Add a system-level nudge as a hidden user message.
|
nudge_content = (
|
||||||
session.history.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": (
|
|
||||||
"[System: You have enough information. Please generate the final "
|
"[System: You have enough information. Please generate the final "
|
||||||
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
),
|
|
||||||
})
|
|
||||||
|
|
||||||
return JourneyResponse(
|
|
||||||
session_id=body.session_id,
|
|
||||||
message=display_message,
|
|
||||||
done=done,
|
|
||||||
prompt_template=prompt_template,
|
|
||||||
)
|
)
|
||||||
|
session.history.append({"role": "user", "content": nudge_content})
|
||||||
|
|
||||||
|
nudge_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||||
|
|
||||||
|
prompt_template = _extract_template(nudge_reply)
|
||||||
|
if prompt_template is not None:
|
||||||
|
done = True
|
||||||
|
ai_reply = nudge_reply
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
if _TEMPLATE_START in ai_reply
|
||||||
|
else "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,45 +1,36 @@
|
|||||||
"""Agent CRUD routes: local directory agents and cloud connector agents.
|
"""Agent routes.
|
||||||
|
|
||||||
Endpoints:
|
Backend responsibilities are intentionally minimal:
|
||||||
GET /agents/catalog — hardcoded agent type catalog
|
GET /agents/catalog — static catalog for UI display
|
||||||
GET /agents/local — list user's local agent configs
|
POST /agents/can-create — billing eligibility check
|
||||||
POST /agents/local — create local agent (tier-gated)
|
POST /agents/trigger — trigger a local agent run
|
||||||
PUT /agents/local/{agent_id} — partial update (ownership check)
|
|
||||||
DELETE /agents/local/{agent_id} — delete + cascade run logs
|
Agent configuration is owned by the Electron app and is not persisted
|
||||||
GET /agents/cloud — list user's cloud agent configs
|
in backend agent-config tables.
|
||||||
POST /agents/cloud — create cloud agent (tier-gated)
|
|
||||||
PUT /agents/cloud/{agent_id} — partial update (ownership check)
|
|
||||||
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
|
|
||||||
GET /agents/runs — paginated run logs (agent_id, page, limit)
|
|
||||||
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
import uuid
|
||||||
from typing import Any
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from pydantic import BaseModel
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy import func, or_, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.billing.tier_manager import FEATURES
|
from app.billing.tier_manager import FEATURES
|
||||||
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
from app.core.agent_runner import is_agent_running, run_local_agent
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
from app.models import AgentRunLog, LocalAgentConfig
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
AgentCatalogItem,
|
AgentCatalogItem,
|
||||||
|
AgentCreationCheckRequest,
|
||||||
|
AgentCreationCheckResponse,
|
||||||
AgentRunLogResponse,
|
AgentRunLogResponse,
|
||||||
CloudAgentConfigCreate,
|
AgentTriggerRequest,
|
||||||
CloudAgentConfigResponse,
|
|
||||||
CloudAgentConfigUpdate,
|
|
||||||
LocalAgentConfigCreate,
|
|
||||||
LocalAgentConfigResponse,
|
|
||||||
LocalAgentConfigUpdate,
|
|
||||||
UserProfile,
|
UserProfile,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -56,39 +47,21 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
|
|||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
# ── Model → schema converters ─────────────────────────────────────────
|
def _to_data_types(values: list[str]) -> list[str]:
|
||||||
|
normalize = {
|
||||||
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
|
"task": "tasks", "tasks": "tasks",
|
||||||
return LocalAgentConfigResponse(
|
"note": "notes", "notes": "notes",
|
||||||
id=a.id,
|
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||||
name=a.name,
|
"project": "projects", "projects": "projects",
|
||||||
device_id=a.device_id,
|
}
|
||||||
directory_paths=a.directory_paths,
|
seen: set[str] = set()
|
||||||
data_types=a.data_types,
|
result: list[str] = []
|
||||||
prompt_template=a.prompt_template,
|
for v in values:
|
||||||
file_extensions=a.file_extensions,
|
mapped = normalize.get(v)
|
||||||
schedule_cron=a.schedule_cron,
|
if mapped and mapped not in seen:
|
||||||
enabled=a.enabled,
|
seen.add(mapped)
|
||||||
last_run_at=_dt_ms_opt(a.last_run_at),
|
result.append(mapped)
|
||||||
created_at=_dt_ms(a.created_at),
|
return result
|
||||||
updated_at=_dt_ms(a.updated_at),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
|
|
||||||
return CloudAgentConfigResponse(
|
|
||||||
id=a.id,
|
|
||||||
provider=a.provider, # type: ignore[arg-type]
|
|
||||||
name=a.name,
|
|
||||||
data_types=a.data_types,
|
|
||||||
prompt_template=a.prompt_template,
|
|
||||||
schedule_cron=a.schedule_cron,
|
|
||||||
filter_config=a.filter_config,
|
|
||||||
enabled=a.enabled,
|
|
||||||
last_run_at=_dt_ms_opt(a.last_run_at),
|
|
||||||
created_at=_dt_ms(a.created_at),
|
|
||||||
updated_at=_dt_ms(a.updated_at),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||||
@@ -105,77 +78,42 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Ownership-checked lookups ─────────────────────────────────────────
|
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||||
|
|
||||||
async def _get_local_agent_for_user(
|
|
||||||
agent_id: str, user_id: str, db: AsyncSession
|
|
||||||
) -> LocalAgentConfig:
|
|
||||||
result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.id == agent_id,
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_cloud_agent_for_user(
|
|
||||||
agent_id: str, user_id: str, db: AsyncSession
|
|
||||||
) -> CloudAgentConfig:
|
|
||||||
result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(
|
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tier limit helper ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
|
|
||||||
"""Return combined enabled local + cloud agent count for the user."""
|
|
||||||
local_count = (
|
|
||||||
await db.execute(
|
|
||||||
select(func.count(LocalAgentConfig.id)).where(
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
LocalAgentConfig.enabled == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one()
|
|
||||||
cloud_count = (
|
|
||||||
await db.execute(
|
|
||||||
select(func.count(CloudAgentConfig.id)).where(
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
CloudAgentConfig.enabled == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one()
|
|
||||||
return local_count + cloud_count
|
|
||||||
|
|
||||||
|
|
||||||
def _enforce_agent_limit(tier: str, current_count: int) -> None:
|
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
if limit != -1 and current_count >= limit:
|
if limit != -1 and current_count >= limit:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
)
|
)
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
# ── Local page schema (used by runs endpoint) ─────────────────────────
|
async def _enforce_run_frequency(
|
||||||
|
tier: str,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||||
|
if limit == -1:
|
||||||
|
return # unlimited
|
||||||
|
|
||||||
class _RunsPage(BaseModel):
|
today_start = datetime.now(timezone.utc).replace(
|
||||||
total: int
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
page: int
|
)
|
||||||
limit: int
|
result = await db.execute(
|
||||||
items: list[AgentRunLogResponse]
|
select(func.count(AgentRunLog.id)).where(
|
||||||
|
AgentRunLog.user_id == user_id,
|
||||||
|
AgentRunLog.started_at >= today_start,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
runs_today: int = result.scalar_one()
|
||||||
|
|
||||||
|
if runs_today >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Catalog ───────────────────────────────────────────────────────────
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
@@ -209,229 +147,61 @@ async def get_agent_catalog(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# ── Local agent CRUD ──────────────────────────────────────────────────
|
@router.post("/can-create", response_model=AgentCreationCheckResponse)
|
||||||
|
async def can_create_agent(
|
||||||
@router.get("/local", response_model=list[LocalAgentConfigResponse])
|
body: AgentCreationCheckRequest,
|
||||||
async def list_local_agents(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
) -> AgentCreationCheckResponse:
|
||||||
) -> list[LocalAgentConfigResponse]:
|
"""Check if the user can create one more agent based on billing tier.
|
||||||
"""List all local directory agent configs owned by the authenticated user."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
return [_to_local_response(a) for a in result.scalars().all()]
|
|
||||||
|
|
||||||
|
Since configuration is client-owned, the Electron app sends its current
|
||||||
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
active agent count and the backend applies tier limits.
|
||||||
async def create_local_agent(
|
|
||||||
body: LocalAgentConfigCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> LocalAgentConfigResponse:
|
|
||||||
"""Create a new local directory agent config.
|
|
||||||
|
|
||||||
The combined count of enabled local and cloud agents for the user is
|
|
||||||
checked against the ``batch_active`` limit for their billing tier.
|
|
||||||
"""
|
"""
|
||||||
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
||||||
agent = LocalAgentConfig(
|
allowed = limit == -1 or body.active_agents < limit
|
||||||
user_id=current_user.id,
|
return AgentCreationCheckResponse(
|
||||||
name=body.name,
|
allowed=allowed,
|
||||||
device_id=body.device_id,
|
tier=current_user.tier,
|
||||||
directory_paths=body.directory_paths,
|
active_agents=body.active_agents,
|
||||||
data_types=body.data_types,
|
limit=limit,
|
||||||
prompt_template=body.prompt_template,
|
|
||||||
file_extensions=body.file_extensions,
|
|
||||||
schedule_cron=body.schedule_cron,
|
|
||||||
)
|
)
|
||||||
db.add(agent)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_local_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
|
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||||
async def update_local_agent(
|
|
||||||
agent_id: str,
|
|
||||||
body: LocalAgentConfigUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> LocalAgentConfigResponse:
|
|
||||||
"""Partially update a local agent config. Only provided fields are changed."""
|
|
||||||
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(agent, field, value)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_local_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/local/{agent_id}", response_model=dict)
|
|
||||||
async def delete_local_agent(
|
|
||||||
agent_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a local agent config. Associated run logs are cascade-deleted."""
|
|
||||||
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
await db.delete(agent)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud agent CRUD ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
|
|
||||||
async def list_cloud_agents(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[CloudAgentConfigResponse]:
|
|
||||||
"""List all cloud connector agent configs owned by the authenticated user."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
return [_to_cloud_response(a) for a in result.scalars().all()]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def create_cloud_agent(
|
|
||||||
body: CloudAgentConfigCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> CloudAgentConfigResponse:
|
|
||||||
"""Create a new cloud connector agent config.
|
|
||||||
|
|
||||||
The combined count of enabled local and cloud agents for the user is
|
|
||||||
checked against the ``batch_active`` limit for their billing tier.
|
|
||||||
"""
|
|
||||||
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
|
||||||
agent = CloudAgentConfig(
|
|
||||||
user_id=current_user.id,
|
|
||||||
provider=body.provider,
|
|
||||||
name=body.name,
|
|
||||||
data_types=body.data_types,
|
|
||||||
prompt_template=body.prompt_template,
|
|
||||||
oauth_token_encrypted=body.oauth_token_encrypted,
|
|
||||||
schedule_cron=body.schedule_cron,
|
|
||||||
filter_config=body.filter_config,
|
|
||||||
)
|
|
||||||
db.add(agent)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_cloud_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
|
|
||||||
async def update_cloud_agent(
|
|
||||||
agent_id: str,
|
|
||||||
body: CloudAgentConfigUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> CloudAgentConfigResponse:
|
|
||||||
"""Partially update a cloud agent config. Only provided fields are changed."""
|
|
||||||
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(agent, field, value)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_cloud_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/cloud/{agent_id}", response_model=dict)
|
|
||||||
async def delete_cloud_agent(
|
|
||||||
agent_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
|
|
||||||
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
await db.delete(agent)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Run logs ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/runs", response_model=_RunsPage)
|
|
||||||
async def list_run_logs(
|
|
||||||
agent_id: str | None = Query(default=None),
|
|
||||||
page: int = Query(default=1, ge=1),
|
|
||||||
limit: int = Query(default=20, ge=1, le=100),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> _RunsPage:
|
|
||||||
"""Return paginated run logs for the authenticated user.
|
|
||||||
|
|
||||||
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
|
|
||||||
"""
|
|
||||||
base_filter = [AgentRunLog.user_id == current_user.id]
|
|
||||||
if agent_id:
|
|
||||||
base_filter.append(AgentRunLog.agent_id == agent_id)
|
|
||||||
|
|
||||||
total = (
|
|
||||||
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
|
|
||||||
).scalar_one()
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(AgentRunLog)
|
|
||||||
.where(*base_filter)
|
|
||||||
.order_by(AgentRunLog.started_at.desc())
|
|
||||||
.offset((page - 1) * limit)
|
|
||||||
.limit(limit)
|
|
||||||
)
|
|
||||||
items = [_to_run_log_response(log) for log in result.scalars().all()]
|
|
||||||
|
|
||||||
return _RunsPage(total=total, page=page, limit=limit, items=items)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Manual trigger stub ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
|
||||||
async def trigger_agent_run(
|
async def trigger_agent_run(
|
||||||
agent_id: str,
|
body: AgentTriggerRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> AgentRunLogResponse:
|
) -> AgentRunLogResponse:
|
||||||
"""Manually trigger an agent run.
|
"""Trigger a local agent run using client-provided configuration."""
|
||||||
|
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||||
|
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
||||||
|
|
||||||
Looks up the agent config (local or cloud) by ID with ownership check,
|
config = LocalAgentConfig(
|
||||||
creates a run log entry with ``status="running"``, and returns it.
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=current_user.id,
|
||||||
|
device_id=body.device_id,
|
||||||
|
name="Local Directory Monitor",
|
||||||
|
directory_paths=[body.directory],
|
||||||
|
data_types=_to_data_types(body.what_to_extract),
|
||||||
|
prompt_template=body.custom_agent_prompt,
|
||||||
|
file_extensions=[],
|
||||||
|
schedule_cron=body.batch_interval,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
Actual dispatch to the agent runner is wired in Step 3.4 once
|
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
||||||
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
stable_agent_id = body.agent_id or config.id
|
||||||
"""
|
|
||||||
# Determine agent type by trying local first, then cloud.
|
|
||||||
# Keep the full config object so we can pass it to the agent runner.
|
|
||||||
local_config: LocalAgentConfig | None = None
|
|
||||||
cloud_config: CloudAgentConfig | None = None
|
|
||||||
|
|
||||||
local_result = await db.execute(
|
if is_agent_running(stable_agent_id):
|
||||||
select(LocalAgentConfig).where(
|
raise HTTPException(
|
||||||
LocalAgentConfig.id == agent_id,
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
LocalAgentConfig.user_id == current_user.id,
|
detail="Agent is already running. Only one run per agent is allowed at a time.",
|
||||||
)
|
)
|
||||||
)
|
|
||||||
local_config = local_result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if local_config is not None:
|
|
||||||
agent_type = "local"
|
|
||||||
else:
|
|
||||||
cloud_result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(
|
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == current_user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cloud_config = cloud_result.scalar_one_or_none()
|
|
||||||
if cloud_config is not None:
|
|
||||||
agent_type = "cloud"
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
|
|
||||||
run_log = AgentRunLog(
|
run_log = AgentRunLog(
|
||||||
agent_id=agent_id,
|
agent_id=stable_agent_id,
|
||||||
agent_type=agent_type,
|
agent_type="local",
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
status="running",
|
status="running",
|
||||||
)
|
)
|
||||||
@@ -439,14 +209,14 @@ async def trigger_agent_run(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(run_log)
|
await db.refresh(run_log)
|
||||||
|
|
||||||
# Dispatch the run as a background task — returns 202 immediately.
|
run_context = {
|
||||||
if agent_type == "local" and local_config is not None:
|
"type": "agent_batch",
|
||||||
|
"run_id": run_log.id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
}
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
||||||
)
|
|
||||||
elif agent_type == "cloud" and cloud_config is not None:
|
|
||||||
asyncio.create_task(
|
|
||||||
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return _to_run_log_response(run_log)
|
return _to_run_log_response(run_log)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from fastapi import APIRouter, Depends
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.core.orchestrator import orchestrate
|
from app.core.deep_agent import run_home
|
||||||
from app.schemas import ChatRequest, UserProfile
|
from app.schemas import ChatRequest, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
@@ -20,10 +20,10 @@ async def chat(
|
|||||||
body: ChatRequest,
|
body: ChatRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""Route a chat message through the orchestrator.
|
"""REST fallback for home chat when websocket streaming is unavailable."""
|
||||||
|
response = await run_home(
|
||||||
Returns ``ChatResponse`` for ``execution_mode='direct'``,
|
user_id=current_user.id,
|
||||||
or ``ExecutionPlan`` for ``execution_mode='plan'``.
|
message=body.message,
|
||||||
"""
|
context=body.context.model_dump(),
|
||||||
result = await orchestrate(body)
|
)
|
||||||
return JSONResponse(content=result.model_dump())
|
return JSONResponse(content={"response": response})
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ Protocol:
|
|||||||
|
|
||||||
Incoming frame dispatch:
|
Incoming frame dispatch:
|
||||||
- ``tool_result`` → resolves a pending tool-call Future.
|
- ``tool_result`` → resolves a pending tool-call Future.
|
||||||
- ``agent_data`` → enqueued in the per-run agent data queue.
|
- ``journey_start`` → starts a guided setup journey session.
|
||||||
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
- ``journey_message`` → continues a journey conversation.
|
||||||
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||||
- unknown types → logged, ignored.
|
- unknown types → logged, ignored.
|
||||||
|
|
||||||
@@ -39,12 +39,13 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
|
from app.core.deep_agent import run_floating_stream, run_home_stream
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.orchestrator import orchestrate_v3_stream
|
from app.core.output_formatter import StreamFormatter
|
||||||
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
@@ -147,37 +148,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
"device_ws: tool_result missing id from user=%s", user_id
|
"device_ws: tool_result missing id from user=%s", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.agent_data:
|
|
||||||
run_id = frame.get("run_id")
|
|
||||||
if run_id:
|
|
||||||
try:
|
|
||||||
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
|
||||||
await queue.put(frame)
|
|
||||||
except RuntimeError:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_data for unknown run user=%s run=%s",
|
|
||||||
user_id,
|
|
||||||
run_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_data missing run_id from user=%s", user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.agent_complete:
|
|
||||||
run_id = frame.get("run_id")
|
|
||||||
if run_id:
|
|
||||||
try:
|
|
||||||
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
|
||||||
# Sentinel: signals the agent data stream is finished.
|
|
||||||
await queue.put(None)
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_complete missing run_id from user=%s", user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.home_request:
|
elif frame_type == WsFrameType.home_request:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_handle_home_request(websocket, user_id, frame)
|
_handle_home_request(websocket, user_id, frame)
|
||||||
@@ -188,6 +158,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_floating_request(websocket, user_id, frame)
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.journey_start:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_journey_start(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.journey_message:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_journey_message(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -219,33 +199,37 @@ async def _handle_home_request(
|
|||||||
request_id = frame.get("request_id") or str(uuid4())
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
logger.info(
|
||||||
|
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message)
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
message,
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
agent_holder: list = []
|
|
||||||
try:
|
try:
|
||||||
token_stream = orchestrate_v3_stream(
|
event_stream = run_home_stream(user_id, message, context)
|
||||||
user_id, message, context, agent_holder=agent_holder
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
)
|
async for ws_frame in formatter.format(event_stream):
|
||||||
formatter = HomeFormatter(request_id=request_id, tool_results=[])
|
|
||||||
async for ws_frame in formatter.format(token_stream):
|
|
||||||
# Inject mutations from agent tool_results into stream_end
|
|
||||||
if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
|
|
||||||
ws_frame.mutations = [ # type: ignore[union-attr]
|
|
||||||
{"action": r["action"], "table": r["table"], "data": r["data"]}
|
|
||||||
for r in getattr(agent_holder[0], "tool_results", [])
|
|
||||||
]
|
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
# Collect text chunks to build the full response for episode storage
|
# Collect text chunks to build the full response for episode storage
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
@@ -262,7 +246,14 @@ async def _handle_home_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks)
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
len("".join(response_chunks)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -276,29 +267,38 @@ async def _handle_floating_request(
|
|||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
scope: dict = frame.get("scope", {})
|
scope: dict = frame.get("scope", {})
|
||||||
|
logger.info(
|
||||||
|
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
json.dumps(scope, ensure_ascii=True)[:200],
|
||||||
|
message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message)
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
message,
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
context: dict = {"scope": scope, **memory_context}
|
context: dict = {
|
||||||
|
"scope": scope,
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
agent_holder: list = []
|
|
||||||
try:
|
try:
|
||||||
token_stream = orchestrate_v3_stream(
|
event_stream = run_floating_stream(user_id, message, context)
|
||||||
user_id, message, context, agent_holder=agent_holder
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
)
|
async for ws_frame in formatter.format(event_stream):
|
||||||
formatter = FloatingFormatter(request_id=request_id)
|
|
||||||
async for ws_frame in formatter.format(token_stream):
|
|
||||||
if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
|
|
||||||
ws_frame.mutations = [ # type: ignore[union-attr]
|
|
||||||
{"action": r["action"], "table": r["table"], "data": r["data"]}
|
|
||||||
for r in getattr(agent_holder[0], "tool_results", [])
|
|
||||||
]
|
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
@@ -314,8 +314,72 @@ async def _handle_floating_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks)
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
len("".join(response_chunks)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_start(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a journey_start frame — explores directory and sends first question."""
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_start(user_id, frame)
|
||||||
|
await websocket.send_text(json.dumps(reply))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": frame.get("session_id", ""),
|
||||||
|
"message": f"Failed to start journey: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}))
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_message(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a journey_message frame — continues the journey conversation."""
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_message(user_id, frame)
|
||||||
|
await websocket.send_text(json.dumps(reply))
|
||||||
|
except Exception as exc:
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
logger.error(
|
||||||
|
"device_ws: journey_message failed user=%s session=%s: %s",
|
||||||
|
user_id, session_id, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": f"Journey error: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}))
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
@@ -351,6 +415,3 @@ async def _mark_runs_disconnected(user_id: str) -> None:
|
|||||||
user_id,
|
user_id,
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.core.execution_plan import plan_cache
|
|
||||||
from app.schemas import ExecutionPlan, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/plans", tags=["plans"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/playbook", response_model=list[ExecutionPlan])
|
|
||||||
async def list_playbooks(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> list[ExecutionPlan]:
|
|
||||||
"""Return all cached execution plan playbooks for the authenticated user.
|
|
||||||
|
|
||||||
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
|
|
||||||
"""
|
|
||||||
return plan_cache.get_all_playbooks()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
|
|
||||||
async def get_playbook(
|
|
||||||
plan_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> ExecutionPlan:
|
|
||||||
"""Return a specific execution plan playbook by ID."""
|
|
||||||
plan = plan_cache.get_plan(plan_id)
|
|
||||||
if plan is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Plan not found: {plan_id}",
|
|
||||||
)
|
|
||||||
return plan
|
|
||||||
@@ -21,6 +21,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"free": {
|
"free": {
|
||||||
"agents": 3,
|
"agents": 3,
|
||||||
"batch_active": 2,
|
"batch_active": 2,
|
||||||
|
"batch_runs_per_day": 5,
|
||||||
"cloud_storage_gb": 0,
|
"cloud_storage_gb": 0,
|
||||||
"backup_gb": 0,
|
"backup_gb": 0,
|
||||||
"providers": 1,
|
"providers": 1,
|
||||||
@@ -31,6 +32,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
"batch_active": 10,
|
"batch_active": 10,
|
||||||
|
"batch_runs_per_day": 50,
|
||||||
"cloud_storage_gb": 5,
|
"cloud_storage_gb": 5,
|
||||||
"backup_gb": 5,
|
"backup_gb": 5,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -41,6 +43,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1, # unlimited
|
"batch_active": -1, # unlimited
|
||||||
|
"batch_runs_per_day": -1, # unlimited
|
||||||
"cloud_storage_gb": 25,
|
"cloud_storage_gb": 25,
|
||||||
"backup_gb": 25,
|
"backup_gb": 25,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -51,6 +54,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1,
|
"batch_active": -1,
|
||||||
|
"batch_runs_per_day": -1, # unlimited
|
||||||
"cloud_storage_gb": -1, # unlimited
|
"cloud_storage_gb": -1, # unlimited
|
||||||
"backup_gb": -1, # unlimited
|
"backup_gb": -1, # unlimited
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -77,16 +81,18 @@ class TierManager:
|
|||||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
"""Return the current billing tier for ``user_id`` from the DB.
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
Falls back to ``'free'`` when no subscription row exists.
|
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
|
||||||
|
when no subscription row exists.
|
||||||
"""
|
"""
|
||||||
from app.models import Subscription # noqa: PLC0415
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str | None = result.scalar_one_or_none()
|
tier: str | None = result.scalar_one_or_none()
|
||||||
if tier is None or tier not in FEATURES:
|
if tier is None or tier not in FEATURES:
|
||||||
return "free"
|
return "power" if settings.ENV == "dev" else "free"
|
||||||
return tier # type: ignore[return-value]
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
# ── Feature access ───────────────────────────────────────────────────
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -27,9 +27,9 @@ class Settings(BaseSettings):
|
|||||||
ANTHROPIC_API_KEY: str = ""
|
ANTHROPIC_API_KEY: str = ""
|
||||||
GOOGLE_API_KEY: str = ""
|
GOOGLE_API_KEY: str = ""
|
||||||
CEREBRAS_API_KEY: str = ""
|
CEREBRAS_API_KEY: str = ""
|
||||||
|
GITHUB_TOKEN: str = ""
|
||||||
|
|
||||||
LLM_MODEL: str = "gpt-4o"
|
LLM_MODEL: str = "gpt-4o"
|
||||||
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
|
||||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
# GitHub Copilot OAuth token storage directory.
|
# GitHub Copilot OAuth token storage directory.
|
||||||
@@ -54,7 +54,9 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".env", env_file_encoding="utf-8", extra="ignore"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
"""Agent Registry — base classes and singleton registry for chat agents."""
|
"""Minimal agent base types retained for compatibility with batch runners."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(ABC):
|
class BaseAgent(ABC):
|
||||||
"""Common base for all agents."""
|
"""Common base for non-chat agents still using the old base contract."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -28,190 +27,4 @@ class BaseAgent(ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def skills(self) -> list[str]:
|
def skills(self) -> list[str]:
|
||||||
"""Override in subclasses to advertise capabilities."""
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(BaseAgent):
|
|
||||||
"""Base class for LLM-powered chat agents."""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
|
|
||||||
self.tool_results: list[dict] = []
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
"""Process a user query and return a text response."""
|
|
||||||
...
|
|
||||||
|
|
||||||
async def handle_stream(
|
|
||||||
self, query: str, context: dict[str, Any]
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Streaming variant of handle().
|
|
||||||
|
|
||||||
Default: calls handle() and yields the full response as one chunk.
|
|
||||||
Override in subclasses for true token-level streaming via _tool_loop_stream.
|
|
||||||
"""
|
|
||||||
yield await self.handle(query, context)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
"""Return LangChain tool definitions available to this agent."""
|
|
||||||
...
|
|
||||||
|
|
||||||
async def _tool_loop(
|
|
||||||
self,
|
|
||||||
llm: Any,
|
|
||||||
messages: list[Any],
|
|
||||||
tools: list[Any],
|
|
||||||
max_iter: int = 5,
|
|
||||||
) -> str:
|
|
||||||
"""Shared tool-calling loop.
|
|
||||||
|
|
||||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
|
||||||
requesting tool calls or *max_iter* is reached, and returns the
|
|
||||||
final text response. Captures raw execute_on_client results in
|
|
||||||
``self.tool_results``.
|
|
||||||
"""
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
|
||||||
|
|
||||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
|
||||||
|
|
||||||
collector: list[dict] = []
|
|
||||||
set_tool_result_collector(collector)
|
|
||||||
try:
|
|
||||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
||||||
|
|
||||||
for _ in range(max_iter):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
return str(response.content)
|
|
||||||
|
|
||||||
# Execute each requested tool call
|
|
||||||
tool_map = {t.name: t for t in tools}
|
|
||||||
for call in response.tool_calls:
|
|
||||||
tool_fn = tool_map.get(call["name"])
|
|
||||||
if tool_fn is None:
|
|
||||||
result = f"Unknown tool: {call['name']}"
|
|
||||||
else:
|
|
||||||
result = await tool_fn.ainvoke(call["args"])
|
|
||||||
messages.append(
|
|
||||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Exhausted iterations — ask model for a final answer without tools
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
return str(response.content)
|
|
||||||
finally:
|
|
||||||
clear_tool_result_collector()
|
|
||||||
self.tool_results = collector
|
|
||||||
|
|
||||||
async def _tool_loop_stream(
|
|
||||||
self,
|
|
||||||
llm: Any,
|
|
||||||
messages: list[Any],
|
|
||||||
tools: list[Any],
|
|
||||||
max_iter: int = 5,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Streaming variant of ``_tool_loop``.
|
|
||||||
|
|
||||||
Behaves identically for tool-calling iterations (uses ainvoke to parse
|
|
||||||
tool calls). For the final response — when the model produces no further
|
|
||||||
tool calls — switches to ``llm.astream()`` and yields text tokens.
|
|
||||||
Captures raw execute_on_client results in ``self.tool_results``.
|
|
||||||
"""
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
|
||||||
|
|
||||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
|
||||||
|
|
||||||
collector: list[dict] = []
|
|
||||||
set_tool_result_collector(collector)
|
|
||||||
try:
|
|
||||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
||||||
|
|
||||||
for _ in range(max_iter):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
# Stream the final answer — don't keep the ainvoke result.
|
|
||||||
async for chunk in llm.astream(messages):
|
|
||||||
if chunk.content:
|
|
||||||
yield str(chunk.content)
|
|
||||||
return
|
|
||||||
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
# Execute each requested tool call
|
|
||||||
tool_map = {t.name: t for t in tools}
|
|
||||||
for call in response.tool_calls:
|
|
||||||
tool_fn = tool_map.get(call["name"])
|
|
||||||
if tool_fn is None:
|
|
||||||
result = f"Unknown tool: {call['name']}"
|
|
||||||
else:
|
|
||||||
result = await tool_fn.ainvoke(call["args"])
|
|
||||||
messages.append(
|
|
||||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Exhausted iterations — stream a final answer without tools
|
|
||||||
async for chunk in llm.astream(messages):
|
|
||||||
if chunk.content:
|
|
||||||
yield str(chunk.content)
|
|
||||||
finally:
|
|
||||||
clear_tool_result_collector()
|
|
||||||
self.tool_results = collector
|
|
||||||
|
|
||||||
|
|
||||||
class AgentRegistry:
|
|
||||||
"""Singleton registry for ChatAgent subclasses."""
|
|
||||||
|
|
||||||
_instance: AgentRegistry | None = None
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._agents: dict[str, type[ChatAgent]] = {}
|
|
||||||
|
|
||||||
def __new__(cls) -> AgentRegistry:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
cls._instance._agents = {}
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
# ── public API ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
|
||||||
"""Class decorator — registers an agent by its name."""
|
|
||||||
instance = agent_class()
|
|
||||||
name = instance.get_name()
|
|
||||||
self._agents[name] = agent_class
|
|
||||||
return agent_class
|
|
||||||
|
|
||||||
def get(self, name: str) -> ChatAgent:
|
|
||||||
"""Return a fresh instance of the named agent."""
|
|
||||||
cls = self._agents.get(name)
|
|
||||||
if cls is None:
|
|
||||||
raise KeyError(f"Agent not found: {name}")
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
def list_agents(self) -> list[dict[str, str]]:
|
|
||||||
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
|
||||||
result: list[dict[str, str]] = []
|
|
||||||
for cls in self._agents.values():
|
|
||||||
inst = cls()
|
|
||||||
result.append(
|
|
||||||
{"name": inst.get_name(), "description": inst.get_description()}
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def call_agent(
|
|
||||||
self, name: str, query: str, context: dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""Instantiate the named agent and call its ``handle`` method."""
|
|
||||||
agent = self.get(name)
|
|
||||||
return await agent.handle(query, context)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
registry = AgentRegistry()
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
846
app/core/deep_agent.py
Normal file
846
app/core/deep_agent.py
Normal file
@@ -0,0 +1,846 @@
|
|||||||
|
"""Single-agent runners for home and floating chat contexts."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import date
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.agents.note_agent import NOTE_TOOLS
|
||||||
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
||||||
|
from app.db import async_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
||||||
|
FloatingDomainSection = Literal["task", "timeline", "note"]
|
||||||
|
|
||||||
|
_HOME_SINGLE_AGENT_SYSTEM = (
|
||||||
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
|
"Always use tools for factual data retrieval before answering. "
|
||||||
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||||
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
|
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||||
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
|
||||||
|
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
|
||||||
|
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
|
||||||
|
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
|
||||||
|
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
|
||||||
|
)
|
||||||
|
|
||||||
|
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
||||||
|
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
|
"Stay focused on the floating scope in context.scope and answer concisely. "
|
||||||
|
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
|
||||||
|
"Always use tools for factual data retrieval before answering. "
|
||||||
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||||
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
|
)
|
||||||
|
|
||||||
|
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
|
||||||
|
"You are a strict domain classifier for websocket floating requests. "
|
||||||
|
"Return ONLY a JSON object with keys: type, id, section. "
|
||||||
|
"Allowed type values: task, timeline, project, node. "
|
||||||
|
"Allowed section values: task, timeline, note, or null. "
|
||||||
|
"Rules: infer from user message intent first; do not blindly trust scope.type. "
|
||||||
|
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
|
||||||
|
"If project id is unknown but context.resolved_project_id exists, use it as id. "
|
||||||
|
"If id is unknown, use null. "
|
||||||
|
"No markdown, no prose, JSON only."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _candidate_tokens(message: str) -> list[str]:
|
||||||
|
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
|
||||||
|
return [token for token in tokens if len(token) >= 3]
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_project_id_from_message(message: str) -> str | None:
|
||||||
|
"""Resolve likely project UUID from user message using client project list."""
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="select", table="projects")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("deep_agent: project resolve select failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not isinstance(rows, list) or not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tokens = _candidate_tokens(message)
|
||||||
|
scored: list[tuple[int, dict[str, Any]]] = []
|
||||||
|
for row in rows:
|
||||||
|
if not isinstance(row, dict):
|
||||||
|
continue
|
||||||
|
name = str(row.get("name", "")).lower()
|
||||||
|
score = sum(1 for token in tokens if token in name)
|
||||||
|
if score > 0:
|
||||||
|
scored.append((score, row))
|
||||||
|
|
||||||
|
if not scored:
|
||||||
|
return None
|
||||||
|
|
||||||
|
scored.sort(key=lambda item: item[0], reverse=True)
|
||||||
|
top_score = scored[0][0]
|
||||||
|
top_rows = [row for score, row in scored if score == top_score]
|
||||||
|
if len(top_rows) != 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
project_id = top_rows[0].get("id")
|
||||||
|
return project_id if isinstance(project_id, str) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _needs_project_resolution(message: str) -> bool:
|
||||||
|
lowered = message.lower()
|
||||||
|
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
prepared = dict(context)
|
||||||
|
if _needs_project_resolution(message):
|
||||||
|
resolved_project_id = await _resolve_project_id_from_message(message)
|
||||||
|
if resolved_project_id:
|
||||||
|
prepared["resolved_project_id"] = resolved_project_id
|
||||||
|
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
|
||||||
|
return prepared
|
||||||
|
|
||||||
|
|
||||||
|
def _all_tools() -> list[Any]:
|
||||||
|
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
||||||
|
|
||||||
|
|
||||||
|
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
||||||
|
debug = context.get("_debug")
|
||||||
|
if isinstance(debug, dict):
|
||||||
|
request_id = debug.get("request_id")
|
||||||
|
if isinstance(request_id, str) and request_id:
|
||||||
|
return request_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
sanitized = dict(context)
|
||||||
|
sanitized.pop("_debug", None)
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
|
||||||
|
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_upcoming_timeline_query(message: str) -> bool:
|
||||||
|
lowered = message.lower()
|
||||||
|
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
|
||||||
|
has_timeline_topic = any(
|
||||||
|
token in lowered
|
||||||
|
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
|
||||||
|
)
|
||||||
|
return has_upcoming and has_timeline_topic
|
||||||
|
|
||||||
|
|
||||||
|
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
|
||||||
|
match = _TIMELINE_DMY_RE.search(dmy)
|
||||||
|
if not match:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
parsed = date(
|
||||||
|
int(match.group("y")),
|
||||||
|
int(match.group("m")),
|
||||||
|
int(match.group("d")),
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
return True
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
return parsed >= today and parsed.year == today.year and parsed.month == today.month
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_tagged_list_lines(text: str, message: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
upcoming_timeline_only = _is_upcoming_timeline_query(message)
|
||||||
|
output_lines: list[str] = []
|
||||||
|
|
||||||
|
for line in text.splitlines():
|
||||||
|
matches = list(_TAG_LINE_RE.finditer(line))
|
||||||
|
if not matches:
|
||||||
|
output_lines.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
|
||||||
|
if not had_non_tag_text and len(matches) == 1:
|
||||||
|
tag_text = matches[0].group(0)
|
||||||
|
if (
|
||||||
|
upcoming_timeline_only
|
||||||
|
and "<timeline>" in tag_text
|
||||||
|
and not _timeline_date_in_current_month_or_future(line)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
output_lines.append(tag_text)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
tag_text = match.group(0)
|
||||||
|
if (
|
||||||
|
upcoming_timeline_only
|
||||||
|
and "<timeline>" in tag_text
|
||||||
|
and not _timeline_date_in_current_month_or_future(line)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
output_lines.append(tag_text)
|
||||||
|
|
||||||
|
return "\n".join(output_lines)
|
||||||
|
|
||||||
|
|
||||||
|
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
|
||||||
|
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
|
||||||
|
_FLOATING_EMPTY_FALLBACK = "No results found."
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_floating_markup_fragment(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
cleaned = _GENERIC_TAG_RE.sub("", text)
|
||||||
|
return _BRACKETED_ID_RE.sub("", cleaned)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_floating_markup(text: str) -> str:
|
||||||
|
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
cleaned = _strip_floating_markup_fragment(text)
|
||||||
|
# Collapse excessive spaces introduced by tag/id removal while preserving lines.
|
||||||
|
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
|
||||||
|
return "\n".join(line for line in lines if line)
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_from_raw_floating_text(raw_text: str) -> str:
|
||||||
|
fallback = _strip_floating_markup_fragment(raw_text or "")
|
||||||
|
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
|
||||||
|
return fallback or _FLOATING_EMPTY_FALLBACK
|
||||||
|
|
||||||
|
|
||||||
|
class _FloatingStreamSanitizer:
|
||||||
|
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._pending = ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_safe_boundary(text: str) -> tuple[str, str]:
|
||||||
|
boundary = len(text)
|
||||||
|
|
||||||
|
last_lt = text.rfind("<")
|
||||||
|
if last_lt != -1 and ">" not in text[last_lt:]:
|
||||||
|
boundary = min(boundary, last_lt)
|
||||||
|
|
||||||
|
last_lb = text.rfind("[")
|
||||||
|
if last_lb != -1 and "]" not in text[last_lb:]:
|
||||||
|
boundary = min(boundary, last_lb)
|
||||||
|
|
||||||
|
if boundary == len(text):
|
||||||
|
return text, ""
|
||||||
|
return text[:boundary], text[boundary:]
|
||||||
|
|
||||||
|
def feed(self, chunk: str) -> str:
|
||||||
|
combined = f"{self._pending}{chunk}"
|
||||||
|
safe_text, self._pending = self._split_safe_boundary(combined)
|
||||||
|
return _strip_floating_markup_fragment(safe_text)
|
||||||
|
|
||||||
|
def finalize(self) -> str:
|
||||||
|
# Drop dangling unfinished wrappers at the very end.
|
||||||
|
tail = re.sub(r"<[^>\n]*$", "", self._pending)
|
||||||
|
tail = re.sub(r"\[[^\]\n]*$", "", tail)
|
||||||
|
self._pending = ""
|
||||||
|
return _strip_floating_markup_fragment(tail)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_memory_label(path_or_label: str) -> str:
|
||||||
|
value = path_or_label.strip()
|
||||||
|
if value.startswith("/memories/"):
|
||||||
|
value = value[len("/memories/"):]
|
||||||
|
value = value.strip("/")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
@tool
|
||||||
|
async def memory_list_blocks() -> str:
|
||||||
|
"""List all core memory blocks currently stored for the user."""
|
||||||
|
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
blocks = await memory.list_core_blocks(user_id)
|
||||||
|
if not blocks:
|
||||||
|
return "No memory blocks found."
|
||||||
|
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
|
||||||
|
return "Memory blocks:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_get(path_or_label: str) -> str:
|
||||||
|
"""Get one memory block by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
value = await memory.get_core_block(user_id, label)
|
||||||
|
if value is None:
|
||||||
|
return f"Memory block '{label}' not found."
|
||||||
|
return f"Memory block '{label}':\n{value}"
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_create(path_or_label: str, value: str) -> str:
|
||||||
|
"""Create or overwrite a memory block value by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.update_core(user_id, label, value, trace_id=trace_id)
|
||||||
|
return f"Memory block '{label}' saved."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_append(path_or_label: str, content: str) -> str:
|
||||||
|
"""Append content to a memory block, creating it if missing."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.append_core(user_id, label, content)
|
||||||
|
return f"Memory block '{label}' appended."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
|
||||||
|
"""Replace one exact string in a memory block."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
changed = await memory.replace_core(user_id, label, old_string, new_string)
|
||||||
|
if not changed:
|
||||||
|
return f"No replacement made in '{label}' (old string not found)."
|
||||||
|
return f"Memory block '{label}' updated."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_delete(path_or_label: str) -> str:
|
||||||
|
"""Delete a memory block by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
deleted = await memory.delete_core(user_id, label)
|
||||||
|
if not deleted:
|
||||||
|
return f"Memory block '{label}' not found."
|
||||||
|
return f"Memory block '{label}' deleted."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def archival_memory_insert(content: str) -> str:
|
||||||
|
"""Insert a long-term archival memory entry."""
|
||||||
|
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.insert_archival(user_id, content, source="assistant")
|
||||||
|
return "Archival memory saved."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def archival_memory_search(query: str, top_k: int = 5) -> str:
|
||||||
|
"""Search long-term archival memory by semantic fallback (keyword currently)."""
|
||||||
|
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
results = await memory.search_archival(user_id, query, top_k=top_k)
|
||||||
|
if not results:
|
||||||
|
return "No archival memory results found."
|
||||||
|
lines = [f"- {item}" for item in results]
|
||||||
|
return "Archival memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def conversation_search(query: str, top_k: int = 5) -> str:
|
||||||
|
"""Search recall memory from prior episodic conversation summaries."""
|
||||||
|
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
results = await memory.search_recall(user_id, query, top_k=top_k)
|
||||||
|
if not results:
|
||||||
|
return "No recall memory results found."
|
||||||
|
lines = [f"- {item}" for item in results]
|
||||||
|
return "Recall memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
return [
|
||||||
|
memory_list_blocks,
|
||||||
|
memory_get,
|
||||||
|
memory_create,
|
||||||
|
memory_append,
|
||||||
|
memory_replace,
|
||||||
|
memory_delete,
|
||||||
|
archival_memory_insert,
|
||||||
|
archival_memory_search,
|
||||||
|
conversation_search,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
|
||||||
|
lowered = message.lower()
|
||||||
|
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
|
||||||
|
return "timeline"
|
||||||
|
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
|
||||||
|
return "task"
|
||||||
|
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
|
||||||
|
return "note"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
|
||||||
|
type_raw = str(payload.get("type") or "").strip().lower()
|
||||||
|
domain_type: FloatingDomainType = "task"
|
||||||
|
if type_raw in {"task", "timeline", "project", "node"}:
|
||||||
|
domain_type = type_raw
|
||||||
|
|
||||||
|
id_value = payload.get("id")
|
||||||
|
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
|
||||||
|
if domain_type == "project" and not domain_id:
|
||||||
|
domain_id = fallback_id
|
||||||
|
|
||||||
|
section_raw = payload.get("section")
|
||||||
|
section: FloatingDomainSection | None = None
|
||||||
|
if isinstance(section_raw, str):
|
||||||
|
section_candidate = section_raw.strip().lower()
|
||||||
|
if section_candidate in {"task", "timeline", "note"}:
|
||||||
|
section = section_candidate
|
||||||
|
|
||||||
|
if domain_type != "project":
|
||||||
|
section = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": domain_type,
|
||||||
|
"id": domain_id,
|
||||||
|
"section": section,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_object(text: str) -> dict[str, Any] | None:
|
||||||
|
raw = text.strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = json.loads(match.group(0))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
||||||
|
section = _detect_domain_section(message)
|
||||||
|
scope = context.get("scope") if isinstance(context, dict) else None
|
||||||
|
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
||||||
|
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
||||||
|
|
||||||
|
if isinstance(scope, dict):
|
||||||
|
scope_type = str(scope.get("type") or "").strip().lower()
|
||||||
|
scope_id = scope.get("id")
|
||||||
|
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
|
||||||
|
|
||||||
|
if scope_type in {"task", "tasks"}:
|
||||||
|
return {"type": "task", "id": scope_id_value, "section": None}
|
||||||
|
if scope_type in {"project", "projects"}:
|
||||||
|
project_scope_id = scope_id_value or project_id
|
||||||
|
return {
|
||||||
|
"type": "project",
|
||||||
|
"id": project_scope_id,
|
||||||
|
"section": section,
|
||||||
|
}
|
||||||
|
if scope_type in {"note", "notes"}:
|
||||||
|
return {
|
||||||
|
"type": "node",
|
||||||
|
"id": scope_id_value,
|
||||||
|
"section": None,
|
||||||
|
}
|
||||||
|
if scope_type in {"timeline", "timelines"}:
|
||||||
|
return {"type": "timeline", "id": scope_id_value, "section": None}
|
||||||
|
|
||||||
|
lowered = message.lower()
|
||||||
|
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
|
||||||
|
return {
|
||||||
|
"type": "project",
|
||||||
|
"id": project_id,
|
||||||
|
"section": section,
|
||||||
|
}
|
||||||
|
if section == "timeline":
|
||||||
|
return {"type": "timeline", "id": None, "section": None}
|
||||||
|
if section == "note":
|
||||||
|
return {"type": "node", "id": None, "section": None}
|
||||||
|
return {"type": "task", "id": None, "section": None}
|
||||||
|
|
||||||
|
|
||||||
|
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
||||||
|
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
||||||
|
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
||||||
|
|
||||||
|
classifier_context = {
|
||||||
|
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
|
||||||
|
"resolved_project_id": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = get_llm()
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[
|
||||||
|
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_SYSTEM),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"Message:\n{message}\n\n"
|
||||||
|
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
parsed = _parse_json_object(_as_text(response.content))
|
||||||
|
if parsed is not None:
|
||||||
|
domain = _normalize_domain_payload(parsed, project_id)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
|
||||||
|
domain.get("type"),
|
||||||
|
domain.get("id"),
|
||||||
|
domain.get("section"),
|
||||||
|
)
|
||||||
|
return domain
|
||||||
|
logger.warning("deep_agent: floating_domain classifier returned non-json output")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
|
||||||
|
|
||||||
|
return _infer_floating_domain_rule_based(message, context)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_single_agent(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
system_prompt: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
max_steps: int = 6,
|
||||||
|
) -> str:
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
|
llm = get_llm()
|
||||||
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
|
model_context = _context_for_model(context)
|
||||||
|
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"User message:\n{message}\n\n"
|
||||||
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls_count = 0
|
||||||
|
collected: list[dict[str, Any]] = []
|
||||||
|
set_tool_result_collector(collected)
|
||||||
|
try:
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
final_text = _as_text(response.content)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
len(final_text),
|
||||||
|
)
|
||||||
|
return final_text
|
||||||
|
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_calls_count += 1
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:1200],
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
final_text = _as_text(final.content)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
len(final_text),
|
||||||
|
)
|
||||||
|
return final_text
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_single_agent_stream(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
system_prompt: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
max_steps: int = 6,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
|
llm = get_llm()
|
||||||
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
|
model_context = _context_for_model(context)
|
||||||
|
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"User message:\n{message}\n\n"
|
||||||
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls_count = 0
|
||||||
|
streamed_chars = 0
|
||||||
|
collected: list[dict[str, Any]] = []
|
||||||
|
set_tool_result_collector(collected)
|
||||||
|
try:
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
emitted_any = False
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
|
if token:
|
||||||
|
streamed_chars += len(token)
|
||||||
|
emitted_any = True
|
||||||
|
yield "token", token
|
||||||
|
|
||||||
|
# Some providers return final text in `response.content` but stream no chunks.
|
||||||
|
if not emitted_any:
|
||||||
|
fallback_text = _as_text(response.content)
|
||||||
|
if fallback_text:
|
||||||
|
streamed_chars += len(fallback_text)
|
||||||
|
yield "token", fallback_text
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
streamed_chars,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_calls_count += 1
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:1200],
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
|
if token:
|
||||||
|
streamed_chars += len(token)
|
||||||
|
yield "token", token
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
streamed_chars,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
response = await _run_single_agent(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
)
|
||||||
|
return _normalize_tagged_list_lines(response, message)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
domain = await _infer_floating_domain(message, prepared_context)
|
||||||
|
response = await _run_single_agent(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
)
|
||||||
|
sanitized = _strip_floating_markup(response)
|
||||||
|
if not sanitized and response:
|
||||||
|
sanitized = _fallback_from_raw_floating_text(response)
|
||||||
|
return sanitized, domain
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
text_chunks: list[str] = []
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
):
|
||||||
|
event_type, data = event
|
||||||
|
if event_type != "token":
|
||||||
|
yield event
|
||||||
|
continue
|
||||||
|
text_chunks.append(str(data or ""))
|
||||||
|
|
||||||
|
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
|
||||||
|
if normalized:
|
||||||
|
yield "token", normalized
|
||||||
|
|
||||||
|
|
||||||
|
async def run_floating_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
domain = await _infer_floating_domain(message, prepared_context)
|
||||||
|
yield "floating_domain", domain
|
||||||
|
|
||||||
|
sanitizer = _FloatingStreamSanitizer()
|
||||||
|
emitted_sanitized = False
|
||||||
|
raw_chunks: list[str] = []
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
):
|
||||||
|
event_type, data = event
|
||||||
|
if event_type != "token":
|
||||||
|
yield event
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_chunk = str(data or "")
|
||||||
|
raw_chunks.append(raw_chunk)
|
||||||
|
sanitized_chunk = sanitizer.feed(raw_chunk)
|
||||||
|
if sanitized_chunk:
|
||||||
|
emitted_sanitized = True
|
||||||
|
yield "token", sanitized_chunk
|
||||||
|
|
||||||
|
tail = sanitizer.finalize()
|
||||||
|
if tail:
|
||||||
|
emitted_sanitized = True
|
||||||
|
yield "token", tail
|
||||||
|
|
||||||
|
if not emitted_sanitized and raw_chunks:
|
||||||
|
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
||||||
|
|
||||||
|
|
||||||
|
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
||||||
|
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.update_core(user_id, key, value)
|
||||||
@@ -3,20 +3,15 @@
|
|||||||
Maintains in-memory state for all active Electron → backend WebSocket
|
Maintains in-memory state for all active Electron → backend WebSocket
|
||||||
connections. One connection per user (latest replaces previous).
|
connections. One connection per user (latest replaces previous).
|
||||||
|
|
||||||
The manager participates in two interaction patterns:
|
The manager handles the **tool-call round-trip** pattern:
|
||||||
|
- Backend sends ``tool_call`` frame → Electron executes the action →
|
||||||
1. **Tool-call round-trip** (bidirectional CRUD):
|
returns ``tool_result`` frame.
|
||||||
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
|
||||||
``tool_result`` frame.
|
|
||||||
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||||
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||||
receive the result dict from Electron.
|
receive the result dict from Electron.
|
||||||
|
|
||||||
2. **Agent-data streaming** (local directory agent runs):
|
This pattern is used by all tools (CRUD, file-system, etc.) via
|
||||||
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
``execute_on_client()`` in ``ws_context.py``.
|
||||||
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
|
||||||
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
|
||||||
a specific ``run_id`` so the agent runner can iterate frames.
|
|
||||||
|
|
||||||
The ``device_manager`` module-level singleton is imported by both the
|
The ``device_manager`` module-level singleton is imported by both the
|
||||||
device WS route and the agent runner.
|
device WS route and the agent runner.
|
||||||
@@ -42,8 +37,6 @@ class DeviceConnection:
|
|||||||
device_id: str
|
device_id: str
|
||||||
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
# Per-run queues for agent_data / agent_complete frames.
|
|
||||||
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceConnectionManager:
|
class DeviceConnectionManager:
|
||||||
@@ -153,31 +146,6 @@ class DeviceConnectionManager:
|
|||||||
if fut is not None and not fut.done():
|
if fut is not None and not fut.done():
|
||||||
fut.set_result(result)
|
fut.set_result(result)
|
||||||
|
|
||||||
# ── Agent-data queue ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
def get_agent_data_queue(
|
|
||||||
self, user_id: str, run_id: str
|
|
||||||
) -> asyncio.Queue[dict | None]:
|
|
||||||
"""Return (creating if absent) the queue for *run_id* agent frames.
|
|
||||||
|
|
||||||
The agent runner reads from this queue. The device WS handler writes
|
|
||||||
to it. ``None`` is the sentinel that signals the stream is finished.
|
|
||||||
"""
|
|
||||||
conn = self._connections.get(user_id)
|
|
||||||
if conn is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"get_agent_data_queue: user {user_id!r} is not connected"
|
|
||||||
)
|
|
||||||
if run_id not in conn.agent_data_queues:
|
|
||||||
conn.agent_data_queues[run_id] = asyncio.Queue()
|
|
||||||
return conn.agent_data_queues[run_id]
|
|
||||||
|
|
||||||
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
|
||||||
"""Remove the queue for *run_id* once a run has completed."""
|
|
||||||
conn = self._connections.get(user_id)
|
|
||||||
if conn:
|
|
||||||
conn.agent_data_queues.pop(run_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — import this everywhere.
|
# Module-level singleton — import this everywhere.
|
||||||
device_manager = DeviceConnectionManager()
|
device_manager = DeviceConnectionManager()
|
||||||
|
|||||||
@@ -1,222 +0,0 @@
|
|||||||
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.schemas import ExecutionPlan, PlanStep
|
|
||||||
|
|
||||||
|
|
||||||
# ── Prompt Template Registry ──────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplateRegistry:
|
|
||||||
"""Server-side store mapping template IDs to prompt text.
|
|
||||||
|
|
||||||
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
|
|
||||||
The actual prompt text is resolved here on the server, keeping prompt IP
|
|
||||||
out of API responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._templates: dict[str, str] = {}
|
|
||||||
|
|
||||||
def register(self, template_id: str, prompt_text: str) -> None:
|
|
||||||
self._templates[template_id] = prompt_text
|
|
||||||
|
|
||||||
def get(self, template_id: str) -> str:
|
|
||||||
"""Resolve a template ID to its prompt text.
|
|
||||||
|
|
||||||
Raises ``KeyError`` if the template is not registered.
|
|
||||||
"""
|
|
||||||
text = self._templates.get(template_id)
|
|
||||||
if text is None:
|
|
||||||
raise KeyError(f"Template not found: {template_id!r}")
|
|
||||||
return text
|
|
||||||
|
|
||||||
def has(self, template_id: str) -> bool:
|
|
||||||
return template_id in self._templates
|
|
||||||
|
|
||||||
def list_ids(self) -> list[str]:
|
|
||||||
"""Return all registered template IDs (never the text)."""
|
|
||||||
return list(self._templates.keys())
|
|
||||||
|
|
||||||
|
|
||||||
# ── Execution Plan Builder ────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlanBuilder:
|
|
||||||
"""Fluent builder for ``ExecutionPlan`` objects.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, agent: str) -> None:
|
|
||||||
self._agent = agent
|
|
||||||
self._steps: list[PlanStep] = []
|
|
||||||
|
|
||||||
# ── step adders ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def add_step(
|
|
||||||
self, action: str, params: dict[str, Any] | None = None
|
|
||||||
) -> ExecutionPlanBuilder:
|
|
||||||
"""Append a generic action step with optional parameters."""
|
|
||||||
self._steps.append(PlanStep(action=action, variables=params))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_llm_step(
|
|
||||||
self, template_id: str, variables: dict[str, Any] | None = None
|
|
||||||
) -> ExecutionPlanBuilder:
|
|
||||||
"""Append an LLM step referencing a server-side template by ID."""
|
|
||||||
self._steps.append(
|
|
||||||
PlanStep(action="llm", prompt_template=template_id, variables=variables)
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
|
|
||||||
"""Append a step whose input comes from the output of an earlier step."""
|
|
||||||
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
|
|
||||||
return self
|
|
||||||
|
|
||||||
# ── build ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def build(self) -> ExecutionPlan:
|
|
||||||
"""Validate step references and return the ``ExecutionPlan``.
|
|
||||||
|
|
||||||
Raises ``ValueError`` if any ``data_from_step`` references a
|
|
||||||
non-existent or future step index.
|
|
||||||
"""
|
|
||||||
for i, step in enumerate(self._steps):
|
|
||||||
if step.data_from_step is not None:
|
|
||||||
if not (0 <= step.data_from_step < i):
|
|
||||||
raise ValueError(
|
|
||||||
f"Step {i}: data_from_step={step.data_from_step} must "
|
|
||||||
f"reference a preceding step index in range 0..{i - 1}"
|
|
||||||
)
|
|
||||||
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
|
|
||||||
|
|
||||||
|
|
||||||
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PlanCache:
|
|
||||||
"""In-memory LRU cache for ``ExecutionPlan`` objects.
|
|
||||||
|
|
||||||
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
|
|
||||||
The cache also serves as a runtime memoisation layer so that repeated
|
|
||||||
identical intent classifications can skip re-building the plan.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, maxsize: int = 1000) -> None:
|
|
||||||
self._maxsize = maxsize
|
|
||||||
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
|
|
||||||
|
|
||||||
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
|
|
||||||
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
|
|
||||||
if key in self._cache:
|
|
||||||
del self._cache[key] # remove so re-insertion places it at the end
|
|
||||||
elif len(self._cache) >= self._maxsize:
|
|
||||||
self._cache.popitem(last=False) # evict least-recently-used
|
|
||||||
self._cache[key] = plan
|
|
||||||
|
|
||||||
def get_plan(self, key: str) -> ExecutionPlan | None:
|
|
||||||
"""Return the cached plan for *key*, or ``None`` if not present.
|
|
||||||
|
|
||||||
Accessing a plan marks it as most-recently used.
|
|
||||||
"""
|
|
||||||
if key not in self._cache:
|
|
||||||
return None
|
|
||||||
self._cache.move_to_end(key)
|
|
||||||
return self._cache[key]
|
|
||||||
|
|
||||||
def get_all_playbooks(self) -> list[ExecutionPlan]:
|
|
||||||
"""Return all cached plans (most-recently used last)."""
|
|
||||||
return list(self._cache.values())
|
|
||||||
|
|
||||||
|
|
||||||
# ── Module-level singletons ───────────────────────────────────────────
|
|
||||||
|
|
||||||
template_registry = PromptTemplateRegistry()
|
|
||||||
plan_cache = PlanCache()
|
|
||||||
|
|
||||||
|
|
||||||
def _register_builtin_templates() -> None:
|
|
||||||
"""Register the built-in server-side prompt templates.
|
|
||||||
|
|
||||||
These strings never leave the server. Clients only receive the IDs.
|
|
||||||
"""
|
|
||||||
_tpls: dict[str, str] = {
|
|
||||||
"tpl_task_agent_default": (
|
|
||||||
"You are a task management assistant. Help the user create, update, "
|
|
||||||
"list, and track tasks. Use correct status values (todo, in_progress, "
|
|
||||||
"done) and priority values (high, medium, low) from the workspace model."
|
|
||||||
),
|
|
||||||
"tpl_timeline_agent_default": (
|
|
||||||
"You are a project timeline assistant. Help the user create and manage "
|
|
||||||
"milestone timelines on their projects. Every timeline requires a "
|
|
||||||
"project_id and a date expressed as a Unix timestamp in milliseconds."
|
|
||||||
),
|
|
||||||
"tpl_project_agent_default": (
|
|
||||||
"You are a project management assistant. Help the user create, find, "
|
|
||||||
"update, and archive projects. Projects have a name, an optional client, "
|
|
||||||
"and a status of either active or archived."
|
|
||||||
),
|
|
||||||
"tpl_note_agent_default": (
|
|
||||||
"You are a note-taking assistant. Help the user create, retrieve, update, "
|
|
||||||
"and delete Markdown notes. Notes can optionally be linked to a project."
|
|
||||||
),
|
|
||||||
"tpl_task_extract_from_project": (
|
|
||||||
"Extract all actionable tasks from the provided project context. "
|
|
||||||
"Return a structured list of tasks, each with a title, inferred priority "
|
|
||||||
"(high, medium, or low), suggested status (todo), and a due_date in "
|
|
||||||
"milliseconds where a deadline can be inferred."
|
|
||||||
),
|
|
||||||
"tpl_note_weekly_summary": (
|
|
||||||
"Generate a weekly project summary note from the provided workspace data. "
|
|
||||||
"Include: tasks completed this week, tasks due soon, active projects, "
|
|
||||||
"and upcoming timelines. Format the output as clean Markdown."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
for tid, text in _tpls.items():
|
|
||||||
template_registry.register(tid, text)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_playbooks() -> None:
|
|
||||||
"""Pre-build and cache the built-in playbooks."""
|
|
||||||
playbooks: list[tuple[str, ExecutionPlan]] = [
|
|
||||||
(
|
|
||||||
"create_tasks_from_project",
|
|
||||||
ExecutionPlanBuilder("project_agent")
|
|
||||||
.add_llm_step(
|
|
||||||
"tpl_task_extract_from_project",
|
|
||||||
{"source": "project_context"},
|
|
||||||
)
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build(),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"generate_weekly_note",
|
|
||||||
ExecutionPlanBuilder("note_agent")
|
|
||||||
.add_llm_step(
|
|
||||||
"tpl_note_weekly_summary",
|
|
||||||
{"period": "last_7_days"},
|
|
||||||
)
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build(),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
for key, plan in playbooks:
|
|
||||||
plan_cache.cache_plan(key, plan)
|
|
||||||
|
|
||||||
|
|
||||||
# Initialise on module load
|
|
||||||
_register_builtin_templates()
|
|
||||||
_load_playbooks()
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
|
Every agent and the orchestrator call ``get_llm()``
|
||||||
instead of directly constructing a provider-specific class. The model string
|
instead of directly constructing a provider-specific class. The model string
|
||||||
follows the `LiteLLM model naming convention
|
follows the `LiteLLM model naming convention
|
||||||
<https://docs.litellm.ai/docs/providers>`_:
|
<https://docs.litellm.ai/docs/providers>`_:
|
||||||
@@ -11,13 +11,14 @@ follows the `LiteLLM model naming convention
|
|||||||
* Ollama: ``ollama/llama3``
|
* Ollama: ``ollama/llama3``
|
||||||
* Bedrock: ``bedrock/anthropic.claude-v2``
|
* Bedrock: ``bedrock/anthropic.claude-v2``
|
||||||
|
|
||||||
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
Switch providers by changing **LLM_MODEL** in ``.env``
|
||||||
— no code changes required.
|
— no code changes required.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import litellm
|
import litellm
|
||||||
@@ -32,6 +33,14 @@ from app.config.settings import settings
|
|||||||
# Drop them silently instead of raising UnsupportedParamsError.
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
# Some provider responses include a plain dict in the `usage` field where a
|
||||||
|
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
@@ -41,6 +50,8 @@ def _api_key_for_model(model: str) -> str | None:
|
|||||||
return settings.GOOGLE_API_KEY or None
|
return settings.GOOGLE_API_KEY or None
|
||||||
if model.startswith("cerebras/"):
|
if model.startswith("cerebras/"):
|
||||||
return settings.CEREBRAS_API_KEY or None
|
return settings.CEREBRAS_API_KEY or None
|
||||||
|
if model.startswith("github/"):
|
||||||
|
return settings.GITHUB_TOKEN or None
|
||||||
if model.startswith("github_copilot/"):
|
if model.startswith("github_copilot/"):
|
||||||
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
||||||
# No API key is required; returning None lets LiteLLM handle auth.
|
# No API key is required; returning None lets LiteLLM handle auth.
|
||||||
@@ -74,6 +85,9 @@ def get_llm(
|
|||||||
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||||
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||||
|
|
||||||
|
if settings.GITHUB_TOKEN:
|
||||||
|
os.environ.setdefault("GITHUB_TOKEN", settings.GITHUB_TOKEN)
|
||||||
|
|
||||||
# Use ChatLiteLLM for provider-prefixed models (github_copilot/, anthropic/, etc.)
|
# Use ChatLiteLLM for provider-prefixed models (github_copilot/, anthropic/, etc.)
|
||||||
# so LiteLLM handles routing and auth. ChatOpenAI for plain OpenAI model names.
|
# so LiteLLM handles routing and auth. ChatOpenAI for plain OpenAI model names.
|
||||||
if "/" in model:
|
if "/" in model:
|
||||||
@@ -86,14 +100,6 @@ def get_llm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_router_llm(
|
|
||||||
*,
|
|
||||||
temperature: float = 0,
|
|
||||||
) -> ChatOpenAI | ChatLiteLLM:
|
|
||||||
"""Return the lighter model used for intent classification / routing."""
|
|
||||||
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
|
||||||
|
|
||||||
|
|
||||||
async def embed(text: str) -> list[float]:
|
async def embed(text: str) -> list[float]:
|
||||||
"""Return an embedding vector for *text*.
|
"""Return an embedding vector for *text*.
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,13 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────────────
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
async def enrich_context(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Build memory context dict to inject into the orchestrator before LLM call.
|
"""Build memory context dict to inject into the orchestrator before LLM call.
|
||||||
|
|
||||||
Returns a dict with keys:
|
Returns a dict with keys:
|
||||||
@@ -65,9 +71,21 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
core = await self._load_core(user_id, fernet)
|
||||||
associative = await self._load_associative(user_id, message, fernet)
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
episodic = await self._load_episodic(user_id, fernet)
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
len(core),
|
||||||
|
len(associative),
|
||||||
|
len(episodic),
|
||||||
|
len(proactive),
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"core_memory": core,
|
"core_memory": core,
|
||||||
"associative_memory": associative,
|
"associative_memory": associative,
|
||||||
@@ -81,6 +99,7 @@ class MemoryMiddleware:
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
response: str,
|
response: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Summarise and store a completed interaction in episodic memory.
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
@@ -103,11 +122,19 @@ class MemoryMiddleware:
|
|||||||
self._db.add(row)
|
self._db.add(row)
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||||
"""Upsert a core memory key/value for a user."""
|
"""Upsert a core memory key/value for a user."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -133,10 +160,176 @@ class MemoryMiddleware:
|
|||||||
))
|
))
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
key,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
||||||
|
"""Return core memory as editable blocks (label/value)."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore)
|
||||||
|
.where(MemoryCore.user_id == user_id)
|
||||||
|
.order_by(MemoryCore.key.asc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[dict[str, str]] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append({"label": row.key, "value": plaintext})
|
||||||
|
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
||||||
|
"""Return a single core memory block value by label."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
||||||
|
return None
|
||||||
|
value = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def delete_core(self, user_id: str, label: str) -> bool:
|
||||||
|
"""Delete a core memory block by label. Returns True if deleted."""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._db.delete(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
||||||
|
"""Append content to a core block, creating it if missing."""
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None:
|
||||||
|
await self.update_core(user_id, label, content)
|
||||||
|
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
||||||
|
return
|
||||||
|
await self.update_core(user_id, label, f"{current}\n{content}")
|
||||||
|
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
||||||
|
|
||||||
|
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
||||||
|
"""Replace one exact string inside a core block. Returns False if not found."""
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None or old not in current:
|
||||||
|
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
||||||
|
return False
|
||||||
|
await self.update_core(user_id, label, current.replace(old, new, 1))
|
||||||
|
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||||
|
"""Insert a long-term archival memory entry."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, content)
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
content_encrypted=encrypted,
|
||||||
|
embedding=None,
|
||||||
|
entity_type=source,
|
||||||
|
entity_id=None,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc())
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
"""Search recall memory (episodic summaries) by keyword."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
# ── Private helpers ───────────────────────────────────────────────────────
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
@@ -148,6 +341,16 @@ class MemoryMiddleware:
|
|||||||
return None
|
return None
|
||||||
return Fernet(user.encryption_key.encode())
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
|
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||||
|
"""Load lightweight user debug fields for trace logs."""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
return {"tier": None}
|
||||||
|
return {
|
||||||
|
"tier": user.tier,
|
||||||
|
}
|
||||||
|
|
||||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
@@ -183,10 +386,17 @@ class MemoryMiddleware:
|
|||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
async def _load_episodic(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
fernet: Fernet,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||||
|
if session_id:
|
||||||
|
query = query.where(MemoryEpisodic.session_id == session_id)
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryEpisodic)
|
query
|
||||||
.where(MemoryEpisodic.user_id == user_id)
|
|
||||||
.order_by(MemoryEpisodic.created_at.desc())
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
.limit(_EPISODIC_RECENT_N)
|
.limit(_EPISODIC_RECENT_N)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,210 +0,0 @@
|
|||||||
"""Orchestrator — LLM-based intent router and agent pipeline."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, AsyncGenerator
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
|
||||||
from app.core.llm import get_router_llm
|
|
||||||
from app.core.agent_registry import registry as _default_registry
|
|
||||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
|
||||||
|
|
||||||
_FALLBACK_AGENT = "task_agent"
|
|
||||||
|
|
||||||
_CLASSIFY_SYSTEM = (
|
|
||||||
"You are an intent classifier. Given the user message and context, decide "
|
|
||||||
"which agent to route to.\n"
|
|
||||||
"Available agents: {agents}\n"
|
|
||||||
"Respond with just the agent name, nothing else."
|
|
||||||
)
|
|
||||||
|
|
||||||
_SYNTHESIZE_HUMAN = (
|
|
||||||
"Combine the following agent results into one coherent response.\n\n"
|
|
||||||
"Agent results:\n{results}\n\n"
|
|
||||||
"Original message: {message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_llm():
|
|
||||||
return get_router_llm()
|
|
||||||
|
|
||||||
|
|
||||||
async def classify_intent(
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> str:
|
|
||||||
"""Use gpt-4o-mini to classify intent and return the matching agent name.
|
|
||||||
|
|
||||||
Falls back to ``task_agent`` when the registry is empty or the model
|
|
||||||
returns a name that is not registered.
|
|
||||||
"""
|
|
||||||
agents = reg.list_agents()
|
|
||||||
if not agents:
|
|
||||||
return _FALLBACK_AGENT
|
|
||||||
|
|
||||||
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
|
|
||||||
# Truncate context to keep the classification prompt short
|
|
||||||
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
|
|
||||||
|
|
||||||
llm = _make_llm()
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[SystemMessage(content=system), HumanMessage(content=human)]
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_name = str(response.content).strip().lower()
|
|
||||||
known = {a["name"] for a in agents}
|
|
||||||
return agent_name if agent_name in known else _FALLBACK_AGENT
|
|
||||||
|
|
||||||
|
|
||||||
async def route_single(
|
|
||||||
agent_name: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
|
|
||||||
response_text = await reg.call_agent(agent_name, message, context)
|
|
||||||
return ChatResponse(response=response_text)
|
|
||||||
|
|
||||||
|
|
||||||
async def route_pipeline(
|
|
||||||
agent_names: list[str],
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""Execute agents sequentially; each agent receives previous results in context.
|
|
||||||
|
|
||||||
A final LLM synthesis call merges all results into one coherent response.
|
|
||||||
"""
|
|
||||||
previous_results: list[str] = []
|
|
||||||
|
|
||||||
for agent_name in agent_names:
|
|
||||||
ctx = {**context, "previous_results": list(previous_results)}
|
|
||||||
result = await reg.call_agent(agent_name, message, ctx)
|
|
||||||
previous_results.append(result)
|
|
||||||
|
|
||||||
results_str = "\n\n".join(
|
|
||||||
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
|
|
||||||
)
|
|
||||||
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
|
|
||||||
llm = _make_llm()
|
|
||||||
synthesis = await llm.ainvoke([HumanMessage(content=human)])
|
|
||||||
return ChatResponse(response=str(synthesis.content))
|
|
||||||
|
|
||||||
|
|
||||||
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
|
|
||||||
"""Build an ``ExecutionPlan`` for the resolved agent.
|
|
||||||
|
|
||||||
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
|
|
||||||
If a default template exists for the agent, an LLM step is emitted;
|
|
||||||
otherwise a plain ``handle`` action step is used.
|
|
||||||
"""
|
|
||||||
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
|
|
||||||
|
|
||||||
template_id = f"tpl_{agent_name}_default"
|
|
||||||
builder = ExecutionPlanBuilder(agent_name)
|
|
||||||
if template_registry.has(template_id):
|
|
||||||
builder.add_llm_step(template_id, {"message": message})
|
|
||||||
else:
|
|
||||||
builder.add_step("handle", {"message": message})
|
|
||||||
return builder.build()
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate(
|
|
||||||
request: ChatRequest,
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> ChatResponse | ExecutionPlan:
|
|
||||||
"""Main orchestration entry point.
|
|
||||||
|
|
||||||
* Classifies the user's intent to select an agent.
|
|
||||||
* ``execution_mode == 'direct'``: routes to the agent and returns a
|
|
||||||
``ChatResponse``.
|
|
||||||
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
|
|
||||||
resolved agent and a template-ID-only step (prompt IP stays server-side).
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
|
|
||||||
context = request.context.model_dump()
|
|
||||||
agent_name = await classify_intent(request.message, context, reg)
|
|
||||||
|
|
||||||
if request.execution_mode == "direct":
|
|
||||||
return await route_single(agent_name, request.message, context, reg)
|
|
||||||
|
|
||||||
# plan mode — return plan, do not execute
|
|
||||||
return _build_plan(agent_name, request.message)
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_v3(
|
|
||||||
user_id: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> tuple[str, ChatAgent]:
|
|
||||||
"""v3 orchestration — returns (agent_name, agent_instance); caller drives execution.
|
|
||||||
|
|
||||||
Classifies intent and instantiates the matching agent. The caller is responsible
|
|
||||||
for invoking handle(), handle_stream(), or _tool_loop_stream() as needed.
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
agent_name = await classify_intent(message, context, reg)
|
|
||||||
return agent_name, reg.get(agent_name)
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_v3_stream(
|
|
||||||
user_id: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
agent_holder: list | None = None,
|
|
||||||
) -> AsyncGenerator[tuple[str, str], None]:
|
|
||||||
"""v3 streaming orchestration — yields (agent_name, token) pairs.
|
|
||||||
|
|
||||||
The first yield always carries the agent_name with an empty token so that
|
|
||||||
callers (e.g. FloatingFormatter) can detect the routing domain before any text
|
|
||||||
tokens arrive.
|
|
||||||
|
|
||||||
If *agent_holder* is provided (a list), the agent instance is appended so
|
|
||||||
callers can access ``agent.tool_results`` after the stream completes.
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
agent_name = await classify_intent(message, context, reg)
|
|
||||||
agent = reg.get(agent_name)
|
|
||||||
if agent_holder is not None:
|
|
||||||
agent_holder.append(agent)
|
|
||||||
yield agent_name, "" # domain signal — no token yet
|
|
||||||
async for token in agent.handle_stream(message, context):
|
|
||||||
yield agent_name, token
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_stream(
|
|
||||||
request: ChatRequest,
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Streaming orchestration — yields plain text chunks only.
|
|
||||||
|
|
||||||
The WebSocket handler in ``app/api/routes/chat.py`` is responsible for
|
|
||||||
wrapping each chunk in a ``text_chunk`` frame and sending the final
|
|
||||||
``final`` frame once the generator is exhausted.
|
|
||||||
|
|
||||||
Agents do not yet support token-level streaming; the full response is
|
|
||||||
fetched first (which may involve multiple WS round-trips for tool calls),
|
|
||||||
then emitted in fixed-size chunks.
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
|
|
||||||
context = request.context.model_dump()
|
|
||||||
agent_name = await classify_intent(request.message, context, reg)
|
|
||||||
response_text = await reg.call_agent(agent_name, request.message, context)
|
|
||||||
|
|
||||||
chunk_size = 50
|
|
||||||
for i in range(0, len(response_text), chunk_size):
|
|
||||||
yield response_text[i : i + chunk_size]
|
|
||||||
@@ -1,244 +1,47 @@
|
|||||||
"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
|
"""Output formatter for deep-agent stream events."""
|
||||||
|
|
||||||
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
|
|
||||||
FloatingFormatter: produces floating_domain, stream_text, stream_end
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.schemas import (
|
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
WsFloatingDomain,
|
|
||||||
WsStreamBlock,
|
|
||||||
WsStreamEnd,
|
|
||||||
WsStreamStart,
|
|
||||||
WsStreamText,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
|
|
||||||
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
|
|
||||||
|
|
||||||
# Map agent name → floating domain
|
|
||||||
_AGENT_DOMAIN: dict[str, str] = {
|
|
||||||
"task_agent": "tasks",
|
|
||||||
"timeline_agent": "timelines",
|
|
||||||
"note_agent": "notes",
|
|
||||||
"project_agent": "projects",
|
|
||||||
}
|
|
||||||
|
|
||||||
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
|
|
||||||
|
|
||||||
|
|
||||||
class HomeFormatter:
|
class StreamFormatter:
|
||||||
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
||||||
|
|
||||||
The LLM is expected to output a newline-delimited sequence of JSON objects,
|
|
||||||
each with a ``type`` field:
|
|
||||||
- ``text`` → yields WsStreamText immediately (word-by-word)
|
|
||||||
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
|
|
||||||
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
|
|
||||||
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
|
|
||||||
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
|
|
||||||
|
|
||||||
Invalid or unknown blocks are logged and skipped — stream never crashes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
|
|
||||||
self.request_id = request_id
|
|
||||||
self.tool_results = tool_results
|
|
||||||
|
|
||||||
async def format(
|
|
||||||
self,
|
|
||||||
token_stream: AsyncGenerator[tuple[str, str], None],
|
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
|
||||||
|
|
||||||
buffer = ""
|
|
||||||
async for _agent_name, token in token_stream:
|
|
||||||
if not token:
|
|
||||||
continue
|
|
||||||
buffer += token
|
|
||||||
# Flush any complete JSON objects from the buffer
|
|
||||||
async for frame in self._flush_complete_objects(buffer):
|
|
||||||
buffer = "" # reset after flush
|
|
||||||
yield frame
|
|
||||||
break # only one flush per iteration; rest accumulates
|
|
||||||
|
|
||||||
# Flush any remaining content
|
|
||||||
if buffer.strip():
|
|
||||||
async for frame in self._flush_complete_objects(buffer, final=True):
|
|
||||||
yield frame
|
|
||||||
|
|
||||||
yield WsStreamEnd(request_id=self.request_id)
|
|
||||||
|
|
||||||
async def _flush_complete_objects(
|
|
||||||
self, text: str, final: bool = False
|
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
|
||||||
"""Try to parse and yield all complete JSON objects from *text*.
|
|
||||||
|
|
||||||
Yields nothing if text is incomplete JSON (unless *final* is True,
|
|
||||||
in which case remaining text is emitted as plain stream_text).
|
|
||||||
"""
|
|
||||||
remaining = text.strip()
|
|
||||||
while remaining:
|
|
||||||
# Fast path: plain text (not JSON)
|
|
||||||
if not remaining.startswith("{"):
|
|
||||||
# Yield as plain text chunk
|
|
||||||
newline_idx = remaining.find("\n")
|
|
||||||
if newline_idx == -1:
|
|
||||||
if final:
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
|
||||||
remaining = ""
|
|
||||||
else:
|
|
||||||
return # accumulate more
|
|
||||||
else:
|
|
||||||
line = remaining[:newline_idx].strip()
|
|
||||||
remaining = remaining[newline_idx + 1:].strip()
|
|
||||||
if line:
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=line)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Try to decode a JSON object
|
|
||||||
try:
|
|
||||||
obj, end_idx = _try_parse_json(remaining)
|
|
||||||
except ValueError:
|
|
||||||
if final:
|
|
||||||
# Emit as raw text if we can't parse
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
|
||||||
remaining = ""
|
|
||||||
return
|
|
||||||
|
|
||||||
if obj is None:
|
|
||||||
if final:
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
|
||||||
remaining = ""
|
|
||||||
return # incomplete — need more tokens
|
|
||||||
|
|
||||||
remaining = remaining[end_idx:].strip()
|
|
||||||
block_type = obj.get("type")
|
|
||||||
|
|
||||||
frame = self._dispatch_block(obj, block_type)
|
|
||||||
if frame is not None:
|
|
||||||
yield frame
|
|
||||||
|
|
||||||
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
|
|
||||||
if block_type == "text":
|
|
||||||
content = obj.get("content", "")
|
|
||||||
if content:
|
|
||||||
return WsStreamText(request_id=self.request_id, chunk=str(content))
|
|
||||||
return None
|
|
||||||
|
|
||||||
if block_type == "chart":
|
|
||||||
chart_type = obj.get("chartType")
|
|
||||||
if chart_type not in _VALID_CHART_TYPES:
|
|
||||||
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
|
|
||||||
return None
|
|
||||||
if not isinstance(obj.get("data"), list):
|
|
||||||
logger.warning("HomeFormatter: chart missing data array — skipping")
|
|
||||||
return None
|
|
||||||
return WsStreamBlock(
|
|
||||||
request_id=self.request_id,
|
|
||||||
block_type="chart",
|
|
||||||
data=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
if block_type == "entity_ref":
|
|
||||||
entity = obj.get("entity")
|
|
||||||
resolved = self._resolve_entity(entity)
|
|
||||||
if resolved is None:
|
|
||||||
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
|
|
||||||
return None
|
|
||||||
return WsStreamBlock(
|
|
||||||
request_id=self.request_id,
|
|
||||||
block_type="entity_ref",
|
|
||||||
data={"entity": entity, "items": resolved},
|
|
||||||
)
|
|
||||||
|
|
||||||
if block_type == "table":
|
|
||||||
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
|
|
||||||
logger.warning("HomeFormatter: table missing headers/rows — skipping")
|
|
||||||
return None
|
|
||||||
return WsStreamBlock(
|
|
||||||
request_id=self.request_id,
|
|
||||||
block_type="table",
|
|
||||||
data=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
if block_type == "timeline":
|
|
||||||
if not isinstance(obj.get("timelines"), list):
|
|
||||||
logger.warning("HomeFormatter: timeline missing timelines — skipping")
|
|
||||||
return None
|
|
||||||
return WsStreamBlock(
|
|
||||||
request_id=self.request_id,
|
|
||||||
block_type="timeline",
|
|
||||||
data=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
|
|
||||||
"""Find matching items in tool_results by entity type."""
|
|
||||||
if not entity:
|
|
||||||
return None
|
|
||||||
matches = [r for r in self.tool_results if r.get("entity") == entity]
|
|
||||||
return matches if matches else None
|
|
||||||
|
|
||||||
|
|
||||||
class FloatingFormatter:
|
|
||||||
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
|
||||||
|
|
||||||
Emits floating_domain immediately (from agent_name), then streams all tokens
|
|
||||||
as plain stream_text — no block parsing for floating context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: str) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
|
||||||
async def format(
|
async def format(
|
||||||
self,
|
self,
|
||||||
token_stream: AsyncGenerator[tuple[str, str], None],
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
domain_sent = False
|
started = False
|
||||||
|
|
||||||
async for agent_name, token in token_stream:
|
async for event_type, data in event_stream:
|
||||||
if not domain_sent:
|
if event_type == "floating_domain":
|
||||||
domain = _AGENT_DOMAIN.get(agent_name, "tasks")
|
if isinstance(data, dict):
|
||||||
yield WsFloatingDomain(
|
yield WsFloatingDomain(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
domain=domain, # type: ignore[arg-type]
|
domain=data,
|
||||||
)
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event_type != "token":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not started:
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
domain_sent = True
|
started = True
|
||||||
|
|
||||||
if token:
|
text = str(data or "")
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=token)
|
if text:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=text)
|
||||||
|
|
||||||
|
if not started:
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
yield WsStreamEnd(request_id=self.request_id)
|
yield WsStreamEnd(request_id=self.request_id)
|
||||||
|
|
||||||
|
|
||||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
|
|
||||||
"""Attempt to parse the first complete JSON object from *text*.
|
|
||||||
|
|
||||||
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
|
|
||||||
object is incomplete, and raises ``ValueError`` when text is not JSON.
|
|
||||||
"""
|
|
||||||
decoder = json.JSONDecoder()
|
|
||||||
try:
|
|
||||||
obj, end_idx = decoder.raw_decode(text)
|
|
||||||
if not isinstance(obj, dict):
|
|
||||||
raise ValueError("Expected JSON object")
|
|
||||||
return obj, end_idx
|
|
||||||
except json.JSONDecodeError as exc:
|
|
||||||
# Incomplete JSON — need more tokens
|
|
||||||
if "Unterminated" in str(exc) or exc.pos == len(text):
|
|
||||||
return None, 0
|
|
||||||
raise ValueError(str(exc)) from exc
|
|
||||||
|
|||||||
@@ -18,9 +18,8 @@ from app.config.settings import settings
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup: initialise DB connection pool and agent registry
|
# Startup: ensure agent tool modules are loaded.
|
||||||
from app.core.agent_registry import registry # noqa: F401 — triggers module load
|
import app.agents # noqa: F401
|
||||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -51,18 +50,16 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors
|
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
app.include_router(plans.router, prefix="/api/v1")
|
|
||||||
app.include_router(storage.router, prefix="/api/v1")
|
app.include_router(storage.router, prefix="/api/v1")
|
||||||
app.include_router(vectors.router, prefix="/api/v1")
|
app.include_router(vectors.router, prefix="/api/v1")
|
||||||
app.include_router(backup.router, prefix="/api/v1")
|
app.include_router(backup.router, prefix="/api/v1")
|
||||||
app.include_router(plugins.router, prefix="/api/v1")
|
app.include_router(plugins.router, prefix="/api/v1")
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
app.include_router(agent_setup.router, prefix="/api/v1")
|
|
||||||
app.include_router(device_ws.router, prefix="/api/v1")
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
|||||||
184
app/schemas.py
184
app/schemas.py
@@ -41,41 +41,13 @@ class ChatContext(BaseModel):
|
|||||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class PlanAction(BaseModel):
|
|
||||||
type: Literal[
|
|
||||||
"create_record",
|
|
||||||
"update_record",
|
|
||||||
"delete_record",
|
|
||||||
"index_document",
|
|
||||||
"send_notification",
|
|
||||||
]
|
|
||||||
table: str | None = None
|
|
||||||
data: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
context: ChatContext = Field(default_factory=ChatContext)
|
context: ChatContext = Field(default_factory=ChatContext)
|
||||||
execution_mode: Literal["direct", "plan"] = "direct"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
class ChatResponse(BaseModel):
|
||||||
response: str
|
response: str
|
||||||
actions: list[PlanAction] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Execution Plans ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class PlanStep(BaseModel):
|
|
||||||
action: str
|
|
||||||
prompt_template: str | None = None
|
|
||||||
variables: dict[str, Any] | None = None
|
|
||||||
data_from_step: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlan(BaseModel):
|
|
||||||
agent: str
|
|
||||||
steps: list[PlanStep] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Backup ───────────────────────────────────────────────────────────
|
# ── Backup ───────────────────────────────────────────────────────────
|
||||||
@@ -170,21 +142,21 @@ class WsFrameType(str, Enum):
|
|||||||
tool_result = "tool_result"
|
tool_result = "tool_result"
|
||||||
final = "final"
|
final = "final"
|
||||||
ping = "ping"
|
ping = "ping"
|
||||||
agent_run = "agent_run"
|
|
||||||
agent_data = "agent_data"
|
|
||||||
agent_complete = "agent_complete"
|
|
||||||
device_hello = "device_hello"
|
device_hello = "device_hello"
|
||||||
# ── v3 frame types ─────────────────────────────────────────────────
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
home_request = "home_request"
|
home_request = "home_request"
|
||||||
floating_request = "floating_request"
|
floating_request = "floating_request"
|
||||||
stream_start = "stream_start"
|
stream_start = "stream_start"
|
||||||
stream_text = "stream_text"
|
stream_text = "stream_text"
|
||||||
stream_block = "stream_block"
|
|
||||||
stream_end = "stream_end"
|
stream_end = "stream_end"
|
||||||
floating_domain = "floating_domain"
|
floating_domain = "floating_domain"
|
||||||
data_request = "data_request"
|
data_request = "data_request"
|
||||||
data_response = "data_response"
|
data_response = "data_response"
|
||||||
mutation = "mutation"
|
mutation = "mutation"
|
||||||
|
# ── v4 journey frame types ────────────────────────────────────────
|
||||||
|
journey_start = "journey_start"
|
||||||
|
journey_message = "journey_message"
|
||||||
|
journey_reply = "journey_reply"
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -237,31 +209,6 @@ class WsDeviceHello(BaseModel):
|
|||||||
agent_ids: list[str] = Field(default_factory=list)
|
agent_ids: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class WsAgentRun(BaseModel):
|
|
||||||
"""Server → Client: trigger an agent run on the connected device."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_run] = WsFrameType.agent_run
|
|
||||||
run_id: str
|
|
||||||
agent_id: str
|
|
||||||
config: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class WsAgentData(BaseModel):
|
|
||||||
"""Client → Server: files read by the local agent."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_data] = WsFrameType.agent_data
|
|
||||||
run_id: str
|
|
||||||
files: list[dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class WsAgentComplete(BaseModel):
|
|
||||||
"""Client → Server: Electron signals it has finished reading files."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_complete] = WsFrameType.agent_complete
|
|
||||||
run_id: str
|
|
||||||
files_read: int
|
|
||||||
errors: list[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
@@ -303,21 +250,19 @@ class WsStreamText(BaseModel):
|
|||||||
chunk: str
|
chunk: str
|
||||||
|
|
||||||
|
|
||||||
class WsStreamBlock(BaseModel):
|
|
||||||
"""Server → Client: structured block (chart, table, entity, timeline)."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.stream_block] = WsFrameType.stream_block
|
|
||||||
request_id: str
|
|
||||||
block_type: Literal["chart", "entity_ref", "table", "timeline"]
|
|
||||||
data: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class WsStreamEnd(BaseModel):
|
class WsStreamEnd(BaseModel):
|
||||||
"""Server → Client: signals end of a streaming response."""
|
"""Server → Client: signals end of a streaming response."""
|
||||||
|
|
||||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
request_id: str
|
request_id: str
|
||||||
mutations: list[dict[str, Any]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
class WsDomain(BaseModel):
|
||||||
|
"""Structured floating domain payload for UI routing decisions."""
|
||||||
|
|
||||||
|
type: Literal["task", "timeline", "project", "node"]
|
||||||
|
id: str | None = None
|
||||||
|
section: Literal["task", "timeline", "note"] | None = None
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingDomain(BaseModel):
|
class WsFloatingDomain(BaseModel):
|
||||||
@@ -325,7 +270,7 @@ class WsFloatingDomain(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
request_id: str
|
request_id: str
|
||||||
domain: Literal["tasks", "timelines", "notes", "projects"]
|
domain: WsDomain
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
@@ -334,84 +279,28 @@ class AgentCatalogItem(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
config_schema: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local Agent Config ────────────────────────────────────────────────
|
class AgentCreationCheckRequest(BaseModel):
|
||||||
|
active_agents: int = Field(ge=0, default=0)
|
||||||
class LocalAgentConfigCreate(BaseModel):
|
|
||||||
name: str
|
|
||||||
device_id: str
|
|
||||||
directory_paths: list[str]
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
file_extensions: list[str]
|
|
||||||
schedule_cron: str
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfigUpdate(BaseModel):
|
class AgentCreationCheckResponse(BaseModel):
|
||||||
name: str | None = None
|
allowed: bool
|
||||||
device_id: str | None = None
|
tier: BillingTier
|
||||||
directory_paths: list[str] | None = None
|
active_agents: int
|
||||||
data_types: list[str] | None = None
|
limit: int
|
||||||
prompt_template: str | None = None
|
|
||||||
file_extensions: list[str] | None = None
|
|
||||||
schedule_cron: str | None = None
|
|
||||||
enabled: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfigResponse(BaseModel):
|
class AgentTriggerRequest(BaseModel):
|
||||||
id: str
|
directory: str = Field(min_length=1)
|
||||||
name: str
|
device_id: str = Field(default="")
|
||||||
device_id: str
|
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
||||||
directory_paths: list[str]
|
what_to_extract: list[str] = Field(min_length=1)
|
||||||
data_types: list[str]
|
actions_by_type: dict[str, list[str]] | None = None
|
||||||
prompt_template: str
|
batch_interval: str = Field(min_length=1)
|
||||||
file_extensions: list[str]
|
custom_agent_prompt: str = Field(min_length=1)
|
||||||
schedule_cron: str
|
active_agents: int = Field(ge=0, default=0)
|
||||||
enabled: bool
|
|
||||||
last_run_at: int | None
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud Agent Config ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class CloudAgentConfigCreate(BaseModel):
|
|
||||||
provider: Literal["gmail", "teams", "outlook"]
|
|
||||||
name: str
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
oauth_token_encrypted: str
|
|
||||||
schedule_cron: str
|
|
||||||
filter_config: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CloudAgentConfigUpdate(BaseModel):
|
|
||||||
provider: Literal["gmail", "teams", "outlook"] | None = None
|
|
||||||
name: str | None = None
|
|
||||||
data_types: list[str] | None = None
|
|
||||||
prompt_template: str | None = None
|
|
||||||
oauth_token_encrypted: str | None = None
|
|
||||||
schedule_cron: str | None = None
|
|
||||||
filter_config: dict[str, Any] | None = None
|
|
||||||
enabled: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CloudAgentConfigResponse(BaseModel):
|
|
||||||
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
provider: Literal["gmail", "teams", "outlook"]
|
|
||||||
name: str
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
schedule_cron: str
|
|
||||||
filter_config: dict[str, Any] | None
|
|
||||||
enabled: bool
|
|
||||||
last_run_at: int | None
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
@@ -430,18 +319,3 @@ class AgentRunLogResponse(BaseModel):
|
|||||||
|
|
||||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||||
|
|
||||||
class JourneyStartRequest(BaseModel):
|
|
||||||
agent_type: Literal["local", "cloud"]
|
|
||||||
agent_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class JourneyMessageRequest(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
class JourneyResponse(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
message: str
|
|
||||||
done: bool
|
|
||||||
prompt_template: str | None = None
|
|
||||||
|
|||||||
941
docs/MICROSERVICES_ARCHITECTURE.md
Normal file
941
docs/MICROSERVICES_ARCHITECTURE.md
Normal file
@@ -0,0 +1,941 @@
|
|||||||
|
# Adiuva — Architettura Microservizi (MVP)
|
||||||
|
|
||||||
|
## Panoramica
|
||||||
|
|
||||||
|
Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
|
||||||
|
|
||||||
|
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────┐
|
||||||
|
│ Cloudflare │
|
||||||
|
│ (DNS + CDN) │
|
||||||
|
└──────┬───────┘
|
||||||
|
│ HTTPS / WSS
|
||||||
|
┌──────▼───────┐
|
||||||
|
│ Traefik │
|
||||||
|
│ API Gateway │
|
||||||
|
│ (routing, │
|
||||||
|
│ TLS, rate │
|
||||||
|
│ limiting) │
|
||||||
|
└──────┬───────┘
|
||||||
|
│
|
||||||
|
┌──────────┬───────────┼───────────┐
|
||||||
|
│ │ │ │
|
||||||
|
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
|
||||||
|
│ Auth │ │ Chat │ │ Agent │ │Billing │
|
||||||
|
│ Service │ │Service│ │ Service │ │Service │
|
||||||
|
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
|
||||||
|
│ │ │ │
|
||||||
|
┌─────▼──────────▼──────────▼───────────▼────┐
|
||||||
|
│ Infrastruttura │
|
||||||
|
│ PostgreSQL │ Redis │ Qdrant │
|
||||||
|
└─────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Suddivisione dei Servizi
|
||||||
|
|
||||||
|
### 1.1 Auth Service (`auth-service`)
|
||||||
|
|
||||||
|
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
|
||||||
|
|
||||||
|
| Endpoint originale | Metodo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/auth/register` | POST |
|
||||||
|
| `/api/v1/auth/login` | POST |
|
||||||
|
| `/api/v1/auth/refresh` | POST |
|
||||||
|
| `/api/v1/auth/me` | GET / PUT |
|
||||||
|
|
||||||
|
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
|
||||||
|
|
||||||
|
**Modifica chiave — JWT con RS256**:
|
||||||
|
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
|
||||||
|
- L'Auth Service firma i JWT con la **chiave privata**.
|
||||||
|
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
|
||||||
|
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# auth-service/app/auth/jwt.py
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
PRIVATE_KEY = ... # Da env/secret
|
||||||
|
PUBLIC_KEY = ... # Derivata o da env
|
||||||
|
|
||||||
|
def create_access_token(user_id: str, tier: str) -> str:
|
||||||
|
return jwt.encode(
|
||||||
|
{"sub": user_id, "tier": tier, "exp": ...},
|
||||||
|
PRIVATE_KEY,
|
||||||
|
algorithm="RS256",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# shared/auth.py (usato da tutti gli altri servizi)
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
|
||||||
|
|
||||||
|
def verify_token(token: str) -> dict:
|
||||||
|
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
||||||
|
```
|
||||||
|
|
||||||
|
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.2 Chat Service (`chat-service`) ⭐ Real-time
|
||||||
|
|
||||||
|
**Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
|
||||||
|
|
||||||
|
Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
|
||||||
|
|
||||||
|
| Endpoint | Tipo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/ws/device` | WebSocket (connessione persistente) |
|
||||||
|
| `/api/v1/chat` | POST (REST fallback) |
|
||||||
|
|
||||||
|
**Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
|
||||||
|
|
||||||
|
**Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
|
||||||
|
|
||||||
|
**Scaling**: 2–N repliche. Sticky cookies per le WS + Redis per cross-instance.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.3 Agent Service (`agent-service`) ⭐ Batch
|
||||||
|
|
||||||
|
**Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
|
||||||
|
|
||||||
|
Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
|
||||||
|
|
||||||
|
| Endpoint | Tipo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/agents/catalog` | GET |
|
||||||
|
| `/api/v1/agents/can-create` | POST |
|
||||||
|
| `/api/v1/agents/trigger` | POST |
|
||||||
|
| `/api/v1/agents/journey/start` | POST (o WS relay) |
|
||||||
|
| `/api/v1/agents/journey/message` | POST (o WS relay) |
|
||||||
|
|
||||||
|
**Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
|
||||||
|
|
||||||
|
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────┐ ┌──────────────┐ ┌──────────┐
|
||||||
|
│ Agent Service│ │ Redis │ │ Chat │
|
||||||
|
│ (batch run) │ │ │ │ Service │
|
||||||
|
│ │ │ │ │ (ha WS) │
|
||||||
|
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
|
||||||
|
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
|
||||||
|
│ from │ │ │ │ al │
|
||||||
|
│ device │ │ │ │ device│
|
||||||
|
│ │ │ │ │ via WS│
|
||||||
|
│ │ SUBSCRIBE │ │ PUBLISH │ │
|
||||||
|
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
|
||||||
|
│ risultato │ │ │ │ reply │
|
||||||
|
└──────────────┘ └──────────────┘ └──────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**Scaling**: 1–N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
|
||||||
|
|
||||||
|
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.4 Billing Service (`billing-service`)
|
||||||
|
|
||||||
|
**Responsabilità**: Stripe checkout, webhook, subscription management.
|
||||||
|
|
||||||
|
| Endpoint originale | Metodo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/billing/checkout` | POST |
|
||||||
|
| `/api/v1/billing/webhook` | POST |
|
||||||
|
| `/api/v1/billing/subscription` | GET / DELETE |
|
||||||
|
|
||||||
|
**Database**: Tabelle `subscriptions` (schema `billing`).
|
||||||
|
|
||||||
|
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
|
||||||
|
|
||||||
|
**Scaling**: 1 replica sufficiente. Basso traffico.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.5 Servizi esclusi dall'MVP
|
||||||
|
|
||||||
|
I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
|
||||||
|
|
||||||
|
| Servizio | Responsabilità | Note |
|
||||||
|
|---|---|---|
|
||||||
|
| **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
|
||||||
|
| **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Tier Check — Dove e Come
|
||||||
|
|
||||||
|
Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
|
||||||
|
|
||||||
|
### Strategia: Tier nel JWT
|
||||||
|
|
||||||
|
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sub": "user_123",
|
||||||
|
"tier": "pro",
|
||||||
|
"exp": 1742515200,
|
||||||
|
"iat": 1742511600
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Ogni servizio:
|
||||||
|
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
|
||||||
|
2. Legge `payload["tier"]` — **zero chiamate extra**
|
||||||
|
3. Applica le sue regole di enforcement localmente
|
||||||
|
|
||||||
|
```python
|
||||||
|
# shared/auth.py — dependency FastAPI condivisa
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
PUBLIC_KEY = ...
|
||||||
|
|
||||||
|
class CurrentUser:
|
||||||
|
def __init__(self, user_id: str, tier: str):
|
||||||
|
self.user_id = user_id
|
||||||
|
self.tier = tier
|
||||||
|
|
||||||
|
async def get_current_user(request: Request) -> CurrentUser:
|
||||||
|
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
|
||||||
|
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
||||||
|
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
|
||||||
|
|
||||||
|
def require_tier(*allowed_tiers: str):
|
||||||
|
"""Dependency che blocca se il tier non è tra quelli ammessi."""
|
||||||
|
async def check(user: CurrentUser = Depends(get_current_user)):
|
||||||
|
if user.tier not in allowed_tiers:
|
||||||
|
raise HTTPException(403, "Tier insufficient")
|
||||||
|
return user
|
||||||
|
return check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cosa succede quando il tier cambia (upgrade/downgrade)?
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
|
||||||
|
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
|
||||||
|
│ │ │ Service │ (Redis pub/sub) │ Service │
|
||||||
|
└──────────┘ └──────────┘ └────┬─────┘
|
||||||
|
│
|
||||||
|
UPDATE users
|
||||||
|
SET tier = 'power'
|
||||||
|
│
|
||||||
|
Al prossimo /refresh
|
||||||
|
il JWT conterrà tier='power'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 15–30 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
|
||||||
|
|
||||||
|
### Dove si applica in ciascun servizio
|
||||||
|
|
||||||
|
| Servizio | Enforcement |
|
||||||
|
|---|---|
|
||||||
|
| **Auth Service** | Nessuno (è lui che scrive il tier) |
|
||||||
|
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
|
||||||
|
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
|
||||||
|
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
|
||||||
|
|
||||||
|
### Rate-limit distribuito via Redis
|
||||||
|
|
||||||
|
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# shared/middleware/rate_limit.py
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
class DistributedRateLimiter:
|
||||||
|
def __init__(self, redis: aioredis.Redis):
|
||||||
|
self._redis = redis
|
||||||
|
|
||||||
|
async def check(self, user_id: str, tier: str, service: str) -> bool:
|
||||||
|
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
|
||||||
|
max_req = limits.get(tier, 20)
|
||||||
|
key = f"rate:{service}:{user_id}"
|
||||||
|
|
||||||
|
pipe = self._redis.pipeline()
|
||||||
|
pipe.incr(key)
|
||||||
|
pipe.expire(key, 60)
|
||||||
|
count, _ = await pipe.execute()
|
||||||
|
|
||||||
|
return count <= max_req
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
|
||||||
|
|
||||||
|
`DeviceConnectionManager` è un **singleton in-memory**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class DeviceConnectionManager:
|
||||||
|
def __init__(self):
|
||||||
|
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
|
||||||
|
```
|
||||||
|
|
||||||
|
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
|
||||||
|
|
||||||
|
### La soluzione: Redis Pub/Sub + Registry
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ Redis │
|
||||||
|
│ │
|
||||||
|
│ Hash: ws:connections │
|
||||||
|
│ user_123 → instance_A │
|
||||||
|
│ user_456 → instance_B │
|
||||||
|
│ │
|
||||||
|
│ Pub/Sub channels: │
|
||||||
|
│ tool_call:{user_id} → tool call payloads │
|
||||||
|
│ tool_result:{call_id} → tool result payloads │
|
||||||
|
│ stream:{user_id} → text_chunk streaming │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
|
||||||
|
┌───────────────────────┐ ┌───────────────────────┐
|
||||||
|
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
|
||||||
|
│ tool_call:user_123│ │ → user_123 è su A │
|
||||||
|
│ │ │ │
|
||||||
|
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
|
||||||
|
│ da Redis channel │ │ tool_call:user_123 │
|
||||||
|
│ │ │ {id, action, ...} │
|
||||||
|
│ 3. Invia al device │ │ │
|
||||||
|
│ via WS │ │ 4. SUBSCRIBE │
|
||||||
|
│ │ │ tool_result:{id} │
|
||||||
|
│ 4. Device risponde │ │ │
|
||||||
|
│ tool_result │──────────│► 5. Riceve risultato │
|
||||||
|
│ │ │ │
|
||||||
|
│ 5. PUBLISH │ │ │
|
||||||
|
│ tool_result:{id} │ │ │
|
||||||
|
└───────────────────────┘ └───────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Implementazione: `RedisDeviceManager`
|
||||||
|
|
||||||
|
```python
|
||||||
|
# chat-service/app/core/device_manager.py
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LocalConnection:
|
||||||
|
ws: WebSocket
|
||||||
|
device_id: str
|
||||||
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisDeviceManager:
|
||||||
|
"""Device manager backed by Redis for cross-instance communication."""
|
||||||
|
|
||||||
|
def __init__(self, redis_url: str = "redis://redis:6379"):
|
||||||
|
self._redis = aioredis.from_url(redis_url)
|
||||||
|
self._pubsub = self._redis.pubsub()
|
||||||
|
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
|
||||||
|
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Avvia il listener Redis per tool_call in arrivo."""
|
||||||
|
asyncio.create_task(self._listen_tool_calls())
|
||||||
|
|
||||||
|
# ── Registrazione ──
|
||||||
|
|
||||||
|
async def register(self, user_id: str, device_id: str, ws: WebSocket):
|
||||||
|
# Registra localmente
|
||||||
|
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
|
||||||
|
# Registra in Redis quale istanza ha la connessione
|
||||||
|
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
|
||||||
|
# Sottoscrivi ai tool_call per questo utente
|
||||||
|
await self._pubsub.subscribe(f"tool_call:{user_id}")
|
||||||
|
|
||||||
|
async def unregister(self, user_id: str):
|
||||||
|
conn = self._local.pop(user_id, None)
|
||||||
|
if conn:
|
||||||
|
for fut in conn.pending_calls.values():
|
||||||
|
if not fut.done():
|
||||||
|
fut.cancel()
|
||||||
|
await self._redis.hdel("ws:connections", user_id)
|
||||||
|
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
|
||||||
|
|
||||||
|
# ── Presenza ──
|
||||||
|
|
||||||
|
async def is_online(self, user_id: str) -> bool:
|
||||||
|
return await self._redis.hexists("ws:connections", user_id)
|
||||||
|
|
||||||
|
# ── Tool-call round-trip (cross-instance) ──
|
||||||
|
|
||||||
|
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Invia un tool_call al device dell'utente.
|
||||||
|
Funziona sia che la WS sia locale che su un'altra istanza.
|
||||||
|
"""
|
||||||
|
call_id = payload["id"]
|
||||||
|
|
||||||
|
# Caso 1: connessione locale → invio diretto
|
||||||
|
if user_id in self._local:
|
||||||
|
conn = self._local[user_id]
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
fut: asyncio.Future[dict] = loop.create_future()
|
||||||
|
conn.pending_calls[call_id] = fut
|
||||||
|
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
|
||||||
|
return await asyncio.wait_for(fut, timeout=30.0)
|
||||||
|
|
||||||
|
# Caso 2: connessione remota → Redis pub/sub
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
fut = loop.create_future()
|
||||||
|
self._remote_futures[call_id] = fut
|
||||||
|
|
||||||
|
# Sottoscrivi al canale di risposta
|
||||||
|
result_channel = f"tool_result:{call_id}"
|
||||||
|
await self._pubsub.subscribe(result_channel)
|
||||||
|
|
||||||
|
# Pubblica il tool_call
|
||||||
|
await self._redis.publish(
|
||||||
|
f"tool_call:{user_id}",
|
||||||
|
json.dumps(payload),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(fut, timeout=30.0)
|
||||||
|
finally:
|
||||||
|
self._remote_futures.pop(call_id, None)
|
||||||
|
await self._pubsub.unsubscribe(result_channel)
|
||||||
|
|
||||||
|
# ── Risoluzione tool_result (da WS locale) ──
|
||||||
|
|
||||||
|
def resolve_local(self, user_id: str, call_id: str, result: dict):
|
||||||
|
conn = self._local.get(user_id)
|
||||||
|
if conn:
|
||||||
|
fut = conn.pending_calls.pop(call_id, None)
|
||||||
|
if fut and not fut.done():
|
||||||
|
fut.set_result(result)
|
||||||
|
|
||||||
|
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
|
||||||
|
"""Chiamato quando il device locale invia un tool_result."""
|
||||||
|
self.resolve_local(user_id, call_id, result)
|
||||||
|
# Pubblica anche su Redis per l'istanza remota che aspetta
|
||||||
|
await self._redis.publish(
|
||||||
|
f"tool_result:{call_id}",
|
||||||
|
json.dumps(result),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Listener Redis ──
|
||||||
|
|
||||||
|
async def _listen_tool_calls(self):
|
||||||
|
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
|
||||||
|
async for message in self._pubsub.listen():
|
||||||
|
if message["type"] != "message":
|
||||||
|
continue
|
||||||
|
channel = message["channel"]
|
||||||
|
if isinstance(channel, bytes):
|
||||||
|
channel = channel.decode()
|
||||||
|
|
||||||
|
data = json.loads(message["data"])
|
||||||
|
|
||||||
|
if channel.startswith("tool_call:"):
|
||||||
|
# Un'altra istanza vuole che inviamo un tool_call al nostro device
|
||||||
|
user_id = channel.split(":", 1)[1]
|
||||||
|
conn = self._local.get(user_id)
|
||||||
|
if conn:
|
||||||
|
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
|
||||||
|
|
||||||
|
elif channel.startswith("tool_result:"):
|
||||||
|
# Risposta a un tool_call che abbiamo inviato tramite Redis
|
||||||
|
call_id = channel.split(":", 1)[1]
|
||||||
|
fut = self._remote_futures.pop(call_id, None)
|
||||||
|
if fut and not fut.done():
|
||||||
|
fut.set_result(data)
|
||||||
|
|
||||||
|
# ── Stream cross-instance ──
|
||||||
|
|
||||||
|
async def publish_stream_chunk(self, user_id: str, chunk: dict):
|
||||||
|
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
|
||||||
|
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Struttura Directory Proposta (MVP)
|
||||||
|
|
||||||
|
```
|
||||||
|
adiuva-api/
|
||||||
|
├── docker-compose.yml # Orchestrazione completa
|
||||||
|
├── docker-compose.dev.yml # Override per sviluppo locale
|
||||||
|
├── shared/ # Codice condiviso (montato come volume)
|
||||||
|
│ ├── auth.py # JWT verification (chiave pubblica)
|
||||||
|
│ ├── schemas.py # Pydantic schemas condivisi
|
||||||
|
│ ├── middleware/
|
||||||
|
│ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
|
||||||
|
│ │ └── sanitizer.py
|
||||||
|
│ └── models/
|
||||||
|
│ └── base.py # SQLAlchemy base condivisa
|
||||||
|
│
|
||||||
|
├── auth-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # users, refresh_tokens
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ └── auth.py
|
||||||
|
│ └── services/
|
||||||
|
│ ├── jwt_service.py # RS256 signing
|
||||||
|
│ └── user_service.py
|
||||||
|
│
|
||||||
|
├── chat-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # memory_*
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ ├── device_ws.py # WS connection owner
|
||||||
|
│ │ └── chat.py # REST fallback
|
||||||
|
│ ├── core/
|
||||||
|
│ │ ├── device_manager.py # RedisDeviceManager
|
||||||
|
│ │ ├── deep_agent.py # Home + floating chat
|
||||||
|
│ │ ├── memory_middleware.py
|
||||||
|
│ │ ├── ws_context.py
|
||||||
|
│ │ ├── output_formatter.py
|
||||||
|
│ │ └── llm.py
|
||||||
|
│ └── agents/ # Tool definitions (used by deep_agent)
|
||||||
|
│ ├── task_agent.py
|
||||||
|
│ ├── project_agent.py
|
||||||
|
│ ├── note_agent.py
|
||||||
|
│ └── timeline_agent.py
|
||||||
|
│
|
||||||
|
├── agent-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # agent_run_logs, local/cloud_agent_configs
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ ├── agents.py # catalog, can-create, trigger
|
||||||
|
│ │ └── agent_setup.py # journey start/message
|
||||||
|
│ ├── core/
|
||||||
|
│ │ ├── agent_runner.py # Batch classify → process
|
||||||
|
│ │ ├── agent_registry.py
|
||||||
|
│ │ ├── redis_executor.py # execute_on_client via Redis pub/sub
|
||||||
|
│ │ └── llm.py
|
||||||
|
│ └── agents/
|
||||||
|
│ ├── task_agent.py # Tool definitions (batch context)
|
||||||
|
│ ├── project_agent.py
|
||||||
|
│ ├── note_agent.py
|
||||||
|
│ ├── timeline_agent.py
|
||||||
|
│ └── filesystem_agent.py
|
||||||
|
│
|
||||||
|
├── billing-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # subscriptions
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ └── billing.py
|
||||||
|
│ └── services/
|
||||||
|
│ ├── stripe_service.py
|
||||||
|
│ └── tier_manager.py
|
||||||
|
│
|
||||||
|
└── infra/
|
||||||
|
├── traefik/
|
||||||
|
│ └── traefik.yml
|
||||||
|
├── keys/
|
||||||
|
│ ├── jwt_private.pem # Solo auth-service
|
||||||
|
│ └── jwt_public.pem # Tutti i servizi
|
||||||
|
└── alembic/ # Migrazioni condivise o per-servizio
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Docker Compose — Configurazione MVP
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# docker-compose.yml
|
||||||
|
|
||||||
|
services:
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# API Gateway
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
traefik:
|
||||||
|
image: traefik:v3.2
|
||||||
|
command:
|
||||||
|
- "--api.insecure=true"
|
||||||
|
- "--providers.docker=true"
|
||||||
|
- "--providers.docker.exposedbydefault=false"
|
||||||
|
- "--entrypoints.web.address=:80"
|
||||||
|
- "--entrypoints.websecure.address=:443"
|
||||||
|
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
|
||||||
|
ports:
|
||||||
|
- "80:80"
|
||||||
|
- "443:443"
|
||||||
|
- "8080:8080" # Dashboard Traefik (disabilitare in prod)
|
||||||
|
volumes:
|
||||||
|
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||||
|
- ./infra/certs:/certs:ro
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Auth Service (2 repliche)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
auth-service:
|
||||||
|
build: ./auth-service
|
||||||
|
deploy:
|
||||||
|
replicas: 2
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
|
||||||
|
SERVICE_NAME: auth
|
||||||
|
secrets:
|
||||||
|
- jwt_private_key
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
|
||||||
|
- "traefik.http.services.auth.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Chat Service — Real-time WS + Chat (scalabile)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
chat-service:
|
||||||
|
build: ./chat-service
|
||||||
|
deploy:
|
||||||
|
replicas: 2
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: chat
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
# REST chat endpoint
|
||||||
|
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
|
||||||
|
- "traefik.http.services.chat.loadbalancer.server.port=8000"
|
||||||
|
# WebSocket route con sticky session
|
||||||
|
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
|
||||||
|
- "traefik.http.routers.ws.service=chat-ws"
|
||||||
|
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
|
||||||
|
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
|
||||||
|
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Agent Service — Batch processing (scalabile indipendentemente)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
agent-service:
|
||||||
|
build: ./agent-service
|
||||||
|
deploy:
|
||||||
|
replicas: 2
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: agent
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
|
||||||
|
- "traefik.http.services.agents.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Billing Service (1 replica)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
billing-service:
|
||||||
|
build: ./billing-service
|
||||||
|
deploy:
|
||||||
|
replicas: 1
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: billing
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
|
||||||
|
- "traefik.http.services.billing.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Infrastruttura
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
db:
|
||||||
|
image: pgvector/pgvector:pg16
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
POSTGRES_DB: adiuva
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:7-alpine
|
||||||
|
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
||||||
|
volumes:
|
||||||
|
- redis_data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
qdrant:
|
||||||
|
image: qdrant/qdrant:latest
|
||||||
|
volumes:
|
||||||
|
- qdrant_data:/qdrant/storage
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
secrets:
|
||||||
|
jwt_private_key:
|
||||||
|
file: ./infra/keys/jwt_private.pem
|
||||||
|
jwt_public_key:
|
||||||
|
file: ./infra/keys/jwt_public.pem
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
redis_data:
|
||||||
|
qdrant_data:
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Configurazione Cloudflare + VPS
|
||||||
|
|
||||||
|
### 6.1 DNS
|
||||||
|
|
||||||
|
```
|
||||||
|
api.tuodominio.com → A record → IP del VPS
|
||||||
|
→ Proxy: ON (orange cloud)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 Cloudflare Settings
|
||||||
|
|
||||||
|
| Setting | Valore | Motivo |
|
||||||
|
|---------|--------|--------|
|
||||||
|
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
|
||||||
|
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
|
||||||
|
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
|
||||||
|
| Under Attack Mode | Off (attivare se necessario) | |
|
||||||
|
|
||||||
|
### 6.3 TLS sul VPS
|
||||||
|
|
||||||
|
Due opzioni:
|
||||||
|
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
|
||||||
|
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# traefik.yml — con Cloudflare Origin Certificate
|
||||||
|
entryPoints:
|
||||||
|
websecure:
|
||||||
|
address: ":443"
|
||||||
|
|
||||||
|
tls:
|
||||||
|
certificates:
|
||||||
|
- certFile: /certs/origin.pem
|
||||||
|
keyFile: /certs/origin-key.pem
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.4 Rete VPS
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
|
||||||
|
# https://www.cloudflare.com/ips/
|
||||||
|
ufw default deny incoming
|
||||||
|
ufw allow from 173.245.48.0/20 to any port 443
|
||||||
|
ufw allow from 103.21.244.0/22 to any port 443
|
||||||
|
# ... (tutti gli IP range di Cloudflare)
|
||||||
|
ufw allow ssh
|
||||||
|
ufw enable
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Comunicazione Inter-Servizio
|
||||||
|
|
||||||
|
### 7.1 Redis Pub/Sub — Event Bus
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────┐ tier_changed:user_123 ┌──────────┐
|
||||||
|
│ Billing │ ────────────────────────► │ Auth │
|
||||||
|
│ Service │ │ Service │
|
||||||
|
└──────────┘ └──────────┘
|
||||||
|
|
||||||
|
┌──────────┐ tool_call:user_123 ┌──────────┐
|
||||||
|
│ Agent │ ────────────────────────► │ Chat │
|
||||||
|
│ Service │ │ Service │
|
||||||
|
│ (batch) │ ◄────────────────────────│ (ha WS) │
|
||||||
|
└──────────┘ tool_result:{call_id} └──────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.2 Health Checks e Service Discovery
|
||||||
|
|
||||||
|
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
|
||||||
|
- **Redis pub/sub** (tool-call cross-instance, tier events)
|
||||||
|
- **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
|
||||||
|
- **PostgreSQL** (dati persistenti condivisi)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Piano di Migrazione Incrementale (MVP)
|
||||||
|
|
||||||
|
### Fase 1 — Preparazione (nel monolite attuale)
|
||||||
|
1. Aggiungere Redis al `docker-compose.yml` attuale
|
||||||
|
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
|
||||||
|
3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
|
||||||
|
4. Estrarre `shared/` con auth verification, schemas, middleware
|
||||||
|
|
||||||
|
### Fase 2 — Auth Service (primo split)
|
||||||
|
1. Estrarre `auth.py` routes + models in `auth-service/`
|
||||||
|
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
|
||||||
|
3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
|
||||||
|
4. Il monolite continua a servire tutto il resto
|
||||||
|
|
||||||
|
### Fase 3 — Billing Service
|
||||||
|
1. Estrarre billing routes, Stripe service, tier manager
|
||||||
|
2. Configurare Redis pub/sub per `tier_changed` events
|
||||||
|
3. Routare via Traefik
|
||||||
|
|
||||||
|
### Fase 4 — Split Chat + Agent (il più delicato)
|
||||||
|
1. Il monolite residuo contiene WS + chat + agents
|
||||||
|
2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
|
||||||
|
3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
|
||||||
|
4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
|
||||||
|
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
|
||||||
|
|
||||||
|
### Fase 5 — Scaling test
|
||||||
|
1. Scalare Chat Service a 2 repliche, verificare sticky sessions
|
||||||
|
2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
|
||||||
|
3. Monitoring (Prometheus + Grafana) per ogni servizio
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Monitoraggio e Logging
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Aggiungere al docker-compose.yml
|
||||||
|
|
||||||
|
prometheus:
|
||||||
|
image: prom/prometheus:latest
|
||||||
|
volumes:
|
||||||
|
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
grafana:
|
||||||
|
image: grafana/grafana:latest
|
||||||
|
ports:
|
||||||
|
- "3000:3000"
|
||||||
|
volumes:
|
||||||
|
- grafana_data:/var/lib/grafana
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
loki:
|
||||||
|
image: grafana/loki:latest
|
||||||
|
restart: unless-stopped
|
||||||
|
```
|
||||||
|
|
||||||
|
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Sizing VPS Minimo Consigliato (MVP)
|
||||||
|
|
||||||
|
| Componente | CPU | RAM | Note |
|
||||||
|
|---|---|---|---|
|
||||||
|
| Traefik | 0.25 | 128MB | |
|
||||||
|
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
|
||||||
|
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
|
||||||
|
| Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
|
||||||
|
| Billing Service | 0.25 | 128MB | |
|
||||||
|
| PostgreSQL | 1.0 | 1GB | |
|
||||||
|
| Redis | 0.25 | 256MB | |
|
||||||
|
| Qdrant | 0.5 | 512MB | |
|
||||||
|
| **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
|
||||||
|
|
||||||
|
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Riepilogo Architettura MVP
|
||||||
|
|
||||||
|
| Servizio | Repliche | Proprietario di |
|
||||||
|
|---|---|---|
|
||||||
|
| **Traefik** | 1 | Routing, TLS, sticky sessions |
|
||||||
|
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
|
||||||
|
| **Chat Service** | 2–N | WebSocket, home/floating chat, streaming |
|
||||||
|
| **Agent Service** | 2–N | Batch processing, directory scan, agent setup |
|
||||||
|
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
|
||||||
|
|
||||||
|
| Decisione | Scelta | Motivazione |
|
||||||
|
|---|---|---|
|
||||||
|
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
|
||||||
|
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
|
||||||
|
| Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
|
||||||
|
| WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
|
||||||
|
| Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
|
||||||
|
| Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
|
||||||
|
| Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
|
||||||
|
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
|
||||||
|
| TLS | Cloudflare Origin Certificate | Zero maintenance |
|
||||||
|
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
|
||||||
|
| Storage / Plugin | Post-MVP | Non critici per il lancio |
|
||||||
@@ -32,4 +32,6 @@ google-auth-oauthlib>=1.2.0
|
|||||||
google-auth-httplib2>=0.2.0
|
google-auth-httplib2>=0.2.0
|
||||||
msal>=1.28.0
|
msal>=1.28.0
|
||||||
cryptography>=42.0.0
|
cryptography>=42.0.0
|
||||||
|
redis>=5.0.0
|
||||||
|
langfuse>=3.0.0
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
|||||||
19
services/auth/.env.example
Normal file
19
services/auth/.env.example
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# ── Auth Service ──────────────────────────────────────────────────────────────
|
||||||
|
# This file contains env vars specific to the Auth Service.
|
||||||
|
# Shared vars (DATABASE_URL, REDIS_URL, etc.) come from the root .env
|
||||||
|
# or from docker-compose environment.
|
||||||
|
|
||||||
|
# ── JWT RS256 Keys ────────────────────────────────────────────────────────────
|
||||||
|
# Generate keypair:
|
||||||
|
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||||
|
# openssl rsa -in private.pem -pubout -out public.pem
|
||||||
|
#
|
||||||
|
# Paste PEM content with literal \n for newlines:
|
||||||
|
# JWT_PRIVATE_KEY=-----BEGIN PRIVATE KEY-----\nMIIEvQ...
|
||||||
|
# JWT_PUBLIC_KEY=-----BEGIN PUBLIC KEY-----\nMIIBIj...
|
||||||
|
|
||||||
|
# PRIVATE KEY — used to SIGN JWTs. NEVER share outside this service.
|
||||||
|
JWT_PRIVATE_KEY=
|
||||||
|
|
||||||
|
# PUBLIC KEY — used to VERIFY JWTs.
|
||||||
|
JWT_PUBLIC_KEY=
|
||||||
36
services/auth/Dockerfile
Normal file
36
services/auth/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
# Install shared + service deps in one layer
|
||||||
|
COPY services/auth/requirements.txt ./requirements.txt
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Copy shared module (available to all services)
|
||||||
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Copy service source
|
||||||
|
COPY services/auth/app/ app/
|
||||||
|
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "2", \
|
||||||
|
"--timeout", "30"]
|
||||||
16
services/auth/README.md
Normal file
16
services/auth/README.md
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Auth Service
|
||||||
|
|
||||||
|
Owns: user registration, login, JWT RS256 issuance, token refresh, `/me` endpoint.
|
||||||
|
|
||||||
|
## Tables owned
|
||||||
|
- `users`
|
||||||
|
- `refresh_tokens`
|
||||||
|
- `subscriptions` (read; Billing Service writes)
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
- `POST /auth/register`
|
||||||
|
- `POST /auth/login`
|
||||||
|
- `POST /auth/refresh`
|
||||||
|
- `GET /auth/me`
|
||||||
|
- `PUT /auth/me`
|
||||||
|
- `GET /auth/verify` (ForwardAuth for Traefik)
|
||||||
0
services/auth/app/__init__.py
Normal file
0
services/auth/app/__init__.py
Normal file
34
services/auth/app/config.py
Normal file
34
services/auth/app/config.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""Auth Service — local configuration.
|
||||||
|
|
||||||
|
Contains secrets that ONLY the Auth Service needs (e.g., JWT private key).
|
||||||
|
These are NOT in shared/config.py to prevent other services from accessing them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import field_validator
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class AuthSettings(BaseSettings):
|
||||||
|
# RS256 private key (PEM format). Used to SIGN JWTs.
|
||||||
|
# Only the Auth Service has this. Generate with:
|
||||||
|
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||||
|
# Then set the env var (newlines as \n):
|
||||||
|
# JWT_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\nMIIEv..."
|
||||||
|
JWT_PRIVATE_KEY: str = ""
|
||||||
|
|
||||||
|
# RS256 public key (PEM format). Used to VERIFY JWTs.
|
||||||
|
# Derived from the private key:
|
||||||
|
# openssl rsa -in private.pem -pubout -out public.pem
|
||||||
|
JWT_PUBLIC_KEY: str = ""
|
||||||
|
|
||||||
|
@field_validator("JWT_PRIVATE_KEY", "JWT_PUBLIC_KEY", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _expand_pem_newlines(cls, v: str) -> str:
|
||||||
|
if isinstance(v, str) and r"\n" in v:
|
||||||
|
return v.replace(r"\n", "\n")
|
||||||
|
return v
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
auth_settings = AuthSettings()
|
||||||
69
services/auth/app/deps.py
Normal file
69
services/auth/app/deps.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""Auth dependencies — JWT validation for the Auth Service.
|
||||||
|
|
||||||
|
This is the canonical get_current_user used by protected endpoints
|
||||||
|
within the Auth Service itself (/me, /me PUT).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.db import get_session
|
||||||
|
from shared.models import Subscription, User
|
||||||
|
from shared.schemas import UserProfile
|
||||||
|
|
||||||
|
from app.config import auth_settings
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: str = Depends(oauth2_scheme),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Validate a Bearer JWT and return the authenticated user.
|
||||||
|
|
||||||
|
The JWT is used for identity and expiry. Tier is fetched live from the
|
||||||
|
subscriptions table so upgrades/downgrades take effect immediately.
|
||||||
|
"""
|
||||||
|
credentials_exc = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
email: str | None = payload.get("email")
|
||||||
|
if not user_id or not email:
|
||||||
|
raise credentials_exc
|
||||||
|
except JWTError:
|
||||||
|
raise credentials_exc
|
||||||
|
|
||||||
|
# Live tier lookup
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
|
# Fetch name/surname
|
||||||
|
user_result = await db.execute(
|
||||||
|
select(User.name, User.surname).where(User.id == user_id)
|
||||||
|
)
|
||||||
|
user_row = user_result.one_or_none()
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user_id,
|
||||||
|
email=email,
|
||||||
|
name=user_row.name if user_row else None,
|
||||||
|
surname=user_row.surname if user_row else None,
|
||||||
|
tier=tier,
|
||||||
|
) # type: ignore[arg-type]
|
||||||
62
services/auth/app/main.py
Normal file
62
services/auth/app/main.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""Auth Service — JWT issuance, user management, ForwardAuth verification.
|
||||||
|
|
||||||
|
Standalone FastAPI service extracted from the adiuva-api monolith.
|
||||||
|
Owns: users, refresh_tokens, subscriptions (read).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure the repo root is on sys.path so "shared" is importable.
|
||||||
|
# In Docker, COPY shared/ puts it at /app/shared/ (already importable).
|
||||||
|
# In local dev, we need to add the repo root (two levels up from this file).
|
||||||
|
_repo_root = str(Path(__file__).resolve().parents[3])
|
||||||
|
if _repo_root not in sys.path:
|
||||||
|
sys.path.insert(0, _repo_root)
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
yield
|
||||||
|
from shared.db import engine
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
app = FastAPI(
|
||||||
|
title="Adiuva Auth Service",
|
||||||
|
version="0.1.0",
|
||||||
|
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||||
|
redoc_url=None,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.CORS_ORIGINS,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.routes import router
|
||||||
|
from app.verify import router as verify_router
|
||||||
|
|
||||||
|
app.include_router(router, prefix="/api/v1")
|
||||||
|
app.include_router(verify_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
async def health() -> dict:
|
||||||
|
return {"status": "ok", "service": "auth", "version": app.version}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
249
services/auth/app/routes.py
Normal file
249
services/auth/app/routes.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
"""Auth routes: register, login, refresh, me.
|
||||||
|
|
||||||
|
Extracted from app/api/routes/auth.py — uses shared.* imports instead of app.*.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from jose import jwt
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.db import get_session
|
||||||
|
from shared.models import RefreshToken, Subscription, User
|
||||||
|
from shared.schemas import AuthTokens, UserProfile
|
||||||
|
|
||||||
|
from app.config import auth_settings
|
||||||
|
from app.deps import get_current_user
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_password(password: str) -> str:
|
||||||
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_password(password: str, hashed: str) -> bool:
|
||||||
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_token(plain_token: str) -> str:
|
||||||
|
"""SHA-256 of the plain refresh token string."""
|
||||||
|
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||||
|
"""Return (RS256-signed JWT, expires_at_ms)."""
|
||||||
|
now = int(time.time())
|
||||||
|
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"email": email,
|
||||||
|
"tier": tier,
|
||||||
|
"exp": exp,
|
||||||
|
"iat": now,
|
||||||
|
}
|
||||||
|
token = jwt.encode(payload, auth_settings.JWT_PRIVATE_KEY, algorithm="RS256")
|
||||||
|
return token, exp * 1000 # ms for client
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_live_tier(db: AsyncSession, user_id: str) -> str:
|
||||||
|
"""Fetch authoritative tier from subscriptions table."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
return result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request bodies ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _RegisterRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class _LoginRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class _RefreshRequest(BaseModel):
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateProfileRequest(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register(
|
||||||
|
body: _RegisterRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Create a new account and return JWT tokens."""
|
||||||
|
existing = await db.execute(select(User).where(User.email == body.email))
|
||||||
|
if existing.scalar_one_or_none() is not None:
|
||||||
|
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
email=body.email,
|
||||||
|
name=body.name,
|
||||||
|
surname=body.surname,
|
||||||
|
password_hash=_hash_password(body.password),
|
||||||
|
tier="free",
|
||||||
|
encryption_key=Fernet.generate_key().decode(),
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=AuthTokens)
|
||||||
|
async def login(
|
||||||
|
body: _LoginRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Validate credentials and return JWT tokens."""
|
||||||
|
result = await db.execute(select(User).where(User.email == body.email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not _verify_password(body.password, user.password_hash):
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||||
|
|
||||||
|
# Fetch live tier for the JWT claim
|
||||||
|
tier = await _get_live_tier(db, user.id)
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=AuthTokens)
|
||||||
|
async def refresh(
|
||||||
|
body: _RefreshRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Rotate a refresh token and return a new token pair."""
|
||||||
|
token_hash = _hash_token(body.refresh_token)
|
||||||
|
result = await db.execute(
|
||||||
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||||
|
)
|
||||||
|
rt = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||||
|
|
||||||
|
await db.delete(rt)
|
||||||
|
|
||||||
|
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||||
|
user = user_result.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||||
|
|
||||||
|
# Fetch live tier for the new JWT
|
||||||
|
tier = await _get_live_tier(db, user.id)
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
new_rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=new_expires,
|
||||||
|
)
|
||||||
|
db.add(new_rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserProfile)
|
||||||
|
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||||
|
"""Return the profile for the authenticated user."""
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me", response_model=UserProfile)
|
||||||
|
async def update_profile(
|
||||||
|
body: _UpdateProfileRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Update the authenticated user's name and surname."""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
if body.name is not None:
|
||||||
|
user.name = body.name
|
||||||
|
if body.surname is not None:
|
||||||
|
user.surname = body.surname
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user.id,
|
||||||
|
email=user.email,
|
||||||
|
name=user.name,
|
||||||
|
surname=user.surname,
|
||||||
|
tier=current_user.tier,
|
||||||
|
)
|
||||||
66
services/auth/app/verify.py
Normal file
66
services/auth/app/verify.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""ForwardAuth verification endpoint for Traefik.
|
||||||
|
|
||||||
|
Traefik calls GET /api/v1/auth/verify on every request to a protected
|
||||||
|
service. This endpoint validates the JWT from the Authorization header
|
||||||
|
and returns identity headers that Traefik injects into downstream requests.
|
||||||
|
|
||||||
|
Downstream services NEVER validate JWTs themselves — they trust the
|
||||||
|
X-User-Id, X-User-Email, X-User-Tier headers injected by Traefik.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request, Response
|
||||||
|
from fastapi import status as http_status
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.models import Subscription
|
||||||
|
|
||||||
|
from app.config import auth_settings
|
||||||
|
|
||||||
|
router = APIRouter(tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/auth/verify")
|
||||||
|
async def verify(request: Request) -> Response:
|
||||||
|
"""Validate JWT and return identity headers for Traefik ForwardAuth.
|
||||||
|
|
||||||
|
Returns 200 with X-User-* headers on success, 401 on failure.
|
||||||
|
Traefik copies response headers to the downstream request.
|
||||||
|
"""
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
|
token = auth_header[7:] # strip "Bearer "
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
email: str | None = payload.get("email")
|
||||||
|
if not user_id or not email:
|
||||||
|
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
||||||
|
except JWTError:
|
||||||
|
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
|
# Live tier lookup from subscriptions table
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
status_code=http_status.HTTP_200_OK,
|
||||||
|
headers={
|
||||||
|
"X-User-Id": user_id,
|
||||||
|
"X-User-Email": email,
|
||||||
|
"X-User-Tier": tier,
|
||||||
|
},
|
||||||
|
)
|
||||||
11
services/auth/requirements.txt
Normal file
11
services/auth/requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
python-jose[cryptography]>=3.3.0
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
asyncpg>=0.30.0
|
||||||
|
bcrypt>=4.2.0
|
||||||
|
cryptography>=42.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
36
services/batch-agent/Dockerfile
Normal file
36
services/batch-agent/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY services/batch-agent/requirements.txt ./requirements.txt
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Shared module
|
||||||
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Service source
|
||||||
|
COPY services/batch-agent/app/ app/
|
||||||
|
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Batch runs are long-lived — use a longer timeout than chat (300s vs 120s)
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "2", \
|
||||||
|
"--timeout", "300"]
|
||||||
23
services/batch-agent/README.md
Normal file
23
services/batch-agent/README.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Batch Agent Service
|
||||||
|
|
||||||
|
Owns: agent_runner, journey builder, filesystem_agent, integrations (Gmail, MS Graph).
|
||||||
|
|
||||||
|
## Tables owned
|
||||||
|
- `local_agent_configs`
|
||||||
|
- `cloud_agent_configs`
|
||||||
|
- `agent_run_logs`
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
- `GET /agents/catalog`
|
||||||
|
- `POST /agents/can-create`
|
||||||
|
- `POST /agents/trigger`
|
||||||
|
- `GET /agents/{id}/history`
|
||||||
|
|
||||||
|
## Redis channels
|
||||||
|
- Subscribe: `batch:request:{user_id}`
|
||||||
|
- Publish: `ws:out:{user_id}` (journey replies + tool calls)
|
||||||
|
- BRPOP: `tool:result:{call_id}` (30s timeout)
|
||||||
|
- SET+EX: `journey:{user_id}` (session state, TTL 1800s)
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
- [ ] Integrate Langfuse tracing (reuse `services/chat/app/tracing.py` pattern — `trace_span()`, `get_langfuse_callback()`, prompt management). Each batch agent run should create a trace with input/output, link prompts, and pass the LangChain `CallbackHandler` to LLM calls.
|
||||||
0
services/batch-agent/app/__init__.py
Normal file
0
services/batch-agent/app/__init__.py
Normal file
910
services/batch-agent/app/agent_runner.py
Normal file
910
services/batch-agent/app/agent_runner.py
Normal file
@@ -0,0 +1,910 @@
|
|||||||
|
"""Agent run orchestrator — adapted for Batch Agent Service.
|
||||||
|
|
||||||
|
Key changes from monolith app/core/agent_runner.py:
|
||||||
|
- No DeviceConnectionManager — tool calls go through Redis ws_context.
|
||||||
|
- set_current_user / clear_current_user replace set_client_executor.
|
||||||
|
- run_local_agent accepts a serialized dict (from Redis / REST) instead
|
||||||
|
of SQLAlchemy model objects.
|
||||||
|
- _finalize_run writes to PostgreSQL via shared.db.async_session.
|
||||||
|
- Cloud agent import path changed to app.integrations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
|
from shared.agents.note_agent import NOTE_TOOLS
|
||||||
|
from shared.agents.project_agent import PROJECT_TOOLS
|
||||||
|
from shared.agents.task_agent import TASK_TOOLS
|
||||||
|
from shared.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
|
from shared.llm import get_llm
|
||||||
|
from shared.ws_context import execute_on_client, set_current_user, clear_current_user
|
||||||
|
import app.tracing as tracing
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
from shared.redis import redis_client, ws_out_channel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Concurrency guard ─────────────────────────────────────────────────────
|
||||||
|
_running_agents: set[str] = set()
|
||||||
|
|
||||||
|
|
||||||
|
def is_agent_running(agent_id: str) -> bool:
|
||||||
|
return agent_id in _running_agents
|
||||||
|
|
||||||
|
|
||||||
|
# ── Timeouts ───────────────────────────────────────────────────────────────
|
||||||
|
_TOOL_CALL_TIMEOUT: int = 30
|
||||||
|
_MAX_PROCESSING_STEPS: int = 12
|
||||||
|
_MAX_SCAN_DEPTH: int = 5
|
||||||
|
|
||||||
|
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
||||||
|
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
||||||
|
"tasks": TASK_TOOLS,
|
||||||
|
"notes": NOTE_TOOLS,
|
||||||
|
"timelines": TIMELINE_TOOLS,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Step 1: Classification prompt ─────────────────────────────────────────
|
||||||
|
|
||||||
|
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
|
||||||
|
"tasks": (
|
||||||
|
"Action items, to-dos, deliverables — anything that describes work to be done, "
|
||||||
|
"assigned to someone, or tracked with a due date or status."
|
||||||
|
),
|
||||||
|
"notes": (
|
||||||
|
"Documentation, meeting notes, summaries, reference material — "
|
||||||
|
"written content meant to be read and referenced rather than acted on."
|
||||||
|
),
|
||||||
|
"timelines": (
|
||||||
|
"Project milestones, deadlines, scheduled events — "
|
||||||
|
"specific dates that mark a point in the progress of a project."
|
||||||
|
),
|
||||||
|
"projects": (
|
||||||
|
"High-level project entities — only relevant if the file clearly introduces "
|
||||||
|
"a new project or updates the scope of an existing one."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
_STEP1_SYSTEM_PROMPT = """\
|
||||||
|
You are a file classifier for a freelance project management tool.
|
||||||
|
|
||||||
|
Your job is to match a file to an existing project and identify which data domains to extract.
|
||||||
|
|
||||||
|
## Project matching rules (STRICT — follow in order)
|
||||||
|
|
||||||
|
1. Search the file content for any mention of a project name, client name, acronym, or topic
|
||||||
|
that overlaps with the existing projects listed below.
|
||||||
|
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
|
||||||
|
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
|
||||||
|
when the file has zero meaningful connection to any listed project.
|
||||||
|
4. When in doubt, pick the closest match from the list.
|
||||||
|
|
||||||
|
## Response format
|
||||||
|
|
||||||
|
Respond ONLY with a JSON object — no markdown, no explanation:
|
||||||
|
|
||||||
|
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
|
||||||
|
|
||||||
|
## Domain definitions (only consider domains in the allowed list)
|
||||||
|
|
||||||
|
{domain_definitions}
|
||||||
|
|
||||||
|
## Existing projects
|
||||||
|
|
||||||
|
{projects_list}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Step 2: Processing prompt ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
_PROCESSING_SYSTEM_PROMPT = """\
|
||||||
|
You are a data extraction assistant for a freelance project management tool.
|
||||||
|
|
||||||
|
Your task: extract structured data from the file content and persist it using the available tools.
|
||||||
|
|
||||||
|
## Mandatory process — follow this order for EVERY item you extract
|
||||||
|
|
||||||
|
1. READ the existing records listed below for the relevant domain.
|
||||||
|
2. SEARCH for a match by title, topic, or semantic similarity.
|
||||||
|
3. If a match exists → call the update_* tool with the existing record's id.
|
||||||
|
4. If no match exists → call the create_* tool and set isAiSuggested=1.
|
||||||
|
|
||||||
|
NEVER call create_* without first checking the existing records.
|
||||||
|
NEVER duplicate a record that already exists under a different wording.
|
||||||
|
|
||||||
|
## Existing records (source of truth)
|
||||||
|
|
||||||
|
{existing_context}
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
Project: {project_context}
|
||||||
|
Domains to extract: {data_types}
|
||||||
|
|
||||||
|
{custom_prompt_section}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Cloud processing prompt ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
_CLOUD_PROCESSING_PROMPT = """\
|
||||||
|
You are a data extraction and management assistant for a freelance project
|
||||||
|
management tool.
|
||||||
|
|
||||||
|
Available tools:
|
||||||
|
Filesystem : read_file_content, list_directory, get_file_metadata
|
||||||
|
Tasks : list_tasks, create_task, update_task, add_task_comment
|
||||||
|
Notes : list_notes, get_note, create_note, update_note
|
||||||
|
Timelines : list_timelines, create_timeline, update_timeline
|
||||||
|
Projects : list_all_projects, get_project, create_project, update_project
|
||||||
|
|
||||||
|
Your task:
|
||||||
|
1. Read the full content of each file below using read_file_content.
|
||||||
|
2. For each piece of information found, ALWAYS try to match and update an
|
||||||
|
existing record before creating a new one.
|
||||||
|
3. ONLY act on these entity types: {data_types}.
|
||||||
|
4. Do NOT invent data. Only extract what is clearly present in the files.
|
||||||
|
5. If a file contains no relevant data for the target entity types, skip it.
|
||||||
|
|
||||||
|
{project_context}
|
||||||
|
|
||||||
|
Files to process:
|
||||||
|
{file_list}
|
||||||
|
|
||||||
|
{custom_prompt_section}
|
||||||
|
|
||||||
|
After processing all files, respond with a brief summary of what you updated
|
||||||
|
and what you created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM tool-calling loop ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_agent_with_tools(
|
||||||
|
*,
|
||||||
|
system_prompt: str,
|
||||||
|
user_message: str,
|
||||||
|
tools: list[Any],
|
||||||
|
max_steps: int,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Run an LLM agent with tool-calling, returning the final text response."""
|
||||||
|
callbacks = [langfuse_handler] if langfuse_handler else None
|
||||||
|
llm = get_llm(callbacks=callbacks)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(content=user_message),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:200],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool list builder ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _build_processing_tools(data_types: list[str]) -> list[Any]:
|
||||||
|
tools: list[Any] = list(FILESYSTEM_TOOLS)
|
||||||
|
for dt in data_types:
|
||||||
|
dt_tools = _DATA_TYPE_TOOLS.get(dt)
|
||||||
|
if dt_tools:
|
||||||
|
tools.extend(dt_tools)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
# ── Code-based directory scanner ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _scan_directories(
|
||||||
|
paths: list[str],
|
||||||
|
extensions: list[str],
|
||||||
|
last_run_at: datetime | None,
|
||||||
|
) -> list[str]:
|
||||||
|
all_files: list[str] = []
|
||||||
|
ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set()
|
||||||
|
|
||||||
|
async def _walk(path: str, depth: int) -> None:
|
||||||
|
if depth > _MAX_SCAN_DEPTH:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="list_directory", data={"path": path})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
|
||||||
|
return
|
||||||
|
for entry in result.get("entries", []):
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
if not entry_path:
|
||||||
|
continue
|
||||||
|
if entry.get("type") == "directory":
|
||||||
|
await _walk(entry_path, depth + 1)
|
||||||
|
elif entry.get("type") == "file":
|
||||||
|
if ext_set:
|
||||||
|
dot_pos = entry_path.rfind(".")
|
||||||
|
file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else ""
|
||||||
|
if file_ext not in ext_set:
|
||||||
|
continue
|
||||||
|
all_files.append(entry_path)
|
||||||
|
|
||||||
|
for root in paths:
|
||||||
|
await _walk(root, depth=0)
|
||||||
|
|
||||||
|
if last_run_at is None:
|
||||||
|
return all_files
|
||||||
|
|
||||||
|
last_run_ms = int(last_run_at.timestamp() * 1000)
|
||||||
|
filtered: list[str] = []
|
||||||
|
for file_path in all_files:
|
||||||
|
try:
|
||||||
|
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
||||||
|
modified_at = meta.get("modifiedAt")
|
||||||
|
if modified_at is None:
|
||||||
|
filtered.append(file_path)
|
||||||
|
continue
|
||||||
|
if isinstance(modified_at, (int, float)):
|
||||||
|
mod_ms = int(modified_at)
|
||||||
|
else:
|
||||||
|
mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000)
|
||||||
|
if mod_ms > last_run_ms:
|
||||||
|
filtered.append(file_path)
|
||||||
|
except Exception:
|
||||||
|
filtered.append(file_path)
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
# ── Code-based entity fetchers ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_projects() -> list[dict]:
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="select", table="projects")
|
||||||
|
return result.get("rows", [])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to fetch projects: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
_DOMAIN_TABLE: dict[str, str] = {
|
||||||
|
"tasks": "tasks",
|
||||||
|
"notes": "notes",
|
||||||
|
"timelines": "timelines",
|
||||||
|
"projects": "projects",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
|
||||||
|
table = _DOMAIN_TABLE.get(domain)
|
||||||
|
if not table:
|
||||||
|
return []
|
||||||
|
filters: dict[str, Any] = {}
|
||||||
|
if project_id != "standalone" and domain != "projects":
|
||||||
|
filters["projectId"] = project_id
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table=table,
|
||||||
|
filters=filters if filters else None,
|
||||||
|
)
|
||||||
|
return result.get("rows", [])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
|
||||||
|
if not rows:
|
||||||
|
return f"No existing {domain}."
|
||||||
|
lines: list[str] = []
|
||||||
|
for r in rows:
|
||||||
|
if domain == "tasks":
|
||||||
|
desc = r.get("description") or ""
|
||||||
|
desc_part = f" — {desc[:120]}" if desc else ""
|
||||||
|
assignee = r.get("assignee") or r.get("assignees") or ""
|
||||||
|
due = r.get("dueDate") or r.get("due_date") or ""
|
||||||
|
meta = ", ".join(filter(None, [
|
||||||
|
f"priority: {r.get('priority', '')}" if r.get("priority") else "",
|
||||||
|
f"assignee: {assignee}" if assignee else "",
|
||||||
|
f"due: {due}" if due else "",
|
||||||
|
]))
|
||||||
|
lines.append(
|
||||||
|
f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}"
|
||||||
|
f" ({meta}, id: {r['id']})"
|
||||||
|
)
|
||||||
|
elif domain == "notes":
|
||||||
|
snippet = (r.get("content") or "")[:200].replace("\n", " ")
|
||||||
|
snippet_part = f"\n Preview: {snippet}" if snippet else ""
|
||||||
|
lines.append(
|
||||||
|
f" - {r.get('title', '')} (id: {r['id']}){snippet_part}"
|
||||||
|
)
|
||||||
|
elif domain == "timelines":
|
||||||
|
lines.append(
|
||||||
|
f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})"
|
||||||
|
)
|
||||||
|
elif domain == "projects":
|
||||||
|
summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120]
|
||||||
|
summary_part = f" — {summary}" if summary else ""
|
||||||
|
lines.append(
|
||||||
|
f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}"
|
||||||
|
f" (id: {r['id']})"
|
||||||
|
)
|
||||||
|
return f"Existing {domain}:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 1: LLM file classifier ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _classify_file(
|
||||||
|
file_path: str,
|
||||||
|
file_content: str,
|
||||||
|
projects: list[dict],
|
||||||
|
config_data_types: list[str],
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
custom_system_prompt: str | None = None,
|
||||||
|
) -> tuple[str, list[str], str | None]:
|
||||||
|
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
|
||||||
|
|
||||||
|
if not file_content.strip():
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
valid_project_ids = {p["id"] for p in projects}
|
||||||
|
|
||||||
|
def _fmt_project(p: dict) -> str:
|
||||||
|
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
|
||||||
|
summary_part = f" — {summary[:100]}" if summary else ""
|
||||||
|
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
|
||||||
|
|
||||||
|
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
|
||||||
|
|
||||||
|
domain_definitions = "\n".join(
|
||||||
|
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
|
||||||
|
for d in config_data_types
|
||||||
|
if d in _DOMAIN_DESCRIPTIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
if custom_system_prompt:
|
||||||
|
# Fixture-provided prompt takes absolute priority
|
||||||
|
system = custom_system_prompt.format_map(
|
||||||
|
{"domain_definitions": domain_definitions, "projects_list": projects_list}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
system = tracing.compile_prompt(
|
||||||
|
"batch_file_classifier",
|
||||||
|
fallback=_STEP1_SYSTEM_PROMPT,
|
||||||
|
variables={
|
||||||
|
"domain_definitions": domain_definitions,
|
||||||
|
"projects_list": projects_list,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = get_llm(callbacks=[langfuse_handler] if langfuse_handler else None)
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=system),
|
||||||
|
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
|
||||||
|
])
|
||||||
|
raw = _as_text(response.content).strip()
|
||||||
|
if raw.startswith("```"):
|
||||||
|
raw = raw.split("```")[1]
|
||||||
|
if raw.startswith("json"):
|
||||||
|
raw = raw[4:]
|
||||||
|
parsed = json.loads(raw.strip())
|
||||||
|
raw_project_id: str = str(parsed.get("project_id") or "new")
|
||||||
|
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
|
||||||
|
new_project_name: str | None = (
|
||||||
|
str(parsed["new_project_name"]).strip() or None
|
||||||
|
if project_id == "new" and parsed.get("new_project_name")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
domains: list[str] = [
|
||||||
|
d for d in parsed.get("domains", [])
|
||||||
|
if d in config_data_types
|
||||||
|
]
|
||||||
|
if not domains:
|
||||||
|
domains = list(config_data_types)
|
||||||
|
return project_id, domains, new_project_name
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"agent_runner: step1 classification failed for %r: %s", file_path, exc
|
||||||
|
)
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local agent runner (two-step per file) ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_local_agent(user_id: str, trigger_data: dict[str, Any], *, langfuse_handler: Any | None = None) -> None:
|
||||||
|
"""Execute a local directory agent run.
|
||||||
|
|
||||||
|
In the microservice world, trigger_data is a serialized dict from
|
||||||
|
the REST route (forwarded via Redis), containing the agent config
|
||||||
|
fields and run_context.
|
||||||
|
|
||||||
|
set_current_user() must be called BEFORE this function.
|
||||||
|
"""
|
||||||
|
run_context: dict = trigger_data.get("run_context", {})
|
||||||
|
agent_id = run_context.get("agent_id", str(uuid.uuid4()))
|
||||||
|
run_id = run_context.get("run_id")
|
||||||
|
|
||||||
|
_running_agents.add(agent_id)
|
||||||
|
|
||||||
|
# Extract config from trigger payload
|
||||||
|
directory_paths: list[str] = trigger_data.get("directory_paths", [])
|
||||||
|
if not directory_paths:
|
||||||
|
directory = trigger_data.get("directory", "")
|
||||||
|
if directory:
|
||||||
|
directory_paths = [directory]
|
||||||
|
|
||||||
|
data_types: list[str] = trigger_data.get("data_types", [])
|
||||||
|
file_extensions: list[str] = trigger_data.get("file_extensions", [])
|
||||||
|
prompt_template: str = trigger_data.get("prompt_template", "")
|
||||||
|
last_run_at_raw = trigger_data.get("last_run_at")
|
||||||
|
last_run_at: datetime | None = None
|
||||||
|
if last_run_at_raw:
|
||||||
|
if isinstance(last_run_at_raw, str):
|
||||||
|
last_run_at = datetime.fromisoformat(last_run_at_raw)
|
||||||
|
elif isinstance(last_run_at_raw, (int, float)):
|
||||||
|
last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
|
||||||
|
custom_section = (
|
||||||
|
f"User instructions:\n{prompt_template}"
|
||||||
|
if prompt_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create or load run log
|
||||||
|
run_log_id = run_id
|
||||||
|
if not run_log_id:
|
||||||
|
async with async_session() as db:
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
run_log_id = run_log.id
|
||||||
|
|
||||||
|
try:
|
||||||
|
# ── Scan directories ─────────────────────────────────────────
|
||||||
|
logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id)
|
||||||
|
file_paths = await _scan_directories(
|
||||||
|
paths=directory_paths,
|
||||||
|
extensions=file_extensions,
|
||||||
|
last_run_at=last_run_at,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not file_paths:
|
||||||
|
await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Fetch all projects once ──────────────────────────────────
|
||||||
|
projects = await _fetch_projects()
|
||||||
|
|
||||||
|
for file_path in file_paths:
|
||||||
|
try:
|
||||||
|
file_result = await execute_on_client(
|
||||||
|
action="read_file_content", data={"path": file_path}
|
||||||
|
)
|
||||||
|
file_content: str = file_result.get("content", "")
|
||||||
|
if not file_content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
items_processed += 1
|
||||||
|
|
||||||
|
# Step 1 — classify file
|
||||||
|
project_id, domains, new_project_name = await _classify_file(
|
||||||
|
file_path=file_path,
|
||||||
|
file_content=file_content,
|
||||||
|
projects=projects,
|
||||||
|
config_data_types=data_types,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2 — resolve project_id, fetch entities, process
|
||||||
|
if project_id == "new":
|
||||||
|
proj_name = new_project_name or "Untitled Project"
|
||||||
|
try:
|
||||||
|
proj_result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="projects",
|
||||||
|
data={"name": proj_name, "clientId": None},
|
||||||
|
)
|
||||||
|
created = proj_result.get("row", {})
|
||||||
|
effective_project_id = created.get("id", "standalone")
|
||||||
|
if "id" in created:
|
||||||
|
projects.append(created)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc)
|
||||||
|
effective_project_id = "standalone"
|
||||||
|
proj_name = "unknown"
|
||||||
|
project_context = (
|
||||||
|
f"Project: {proj_name} (id: {effective_project_id}). "
|
||||||
|
"Always set projectId to this id on every record you create."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
effective_project_id = project_id
|
||||||
|
proj = next((p for p in projects if p["id"] == project_id), None)
|
||||||
|
proj_name = proj.get("name", project_id) if proj else project_id
|
||||||
|
project_context = (
|
||||||
|
f"Project: {proj_name} (id: {project_id}). "
|
||||||
|
"Always set projectId to this id on every record you create."
|
||||||
|
)
|
||||||
|
|
||||||
|
domains = [d for d in domains if d != "projects"]
|
||||||
|
|
||||||
|
existing_blocks: list[str] = []
|
||||||
|
for domain in domains:
|
||||||
|
rows = await _fetch_domain_entities(domain, effective_project_id)
|
||||||
|
existing_blocks.append(_format_entities_for_context(domain, rows))
|
||||||
|
|
||||||
|
existing_context = "\n\n".join(existing_blocks)
|
||||||
|
|
||||||
|
system_prompt = tracing.compile_prompt(
|
||||||
|
"batch_processing",
|
||||||
|
fallback=_PROCESSING_SYSTEM_PROMPT,
|
||||||
|
variables={
|
||||||
|
"existing_context": existing_context,
|
||||||
|
"project_context": project_context,
|
||||||
|
"data_types": ", ".join(domains),
|
||||||
|
"custom_prompt_section": custom_section,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
processing_tools = _build_processing_tools(domains)
|
||||||
|
|
||||||
|
result_text = await _run_agent_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_message=(
|
||||||
|
f"Process this file and extract relevant information.\n\n"
|
||||||
|
f"File: {file_path}\n\nContent:\n{file_content}"
|
||||||
|
),
|
||||||
|
tools=processing_tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s file=%r result=%s",
|
||||||
|
run_log_id, file_path, result_text[:200],
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Error processing '{file_path}': {exc}")
|
||||||
|
logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Agent run failed: {exc}")
|
||||||
|
logger.error("agent_runner: run=%s failed: %s", run_log_id, exc)
|
||||||
|
finally:
|
||||||
|
_running_agents.discard(agent_id)
|
||||||
|
|
||||||
|
# ── Finalise ────────────────────────────────────────────────────
|
||||||
|
if errors and items_processed == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=items_created,
|
||||||
|
errors=errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Notify Electron that the run is complete via Redis
|
||||||
|
if run_context:
|
||||||
|
try:
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps({
|
||||||
|
"type": "run_complete",
|
||||||
|
"run_context": run_context,
|
||||||
|
"status": final_status,
|
||||||
|
}))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
||||||
|
|
||||||
|
|
||||||
|
async def run_cloud_agent(user_id: str, config_id: str, *, langfuse_handler: Any | None = None) -> None:
|
||||||
|
"""Execute a cloud connector agent run.
|
||||||
|
|
||||||
|
Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches
|
||||||
|
messages from the provider, and runs LLM extraction.
|
||||||
|
|
||||||
|
set_current_user() must be called BEFORE this function.
|
||||||
|
"""
|
||||||
|
from app.integrations import decrypt_token, encrypt_token, get_provider
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
config = result.scalar_one_or_none()
|
||||||
|
if config is None:
|
||||||
|
logger.error("agent_runner: cloud config %s not found", config_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create run log
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=config.id,
|
||||||
|
agent_type="cloud",
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
run_log_id = run_log.id
|
||||||
|
|
||||||
|
# ── Decrypt OAuth token ────────────────────────────────────────
|
||||||
|
if not config.oauth_token_encrypted:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status="error",
|
||||||
|
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
||||||
|
except ValueError as exc:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Failed to decrypt OAuth token: {exc}"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Instantiate provider ──────────────────────────────────────
|
||||||
|
try:
|
||||||
|
provider = get_provider(config.provider, credentials_info)
|
||||||
|
except ValueError as exc:
|
||||||
|
await _finalize_run(run_log_id, status="error", errors=[str(exc)])
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Fetch messages ────────────────────────────────────────────
|
||||||
|
since: datetime | None = config.last_run_at
|
||||||
|
if since is None:
|
||||||
|
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
|
||||||
|
if since.tzinfo is None:
|
||||||
|
since = since.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
items_processed = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if config.provider == "gmail":
|
||||||
|
raw_messages = await provider.fetch_messages(
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "outlook":
|
||||||
|
raw_messages = await provider.fetch_emails(
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "teams":
|
||||||
|
raw_messages = await provider.fetch_messages(
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_messages = []
|
||||||
|
except RuntimeError as exc:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Provider fetch failed: {exc}"],
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: cloud agent %s fetched %d item(s) from %s",
|
||||||
|
config.id, len(raw_messages), config.provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Extract + insert via LLM ─────────────────────────────────
|
||||||
|
try:
|
||||||
|
processing_tools = _build_processing_tools(config.data_types)
|
||||||
|
custom_section = (
|
||||||
|
f"User instructions:\n{config.prompt_template}"
|
||||||
|
if config.prompt_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
for msg in raw_messages:
|
||||||
|
content_text = msg.as_text
|
||||||
|
if not content_text:
|
||||||
|
continue
|
||||||
|
items_processed += 1
|
||||||
|
|
||||||
|
processing_prompt = tracing.compile_prompt(
|
||||||
|
"batch_cloud_processing",
|
||||||
|
fallback=_CLOUD_PROCESSING_PROMPT,
|
||||||
|
variables={
|
||||||
|
"data_types": ", ".join(config.data_types),
|
||||||
|
"project_context": "Determine the appropriate project from the message context.",
|
||||||
|
"file_list": f"Message from {config.provider} (id: {msg.id})",
|
||||||
|
"custom_prompt_section": custom_section,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _run_agent_with_tools(
|
||||||
|
system_prompt=processing_prompt,
|
||||||
|
user_message=f"Process this message content:\n\n{content_text[:8000]}",
|
||||||
|
tools=processing_tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Agent run failed: {exc}")
|
||||||
|
|
||||||
|
# ── Persist refreshed token ───────────────────────────────────
|
||||||
|
refreshed = getattr(provider, "refreshed_credentials", None)
|
||||||
|
if refreshed:
|
||||||
|
try:
|
||||||
|
new_encrypted = encrypt_token(refreshed)
|
||||||
|
async with async_session() as db:
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
||||||
|
)
|
||||||
|
cfg_row = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg_row:
|
||||||
|
cfg_row.oauth_token_encrypted = new_encrypted
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to persist refreshed token: %s", exc)
|
||||||
|
|
||||||
|
# ── Finalise ──────────────────────────────────────────────────
|
||||||
|
if errors and items_processed == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=0,
|
||||||
|
errors=errors,
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helper ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _finalize_run(
|
||||||
|
run_log_id: int | str,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
items_processed: int = 0,
|
||||||
|
items_created: int = 0,
|
||||||
|
errors: list[str] | None = None,
|
||||||
|
update_config_last_run: bool = False,
|
||||||
|
config_id: str | None = None,
|
||||||
|
config_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Persist the run outcome and optionally update last_run_at on the config."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentRunLog).where(AgentRunLog.id == run_log_id)
|
||||||
|
)
|
||||||
|
managed = result.scalar_one_or_none()
|
||||||
|
if managed is None:
|
||||||
|
logger.warning("agent_runner: run_log %s not found for finalization", run_log_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
managed.status = status
|
||||||
|
managed.items_processed = items_processed
|
||||||
|
managed.items_created = items_created
|
||||||
|
managed.errors = errors or []
|
||||||
|
managed.completed_at = now
|
||||||
|
|
||||||
|
if update_config_last_run and config_id:
|
||||||
|
if config_type == "local":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
elif config_type == "cloud":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc)
|
||||||
1
services/batch-agent/app/agents/__init__.py
Normal file
1
services/batch-agent/app/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Batch Agent Service domain agents and filesystem tools."""
|
||||||
83
services/batch-agent/app/agents/filesystem_agent.py
Normal file
83
services/batch-agent/app/agents/filesystem_agent.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from shared.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_directory(path: str) -> str:
|
||||||
|
"""List files and folders in a local directory on the user's device.
|
||||||
|
|
||||||
|
Returns a formatted listing of entries with name, type (file/directory),
|
||||||
|
and full path.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
return f"Directory '{path}' is empty or does not exist."
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
entry_type = entry.get("type", "unknown")
|
||||||
|
entry_name = entry.get("name", "")
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||||
|
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_file_content(path: str) -> str:
|
||||||
|
"""Read the text content of a local file on the user's device.
|
||||||
|
|
||||||
|
Returns the file content as a string. Large files may be truncated
|
||||||
|
by the Electron client.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return f"File '{path}' is empty or could not be read."
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_file_metadata(path: str) -> str:
|
||||||
|
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||||
|
|
||||||
|
Returns a formatted summary of the file's metadata.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_file_metadata",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
size = result.get("size", "unknown")
|
||||||
|
created = result.get("createdAt", "unknown")
|
||||||
|
modified = result.get("modifiedAt", "unknown")
|
||||||
|
extension = result.get("extension", "unknown")
|
||||||
|
name = result.get("name", path)
|
||||||
|
return (
|
||||||
|
f"File: {name}\n"
|
||||||
|
f" Extension: {extension}\n"
|
||||||
|
f" Size: {size} bytes\n"
|
||||||
|
f" Created: {created}\n"
|
||||||
|
f" Modified: {modified}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FILESYSTEM_TOOLS: list[Any] = [
|
||||||
|
list_directory,
|
||||||
|
read_file_content,
|
||||||
|
get_file_metadata,
|
||||||
|
]
|
||||||
108
services/batch-agent/app/integrations/__init__.py
Normal file
108
services/batch-agent/app/integrations/__init__.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Cloud provider integration utilities.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from shared.config instead of app.config.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
* Shared message dataclasses (EmailMessage, ChatMessage)
|
||||||
|
* get_provider() — factory for Gmail/MS Graph clients
|
||||||
|
* encrypt_token() / decrypt_token() — Fernet-based OAuth token encryption
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmailMessage:
|
||||||
|
id: str
|
||||||
|
subject: str
|
||||||
|
sender: str
|
||||||
|
body_text: str
|
||||||
|
date: datetime
|
||||||
|
labels: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{labels_str}\n"
|
||||||
|
f"Subject: {self.subject}\n\n"
|
||||||
|
f"{self.body_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
id: str
|
||||||
|
content: str
|
||||||
|
sender: str
|
||||||
|
channel: str | None
|
||||||
|
date: datetime
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
channel_str = f" [channel: {self.channel}]" if self.channel else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{channel_str}\n\n"
|
||||||
|
f"{self.content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fernet() -> Fernet:
|
||||||
|
key = settings.OAUTH_ENCRYPTION_KEY
|
||||||
|
if not key:
|
||||||
|
raise RuntimeError(
|
||||||
|
"OAUTH_ENCRYPTION_KEY is not set. "
|
||||||
|
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
||||||
|
)
|
||||||
|
return Fernet(key.encode() if isinstance(key, str) else key)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_token(token_info: dict) -> str:
|
||||||
|
if not isinstance(token_info, dict) or not token_info:
|
||||||
|
raise ValueError("token_info must be a non-empty dict")
|
||||||
|
plaintext = json.dumps(token_info).encode("utf-8")
|
||||||
|
return _get_fernet().encrypt(plaintext).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_token(encrypted: str) -> dict:
|
||||||
|
try:
|
||||||
|
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
|
||||||
|
return json.loads(plaintext)
|
||||||
|
except (InvalidToken, json.JSONDecodeError) as exc:
|
||||||
|
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(
|
||||||
|
provider: str,
|
||||||
|
credentials_info: dict,
|
||||||
|
) -> "GmailClient | MSGraphClient":
|
||||||
|
if provider == "gmail":
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
return GmailClient(credentials_info)
|
||||||
|
if provider in {"outlook", "teams"}:
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(credentials_info)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown cloud provider {provider!r}. "
|
||||||
|
"Supported: 'gmail', 'outlook', 'teams'."
|
||||||
|
)
|
||||||
252
services/batch-agent/app/integrations/gmail.py
Normal file
252
services/batch-agent/app/integrations/gmail.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""Gmail API client for cloud agent integration.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.integrations instead of
|
||||||
|
app.integrations (same relative path within the service).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import email
|
||||||
|
import html
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.integrations import EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gmail_query(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
labels: list[str] = cfg.get("labels", [])
|
||||||
|
if labels:
|
||||||
|
if len(labels) == 1:
|
||||||
|
parts.append(f"label:{labels[0]}")
|
||||||
|
else:
|
||||||
|
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
||||||
|
parts.append(f"({label_expr})")
|
||||||
|
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
for sender in senders:
|
||||||
|
parts.append(f"from:{sender}")
|
||||||
|
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw_html: str) -> str:
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
||||||
|
decoded = html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_body(payload: dict[str, Any]) -> str:
|
||||||
|
mime_type: str = payload.get("mimeType", "")
|
||||||
|
body: dict = payload.get("body", {})
|
||||||
|
parts: list[dict] = payload.get("parts", [])
|
||||||
|
|
||||||
|
if mime_type == "text/plain":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if mime_type == "text/html":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return _strip_html(raw)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
plain_fallback = ""
|
||||||
|
for part in parts:
|
||||||
|
part_mime = part.get("mimeType", "")
|
||||||
|
if part_mime == "text/plain":
|
||||||
|
return _parse_body(part)
|
||||||
|
if part_mime == "text/html" and not plain_fallback:
|
||||||
|
plain_fallback = _parse_body(part)
|
||||||
|
if part_mime.startswith("multipart/"):
|
||||||
|
nested = _parse_body(part)
|
||||||
|
if nested:
|
||||||
|
return nested
|
||||||
|
return plain_fallback
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_date(raw: str) -> datetime:
|
||||||
|
try:
|
||||||
|
parsed = email.utils.parsedate_to_datetime(raw)
|
||||||
|
if parsed.tzinfo is None:
|
||||||
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||||
|
return parsed.astimezone(timezone.utc)
|
||||||
|
except Exception:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
class GmailClient:
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
expiry_str: str | None = credentials_info.get("expiry")
|
||||||
|
expiry: datetime | None = None
|
||||||
|
if expiry_str:
|
||||||
|
try:
|
||||||
|
expiry = datetime.fromisoformat(
|
||||||
|
expiry_str.replace("Z", "+00:00")
|
||||||
|
).replace(tzinfo=timezone.utc)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._credentials = Credentials(
|
||||||
|
token=credentials_info.get("token"),
|
||||||
|
refresh_token=credentials_info.get("refresh_token"),
|
||||||
|
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||||
|
client_id=credentials_info.get("client_id"),
|
||||||
|
client_secret=credentials_info.get("client_secret"),
|
||||||
|
scopes=credentials_info.get("scopes"),
|
||||||
|
expiry=expiry,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
query = _build_gmail_query(filter_config, since)
|
||||||
|
logger.debug("gmail: executing search query %r", query)
|
||||||
|
return await asyncio.to_thread(self._fetch_sync, query)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
creds = self._credentials
|
||||||
|
if not creds.valid and creds.expired:
|
||||||
|
return None
|
||||||
|
if creds.token != self._credentials_info.get("token"):
|
||||||
|
result = {
|
||||||
|
"token": creds.token,
|
||||||
|
"refresh_token": creds.refresh_token,
|
||||||
|
"token_uri": creds.token_uri,
|
||||||
|
"client_id": creds.client_id,
|
||||||
|
"client_secret": creds.client_secret,
|
||||||
|
"scopes": list(creds.scopes or []),
|
||||||
|
}
|
||||||
|
if creds.expiry:
|
||||||
|
result["expiry"] = creds.expiry.isoformat()
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
||||||
|
import googleapiclient.discovery
|
||||||
|
import googleapiclient.errors
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
|
||||||
|
if self._credentials.expired and self._credentials.refresh_token:
|
||||||
|
try:
|
||||||
|
self._credentials.refresh(Request())
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
|
||||||
|
|
||||||
|
service = googleapiclient.discovery.build(
|
||||||
|
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
||||||
|
)
|
||||||
|
user_api = service.users()
|
||||||
|
|
||||||
|
ids: list[str] = []
|
||||||
|
page_token: str | None = None
|
||||||
|
while len(ids) < _MAX_MESSAGES:
|
||||||
|
batch_size = min(100, _MAX_MESSAGES - len(ids))
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"userId": "me",
|
||||||
|
"maxResults": batch_size,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
kwargs["q"] = query
|
||||||
|
if page_token:
|
||||||
|
kwargs["pageToken"] = page_token
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = user_api.messages().list(**kwargs).execute()
|
||||||
|
except googleapiclient.errors.HttpError as exc:
|
||||||
|
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
|
||||||
|
|
||||||
|
for msg in resp.get("messages", []):
|
||||||
|
ids.append(msg["id"])
|
||||||
|
|
||||||
|
page_token = resp.get("nextPageToken")
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info("gmail: fetching %d message(s)", len(ids))
|
||||||
|
|
||||||
|
messages: list[EmailMessage] = []
|
||||||
|
for msg_id in ids:
|
||||||
|
try:
|
||||||
|
msg = user_api.messages().get(
|
||||||
|
userId="me", id=msg_id, format="full"
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
h["name"].lower(): h["value"]
|
||||||
|
for h in msg.get("payload", {}).get("headers", [])
|
||||||
|
}
|
||||||
|
subject = headers.get("subject", "(no subject)")
|
||||||
|
sender = headers.get("from", "unknown")
|
||||||
|
date_raw = headers.get("date", "")
|
||||||
|
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
|
||||||
|
labels = msg.get("labelIds", [])
|
||||||
|
|
||||||
|
messages.append(EmailMessage(
|
||||||
|
id=msg_id,
|
||||||
|
subject=subject,
|
||||||
|
sender=sender,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
labels=labels,
|
||||||
|
))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("gmail: skipping message %s: %s", msg_id, exc)
|
||||||
|
|
||||||
|
logger.info("gmail: returned %d message(s)", len(messages))
|
||||||
|
return messages
|
||||||
266
services/batch-agent/app/integrations/ms_graph.py
Normal file
266
services/batch-agent/app/integrations/ms_graph.py
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
"""Microsoft Graph API client for Outlook and Teams.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import settings from shared.config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from app.integrations import ChatMessage, EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
_MAX_EMAILS = 200
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw: str) -> str:
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw)
|
||||||
|
import html as _html
|
||||||
|
decoded = _html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _odata_datetime(dt: datetime) -> str:
|
||||||
|
utc = dt.astimezone(timezone.utc)
|
||||||
|
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_email_filter(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
clauses: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
if senders:
|
||||||
|
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
|
||||||
|
clauses.append("(" + " or ".join(sender_clauses) + ")")
|
||||||
|
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
|
||||||
|
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
if to_dt.tzinfo is None:
|
||||||
|
to_dt = to_dt.replace(tzinfo=timezone.utc)
|
||||||
|
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " and ".join(clauses)
|
||||||
|
|
||||||
|
|
||||||
|
class MSGraphClient:
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
self._access_token: str = credentials_info.get("access_token", "")
|
||||||
|
self._original_access_token: str = self._access_token
|
||||||
|
self._refresh_token: str | None = credentials_info.get("refresh_token")
|
||||||
|
|
||||||
|
def _auth_headers(self) -> dict[str, str]:
|
||||||
|
return {"Authorization": f"Bearer {self._access_token}"}
|
||||||
|
|
||||||
|
async def _refresh_access_token(self) -> None:
|
||||||
|
import msal
|
||||||
|
|
||||||
|
app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=settings.MS_CLIENT_ID,
|
||||||
|
client_credential=settings.MS_CLIENT_SECRET,
|
||||||
|
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
|
||||||
|
)
|
||||||
|
scopes: list[str] = self._credentials_info.get("scope", "").split()
|
||||||
|
if not scopes:
|
||||||
|
scopes = ["https://graph.microsoft.com/.default"]
|
||||||
|
|
||||||
|
result = app.acquire_token_by_refresh_token(
|
||||||
|
self._refresh_token,
|
||||||
|
scopes=scopes,
|
||||||
|
)
|
||||||
|
if "access_token" not in result:
|
||||||
|
error = result.get("error_description", result.get("error", "unknown"))
|
||||||
|
raise RuntimeError(f"MS Graph token refresh failed: {error}")
|
||||||
|
|
||||||
|
self._access_token = result["access_token"]
|
||||||
|
if "refresh_token" in result:
|
||||||
|
self._refresh_token = result["refresh_token"]
|
||||||
|
self._credentials_info["refresh_token"] = result["refresh_token"]
|
||||||
|
self._credentials_info["access_token"] = self._access_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
if self._access_token != self._original_access_token:
|
||||||
|
return {**self._credentials_info, "access_token": self._access_token}
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get(
|
||||||
|
self,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
url: str,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
retry_on_401: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
|
||||||
|
await self._refresh_access_token()
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 429:
|
||||||
|
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
async def fetch_emails(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
odata_filter = _build_email_filter(filter_config, since)
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"$top": 50,
|
||||||
|
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
|
||||||
|
"$orderby": "receivedDateTime desc",
|
||||||
|
}
|
||||||
|
if odata_filter:
|
||||||
|
params["$filter"] = odata_filter
|
||||||
|
|
||||||
|
emails: list[EmailMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/messages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(emails) < _MAX_EMAILS:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
for item in data.get("value", []):
|
||||||
|
emails.append(self._parse_email(item))
|
||||||
|
if len(emails) >= _MAX_EMAILS:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
|
||||||
|
return emails
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[ChatMessage]:
|
||||||
|
cfg = filter_config or {}
|
||||||
|
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
|
||||||
|
params: dict[str, Any] = {"$top": 50}
|
||||||
|
if since:
|
||||||
|
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
|
||||||
|
|
||||||
|
messages: list[ChatMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(messages) < _MAX_MESSAGES:
|
||||||
|
try:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if exc.response.status_code in (403, 404):
|
||||||
|
logger.warning(
|
||||||
|
"ms_graph: /me/chats/getAllMessages not available (%d)",
|
||||||
|
exc.response.status_code,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
raise
|
||||||
|
|
||||||
|
for item in data.get("value", []):
|
||||||
|
msg = self._parse_teams_message(item)
|
||||||
|
if channel_filter and msg.channel:
|
||||||
|
if not any(c in msg.channel.lower() for c in channel_filter):
|
||||||
|
continue
|
||||||
|
messages.append(msg)
|
||||||
|
if len(messages) >= _MAX_MESSAGES:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_email(item: dict[str, Any]) -> EmailMessage:
|
||||||
|
subject: str = item.get("subject", "(no subject)") or "(no subject)"
|
||||||
|
sender_block = item.get("from", {}) or {}
|
||||||
|
sender_addr = (
|
||||||
|
(sender_block.get("emailAddress") or {}).get("address", "unknown")
|
||||||
|
)
|
||||||
|
date_str: str = item.get("receivedDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_body: str = body_block.get("content", "")
|
||||||
|
if content_type == "html":
|
||||||
|
body_text = _strip_html(raw_body)
|
||||||
|
else:
|
||||||
|
body_text = raw_body or item.get("bodyPreview", "")
|
||||||
|
body_text = body_text[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return EmailMessage(
|
||||||
|
id=item.get("id", ""),
|
||||||
|
subject=subject,
|
||||||
|
sender=sender_addr,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
|
||||||
|
msg_id: str = item.get("id", "")
|
||||||
|
sender_block = (item.get("from") or {}).get("user") or {}
|
||||||
|
sender: str = sender_block.get("displayName", "unknown")
|
||||||
|
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
|
||||||
|
|
||||||
|
date_str: str = item.get("createdDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_content: str = body_block.get("content", "")
|
||||||
|
content = _strip_html(raw_content) if content_type == "html" else raw_content
|
||||||
|
content = content[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
id=msg_id,
|
||||||
|
content=content,
|
||||||
|
sender=sender,
|
||||||
|
channel=channel,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
395
services/batch-agent/app/journey.py
Normal file
395
services/batch-agent/app/journey.py
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
"""Chatbot Journey — guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: imports from app.agents.filesystem_agent
|
||||||
|
and app.llm instead of monolith paths. Session state is in-memory (could
|
||||||
|
be moved to Redis for horizontal scaling in the future).
|
||||||
|
|
||||||
|
Journey flow:
|
||||||
|
1. Redis consumer dispatches ``journey_start`` with basic agent config.
|
||||||
|
2. Server creates an in-memory session, runs the setup LLM with
|
||||||
|
file-system tools to explore the directory, returns first question.
|
||||||
|
3. ``journey_message`` frames drive the conversation.
|
||||||
|
4. After 3-5 turns the LLM emits PROMPT_TEMPLATE_START / _END block.
|
||||||
|
5. Server parses the block and returns ``journey_reply`` with ``done=True``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
|
from shared.llm import get_llm
|
||||||
|
import app.tracing as tracing
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
|
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
||||||
|
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||||
|
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||||
|
|
||||||
|
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||||
|
_MAX_TURNS: int = 15
|
||||||
|
_MAX_TOOL_STEPS: int = 6
|
||||||
|
|
||||||
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JourneySession:
|
||||||
|
session_id: str
|
||||||
|
user_id: str
|
||||||
|
agent_type: str # "local" | "cloud"
|
||||||
|
directory: str
|
||||||
|
data_types: list[str]
|
||||||
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
system_prompt: str = ""
|
||||||
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
|
||||||
|
|
||||||
|
|
||||||
|
# session_id → session
|
||||||
|
_sessions: dict[str, JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||||
|
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||||
|
s = _sessions.get(session_id)
|
||||||
|
if s is None or s.is_expired():
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
return None
|
||||||
|
if s.user_id != user_id:
|
||||||
|
return None
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
|
Your job is to understand exactly what data the user wants to extract from their
|
||||||
|
local directory and produce a concise prompt_template that a separate AI will use
|
||||||
|
as its instruction set.
|
||||||
|
|
||||||
|
You have access to file-system tools to explore the user's directory:
|
||||||
|
- list_directory: to see folder structure
|
||||||
|
- read_file_content: to peek at file contents
|
||||||
|
- get_file_metadata: to check file info
|
||||||
|
|
||||||
|
The user's configured directory is: {directory}
|
||||||
|
Target data types: {data_types}
|
||||||
|
|
||||||
|
IMPORTANT — project assignment is handled automatically. You MUST NOT ask the user
|
||||||
|
about projects, projectId, or how to link records to projects. Never include
|
||||||
|
projectId logic or project creation instructions in the generated prompt_template.
|
||||||
|
|
||||||
|
Start by exploring the directory to understand its structure. Then ask concise,
|
||||||
|
focused questions one at a time. Cover only the topics relevant to the target
|
||||||
|
data types listed above:
|
||||||
|
|
||||||
|
1. Content type and format — confirmed by your exploration.
|
||||||
|
2. For TASKS (if in scope): field mapping for title, status, priority, content,
|
||||||
|
dueDate (where is the date found? what's the fallback when absent?),
|
||||||
|
and assignee (is there a person name to assign?).
|
||||||
|
3. For NOTES when TASKS are also in scope: note vs task distinction —
|
||||||
|
what makes something a note rather than a task?
|
||||||
|
4. For TIMELINES (if in scope): the date source — what marks a milestone or event?
|
||||||
|
5. Exclusions and special handling applicable to the target data types.
|
||||||
|
|
||||||
|
Keep asking focused questions until you are at least 90% confident. Then stop and
|
||||||
|
output the final prompt_template immediately, wrapped between these exact markers
|
||||||
|
on their own lines:
|
||||||
|
|
||||||
|
{template_start}
|
||||||
|
<the complete extraction prompt here>
|
||||||
|
{template_end}
|
||||||
|
|
||||||
|
The prompt_template must be concise (bullet points, ~15–25 lines maximum).
|
||||||
|
Specify only:
|
||||||
|
- Scope: what files/content qualify and what entity types to create.
|
||||||
|
- Field mapping rules per entity type (camelCase fields: title, status, priority,
|
||||||
|
dueDate, content, assignee, etc.).
|
||||||
|
- dueDate rule (if tasks in scope): source and fallback behaviour.
|
||||||
|
- Note vs task rule (if both in scope): the criterion that separates them.
|
||||||
|
- Timeline date rule (if timelines in scope): what constitutes a timeline event.
|
||||||
|
- Exclusion/filtering rules.
|
||||||
|
- 2–3 concrete mapping examples based on what you discovered.
|
||||||
|
|
||||||
|
{existing_section}Begin by exploring the directory, then ask your first question.\
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_system_prompt(
|
||||||
|
directory: str,
|
||||||
|
data_types: list[str],
|
||||||
|
existing_template: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
existing_section = (
|
||||||
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
|
f"---\n{existing_template}\n---\n"
|
||||||
|
if existing_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
# Use Langfuse compile_prompt ({{variable}} syntax) with Python .format() fallback
|
||||||
|
return tracing.compile_prompt(
|
||||||
|
"journey_system",
|
||||||
|
fallback=_SYSTEM_PROMPT_TEMPLATE,
|
||||||
|
variables={
|
||||||
|
"directory": directory,
|
||||||
|
"data_types": ", ".join(data_types),
|
||||||
|
"existing_section": existing_section,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_template(text: str) -> str | None:
|
||||||
|
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
||||||
|
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
||||||
|
return None
|
||||||
|
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
||||||
|
end_idx = text.index(_TEMPLATE_END)
|
||||||
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM call with tool support ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_llm_with_tools(
|
||||||
|
system_prompt: str,
|
||||||
|
history: list[dict[str, Any]],
|
||||||
|
tools: list[Any],
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||||
|
|
||||||
|
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||||
|
continue until a final text response is produced.
|
||||||
|
"""
|
||||||
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
|
for turn in history:
|
||||||
|
if turn["role"] == "user":
|
||||||
|
messages.append(HumanMessage(content=turn["content"]))
|
||||||
|
else:
|
||||||
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
|
callbacks = [langfuse_handler] if langfuse_handler else None
|
||||||
|
llm = get_llm(model=None, temperature=0.4, callbacks=callbacks)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(_MAX_TOOL_STEPS):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"journey: tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey: tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:800],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
# Fallback: exceeded max tool steps.
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Journey handlers (called from redis_consumer) ────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_start(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
*,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_start`` request.
|
||||||
|
|
||||||
|
Creates a session, runs the setup LLM with directory exploration,
|
||||||
|
and returns the ``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
agent_type = frame.get("agent_type", "local")
|
||||||
|
directory = frame.get("directory", "")
|
||||||
|
data_types = frame.get("data_types", [])
|
||||||
|
existing_template = frame.get("existing_template")
|
||||||
|
|
||||||
|
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||||
|
system_prompt = _build_system_prompt(directory, data_types, existing_template)
|
||||||
|
|
||||||
|
session = JourneySession(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
directory=directory,
|
||||||
|
data_types=data_types,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
seed_history: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||||
|
]
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history=seed_history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.extend(seed_history)
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
_sessions[session_id] = session
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey: session %s started for user %s (directory=%s)",
|
||||||
|
session_id,
|
||||||
|
user_id,
|
||||||
|
directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_message(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
*,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_message`` request.
|
||||||
|
|
||||||
|
Appends the user message, calls the LLM, and returns the
|
||||||
|
``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
message = frame.get("message", "")
|
||||||
|
|
||||||
|
session = get_journey_session(session_id, user_id)
|
||||||
|
if session is None:
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "Journey session not found or expired. Please start a new setup.",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
session.history.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
if not done:
|
||||||
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
|
if turns >= _MAX_TURNS:
|
||||||
|
nudge_content = (
|
||||||
|
"[System: You have enough information. Please generate the final "
|
||||||
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
|
)
|
||||||
|
session.history.append({"role": "user", "content": nudge_content})
|
||||||
|
|
||||||
|
nudge_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||||
|
|
||||||
|
prompt_template = _extract_template(nudge_reply)
|
||||||
|
if prompt_template is not None:
|
||||||
|
done = True
|
||||||
|
ai_reply = nudge_reply
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
if _TEMPLATE_START in ai_reply
|
||||||
|
else "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
logger.info("journey: session %s completed for user %s", session_id, user_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
76
services/batch-agent/app/llm.py
Normal file
76
services/batch-agent/app/llm.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
|
Identical to services/chat/app/llm.py. Uses shared.config.settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_litellm import ChatLiteLLM
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
|
if model.startswith("anthropic/"):
|
||||||
|
return settings.ANTHROPIC_API_KEY or None
|
||||||
|
if model.startswith("gemini/") or model.startswith("google/"):
|
||||||
|
return settings.GOOGLE_API_KEY or None
|
||||||
|
if model.startswith("cerebras/"):
|
||||||
|
return settings.CEREBRAS_API_KEY or None
|
||||||
|
if model.startswith("github/"):
|
||||||
|
return settings.GITHUB_TOKEN or None
|
||||||
|
if model.startswith("github_copilot/"):
|
||||||
|
return None
|
||||||
|
return settings.OPENAI_API_KEY or None
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm(
|
||||||
|
*,
|
||||||
|
model: str | None = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
callbacks: list | None = None,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
model = model or settings.LLM_MODEL
|
||||||
|
|
||||||
|
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||||
|
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||||
|
|
||||||
|
if settings.GITHUB_TOKEN:
|
||||||
|
os.environ.setdefault("GITHUB_TOKEN", settings.GITHUB_TOKEN)
|
||||||
|
|
||||||
|
if "/" in model:
|
||||||
|
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=_api_key_for_model(model),
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embed(text: str) -> list[float]:
|
||||||
|
model = settings.LLM_EMBED_MODEL
|
||||||
|
|
||||||
|
if model.startswith("github_copilot/") or "/" in model:
|
||||||
|
response = await litellm.aembedding(model=model, input=[text])
|
||||||
|
return response.data[0]["embedding"]
|
||||||
|
|
||||||
|
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
|
response = await client.embeddings.create(model=model, input=text)
|
||||||
|
return response.data[0].embedding
|
||||||
79
services/batch-agent/app/main.py
Normal file
79
services/batch-agent/app/main.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Batch Agent Service — FastAPI application.
|
||||||
|
|
||||||
|
Owns: agent_runner (local directory + cloud connectors), journey builder,
|
||||||
|
filesystem_agent, integrations (Gmail, MS Graph).
|
||||||
|
|
||||||
|
Communicates with WS Gateway via Redis:
|
||||||
|
- Subscribes to batch:request:{user_id} (journey_start, journey_message)
|
||||||
|
- Publishes to ws:out:{user_id} (journey replies + tool calls)
|
||||||
|
- BRPOP on tool:result:{call_id} (tool-call round-trip, 30s timeout)
|
||||||
|
- SET+EX on journey:{user_id} (journey session state, TTL 1800s)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure the repo root is on sys.path so ``shared`` is importable when
|
||||||
|
# running locally (in Docker the COPY already places it at /app/shared/).
|
||||||
|
_repo_root = str(Path(__file__).resolve().parents[3])
|
||||||
|
if _repo_root not in sys.path:
|
||||||
|
sys.path.insert(0, _repo_root)
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.redis_consumer import start_consumer
|
||||||
|
from app.routes import router
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
# Initialise Langfuse tracing (no-op if keys are missing)
|
||||||
|
from app.tracing import init_langfuse
|
||||||
|
init_langfuse()
|
||||||
|
|
||||||
|
logger.info("batch-agent: starting Redis consumer")
|
||||||
|
task = asyncio.create_task(start_consumer())
|
||||||
|
yield
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from app.tracing import shutdown as shutdown_langfuse
|
||||||
|
shutdown_langfuse()
|
||||||
|
|
||||||
|
from shared.db import engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
from shared.redis import redis_client
|
||||||
|
await redis_client.aclose()
|
||||||
|
|
||||||
|
logger.info("batch-agent: Redis consumer stopped")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Adiuva Batch Agent Service", lifespan=lifespan)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["GET", "POST"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health() -> dict[str, str]:
|
||||||
|
return {"status": "ok", "service": "batch-agent"}
|
||||||
183
services/batch-agent/app/redis_consumer.py
Normal file
183
services/batch-agent/app/redis_consumer.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Redis consumer for the Batch Agent Service.
|
||||||
|
|
||||||
|
Subscribes to batch:request:* (pattern) and dispatches:
|
||||||
|
- journey_start → handle_journey_start
|
||||||
|
- journey_message → handle_journey_message
|
||||||
|
- agent_trigger → run_local_agent / run_cloud_agent
|
||||||
|
|
||||||
|
Results are published back to ws:out:{user_id} via Redis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.redis import redis_client, batch_request_channel, ws_out_channel
|
||||||
|
|
||||||
|
import app.tracing as tracing
|
||||||
|
from shared.ws_context import set_current_user, clear_current_user
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None:
|
||||||
|
"""Publish a frame to the user's WS outbound channel."""
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(payload))
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None:
|
||||||
|
"""Handle a journey_start request from WS Gateway."""
|
||||||
|
from app.journey import handle_journey_start
|
||||||
|
|
||||||
|
session_id = data.get("session_id", "")
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
with tracing.trace_span(
|
||||||
|
name="journey_start",
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
input=data.get("directory", ""),
|
||||||
|
metadata={"data_types": data.get("data_types", [])},
|
||||||
|
tags=["journey"],
|
||||||
|
) as span:
|
||||||
|
langfuse_handler = tracing.get_langfuse_callback()
|
||||||
|
reply = await handle_journey_start(user_id, data, langfuse_handler=langfuse_handler)
|
||||||
|
tracing.link_prompt_to_trace(span, "journey_system")
|
||||||
|
span.update(output=reply.get("message", "")[:500])
|
||||||
|
await _publish_to_user(user_id, reply)
|
||||||
|
tracing.flush()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc)
|
||||||
|
await _publish_to_user(user_id, {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": f"Journey setup failed: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None:
|
||||||
|
"""Handle a journey_message from WS Gateway."""
|
||||||
|
from app.journey import handle_journey_message
|
||||||
|
|
||||||
|
session_id = data.get("session_id", "")
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
with tracing.trace_span(
|
||||||
|
name="journey_message",
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
input=data.get("message", "")[:200],
|
||||||
|
tags=["journey"],
|
||||||
|
) as span:
|
||||||
|
langfuse_handler = tracing.get_langfuse_callback()
|
||||||
|
reply = await handle_journey_message(user_id, data, langfuse_handler=langfuse_handler)
|
||||||
|
tracing.link_prompt_to_trace(span, "journey_system")
|
||||||
|
span.update(output=reply.get("message", "")[:500])
|
||||||
|
await _publish_to_user(user_id, reply)
|
||||||
|
tracing.flush()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc)
|
||||||
|
await _publish_to_user(user_id, {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": f"Journey processing failed: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None:
|
||||||
|
"""Handle an agent_trigger request from the REST route (forwarded via Redis)."""
|
||||||
|
from app.agent_runner import run_local_agent
|
||||||
|
|
||||||
|
run_context = data.get("run_context", {})
|
||||||
|
agent_id = run_context.get("agent_id", "")
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
with tracing.trace_span(
|
||||||
|
name="agent_trigger",
|
||||||
|
user_id=user_id,
|
||||||
|
trace_id=run_context.get("run_id"),
|
||||||
|
input={"agent_id": agent_id, "directory": data.get("directory", "")},
|
||||||
|
metadata={"data_types": data.get("data_types", [])},
|
||||||
|
tags=["batch", "agent_run"],
|
||||||
|
) as span:
|
||||||
|
langfuse_handler = tracing.get_langfuse_callback()
|
||||||
|
await run_local_agent(user_id, data, langfuse_handler=langfuse_handler)
|
||||||
|
tracing.link_prompt_to_trace(span, "batch_processing")
|
||||||
|
span.update(output={"status": "completed"})
|
||||||
|
tracing.flush()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc)
|
||||||
|
await _publish_to_user(user_id, {
|
||||||
|
"type": "run_complete",
|
||||||
|
"status": "error",
|
||||||
|
"run_context": run_context,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None:
|
||||||
|
"""Route a batch request to the correct handler."""
|
||||||
|
msg_type = message_data.get("type", "")
|
||||||
|
|
||||||
|
if msg_type == "journey_start":
|
||||||
|
await _handle_journey_start(user_id, message_data)
|
||||||
|
elif msg_type == "journey_message":
|
||||||
|
await _handle_journey_message(user_id, message_data)
|
||||||
|
elif msg_type == "agent_trigger":
|
||||||
|
await _handle_agent_trigger(user_id, message_data)
|
||||||
|
elif msg_type == "device_online":
|
||||||
|
logger.info("batch-agent: device_online user=%s device=%s", user_id, message_data.get("device_id", "?"))
|
||||||
|
else:
|
||||||
|
logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def start_consumer() -> None:
|
||||||
|
"""Subscribe to batch:request:* and dispatch incoming frames."""
|
||||||
|
pubsub = redis_client.pubsub()
|
||||||
|
await pubsub.psubscribe("batch:request:*")
|
||||||
|
logger.info("batch-agent: subscribed to batch:request:*")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for message in pubsub.listen():
|
||||||
|
if message["type"] != "pmessage":
|
||||||
|
continue
|
||||||
|
|
||||||
|
channel: str = message["channel"]
|
||||||
|
if isinstance(channel, bytes):
|
||||||
|
channel = channel.decode()
|
||||||
|
|
||||||
|
# Extract user_id from channel: batch:request:{user_id}
|
||||||
|
parts = channel.split(":", 2)
|
||||||
|
if len(parts) < 3:
|
||||||
|
continue
|
||||||
|
user_id = parts[2]
|
||||||
|
|
||||||
|
raw = message["data"]
|
||||||
|
if isinstance(raw, bytes):
|
||||||
|
raw = raw.decode()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("batch-agent: invalid JSON on channel %s", channel)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Dispatch in a separate task to avoid blocking the consumer
|
||||||
|
asyncio.create_task(_dispatch(user_id, data))
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("batch-agent: consumer shutting down")
|
||||||
|
finally:
|
||||||
|
await pubsub.punsubscribe("batch:request:*")
|
||||||
208
services/batch-agent/app/routes.py
Normal file
208
services/batch-agent/app/routes.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""Agent REST routes — catalog, billing checks, trigger.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas.
|
||||||
|
Agent trigger dispatches via Redis to the consumer instead of spawning
|
||||||
|
an in-process background task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Header, HTTPException, status
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.models import AgentRunLog
|
||||||
|
from shared.redis import redis_client, batch_request_channel
|
||||||
|
|
||||||
|
from app.agent_runner import is_agent_running
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||||
|
|
||||||
|
# ── Tier feature limits ───────────────────────────────────────────────
|
||||||
|
# Mirrors app/billing/tier_manager.py FEATURES dict.
|
||||||
|
FEATURES: dict[str, dict] = {
|
||||||
|
"free": {"batch_active": 1, "batch_runs_per_day": 3},
|
||||||
|
"pro": {"batch_active": 5, "batch_runs_per_day": 20},
|
||||||
|
"power": {"batch_active": 20, "batch_runs_per_day": 100},
|
||||||
|
"team": {"batch_active": -1, "batch_runs_per_day": -1},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _dt_ms(dt: datetime) -> int:
|
||||||
|
return int(dt.timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
|
def _to_data_types(values: list[str]) -> list[str]:
|
||||||
|
normalize = {
|
||||||
|
"task": "tasks", "tasks": "tasks",
|
||||||
|
"note": "notes", "notes": "notes",
|
||||||
|
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||||
|
"project": "projects", "projects": "projects",
|
||||||
|
}
|
||||||
|
seen: set[str] = set()
|
||||||
|
result: list[str] = []
|
||||||
|
for v in values:
|
||||||
|
mapped = normalize.get(v)
|
||||||
|
if mapped and mapped not in seen:
|
||||||
|
seen.add(mapped)
|
||||||
|
result.append(mapped)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
|
if limit != -1 and current_count >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
|
)
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
|
async def _enforce_run_frequency(tier: str, user_id: str) -> None:
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||||
|
if limit == -1:
|
||||||
|
return
|
||||||
|
today_start = datetime.now(timezone.utc).replace(
|
||||||
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
|
)
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.count(AgentRunLog.id)).where(
|
||||||
|
AgentRunLog.user_id == user_id,
|
||||||
|
AgentRunLog.started_at >= today_start,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
runs_today: int = result.scalar_one()
|
||||||
|
|
||||||
|
if runs_today >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Daily batch run limit ({limit}) reached for your tier.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/catalog")
|
||||||
|
async def get_agent_catalog(
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
) -> list[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "local_directory",
|
||||||
|
"name": "Local Directory Monitor",
|
||||||
|
"description": "Watches local directories, extracts data from files using AI",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "gmail",
|
||||||
|
"name": "Gmail Connector",
|
||||||
|
"description": "Scans Gmail inbox, extracts tasks/notes from emails",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "teams",
|
||||||
|
"name": "Microsoft Teams Connector",
|
||||||
|
"description": "Monitors Teams messages, extracts action items",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "outlook",
|
||||||
|
"name": "Outlook Connector",
|
||||||
|
"description": "Scans Outlook inbox, extracts tasks/notes",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Can-create check ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/can-create")
|
||||||
|
async def can_create_agent(
|
||||||
|
body: dict,
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
||||||
|
) -> dict:
|
||||||
|
active_agents = body.get("active_agents", 0)
|
||||||
|
limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"]
|
||||||
|
allowed = limit == -1 or active_agents < limit
|
||||||
|
return {
|
||||||
|
"allowed": allowed,
|
||||||
|
"tier": x_user_tier,
|
||||||
|
"active_agents": active_agents,
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Trigger ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/trigger", status_code=status.HTTP_202_ACCEPTED)
|
||||||
|
async def trigger_agent_run(
|
||||||
|
body: dict,
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
||||||
|
) -> dict:
|
||||||
|
"""Trigger a local agent run — creates run log and dispatches via Redis."""
|
||||||
|
active_agents = body.get("active_agents", 0)
|
||||||
|
_enforce_agent_limit(x_user_tier, active_agents)
|
||||||
|
await _enforce_run_frequency(x_user_tier, x_user_id)
|
||||||
|
|
||||||
|
stable_agent_id = body.get("agent_id") or str(uuid.uuid4())
|
||||||
|
|
||||||
|
if is_agent_running(stable_agent_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Agent is already running.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create run log in DB
|
||||||
|
async with async_session() as db:
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=stable_agent_id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=x_user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
run_log_id = run_log.id
|
||||||
|
|
||||||
|
run_context = {
|
||||||
|
"type": "agent_batch",
|
||||||
|
"run_id": run_log_id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Dispatch to the Redis consumer for processing
|
||||||
|
trigger_data = {
|
||||||
|
"type": "agent_trigger",
|
||||||
|
"directory": body.get("directory", ""),
|
||||||
|
"directory_paths": [body.get("directory", "")] if body.get("directory") else [],
|
||||||
|
"data_types": _to_data_types(body.get("what_to_extract", [])),
|
||||||
|
"file_extensions": body.get("file_extensions", []),
|
||||||
|
"prompt_template": body.get("custom_agent_prompt", ""),
|
||||||
|
"device_id": body.get("device_id", ""),
|
||||||
|
"run_context": run_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = batch_request_channel(x_user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(trigger_data))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": run_log_id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
"agent_type": "local",
|
||||||
|
"status": "running",
|
||||||
|
"items_processed": 0,
|
||||||
|
"items_created": 0,
|
||||||
|
"errors": [],
|
||||||
|
"started_at": _dt_ms(run_log.started_at),
|
||||||
|
"completed_at": None,
|
||||||
|
}
|
||||||
336
services/batch-agent/app/tracing.py
Normal file
336
services/batch-agent/app/tracing.py
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
"""Langfuse tracing & prompt management for the Batch Agent Service (v4 SDK).
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- ``init_langfuse()`` — initialise the singleton client at startup
|
||||||
|
- ``trace_span()`` — context manager that creates a trace + span
|
||||||
|
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
|
||||||
|
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
|
||||||
|
- ``flush()`` / ``shutdown()`` — lifecycle management
|
||||||
|
|
||||||
|
All functions gracefully degrade to no-ops when Langfuse is not configured,
|
||||||
|
so the service works identically with or without observability keys.
|
||||||
|
|
||||||
|
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── State ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_initialised: bool = False
|
||||||
|
_disabled: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_configured() -> bool:
|
||||||
|
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
def init_langfuse() -> None:
|
||||||
|
"""Initialise the Langfuse singleton. Call once at startup."""
|
||||||
|
global _initialised, _disabled
|
||||||
|
|
||||||
|
if _initialised or _disabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not _is_configured():
|
||||||
|
_disabled = True
|
||||||
|
logger.info("tracing: Langfuse keys not set — tracing disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import Langfuse
|
||||||
|
|
||||||
|
Langfuse(
|
||||||
|
secret_key=settings.LANGFUSE_SECRET_KEY,
|
||||||
|
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
||||||
|
host=settings.LANGFUSE_HOST,
|
||||||
|
)
|
||||||
|
_initialised = True
|
||||||
|
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
|
||||||
|
except Exception as exc:
|
||||||
|
_disabled = True
|
||||||
|
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client() -> Any | None:
|
||||||
|
"""Return the singleton Langfuse client, or *None* if disabled."""
|
||||||
|
if _disabled:
|
||||||
|
return None
|
||||||
|
if not _initialised:
|
||||||
|
init_langfuse()
|
||||||
|
if _disabled:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
from langfuse import get_client
|
||||||
|
return get_client()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _NullSpan:
|
||||||
|
"""Drop-in replacement when Langfuse is disabled."""
|
||||||
|
|
||||||
|
def update(self, **_: Any) -> None: ...
|
||||||
|
def set_trace_io(self, **_: Any) -> None: ...
|
||||||
|
def score_trace(self, **_: Any) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
# ── Trace context manager ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def trace_span(
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str | None = None,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
input: Any = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
"""Context manager that creates a Langfuse trace/span.
|
||||||
|
|
||||||
|
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
|
||||||
|
A ``CallbackHandler`` created inside this block auto-inherits the trace
|
||||||
|
context, so there is no need to pass trace IDs manually.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
yield _NullSpan()
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import Langfuse, propagate_attributes
|
||||||
|
|
||||||
|
trace_ctx: dict[str, str] = {}
|
||||||
|
if trace_id is not None:
|
||||||
|
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
|
||||||
|
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="span",
|
||||||
|
name=name,
|
||||||
|
input=input,
|
||||||
|
metadata=metadata or {},
|
||||||
|
**({"trace_context": trace_ctx} if trace_ctx else {}),
|
||||||
|
) as span:
|
||||||
|
with propagate_attributes(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
tags=tags or [],
|
||||||
|
):
|
||||||
|
yield span
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
|
||||||
|
yield _NullSpan()
|
||||||
|
|
||||||
|
|
||||||
|
# ── LangChain callback handler ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_langfuse_callback() -> Any | None:
|
||||||
|
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
|
||||||
|
|
||||||
|
Must be called inside a ``trace_span()`` block for proper linking.
|
||||||
|
Returns *None* when Langfuse is disabled.
|
||||||
|
"""
|
||||||
|
if _disabled and not _initialised:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse.langchain import CallbackHandler
|
||||||
|
return CallbackHandler()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prompt management ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt(
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
fallback: str | None = None,
|
||||||
|
cache_ttl_seconds: int = 300,
|
||||||
|
) -> str | None:
|
||||||
|
"""Fetch a managed prompt from Langfuse by name (without variable compilation).
|
||||||
|
|
||||||
|
Returns the raw prompt string, or *fallback* if the prompt is not
|
||||||
|
found or Langfuse is disabled.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"name": name,
|
||||||
|
"cache_ttl_seconds": cache_ttl_seconds,
|
||||||
|
}
|
||||||
|
if version is not None:
|
||||||
|
kwargs["version"] = version
|
||||||
|
if label is not None:
|
||||||
|
kwargs["label"] = label
|
||||||
|
prompt = lf.get_prompt(**kwargs)
|
||||||
|
return prompt.prompt
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: get_prompt(%s) failed: %s", name, exc)
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
def compile_prompt(
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
fallback: str,
|
||||||
|
variables: dict[str, str],
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
cache_ttl_seconds: int = 300,
|
||||||
|
) -> str:
|
||||||
|
"""Fetch a managed prompt from Langfuse and compile it with ``{{variables}}``.
|
||||||
|
|
||||||
|
If the prompt exists in Langfuse, uses the SDK's ``.compile(**variables)``
|
||||||
|
which replaces ``{{key}}`` placeholders. If Langfuse is disabled or the
|
||||||
|
prompt is not found, falls back to ``fallback.format(**variables)`` (Python
|
||||||
|
``{key}`` placeholders).
|
||||||
|
|
||||||
|
This means:
|
||||||
|
- Langfuse prompts use ``{{variable}}`` syntax.
|
||||||
|
- Hardcoded fallback strings use Python ``{variable}`` syntax.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return fallback.format(**variables)
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"name": name,
|
||||||
|
"cache_ttl_seconds": cache_ttl_seconds,
|
||||||
|
}
|
||||||
|
if version is not None:
|
||||||
|
kwargs["version"] = version
|
||||||
|
if label is not None:
|
||||||
|
kwargs["label"] = label
|
||||||
|
prompt = lf.get_prompt(**kwargs)
|
||||||
|
return prompt.compile(**variables)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: compile_prompt(%s) failed, using fallback: %s", name, exc)
|
||||||
|
return fallback.format(**variables)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_object(
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
cache_ttl_seconds: int = 300,
|
||||||
|
) -> Any | None:
|
||||||
|
"""Fetch the raw Langfuse prompt *object* (not the compiled string).
|
||||||
|
|
||||||
|
Returns ``None`` when Langfuse is disabled or the prompt is not found.
|
||||||
|
Use this when you need to pass the prompt to ``start_observation(prompt=...)``
|
||||||
|
for linking the prompt to a trace in the Langfuse UI.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"name": name,
|
||||||
|
"cache_ttl_seconds": cache_ttl_seconds,
|
||||||
|
}
|
||||||
|
if version is not None:
|
||||||
|
kwargs["version"] = version
|
||||||
|
if label is not None:
|
||||||
|
kwargs["label"] = label
|
||||||
|
return lf.get_prompt(**kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: get_prompt_object(%s) failed: %s", name, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def link_prompt_to_trace(
|
||||||
|
span: Any,
|
||||||
|
prompt_name: str,
|
||||||
|
*,
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Link a Langfuse managed prompt to a span/observation.
|
||||||
|
|
||||||
|
Uses the SDK v4 ``prompt=`` parameter so that the prompt version
|
||||||
|
appears linked in the Langfuse UI with metrics tracking.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None or isinstance(span, _NullSpan):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt = get_prompt_object(prompt_name, version=version, label=label)
|
||||||
|
if prompt is not None:
|
||||||
|
span.update(prompt=prompt)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Scoring helper ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def score_trace(
|
||||||
|
trace_id: str,
|
||||||
|
name: str,
|
||||||
|
value: float,
|
||||||
|
*,
|
||||||
|
comment: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: score_trace failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Shutdown ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def flush() -> None:
|
||||||
|
"""Flush pending Langfuse events."""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is not None:
|
||||||
|
try:
|
||||||
|
lf.flush()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: flush failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
def shutdown() -> None:
|
||||||
|
"""Flush and close the Langfuse client."""
|
||||||
|
global _initialised, _disabled
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is not None:
|
||||||
|
try:
|
||||||
|
lf.flush()
|
||||||
|
lf.shutdown()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: shutdown failed: %s", exc)
|
||||||
|
_initialised = False
|
||||||
|
_disabled = False
|
||||||
1
services/batch-agent/eval/__init__.py
Normal file
1
services/batch-agent/eval/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Batch Agent E2E evaluation harness."""
|
||||||
5
services/batch-agent/eval/__main__.py
Normal file
5
services/batch-agent/eval/__main__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Allow running the eval package as ``python -m eval``."""
|
||||||
|
|
||||||
|
from eval.cli import main
|
||||||
|
|
||||||
|
main()
|
||||||
285
services/batch-agent/eval/cli.py
Normal file
285
services/batch-agent/eval/cli.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""CLI entry point for the batch agent evaluation harness.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
# From services/batch-agent/:
|
||||||
|
python -m eval run # all agent fixtures, default model
|
||||||
|
python -m eval run --fixture=classify-invoices # single fixture
|
||||||
|
python -m eval run --models=gpt-4o,gpt-5.3-codex # multiple models
|
||||||
|
python -m eval run --mode=step1 # only step1 fixtures
|
||||||
|
python -m eval run --no-judge # skip LLM judge scoring
|
||||||
|
|
||||||
|
python -m eval interactive # interactive journey session
|
||||||
|
python -m eval interactive --fixture=journey-invoice-setup
|
||||||
|
python -m eval interactive --model=gpt-4o
|
||||||
|
python -m eval interactive --judge-model=github_copilot/gpt-4o-mini
|
||||||
|
|
||||||
|
python -m eval list # list all fixtures
|
||||||
|
python -m eval sync # sync fixtures to Langfuse datasets
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure the service root and repo root are in sys.path.
|
||||||
|
# Service root must come BEFORE repo root so its ``app/`` package
|
||||||
|
# shadows the monolith ``app/`` in the repo root.
|
||||||
|
_SERVICE_ROOT = Path(__file__).resolve().parent.parent
|
||||||
|
_REPO_ROOT = _SERVICE_ROOT.parent.parent
|
||||||
|
_sr = str(_SERVICE_ROOT)
|
||||||
|
_rr = str(_REPO_ROOT)
|
||||||
|
if _rr not in sys.path:
|
||||||
|
sys.path.insert(0, _rr)
|
||||||
|
# Always force service root to position 0 (python -m may have already
|
||||||
|
# added CWD further down the list, which loses to repo root).
|
||||||
|
if _sr in sys.path:
|
||||||
|
sys.path.remove(_sr)
|
||||||
|
sys.path.insert(0, _sr)
|
||||||
|
|
||||||
|
from eval.config import discover_fixtures, discover_journey_fixtures
|
||||||
|
from eval.runner import run_fixture_eval, print_results
|
||||||
|
from eval.interactive import run_interactive
|
||||||
|
from eval import langfuse_eval
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_logging(verbose: bool) -> None:
|
||||||
|
level = logging.DEBUG if verbose else logging.INFO
|
||||||
|
logging.basicConfig(
|
||||||
|
level=level,
|
||||||
|
format="%(asctime)s %(name)-20s %(levelname)-5s %(message)s",
|
||||||
|
datefmt="%H:%M:%S",
|
||||||
|
)
|
||||||
|
# Quiet noisy libraries
|
||||||
|
for name in ("httpx", "httpcore", "openai", "litellm", "urllib3"):
|
||||||
|
logging.getLogger(name).setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Batch Agent E2E evaluation harness",
|
||||||
|
prog="python -m eval",
|
||||||
|
)
|
||||||
|
sub = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
# ── run ───────────────────────────────────────────────────────
|
||||||
|
run_cmd = sub.add_parser("run", help="Run evaluations")
|
||||||
|
run_cmd.add_argument(
|
||||||
|
"--fixture", "-f",
|
||||||
|
help="Run only the named fixture (default: all)",
|
||||||
|
)
|
||||||
|
run_cmd.add_argument(
|
||||||
|
"--models", "-m",
|
||||||
|
default="github_copilot/gpt-5.3-codex",
|
||||||
|
help="Comma-separated list of models to test (default: github_copilot/gpt-5.3-codex)",
|
||||||
|
)
|
||||||
|
run_cmd.add_argument(
|
||||||
|
"--mode",
|
||||||
|
default=None,
|
||||||
|
choices=["step1", "step2", "full"],
|
||||||
|
help="Only run fixtures with this mode (default: all)",
|
||||||
|
)
|
||||||
|
run_cmd.add_argument(
|
||||||
|
"--no-judge",
|
||||||
|
action="store_true",
|
||||||
|
help="Skip LLM-as-judge scoring",
|
||||||
|
)
|
||||||
|
run_cmd.add_argument(
|
||||||
|
"--judge-model",
|
||||||
|
default="gpt-4o",
|
||||||
|
help="Model for LLM judge (default: gpt-4o)",
|
||||||
|
)
|
||||||
|
run_cmd.add_argument(
|
||||||
|
"--fixtures-dir",
|
||||||
|
default=None,
|
||||||
|
help="Path to fixtures directory (default: eval/fixtures/)",
|
||||||
|
)
|
||||||
|
run_cmd.add_argument("-v", "--verbose", action="store_true")
|
||||||
|
|
||||||
|
# ── list ──────────────────────────────────────────────────────
|
||||||
|
list_cmd = sub.add_parser("list", help="List available fixtures")
|
||||||
|
list_cmd.add_argument("--fixtures-dir", default=None)
|
||||||
|
list_cmd.add_argument("-v", "--verbose", action="store_true")
|
||||||
|
|
||||||
|
# ── sync ──────────────────────────────────────────────────────
|
||||||
|
sync_cmd = sub.add_parser("sync", help="Sync fixtures to Langfuse datasets")
|
||||||
|
sync_cmd.add_argument("--fixture", "-f", default=None, help="Sync only the named fixture")
|
||||||
|
sync_cmd.add_argument("--fixtures-dir", default=None)
|
||||||
|
sync_cmd.add_argument("-v", "--verbose", action="store_true")
|
||||||
|
|
||||||
|
# ── interactive ───────────────────────────────────────────────
|
||||||
|
inter_cmd = sub.add_parser("interactive", help="Interactive journey session (human-in-the-loop)")
|
||||||
|
inter_cmd.add_argument(
|
||||||
|
"--fixture", "-f",
|
||||||
|
help="Journey fixture to use (default: pick interactively)",
|
||||||
|
)
|
||||||
|
inter_cmd.add_argument(
|
||||||
|
"--model", "-m",
|
||||||
|
default="github_copilot/gpt-5.3-codex",
|
||||||
|
help="Model for the journey AI (default: github_copilot/gpt-5.3-codex)",
|
||||||
|
)
|
||||||
|
inter_cmd.add_argument(
|
||||||
|
"--judge-model",
|
||||||
|
default="gpt-4o",
|
||||||
|
help="Model for LLM judge (default: gpt-4o)",
|
||||||
|
)
|
||||||
|
inter_cmd.add_argument(
|
||||||
|
"--fixtures-dir",
|
||||||
|
default=None,
|
||||||
|
help="Path to fixtures directory (default: eval/fixtures/)",
|
||||||
|
)
|
||||||
|
inter_cmd.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
default=None,
|
||||||
|
help="Override sample data directory (e.g. path to private test files not in git)",
|
||||||
|
)
|
||||||
|
inter_cmd.add_argument("-v", "--verbose", action="store_true")
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _fixtures_dir(arg: str | None) -> Path | None:
|
||||||
|
if arg:
|
||||||
|
return Path(arg)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _cmd_run(args: argparse.Namespace) -> None:
|
||||||
|
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||||
|
if not fixtures:
|
||||||
|
print("No fixtures found. Create YAML files in eval/fixtures/.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.fixture:
|
||||||
|
fixtures = [f for f in fixtures if f.name == args.fixture]
|
||||||
|
if not fixtures:
|
||||||
|
print(f"Fixture '{args.fixture}' not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
models = [m.strip() for m in args.models.split(",")]
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
for fixture in fixtures:
|
||||||
|
if args.mode and fixture.mode != args.mode:
|
||||||
|
continue
|
||||||
|
results = await run_fixture_eval(
|
||||||
|
fixture,
|
||||||
|
models=models,
|
||||||
|
use_llm_judge=not args.no_judge,
|
||||||
|
judge_model=args.judge_model,
|
||||||
|
)
|
||||||
|
all_results.extend(results)
|
||||||
|
|
||||||
|
print_results(all_results)
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_list(args: argparse.Namespace) -> None:
|
||||||
|
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||||
|
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||||
|
|
||||||
|
if not fixtures and not journey_fixtures:
|
||||||
|
print("No fixtures found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if fixtures:
|
||||||
|
print(f"\n{'[Agent Fixtures]'}")
|
||||||
|
print(f"{'Name':<30} {'Mode':<6} {'Types':<25} {'Expected'}")
|
||||||
|
print("-" * 90)
|
||||||
|
for f in fixtures:
|
||||||
|
types = ", ".join(f.data_types)
|
||||||
|
n_expected = len(f.expected) + len(f.expected_classification)
|
||||||
|
print(f"{f.name:<30} {f.mode:<6} {types:<25} {n_expected}")
|
||||||
|
|
||||||
|
if journey_fixtures:
|
||||||
|
print(f"\n{'[Journey Fixtures]'}")
|
||||||
|
print(f"{'Name':<30} {'Types':<25} {'Messages':<10} {'Criteria'}")
|
||||||
|
print("-" * 90)
|
||||||
|
for f in journey_fixtures:
|
||||||
|
types = ", ".join(f.data_types)
|
||||||
|
print(f"{f.name:<30} {types:<25} {len(f.user_messages):<10} {len(f.expected_template_criteria)}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_sync(args: argparse.Namespace) -> None:
|
||||||
|
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||||
|
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||||
|
|
||||||
|
if args.fixture:
|
||||||
|
fixtures = [f for f in fixtures if f.name == args.fixture]
|
||||||
|
journey_fixtures = [f for f in journey_fixtures if f.name == args.fixture]
|
||||||
|
|
||||||
|
if not fixtures and not journey_fixtures:
|
||||||
|
print("No fixtures to sync.")
|
||||||
|
return
|
||||||
|
|
||||||
|
for fixture in fixtures:
|
||||||
|
name = langfuse_eval.sync_fixture_to_dataset(fixture)
|
||||||
|
if name:
|
||||||
|
print(f"Synced: {fixture.name} → {name}")
|
||||||
|
else:
|
||||||
|
print(f"Skipped: {fixture.name} (Langfuse not configured)")
|
||||||
|
|
||||||
|
for fixture in journey_fixtures:
|
||||||
|
name = langfuse_eval.sync_journey_fixture_to_dataset(fixture)
|
||||||
|
if name:
|
||||||
|
print(f"Synced: {fixture.name} → {name}")
|
||||||
|
else:
|
||||||
|
print(f"Skipped: {fixture.name} (Langfuse not configured)")
|
||||||
|
|
||||||
|
|
||||||
|
async def _cmd_interactive(args: argparse.Namespace) -> None:
|
||||||
|
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||||
|
if not journey_fixtures:
|
||||||
|
print("No journey fixtures found. Create YAML files with type: journey in eval/fixtures/.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.fixture:
|
||||||
|
fixtures = [f for f in journey_fixtures if f.name == args.fixture]
|
||||||
|
if not fixtures:
|
||||||
|
print(f"Journey fixture '{args.fixture}' not found.")
|
||||||
|
return
|
||||||
|
fixture = fixtures[0]
|
||||||
|
elif len(journey_fixtures) == 1:
|
||||||
|
fixture = journey_fixtures[0]
|
||||||
|
else:
|
||||||
|
# Let user pick
|
||||||
|
print("\nAvailable journey fixtures:")
|
||||||
|
for i, f in enumerate(journey_fixtures, 1):
|
||||||
|
print(f" {i}. {f.name} — {f.description[:60]}")
|
||||||
|
print()
|
||||||
|
try:
|
||||||
|
choice = int(input("Pick a fixture number: ").strip()) - 1
|
||||||
|
fixture = journey_fixtures[choice]
|
||||||
|
except (ValueError, IndexError, EOFError, KeyboardInterrupt):
|
||||||
|
print("Invalid choice.")
|
||||||
|
return
|
||||||
|
|
||||||
|
await run_interactive(
|
||||||
|
fixture,
|
||||||
|
model=args.model,
|
||||||
|
judge_model=args.judge_model,
|
||||||
|
data_dir=Path(args.data_dir).resolve() if args.data_dir else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = _parse_args()
|
||||||
|
_setup_logging(args.verbose)
|
||||||
|
|
||||||
|
if args.command == "run":
|
||||||
|
asyncio.run(_cmd_run(args))
|
||||||
|
elif args.command == "interactive":
|
||||||
|
asyncio.run(_cmd_interactive(args))
|
||||||
|
elif args.command == "list":
|
||||||
|
_cmd_list(args)
|
||||||
|
elif args.command == "sync":
|
||||||
|
_cmd_sync(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
220
services/batch-agent/eval/config.py
Normal file
220
services/batch-agent/eval/config.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""Eval configuration — YAML fixture loader and dataclasses.
|
||||||
|
|
||||||
|
Fixtures come in two families:
|
||||||
|
|
||||||
|
1. **Agent fixtures** — test the batch agent pipeline.
|
||||||
|
Three modes controlled by ``mode``:
|
||||||
|
|
||||||
|
``step1`` — classification prompt only.
|
||||||
|
``step2`` — processing prompt only.
|
||||||
|
``full`` — both steps in sequence.
|
||||||
|
|
||||||
|
2. **Journey fixtures** — test the prompt-template builder conversation
|
||||||
|
(unchanged).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
EvalMode = Literal["step1", "step2", "full"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExpectedRecord:
|
||||||
|
"""A single expected extraction result.
|
||||||
|
|
||||||
|
Only the fields specified are checked — unspecified fields are ignored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
table: str # tasks | notes | timelines | projects
|
||||||
|
fields: dict[str, Any] # field_name → expected_value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExpectedClassification:
|
||||||
|
"""Expected output of step-1 classification for one file."""
|
||||||
|
|
||||||
|
file: str # relative path to the sample file
|
||||||
|
project_id: str # expected matched project id, or "new"
|
||||||
|
domains: list[str] # expected domain list
|
||||||
|
new_project_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalFixture:
|
||||||
|
"""A complete test scenario loaded from YAML.
|
||||||
|
|
||||||
|
``mode`` determines which pipeline steps are exercised:
|
||||||
|
|
||||||
|
- **step1**: only ``_classify_file``
|
||||||
|
- **step2**: only the processing LLM + tool loop
|
||||||
|
- **full**: both steps in sequence (``run_local_agent``)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
mode: EvalMode
|
||||||
|
directory: str # relative path to sample files
|
||||||
|
data_types: list[str]
|
||||||
|
file_extensions: list[str]
|
||||||
|
models: list[str] # if empty, use CLI default
|
||||||
|
fixture_path: Path = field(default_factory=lambda: Path("."))
|
||||||
|
|
||||||
|
# ── Step-1 inputs (classification) ───────────────────────────
|
||||||
|
domain_definitions: str = ""
|
||||||
|
projects_list: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
custom_step1_prompt: str = ""
|
||||||
|
|
||||||
|
# ── Step-2 inputs (processing) ───────────────────────────────
|
||||||
|
existing_context: str = ""
|
||||||
|
project_context: str = ""
|
||||||
|
custom_prompt_section: str = ""
|
||||||
|
|
||||||
|
# ── Seed records for mock executor ───────────────────────────
|
||||||
|
seed_records: dict[str, list[dict]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# ── Expected outputs ─────────────────────────────────────────
|
||||||
|
expected_classification: list[ExpectedClassification] = field(default_factory=list)
|
||||||
|
expected: list[ExpectedRecord] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fixture_dir(self) -> Path:
|
||||||
|
"""Absolute path to the sample files directory."""
|
||||||
|
return self.fixture_path.parent / self.directory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, path: Path) -> "EvalFixture":
|
||||||
|
"""Load a fixture from a YAML file."""
|
||||||
|
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
mode: EvalMode = raw.get("mode", "full")
|
||||||
|
|
||||||
|
# Parse expected records (step2/full)
|
||||||
|
expected: list[ExpectedRecord] = []
|
||||||
|
for table, records in (raw.get("expected") or {}).items():
|
||||||
|
for rec in records:
|
||||||
|
expected.append(ExpectedRecord(table=table, fields=rec))
|
||||||
|
|
||||||
|
# Parse expected classification (step1/full)
|
||||||
|
expected_classification: list[ExpectedClassification] = []
|
||||||
|
for item in raw.get("expected_classification") or []:
|
||||||
|
expected_classification.append(ExpectedClassification(
|
||||||
|
file=item["file"],
|
||||||
|
project_id=item["project_id"],
|
||||||
|
domains=item.get("domains", []),
|
||||||
|
new_project_name=item.get("new_project_name"),
|
||||||
|
))
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=raw["name"],
|
||||||
|
description=raw.get("description", ""),
|
||||||
|
mode=mode,
|
||||||
|
directory=raw.get("directory", "sample_files"),
|
||||||
|
data_types=raw.get("data_types", ["tasks"]),
|
||||||
|
file_extensions=raw.get("file_extensions", []),
|
||||||
|
models=raw.get("models", []),
|
||||||
|
fixture_path=path,
|
||||||
|
# Step-1 inputs
|
||||||
|
domain_definitions=raw.get("domain_definitions", ""),
|
||||||
|
projects_list=raw.get("projects_list", []),
|
||||||
|
# Step-2 inputs
|
||||||
|
existing_context=raw.get("existing_context", ""),
|
||||||
|
project_context=raw.get("project_context", ""),
|
||||||
|
custom_prompt_section=raw.get("custom_prompt_section", ""),
|
||||||
|
# Shared
|
||||||
|
seed_records=raw.get("seed_records", {}),
|
||||||
|
expected_classification=expected_classification,
|
||||||
|
expected=expected,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def discover_fixtures(fixtures_dir: Path | None = None) -> list[EvalFixture]:
|
||||||
|
"""Find and load all YAML fixtures in the fixtures directory."""
|
||||||
|
if fixtures_dir is None:
|
||||||
|
fixtures_dir = Path(__file__).parent / "fixtures"
|
||||||
|
|
||||||
|
fixtures: list[EvalFixture] = []
|
||||||
|
if not fixtures_dir.is_dir():
|
||||||
|
logger.warning("eval: fixtures directory not found: %s", fixtures_dir)
|
||||||
|
return fixtures
|
||||||
|
|
||||||
|
for yaml_path in sorted(fixtures_dir.glob("*.yaml")):
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
|
||||||
|
if raw.get("type") == "journey":
|
||||||
|
continue # Skip journey fixtures
|
||||||
|
fixtures.append(EvalFixture.from_yaml(yaml_path))
|
||||||
|
logger.info("eval: loaded fixture %s from %s", fixtures[-1].name, yaml_path.name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("eval: failed to load fixture %s: %s", yaml_path.name, exc)
|
||||||
|
|
||||||
|
return fixtures
|
||||||
|
|
||||||
|
|
||||||
|
# ── Journey fixtures ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JourneyFixture:
|
||||||
|
"""A journey test scenario — tests the prompt_template builder conversation."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
directory: str # relative path to sample files
|
||||||
|
data_types: list[str]
|
||||||
|
expected_template_criteria: list[str] # what the template should contain/satisfy
|
||||||
|
user_messages: list[str] = field(default_factory=list) # for automated journey runs (unused in interactive mode)
|
||||||
|
models: list[str] = field(default_factory=list)
|
||||||
|
fixture_path: Path = field(default_factory=lambda: Path("."))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fixture_dir(self) -> Path:
|
||||||
|
"""Absolute path to the sample files directory."""
|
||||||
|
return self.fixture_path.parent / self.directory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, path: Path) -> "JourneyFixture":
|
||||||
|
"""Load a journey fixture from a YAML file."""
|
||||||
|
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=raw["name"],
|
||||||
|
description=raw.get("description", ""),
|
||||||
|
directory=raw.get("directory", "sample_files"),
|
||||||
|
data_types=raw.get("data_types", ["tasks"]),
|
||||||
|
user_messages=raw.get("user_messages", []),
|
||||||
|
expected_template_criteria=raw.get("expected_template_criteria", []),
|
||||||
|
models=raw.get("models", []),
|
||||||
|
fixture_path=path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def discover_journey_fixtures(fixtures_dir: Path | None = None) -> list[JourneyFixture]:
|
||||||
|
"""Find and load all journey YAML fixtures in the fixtures directory."""
|
||||||
|
if fixtures_dir is None:
|
||||||
|
fixtures_dir = Path(__file__).parent / "fixtures"
|
||||||
|
|
||||||
|
fixtures: list[JourneyFixture] = []
|
||||||
|
if not fixtures_dir.is_dir():
|
||||||
|
logger.warning("eval: fixtures directory not found: %s", fixtures_dir)
|
||||||
|
return fixtures
|
||||||
|
|
||||||
|
for yaml_path in sorted(fixtures_dir.glob("*.yaml")):
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
|
||||||
|
if raw.get("type") != "journey":
|
||||||
|
continue
|
||||||
|
fixtures.append(JourneyFixture.from_yaml(yaml_path))
|
||||||
|
logger.info("eval: loaded journey fixture %s from %s", fixtures[-1].name, yaml_path.name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("eval: failed to load journey fixture %s: %s", yaml_path.name, exc)
|
||||||
|
|
||||||
|
return fixtures
|
||||||
40
services/batch-agent/eval/fixtures/classify_invoices.yaml
Normal file
40
services/batch-agent/eval/fixtures/classify_invoices.yaml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# Fixture: classify-invoices (step1)
|
||||||
|
# Tests _STEP1_SYSTEM_PROMPT — file classification and project matching.
|
||||||
|
# Verifies that the LLM correctly matches files to existing projects
|
||||||
|
# and identifies the right data domains.
|
||||||
|
|
||||||
|
name: classify-invoices
|
||||||
|
mode: step1
|
||||||
|
description: >
|
||||||
|
Test file classification on Italian freelance invoices and meeting notes.
|
||||||
|
Verifies project matching and domain identification.
|
||||||
|
|
||||||
|
directory: sample_files/invoices
|
||||||
|
data_types: [tasks, notes, timelines]
|
||||||
|
file_extensions: [txt, md]
|
||||||
|
|
||||||
|
# ── Step-1 prompt variables ──────────────────────────────────────
|
||||||
|
domain_definitions: |
|
||||||
|
- tasks: Action items, deliverables, things to do — anything that someone needs to complete.
|
||||||
|
- notes: Meeting summaries, decisions, reference information — permanent knowledge entries.
|
||||||
|
- timelines: Project milestones, deadlines, scheduled events — specific dates that mark a point in the progress of a project.
|
||||||
|
|
||||||
|
projects_list:
|
||||||
|
- id: "proj-web-redesign"
|
||||||
|
name: "Redesign Sito Web Corporate"
|
||||||
|
status: "active"
|
||||||
|
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
|
||||||
|
- id: "proj-ecommerce"
|
||||||
|
name: "E-Commerce FashionStore"
|
||||||
|
status: "active"
|
||||||
|
aiSummary: "Next.js e-commerce platform for FashionStore srl"
|
||||||
|
|
||||||
|
# ── Expected classification results ─────────────────────────────
|
||||||
|
expected_classification:
|
||||||
|
- file: "sample_files/invoices/fattura_042.txt"
|
||||||
|
project_id: "proj-web-redesign"
|
||||||
|
domains: [tasks, notes, timelines]
|
||||||
|
|
||||||
|
- file: "sample_files/invoices/meeting_ecommerce.md"
|
||||||
|
project_id: "proj-ecommerce"
|
||||||
|
domains: [tasks, notes, timelines]
|
||||||
108
services/batch-agent/eval/fixtures/full_invoices.yaml
Normal file
108
services/batch-agent/eval/fixtures/full_invoices.yaml
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
# Fixture: full-invoices (full)
|
||||||
|
# Tests both _STEP1_SYSTEM_PROMPT and _PROCESSING_SYSTEM_PROMPT in sequence
|
||||||
|
# via run_local_agent(). Verifies end-to-end classification + extraction.
|
||||||
|
|
||||||
|
name: full-invoices
|
||||||
|
mode: full
|
||||||
|
description: >
|
||||||
|
End-to-end test: classify Italian invoices/meeting notes into the
|
||||||
|
correct project, then extract tasks, notes, and timeline events.
|
||||||
|
|
||||||
|
directory: sample_files/invoices
|
||||||
|
data_types: [tasks, notes, timelines]
|
||||||
|
file_extensions: [txt, md]
|
||||||
|
|
||||||
|
# ── Step-1 prompt variables ──────────────────────────────────────
|
||||||
|
domain_definitions: |
|
||||||
|
- tasks: Action items, deliverables, things to do — anything that someone needs to complete.
|
||||||
|
- notes: Meeting summaries, decisions, reference information — permanent knowledge entries.
|
||||||
|
- timelines: Project milestones, deadlines, scheduled events — specific dates that mark a point in the progress of a project.
|
||||||
|
|
||||||
|
projects_list:
|
||||||
|
- id: "proj-web-redesign"
|
||||||
|
name: "Redesign Sito Web Corporate"
|
||||||
|
status: "active"
|
||||||
|
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
|
||||||
|
- id: "proj-ecommerce"
|
||||||
|
name: "E-Commerce FashionStore"
|
||||||
|
status: "active"
|
||||||
|
aiSummary: "Next.js e-commerce platform for FashionStore srl"
|
||||||
|
|
||||||
|
# ── Step-2 prompt variables ──────────────────────────────────────
|
||||||
|
existing_context: |
|
||||||
|
Existing tasks:
|
||||||
|
(none)
|
||||||
|
|
||||||
|
Existing notes:
|
||||||
|
(none)
|
||||||
|
|
||||||
|
Existing timelines:
|
||||||
|
(none)
|
||||||
|
|
||||||
|
project_context: ""
|
||||||
|
|
||||||
|
custom_prompt_section: |
|
||||||
|
User instructions:
|
||||||
|
Estrai i dati dai file come segue:
|
||||||
|
- TASK: ogni azione da fare, deliverable, o item con scadenza.
|
||||||
|
Mappa "URGENTE" o "ALTA PRIORITÀ" → priority: high.
|
||||||
|
Mappa "media priorità" → priority: medium.
|
||||||
|
Mappa "bassa priorità" → priority: low.
|
||||||
|
Se un item è marcato come "completato" o [x], impostalo status: done.
|
||||||
|
Altrimenti status: todo.
|
||||||
|
- NOTE: riassunti di meeting, decisioni prese, note tecniche.
|
||||||
|
- TIMELINE: date di scadenza, milestone, meeting futuri.
|
||||||
|
Imposta sempre isAiSuggested=1.
|
||||||
|
|
||||||
|
# ── Seed records (pre-existing DB state) ─────────────────────────
|
||||||
|
seed_records:
|
||||||
|
projects:
|
||||||
|
- id: "proj-web-redesign"
|
||||||
|
name: "Redesign Sito Web Corporate"
|
||||||
|
status: "active"
|
||||||
|
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
|
||||||
|
- id: "proj-ecommerce"
|
||||||
|
name: "E-Commerce FashionStore"
|
||||||
|
status: "active"
|
||||||
|
aiSummary: "Next.js e-commerce platform for FashionStore srl"
|
||||||
|
tasks: []
|
||||||
|
notes: []
|
||||||
|
timelines: []
|
||||||
|
|
||||||
|
# ── Expected classification (step 1) ─────────────────────────────
|
||||||
|
expected_classification:
|
||||||
|
- file: "sample_files/invoices/fattura_042.txt"
|
||||||
|
project_id: "proj-web-redesign"
|
||||||
|
domains: [tasks, notes, timelines]
|
||||||
|
|
||||||
|
- file: "sample_files/invoices/meeting_ecommerce.md"
|
||||||
|
project_id: "proj-ecommerce"
|
||||||
|
domains: [tasks, notes, timelines]
|
||||||
|
|
||||||
|
# ── Expected extractions (step 2) ────────────────────────────────
|
||||||
|
expected:
|
||||||
|
tasks:
|
||||||
|
- title: "Sviluppo frontend React"
|
||||||
|
priority: "high"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Integrazione API backend"
|
||||||
|
priority: "medium"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Testing cross-browser e fix bug responsive"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Preparare wireframe homepage"
|
||||||
|
priority: "high"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Setup progetto Next.js e configurare CI/CD"
|
||||||
|
priority: "medium"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Ricerca plugin Stripe per gestione abbonamenti"
|
||||||
|
priority: "low"
|
||||||
|
status: "todo"
|
||||||
|
|
||||||
|
notes:
|
||||||
|
- title: "Meeting Kickoff Progetto E-Commerce"
|
||||||
|
|
||||||
|
timelines:
|
||||||
|
- title: "MVP E-Commerce pronto"
|
||||||
|
- title: "Meeting di revisione"
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
# Journey Fixture: journey-invoice-setup
|
||||||
|
# Used by `python -m eval interactive` for human-in-the-loop testing
|
||||||
|
# of the journey chatbot's prompt-building conversation.
|
||||||
|
|
||||||
|
type: journey
|
||||||
|
name: journey-invoice-setup
|
||||||
|
description: >
|
||||||
|
Interactive test for the journey chatbot — explore a directory of
|
||||||
|
Italian invoices and meeting notes, answer the chatbot's questions,
|
||||||
|
and verify it produces a well-structured prompt_template for data
|
||||||
|
extraction.
|
||||||
|
|
||||||
|
directory: sample_files/invoices
|
||||||
|
data_types: [tasks, notes, timelines, projects]
|
||||||
|
|
||||||
|
# Criteria the generated prompt_template must satisfy
|
||||||
|
# Each is scored 0-1 by an LLM judge
|
||||||
|
expected_template_criteria:
|
||||||
|
- "Mentions creating tasks from action items and work descriptions"
|
||||||
|
- "Mentions creating notes from meeting summaries"
|
||||||
|
- "Mentions extracting timeline events from deadlines and meeting dates"
|
||||||
|
- "Mentions creating projects from relevant information"
|
||||||
|
- "Sets isAiSuggested=1 on all created records"
|
||||||
|
- "Does NOT include projectId assignment logic"
|
||||||
|
- "Uses camelCase field names (title, status, priority, dueDate, content)"
|
||||||
|
|
||||||
|
# Models to test (empty = use CLI --models default)
|
||||||
|
models: []
|
||||||
81
services/batch-agent/eval/fixtures/process_invoices.yaml
Normal file
81
services/batch-agent/eval/fixtures/process_invoices.yaml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# Fixture: process-invoices (step2)
|
||||||
|
# Tests _PROCESSING_SYSTEM_PROMPT — data extraction & tool calling.
|
||||||
|
# The classification step is skipped; prompt variables are injected directly.
|
||||||
|
|
||||||
|
name: process-invoices
|
||||||
|
mode: step2
|
||||||
|
description: >
|
||||||
|
Test data extraction from Italian freelance invoices.
|
||||||
|
Verifies correct record creation via tool calls with the right
|
||||||
|
fields, priorities, and status values.
|
||||||
|
|
||||||
|
directory: sample_files/invoices
|
||||||
|
data_types: [tasks, notes, timelines]
|
||||||
|
file_extensions: [txt, md]
|
||||||
|
|
||||||
|
# ── Step-2 prompt variables ──────────────────────────────────────
|
||||||
|
existing_context: |
|
||||||
|
Existing tasks:
|
||||||
|
(none)
|
||||||
|
|
||||||
|
Existing notes:
|
||||||
|
(none)
|
||||||
|
|
||||||
|
Existing timelines:
|
||||||
|
(none)
|
||||||
|
|
||||||
|
project_context: >
|
||||||
|
Project: Redesign Sito Web Corporate (id: proj-web-redesign).
|
||||||
|
Always set projectId to this id on every record you create.
|
||||||
|
|
||||||
|
custom_prompt_section: |
|
||||||
|
User instructions:
|
||||||
|
Estrai i dati dai file come segue:
|
||||||
|
- TASK: ogni azione da fare, deliverable, o item con scadenza.
|
||||||
|
Mappa "URGENTE" o "ALTA PRIORITÀ" → priority: high.
|
||||||
|
Mappa "media priorità" → priority: medium.
|
||||||
|
Mappa "bassa priorità" → priority: low.
|
||||||
|
Se un item è marcato come "completato" o [x], impostalo status: done.
|
||||||
|
Altrimenti status: todo.
|
||||||
|
- NOTE: riassunti di meeting, decisioni prese, note tecniche.
|
||||||
|
Il titolo deve essere descrittivo. Il content deve includere tutti i dettagli.
|
||||||
|
- TIMELINE: date di scadenza, milestone, meeting futuri.
|
||||||
|
Imposta sempre isAiSuggested=1.
|
||||||
|
|
||||||
|
# ── Seed records (pre-existing DB state) ─────────────────────────
|
||||||
|
seed_records:
|
||||||
|
projects:
|
||||||
|
- id: "proj-web-redesign"
|
||||||
|
name: "Redesign Sito Web Corporate"
|
||||||
|
status: "active"
|
||||||
|
tasks: []
|
||||||
|
notes: []
|
||||||
|
timelines: []
|
||||||
|
|
||||||
|
# ── Expected extractions ─────────────────────────────────────────
|
||||||
|
expected:
|
||||||
|
tasks:
|
||||||
|
- title: "Sviluppo frontend React"
|
||||||
|
priority: "high"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Integrazione API backend"
|
||||||
|
priority: "medium"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Testing cross-browser e fix bug responsive"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Preparare wireframe homepage"
|
||||||
|
priority: "high"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Setup progetto Next.js e configurare CI/CD"
|
||||||
|
priority: "medium"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Ricerca plugin Stripe per gestione abbonamenti"
|
||||||
|
priority: "low"
|
||||||
|
status: "todo"
|
||||||
|
|
||||||
|
notes:
|
||||||
|
- title: "Meeting Kickoff Progetto E-Commerce"
|
||||||
|
|
||||||
|
timelines:
|
||||||
|
- title: "MVP E-Commerce pronto"
|
||||||
|
- title: "Meeting di revisione"
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
FATTURA N. 2026-0042
|
||||||
|
Data: 15 Marzo 2026
|
||||||
|
Cliente: Studio Architettura Bianchi
|
||||||
|
|
||||||
|
Progetto: Redesign Sito Web Corporate
|
||||||
|
|
||||||
|
Descrizione lavori:
|
||||||
|
- Sviluppo frontend React (40 ore) — URGENTE, completare entro 20 marzo
|
||||||
|
- Integrazione API backend (20 ore) — priorità media
|
||||||
|
- Design UI/UX mockup homepage (8 ore) — completato
|
||||||
|
- Testing cross-browser e fix bug responsive (12 ore) — da iniziare
|
||||||
|
|
||||||
|
Totale: €4.800,00 + IVA
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Meeting di revisione previsto per il 18 marzo alle 10:00.
|
||||||
|
Il cliente ha richiesto modifiche al layout mobile della sezione contatti.
|
||||||
|
Attendere conferma budget aggiuntivo per sezione blog.
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
# Meeting Notes - Kickoff Progetto E-Commerce
|
||||||
|
|
||||||
|
**Data:** 10 Marzo 2026
|
||||||
|
**Partecipanti:** Marco R., Giulia T., Cliente (FashionStore srl)
|
||||||
|
|
||||||
|
## Decisioni prese
|
||||||
|
|
||||||
|
1. **Piattaforma**: Next.js + Stripe per i pagamenti
|
||||||
|
2. **Timeline**: MVP pronto entro 30 aprile 2026
|
||||||
|
3. **Budget**: €12.000 totale, €4.000 anticipo già ricevuto
|
||||||
|
|
||||||
|
## Action items
|
||||||
|
|
||||||
|
- [ ] Marco: preparare wireframe homepage entro 14 marzo — ALTA PRIORITÀ
|
||||||
|
- [ ] Giulia: setup progetto Next.js e configurare CI/CD — media priorità
|
||||||
|
- [ ] Marco: ricerca plugin Stripe per gestione abbonamenti — bassa priorità
|
||||||
|
- [x] Giulia: inviare contratto firmato al cliente — COMPLETATO
|
||||||
|
|
||||||
|
## Note aggiuntive
|
||||||
|
|
||||||
|
Il cliente vuole un design minimalista, ispirato a Zara.com.
|
||||||
|
Colori primari: nero, bianco, oro.
|
||||||
|
Font: Inter per body, Playfair Display per headings.
|
||||||
|
|
||||||
|
Prossimo meeting: 24 marzo 2026 ore 15:00.
|
||||||
471
services/batch-agent/eval/interactive.py
Normal file
471
services/batch-agent/eval/interactive.py
Normal file
@@ -0,0 +1,471 @@
|
|||||||
|
"""Interactive journey session — human-in-the-loop CLI conversation.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Show the system prompt used by the journey AI.
|
||||||
|
2. Start the journey (AI explores files, asks first question).
|
||||||
|
3. User types responses in the terminal — AI replies.
|
||||||
|
4. User types `/done` to end the conversation.
|
||||||
|
5. User writes a comment about the interaction quality.
|
||||||
|
6. LLM judge scores the conversation + generated template.
|
||||||
|
7. Results are reported to Langfuse.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
python -m eval interactive # pick a fixture interactively
|
||||||
|
python -m eval interactive --fixture=journey-invoice-setup
|
||||||
|
python -m eval interactive --model=gpt-4o
|
||||||
|
python -m eval interactive --judge-model=github_copilot/gpt-4o-mini
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from eval.config import JourneyFixture, discover_journey_fixtures
|
||||||
|
from eval.mock_executor import MockExecutor
|
||||||
|
from eval import langfuse_eval
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Special commands ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_CMD_DONE = "/done"
|
||||||
|
_CMD_QUIT = "/quit"
|
||||||
|
_CMD_TEMPLATE = "/template"
|
||||||
|
_CMD_HELP = "/help"
|
||||||
|
|
||||||
|
_HELP_TEXT = f"""\
|
||||||
|
{_CMD_DONE} — End the conversation and proceed to evaluation
|
||||||
|
{_CMD_QUIT} — Abort without evaluation
|
||||||
|
{_CMD_TEMPLATE} — Show the generated template (if any)
|
||||||
|
{_CMD_HELP} — Show this help"""
|
||||||
|
|
||||||
|
# ── Terminal colours (ANSI) ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
_C_RESET = "\033[0m"
|
||||||
|
_C_BOLD = "\033[1m"
|
||||||
|
_C_DIM = "\033[2m"
|
||||||
|
_C_CYAN = "\033[36m"
|
||||||
|
_C_GREEN = "\033[32m"
|
||||||
|
_C_YELLOW = "\033[33m"
|
||||||
|
_C_MAGENTA = "\033[35m"
|
||||||
|
_C_RED = "\033[31m"
|
||||||
|
_C_BLUE = "\033[34m"
|
||||||
|
|
||||||
|
|
||||||
|
def _print_header(text: str) -> None:
|
||||||
|
print(f"\n{_C_BOLD}{_C_CYAN}{'═' * 80}")
|
||||||
|
print(f" {text}")
|
||||||
|
print(f"{'═' * 80}{_C_RESET}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_ai(text: str) -> None:
|
||||||
|
print(f"\n{_C_GREEN}{_C_BOLD}AI:{_C_RESET} {text}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_system(text: str) -> None:
|
||||||
|
print(f"{_C_DIM}{text}{_C_RESET}")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_score(label: str, score: float) -> None:
|
||||||
|
if score >= 0.7:
|
||||||
|
color = _C_GREEN
|
||||||
|
tag = "PASS"
|
||||||
|
elif score >= 0.4:
|
||||||
|
color = _C_YELLOW
|
||||||
|
tag = "PARTIAL"
|
||||||
|
else:
|
||||||
|
color = _C_RED
|
||||||
|
tag = "FAIL"
|
||||||
|
print(f" {color}{tag:>7}{_C_RESET} ({score:.1f}) {label}")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result type ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InteractiveResult:
|
||||||
|
fixture_name: str
|
||||||
|
model: str
|
||||||
|
judge_model: str
|
||||||
|
prompt_template: str | None
|
||||||
|
conversation: list[dict[str, str]]
|
||||||
|
user_comment: str
|
||||||
|
done: bool
|
||||||
|
criteria_scores: dict[str, float]
|
||||||
|
overall_score: float
|
||||||
|
judge_reasoning: str
|
||||||
|
elapsed_seconds: float
|
||||||
|
|
||||||
|
def summary(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"fixture": self.fixture_name,
|
||||||
|
"model": self.model,
|
||||||
|
"judge_model": self.judge_model,
|
||||||
|
"done": self.done,
|
||||||
|
"turns": len([c for c in self.conversation if c["role"] == "user"]),
|
||||||
|
"overall_score": round(self.overall_score, 3),
|
||||||
|
"user_comment": self.user_comment,
|
||||||
|
"criteria_scores": {k: round(v, 3) for k, v in self.criteria_scores.items()},
|
||||||
|
"elapsed_s": round(self.elapsed_seconds, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM judge ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_INTERACTIVE_JUDGE_SYSTEM = """\
|
||||||
|
You are an evaluation judge for AI-generated prompt templates produced during
|
||||||
|
an interactive conversation between a human and a journey chatbot.
|
||||||
|
|
||||||
|
The chatbot explored a directory and through multi-turn conversation with the
|
||||||
|
user produced a prompt_template — an instruction set for a data-extraction agent.
|
||||||
|
|
||||||
|
You have access to:
|
||||||
|
- The full conversation transcript
|
||||||
|
- The generated prompt_template (if any)
|
||||||
|
- The user's own comment about the interaction
|
||||||
|
- A list of quality criteria
|
||||||
|
|
||||||
|
Score each criterion from 0 to 1:
|
||||||
|
- 1.0: Fully satisfied
|
||||||
|
- 0.5: Partially satisfied
|
||||||
|
- 0.0: Not satisfied
|
||||||
|
|
||||||
|
Also provide an overall_quality score (0-1) evaluating the conversation flow,
|
||||||
|
how well the AI understood the user, and the template quality.
|
||||||
|
|
||||||
|
Respond with ONLY a JSON object:
|
||||||
|
{
|
||||||
|
"criteria_scores": {"criterion_1": 0.8, ...},
|
||||||
|
"overall_quality": 0.85,
|
||||||
|
"reasoning": "Brief explanation covering both conversation quality and template accuracy"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def _judge_interactive(
|
||||||
|
conversation: list[dict[str, str]],
|
||||||
|
prompt_template: str | None,
|
||||||
|
user_comment: str,
|
||||||
|
criteria: list[str],
|
||||||
|
*,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
) -> tuple[dict[str, float], float, str]:
|
||||||
|
"""Score an interactive session. Returns (criteria_scores, overall_quality, reasoning)."""
|
||||||
|
from shared.llm import get_llm
|
||||||
|
|
||||||
|
llm = get_llm(model=judge_model, temperature=0)
|
||||||
|
|
||||||
|
conv_text = "\n".join(
|
||||||
|
f"{'USER' if t['role'] == 'user' else 'AI'}: {t['content']}"
|
||||||
|
for t in conversation
|
||||||
|
)
|
||||||
|
criteria_text = "\n".join(f" {i+1}. {c}" for i, c in enumerate(criteria))
|
||||||
|
|
||||||
|
user_content = (
|
||||||
|
f"## Conversation transcript\n```\n{conv_text}\n```\n\n"
|
||||||
|
f"## Generated prompt_template\n```\n{prompt_template or '(none — conversation did not complete)'}\n```\n\n"
|
||||||
|
f"## User's comment\n{user_comment}\n\n"
|
||||||
|
f"## Criteria to evaluate\n{criteria_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=_INTERACTIVE_JUDGE_SYSTEM),
|
||||||
|
HumanMessage(content=user_content),
|
||||||
|
])
|
||||||
|
raw = response.content.strip()
|
||||||
|
if raw.startswith("```"):
|
||||||
|
raw = raw.split("```")[1]
|
||||||
|
if raw.startswith("json"):
|
||||||
|
raw = raw[4:]
|
||||||
|
parsed = json.loads(raw.strip())
|
||||||
|
|
||||||
|
scores_raw = parsed.get("criteria_scores", parsed.get("scores", {}))
|
||||||
|
criteria_scores: dict[str, float] = {}
|
||||||
|
for i, criterion in enumerate(criteria):
|
||||||
|
key_candidates = [f"criterion_{i+1}", criterion, criterion[:50], str(i + 1)]
|
||||||
|
score = 0.0
|
||||||
|
for key in key_candidates:
|
||||||
|
if key in scores_raw:
|
||||||
|
score = float(scores_raw[key])
|
||||||
|
break
|
||||||
|
if score == 0.0 and i < len(scores_raw):
|
||||||
|
score = float(list(scores_raw.values())[i])
|
||||||
|
criteria_scores[criterion] = score
|
||||||
|
|
||||||
|
overall = float(parsed.get("overall_quality", 0.0))
|
||||||
|
reasoning = str(parsed.get("reasoning", ""))
|
||||||
|
return criteria_scores, overall, reasoning
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("interactive judge failed: %s", exc)
|
||||||
|
return {c: 0.0 for c in criteria}, 0.0, f"Judge error: {exc}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Interactive session ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_interactive(
|
||||||
|
fixture: JourneyFixture,
|
||||||
|
*,
|
||||||
|
model: str = "gpt-4o",
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
data_dir: Path | None = None,
|
||||||
|
) -> InteractiveResult:
|
||||||
|
"""Run an interactive journey session in the terminal.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data_dir :
|
||||||
|
If set, overrides the fixture's sample-file directory. The LLM
|
||||||
|
will explore this folder instead of the default
|
||||||
|
``fixtures/sample_files/…``. Useful for private test data that
|
||||||
|
shouldn't be committed to git.
|
||||||
|
"""
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.ws_context import set_current_user, clear_current_user
|
||||||
|
from app.journey import (
|
||||||
|
handle_journey_start,
|
||||||
|
handle_journey_message,
|
||||||
|
_build_system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# When --data-dir is given, the MockExecutor's root becomes
|
||||||
|
# data_dir's parent and the journey directory is data_dir's name.
|
||||||
|
# This way the LLM sees a meaningful directory name (not ".") and
|
||||||
|
# MockExecutor resolves paths correctly.
|
||||||
|
# Otherwise, use the fixture's YAML parent and its relative path.
|
||||||
|
if data_dir:
|
||||||
|
mock_root = data_dir.parent
|
||||||
|
journey_directory = data_dir.name
|
||||||
|
else:
|
||||||
|
mock_root = fixture.fixture_path.parent
|
||||||
|
journey_directory = fixture.directory
|
||||||
|
|
||||||
|
mock = MockExecutor(
|
||||||
|
fixture_dir=mock_root,
|
||||||
|
seed_records={},
|
||||||
|
)
|
||||||
|
|
||||||
|
original_model = settings.LLM_MODEL
|
||||||
|
settings.LLM_MODEL = model
|
||||||
|
eval_user_id = f"interactive-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# ── Show system prompt ───────────────────────────────────────
|
||||||
|
system_prompt = _build_system_prompt(journey_directory, fixture.data_types)
|
||||||
|
|
||||||
|
_print_header("SYSTEM PROMPT")
|
||||||
|
print(f"{_C_DIM}{system_prompt}{_C_RESET}")
|
||||||
|
|
||||||
|
_print_header(f"INTERACTIVE JOURNEY | fixture: {fixture.name} | model: {model}")
|
||||||
|
print(f" Data dir: {mock_root}")
|
||||||
|
print(f" Type your responses. Commands: {_CMD_DONE}, {_CMD_QUIT}, {_CMD_TEMPLATE}, {_CMD_HELP}")
|
||||||
|
print(f" Judge model: {judge_model}")
|
||||||
|
print(f" Criteria: {len(fixture.expected_template_criteria)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
conversation: list[dict[str, str]] = []
|
||||||
|
prompt_template: str | None = None
|
||||||
|
done = False
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
set_current_user(eval_user_id)
|
||||||
|
|
||||||
|
with mock.patch():
|
||||||
|
# ── Start ────────────────────────────────────────────
|
||||||
|
_print_system("Starting journey... (AI is exploring your files)")
|
||||||
|
|
||||||
|
start_frame: dict[str, Any] = {
|
||||||
|
"agent_type": "local",
|
||||||
|
"directory": journey_directory,
|
||||||
|
"data_types": fixture.data_types,
|
||||||
|
"session_id": f"interactive-{uuid.uuid4().hex[:8]}",
|
||||||
|
}
|
||||||
|
|
||||||
|
reply = await handle_journey_start(eval_user_id, start_frame)
|
||||||
|
session_id = reply["session_id"]
|
||||||
|
conversation.append({"role": "assistant", "content": reply["message"]})
|
||||||
|
_print_ai(reply["message"])
|
||||||
|
|
||||||
|
if reply["done"]:
|
||||||
|
prompt_template = reply.get("prompt_template")
|
||||||
|
done = True
|
||||||
|
_print_system("Journey completed on first reply (template generated).")
|
||||||
|
|
||||||
|
# ── Conversation loop ────────────────────────────────
|
||||||
|
while not done:
|
||||||
|
try:
|
||||||
|
user_input = input(f"{_C_BOLD}{_C_BLUE}YOU:{_C_RESET} ").strip()
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
print()
|
||||||
|
user_input = _CMD_QUIT
|
||||||
|
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle commands
|
||||||
|
if user_input.lower() == _CMD_QUIT:
|
||||||
|
_print_system("Aborted — no evaluation will be performed.")
|
||||||
|
settings.LLM_MODEL = original_model
|
||||||
|
clear_current_user()
|
||||||
|
return InteractiveResult(
|
||||||
|
fixture_name=fixture.name, model=model, judge_model=judge_model,
|
||||||
|
prompt_template=None, conversation=conversation,
|
||||||
|
user_comment="(aborted)", done=False,
|
||||||
|
criteria_scores={}, overall_score=0.0,
|
||||||
|
judge_reasoning="Session aborted by user.",
|
||||||
|
elapsed_seconds=time.time() - start_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_input.lower() == _CMD_HELP:
|
||||||
|
print(_HELP_TEXT)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user_input.lower() == _CMD_TEMPLATE:
|
||||||
|
if prompt_template:
|
||||||
|
print(f"\n{_C_MAGENTA}{prompt_template}{_C_RESET}\n")
|
||||||
|
else:
|
||||||
|
_print_system("No template generated yet.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user_input.lower() == _CMD_DONE:
|
||||||
|
_print_system("Ending conversation...")
|
||||||
|
break
|
||||||
|
|
||||||
|
# ── Send message to AI ───────────────────────────
|
||||||
|
conversation.append({"role": "user", "content": user_input})
|
||||||
|
_print_system("AI is thinking...")
|
||||||
|
|
||||||
|
msg_frame: dict[str, Any] = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": user_input,
|
||||||
|
}
|
||||||
|
reply = await handle_journey_message(eval_user_id, msg_frame)
|
||||||
|
conversation.append({"role": "assistant", "content": reply["message"]})
|
||||||
|
_print_ai(reply["message"])
|
||||||
|
|
||||||
|
if reply["done"]:
|
||||||
|
prompt_template = reply.get("prompt_template")
|
||||||
|
done = True
|
||||||
|
_print_system("Journey completed — template generated!")
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("interactive journey failed: %s", exc)
|
||||||
|
_print_system(f"Error: {exc}")
|
||||||
|
finally:
|
||||||
|
settings.LLM_MODEL = original_model
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
turns = len([c for c in conversation if c["role"] == "user"])
|
||||||
|
|
||||||
|
# ── Show template if generated ───────────────────────────────
|
||||||
|
if prompt_template:
|
||||||
|
_print_header("GENERATED TEMPLATE")
|
||||||
|
print(f"{_C_MAGENTA}{prompt_template}{_C_RESET}\n")
|
||||||
|
else:
|
||||||
|
_print_system("No template was generated during this session.")
|
||||||
|
|
||||||
|
# ── User comment ─────────────────────────────────────────────
|
||||||
|
_print_header("YOUR EVALUATION")
|
||||||
|
print(" Write your comment about this interaction (press Enter twice to finish):")
|
||||||
|
print()
|
||||||
|
comment_lines: list[str] = []
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
line = input()
|
||||||
|
if line == "" and comment_lines and comment_lines[-1] == "":
|
||||||
|
comment_lines.pop() # remove trailing empty
|
||||||
|
break
|
||||||
|
comment_lines.append(line)
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
pass
|
||||||
|
user_comment = "\n".join(comment_lines).strip() or "(no comment)"
|
||||||
|
|
||||||
|
# ── Judge ────────────────────────────────────────────────────
|
||||||
|
_print_header("LLM JUDGE EVALUATION")
|
||||||
|
_print_system(f"Scoring with {judge_model}...")
|
||||||
|
|
||||||
|
criteria_scores, overall_quality, judge_reasoning = await _judge_interactive(
|
||||||
|
conversation=conversation,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
user_comment=user_comment,
|
||||||
|
criteria=fixture.expected_template_criteria,
|
||||||
|
judge_model=judge_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Display scores ───────────────────────────────────────────
|
||||||
|
print()
|
||||||
|
for criterion, score in criteria_scores.items():
|
||||||
|
_print_score(criterion, score)
|
||||||
|
|
||||||
|
overall = (
|
||||||
|
sum(criteria_scores.values()) / len(criteria_scores)
|
||||||
|
if criteria_scores
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n {_C_BOLD}Criteria avg: {overall:.2f}{_C_RESET}")
|
||||||
|
print(f" {_C_BOLD}Overall quality: {overall_quality:.2f}{_C_RESET}")
|
||||||
|
print(f" {_C_BOLD}Turns: {turns}{_C_RESET}")
|
||||||
|
print(f" {_C_BOLD}Time: {elapsed:.1f}s{_C_RESET}")
|
||||||
|
print(f"\n {_C_DIM}Judge: {judge_reasoning}{_C_RESET}")
|
||||||
|
print(f" {_C_DIM}Your comment: {user_comment}{_C_RESET}\n")
|
||||||
|
|
||||||
|
result = InteractiveResult(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
judge_model=judge_model,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
conversation=conversation,
|
||||||
|
user_comment=user_comment,
|
||||||
|
done=done,
|
||||||
|
criteria_scores=criteria_scores,
|
||||||
|
overall_score=overall_quality,
|
||||||
|
judge_reasoning=judge_reasoning,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Report to Langfuse ───────────────────────────────────────
|
||||||
|
trace_id = langfuse_eval.log_eval_trace(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant="interactive",
|
||||||
|
prompt_template=prompt_template or "(not generated)",
|
||||||
|
actual_mutations=[{
|
||||||
|
"conversation": conversation[:30],
|
||||||
|
"user_comment": user_comment,
|
||||||
|
}],
|
||||||
|
scores_summary=result.summary(),
|
||||||
|
langfuse_prompt_names=["journey_system"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if trace_id:
|
||||||
|
from eval.scorer import EvalScores
|
||||||
|
scores_obj = EvalScores(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant="interactive",
|
||||||
|
precision=overall,
|
||||||
|
recall=float(done),
|
||||||
|
f1=overall,
|
||||||
|
llm_judge_score=overall_quality,
|
||||||
|
llm_judge_reasoning=judge_reasoning,
|
||||||
|
)
|
||||||
|
langfuse_eval.post_eval_scores(scores_obj, trace_id=trace_id)
|
||||||
|
_print_system(f"Results reported to Langfuse (trace: {trace_id})")
|
||||||
|
else:
|
||||||
|
_print_system("Langfuse not configured — results not reported.")
|
||||||
|
|
||||||
|
return result
|
||||||
385
services/batch-agent/eval/journey_runner.py
Normal file
385
services/batch-agent/eval/journey_runner.py
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
"""Journey eval runner — tests the prompt_template builder conversation.
|
||||||
|
|
||||||
|
For each (journey_fixture × model) combination:
|
||||||
|
1. Build a MockExecutor (for filesystem tools used during journey)
|
||||||
|
2. Patch execute_on_client
|
||||||
|
3. Override LLM_MODEL
|
||||||
|
4. Call handle_journey_start to kick off the conversation
|
||||||
|
5. Feed simulated user_messages via handle_journey_message
|
||||||
|
6. Collect the generated prompt_template
|
||||||
|
7. Score it against expected_template_criteria (via LLM judge)
|
||||||
|
8. Report to Langfuse
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from eval.config import JourneyFixture
|
||||||
|
from eval.mock_executor import MockExecutor
|
||||||
|
from eval import langfuse_eval
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result type ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JourneyEvalResult:
|
||||||
|
"""Result of one journey eval run."""
|
||||||
|
|
||||||
|
fixture_name: str
|
||||||
|
model: str
|
||||||
|
prompt_template: str | None # the generated template (None if journey failed)
|
||||||
|
conversation_turns: int
|
||||||
|
done: bool # whether journey reached completion
|
||||||
|
criteria_scores: dict[str, float] # criterion → 0-1 score
|
||||||
|
overall_score: float # average of criteria scores
|
||||||
|
judge_reasoning: str
|
||||||
|
elapsed_seconds: float
|
||||||
|
|
||||||
|
def summary(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"fixture": self.fixture_name,
|
||||||
|
"model": self.model,
|
||||||
|
"done": self.done,
|
||||||
|
"turns": self.conversation_turns,
|
||||||
|
"overall_score": round(self.overall_score, 3),
|
||||||
|
"criteria_scores": {k: round(v, 3) for k, v in self.criteria_scores.items()},
|
||||||
|
"elapsed_s": round(self.elapsed_seconds, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM judge for template quality ──────────────────────────────────────
|
||||||
|
|
||||||
|
_JOURNEY_JUDGE_SYSTEM = """\
|
||||||
|
You are an evaluation judge for AI-generated prompt templates.
|
||||||
|
|
||||||
|
A journey chatbot explored a user's directory structure and through
|
||||||
|
conversation produced a prompt_template — an instruction set for a
|
||||||
|
data-extraction agent.
|
||||||
|
|
||||||
|
Your task: evaluate the generated template against a list of criteria.
|
||||||
|
Score each criterion from 0 to 1:
|
||||||
|
- 1.0: Fully satisfied, clearly present in the template
|
||||||
|
- 0.5: Partially satisfied or ambiguously addressed
|
||||||
|
- 0.0: Not satisfied, missing from the template
|
||||||
|
|
||||||
|
Respond with ONLY a JSON object:
|
||||||
|
{
|
||||||
|
"scores": {"criterion_1": 0.8, "criterion_2": 1.0, ...},
|
||||||
|
"reasoning": "Brief explanation"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def _judge_template(
|
||||||
|
prompt_template: str,
|
||||||
|
criteria: list[str],
|
||||||
|
*,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
) -> tuple[dict[str, float], str]:
|
||||||
|
"""Use an LLM to evaluate a generated prompt_template against criteria.
|
||||||
|
|
||||||
|
Returns (criteria_scores, reasoning).
|
||||||
|
"""
|
||||||
|
from shared.llm import get_llm
|
||||||
|
|
||||||
|
llm = get_llm(model=judge_model, temperature=0)
|
||||||
|
|
||||||
|
criteria_text = "\n".join(f" {i+1}. {c}" for i, c in enumerate(criteria))
|
||||||
|
user_content = (
|
||||||
|
f"## Generated prompt_template\n```\n{prompt_template}\n```\n\n"
|
||||||
|
f"## Criteria to evaluate\n{criteria_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=_JOURNEY_JUDGE_SYSTEM),
|
||||||
|
HumanMessage(content=user_content),
|
||||||
|
])
|
||||||
|
raw = response.content.strip()
|
||||||
|
if raw.startswith("```"):
|
||||||
|
raw = raw.split("```")[1]
|
||||||
|
if raw.startswith("json"):
|
||||||
|
raw = raw[4:]
|
||||||
|
parsed = json.loads(raw.strip())
|
||||||
|
|
||||||
|
scores_raw = parsed.get("scores", {})
|
||||||
|
# Map criterion keys back to the original criteria text
|
||||||
|
criteria_scores: dict[str, float] = {}
|
||||||
|
for i, criterion in enumerate(criteria):
|
||||||
|
# Try matching by index key or exact criterion text
|
||||||
|
key_candidates = [
|
||||||
|
f"criterion_{i+1}",
|
||||||
|
criterion,
|
||||||
|
criterion[:50],
|
||||||
|
str(i + 1),
|
||||||
|
]
|
||||||
|
score = 0.0
|
||||||
|
for key in key_candidates:
|
||||||
|
if key in scores_raw:
|
||||||
|
score = float(scores_raw[key])
|
||||||
|
break
|
||||||
|
# If no match found, try values in order
|
||||||
|
if score == 0.0 and i < len(scores_raw):
|
||||||
|
score = float(list(scores_raw.values())[i])
|
||||||
|
criteria_scores[criterion] = score
|
||||||
|
|
||||||
|
reasoning = str(parsed.get("reasoning", ""))
|
||||||
|
return criteria_scores, reasoning
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("journey_eval: LLM judge failed: %s", exc)
|
||||||
|
return {c: 0.0 for c in criteria}, f"Judge error: {exc}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Journey runner ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_single_journey_eval(
|
||||||
|
fixture: JourneyFixture,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
data_dir: Path | None = None,
|
||||||
|
) -> JourneyEvalResult:
|
||||||
|
"""Execute one journey eval: start \u2192 messages \u2192 score template."""
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
# When data_dir is given, use its parent as MockExecutor root
|
||||||
|
# and its name as the journey directory so the LLM sees a
|
||||||
|
# meaningful path (not ".").
|
||||||
|
if data_dir:
|
||||||
|
mock_root = data_dir.parent
|
||||||
|
journey_directory = data_dir.name
|
||||||
|
else:
|
||||||
|
mock_root = fixture.fixture_path.parent
|
||||||
|
journey_directory = fixture.directory
|
||||||
|
|
||||||
|
mock = MockExecutor(
|
||||||
|
fixture_dir=mock_root,
|
||||||
|
seed_records={},
|
||||||
|
)
|
||||||
|
|
||||||
|
original_model = settings.LLM_MODEL
|
||||||
|
settings.LLM_MODEL = model
|
||||||
|
|
||||||
|
eval_user_id = f"eval-journey-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey_eval: starting %s | model=%s",
|
||||||
|
fixture.name, model,
|
||||||
|
)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
prompt_template: str | None = None
|
||||||
|
conversation: list[dict[str, str]] = []
|
||||||
|
done = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from shared.ws_context import set_current_user, clear_current_user
|
||||||
|
from app.journey import handle_journey_start, handle_journey_message, _sessions
|
||||||
|
|
||||||
|
set_current_user(eval_user_id)
|
||||||
|
with mock.patch():
|
||||||
|
# ── Start the journey ────────────────────────────────
|
||||||
|
start_frame: dict[str, Any] = {
|
||||||
|
"agent_type": "local",
|
||||||
|
"directory": journey_directory,
|
||||||
|
"data_types": fixture.data_types,
|
||||||
|
"session_id": f"eval-{uuid.uuid4().hex[:8]}",
|
||||||
|
}
|
||||||
|
|
||||||
|
reply = await handle_journey_start(eval_user_id, start_frame)
|
||||||
|
session_id = reply["session_id"]
|
||||||
|
conversation.append({"role": "assistant", "content": reply["message"]})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey_eval: start reply (%d chars), done=%s",
|
||||||
|
len(reply["message"]), reply["done"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if reply["done"]:
|
||||||
|
prompt_template = reply.get("prompt_template")
|
||||||
|
done = True
|
||||||
|
else:
|
||||||
|
# ── Send user messages ───────────────────────────
|
||||||
|
for i, user_msg in enumerate(fixture.user_messages):
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
conversation.append({"role": "user", "content": user_msg})
|
||||||
|
|
||||||
|
msg_frame: dict[str, Any] = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": user_msg,
|
||||||
|
}
|
||||||
|
reply = await handle_journey_message(eval_user_id, msg_frame)
|
||||||
|
conversation.append({"role": "assistant", "content": reply["message"]})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey_eval: turn %d reply (%d chars), done=%s",
|
||||||
|
i + 1, len(reply["message"]), reply["done"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if reply["done"]:
|
||||||
|
prompt_template = reply.get("prompt_template")
|
||||||
|
done = True
|
||||||
|
|
||||||
|
# If not done after all user messages, send a final nudge
|
||||||
|
if not done:
|
||||||
|
nudge = "Please generate the final prompt_template now. I'm satisfied with the configuration."
|
||||||
|
conversation.append({"role": "user", "content": nudge})
|
||||||
|
|
||||||
|
nudge_frame: dict[str, Any] = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": nudge,
|
||||||
|
}
|
||||||
|
reply = await handle_journey_message(eval_user_id, nudge_frame)
|
||||||
|
conversation.append({"role": "assistant", "content": reply["message"]})
|
||||||
|
if reply["done"]:
|
||||||
|
prompt_template = reply.get("prompt_template")
|
||||||
|
done = True
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("journey_eval: pipeline failed for %s/%s: %s", fixture.name, model, exc)
|
||||||
|
finally:
|
||||||
|
settings.LLM_MODEL = original_model
|
||||||
|
from shared.ws_context import clear_current_user
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
turns = len([c for c in conversation if c["role"] == "user"])
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey_eval: completed in %.1fs — %d turns, done=%s, template=%s",
|
||||||
|
elapsed, turns, done, "yes" if prompt_template else "no",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Score the template ───────────────────────────────────────
|
||||||
|
criteria_scores: dict[str, float] = {}
|
||||||
|
judge_reasoning = ""
|
||||||
|
|
||||||
|
if prompt_template and fixture.expected_template_criteria:
|
||||||
|
criteria_scores, judge_reasoning = await _judge_template(
|
||||||
|
prompt_template,
|
||||||
|
fixture.expected_template_criteria,
|
||||||
|
judge_model=judge_model,
|
||||||
|
)
|
||||||
|
elif not prompt_template:
|
||||||
|
criteria_scores = {c: 0.0 for c in fixture.expected_template_criteria}
|
||||||
|
judge_reasoning = "No prompt_template was generated — journey did not complete."
|
||||||
|
|
||||||
|
overall = (
|
||||||
|
sum(criteria_scores.values()) / len(criteria_scores)
|
||||||
|
if criteria_scores
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
result = JourneyEvalResult(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
conversation_turns=turns,
|
||||||
|
done=done,
|
||||||
|
criteria_scores=criteria_scores,
|
||||||
|
overall_score=overall,
|
||||||
|
judge_reasoning=judge_reasoning,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Report to Langfuse ───────────────────────────────────────
|
||||||
|
trace_id = langfuse_eval.log_eval_trace(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant="journey",
|
||||||
|
prompt_template=prompt_template or "(not generated)",
|
||||||
|
actual_mutations=[{"conversation": conversation[:20]}],
|
||||||
|
scores_summary=result.summary(),
|
||||||
|
langfuse_prompt_names=["journey_system"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if trace_id:
|
||||||
|
from eval.scorer import EvalScores
|
||||||
|
scores_obj = EvalScores(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant="journey",
|
||||||
|
precision=overall,
|
||||||
|
recall=float(done),
|
||||||
|
f1=overall,
|
||||||
|
llm_judge_score=overall,
|
||||||
|
llm_judge_reasoning=judge_reasoning,
|
||||||
|
)
|
||||||
|
langfuse_eval.post_eval_scores(scores_obj, trace_id=trace_id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def run_journey_fixture_eval(
|
||||||
|
fixture: JourneyFixture,
|
||||||
|
models: list[str],
|
||||||
|
*,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
data_dir: Path | None = None,
|
||||||
|
) -> list[JourneyEvalResult]:
|
||||||
|
"""Run all models for a journey fixture."""
|
||||||
|
langfuse_eval.sync_journey_fixture_to_dataset(fixture)
|
||||||
|
|
||||||
|
results: list[JourneyEvalResult] = []
|
||||||
|
for model in models:
|
||||||
|
result = await run_single_journey_eval(
|
||||||
|
fixture, model, judge_model=judge_model,
|
||||||
|
data_dir=data_dir,
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def print_journey_results(results: list[JourneyEvalResult]) -> None:
|
||||||
|
"""Print a formatted summary of journey eval results."""
|
||||||
|
if not results:
|
||||||
|
print("\nNo journey eval results.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n" + "=" * 95)
|
||||||
|
print(f"{'Fixture':<25} {'Model':<25} {'Done':>5} {'Turns':>6} {'Score':>7} {'Time':>7}")
|
||||||
|
print("-" * 95)
|
||||||
|
|
||||||
|
for r in results:
|
||||||
|
done_str = "yes" if r.done else "NO"
|
||||||
|
print(
|
||||||
|
f"{r.fixture_name:<25} {r.model:<25} {done_str:>5} "
|
||||||
|
f"{r.conversation_turns:>6} {r.overall_score:>7.2f} {r.elapsed_seconds:>6.1f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 95)
|
||||||
|
|
||||||
|
# Criteria breakdown
|
||||||
|
for r in results:
|
||||||
|
if r.criteria_scores:
|
||||||
|
print(f"\n[{r.model}] Criteria scores:")
|
||||||
|
for criterion, score in r.criteria_scores.items():
|
||||||
|
indicator = "PASS" if score >= 0.7 else "PARTIAL" if score >= 0.4 else "FAIL"
|
||||||
|
print(f" {indicator:>7} ({score:.1f}) {criterion}")
|
||||||
|
|
||||||
|
if r.judge_reasoning:
|
||||||
|
print(f" Judge: {r.judge_reasoning}")
|
||||||
|
|
||||||
|
if r.prompt_template:
|
||||||
|
preview = r.prompt_template[:200].replace("\n", " ")
|
||||||
|
print(f" Template preview: {preview}...")
|
||||||
|
|
||||||
|
print()
|
||||||
327
services/batch-agent/eval/langfuse_eval.py
Normal file
327
services/batch-agent/eval/langfuse_eval.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
"""Langfuse evaluation integration — datasets, runs, and scoring.
|
||||||
|
|
||||||
|
Uses the Langfuse Python SDK v4 (OpenTelemetry-based) to:
|
||||||
|
|
||||||
|
1. **Sync fixtures → Langfuse datasets**: Each YAML fixture becomes a dataset,
|
||||||
|
each prompt variant + expected pair becomes a dataset item.
|
||||||
|
|
||||||
|
2. **Track eval runs**: Each (fixture × model × prompt_variant) execution
|
||||||
|
is recorded as a trace with linked scores.
|
||||||
|
|
||||||
|
3. **Post scores**: precision, recall, F1, field_accuracy, llm_judge are
|
||||||
|
posted as numeric scores on the trace.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from eval.config import EvalFixture
|
||||||
|
from eval.scorer import EvalScores
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_langfuse():
|
||||||
|
"""Get or create a Langfuse client instance (SDK v4)."""
|
||||||
|
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
os.environ.setdefault("LANGFUSE_SECRET_KEY", settings.LANGFUSE_SECRET_KEY)
|
||||||
|
os.environ.setdefault("LANGFUSE_PUBLIC_KEY", settings.LANGFUSE_PUBLIC_KEY)
|
||||||
|
if settings.LANGFUSE_HOST:
|
||||||
|
os.environ.setdefault("LANGFUSE_HOST", settings.LANGFUSE_HOST)
|
||||||
|
from langfuse import get_client
|
||||||
|
return get_client()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: failed to create client: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def sync_fixture_to_dataset(fixture: EvalFixture) -> str | None:
|
||||||
|
"""Create or update a Langfuse dataset from a fixture.
|
||||||
|
|
||||||
|
Each prompt variant becomes a separate dataset item with:
|
||||||
|
- input: {directory, data_types, prompt_template, seed_records}
|
||||||
|
- expected_output: {expected records}
|
||||||
|
|
||||||
|
Returns the dataset name, or None if Langfuse is unavailable.
|
||||||
|
"""
|
||||||
|
lf = _get_langfuse()
|
||||||
|
if lf is None:
|
||||||
|
logger.info("langfuse_eval: Langfuse not configured — skipping dataset sync")
|
||||||
|
return None
|
||||||
|
|
||||||
|
dataset_name = f"batch-eval-{fixture.name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
lf.create_dataset(
|
||||||
|
name=dataset_name,
|
||||||
|
description=fixture.description,
|
||||||
|
metadata={
|
||||||
|
"data_types": ",".join(fixture.data_types),
|
||||||
|
"file_extensions": ",".join(fixture.file_extensions) if fixture.file_extensions else "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Dataset may already exist — that's fine
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Build expected_output appropriate to the fixture's mode
|
||||||
|
expected_output: dict[str, Any] = {}
|
||||||
|
if fixture.mode in ("step1", "full") and fixture.expected_classification:
|
||||||
|
expected_output["classifications"] = [
|
||||||
|
{"file": ec.file, "project_id": ec.project_id, "domains": ec.domains}
|
||||||
|
for ec in fixture.expected_classification
|
||||||
|
]
|
||||||
|
if fixture.mode in ("step2", "full") and fixture.expected:
|
||||||
|
for rec in fixture.expected:
|
||||||
|
expected_output.setdefault(rec.table, []).append(rec.fields)
|
||||||
|
|
||||||
|
item_id = f"{fixture.name}--{fixture.mode}"
|
||||||
|
try:
|
||||||
|
lf.create_dataset_item(
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
id=item_id,
|
||||||
|
input={
|
||||||
|
"directory": fixture.directory,
|
||||||
|
"data_types": fixture.data_types,
|
||||||
|
"mode": fixture.mode,
|
||||||
|
"seed_records": fixture.seed_records,
|
||||||
|
},
|
||||||
|
expected_output=expected_output,
|
||||||
|
metadata={"mode": fixture.mode},
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"langfuse_eval: failed to upsert dataset item %s: %s", item_id, exc
|
||||||
|
)
|
||||||
|
|
||||||
|
lf.flush()
|
||||||
|
logger.info("langfuse_eval: synced fixture '%s' → dataset '%s'", fixture.name, dataset_name)
|
||||||
|
return dataset_name
|
||||||
|
|
||||||
|
|
||||||
|
def sync_journey_fixture_to_dataset(fixture) -> str | None:
|
||||||
|
"""Create or update a Langfuse dataset from a journey fixture.
|
||||||
|
|
||||||
|
Each journey fixture becomes a single dataset item with:
|
||||||
|
- input: {directory, data_types, user_messages}
|
||||||
|
- expected_output: {criteria}
|
||||||
|
"""
|
||||||
|
lf = _get_langfuse()
|
||||||
|
if lf is None:
|
||||||
|
logger.info("langfuse_eval: Langfuse not configured — skipping journey dataset sync")
|
||||||
|
return None
|
||||||
|
|
||||||
|
dataset_name = f"journey-eval-{fixture.name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
lf.create_dataset(
|
||||||
|
name=dataset_name,
|
||||||
|
description=fixture.description,
|
||||||
|
metadata={"type": "journey", "data_types": ",".join(fixture.data_types)},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Dataset may already exist
|
||||||
|
|
||||||
|
item_id = f"{fixture.name}--journey"
|
||||||
|
try:
|
||||||
|
lf.create_dataset_item(
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
id=item_id,
|
||||||
|
input={
|
||||||
|
"directory": fixture.directory,
|
||||||
|
"data_types": fixture.data_types,
|
||||||
|
"user_messages": fixture.user_messages,
|
||||||
|
},
|
||||||
|
expected_output={
|
||||||
|
"criteria": fixture.expected_template_criteria,
|
||||||
|
},
|
||||||
|
metadata={"type": "journey"},
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: failed to upsert journey dataset item %s: %s", item_id, exc)
|
||||||
|
|
||||||
|
lf.flush()
|
||||||
|
logger.info("langfuse_eval: synced journey fixture '%s' → dataset '%s'", fixture.name, dataset_name)
|
||||||
|
return dataset_name
|
||||||
|
|
||||||
|
|
||||||
|
def create_eval_run(
|
||||||
|
dataset_name: str,
|
||||||
|
run_name: str,
|
||||||
|
*,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a dataset run in Langfuse. Returns the run name.
|
||||||
|
|
||||||
|
Note: In SDK v4, dataset runs are created implicitly via
|
||||||
|
dataset.run_experiment(). This function is kept for backwards
|
||||||
|
compatibility but may not create a run.
|
||||||
|
"""
|
||||||
|
lf = _get_langfuse()
|
||||||
|
if lf is None:
|
||||||
|
return run_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(lf, "create_dataset_run"):
|
||||||
|
lf.create_dataset_run(
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
run_name=run_name,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
lf.flush()
|
||||||
|
else:
|
||||||
|
logger.debug("langfuse_eval: create_dataset_run not available in SDK v4")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: failed to create run %s: %s", run_name, exc)
|
||||||
|
|
||||||
|
return run_name
|
||||||
|
|
||||||
|
|
||||||
|
def post_eval_scores(
|
||||||
|
scores: EvalScores,
|
||||||
|
*,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
dataset_name: str | None = None,
|
||||||
|
run_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Post evaluation scores to Langfuse.
|
||||||
|
|
||||||
|
If trace_id is provided, scores are attached to that trace.
|
||||||
|
"""
|
||||||
|
lf = _get_langfuse()
|
||||||
|
if lf is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
score_data = [
|
||||||
|
("precision", scores.precision),
|
||||||
|
("recall", scores.recall),
|
||||||
|
("f1", scores.f1),
|
||||||
|
]
|
||||||
|
# Only post field_accuracy when there are field-level scores (step2/full)
|
||||||
|
if scores.field_scores:
|
||||||
|
score_data.append(("field_accuracy", scores.field_accuracy))
|
||||||
|
if scores.llm_judge_score is not None:
|
||||||
|
score_data.append(("llm_judge", scores.llm_judge_score))
|
||||||
|
|
||||||
|
for name, value in score_data:
|
||||||
|
try:
|
||||||
|
lf.create_score(
|
||||||
|
name=name,
|
||||||
|
value=value,
|
||||||
|
trace_id=trace_id,
|
||||||
|
data_type="NUMERIC",
|
||||||
|
comment=f"{scores.fixture_name} | {scores.model} | {scores.prompt_variant}",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: failed to post score %s: %s", name, exc)
|
||||||
|
|
||||||
|
lf.flush()
|
||||||
|
logger.info(
|
||||||
|
"langfuse_eval: posted %d scores for %s/%s/%s",
|
||||||
|
len(score_data), scores.fixture_name, scores.model, scores.prompt_variant,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def log_eval_trace(
|
||||||
|
*,
|
||||||
|
fixture_name: str,
|
||||||
|
model: str,
|
||||||
|
prompt_variant: str,
|
||||||
|
prompt_template: str,
|
||||||
|
actual_mutations: list[dict],
|
||||||
|
scores_summary: dict[str, Any],
|
||||||
|
step1_results: list[dict] | None = None,
|
||||||
|
dataset_name: str | None = None,
|
||||||
|
run_name: str | None = None,
|
||||||
|
dataset_item_id: str | None = None,
|
||||||
|
langfuse_prompt_names: list[str] | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""Create a Langfuse trace for one eval execution and link it to a dataset run.
|
||||||
|
|
||||||
|
Uses SDK v4 observation API (traces are created implicitly by root spans).
|
||||||
|
``langfuse_prompt_names`` can contain one or two prompt names to link
|
||||||
|
(e.g. ``["batch_file_classifier", "batch_processing"]`` for full mode).
|
||||||
|
Each prompt gets its own generation-type observation for per-version
|
||||||
|
metrics tracking.
|
||||||
|
|
||||||
|
Returns the trace_id, or None if Langfuse is unavailable.
|
||||||
|
"""
|
||||||
|
lf = _get_langfuse()
|
||||||
|
if lf is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import propagate_attributes
|
||||||
|
|
||||||
|
# Fetch prompt objects for linking
|
||||||
|
prompt_objs: list[tuple[str, Any]] = []
|
||||||
|
for pname in (langfuse_prompt_names or []):
|
||||||
|
try:
|
||||||
|
obj = lf.get_prompt(name=pname, cache_ttl_seconds=300)
|
||||||
|
prompt_objs.append((pname, obj))
|
||||||
|
logger.info("langfuse_eval: linked prompt '%s' (type=%s)", pname, type(obj).__name__)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: prompt '%s' not found — %s", pname, exc)
|
||||||
|
|
||||||
|
# Build trace output dict
|
||||||
|
trace_output: dict[str, Any] = {"scores": scores_summary}
|
||||||
|
if step1_results:
|
||||||
|
trace_output["classifications"] = step1_results
|
||||||
|
if actual_mutations:
|
||||||
|
trace_output["mutations"] = actual_mutations[:50]
|
||||||
|
|
||||||
|
with propagate_attributes(
|
||||||
|
trace_name=f"eval-{fixture_name}",
|
||||||
|
metadata={
|
||||||
|
"eval": "true",
|
||||||
|
"fixture": fixture_name,
|
||||||
|
"model": model,
|
||||||
|
"prompt_variant": prompt_variant,
|
||||||
|
},
|
||||||
|
tags=["eval", f"model:{model}", f"variant:{prompt_variant}"],
|
||||||
|
):
|
||||||
|
# Root span for the eval run
|
||||||
|
span = lf.start_observation(name=f"eval-{fixture_name}")
|
||||||
|
span.update(
|
||||||
|
input={
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
"model": model,
|
||||||
|
"prompt_variant": prompt_variant,
|
||||||
|
},
|
||||||
|
output=trace_output,
|
||||||
|
)
|
||||||
|
trace_id = span.trace_id
|
||||||
|
|
||||||
|
# Create a generation-type observation per linked prompt
|
||||||
|
for pname, pobj in prompt_objs:
|
||||||
|
gen = lf.start_observation(
|
||||||
|
name=f"prompt-{pname}",
|
||||||
|
prompt=pobj,
|
||||||
|
as_type="generation",
|
||||||
|
)
|
||||||
|
gen.end()
|
||||||
|
|
||||||
|
# Link to dataset run if available
|
||||||
|
if dataset_name and run_name and dataset_item_id:
|
||||||
|
try:
|
||||||
|
dataset = lf.get_dataset(dataset_name)
|
||||||
|
for item in dataset.items:
|
||||||
|
if item.id == dataset_item_id:
|
||||||
|
item.link(span, run_name)
|
||||||
|
break
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: failed to link trace to dataset run: %s", exc)
|
||||||
|
|
||||||
|
span.end()
|
||||||
|
|
||||||
|
lf.flush()
|
||||||
|
return trace_id
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: failed to create eval trace: %s", exc)
|
||||||
|
return None
|
||||||
258
services/batch-agent/eval/mock_executor.py
Normal file
258
services/batch-agent/eval/mock_executor.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
"""Mock executor — intercepts execute_on_client for offline E2E testing.
|
||||||
|
|
||||||
|
Patches ``execute_on_client`` at all usage sites so agent pipeline runs don't
|
||||||
|
require a live Electron client or Redis. Instead:
|
||||||
|
|
||||||
|
- **Filesystem actions** (list_directory, read_file_content, get_file_metadata)
|
||||||
|
are served from local fixture files on disk.
|
||||||
|
- **Read actions** (select, get) return preseeded records from an in-memory
|
||||||
|
store provided by the test fixture.
|
||||||
|
- **Write actions** (insert, update, delete) are captured as *mutations* and
|
||||||
|
stored for later comparison against expected results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from contextlib import contextmanager, asynccontextmanager
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Mutation:
|
||||||
|
"""A single recorded write operation."""
|
||||||
|
|
||||||
|
action: str # insert | update | delete
|
||||||
|
table: str
|
||||||
|
data: dict[str, Any]
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fake DB helpers (used to bypass async_session in full mode) ───────
|
||||||
|
|
||||||
|
class _FakeRow:
|
||||||
|
"""Mimics an AgentRunLog row returned by SQLAlchemy."""
|
||||||
|
id = 0
|
||||||
|
status = "running"
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
errors: list[str] = []
|
||||||
|
completed_at = None
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
object.__setattr__(self, name, value)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResult:
|
||||||
|
"""Mimics a SQLAlchemy ``Result`` with ``scalar_one_or_none``."""
|
||||||
|
def __init__(self, row: _FakeRow) -> None:
|
||||||
|
self._row = row
|
||||||
|
|
||||||
|
def scalar_one_or_none(self) -> _FakeRow:
|
||||||
|
return self._row
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockExecutor:
|
||||||
|
"""In-memory executor that replaces Redis-based tool round-trip.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fixture_dir : Path
|
||||||
|
Directory containing sample files for filesystem tool calls.
|
||||||
|
seed_records : dict[str, list[dict]]
|
||||||
|
Pre-existing records per table, e.g. ``{"tasks": [...], "projects": [...]}``.
|
||||||
|
The executor returns these for ``select`` / ``get`` actions and auto-updates
|
||||||
|
them on ``insert`` / ``update`` / ``delete`` so subsequent selects reflect changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fixture_dir: Path
|
||||||
|
seed_records: dict[str, list[dict]] = field(default_factory=dict)
|
||||||
|
mutations: list[Mutation] = field(default_factory=list)
|
||||||
|
_id_counter: int = field(default=1000, repr=False)
|
||||||
|
|
||||||
|
# ── Public API ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Clear recorded mutations (keep seed_records intact)."""
|
||||||
|
self.mutations.clear()
|
||||||
|
|
||||||
|
def get_mutations(self, *, table: str | None = None, action: str | None = None) -> list[Mutation]:
|
||||||
|
"""Filter mutations by table and/or action."""
|
||||||
|
result = self.mutations
|
||||||
|
if table:
|
||||||
|
result = [m for m in result if m.table == table]
|
||||||
|
if action:
|
||||||
|
result = [m for m in result if m.action == action]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def created_records(self, table: str) -> list[dict]:
|
||||||
|
"""Return data dicts of all inserts into *table*."""
|
||||||
|
return [m.data for m in self.mutations if m.table == table and m.action == "insert"]
|
||||||
|
|
||||||
|
def updated_records(self, table: str) -> list[dict]:
|
||||||
|
"""Return data dicts of all updates to *table*."""
|
||||||
|
return [m.data for m in self.mutations if m.table == table and m.action == "update"]
|
||||||
|
|
||||||
|
# ── Context manager for patching ──────────────────────────────
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch(self):
|
||||||
|
"""Patch execute_on_client and DB session at all usage sites."""
|
||||||
|
mock_fn = AsyncMock(side_effect=self._handle)
|
||||||
|
targets = [
|
||||||
|
"shared.ws_context.execute_on_client",
|
||||||
|
"app.agent_runner.execute_on_client",
|
||||||
|
"app.agents.filesystem_agent.execute_on_client",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock async_session so run_local_agent / _finalize_run skip real DB
|
||||||
|
fake_row = _FakeRow()
|
||||||
|
fake_db = AsyncMock()
|
||||||
|
fake_db.commit = AsyncMock()
|
||||||
|
fake_db.refresh = AsyncMock()
|
||||||
|
fake_db.execute = AsyncMock(return_value=_FakeResult(fake_row))
|
||||||
|
fake_db.add = lambda obj: None # noqa: ARG005
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _fake_session():
|
||||||
|
yield fake_db
|
||||||
|
|
||||||
|
patches = [patch(t, new=mock_fn) for t in targets]
|
||||||
|
patches.append(patch("app.agent_runner.async_session", _fake_session))
|
||||||
|
for p in patches:
|
||||||
|
p.start()
|
||||||
|
try:
|
||||||
|
yield mock_fn
|
||||||
|
finally:
|
||||||
|
for p in patches:
|
||||||
|
p.stop()
|
||||||
|
|
||||||
|
# ── Internal dispatch ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _handle(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
table: str | None = None,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
vector: list[float] | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
# Filesystem
|
||||||
|
if action == "list_directory":
|
||||||
|
return self._list_directory(data or {})
|
||||||
|
if action == "read_file_content":
|
||||||
|
return self._read_file(data or {})
|
||||||
|
if action == "get_file_metadata":
|
||||||
|
return self._get_file_metadata(data or {})
|
||||||
|
|
||||||
|
# CRUD
|
||||||
|
if action == "select":
|
||||||
|
return self._select(table or "", filters)
|
||||||
|
if action == "get":
|
||||||
|
return self._get(table or "", data or {})
|
||||||
|
if action == "insert":
|
||||||
|
return self._insert(table or "", data or {})
|
||||||
|
if action == "update":
|
||||||
|
return self._update(table or "", data or {})
|
||||||
|
if action == "delete":
|
||||||
|
return self._delete(table or "", data or {})
|
||||||
|
|
||||||
|
# Vector (no-op for eval)
|
||||||
|
if action in ("vector_upsert", "vector_search"):
|
||||||
|
return {"rows": []}
|
||||||
|
|
||||||
|
return {"error": f"Unknown action: {action}"}
|
||||||
|
|
||||||
|
# ── Filesystem handlers ───────────────────────────────────────
|
||||||
|
|
||||||
|
def _list_directory(self, data: dict) -> dict:
|
||||||
|
rel_path = data.get("path", "")
|
||||||
|
abs_path = self.fixture_dir / rel_path.lstrip("/\\")
|
||||||
|
if not abs_path.is_dir():
|
||||||
|
return {"entries": []}
|
||||||
|
entries: list[dict] = []
|
||||||
|
for child in sorted(abs_path.iterdir()):
|
||||||
|
entry_type = "directory" if child.is_dir() else "file"
|
||||||
|
# Return paths relative to fixture_dir but with the original prefix
|
||||||
|
entry_path = rel_path.rstrip("/\\") + "/" + child.name
|
||||||
|
entries.append({
|
||||||
|
"name": child.name,
|
||||||
|
"path": entry_path,
|
||||||
|
"type": entry_type,
|
||||||
|
})
|
||||||
|
return {"entries": entries}
|
||||||
|
|
||||||
|
def _read_file(self, data: dict) -> dict:
|
||||||
|
rel_path = data.get("path", "")
|
||||||
|
abs_path = self.fixture_dir / rel_path.lstrip("/\\")
|
||||||
|
if not abs_path.is_file():
|
||||||
|
return {"content": "", "error": f"File not found: {rel_path}"}
|
||||||
|
return {"content": abs_path.read_text(encoding="utf-8", errors="replace")}
|
||||||
|
|
||||||
|
def _get_file_metadata(self, data: dict) -> dict:
|
||||||
|
rel_path = data.get("path", "")
|
||||||
|
abs_path = self.fixture_dir / rel_path.lstrip("/\\")
|
||||||
|
if not abs_path.exists():
|
||||||
|
return {"error": f"Not found: {rel_path}"}
|
||||||
|
stat = abs_path.stat()
|
||||||
|
return {
|
||||||
|
"path": rel_path,
|
||||||
|
"size": stat.st_size,
|
||||||
|
"modifiedAt": int(stat.st_mtime * 1000),
|
||||||
|
"createdAt": int(stat.st_ctime * 1000),
|
||||||
|
"isDirectory": abs_path.is_dir(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── CRUD handlers ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _select(self, table: str, filters: dict | None) -> dict:
|
||||||
|
rows = list(self.seed_records.get(table, []))
|
||||||
|
if filters:
|
||||||
|
rows = [
|
||||||
|
r for r in rows
|
||||||
|
if all(r.get(k) == v for k, v in filters.items() if v is not None)
|
||||||
|
]
|
||||||
|
return {"rows": rows}
|
||||||
|
|
||||||
|
def _get(self, table: str, data: dict) -> dict:
|
||||||
|
record_id = data.get("id", "")
|
||||||
|
rows = self.seed_records.get(table, [])
|
||||||
|
for r in rows:
|
||||||
|
if r.get("id") == record_id:
|
||||||
|
return {"row": r}
|
||||||
|
return {"row": None}
|
||||||
|
|
||||||
|
def _insert(self, table: str, data: dict) -> dict:
|
||||||
|
self._id_counter += 1
|
||||||
|
record = {**data, "id": str(self._id_counter)}
|
||||||
|
# Add to seed so subsequent selects can find it
|
||||||
|
self.seed_records.setdefault(table, []).append(record)
|
||||||
|
self.mutations.append(Mutation(action="insert", table=table, data=record))
|
||||||
|
return {"row": record}
|
||||||
|
|
||||||
|
def _update(self, table: str, data: dict) -> dict:
|
||||||
|
record_id = data.get("id", "")
|
||||||
|
rows = self.seed_records.get(table, [])
|
||||||
|
for r in rows:
|
||||||
|
if r.get("id") == record_id:
|
||||||
|
r.update({k: v for k, v in data.items() if v is not None and v != ""})
|
||||||
|
self.mutations.append(Mutation(action="update", table=table, data=dict(r)))
|
||||||
|
return {"row": r}
|
||||||
|
# Record not found — still log the mutation
|
||||||
|
self.mutations.append(Mutation(action="update", table=table, data=data))
|
||||||
|
return {"row": data}
|
||||||
|
|
||||||
|
def _delete(self, table: str, data: dict) -> dict:
|
||||||
|
record_id = data.get("id", "")
|
||||||
|
rows = self.seed_records.get(table, [])
|
||||||
|
self.seed_records[table] = [r for r in rows if r.get("id") != record_id]
|
||||||
|
self.mutations.append(Mutation(action="delete", table=table, data={"id": record_id}))
|
||||||
|
return {"deleted": True}
|
||||||
2
services/batch-agent/eval/requirements.txt
Normal file
2
services/batch-agent/eval/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# Extra dependencies for the eval harness (on top of the service requirements.txt)
|
||||||
|
pyyaml>=6.0.0
|
||||||
545
services/batch-agent/eval/runner.py
Normal file
545
services/batch-agent/eval/runner.py
Normal file
@@ -0,0 +1,545 @@
|
|||||||
|
"""Eval runner — orchestrates fixture → mock → agent pipeline → scoring.
|
||||||
|
|
||||||
|
Supports three eval modes:
|
||||||
|
|
||||||
|
- **step1**: Test classification prompt only (``_STEP1_SYSTEM_PROMPT``).
|
||||||
|
Calls the LLM with fixture-provided ``domain_definitions`` and
|
||||||
|
``projects_list`` and compares output against ``expected_classification``.
|
||||||
|
|
||||||
|
- **step2**: Test processing prompt only (``_PROCESSING_SYSTEM_PROMPT``).
|
||||||
|
Compiles the prompt with fixture-provided ``existing_context``,
|
||||||
|
``project_context``, ``data_types``, and ``custom_prompt_section``,
|
||||||
|
then runs the tool-calling loop. Mutations are scored against
|
||||||
|
``expected`` records.
|
||||||
|
|
||||||
|
- **full**: Run ``run_local_agent()`` end-to-end (both steps).
|
||||||
|
Scored on both classification and extraction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from eval.config import EvalFixture, ExpectedClassification
|
||||||
|
from eval.mock_executor import MockExecutor
|
||||||
|
from eval.scorer import (
|
||||||
|
EvalScores,
|
||||||
|
FieldScore,
|
||||||
|
compute_precision_recall,
|
||||||
|
llm_judge_score,
|
||||||
|
score_field_match,
|
||||||
|
)
|
||||||
|
from eval import langfuse_eval
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 1 runner ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_step1(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
model: str,
|
||||||
|
mock: MockExecutor,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Run step-1 classification for every file in the fixture directory.
|
||||||
|
|
||||||
|
Scans the directory recursively, classifies each file, and returns
|
||||||
|
a list of result dicts:
|
||||||
|
``[{file, project_id, domains, new_project_name}, ...]``
|
||||||
|
"""
|
||||||
|
from app.agent_runner import _classify_file
|
||||||
|
|
||||||
|
# Build project name lookup for display
|
||||||
|
proj_names: dict[str, str] = {
|
||||||
|
p.get("id", ""): p.get("name", "") for p in fixture.projects_list
|
||||||
|
}
|
||||||
|
|
||||||
|
# Discover all files in the fixture directory
|
||||||
|
all_files = await _scan_fixture_files(mock, fixture.directory)
|
||||||
|
print(f"\n Scanning {len(all_files)} files in {fixture.directory}\n")
|
||||||
|
|
||||||
|
results: list[dict[str, Any]] = []
|
||||||
|
for i, file_path in enumerate(all_files, 1):
|
||||||
|
file_result = await mock._handle(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": file_path},
|
||||||
|
)
|
||||||
|
file_content: str = file_result.get("content", "")
|
||||||
|
if not file_content.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
project_id, domains, new_name = await _classify_file(
|
||||||
|
file_path=file_path,
|
||||||
|
file_content=file_content,
|
||||||
|
projects=fixture.projects_list,
|
||||||
|
config_data_types=fixture.data_types,
|
||||||
|
custom_system_prompt=fixture.custom_step1_prompt or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
short_name = file_path.rsplit("/", 1)[-1] if "/" in file_path else file_path
|
||||||
|
proj_label = proj_names.get(project_id, new_name or "?")
|
||||||
|
print(f" [{i}/{len(all_files)}] {short_name} → {project_id} ({proj_label}) {domains}")
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"file": file_path,
|
||||||
|
"project_id": project_id,
|
||||||
|
"domains": domains,
|
||||||
|
"new_project_name": new_name,
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def _scan_fixture_files(mock: MockExecutor, directory: str) -> list[str]:
|
||||||
|
"""Recursively list all files under *directory* via the mock executor."""
|
||||||
|
files: list[str] = []
|
||||||
|
|
||||||
|
async def _walk(path: str) -> None:
|
||||||
|
result = await mock._handle(action="list_directory", data={"path": path})
|
||||||
|
for entry in result.get("entries", []):
|
||||||
|
if entry.get("type") == "directory":
|
||||||
|
await _walk(entry["path"])
|
||||||
|
elif entry.get("type") == "file":
|
||||||
|
files.append(entry["path"])
|
||||||
|
|
||||||
|
await _walk(directory)
|
||||||
|
return sorted(files)
|
||||||
|
|
||||||
|
|
||||||
|
def _score_step1(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
results: list[dict[str, Any]],
|
||||||
|
) -> tuple[float, float, float, str]:
|
||||||
|
"""Score step-1 results. Returns (precision, recall, f1, reasoning).
|
||||||
|
|
||||||
|
Files with expected classifications are scored (OK/FAIL).
|
||||||
|
Files without expectations are shown as informational (INFO).
|
||||||
|
"""
|
||||||
|
if not fixture.expected_classification:
|
||||||
|
return 0.0, 0.0, 0.0, "No expected classifications"
|
||||||
|
|
||||||
|
# Build project name lookup
|
||||||
|
proj_names: dict[str, str] = {
|
||||||
|
p.get("id", ""): p.get("name", "") for p in fixture.projects_list
|
||||||
|
}
|
||||||
|
proj_names["new"] = "(new project)"
|
||||||
|
|
||||||
|
def _proj_label(pid: str, new_name: str | None = None) -> str:
|
||||||
|
name = proj_names.get(pid, "?")
|
||||||
|
if pid == "new" and new_name:
|
||||||
|
return f"new → \"{new_name}\""
|
||||||
|
return f"{pid} ({name})" if name and name != "?" else pid
|
||||||
|
|
||||||
|
def _short_file(path: str) -> str:
|
||||||
|
"""Use just the filename for cleaner display."""
|
||||||
|
return path.rsplit("/", 1)[-1] if "/" in path else path
|
||||||
|
|
||||||
|
expected_files = {ec.file for ec in fixture.expected_classification}
|
||||||
|
total = len(fixture.expected_classification)
|
||||||
|
matched = 0
|
||||||
|
|
||||||
|
scored_lines: list[str] = []
|
||||||
|
info_lines: list[str] = []
|
||||||
|
|
||||||
|
# Score expected files
|
||||||
|
for ec in fixture.expected_classification:
|
||||||
|
actual = next((r for r in results if r["file"] == ec.file), None)
|
||||||
|
fname = _short_file(ec.file)
|
||||||
|
if actual is None:
|
||||||
|
scored_lines.append(f" MISS {fname}")
|
||||||
|
scored_lines.append(f" expected: {_proj_label(ec.project_id)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
pid_ok = actual["project_id"] == ec.project_id
|
||||||
|
domains_ok = set(actual["domains"]) == set(ec.domains) if ec.domains else True
|
||||||
|
|
||||||
|
if pid_ok and domains_ok:
|
||||||
|
matched += 1
|
||||||
|
scored_lines.append(f" OK {fname}")
|
||||||
|
scored_lines.append(f" project: {_proj_label(actual['project_id'])}")
|
||||||
|
scored_lines.append(f" domains: {actual['domains']}")
|
||||||
|
else:
|
||||||
|
scored_lines.append(f" FAIL {fname}")
|
||||||
|
if not pid_ok:
|
||||||
|
scored_lines.append(f" project: {_proj_label(actual['project_id'])} (expected: {_proj_label(ec.project_id)})")
|
||||||
|
else:
|
||||||
|
scored_lines.append(f" project: {_proj_label(actual['project_id'])}")
|
||||||
|
if not domains_ok:
|
||||||
|
scored_lines.append(f" domains: {actual['domains']} (expected: {ec.domains})")
|
||||||
|
else:
|
||||||
|
scored_lines.append(f" domains: {actual['domains']}")
|
||||||
|
|
||||||
|
# Show unscored files
|
||||||
|
for r in results:
|
||||||
|
if r["file"] not in expected_files:
|
||||||
|
fname = _short_file(r["file"])
|
||||||
|
proj = _proj_label(r["project_id"], r.get("new_project_name"))
|
||||||
|
info_lines.append(f" · {fname}")
|
||||||
|
info_lines.append(f" project: {proj} | domains: {r['domains']}")
|
||||||
|
|
||||||
|
precision = matched / total if total > 0 else 0.0
|
||||||
|
recall = precision
|
||||||
|
f1 = precision
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
if scored_lines:
|
||||||
|
parts.append(f"Scored ({matched}/{total}):")
|
||||||
|
parts.extend(scored_lines)
|
||||||
|
if info_lines:
|
||||||
|
parts.append(f"\nOther files ({len(info_lines) // 2}):")
|
||||||
|
parts.extend(info_lines)
|
||||||
|
|
||||||
|
return precision, recall, f1, "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 2 runner ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_step2(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
model: str,
|
||||||
|
mock: MockExecutor,
|
||||||
|
) -> None:
|
||||||
|
"""Run step-2 processing for each file in the fixture directory.
|
||||||
|
|
||||||
|
Compiles ``_PROCESSING_SYSTEM_PROMPT`` with fixture-provided variables
|
||||||
|
and runs the tool-calling loop. Mutations are captured by the mock.
|
||||||
|
"""
|
||||||
|
from app.agent_runner import (
|
||||||
|
_PROCESSING_SYSTEM_PROMPT,
|
||||||
|
_build_processing_tools,
|
||||||
|
_run_agent_with_tools,
|
||||||
|
_MAX_PROCESSING_STEPS,
|
||||||
|
)
|
||||||
|
from app import tracing
|
||||||
|
|
||||||
|
# Compile the processing prompt with fixture variables
|
||||||
|
system_prompt = tracing.compile_prompt(
|
||||||
|
"batch_processing",
|
||||||
|
fallback=_PROCESSING_SYSTEM_PROMPT,
|
||||||
|
variables={
|
||||||
|
"existing_context": fixture.existing_context,
|
||||||
|
"project_context": fixture.project_context,
|
||||||
|
"data_types": ", ".join(fixture.data_types),
|
||||||
|
"custom_prompt_section": fixture.custom_prompt_section,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = _build_processing_tools(fixture.data_types)
|
||||||
|
|
||||||
|
# Scan files in the fixture directory
|
||||||
|
file_entries = await mock._handle(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": fixture.directory},
|
||||||
|
)
|
||||||
|
for entry in file_entries.get("entries", []):
|
||||||
|
if entry.get("type") != "file":
|
||||||
|
continue
|
||||||
|
# Filter by extension if specified
|
||||||
|
if fixture.file_extensions:
|
||||||
|
ext = entry["name"].rsplit(".", 1)[-1] if "." in entry["name"] else ""
|
||||||
|
if ext not in fixture.file_extensions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_result = await mock._handle(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": entry["path"]},
|
||||||
|
)
|
||||||
|
file_content: str = file_result.get("content", "")
|
||||||
|
if not file_content.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
await _run_agent_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_message=(
|
||||||
|
f"Process this file and extract relevant information.\n\n"
|
||||||
|
f"File: {entry['path']}\n\nContent:\n{file_content}"
|
||||||
|
),
|
||||||
|
tools=tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Full runner ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_full(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
model: str,
|
||||||
|
mock: MockExecutor,
|
||||||
|
user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Run the full two-step pipeline via ``run_local_agent``."""
|
||||||
|
from app.agent_runner import run_local_agent
|
||||||
|
|
||||||
|
trigger_data: dict[str, Any] = {
|
||||||
|
"type": "agent_trigger",
|
||||||
|
"directory": fixture.directory,
|
||||||
|
"directory_paths": [fixture.directory],
|
||||||
|
"data_types": fixture.data_types,
|
||||||
|
"file_extensions": fixture.file_extensions,
|
||||||
|
"prompt_template": fixture.custom_prompt_section,
|
||||||
|
"device_id": "eval-harness",
|
||||||
|
"run_context": {
|
||||||
|
"agent_id": f"eval-{fixture.name}",
|
||||||
|
"run_id": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with mock.patch():
|
||||||
|
await run_local_agent(user_id, trigger_data)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Scoring helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _score_mutations(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
mock: MockExecutor,
|
||||||
|
) -> tuple[list[FieldScore], float, float, float, int, int]:
|
||||||
|
"""Score mutations against expected records.
|
||||||
|
|
||||||
|
Returns (field_scores, precision, recall, f1, extra, missing).
|
||||||
|
"""
|
||||||
|
all_field_scores: list[FieldScore] = []
|
||||||
|
total_expected = 0
|
||||||
|
total_actual = 0
|
||||||
|
total_matched = 0
|
||||||
|
total_extra = 0
|
||||||
|
total_missing = 0
|
||||||
|
|
||||||
|
expected_by_table: dict[str, list[dict]] = {}
|
||||||
|
for rec in fixture.expected:
|
||||||
|
expected_by_table.setdefault(rec.table, []).append(rec.fields)
|
||||||
|
|
||||||
|
tables = set(expected_by_table.keys()) | {m.table for m in mock.mutations}
|
||||||
|
for table in tables:
|
||||||
|
expected_records = expected_by_table.get(table, [])
|
||||||
|
actual_records = mock.created_records(table) + mock.updated_records(table)
|
||||||
|
|
||||||
|
field_scores, extra, missing = score_field_match(expected_records, actual_records, table)
|
||||||
|
all_field_scores.extend(field_scores)
|
||||||
|
|
||||||
|
matched = sum(1 for s in field_scores if s.best_match is not None)
|
||||||
|
total_expected += len(expected_records)
|
||||||
|
total_actual += len(actual_records)
|
||||||
|
total_matched += matched
|
||||||
|
total_extra += extra
|
||||||
|
total_missing += missing
|
||||||
|
|
||||||
|
precision, recall, f1 = compute_precision_recall(total_expected, total_actual, total_matched)
|
||||||
|
return all_field_scores, precision, recall, f1, total_extra, total_missing
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main entry point ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_single_eval(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
use_llm_judge: bool = True,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
) -> EvalScores:
|
||||||
|
"""Execute one eval run for a fixture + model. Mode is read from the fixture."""
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.ws_context import set_current_user, clear_current_user
|
||||||
|
|
||||||
|
seed = copy.deepcopy(fixture.seed_records)
|
||||||
|
mock = MockExecutor(
|
||||||
|
fixture_dir=fixture.fixture_path.parent,
|
||||||
|
seed_records=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
original_model = settings.LLM_MODEL
|
||||||
|
settings.LLM_MODEL = model
|
||||||
|
eval_user_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"eval: starting %s | mode=%s | model=%s",
|
||||||
|
fixture.name, fixture.mode, model,
|
||||||
|
)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
step1_results: list[dict[str, Any]] = []
|
||||||
|
step1_reasoning = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
set_current_user(eval_user_id)
|
||||||
|
|
||||||
|
if fixture.mode == "step1":
|
||||||
|
with mock.patch():
|
||||||
|
step1_results = await _run_step1(fixture, model, mock)
|
||||||
|
|
||||||
|
elif fixture.mode == "step2":
|
||||||
|
with mock.patch():
|
||||||
|
await _run_step2(fixture, model, mock)
|
||||||
|
|
||||||
|
elif fixture.mode == "full":
|
||||||
|
with mock.patch():
|
||||||
|
# Step 1 — classification (independent from run_local_agent)
|
||||||
|
if fixture.expected_classification:
|
||||||
|
step1_results = await _run_step1(fixture, model, mock)
|
||||||
|
|
||||||
|
# Step 2 — full pipeline (run_local_agent handles both steps)
|
||||||
|
await _run_full(fixture, model, mock, eval_user_id)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("eval: pipeline failed for %s/%s: %s", fixture.name, model, exc)
|
||||||
|
finally:
|
||||||
|
settings.LLM_MODEL = original_model
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.info("eval: completed in %.1fs — %d mutations", elapsed, len(mock.mutations))
|
||||||
|
|
||||||
|
# ── Score ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
if fixture.mode == "step1":
|
||||||
|
s1_precision, s1_recall, s1_f1, step1_reasoning = _score_step1(fixture, step1_results)
|
||||||
|
scores = EvalScores(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant=fixture.mode,
|
||||||
|
precision=s1_precision,
|
||||||
|
recall=s1_recall,
|
||||||
|
f1=s1_f1,
|
||||||
|
llm_judge_reasoning=step1_reasoning,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# step2 or full — score mutations
|
||||||
|
field_scores, precision, recall, f1, extra, missing = _score_mutations(fixture, mock)
|
||||||
|
scores = EvalScores(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant=fixture.mode,
|
||||||
|
field_scores=field_scores,
|
||||||
|
precision=precision,
|
||||||
|
recall=recall,
|
||||||
|
f1=f1,
|
||||||
|
extra_records=extra,
|
||||||
|
missing_records=missing,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add step1 classification scores for full mode
|
||||||
|
if fixture.mode == "full" and fixture.expected_classification:
|
||||||
|
s1_p, s1_r, s1_f1, step1_reasoning = _score_step1(fixture, step1_results)
|
||||||
|
scores.llm_judge_reasoning = f"Step1 classification:\n{step1_reasoning}"
|
||||||
|
|
||||||
|
# Optional LLM judge for extraction quality
|
||||||
|
if use_llm_judge and fixture.expected:
|
||||||
|
all_expected = [r.fields for r in fixture.expected]
|
||||||
|
all_actual = [m.data for m in mock.mutations if m.action in ("insert", "update")]
|
||||||
|
judge_score, reasoning = await llm_judge_score(
|
||||||
|
all_expected, all_actual, judge_model=judge_model,
|
||||||
|
)
|
||||||
|
scores.llm_judge_score = judge_score
|
||||||
|
if step1_reasoning:
|
||||||
|
scores.llm_judge_reasoning += f"\n\nLLM judge:\n{reasoning}"
|
||||||
|
else:
|
||||||
|
scores.llm_judge_reasoning = reasoning
|
||||||
|
|
||||||
|
# ── Report to Langfuse ────────────────────────────────────────
|
||||||
|
prompt_names = {
|
||||||
|
"step1": ["batch_file_classifier"],
|
||||||
|
"step2": ["batch_processing"],
|
||||||
|
"full": ["batch_file_classifier", "batch_processing"],
|
||||||
|
}.get(fixture.mode, ["batch_processing"])
|
||||||
|
|
||||||
|
trace_id = langfuse_eval.log_eval_trace(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant=fixture.mode,
|
||||||
|
prompt_template=fixture.custom_prompt_section or "(default)",
|
||||||
|
actual_mutations=[{"action": m.action, "table": m.table, "data": m.data} for m in mock.mutations],
|
||||||
|
scores_summary=scores.summary(),
|
||||||
|
step1_results=step1_results or None,
|
||||||
|
langfuse_prompt_names=prompt_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
if trace_id:
|
||||||
|
langfuse_eval.post_eval_scores(scores, trace_id=trace_id)
|
||||||
|
|
||||||
|
# For full mode, post classification scores separately
|
||||||
|
if fixture.mode == "full" and fixture.expected_classification:
|
||||||
|
s1_p, s1_r, s1_f1, _ = _score_step1(fixture, step1_results)
|
||||||
|
for name, value in [
|
||||||
|
("classification_precision", s1_p),
|
||||||
|
("classification_recall", s1_r),
|
||||||
|
("classification_f1", s1_f1),
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
from langfuse import get_client
|
||||||
|
lf = get_client()
|
||||||
|
if lf:
|
||||||
|
lf.create_score(
|
||||||
|
name=name,
|
||||||
|
value=value,
|
||||||
|
trace_id=trace_id,
|
||||||
|
data_type="NUMERIC",
|
||||||
|
comment=f"{fixture.name} | {model} | full",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
async def run_fixture_eval(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
models: list[str],
|
||||||
|
*,
|
||||||
|
use_llm_judge: bool = True,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
) -> list[EvalScores]:
|
||||||
|
"""Run all models for a fixture."""
|
||||||
|
langfuse_eval.sync_fixture_to_dataset(fixture)
|
||||||
|
|
||||||
|
results: list[EvalScores] = []
|
||||||
|
for model in models:
|
||||||
|
scores = await run_single_eval(
|
||||||
|
fixture, model,
|
||||||
|
use_llm_judge=use_llm_judge,
|
||||||
|
judge_model=judge_model,
|
||||||
|
)
|
||||||
|
results.append(scores)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(results: list[EvalScores]) -> None:
|
||||||
|
"""Print a formatted summary table of eval results."""
|
||||||
|
if not results:
|
||||||
|
print("\nNo eval results.")
|
||||||
|
return
|
||||||
|
|
||||||
|
W = 90
|
||||||
|
|
||||||
|
print("\n" + "=" * W)
|
||||||
|
print(f"{'Fixture':<25} {'Mode':<6} {'Model':<25} {'P':>6} {'R':>6} {'F1':>6} {'FA':>6} {'LLM':>6}")
|
||||||
|
print("-" * W)
|
||||||
|
|
||||||
|
for s in results:
|
||||||
|
llm_str = f"{s.llm_judge_score:.2f}" if s.llm_judge_score is not None else " --"
|
||||||
|
fa_str = f"{s.field_accuracy:.2f}" if s.field_scores else " --"
|
||||||
|
print(
|
||||||
|
f"{s.fixture_name:<25} {s.prompt_variant:<6} {s.model:<25} "
|
||||||
|
f"{s.precision:>6.2f} {s.recall:>6.2f} {s.f1:>6.2f} "
|
||||||
|
f"{fa_str:>6} {llm_str:>6}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * W)
|
||||||
|
|
||||||
|
for s in results:
|
||||||
|
if s.llm_judge_reasoning:
|
||||||
|
print(f"\n{'─' * W}")
|
||||||
|
print(f" {s.fixture_name} | {s.model} | {s.prompt_variant}")
|
||||||
|
print(f"{'─' * W}")
|
||||||
|
print(s.llm_judge_reasoning)
|
||||||
|
|
||||||
|
print()
|
||||||
268
services/batch-agent/eval/scorer.py
Normal file
268
services/batch-agent/eval/scorer.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""Scoring functions for batch agent evaluation.
|
||||||
|
|
||||||
|
Two scoring strategies:
|
||||||
|
|
||||||
|
1. **FieldMatchScorer** — deterministic check: for each expected record,
|
||||||
|
find the best-matching actual record and compare specified fields.
|
||||||
|
Returns precision, recall, and per-field accuracy.
|
||||||
|
|
||||||
|
2. **LLMJudgeScorer** — uses a secondary LLM to semantically evaluate
|
||||||
|
whether the actual extractions satisfy the expected intent, even if
|
||||||
|
wording differs. Returns a 0-1 score + reasoning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result types ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FieldScore:
|
||||||
|
"""Score for a single expected record against its best match."""
|
||||||
|
|
||||||
|
expected: dict[str, Any]
|
||||||
|
best_match: dict[str, Any] | None
|
||||||
|
matched_fields: dict[str, bool]
|
||||||
|
similarity: float # 0-1 overall similarity
|
||||||
|
|
||||||
|
@property
|
||||||
|
def field_accuracy(self) -> float:
|
||||||
|
if not self.matched_fields:
|
||||||
|
return 0.0
|
||||||
|
return sum(self.matched_fields.values()) / len(self.matched_fields)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalScores:
|
||||||
|
"""Aggregated scores for one eval run."""
|
||||||
|
|
||||||
|
fixture_name: str
|
||||||
|
model: str
|
||||||
|
prompt_variant: str
|
||||||
|
field_scores: list[FieldScore] = field(default_factory=list)
|
||||||
|
precision: float = 0.0
|
||||||
|
recall: float = 0.0
|
||||||
|
f1: float = 0.0
|
||||||
|
llm_judge_score: float | None = None
|
||||||
|
llm_judge_reasoning: str = ""
|
||||||
|
extra_records: int = 0 # records created but not expected
|
||||||
|
missing_records: int = 0 # expected but not found
|
||||||
|
|
||||||
|
@property
|
||||||
|
def field_accuracy(self) -> float:
|
||||||
|
if not self.field_scores:
|
||||||
|
return 0.0
|
||||||
|
return sum(s.field_accuracy for s in self.field_scores) / len(self.field_scores)
|
||||||
|
|
||||||
|
def summary(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"fixture": self.fixture_name,
|
||||||
|
"model": self.model,
|
||||||
|
"prompt_variant": self.prompt_variant,
|
||||||
|
"precision": round(self.precision, 3),
|
||||||
|
"recall": round(self.recall, 3),
|
||||||
|
"f1": round(self.f1, 3),
|
||||||
|
"field_accuracy": round(self.field_accuracy, 3),
|
||||||
|
"llm_judge_score": round(self.llm_judge_score, 3) if self.llm_judge_score is not None else None,
|
||||||
|
"extra_records": self.extra_records,
|
||||||
|
"missing_records": self.missing_records,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Field Match Scorer ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize(value: Any) -> str:
|
||||||
|
"""Normalize a value for comparison."""
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
return str(value).strip().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _text_similarity(a: str, b: str) -> float:
|
||||||
|
"""Fuzzy text similarity using SequenceMatcher."""
|
||||||
|
if not a and not b:
|
||||||
|
return 1.0
|
||||||
|
if not a or not b:
|
||||||
|
return 0.0
|
||||||
|
return SequenceMatcher(None, a.lower(), b.lower()).ratio()
|
||||||
|
|
||||||
|
|
||||||
|
def _find_best_match(
|
||||||
|
expected: dict[str, Any],
|
||||||
|
actuals: list[dict[str, Any]],
|
||||||
|
) -> tuple[dict[str, Any] | None, float]:
|
||||||
|
"""Find the actual record most similar to expected, return (match, similarity)."""
|
||||||
|
if not actuals:
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
best_match = None
|
||||||
|
best_score = 0.0
|
||||||
|
|
||||||
|
# Primary matching key: title or name
|
||||||
|
expected_title = _normalize(expected.get("title", expected.get("name", "")))
|
||||||
|
|
||||||
|
for actual in actuals:
|
||||||
|
actual_title = _normalize(actual.get("title", actual.get("name", "")))
|
||||||
|
sim = _text_similarity(expected_title, actual_title)
|
||||||
|
if sim > best_score:
|
||||||
|
best_score = sim
|
||||||
|
best_match = actual
|
||||||
|
|
||||||
|
return best_match, best_score
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_fields(
|
||||||
|
expected: dict[str, Any],
|
||||||
|
actual: dict[str, Any],
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Compare each expected field against the actual record."""
|
||||||
|
results: dict[str, bool] = {}
|
||||||
|
for key, expected_val in expected.items():
|
||||||
|
actual_val = actual.get(key)
|
||||||
|
# Exact match for non-string types
|
||||||
|
if not isinstance(expected_val, str):
|
||||||
|
results[key] = actual_val == expected_val
|
||||||
|
else:
|
||||||
|
# Fuzzy match for strings (threshold: 0.7)
|
||||||
|
results[key] = _text_similarity(
|
||||||
|
_normalize(expected_val), _normalize(actual_val)
|
||||||
|
) >= 0.7
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def score_field_match(
|
||||||
|
expected_records: list[dict[str, Any]],
|
||||||
|
actual_records: list[dict[str, Any]],
|
||||||
|
table: str,
|
||||||
|
) -> tuple[list[FieldScore], int, int]:
|
||||||
|
"""Score actual extractions against expected records for one table.
|
||||||
|
|
||||||
|
Returns (field_scores, extra_count, missing_count).
|
||||||
|
"""
|
||||||
|
field_scores: list[FieldScore] = []
|
||||||
|
matched_actuals: set[int] = set()
|
||||||
|
|
||||||
|
for exp in expected_records:
|
||||||
|
# Find best match among unmatched actuals
|
||||||
|
candidates = [
|
||||||
|
(i, a) for i, a in enumerate(actual_records) if i not in matched_actuals
|
||||||
|
]
|
||||||
|
if not candidates:
|
||||||
|
field_scores.append(FieldScore(
|
||||||
|
expected=exp, best_match=None, matched_fields={}, similarity=0.0,
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
best_idx, best_match = None, None
|
||||||
|
best_sim = 0.0
|
||||||
|
for idx, actual in candidates:
|
||||||
|
_, sim = _find_best_match(exp, [actual])
|
||||||
|
if sim > best_sim:
|
||||||
|
best_sim = sim
|
||||||
|
best_idx = idx
|
||||||
|
best_match = actual
|
||||||
|
|
||||||
|
if best_sim >= 0.5 and best_match is not None:
|
||||||
|
matched_actuals.add(best_idx)
|
||||||
|
matched_fields = _compare_fields(exp, best_match)
|
||||||
|
field_scores.append(FieldScore(
|
||||||
|
expected=exp, best_match=best_match,
|
||||||
|
matched_fields=matched_fields, similarity=best_sim,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
field_scores.append(FieldScore(
|
||||||
|
expected=exp, best_match=None, matched_fields={}, similarity=0.0,
|
||||||
|
))
|
||||||
|
|
||||||
|
extra_count = len(actual_records) - len(matched_actuals)
|
||||||
|
missing_count = sum(1 for s in field_scores if s.best_match is None)
|
||||||
|
|
||||||
|
return field_scores, extra_count, missing_count
|
||||||
|
|
||||||
|
|
||||||
|
def compute_precision_recall(
|
||||||
|
expected_count: int,
|
||||||
|
actual_count: int,
|
||||||
|
matched_count: int,
|
||||||
|
) -> tuple[float, float, float]:
|
||||||
|
"""Compute precision, recall, F1."""
|
||||||
|
precision = matched_count / actual_count if actual_count > 0 else 0.0
|
||||||
|
recall = matched_count / expected_count if expected_count > 0 else 0.0
|
||||||
|
f1 = (
|
||||||
|
2 * precision * recall / (precision + recall)
|
||||||
|
if (precision + recall) > 0
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
return precision, recall, f1
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM Judge Scorer ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_JUDGE_SYSTEM_PROMPT = """\
|
||||||
|
You are an evaluation judge for a data extraction system.
|
||||||
|
|
||||||
|
Your task is to compare the EXPECTED extractions against the ACTUAL extractions
|
||||||
|
produced by an AI agent, and assess quality on a 0-1 scale.
|
||||||
|
|
||||||
|
Scoring criteria:
|
||||||
|
- 1.0: All expected records found with correct fields, no significant extras
|
||||||
|
- 0.8: Most expected records found, minor field differences or extras
|
||||||
|
- 0.6: Core extractions present but some missing or incorrect
|
||||||
|
- 0.4: Partial match — several expected records missing or wrong
|
||||||
|
- 0.2: Poor quality — most expected records missing or incorrect
|
||||||
|
- 0.0: Complete failure — no meaningful overlap
|
||||||
|
|
||||||
|
Consider semantic equivalence: "Send invoice" and "Email the invoice" are matches.
|
||||||
|
Ignore field ordering and formatting differences.
|
||||||
|
|
||||||
|
Respond with ONLY a JSON object:
|
||||||
|
{"score": 0.85, "reasoning": "Brief explanation of the score"}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_judge_score(
|
||||||
|
expected: list[dict[str, Any]],
|
||||||
|
actual: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
) -> tuple[float, str]:
|
||||||
|
"""Use an LLM to semantically evaluate extraction quality.
|
||||||
|
|
||||||
|
Returns (score, reasoning).
|
||||||
|
"""
|
||||||
|
from shared.llm import get_llm
|
||||||
|
|
||||||
|
llm = get_llm(model=judge_model, temperature=0)
|
||||||
|
|
||||||
|
user_content = (
|
||||||
|
f"## Expected extractions\n```json\n{json.dumps(expected, indent=2, default=str)}\n```\n\n"
|
||||||
|
f"## Actual extractions\n```json\n{json.dumps(actual, indent=2, default=str)}\n```"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=_JUDGE_SYSTEM_PROMPT),
|
||||||
|
HumanMessage(content=user_content),
|
||||||
|
])
|
||||||
|
raw = response.content.strip()
|
||||||
|
if raw.startswith("```"):
|
||||||
|
raw = raw.split("```")[1]
|
||||||
|
if raw.startswith("json"):
|
||||||
|
raw = raw[4:]
|
||||||
|
parsed = json.loads(raw.strip())
|
||||||
|
return float(parsed.get("score", 0.0)), str(parsed.get("reasoning", ""))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("eval: LLM judge failed: %s", exc)
|
||||||
|
return 0.0, f"Judge error: {exc}"
|
||||||
21
services/batch-agent/requirements.txt
Normal file
21
services/batch-agent/requirements.txt
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
asyncpg>=0.30.0
|
||||||
|
redis>=5.0.0
|
||||||
|
cryptography>=42.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
langchain-core>=0.3.0
|
||||||
|
langchain-openai>=0.3.0
|
||||||
|
langchain-litellm>=0.3.0
|
||||||
|
litellm>=1.50.0
|
||||||
|
openai>=1.50.0
|
||||||
|
httpx>=0.27.0
|
||||||
|
langfuse>=3.0.0
|
||||||
|
croniter>=2.0.0
|
||||||
|
google-api-python-client>=2.130.0
|
||||||
|
google-auth>=2.30.0
|
||||||
|
msal>=1.28.0
|
||||||
36
services/billing/Dockerfile
Normal file
36
services/billing/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY services/billing/requirements.txt ./requirements.txt
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Shared module
|
||||||
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Service source
|
||||||
|
COPY services/billing/app/ app/
|
||||||
|
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Billing is lightweight — single worker is fine
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "1", \
|
||||||
|
"--timeout", "30"]
|
||||||
15
services/billing/README.md
Normal file
15
services/billing/README.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Billing Service
|
||||||
|
|
||||||
|
Owns: Stripe integration, tier management, subscription CRUD.
|
||||||
|
|
||||||
|
## Tables owned (write)
|
||||||
|
- `subscriptions`
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
- `POST /billing/checkout`
|
||||||
|
- `POST /billing/webhook` (Stripe, no JWT auth)
|
||||||
|
- `GET /billing/subscription`
|
||||||
|
- `DELETE /billing/subscription`
|
||||||
|
|
||||||
|
## Redis channels
|
||||||
|
- Publish: `tier:changed:{user_id}` on tier change
|
||||||
0
services/billing/app/__init__.py
Normal file
0
services/billing/app/__init__.py
Normal file
53
services/billing/app/main.py
Normal file
53
services/billing/app/main.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""Billing Service — FastAPI application.
|
||||||
|
|
||||||
|
Owns: Stripe checkout/webhook, subscription management, tier feature matrix,
|
||||||
|
quota enforcement.
|
||||||
|
|
||||||
|
Downstream services query this service (or read the user's tier from
|
||||||
|
the X-User-Tier header injected by Traefik) for billing decisions.
|
||||||
|
The webhook endpoint is exposed WITHOUT ForwardAuth so Stripe can reach it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
|
||||||
|
_repo_root = str(Path(__file__).resolve().parents[3])
|
||||||
|
if _repo_root not in sys.path:
|
||||||
|
sys.path.insert(0, _repo_root)
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.routes import router
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
logger.info("billing: service started")
|
||||||
|
yield
|
||||||
|
logger.info("billing: service stopped")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Adiuva Billing Service", lifespan=lifespan)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["GET", "POST", "DELETE"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health() -> dict[str, str]:
|
||||||
|
return {"status": "ok", "service": "billing"}
|
||||||
134
services/billing/app/routes.py
Normal file
134
services/billing/app/routes.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""Billing routes: Stripe checkout, webhook, subscription, tier query.
|
||||||
|
|
||||||
|
Adapted for the Billing microservice:
|
||||||
|
- Authenticated routes use Traefik-injected headers (X-User-Id, X-User-Tier)
|
||||||
|
- Webhook route has NO auth (Stripe signature verification only)
|
||||||
|
- Added /tier/{user_id} for internal service-to-service tier lookups
|
||||||
|
- Added /features/{tier} for feature matrix queries
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Header, HTTPException, Request, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.schemas import BillingTier
|
||||||
|
|
||||||
|
from app.stripe_service import stripe_service
|
||||||
|
from app.tier_manager import tier_manager, FEATURES, RATE_LIMITS
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request bodies ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _CheckoutRequest(BaseModel):
|
||||||
|
tier: BillingTier
|
||||||
|
|
||||||
|
|
||||||
|
# ── Checkout ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/checkout")
|
||||||
|
async def create_checkout(
|
||||||
|
body: _CheckoutRequest,
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Create a Stripe checkout session for a tier upgrade."""
|
||||||
|
url = stripe_service.create_checkout_session(x_user_id, body.tier)
|
||||||
|
return {"checkout_url": url}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Webhook (NO auth — Stripe signature only) ─────────────────────────
|
||||||
|
|
||||||
|
@router.post("/webhook")
|
||||||
|
async def stripe_webhook(
|
||||||
|
request: Request,
|
||||||
|
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Handle Stripe webhook events.
|
||||||
|
|
||||||
|
This endpoint is exposed without ForwardAuth in Traefik config
|
||||||
|
so Stripe can reach it directly.
|
||||||
|
"""
|
||||||
|
payload = await request.body()
|
||||||
|
async with async_session() as db:
|
||||||
|
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Subscription CRUD ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/subscription")
|
||||||
|
async def get_subscription(
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return the current subscription info for the authenticated user."""
|
||||||
|
async with async_session() as db:
|
||||||
|
sub = await stripe_service.get_subscription(x_user_id, db)
|
||||||
|
if sub is None:
|
||||||
|
return {
|
||||||
|
"tier": x_user_tier,
|
||||||
|
"status": "free",
|
||||||
|
"stripe_subscription_id": None,
|
||||||
|
"current_period_end": None,
|
||||||
|
}
|
||||||
|
return sub
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/subscription")
|
||||||
|
async def cancel_subscription(
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Cancel the active subscription."""
|
||||||
|
async with async_session() as db:
|
||||||
|
await stripe_service.cancel_subscription(x_user_id, db)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier query (internal, service-to-service) ─────────────────────────
|
||||||
|
|
||||||
|
@router.get("/tier/{user_id}")
|
||||||
|
async def get_user_tier(user_id: str) -> dict[str, str]:
|
||||||
|
"""Return the billing tier for a given user_id.
|
||||||
|
|
||||||
|
Used by other services for tier lookups. Protected by Traefik
|
||||||
|
ForwardAuth — only internal services should call this.
|
||||||
|
"""
|
||||||
|
async with async_session() as db:
|
||||||
|
tier = await tier_manager.get_tier(user_id, db)
|
||||||
|
return {"user_id": user_id, "tier": tier}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Feature matrix (public, cacheable) ────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/features/{tier}")
|
||||||
|
async def get_tier_features(tier: str) -> dict[str, Any]:
|
||||||
|
"""Return the feature matrix for a tier."""
|
||||||
|
if tier not in FEATURES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Unknown tier: {tier}",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"tier": tier,
|
||||||
|
"features": FEATURES[tier],
|
||||||
|
"rate_limit_rpm": RATE_LIMITS.get(tier, RATE_LIMITS["free"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/features")
|
||||||
|
async def get_all_features() -> dict[str, Any]:
|
||||||
|
"""Return the full feature matrix for all tiers."""
|
||||||
|
return {
|
||||||
|
"tiers": {
|
||||||
|
tier: {
|
||||||
|
"features": features,
|
||||||
|
"rate_limit_rpm": RATE_LIMITS.get(tier, RATE_LIMITS["free"]),
|
||||||
|
}
|
||||||
|
for tier, features in FEATURES.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
240
services/billing/app/stripe_service.py
Normal file
240
services/billing/app/stripe_service.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||||
|
|
||||||
|
Adapted for the Billing microservice — uses shared.models and shared.db.
|
||||||
|
All Stripe calls are gracefully stubbed when STRIPE_SECRET_KEY is not
|
||||||
|
configured, enabling local development without live credentials.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import stripe as stripe_lib
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.models import Subscription
|
||||||
|
|
||||||
|
# Stripe price IDs per tier — replace with real IDs in production .env
|
||||||
|
TIER_PRICE_IDS: dict[str, str] = {
|
||||||
|
"pro": "price_pro_monthly",
|
||||||
|
"power": "price_power_monthly",
|
||||||
|
"team": "price_team_monthly",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StripeService:
|
||||||
|
"""Wraps all Stripe interactions and owns subscription persistence."""
|
||||||
|
|
||||||
|
# ── Internal helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _configured(self) -> bool:
|
||||||
|
return bool(settings.STRIPE_SECRET_KEY)
|
||||||
|
|
||||||
|
def _client(self) -> Any:
|
||||||
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
return stripe_lib
|
||||||
|
|
||||||
|
# ── Public API ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def create_checkout_session(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
tier: str,
|
||||||
|
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||||
|
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
||||||
|
) -> str:
|
||||||
|
"""Create a Stripe checkout session and return the URL."""
|
||||||
|
if tier == "free":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Cannot create a checkout session for the free tier",
|
||||||
|
)
|
||||||
|
|
||||||
|
price_id = TIER_PRICE_IDS.get(tier)
|
||||||
|
if not price_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unknown tier: {tier}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._configured():
|
||||||
|
return "https://stripe.com/stub-checkout"
|
||||||
|
|
||||||
|
s = self._client()
|
||||||
|
session = s.checkout.Session.create(
|
||||||
|
payment_method_types=["card"],
|
||||||
|
mode="subscription",
|
||||||
|
line_items=[{"price": price_id, "quantity": 1}],
|
||||||
|
success_url=success_url,
|
||||||
|
cancel_url=cancel_url,
|
||||||
|
metadata={"user_id": user_id, "tier": tier},
|
||||||
|
)
|
||||||
|
return session.url
|
||||||
|
|
||||||
|
async def handle_webhook(
|
||||||
|
self,
|
||||||
|
payload: bytes,
|
||||||
|
sig_header: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Process a Stripe webhook event.
|
||||||
|
|
||||||
|
Verifies the signature, then dispatches on event type.
|
||||||
|
"""
|
||||||
|
if not self._configured():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = self._client()
|
||||||
|
event = s.Webhook.construct_event(
|
||||||
|
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||||
|
)
|
||||||
|
except stripe_lib.error.SignatureVerificationError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid Stripe signature",
|
||||||
|
)
|
||||||
|
|
||||||
|
event_type: str = event["type"]
|
||||||
|
data: dict[str, Any] = event["data"]["object"]
|
||||||
|
|
||||||
|
if event_type == "checkout.session.completed":
|
||||||
|
user_id = data.get("metadata", {}).get("user_id")
|
||||||
|
tier = data.get("metadata", {}).get("tier", "free")
|
||||||
|
sub_id = data.get("subscription")
|
||||||
|
period_end_ts = data.get("current_period_end")
|
||||||
|
period_end = (
|
||||||
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||||
|
if period_end_ts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if user_id:
|
||||||
|
await self._upsert_subscription(
|
||||||
|
db, user_id, sub_id, tier, "active", period_end
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.updated":
|
||||||
|
sub_id = data.get("id")
|
||||||
|
new_status = data.get("status", "active")
|
||||||
|
period_end_ts = data.get("current_period_end")
|
||||||
|
period_end = (
|
||||||
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||||
|
if period_end_ts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, status=new_status, current_period_end=period_end
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.deleted":
|
||||||
|
sub_id = data.get("id")
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, tier="free", status="canceled"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "invoice.payment_failed":
|
||||||
|
sub_id = data.get("subscription")
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, status="past_due"
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def get_subscription(
|
||||||
|
self, user_id: str, db: AsyncSession
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Return the subscription record for user_id, or None."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"tier": sub.tier,
|
||||||
|
"stripe_subscription_id": sub.stripe_subscription_id,
|
||||||
|
"status": sub.status,
|
||||||
|
"current_period_end": (
|
||||||
|
int(sub.current_period_end.timestamp() * 1000)
|
||||||
|
if sub.current_period_end
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
||||||
|
"""Cancel the user's Stripe subscription and downgrade to free."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None or not sub.stripe_subscription_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="No active subscription found",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._configured():
|
||||||
|
s = self._client()
|
||||||
|
s.Subscription.cancel(sub.stripe_subscription_id)
|
||||||
|
|
||||||
|
sub.tier = "free"
|
||||||
|
sub.status = "canceled"
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# ── Private DB helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _upsert_subscription(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
stripe_subscription_id: str | None,
|
||||||
|
tier: str,
|
||||||
|
sub_status: str,
|
||||||
|
current_period_end: datetime | None,
|
||||||
|
) -> None:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
sub = Subscription(user_id=user_id)
|
||||||
|
db.add(sub)
|
||||||
|
sub.stripe_subscription_id = stripe_subscription_id
|
||||||
|
sub.tier = tier
|
||||||
|
sub.status = sub_status
|
||||||
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
async def _update_subscription_by_stripe_id(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
stripe_subscription_id: str,
|
||||||
|
*,
|
||||||
|
tier: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
current_period_end: datetime | None = None,
|
||||||
|
) -> None:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(
|
||||||
|
Subscription.stripe_subscription_id == stripe_subscription_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return
|
||||||
|
if tier is not None:
|
||||||
|
sub.tier = tier
|
||||||
|
if status is not None:
|
||||||
|
sub.status = status
|
||||||
|
if current_period_end is not None:
|
||||||
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
stripe_service = StripeService()
|
||||||
178
services/billing/app/tier_manager.py
Normal file
178
services/billing/app/tier_manager.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
"""Tier manager: feature matrix and quota enforcement.
|
||||||
|
|
||||||
|
Single source of truth for what each billing tier allows.
|
||||||
|
Other services can query the /tier/{user_id} endpoint or rely on the
|
||||||
|
X-User-Tier header injected by Traefik.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.models import Subscription
|
||||||
|
from shared.schemas import BillingTier
|
||||||
|
|
||||||
|
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
||||||
|
FEATURES: dict[str, dict[str, Any]] = {
|
||||||
|
"free": {
|
||||||
|
"agents": 3,
|
||||||
|
"batch_active": 2,
|
||||||
|
"batch_runs_per_day": 5,
|
||||||
|
"cloud_storage_gb": 0,
|
||||||
|
"backup_gb": 0,
|
||||||
|
"providers": 1,
|
||||||
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"pro": {
|
||||||
|
"agents": -1,
|
||||||
|
"batch_active": 10,
|
||||||
|
"batch_runs_per_day": 50,
|
||||||
|
"cloud_storage_gb": 5,
|
||||||
|
"backup_gb": 5,
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"power": {
|
||||||
|
"agents": -1,
|
||||||
|
"batch_active": -1,
|
||||||
|
"batch_runs_per_day": -1,
|
||||||
|
"cloud_storage_gb": 25,
|
||||||
|
"backup_gb": 25,
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"team": {
|
||||||
|
"agents": -1,
|
||||||
|
"batch_active": -1,
|
||||||
|
"batch_runs_per_day": -1,
|
||||||
|
"cloud_storage_gb": -1,
|
||||||
|
"backup_gb": -1,
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
|
"sso": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Requests-per-minute limit per tier.
|
||||||
|
RATE_LIMITS: dict[str, int] = {
|
||||||
|
"free": 20,
|
||||||
|
"pro": 60,
|
||||||
|
"power": 120,
|
||||||
|
"team": 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TierManager:
|
||||||
|
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||||
|
|
||||||
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
|
"""Return the current billing tier for user_id from the DB."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
tier: str | None = result.scalar_one_or_none()
|
||||||
|
if tier is None or tier not in FEATURES:
|
||||||
|
return "power" if settings.ENV == "dev" else "free"
|
||||||
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
|
def get_features(self, tier: BillingTier) -> dict[str, Any]:
|
||||||
|
"""Return the full feature dict for a tier."""
|
||||||
|
return FEATURES.get(tier, FEATURES["free"])
|
||||||
|
|
||||||
|
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||||
|
"""Return True if tier has feature enabled."""
|
||||||
|
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
return value != 0
|
||||||
|
|
||||||
|
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||||
|
"""Raise HTTP 403 if tier does not have feature."""
|
||||||
|
if not self.check_feature(tier, feature):
|
||||||
|
detail = (
|
||||||
|
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||||
|
if tier_name
|
||||||
|
else f"Feature '{feature}' is not available on your current tier."
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||||
|
|
||||||
|
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||||
|
"""Return the requests-per-minute limit for tier."""
|
||||||
|
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||||
|
|
||||||
|
def enforce_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the user would exceed their cloud storage quota."""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Cloud storage is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def enforce_backup_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the user would exceed their backup quota."""
|
||||||
|
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> bool:
|
||||||
|
"""Return True if the user can store additional_bytes more data."""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
return False
|
||||||
|
if limit_gb == -1:
|
||||||
|
return True
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
return current_bytes + additional_bytes <= limit_bytes
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
tier_manager = TierManager()
|
||||||
9
services/billing/requirements.txt
Normal file
9
services/billing/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
asyncpg>=0.30.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
stripe>=8.0.0
|
||||||
36
services/chat/Dockerfile
Normal file
36
services/chat/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY services/chat/requirements.txt ./requirements.txt
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Shared module
|
||||||
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Service source
|
||||||
|
COPY services/chat/app/ app/
|
||||||
|
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Chat service is CPU-bound (LLM calls) — use multiple workers
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "2", \
|
||||||
|
"--timeout", "120"]
|
||||||
21
services/chat/README.md
Normal file
21
services/chat/README.md
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Chat Service
|
||||||
|
|
||||||
|
Owns: deep_agent (home + floating chat), memory middleware, domain agents
|
||||||
|
(task, note, project, timeline), LLM orchestration.
|
||||||
|
|
||||||
|
## Tables owned
|
||||||
|
- `memory_core`
|
||||||
|
- `memory_associative`
|
||||||
|
- `memory_episodic`
|
||||||
|
- `memory_proactive`
|
||||||
|
|
||||||
|
## Tables read (cross-service)
|
||||||
|
- `users` (for encryption_key — memory decryption)
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
- `POST /chat` (REST fallback)
|
||||||
|
|
||||||
|
## Redis channels
|
||||||
|
- Subscribe: `chat:request:{user_id}`
|
||||||
|
- Publish: `ws:out:{user_id}` (stream frames + tool calls)
|
||||||
|
- BRPOP: `tool:result:{call_id}` (30s timeout)
|
||||||
0
services/chat/app/__init__.py
Normal file
0
services/chat/app/__init__.py
Normal file
883
services/chat/app/deep_agent.py
Normal file
883
services/chat/app/deep_agent.py
Normal file
@@ -0,0 +1,883 @@
|
|||||||
|
"""Single-agent runners for home and floating chat contexts.
|
||||||
|
|
||||||
|
Adapted from app/core/deep_agent.py for the Chat Service.
|
||||||
|
Import paths changed to use local app modules and shared/.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import date
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from shared.agents.note_agent import NOTE_TOOLS
|
||||||
|
from shared.agents.project_agent import PROJECT_TOOLS
|
||||||
|
from shared.agents.task_agent import TASK_TOOLS
|
||||||
|
from shared.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
|
from shared.llm import get_llm
|
||||||
|
from app.memory_middleware import MemoryMiddleware
|
||||||
|
from shared.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
||||||
|
from app import tracing
|
||||||
|
from shared.db import async_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
||||||
|
FloatingDomainSection = Literal["task", "timeline", "note"]
|
||||||
|
|
||||||
|
_HOME_SINGLE_AGENT_SYSTEM = (
|
||||||
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
|
"Always use tools for factual data retrieval before answering. "
|
||||||
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||||
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
|
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||||
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
|
||||||
|
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
|
||||||
|
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
|
||||||
|
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
|
||||||
|
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
|
||||||
|
)
|
||||||
|
|
||||||
|
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
||||||
|
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
|
"Stay focused on the floating scope in context.scope and answer concisely. "
|
||||||
|
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
|
||||||
|
"Always use tools for factual data retrieval before answering. "
|
||||||
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||||
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
|
)
|
||||||
|
|
||||||
|
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
|
||||||
|
"You are a strict domain classifier for websocket floating requests. "
|
||||||
|
"Return ONLY a JSON object with keys: type, id, section. "
|
||||||
|
"Allowed type values: task, timeline, project, node. "
|
||||||
|
"Allowed section values: task, timeline, note, or null. "
|
||||||
|
"Rules: infer from user message intent first; do not blindly trust scope.type. "
|
||||||
|
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
|
||||||
|
"If project id is unknown but context.resolved_project_id exists, use it as id. "
|
||||||
|
"If id is unknown, use null. "
|
||||||
|
"No markdown, no prose, JSON only."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _candidate_tokens(message: str) -> list[str]:
|
||||||
|
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
|
||||||
|
return [token for token in tokens if len(token) >= 3]
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_project_id_from_message(message: str) -> str | None:
|
||||||
|
"""Resolve likely project UUID from user message using client project list."""
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="select", table="projects")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("deep_agent: project resolve select failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not isinstance(rows, list) or not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tokens = _candidate_tokens(message)
|
||||||
|
scored: list[tuple[int, dict[str, Any]]] = []
|
||||||
|
for row in rows:
|
||||||
|
if not isinstance(row, dict):
|
||||||
|
continue
|
||||||
|
name = str(row.get("name", "")).lower()
|
||||||
|
score = sum(1 for token in tokens if token in name)
|
||||||
|
if score > 0:
|
||||||
|
scored.append((score, row))
|
||||||
|
|
||||||
|
if not scored:
|
||||||
|
return None
|
||||||
|
|
||||||
|
scored.sort(key=lambda item: item[0], reverse=True)
|
||||||
|
top_score = scored[0][0]
|
||||||
|
top_rows = [row for score, row in scored if score == top_score]
|
||||||
|
if len(top_rows) != 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
project_id = top_rows[0].get("id")
|
||||||
|
return project_id if isinstance(project_id, str) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _needs_project_resolution(message: str) -> bool:
|
||||||
|
lowered = message.lower()
|
||||||
|
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
prepared = dict(context)
|
||||||
|
if _needs_project_resolution(message):
|
||||||
|
resolved_project_id = await _resolve_project_id_from_message(message)
|
||||||
|
if resolved_project_id:
|
||||||
|
prepared["resolved_project_id"] = resolved_project_id
|
||||||
|
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
|
||||||
|
return prepared
|
||||||
|
|
||||||
|
|
||||||
|
def _all_tools() -> list[Any]:
|
||||||
|
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
||||||
|
|
||||||
|
|
||||||
|
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
||||||
|
debug = context.get("_debug")
|
||||||
|
if isinstance(debug, dict):
|
||||||
|
request_id = debug.get("request_id")
|
||||||
|
if isinstance(request_id, str) and request_id:
|
||||||
|
return request_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
sanitized = dict(context)
|
||||||
|
sanitized.pop("_debug", None)
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
|
||||||
|
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_upcoming_timeline_query(message: str) -> bool:
|
||||||
|
lowered = message.lower()
|
||||||
|
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
|
||||||
|
has_timeline_topic = any(
|
||||||
|
token in lowered
|
||||||
|
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
|
||||||
|
)
|
||||||
|
return has_upcoming and has_timeline_topic
|
||||||
|
|
||||||
|
|
||||||
|
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
|
||||||
|
match = _TIMELINE_DMY_RE.search(dmy)
|
||||||
|
if not match:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
parsed = date(
|
||||||
|
int(match.group("y")),
|
||||||
|
int(match.group("m")),
|
||||||
|
int(match.group("d")),
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
return True
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
return parsed >= today and parsed.year == today.year and parsed.month == today.month
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_tagged_list_lines(text: str, message: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
upcoming_timeline_only = _is_upcoming_timeline_query(message)
|
||||||
|
output_lines: list[str] = []
|
||||||
|
|
||||||
|
for line in text.splitlines():
|
||||||
|
matches = list(_TAG_LINE_RE.finditer(line))
|
||||||
|
if not matches:
|
||||||
|
output_lines.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
|
||||||
|
if not had_non_tag_text and len(matches) == 1:
|
||||||
|
tag_text = matches[0].group(0)
|
||||||
|
if (
|
||||||
|
upcoming_timeline_only
|
||||||
|
and "<timeline>" in tag_text
|
||||||
|
and not _timeline_date_in_current_month_or_future(line)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
output_lines.append(tag_text)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
tag_text = match.group(0)
|
||||||
|
if (
|
||||||
|
upcoming_timeline_only
|
||||||
|
and "<timeline>" in tag_text
|
||||||
|
and not _timeline_date_in_current_month_or_future(line)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
output_lines.append(tag_text)
|
||||||
|
|
||||||
|
return "\n".join(output_lines)
|
||||||
|
|
||||||
|
|
||||||
|
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
|
||||||
|
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
|
||||||
|
_FLOATING_EMPTY_FALLBACK = "No results found."
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_floating_markup_fragment(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
cleaned = _GENERIC_TAG_RE.sub("", text)
|
||||||
|
return _BRACKETED_ID_RE.sub("", cleaned)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_floating_markup(text: str) -> str:
|
||||||
|
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
cleaned = _strip_floating_markup_fragment(text)
|
||||||
|
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
|
||||||
|
return "\n".join(line for line in lines if line)
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_from_raw_floating_text(raw_text: str) -> str:
|
||||||
|
fallback = _strip_floating_markup_fragment(raw_text or "")
|
||||||
|
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
|
||||||
|
return fallback or _FLOATING_EMPTY_FALLBACK
|
||||||
|
|
||||||
|
|
||||||
|
class _FloatingStreamSanitizer:
|
||||||
|
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._pending = ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_safe_boundary(text: str) -> tuple[str, str]:
|
||||||
|
boundary = len(text)
|
||||||
|
|
||||||
|
last_lt = text.rfind("<")
|
||||||
|
if last_lt != -1 and ">" not in text[last_lt:]:
|
||||||
|
boundary = min(boundary, last_lt)
|
||||||
|
|
||||||
|
last_lb = text.rfind("[")
|
||||||
|
if last_lb != -1 and "]" not in text[last_lb:]:
|
||||||
|
boundary = min(boundary, last_lb)
|
||||||
|
|
||||||
|
if boundary == len(text):
|
||||||
|
return text, ""
|
||||||
|
return text[:boundary], text[boundary:]
|
||||||
|
|
||||||
|
def feed(self, chunk: str) -> str:
|
||||||
|
combined = f"{self._pending}{chunk}"
|
||||||
|
safe_text, self._pending = self._split_safe_boundary(combined)
|
||||||
|
return _strip_floating_markup_fragment(safe_text)
|
||||||
|
|
||||||
|
def finalize(self) -> str:
|
||||||
|
tail = re.sub(r"<[^>\n]*$", "", self._pending)
|
||||||
|
tail = re.sub(r"\[[^\]\n]*$", "", tail)
|
||||||
|
self._pending = ""
|
||||||
|
return _strip_floating_markup_fragment(tail)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_memory_label(path_or_label: str) -> str:
|
||||||
|
value = path_or_label.strip()
|
||||||
|
if value.startswith("/memories/"):
|
||||||
|
value = value[len("/memories/"):]
|
||||||
|
value = value.strip("/")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
@tool
|
||||||
|
async def memory_list_blocks() -> str:
|
||||||
|
"""List all core memory blocks currently stored for the user."""
|
||||||
|
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
blocks = await memory.list_core_blocks(user_id)
|
||||||
|
if not blocks:
|
||||||
|
return "No memory blocks found."
|
||||||
|
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
|
||||||
|
return "Memory blocks:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_get(path_or_label: str) -> str:
|
||||||
|
"""Get one memory block by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
value = await memory.get_core_block(user_id, label)
|
||||||
|
if value is None:
|
||||||
|
return f"Memory block '{label}' not found."
|
||||||
|
return f"Memory block '{label}':\n{value}"
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_create(path_or_label: str, value: str) -> str:
|
||||||
|
"""Create or overwrite a memory block value by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.update_core(user_id, label, value, trace_id=trace_id)
|
||||||
|
return f"Memory block '{label}' saved."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_append(path_or_label: str, content: str) -> str:
|
||||||
|
"""Append content to a memory block, creating it if missing."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.append_core(user_id, label, content)
|
||||||
|
return f"Memory block '{label}' appended."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
|
||||||
|
"""Replace one exact string in a memory block."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
changed = await memory.replace_core(user_id, label, old_string, new_string)
|
||||||
|
if not changed:
|
||||||
|
return f"No replacement made in '{label}' (old string not found)."
|
||||||
|
return f"Memory block '{label}' updated."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_delete(path_or_label: str) -> str:
|
||||||
|
"""Delete a memory block by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
deleted = await memory.delete_core(user_id, label)
|
||||||
|
if not deleted:
|
||||||
|
return f"Memory block '{label}' not found."
|
||||||
|
return f"Memory block '{label}' deleted."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def archival_memory_insert(content: str) -> str:
|
||||||
|
"""Insert a long-term archival memory entry."""
|
||||||
|
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.insert_archival(user_id, content, source="assistant")
|
||||||
|
return "Archival memory saved."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def archival_memory_search(query: str, top_k: int = 5) -> str:
|
||||||
|
"""Search long-term archival memory by semantic fallback (keyword currently)."""
|
||||||
|
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
results = await memory.search_archival(user_id, query, top_k=top_k)
|
||||||
|
if not results:
|
||||||
|
return "No archival memory results found."
|
||||||
|
lines = [f"- {item}" for item in results]
|
||||||
|
return "Archival memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def conversation_search(query: str, top_k: int = 5) -> str:
|
||||||
|
"""Search recall memory from prior episodic conversation summaries."""
|
||||||
|
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
results = await memory.search_recall(user_id, query, top_k=top_k)
|
||||||
|
if not results:
|
||||||
|
return "No recall memory results found."
|
||||||
|
lines = [f"- {item}" for item in results]
|
||||||
|
return "Recall memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
return [
|
||||||
|
memory_list_blocks,
|
||||||
|
memory_get,
|
||||||
|
memory_create,
|
||||||
|
memory_append,
|
||||||
|
memory_replace,
|
||||||
|
memory_delete,
|
||||||
|
archival_memory_insert,
|
||||||
|
archival_memory_search,
|
||||||
|
conversation_search,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
|
||||||
|
lowered = message.lower()
|
||||||
|
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
|
||||||
|
return "timeline"
|
||||||
|
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
|
||||||
|
return "task"
|
||||||
|
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
|
||||||
|
return "note"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
|
||||||
|
type_raw = str(payload.get("type") or "").strip().lower()
|
||||||
|
domain_type: FloatingDomainType = "task"
|
||||||
|
if type_raw in {"task", "timeline", "project", "node"}:
|
||||||
|
domain_type = type_raw
|
||||||
|
|
||||||
|
id_value = payload.get("id")
|
||||||
|
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
|
||||||
|
if domain_type == "project" and not domain_id:
|
||||||
|
domain_id = fallback_id
|
||||||
|
|
||||||
|
section_raw = payload.get("section")
|
||||||
|
section: FloatingDomainSection | None = None
|
||||||
|
if isinstance(section_raw, str):
|
||||||
|
section_candidate = section_raw.strip().lower()
|
||||||
|
if section_candidate in {"task", "timeline", "note"}:
|
||||||
|
section = section_candidate
|
||||||
|
|
||||||
|
if domain_type != "project":
|
||||||
|
section = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": domain_type,
|
||||||
|
"id": domain_id,
|
||||||
|
"section": section,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_object(text: str) -> dict[str, Any] | None:
|
||||||
|
raw = text.strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = json.loads(match.group(0))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
||||||
|
section = _detect_domain_section(message)
|
||||||
|
scope = context.get("scope") if isinstance(context, dict) else None
|
||||||
|
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
||||||
|
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
||||||
|
|
||||||
|
if isinstance(scope, dict):
|
||||||
|
scope_type = str(scope.get("type") or "").strip().lower()
|
||||||
|
scope_id = scope.get("id")
|
||||||
|
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
|
||||||
|
|
||||||
|
if scope_type in {"task", "tasks"}:
|
||||||
|
return {"type": "task", "id": scope_id_value, "section": None}
|
||||||
|
if scope_type in {"project", "projects"}:
|
||||||
|
project_scope_id = scope_id_value or project_id
|
||||||
|
return {
|
||||||
|
"type": "project",
|
||||||
|
"id": project_scope_id,
|
||||||
|
"section": section,
|
||||||
|
}
|
||||||
|
if scope_type in {"note", "notes"}:
|
||||||
|
return {
|
||||||
|
"type": "node",
|
||||||
|
"id": scope_id_value,
|
||||||
|
"section": None,
|
||||||
|
}
|
||||||
|
if scope_type in {"timeline", "timelines"}:
|
||||||
|
return {"type": "timeline", "id": scope_id_value, "section": None}
|
||||||
|
|
||||||
|
lowered = message.lower()
|
||||||
|
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
|
||||||
|
return {
|
||||||
|
"type": "project",
|
||||||
|
"id": project_id,
|
||||||
|
"section": section,
|
||||||
|
}
|
||||||
|
if section == "timeline":
|
||||||
|
return {"type": "timeline", "id": None, "section": None}
|
||||||
|
if section == "note":
|
||||||
|
return {"type": "node", "id": None, "section": None}
|
||||||
|
return {"type": "task", "id": None, "section": None}
|
||||||
|
|
||||||
|
|
||||||
|
async def _infer_floating_domain(
|
||||||
|
message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None,
|
||||||
|
) -> dict[str, str | None]:
|
||||||
|
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
||||||
|
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
||||||
|
|
||||||
|
classifier_context = {
|
||||||
|
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
|
||||||
|
"resolved_project_id": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
classifier_prompt = _get_system_prompt(
|
||||||
|
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_SYSTEM,
|
||||||
|
)
|
||||||
|
callbacks = _build_callbacks(langfuse_handler)
|
||||||
|
llm = get_llm(callbacks=callbacks)
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[
|
||||||
|
SystemMessage(content=classifier_prompt),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"Message:\n{message}\n\n"
|
||||||
|
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
parsed = _parse_json_object(_as_text(response.content))
|
||||||
|
if parsed is not None:
|
||||||
|
domain = _normalize_domain_payload(parsed, project_id)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
|
||||||
|
domain.get("type"),
|
||||||
|
domain.get("id"),
|
||||||
|
domain.get("section"),
|
||||||
|
)
|
||||||
|
return domain
|
||||||
|
logger.warning("deep_agent: floating_domain classifier returned non-json output")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
|
||||||
|
|
||||||
|
return _infer_floating_domain_rule_based(message, context)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_system_prompt(langfuse_name: str, fallback: str) -> str:
|
||||||
|
"""Fetch a managed prompt from Langfuse, falling back to the hardcoded string."""
|
||||||
|
managed = tracing.get_prompt(langfuse_name, fallback=None)
|
||||||
|
return managed if managed is not None else fallback
|
||||||
|
|
||||||
|
|
||||||
|
def _build_callbacks(langfuse_handler: Any | None) -> list[Any] | None:
|
||||||
|
"""Return a callbacks list if a Langfuse handler is available."""
|
||||||
|
if langfuse_handler is None:
|
||||||
|
return None
|
||||||
|
return [langfuse_handler]
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_single_agent(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
system_prompt: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
max_steps: int = 6,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> str:
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
|
callbacks = _build_callbacks(langfuse_handler)
|
||||||
|
llm = get_llm(callbacks=callbacks)
|
||||||
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
|
model_context = _context_for_model(context)
|
||||||
|
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"User message:\n{message}\n\n"
|
||||||
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls_count = 0
|
||||||
|
collected: list[dict[str, Any]] = []
|
||||||
|
set_tool_result_collector(collected)
|
||||||
|
try:
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
final_text = _as_text(response.content)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
len(final_text),
|
||||||
|
)
|
||||||
|
return final_text
|
||||||
|
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_calls_count += 1
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:1200],
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
final_text = _as_text(final.content)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
len(final_text),
|
||||||
|
)
|
||||||
|
return final_text
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_single_agent_stream(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
system_prompt: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
max_steps: int = 6,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
|
callbacks = _build_callbacks(langfuse_handler)
|
||||||
|
llm = get_llm(callbacks=callbacks)
|
||||||
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
|
model_context = _context_for_model(context)
|
||||||
|
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"User message:\n{message}\n\n"
|
||||||
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls_count = 0
|
||||||
|
streamed_chars = 0
|
||||||
|
collected: list[dict[str, Any]] = []
|
||||||
|
set_tool_result_collector(collected)
|
||||||
|
try:
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
emitted_any = False
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
|
if token:
|
||||||
|
streamed_chars += len(token)
|
||||||
|
emitted_any = True
|
||||||
|
yield "token", token
|
||||||
|
|
||||||
|
if not emitted_any:
|
||||||
|
fallback_text = _as_text(response.content)
|
||||||
|
if fallback_text:
|
||||||
|
streamed_chars += len(fallback_text)
|
||||||
|
yield "token", fallback_text
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
streamed_chars,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_calls_count += 1
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:1200],
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
|
if token:
|
||||||
|
streamed_chars += len(token)
|
||||||
|
yield "token", token
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
streamed_chars,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> str:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
|
||||||
|
response = await _run_single_agent(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
return _normalize_tagged_list_lines(response, message)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_floating(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> tuple[str, dict[str, str | None]]:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
|
||||||
|
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
|
||||||
|
response = await _run_single_agent(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
sanitized = _strip_floating_markup(response)
|
||||||
|
if not sanitized and response:
|
||||||
|
sanitized = _fallback_from_raw_floating_text(response)
|
||||||
|
return sanitized, domain
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
*,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
|
||||||
|
text_chunks: list[str] = []
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
):
|
||||||
|
event_type, data = event
|
||||||
|
if event_type != "token":
|
||||||
|
yield event
|
||||||
|
continue
|
||||||
|
text_chunks.append(str(data or ""))
|
||||||
|
|
||||||
|
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
|
||||||
|
if normalized:
|
||||||
|
yield "token", normalized
|
||||||
|
|
||||||
|
|
||||||
|
async def run_floating_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
*,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
|
||||||
|
yield "floating_domain", domain
|
||||||
|
|
||||||
|
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
|
||||||
|
sanitizer = _FloatingStreamSanitizer()
|
||||||
|
emitted_sanitized = False
|
||||||
|
raw_chunks: list[str] = []
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
):
|
||||||
|
event_type, data = event
|
||||||
|
if event_type != "token":
|
||||||
|
yield event
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_chunk = str(data or "")
|
||||||
|
raw_chunks.append(raw_chunk)
|
||||||
|
sanitized_chunk = sanitizer.feed(raw_chunk)
|
||||||
|
if sanitized_chunk:
|
||||||
|
emitted_sanitized = True
|
||||||
|
yield "token", sanitized_chunk
|
||||||
|
|
||||||
|
tail = sanitizer.finalize()
|
||||||
|
if tail:
|
||||||
|
emitted_sanitized = True
|
||||||
|
yield "token", tail
|
||||||
|
|
||||||
|
if not emitted_sanitized and raw_chunks:
|
||||||
|
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
||||||
|
|
||||||
|
|
||||||
|
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
||||||
|
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.update_core(user_id, key, value)
|
||||||
77
services/chat/app/llm.py
Normal file
77
services/chat/app/llm.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
|
Adapted from app/core/llm.py for the Chat Service.
|
||||||
|
Uses shared.config.settings instead of app.config.settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_litellm import ChatLiteLLM
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
|
if model.startswith("anthropic/"):
|
||||||
|
return settings.ANTHROPIC_API_KEY or None
|
||||||
|
if model.startswith("gemini/") or model.startswith("google/"):
|
||||||
|
return settings.GOOGLE_API_KEY or None
|
||||||
|
if model.startswith("cerebras/"):
|
||||||
|
return settings.CEREBRAS_API_KEY or None
|
||||||
|
if model.startswith("github/"):
|
||||||
|
return settings.GITHUB_TOKEN or None
|
||||||
|
if model.startswith("github_copilot/"):
|
||||||
|
return None
|
||||||
|
return settings.OPENAI_API_KEY or None
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm(
|
||||||
|
*,
|
||||||
|
model: str | None = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
callbacks: list | None = None,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
model = model or settings.LLM_MODEL
|
||||||
|
|
||||||
|
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||||
|
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||||
|
|
||||||
|
if settings.GITHUB_TOKEN:
|
||||||
|
os.environ.setdefault("GITHUB_TOKEN", settings.GITHUB_TOKEN)
|
||||||
|
|
||||||
|
if "/" in model:
|
||||||
|
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=_api_key_for_model(model),
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embed(text: str) -> list[float]:
|
||||||
|
model = settings.LLM_EMBED_MODEL
|
||||||
|
|
||||||
|
if model.startswith("github_copilot/") or "/" in model:
|
||||||
|
response = await litellm.aembedding(model=model, input=[text])
|
||||||
|
return response.data[0]["embedding"]
|
||||||
|
|
||||||
|
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
|
response = await client.embeddings.create(model=model, input=text)
|
||||||
|
return response.data[0].embedding
|
||||||
87
services/chat/app/main.py
Normal file
87
services/chat/app/main.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""Chat Service — LLM orchestration, domain agents, memory.
|
||||||
|
|
||||||
|
Consumes chat requests from Redis, executes deep_agent (home/floating),
|
||||||
|
streams responses back via Redis pub/sub to WS Gateway.
|
||||||
|
|
||||||
|
Owns: memory_core, memory_associative, memory_episodic, memory_proactive tables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
|
||||||
|
_repo_root = str(Path(__file__).resolve().parents[3])
|
||||||
|
if _repo_root not in sys.path:
|
||||||
|
sys.path.insert(0, _repo_root)
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# Initialise Langfuse tracing (no-op if keys are missing)
|
||||||
|
from app.tracing import init_langfuse
|
||||||
|
|
||||||
|
init_langfuse()
|
||||||
|
|
||||||
|
# Start Redis consumer in background
|
||||||
|
from app.redis_consumer import start_consumer
|
||||||
|
|
||||||
|
consumer_task = start_consumer()
|
||||||
|
yield
|
||||||
|
consumer_task.cancel()
|
||||||
|
|
||||||
|
from app.tracing import shutdown as shutdown_langfuse
|
||||||
|
|
||||||
|
shutdown_langfuse()
|
||||||
|
|
||||||
|
from shared.db import engine
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
from shared.redis import redis_client
|
||||||
|
|
||||||
|
await redis_client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
app = FastAPI(
|
||||||
|
title="Adiuva Chat Service",
|
||||||
|
version="0.1.0",
|
||||||
|
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||||
|
redoc_url=None,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.CORS_ORIGINS,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.routes import router
|
||||||
|
|
||||||
|
app.include_router(router, prefix="/api/v1")
|
||||||
|
|
||||||
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
async def health() -> dict:
|
||||||
|
return {"status": "ok", "service": "chat", "version": app.version}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
295
services/chat/app/memory_middleware.py
Normal file
295
services/chat/app/memory_middleware.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
"""Memory Middleware — adapted for Chat Service.
|
||||||
|
|
||||||
|
Uses shared.models instead of app.models. Otherwise identical to the
|
||||||
|
monolith's app/core/memory_middleware.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.models import (
|
||||||
|
MemoryAssociative,
|
||||||
|
MemoryCore,
|
||||||
|
MemoryEpisodic,
|
||||||
|
MemoryProactive,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ASSOCIATIVE_TOP_K = 5
|
||||||
|
_EPISODIC_RECENT_N = 10
|
||||||
|
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryMiddleware:
|
||||||
|
|
||||||
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
|
self._db = db
|
||||||
|
|
||||||
|
async def enrich_context(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
core = await self._load_core(user_id, fernet)
|
||||||
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"memory: enrich_context trace=%s user=%s core=%d assoc=%d episodic=%d proactive=%d",
|
||||||
|
trace_id or "-", user_id, len(core), len(associative), len(episodic), len(proactive),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"core_memory": core,
|
||||||
|
"associative_memory": associative,
|
||||||
|
"episodic_memory": episodic,
|
||||||
|
"proactive_hints": proactive,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def store_episode(
|
||||||
|
self, user_id: str, session_id: str, message: str, response: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||||
|
encrypted = _encrypt(fernet, summary)
|
||||||
|
|
||||||
|
row = MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
summary_encrypted=encrypted,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, value)
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == key)
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
if existing is not None:
|
||||||
|
existing.value_encrypted = encrypted
|
||||||
|
else:
|
||||||
|
self._db.add(MemoryCore(
|
||||||
|
id=str(uuid.uuid4()), user_id=user_id, key=key, value_encrypted=encrypted,
|
||||||
|
))
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id).order_by(MemoryCore.key.asc())
|
||||||
|
)
|
||||||
|
out: list[dict[str, str]] = []
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append({"label": row.key, "value": plaintext})
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return None
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
|
||||||
|
async def delete_core(self, user_id: str, label: str) -> bool:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
return False
|
||||||
|
await self._db.delete(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None:
|
||||||
|
await self.update_core(user_id, label, content)
|
||||||
|
return
|
||||||
|
await self.update_core(user_id, label, f"{current}\n{content}")
|
||||||
|
|
||||||
|
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None or old not in current:
|
||||||
|
return False
|
||||||
|
await self.update_core(user_id, label, current.replace(old, new, 1))
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
encrypted = _encrypt(fernet, content)
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()), user_id=user_id,
|
||||||
|
content_encrypted=encrypted, embedding=None,
|
||||||
|
entity_type=source, entity_id=None,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc()).limit(100)
|
||||||
|
)
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||||
|
.order_by(MemoryEpisodic.created_at.desc()).limit(100)
|
||||||
|
)
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
# ── Private ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not user.encryption_key:
|
||||||
|
logger.warning("memory: no encryption_key for user=%s", user_id)
|
||||||
|
return None
|
||||||
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
|
)
|
||||||
|
out: dict[str, str] = {}
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out[row.key] = plaintext
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_associative(self, user_id: str, message: str, fernet: Fernet) -> list[str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc()).limit(_ASSOCIATIVE_TOP_K)
|
||||||
|
)
|
||||||
|
out: list[str] = []
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_episodic(self, user_id: str, fernet: Fernet, session_id: str | None = None) -> list[str]:
|
||||||
|
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||||
|
if session_id:
|
||||||
|
query = query.where(MemoryEpisodic.session_id == session_id)
|
||||||
|
result = await self._db.execute(
|
||||||
|
query.order_by(MemoryEpisodic.created_at.desc()).limit(_EPISODIC_RECENT_N)
|
||||||
|
)
|
||||||
|
out: list[str] = []
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryProactive).where(
|
||||||
|
MemoryProactive.user_id == user_id,
|
||||||
|
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
||||||
|
).order_by(MemoryProactive.confidence.desc())
|
||||||
|
)
|
||||||
|
out: list[str] = []
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
||||||
|
return fernet.encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
||||||
|
try:
|
||||||
|
return fernet.decrypt(ciphertext.encode()).decode()
|
||||||
|
except (InvalidToken, Exception) as exc:
|
||||||
|
logger.warning("memory: decrypt failed: %s", exc)
|
||||||
|
return None
|
||||||
50
services/chat/app/output_formatter.py
Normal file
50
services/chat/app/output_formatter.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Output formatter for deep-agent stream events — Chat Service copy.
|
||||||
|
|
||||||
|
Converts (event_type, data) tuples into WebSocket frame Pydantic models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
|
|
||||||
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
|
class StreamFormatter:
|
||||||
|
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
started = False
|
||||||
|
|
||||||
|
async for event_type, data in event_stream:
|
||||||
|
if event_type == "floating_domain":
|
||||||
|
if isinstance(data, dict):
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain=data,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event_type != "token":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not started:
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
started = True
|
||||||
|
|
||||||
|
text = str(data or "")
|
||||||
|
if text:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=text)
|
||||||
|
|
||||||
|
if not started:
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
yield WsStreamEnd(request_id=self.request_id)
|
||||||
209
services/chat/app/redis_consumer.py
Normal file
209
services/chat/app/redis_consumer.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
"""Redis consumer — listens for chat requests and dispatches to deep_agent.
|
||||||
|
|
||||||
|
Subscribes to a Redis pattern channel chat:request:* so it receives
|
||||||
|
requests for ALL users. Each request is processed in a separate asyncio task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.redis import redis_client, ws_out_channel
|
||||||
|
|
||||||
|
from app.deep_agent import run_floating_stream, run_home_stream
|
||||||
|
from app.memory_middleware import MemoryMiddleware
|
||||||
|
from app.output_formatter import StreamFormatter
|
||||||
|
from shared.ws_context import clear_current_user, set_current_user
|
||||||
|
from app import tracing
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def start_consumer() -> asyncio.Task:
|
||||||
|
"""Start the Redis consumer as a background asyncio task."""
|
||||||
|
return asyncio.create_task(_consumer_loop())
|
||||||
|
|
||||||
|
|
||||||
|
async def _consumer_loop() -> None:
|
||||||
|
"""Subscribe to chat:request:* and dispatch incoming frames."""
|
||||||
|
pubsub = redis_client.pubsub()
|
||||||
|
await pubsub.psubscribe("chat:request:*")
|
||||||
|
logger.info("redis_consumer: subscribed to chat:request:*")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await pubsub.get_message(
|
||||||
|
ignore_subscribe_messages=True, timeout=1.0
|
||||||
|
)
|
||||||
|
if message is not None and message["type"] == "pmessage":
|
||||||
|
frame = json.loads(message["data"])
|
||||||
|
asyncio.create_task(_dispatch(frame))
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("redis_consumer: shutting down")
|
||||||
|
finally:
|
||||||
|
await pubsub.punsubscribe()
|
||||||
|
await pubsub.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch(frame: dict) -> None:
|
||||||
|
"""Route a chat request frame to the appropriate handler."""
|
||||||
|
frame_type = frame.get("type")
|
||||||
|
user_id = frame.get("user_id")
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("redis_consumer: frame missing user_id: %s", frame.get("type"))
|
||||||
|
return
|
||||||
|
|
||||||
|
if frame_type == "home_request":
|
||||||
|
await _handle_home_request(user_id, frame)
|
||||||
|
elif frame_type == "floating_request":
|
||||||
|
await _handle_floating_request(user_id, frame)
|
||||||
|
else:
|
||||||
|
logger.debug("redis_consumer: unknown frame type %r", frame_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def _publish_frame(user_id: str, frame_data: str) -> None:
|
||||||
|
"""Publish a frame to ws:out:{user_id} for the WS Gateway to forward."""
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, frame_data)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_home_request(user_id: str, frame: dict) -> None:
|
||||||
|
"""Process a home_request — enrich with memory, run deep_agent, stream results."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"redis_consumer: home_request user=%s req=%s msg=%s",
|
||||||
|
user_id, request_id, message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
|
||||||
|
with tracing.trace_span(
|
||||||
|
name="home_request",
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
trace_id=request_id,
|
||||||
|
input=message,
|
||||||
|
metadata={"message_preview": message[:200]},
|
||||||
|
tags=["home"],
|
||||||
|
) as span:
|
||||||
|
langfuse_handler = tracing.get_langfuse_callback()
|
||||||
|
|
||||||
|
# Enrich with memory context
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id, message,
|
||||||
|
trace_id=request_id, session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler)
|
||||||
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
await _publish_frame(user_id, ws_frame.model_dump_json())
|
||||||
|
if hasattr(ws_frame, "chunk"):
|
||||||
|
response_chunks.append(ws_frame.chunk)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc)
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
# Link prompt and attach output preview
|
||||||
|
tracing.link_prompt_to_trace(span, "home_system")
|
||||||
|
response_text = "".join(response_chunks)
|
||||||
|
span.update(output=response_text[:500] if response_text else None)
|
||||||
|
|
||||||
|
tracing.flush()
|
||||||
|
|
||||||
|
# Store episode
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.store_episode(
|
||||||
|
user_id, session_id, message, "".join(response_chunks),
|
||||||
|
trace_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_floating_request(user_id: str, frame: dict) -> None:
|
||||||
|
"""Process a floating_request — enrich with memory, run deep_agent, stream results."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
scope: dict = frame.get("scope", {})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"redis_consumer: floating_request user=%s req=%s scope=%s msg=%s",
|
||||||
|
user_id, request_id, json.dumps(scope)[:200], message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
|
||||||
|
with tracing.trace_span(
|
||||||
|
name="floating_request",
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
trace_id=request_id,
|
||||||
|
input=message,
|
||||||
|
metadata={"message_preview": message[:200], "scope": scope},
|
||||||
|
tags=["floating"],
|
||||||
|
) as span:
|
||||||
|
langfuse_handler = tracing.get_langfuse_callback()
|
||||||
|
|
||||||
|
# Enrich with memory context
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id, message,
|
||||||
|
trace_id=request_id, session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"scope": scope,
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler)
|
||||||
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
await _publish_frame(user_id, ws_frame.model_dump_json())
|
||||||
|
if hasattr(ws_frame, "chunk"):
|
||||||
|
response_chunks.append(ws_frame.chunk)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc)
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
# Link prompt and attach output preview
|
||||||
|
tracing.link_prompt_to_trace(span, "floating_system")
|
||||||
|
response_text = "".join(response_chunks)
|
||||||
|
span.update(output=response_text[:500] if response_text else None)
|
||||||
|
|
||||||
|
tracing.flush()
|
||||||
|
|
||||||
|
# Store episode
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.store_episode(
|
||||||
|
user_id, session_id, message, "".join(response_chunks),
|
||||||
|
trace_id=request_id,
|
||||||
|
)
|
||||||
37
services/chat/app/routes.py
Normal file
37
services/chat/app/routes.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Chat REST route — POST /chat fallback when WS is unavailable."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from shared.schemas import ChatRequest
|
||||||
|
|
||||||
|
from app.deep_agent import run_home
|
||||||
|
from shared.ws_context import clear_current_user, set_current_user
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
async def chat(body: ChatRequest, request: Request) -> JSONResponse:
|
||||||
|
"""REST fallback for home chat.
|
||||||
|
|
||||||
|
In the microservices setup, Traefik ForwardAuth has already validated
|
||||||
|
the JWT and injected X-User-Id / X-User-Email / X-User-Tier headers.
|
||||||
|
"""
|
||||||
|
user_id = request.headers.get("X-User-Id", "")
|
||||||
|
if not user_id:
|
||||||
|
return JSONResponse(status_code=401, content={"detail": "Missing X-User-Id header"})
|
||||||
|
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
response = await run_home(
|
||||||
|
user_id=user_id,
|
||||||
|
message=body.message,
|
||||||
|
context=body.context.model_dump(),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
return JSONResponse(content={"response": response})
|
||||||
304
services/chat/app/tracing.py
Normal file
304
services/chat/app/tracing.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
"""Langfuse tracing & prompt management for the Chat Service (v4 SDK).
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- ``init_langfuse()`` — initialise the singleton client at startup
|
||||||
|
- ``trace_span()`` — context manager that creates a trace + span
|
||||||
|
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
|
||||||
|
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
|
||||||
|
- ``flush()`` / ``shutdown()`` — lifecycle management
|
||||||
|
|
||||||
|
All functions gracefully degrade to no-ops when Langfuse is not configured,
|
||||||
|
so the service works identically with or without observability keys.
|
||||||
|
|
||||||
|
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── State ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_initialised: bool = False
|
||||||
|
_disabled: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_configured() -> bool:
|
||||||
|
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
def init_langfuse() -> None:
|
||||||
|
"""Initialise the Langfuse singleton. Call once at startup."""
|
||||||
|
global _initialised, _disabled
|
||||||
|
|
||||||
|
if _initialised or _disabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not _is_configured():
|
||||||
|
_disabled = True
|
||||||
|
logger.info("tracing: Langfuse keys not set — tracing disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import Langfuse
|
||||||
|
|
||||||
|
Langfuse(
|
||||||
|
secret_key=settings.LANGFUSE_SECRET_KEY,
|
||||||
|
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
||||||
|
host=settings.LANGFUSE_HOST,
|
||||||
|
)
|
||||||
|
_initialised = True
|
||||||
|
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
|
||||||
|
except Exception as exc:
|
||||||
|
_disabled = True
|
||||||
|
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client() -> Any | None:
|
||||||
|
"""Return the singleton Langfuse client, or *None* if disabled."""
|
||||||
|
if _disabled:
|
||||||
|
return None
|
||||||
|
if not _initialised:
|
||||||
|
init_langfuse()
|
||||||
|
if _disabled:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
from langfuse import get_client
|
||||||
|
return get_client()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _NullSpan:
|
||||||
|
"""Drop-in replacement when Langfuse is disabled."""
|
||||||
|
|
||||||
|
def update(self, **_: Any) -> None: ...
|
||||||
|
def set_trace_io(self, **_: Any) -> None: ...
|
||||||
|
def score_trace(self, **_: Any) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
# ── Trace context manager ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def trace_span(
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str | None = None,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
input: Any = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
"""Context manager that creates a Langfuse trace/span.
|
||||||
|
|
||||||
|
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
|
||||||
|
A ``CallbackHandler`` created inside this block auto-inherits the trace
|
||||||
|
context, so there is no need to pass trace IDs manually.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
yield _NullSpan()
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import Langfuse, propagate_attributes
|
||||||
|
|
||||||
|
trace_ctx: dict[str, str] = {}
|
||||||
|
if trace_id is not None:
|
||||||
|
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
|
||||||
|
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="span",
|
||||||
|
name=name,
|
||||||
|
input=input,
|
||||||
|
metadata=metadata or {},
|
||||||
|
**({"trace_context": trace_ctx} if trace_ctx else {}),
|
||||||
|
) as span:
|
||||||
|
with propagate_attributes(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
tags=tags or [],
|
||||||
|
):
|
||||||
|
yield span
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
|
||||||
|
yield _NullSpan()
|
||||||
|
|
||||||
|
|
||||||
|
# ── LangChain callback handler ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_langfuse_callback() -> Any | None:
|
||||||
|
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
|
||||||
|
|
||||||
|
Must be called inside a ``trace_span()`` block for proper linking.
|
||||||
|
Returns *None* when Langfuse is disabled.
|
||||||
|
"""
|
||||||
|
if _disabled and not _initialised:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse.langchain import CallbackHandler
|
||||||
|
return CallbackHandler()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prompt management ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt(
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
fallback: str | None = None,
|
||||||
|
cache_ttl_seconds: int = 300,
|
||||||
|
) -> str | None:
|
||||||
|
"""Fetch a managed prompt from Langfuse by name (without variable compilation).
|
||||||
|
|
||||||
|
Returns the raw prompt string, or *fallback* if the prompt is not
|
||||||
|
found or Langfuse is disabled.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"name": name,
|
||||||
|
"cache_ttl_seconds": cache_ttl_seconds,
|
||||||
|
}
|
||||||
|
if version is not None:
|
||||||
|
kwargs["version"] = version
|
||||||
|
if label is not None:
|
||||||
|
kwargs["label"] = label
|
||||||
|
prompt = lf.get_prompt(**kwargs)
|
||||||
|
return prompt.prompt
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: get_prompt(%s) failed: %s", name, exc)
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
def compile_prompt(
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
fallback: str,
|
||||||
|
variables: dict[str, str],
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
cache_ttl_seconds: int = 300,
|
||||||
|
) -> str:
|
||||||
|
"""Fetch a managed prompt from Langfuse and compile it with ``{{variables}}``.
|
||||||
|
|
||||||
|
If the prompt exists in Langfuse, uses the SDK's ``.compile(**variables)``
|
||||||
|
which replaces ``{{key}}`` placeholders. If Langfuse is disabled or the
|
||||||
|
prompt is not found, falls back to ``fallback.format(**variables)`` (Python
|
||||||
|
``{key}`` placeholders).
|
||||||
|
|
||||||
|
This means:
|
||||||
|
- Langfuse prompts use ``{{variable}}`` syntax.
|
||||||
|
- Hardcoded fallback strings use Python ``{variable}`` syntax.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return fallback.format(**variables)
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"name": name,
|
||||||
|
"cache_ttl_seconds": cache_ttl_seconds,
|
||||||
|
}
|
||||||
|
if version is not None:
|
||||||
|
kwargs["version"] = version
|
||||||
|
if label is not None:
|
||||||
|
kwargs["label"] = label
|
||||||
|
prompt = lf.get_prompt(**kwargs)
|
||||||
|
return prompt.compile(**variables)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: compile_prompt(%s) failed, using fallback: %s", name, exc)
|
||||||
|
return fallback.format(**variables)
|
||||||
|
|
||||||
|
|
||||||
|
def link_prompt_to_trace(
|
||||||
|
span: Any,
|
||||||
|
prompt_name: str,
|
||||||
|
*,
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Attach prompt metadata to a span/trace."""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None or isinstance(span, _NullSpan):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs: dict[str, Any] = {"name": prompt_name}
|
||||||
|
if version is not None:
|
||||||
|
kwargs["version"] = version
|
||||||
|
if label is not None:
|
||||||
|
kwargs["label"] = label
|
||||||
|
prompt = lf.get_prompt(**kwargs)
|
||||||
|
span.update(metadata={"prompt": {"name": prompt_name, "version": prompt.version}})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Scoring helper ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def score_trace(
|
||||||
|
trace_id: str,
|
||||||
|
name: str,
|
||||||
|
value: float,
|
||||||
|
*,
|
||||||
|
comment: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: score_trace failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Shutdown ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def flush() -> None:
|
||||||
|
"""Flush pending Langfuse events."""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is not None:
|
||||||
|
try:
|
||||||
|
lf.flush()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: flush failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
def shutdown() -> None:
|
||||||
|
"""Flush and close the Langfuse client."""
|
||||||
|
global _initialised, _disabled
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is not None:
|
||||||
|
try:
|
||||||
|
lf.flush()
|
||||||
|
lf.shutdown()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: shutdown failed: %s", exc)
|
||||||
|
_initialised = False
|
||||||
|
_disabled = False
|
||||||
17
services/chat/requirements.txt
Normal file
17
services/chat/requirements.txt
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
asyncpg>=0.30.0
|
||||||
|
redis>=5.0.0
|
||||||
|
cryptography>=42.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
langchain-core>=0.3.0
|
||||||
|
langchain-openai>=0.3.0
|
||||||
|
langchain-litellm>=0.3.0
|
||||||
|
litellm>=1.50.0
|
||||||
|
openai>=1.50.0
|
||||||
|
httpx>=0.27.0
|
||||||
|
langfuse>=3.0.0
|
||||||
36
services/ws-gateway/Dockerfile
Normal file
36
services/ws-gateway/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY services/ws-gateway/requirements.txt ./requirements.txt
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Shared module
|
||||||
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Service source
|
||||||
|
COPY services/ws-gateway/app/ app/
|
||||||
|
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Single worker — each instance handles many WS connections via asyncio
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "1", \
|
||||||
|
"--timeout", "0"]
|
||||||
17
services/ws-gateway/README.md
Normal file
17
services/ws-gateway/README.md
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# WS Gateway
|
||||||
|
|
||||||
|
Stateless WebSocket proxy. Accepts Electron connections, authenticates JWT,
|
||||||
|
routes frames to Chat/Batch services via Redis pub/sub.
|
||||||
|
|
||||||
|
## No business logic
|
||||||
|
This service does NOT know what tasks, notes, or agents are.
|
||||||
|
It only routes JSON frames between Electron and downstream services.
|
||||||
|
|
||||||
|
## Scaling
|
||||||
|
Sticky sessions on `user_id` (Traefik consistent hashing).
|
||||||
|
|
||||||
|
## Redis channels used
|
||||||
|
- Subscribe: `ws:out:{user_id}` (frames to send to client)
|
||||||
|
- Publish: `chat:request:{user_id}`, `batch:request:{user_id}`
|
||||||
|
- LPUSH: `tool:result:{call_id}` (from client tool_result frames)
|
||||||
|
- HSET/HDEL: `ws:devices:{user_id}` (device registry)
|
||||||
0
services/ws-gateway/app/__init__.py
Normal file
0
services/ws-gateway/app/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user