Compare commits
388 Commits
396fd2faa4
...
develop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53962a05a4 | ||
|
|
b9148c67a0 | ||
|
|
7c9e8296bf | ||
|
|
11e5e1e656 | ||
| f36ca72396 | |||
| f0adae6513 | |||
|
|
4b80bcb53b | ||
|
|
79a926e4d8 | ||
|
|
55c1bab7b1 | ||
|
|
60925da98c | ||
|
|
f64ca11888 | ||
|
|
93674456ec | ||
|
|
95d4e4be75 | ||
|
|
a7b9d51268 | ||
|
|
9733aa6a3a | ||
|
|
e8f56feaac | ||
|
|
b6468c755f | ||
|
|
6868d8813e | ||
|
|
486ff83a94 | ||
|
|
0a893d1929 | ||
|
|
e6d3f9d7be | ||
|
|
3b41e8e7aa | ||
|
|
4979c2b7d9 | ||
|
|
7bd4cc9d9e | ||
|
|
b9b0a10139 | ||
|
|
78767512f9 | ||
|
|
6e12429f92 | ||
|
|
e87b64cd68 | ||
|
|
1c65bbfe75 | ||
|
|
4cd1ac11cc | ||
|
|
1a4cfb07a5 | ||
|
|
0833db239c | ||
|
|
6adb13ff88 | ||
|
|
11b31e5814 | ||
|
|
cb274c9728 | ||
|
|
d3497a1908 | ||
|
|
0c0299808c | ||
|
|
d1016fd65a | ||
|
|
c559754532 | ||
|
|
ff1208fd3c | ||
|
|
9f21d5ae8f | ||
|
|
699bba3a30 | ||
|
|
3d4aef7fe3 | ||
|
|
1364b9ba37 | ||
|
|
27df8c0a8d | ||
|
|
4933f8055c | ||
|
|
ac33ac1c0d | ||
|
|
5cd895f04e | ||
|
|
fbd308d288 | ||
|
|
49b1d60fca | ||
|
|
b258ec3de5 | ||
|
|
f0a18d7011 | ||
|
|
9b66dc3329 | ||
|
|
105cf52083 | ||
|
|
c2b27d4fb7 | ||
|
|
b92e72b685 | ||
|
|
1ccb0282fe | ||
|
|
1a20c11e86 | ||
|
|
c1b1b289c1 | ||
|
|
6aa7cb3d22 | ||
|
|
70c19d3064 | ||
|
|
886730b47e | ||
|
|
052c7e3741 | ||
|
|
1f60931a0f | ||
|
|
42a457f973 | ||
|
|
e6357b0d61 | ||
|
|
63fc3cfa43 | ||
|
|
d63fd5f3b9 | ||
|
|
d50be8e7af | ||
|
|
d6b1a86e95 | ||
|
|
ca669a1c5c | ||
|
|
ffd0e97508 | ||
|
|
2bc9617b14 | ||
|
|
3aa7aa0d50 | ||
|
|
8a6befd481 | ||
|
|
652a6b830d | ||
|
|
b2b9607f64 | ||
|
|
bdc9411782 | ||
|
|
8529c3f0b6 | ||
|
|
732235c93a | ||
|
|
539beaf225 | ||
|
|
f9eb4b41b6 | ||
|
|
4e42ac8b04 | ||
|
|
869e0d82ee | ||
|
|
5e42b2abb1 | ||
|
|
2b71469e86 | ||
|
|
6188ae15b3 | ||
|
|
e1db7cdf06 | ||
|
|
c53f08229c | ||
|
|
3e2d80d5bb | ||
|
|
49c0ae2413 | ||
|
|
4b5f379126 | ||
|
|
aad8292f9e | ||
|
|
44a21d662d | ||
|
|
ae2cef4335 | ||
|
|
57462af4f4 | ||
|
|
425025ad68 | ||
|
|
b879760013 | ||
|
|
21aa1db07e | ||
|
|
81fe6d29e2 | ||
|
|
b2d7fa1723 | ||
|
|
4c641ab93a | ||
|
|
84720ff23c | ||
|
|
d7307e146a | ||
|
|
7d4059ca4b | ||
|
|
9691842e79 | ||
|
|
094840e671 | ||
|
|
e8592b25a8 | ||
|
|
27b385df53 | ||
|
|
e170844f17 | ||
|
|
27c1194384 | ||
|
|
26ea095f60 | ||
|
|
751d16a9f4 | ||
|
|
285214a2d2 | ||
|
|
89645f2abd | ||
|
|
7dadeb88fe | ||
|
|
13531fec40 | ||
|
|
e254efd420 | ||
|
|
6d79911414 | ||
|
|
69a859e19f | ||
|
|
098ce86c76 | ||
|
|
9ef809ba02 | ||
|
|
024d572ebb | ||
|
|
d24f09bbea | ||
|
|
56fe6c0754 | ||
|
|
c76de207d7 | ||
|
|
4e89a7a96c | ||
|
|
0fc3aa421e | ||
|
|
cc0e258e8c | ||
|
|
c10fbe22d7 | ||
|
|
12e203e63d | ||
|
|
e3e0b06fb6 | ||
|
|
b3d85b93f1 | ||
|
|
ffcd7390f0 | ||
|
|
91e880f9d4 | ||
|
|
7d47ca54be | ||
|
|
659607a1e9 | ||
|
|
80a0d2c56f | ||
|
|
66448a25f4 | ||
|
|
93144b9de8 | ||
|
|
b0c415f90f | ||
|
|
8a2225da7c | ||
|
|
e0c5971d20 | ||
|
|
a499d55636 | ||
|
|
c36890cc8b | ||
|
|
b80ba0434b | ||
|
|
01d3735dd1 | ||
|
|
e0bcb2fe0a | ||
|
|
a1c83a6134 | ||
|
|
bd5e3076ed | ||
|
|
316b8fa66a | ||
|
|
6f907f6a96 | ||
|
|
93caf0116d | ||
|
|
15af8d54e6 | ||
|
|
c4ed7b3482 | ||
|
|
066d407a5f | ||
|
|
c2826ae4be | ||
|
|
956fa88853 | ||
|
|
fb2f59ccea | ||
|
|
56dbb7f4cd | ||
|
|
506f517851 | ||
|
|
520c186991 | ||
|
|
582bf27deb | ||
|
|
2aeb453229 | ||
|
|
b7a4edac90 | ||
|
|
822b4cd8b1 | ||
|
|
ab24fc4c91 | ||
|
|
a98e99f7a2 | ||
|
|
a0ff285bcd | ||
|
|
177c1a87dd | ||
|
|
441a4ea05c | ||
|
|
a693a64bf5 | ||
|
|
adb1cc81ef | ||
|
|
a4fd10e640 | ||
|
|
efa3051c61 | ||
|
|
72e09501de | ||
|
|
875fe625b5 | ||
|
|
dac1d50b02 | ||
|
|
e104ffc3ab | ||
|
|
1cffb9bdbf | ||
|
|
bae84f1a48 | ||
|
|
938c8eef8a | ||
|
|
50d01c7aec | ||
|
|
ef04bec66f | ||
|
|
2e9ec31d83 | ||
|
|
ca290225b9 | ||
|
|
a5ec0647ec | ||
|
|
57f5470f0d | ||
|
|
33e5edc2ba | ||
|
|
fadda94135 | ||
|
|
5fa3df9c16 | ||
|
|
b48ceea0af | ||
|
|
9e31cfa78e | ||
|
|
c63c94b561 | ||
|
|
cbdb37f5a5 | ||
|
|
05de7405ba | ||
|
|
68286b61bd | ||
|
|
a7fbc4c7e3 | ||
|
|
1a5605569c | ||
|
|
ef71710244 | ||
|
|
ca78a4cbc0 | ||
|
|
b652248404 | ||
|
|
f5ac37867c | ||
|
|
37878df992 | ||
|
|
9e90791743 | ||
|
|
dd3f1442b0 | ||
|
|
a5556743f0 | ||
|
|
67562b8092 | ||
|
|
ca231e7b7c | ||
|
|
a5a6e25a89 | ||
|
|
df8cbb5c35 | ||
|
|
6f4c68b359 | ||
|
|
d0b344beec | ||
|
|
1f4adfca90 | ||
|
|
259ab50b25 | ||
|
|
c20c6d7853 | ||
|
|
6787e690ba | ||
|
|
a04c2434b6 | ||
|
|
cb8f56d909 | ||
|
|
c291fc689a | ||
|
|
b61a6de73a | ||
|
|
f2a68ee5f6 | ||
|
|
0c43f5633f | ||
|
|
4ebf0d4062 | ||
|
|
244d53f93d | ||
|
|
8dceacc2ce | ||
|
|
2c7cac9e03 | ||
|
|
7244810fe1 | ||
|
|
ea9094f47f | ||
|
|
d5fea95561 | ||
|
|
0b5ef48463 | ||
|
|
ca8721e1ac | ||
|
|
f658e5e6a3 | ||
|
|
e9c790e017 | ||
|
|
341ee140e5 | ||
|
|
741b9b87fb | ||
|
|
2d8abb6311 | ||
|
|
9b32d834b3 | ||
|
|
e668e3fd20 | ||
|
|
333b6cb769 | ||
|
|
87c444e78d | ||
|
|
811759dddb | ||
|
|
275edab4bf | ||
|
|
7ccdad431f | ||
|
|
0371a46731 | ||
|
|
cd8f6a6751 | ||
|
|
4073863dc6 | ||
|
|
dd98aaaf4d | ||
|
|
a85f8fde29 | ||
|
|
90500a3462 | ||
|
|
c1a8ac7669 | ||
|
|
20bc28e59b | ||
|
|
5d112c8dfd | ||
|
|
c510cbaae5 | ||
|
|
ce139bbac3 | ||
|
|
27bc9d90af | ||
|
|
3cf067faea | ||
|
|
016c44c6f0 | ||
|
|
7253f6fe72 | ||
|
|
41db3a7089 | ||
|
|
cc94194fd1 | ||
|
|
02a0f3635b | ||
|
|
96c91e386d | ||
|
|
109551f713 | ||
| f129b3ba43 | |||
| 7f0c6f45b0 | |||
|
|
2caea8e21d | ||
|
|
b23c4ef255 | ||
|
|
801ae43000 | ||
|
|
c0aef71141 | ||
|
|
467abc8d42 | ||
| bd9af5ddd6 | |||
|
|
5753f8def9 | ||
|
|
e672b58b6f | ||
|
|
d8add7e8cb | ||
|
|
c6c4578f9a | ||
|
|
3aa0b36a6c | ||
|
|
fa231a3642 | ||
|
|
d91c98f86d | ||
|
|
c0619f5c4d | ||
|
|
da282229ff | ||
|
|
7fa6ad5760 | ||
|
|
dcd14220ca | ||
|
|
3cc32569d9 | ||
|
|
bf445ac2ce | ||
|
|
a2d6d689e4 | ||
|
|
aa8bcbf0d8 | ||
|
|
1ce1d492b0 | ||
|
|
552b8eb305 | ||
|
|
0d93b3960d | ||
|
|
f07580574b | ||
|
|
1a8bf11f90 | ||
|
|
e7cdce8287 | ||
| 3ae9e450be | |||
|
|
58bc6efd4b | ||
| 7616153345 | |||
|
|
6c450805cb | ||
| 0c21f47a59 | |||
|
|
f340d0fa3e | ||
|
|
edc53cb6eb | ||
| 7256f1ef4e | |||
| bf635d9c30 | |||
| 5add259348 | |||
| 198fd62ef2 | |||
| 34a771bee3 | |||
|
|
725cece5c1 | ||
|
|
297e20ce8d | ||
| 65a08838c9 | |||
|
|
5a03bd1cfb | ||
| 8b5a05a16e | |||
| 6a87590176 | |||
| cd4644637b | |||
|
|
87b7a1c6c9 | ||
| 9fd441e7d7 | |||
|
|
826f64d6bb | ||
| 5faa6b1d7c | |||
| b7ddc95171 | |||
| 488dab7aa1 | |||
| a52e5362b3 | |||
| 582ad389e1 | |||
| 02a9684cd6 | |||
| 3283cc9ad5 | |||
| fae9efee0d | |||
| 30b062dd4a | |||
| 2a0331d7ce | |||
| 13fd8677c1 | |||
| 9bd629cb59 | |||
| 9c97702daa | |||
| a1e364c9c0 | |||
| 5b55f1292a | |||
| 5bc9ea6cd6 | |||
| f7404b6f66 | |||
| d667e43c73 | |||
| fe085a7951 | |||
| 2de67213f8 | |||
| f6ed383b3a | |||
| 9332e29e53 | |||
| 618076193a | |||
| 34f01234c9 | |||
| 0bd46937d3 | |||
| e6b5bc2e7d | |||
| c90ed58078 | |||
| 76c8f2bdad | |||
| 393b3befd6 | |||
| 2c08275934 | |||
| 7cb384fa63 | |||
| 7efaeba283 | |||
| b61ded8458 | |||
| ac71d99f9a | |||
| 3b3b3baf25 | |||
| 45415bb9ee | |||
| a775a2da18 | |||
| 24772f2b67 | |||
| fd1396a710 | |||
| 914f70bd85 | |||
| 608d6c784f | |||
| 19ad5be97f | |||
| 1dfd088e18 | |||
| c6e1e4e7fd | |||
| cc603aba06 | |||
| 6d9a16e513 | |||
| 27c087d5d8 | |||
|
|
4d7fd519c5 | ||
| 06de7c7ab0 | |||
| e3c7547c75 | |||
| 314780d59a | |||
| 091787a6da | |||
| 7f278c6f63 | |||
| 8bfce9da00 | |||
| 480e7ac5bd | |||
| d0b303e745 | |||
| 5d485b3665 | |||
| 9787befd4a | |||
| 8f7bc25611 | |||
| 3e07fff958 | |||
| 9119474e71 | |||
| 4c4df7335a | |||
| c8ef7b119b | |||
| 35dd9ac86f | |||
| e72d72f4f6 | |||
| 14d1a7351d | |||
| 68955d2fc2 | |||
| 864dfdc4e6 | |||
| 0d16729036 | |||
| 82669d3704 | |||
| 4d0917f5df | |||
| 71fd1a0a7c | |||
| 493b4dd12a |
@@ -1,169 +1,297 @@
|
||||
# CLAUDE.md
|
||||
|
||||
## Commands
|
||||
Guide Claude Code when work in repo.
|
||||
|
||||
## Keeping This File Up to Date
|
||||
|
||||
Update when lesson learned. Update when:
|
||||
|
||||
- Non-obvious arch decision made or found
|
||||
- Gotcha, footgun, surprising behavior hit (+ fix/workaround)
|
||||
- New command, workflow, tool added
|
||||
- Convention set that not obvious from code
|
||||
- Integration detail clarified (IPC protocol behavior, agent tool call edge cases)
|
||||
|
||||
Do **not** add derivable-from-code things, generic best practices, or ephemeral task notes — durable knowledge only.
|
||||
|
||||
> graphify rules live in the root `CLAUDE.md` (single source).
|
||||
|
||||
## Repository Layout
|
||||
|
||||
**Single merged monorepo.** Electron app and FastAPI backend were previously separate submodules; they now live as plain subdirectories in this repo.
|
||||
|
||||
| Directory | What |
|
||||
|-----------|------|
|
||||
| **`electron/`** | Electron desktop app (TypeScript/React) |
|
||||
| **`api/`** | FastAPI backend (Python) |
|
||||
| **`docs/`** | Planning docs & working memory |
|
||||
| **`graphify-out/`** | Knowledge graph (see root `CLAUDE.md`) |
|
||||
|
||||
---
|
||||
|
||||
## Electron App (`electron/`)
|
||||
|
||||
### Commands
|
||||
|
||||
```bash
|
||||
source ~/.nvm/nvm.sh && npm start # Dev with hot-reload
|
||||
source ~/.nvm/nvm.sh && npm run make # Build distributable packages
|
||||
source ~/.nvm/nvm.sh && npm run package # Package without installers
|
||||
source ~/.nvm/nvm.sh && npm run lint # ESLint (.ts/.tsx)
|
||||
source ~/.nvm/nvm.sh && npx drizzle-kit generate # Generate migration from schema
|
||||
source ~/.nvm/nvm.sh && npx drizzle-kit push # Push schema directly (dev only)
|
||||
cd electron
|
||||
npm run start # Dev (Electron + Vite)
|
||||
npm run lint # ESLint
|
||||
npm run knip # Dead code analysis
|
||||
npm run make # Build installers (Win/Linux/macOS)
|
||||
npm run package # Package without installers
|
||||
npm run dev:web # Standalone web SPA dev
|
||||
npm run build:web # Build standalone SPA → dist-web/
|
||||
npm run preview:web # Preview built web SPA
|
||||
npx drizzle-kit generate # Generate migration from schema
|
||||
npx drizzle-kit push # Push schema directly (dev only)
|
||||
```
|
||||
|
||||
No test suite currently.
|
||||
|
||||
## Architecture
|
||||
|
||||
Adiuva is a local-first Electron desktop app. The three processes communicate via a custom tRPC v11 ↔ IPC bridge (the public `electron-trpc` package is incompatible with tRPC v11).
|
||||
### Architecture
|
||||
|
||||
```
|
||||
Renderer (React 19) ──ipcLink──► Preload (contextBridge) ──IPC──► Main (tRPC router + SQLite)
|
||||
Renderer (React 19 + TanStack Router)
|
||||
↓ custom ipcLink (NOT electron-trpc — incompatible with tRPC v11)
|
||||
Preload (contextBridge: window.electronTRPC + window.electronAI)
|
||||
↓ IPC channels
|
||||
Main Process (Node.js)
|
||||
├── tRPC router (CRUD + AI proxy procedures)
|
||||
├── SQLite (better-sqlite3 + Drizzle ORM, WAL mode)
|
||||
└── Backend delegation layer (orchestrator.ts forwards to FastAPI WS)
|
||||
```
|
||||
|
||||
### Main Process (`src/main/`)
|
||||
**Local-first storage, cloud AI.** All user data (clients, projects, tasks, notes, timelines) in local SQLite. AI lives entirely on the FastAPI backend — Electron orchestrator is a thin delegation shell that forwards to `/api/v1/device` WS and dispatches v3 typed stream frames + tool-call ↔ DrizzleExecutor round-trips back to renderer.
|
||||
|
||||
Owns the database and all business logic.
|
||||
**IPC channels**:
|
||||
- `'trpc'` — bidirectional tRPC request/response (all CRUD + auth + scout + memory proxy)
|
||||
- `'ai:stream'` — one-way v3 stream frames main → renderer
|
||||
- `'ai:action'` — AI side-effects (e.g. agent auto-creates task)
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `index.ts` | Window creation, app lifecycle |
|
||||
| `ipc.ts` | Bridges `ipcMain` to tRPC procedures |
|
||||
| `router/index.ts` | All tRPC sub-routers merged into `appRouter` |
|
||||
| `db/index.ts` | Drizzle + better-sqlite3, WAL mode, singleton `getDb()` |
|
||||
| `db/schema.ts` | Table definitions: clients, projects, tasks, checkpoints, notes, taskComments |
|
||||
| `db/vectordb.ts` | LanceDB vector store for note embeddings |
|
||||
| `store.ts` | electron-store for persistent UI settings |
|
||||
**Main process layout (`src/main/`)**:
|
||||
- `index.ts` — Window creation, app lifecycle, protocol handler
|
||||
- `ipc.ts` — Custom tRPC↔IPC bridge
|
||||
- `store.ts` — electron-store for `FormatPrefs` + `uiLanguage`; exports `getUiLanguage()`
|
||||
- `router/index.ts` — All tRPC sub-routers (~1627 LOC)
|
||||
- `db/schema.ts` — 10 tables: clients, projects, tasks, timelineEvents, timelineEventDependencies, notes, noteEdits, taskComments, scoutRuns, scoutRunActions
|
||||
- `db/index.ts` — Drizzle + better-sqlite3 (WAL), singleton `getDb()`, `initDb()` migrations
|
||||
- `db/notes-backfill.ts` — Startup backfill: generates `aiSummary` for notes with null summary
|
||||
- `ai/orchestrator.ts` — Thin backend-delegation layer (~304 LOC). Connectivity/auth guard → `BackendClient.sendHomeRequest()` / `sendFloatingRequest()` → forwards v3 stream frames to renderer. Also schedules daily-brief regeneration.
|
||||
- `ai/token.ts` — Two-tier token storage (safeStorage + electron-store fallback)
|
||||
- `scouts/scout-scheduler.ts` — Local scout scheduling (filesystem scouts)
|
||||
- `api/backend-client.ts` — WS client to FastAPI: handles tool-call round-trips, v3 stream frame dispatch, journey + scout proxies
|
||||
- `api/drizzle-executor.ts` — Executes backend-issued tool calls against local SQLite. Wraps results through `formatRow()`/`formatRows()` using user FormatPrefs
|
||||
- `auth/auth-manager.ts` — Login, register, logout, OAuth flow (singleton)
|
||||
- `auth/backup-key.ts` — Device-specific AES-256 backup key (safeStorage, not password-derived)
|
||||
- `auth/locale-defaults.ts` — Detects timezone, date/time format, language from OS locale
|
||||
|
||||
### Preload (`src/preload/trpc.ts`)
|
||||
**tRPC routers** (in `appRouter`): `health`, `settings`, `clients`, `projects`, `tasks`, `timelineEvents`, `timelineEventDependencies`, `notes`, `noteEdits`, `taskComments`, `ai`, `auth`, `scout` (with `local` / `cloud` / `journey` sub-routers), `memory`.
|
||||
|
||||
Exposes `window.electronTRPC` with `sendMessage()` / `onMessage()`.
|
||||
**Renderer** (`src/renderer/`): file-based routing via TanStack Router (`routeTree.gen.ts` auto-generated). shadcn/ui new-york theme, neutral colors. Path alias `@/*` → `src/renderer/*`. Notes editor: Milkdown (`@milkdown/crepe`).
|
||||
|
||||
### Renderer (`src/renderer/`)
|
||||
**Non-obvious details**:
|
||||
- `electron-trpc` NOT used — custom IPC bridge (`ipc.ts` + `lib/ipcLink.ts`) because electron-trpc bundles tRPC v10 internals
|
||||
- Vite configs use `.mts` extension to avoid ESM/CJS conflicts with electron-forge
|
||||
- `forge.config.ts` has cross-compilation hooks (downloads platform-specific native binaries for better-sqlite3)
|
||||
- DB has no foreign key constraints — cascade deletes in tRPC procedures
|
||||
- Timestamps are milliseconds (`Date.getTime()`), not ISO strings
|
||||
- Notes use `aiSummary` (≤250 char, backend `gpt-4o-mini` via `POST /api/v1/scouts/notes/summarize`) for AI navigation — LanceDB fully removed
|
||||
- AI note edits go through `noteEdits` HITL table (`type: append|insert|replace`, `status: pending|approved|rejected`); backend tool `propose_note_edit` → drizzle-executor inserts row; user approves/rejects in UI; auto-reject on missing anchor
|
||||
- `checkpoints` table replaced by `timelineEvents` + `timelineEventDependencies` (events are typed `milestone|checkpoint|activity`, with optional dep edges)
|
||||
- `scoutRuns` + `scoutRunActions` populated by backend-client on tool_call/run_complete frames; UI reads via `scout.runs` / `scout.runActions`
|
||||
|
||||
React 19 — never accesses Node APIs directly. All data through `trpc.*.useQuery()` / `trpc.*.useMutation()`.
|
||||
**Settings Page (shared Electron + Web)**:
|
||||
- Settings page runs in **both** Electron and standalone web SPA. Same React components — no duplication.
|
||||
- **Platform Adapter**: `PlatformProvider` context (`src/renderer/lib/platform.tsx`) exposes `isElectron`/`isWeb`/`hasLocalAgents`/`hasFileDialog`. Components use `usePlatform()` to gate Electron-only features.
|
||||
- **Web build**: `vite.web.config.mts` → `dist-web/`. Entry: `web.html` → `src/renderer/web-main.tsx` (uses `httpBatchLink` via `lib/httpLink.ts` instead of `ipcLink`).
|
||||
- **Electron-only gating**: Device ID card and local scout filesystem gated behind `platform.isElectron`. On web: visible but disabled, not hidden.
|
||||
- **Gotcha**: Do NOT add Electron-specific settings (server URL, native file pickers) without wrapping in `platform.isElectron`. Same component tree renders on web.
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `lib/ipcLink.ts` | Custom TRPCLink routing through `window.electronTRPC` |
|
||||
| `lib/trpc.ts` | `createTRPCReact<AppRouter>()` typed client |
|
||||
| `index.tsx` | QueryClient + tRPC + Router providers |
|
||||
**Onboarding Wizard**:
|
||||
- First-run wizard collects 5 fields: `job_role`, `industry`, `primary_use_case`, `tone_preference`, `language`. Plus `user_name` from `name`+`surname`.
|
||||
- All fields stored as encrypted core memory (backend `MemoryMiddleware`), not local electron-store.
|
||||
- `onboarding_completed_at` on `users` table (nullable TIMESTAMPTZ) gates flow — `null` = show wizard, non-null = skip.
|
||||
- `AppShell.tsx` gates: if `profile.onboardingCompletedAt == null` → render `<OnboardingFlow>` instead of app.
|
||||
- `auth.status` tRPC procedure auto-seeds `language` and `user_name` into MemoryCore if missing (fire-and-forget `.catch(() => {})`).
|
||||
- Format prefs (timezone, dateFormat, timeFormat) stored in electron-store (`FormatPrefs`), not core memory — device-specific.
|
||||
- `drizzle-executor.ts` wraps all query results through `formatRow()`/`formatRows()` using user FormatPrefs.
|
||||
- Settings > Profile allows post-onboarding edit of all fields + format prefs.
|
||||
- **Gotcha — shadcn Button `outline` variant in dark mode**: variant defines `dark:bg-input/30 dark:border-input dark:hover:bg-input/50` — overrides custom `className` background. Fix: switch between `variant="default"` and `variant="outline"` instead of className overrides.
|
||||
- **Gotcha — locale codes vs human names**: `app.getLocale()` and `navigator.language` return codes like `en-US`. Use `Intl.DisplayNames(undefined, { type: 'language' })` to convert to "English". Must do in both main process (`locale-defaults.ts`) and renderer (`OnboardingFlow.tsx`).
|
||||
|
||||
### Routing
|
||||
**i18n (Internationalization)**:
|
||||
- `i18next` + `react-i18next` with bundled JSON translations (no lazy loading).
|
||||
- Config in `src/renderer/i18n.ts`. 5 languages: EN, IT, ES, FR, DE. `SUPPORTED_LANGUAGES` exported for UI selectors.
|
||||
- Translation files: `src/renderer/locales/{en,it,es,fr,de}/translation.json`. Namespaces: `nav`, `auth`, `tasks`, `settings`, `common`, `errors`, `home`, `timeline`, `projects`, `scouts`.
|
||||
- **`common.*` namespace** holds shared labels (`save`, `cancel`, `delete`, `edit`, `add`, `rename`, `saving`, `deleting`, `creating`, `renameDescription`, `deleteTitle`). Check `common.*` before adding new key.
|
||||
- Pluralization uses i18next `_one`/`_other` suffixes.
|
||||
- `LanguageSync` component in `src/renderer/index.tsx` reads persisted `uiLanguage` from electron-store via tRPC on startup, syncs to i18next.
|
||||
- Language selector in `GeneralSection.tsx` (Settings > General). On change: (1) calls `i18n.changeLanguage()`, (2) persists to electron-store via `setUiLanguage` mutation, (3) writes to backend core memory so AI responds in same language.
|
||||
- `getUiLanguage()` exported from `src/main/store.ts`.
|
||||
- Static data arrays needing translation use `labelKey` pattern: store translation key, call `t(labelKey)` at render. Used in `NAV_ITEMS`, `COLUMNS`, `SECTIONS`, `SUGGESTION_CHIPS`.
|
||||
- When adding new translated text: add key to **all 5** JSON files. Keep `common.*` consistent across all languages.
|
||||
|
||||
File-based via TanStack Router (`tsr.config.json` at root). Route tree auto-generated at `routeTree.gen.ts`.
|
||||
**Google OAuth (adiuvAI side)**:
|
||||
- `adiuvai://` NOT accepted by Google as redirect URI — Google only accepts `http://localhost` or `https://`. API backend exposes `GET /auth/oauth/google/web-callback` which receives Google redirect and bounces to `adiuvai://oauth/callback?...`. Redirect URI in Google Cloud Console points to backend, not Electron app.
|
||||
- `app.requestSingleInstanceLock()` required for `second-instance` event on Windows/Linux. If returns `false`, call `app.quit()` immediately.
|
||||
- In dev (`process.defaultApp === true`), `setAsDefaultProtocolClient('adiuvai')` must include `[path.resolve(process.argv[1])]` as third arg so OS protocol registration includes entry script.
|
||||
- `loginWithOAuth` uses `fetch()` directly (not `this.get()`) — authorize endpoint is public, `get()` throws when not authenticated.
|
||||
- Backup key in `backup-key.ts` stored in `encryptedTokens` under key `backup_key`, reusing `getToken/setToken` from `token.ts`. Device-bound, never password-derived — social-login users can use backup features.
|
||||
|
||||
Routes: `__root.tsx` (AppShell layout), `index`, `tasks`, `timeline`, `projects`, `notes.$noteId`
|
||||
---
|
||||
|
||||
### tRPC Routers
|
||||
## api (FastAPI Backend)
|
||||
|
||||
`health`, `settings`, `clients`, `projects`, `tasks`, `checkpoints`, `notes`, `taskComments`, `ai`
|
||||
### Commands
|
||||
|
||||
### Database
|
||||
```bash
|
||||
cd api
|
||||
|
||||
Schema in `src/main/db/schema.ts`, migrations in `src/main/db/migrations/`. DB created in Electron's `userData` as `adiuva.db`. On startup, `initDb()` runs non-destructive migrations.
|
||||
# Development
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
|
||||
To add a table/column: edit `schema.ts` → `drizzle-kit generate` → `drizzle-kit push` (dev) or commit the migration.
|
||||
# Production
|
||||
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120
|
||||
|
||||
### Adding a Feature (end-to-end)
|
||||
# Database migrations
|
||||
alembic upgrade head
|
||||
|
||||
1. **Schema** — `src/main/db/schema.ts`
|
||||
2. **Router** — Add sub-router in `src/main/router/index.ts`, merge into `appRouter`
|
||||
3. **Types** — Flow automatically via `AppRouter` export
|
||||
4. **UI** — Components in `src/renderer/components/<feature>/`, data via `trpc.*.useQuery()`
|
||||
# Testing
|
||||
pytest # all tests
|
||||
pytest -v # verbose
|
||||
pytest tests/test_deep_agent.py # single file
|
||||
pytest tests/test_deep_agent.py -k test_name # single test
|
||||
|
||||
## AI Subsystem (`src/main/ai/`)
|
||||
# Linting/formatting
|
||||
ruff check .
|
||||
ruff format .
|
||||
|
||||
LangGraph-based agentic system with pluggable LLM providers.
|
||||
# Docker (full stack)
|
||||
docker compose up --build
|
||||
```
|
||||
|
||||
### Orchestrator (`orchestrator.ts`)
|
||||
### Architecture
|
||||
|
||||
Classifies user intent → routes to a specialist agent:
|
||||
```
|
||||
FastAPI app (app/main.py)
|
||||
├── Lifespan: APScheduler crons (memory hourly + audit weekly) when SCHEDULER_ENABLED
|
||||
├── Middleware: TierRateLimit → Sanitizer → CORS
|
||||
├── HTTP Routes (app/api/routes/) — all under /api/v1
|
||||
│ ├── auth.py — register, login, refresh, profile, OAuth, onboarding, password
|
||||
│ ├── chat.py — POST /chat, /chat/brief, /chat/embed
|
||||
│ ├── scouts.py — catalog, can-create, trigger, notes/summarize
|
||||
│ ├── scout_setup.py — guided scout setup (journey)
|
||||
│ ├── billing.py — Stripe checkout, webhook, subscription, invoices
|
||||
│ ├── device_ws.py — WS /device (unified streaming endpoint: home, floating, brief, journey)
|
||||
│ └── memory.py — core / relational / forget-all
|
||||
├── Agent System (app/agents/)
|
||||
│ ├── task_agent.py
|
||||
│ ├── project_agent.py
|
||||
│ ├── note_agent.py
|
||||
│ ├── timeline_agent.py
|
||||
│ └── filesystem_agent.py
|
||||
├── Core (app/core/)
|
||||
│ ├── deep_agent.py — main agent runner (run_home / run_floating / run_brief / run_journey)
|
||||
│ ├── brief_agent.py — daily brief generation
|
||||
│ ├── agent_runner.py — local + cloud agent run executor
|
||||
│ ├── agent_session_buffer.py — per-session conversation buffer
|
||||
│ ├── agent_registry.py — decorator-based agent registry
|
||||
│ ├── llm.py — LiteLLM factory (multi-provider)
|
||||
│ ├── memory_middleware.py — encrypted core memory read/write
|
||||
│ ├── memory_extraction.py — LLM extraction from conversation tail
|
||||
│ ├── memory_maintenance.py — drain queue, contradiction audit, proactive mining
|
||||
│ ├── note_summarizer.py — gpt-4o-mini summary for notes
|
||||
│ ├── output_formatter.py — render agent output to user-facing markdown
|
||||
│ ├── embeddings.py
|
||||
│ ├── device_manager.py — device registration / WS session tracking
|
||||
│ ├── ws_context.py — per-WS user context plumbing
|
||||
│ ├── langfuse_client.py — Langfuse prompt + tracing client
|
||||
│ └── preprocessors/ — input preprocessors (e.g. email_html.py)
|
||||
├── Auth (app/auth/oauth_providers.py) — GoogleOAuthProvider (httpx + manual PKCE)
|
||||
├── Billing (app/billing/) — tier_manager + stripe_service
|
||||
├── Integrations (app/integrations/) — gmail.py, ms_graph.py
|
||||
└── Models (app/models.py) — SQLAlchemy 2.0 ORM
|
||||
```
|
||||
|
||||
| Agent | Scope | Tools |
|
||||
|---|---|---|
|
||||
| Project | Project-scoped Q&A | `read_project_notes`, `add_task`, `get_summary`, `suggest_checkpoints`, `suggest_tasks` |
|
||||
| Knowledge | Cross-project search | `vector_search_all` |
|
||||
| General | Workspace-wide | `add_task` |
|
||||
**HTTP route prefix**: every router included with `prefix="/api/v1"`. So `/api/v1/auth/...`, `/api/v1/chat`, `/api/v1/scouts/...`, `/api/v1/memory/...`, `/api/v1/device` (WS).
|
||||
|
||||
All providers use LangChain `bindTools()` + ToolMessage loop (max 5 iterations).
|
||||
**ORM models** (`app/models.py`): `User`, `RefreshToken`, `OAuthAccount`, `Subscription`, `LocalScoutConfig`, `CloudScoutConfig`, `ScoutRunLog`, `MemoryCore`, `MemoryAssociative`, `MemoryEpisodic`, `MemoryProactive`, `ExtractionQueue`, `MemoryRelation`, `Plugin`. PostgreSQL (asyncpg + SQLAlchemy 2.0 async). Alembic migrations in `alembic/versions/`.
|
||||
|
||||
Also exports `dailyBrief()` for AI-generated daily summaries (`ai.dailyBrief` tRPC mutation).
|
||||
**Lifespan crons** (only if `settings.SCHEDULER_ENABLED`):
|
||||
- `_memory_cron_tick` — hourly: drains Free-tier extraction queue + mines proactive patterns for Power+ users
|
||||
- `_memory_audit_cron_tick` — weekly: contradiction scan + label canonicalization for all users (Phase 7)
|
||||
|
||||
### Streaming
|
||||
**LLM routing**: backend agents own all intelligence. Tool calls describe client-side ops (JSON) → Electron `drizzle-executor` runs them against local SQLite → result returned to backend over WS. Tool loop cap inside agent runner prevents runaway iteration.
|
||||
|
||||
`sendStreamChunk(sender, token, done)` over IPC `'ai:stream'`. Renderer subscribes via `window.electronAI.onStreamChunk()` in `AIChatPanel.tsx`. `<tool_call>` blocks are filtered before display.
|
||||
**Zero-trust data model**: backend never stores raw user content. PostgreSQL holds auth, billing, plugin metadata, encrypted memory (Core/Associative/Episodic/Proactive/Relational), scout configs, run logs.
|
||||
|
||||
### Providers (`llm.ts`)
|
||||
**Config**: `app/config/settings.py` — all env vars via Pydantic Settings. Copy `.env.example` to `.env` for local dev.
|
||||
|
||||
| Provider | Model | Notes |
|
||||
|---|---|---|
|
||||
| OpenAI | `gpt-4o-mini` | Via LangChain |
|
||||
| Anthropic | `claude-sonnet-4-20250514` | Via LangChain |
|
||||
| Copilot | `ChatCopilot` wrapper | `copilot.ts` / `chat-copilot.ts` |
|
||||
**Testing**: pytest + pytest-asyncio. Fixtures in `tests/conftest.py`. Active suites: agent runner, auth, brief/deep agents, device WS, integrations, journey, memory (audit/extraction/middleware/models/proactive/relations), middleware, output formatter, preprocessors, schemas, ws_unified.
|
||||
|
||||
All use `temperature: 0.3`, streaming enabled. Provider management in `provider.ts`.
|
||||
### Non-obvious details
|
||||
|
||||
### Vector Embeddings (`db/vectordb.ts`)
|
||||
**Provider factory** (`llm.ts`): `gpt-4o-mini` (OpenAI), `claude-sonnet-4-20250514` (Anthropic), or ChatCopilot wrapper — all with `temperature: 0.3` and streaming enabled.
|
||||
- **Tier from DB, not JWT**: `get_current_user` decodes JWT but fetches authoritative tier from `subscriptions` — tier changes take effect immediately, no re-login needed
|
||||
- **Refresh tokens hashed**: plaintext returned to client, stored as SHA-256 in DB — server can never retrieve plaintext (intentional)
|
||||
- **WebSocket auth via query param**: `?token=<jwt>` instead of Bearer header (WebSocket handshake limitation)
|
||||
- **Unified device WS**: `/api/v1/device` is the single bidirectional channel. Handles home requests, floating requests, daily briefs, journeys, heartbeats. Tool calls round-trip through the same socket
|
||||
- **Prompt IP protection**: prompts kept server-side via Langfuse (`langfuse_client`). `SanitizerMiddleware` strips leaked fragments from responses
|
||||
- **Agents don't execute operations**: tools return JSON describing client-side ops — Electron client executes against local SQLite
|
||||
- **Alembic async/sync split**: app uses `postgresql+asyncpg`, Alembic CLI needs `postgresql+psycopg2` — `env.py` handles URL conversion
|
||||
- **CORS includes `app://`**: Electron uses custom `app://` protocol, not http/https
|
||||
- **Run-disconnect tracking**: `_mark_runs_disconnected` flips active runs when WS drops so client can resume cleanly
|
||||
|
||||
**Token storage** (`token.ts`) — two-tier fallback:
|
||||
1. electron-store + `safeStorage` — encrypted at rest (preferred)
|
||||
2. Plain electron-store — last resort (e.g. WSL with no keyring)
|
||||
**Onboarding (API side)**:
|
||||
- `PUT /auth/me/memory` — updates core memory k/v pairs, optionally marks onboarding complete (`mark_onboarded: true` sets `users.onboarding_completed_at`).
|
||||
- `POST /auth/me/onboarding/reset` — nullifies `onboarding_completed_at` so wizard re-runs.
|
||||
- `POST /auth/onboarding/normalize` — LLM-normalizes free-text onboarding inputs via `gpt-4o-mini`; returns inputs unchanged on error.
|
||||
- `get_current_user()` in `auth.py` middleware decrypts core memory blocks, includes in `UserProfile.memory` dict.
|
||||
- `users.onboarding_completed_at` — nullable TIMESTAMPTZ, returned as epoch ms (int) in UserProfile schema.
|
||||
|
||||
**AI approval pattern**: Tasks and checkpoints have `isAiSuggested` (bool) and `isApproved` (bool) columns. AI-suggested items appear in the UI pending user approval before being treated as real records.
|
||||
**i18n (API side)**:
|
||||
- `_language_instruction()` in `app/core/deep_agent.py` reads user's `language` from `MemoryCore`, appends system prompt directive ("Always respond in {language}") to all `run_*` functions.
|
||||
- Electron client writes chosen language to backend core memory on change — API picks up on next agent call.
|
||||
|
||||
### Vector Embeddings (`src/main/db/vectordb.ts`)
|
||||
**Google OAuth (api side)**:
|
||||
- OAuth routes in `app/api/routes/auth.py`: `GET /auth/oauth/{provider}/authorize`, `POST /auth/oauth/{provider}/callback`, `GET /auth/oauth/{provider}/web-callback` (bounces to deep link, excluded from OpenAPI schema).
|
||||
- Provider abstraction in `app/auth/oauth_providers.py` — `GoogleOAuthProvider` uses `httpx` directly (no `authlib`). PKCE S256 implemented manually via `generate_pkce_pair()`.
|
||||
- `_pending_states` dict in `routes/auth.py` is **in-memory** — works for single-process dev, doesn't survive restarts, doesn't scale to multiple workers. Replace with Redis in production.
|
||||
- `users.password_hash` is **nullable** — social-only users have `password_hash=None`. `await db.flush()` required before creating linked `OAuthAccount` to populate `new_user.id` before commit.
|
||||
- `OAUTH_REDIRECT_URI` must point to **API backend** (e.g. `https://api.adiuvai.com/...`).
|
||||
- **Unverified email + existing account = 409**: if `email_verified=False` and email already registered, callback returns 409. Without this guard, branch 3 would INSERT duplicate email and crash with DB constraint violation (500).
|
||||
- **Testing OAuth routes**: mock `GoogleOAuthProvider.exchange_code` and `get_userinfo` with `patch.object(..., new=AsyncMock(...))` — works because FastAPI instantiates new provider per request. Use `monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_ID", ...)` to simulate configured credentials without restart.
|
||||
|
||||
LanceDB in `{userData}/vectors/`. Schema: `{ id, projectId, content, vector }` (1536-dim, `text-embedding-3-small` via `embeddings.ts`). Embedding priority: Copilot CLI token → OpenAI token.
|
||||
### Tier System
|
||||
|
||||
- `upsertNoteEmbedding()` on note create/update (fire-and-forget)
|
||||
- `migrateNotesIfNeeded()` backfills on first startup
|
||||
- `searchNotes(query, limit=5)` used by Knowledge agent
|
||||
Source of truth: `app/billing/tier_manager.py` (`FEATURES` + `RATE_LIMITS` dicts).
|
||||
|
||||
### AI Approval Pattern
|
||||
| Feature | Free | Pro | Power | Team |
|
||||
|---------------------|--------|-----------|-----------|-----------|
|
||||
| Rate limit | 20/min | 60/min | 120/min | 200/min |
|
||||
| Providers | 1 | unlimited | unlimited | unlimited |
|
||||
| Relational memory | no | yes | yes | yes |
|
||||
| Proactive mining | no | no | yes | yes |
|
||||
|
||||
Tasks and checkpoints have `isAiSuggested` + `isApproved` columns. AI suggestions appear pending user approval (dashed borders in UI).
|
||||
`tier_manager.get_tier()` falls back to `'power'` in dev (`settings.ENV == 'dev'`) when no subscription found, else `'free'`. Enforced in `app/api/middleware/rate_limit.py` (sliding window) and `tier_manager.check_feature()` calls scattered through agent + memory paths.
|
||||
|
||||
## Config Notes
|
||||
---
|
||||
|
||||
- Vite configs use `.mts` (not `.ts`) — avoids ESM/CJS conflicts with electron-forge
|
||||
- `@/*` path alias → `src/renderer/*` (TypeScript + Vite + shadcn/ui)
|
||||
- **shadcn/ui**: new-york style, neutral base color
|
||||
- **Icons**: lucide-react only — do not introduce other icon libraries
|
||||
- **Tailwind 4** — CSS variable theming in `globals.css`, no `tailwind.config.js`
|
||||
- **Notes editor**: Milkdown (`@milkdown/crepe`) at `src/renderer/components/notes/MilkdownEditor.tsx`
|
||||
## Cross-Project Integration
|
||||
|
||||
## Design Context
|
||||
Electron app and FastAPI backend communicate via **WebSocket** (`/api/v1/device`):
|
||||
|
||||
### Users
|
||||
Freelancers and solo professionals managing client work (projects, tasks, notes, timelines). Single workspace, no enterprise overhead. AI as force multiplier. They open the app mid-workday — often stressed — so the interface must feel immediately grounding and in control.
|
||||
1. Electron connects with `?token=<jwt>` query param
|
||||
2. Client sends typed request frames (home / floating / brief / journey_start / journey_message)
|
||||
3. Server streams v3 typed frames (text deltas, tool_call, run_complete, error)
|
||||
4. Tool call frames → Electron `drizzle-executor` runs against local SQLite → returns `tool_result` over same socket
|
||||
5. Heartbeat loop keeps connection alive; backend marks runs disconnected on drop
|
||||
|
||||
### Brand Personality
|
||||
**Calm. Intelligent. Warm.** A thoughtful companion, not a flashy tool. Confident and understated — never loud, gamified, or corporate. Fully original aesthetic (no external design system references; this look is intentional and owned).
|
||||
There is no fully-local AI fallback — the Electron orchestrator is a thin delegation shell that requires connectivity + auth. If offline or logged out, `checkConnectivity()` short-circuits with a user-facing error.
|
||||
|
||||
### Emotional Goal
|
||||
When a user opens Adiuva, the first impression should communicate **"everything is under control"** — calm clarity over urgency. The design should lower cognitive load, not raise it.
|
||||
---
|
||||
|
||||
### Aesthetic Direction
|
||||
- Light mode: pinkish-white canvas `#f4edf3`, golden yellow primary `#fbc881`, slate blue-gray secondary `#8a8ea9`, dusty lavender borders `#c8c3cd`
|
||||
- Dark mode: near-black `#0c0c0c`, pure white primary, dark gray `#323232` surfaces
|
||||
- Geist sans-serif, weights 400/500/600. Tight tracking (`-1px`) on headings. Body `text-sm`, metadata `text-xs`
|
||||
- 10px border-radius (`rounded-lg`), `rounded-2xl` for chat/AI elements
|
||||
- Glassmorphism on AI inputs (`backdrop-blur-xl`, transparency, gradient border via padding-box/border-box technique)
|
||||
- Spring animations (stiffness 400, damping 30), scale-and-fade transitions
|
||||
- No gamification (badges, streaks, confetti). Mature and professional
|
||||
- Dashed borders + Sparkles icon = AI-pending state marker
|
||||
## MCP Servers
|
||||
|
||||
### Accessibility
|
||||
Best-effort — not formally audited. Maintain reasonable contrast and keyboard operability without targeting a specific WCAG level.
|
||||
|
||||
### Current Design Focus
|
||||
**Polish and refinement.** The overall direction is solid; the priority is elevating specific areas that feel rough or inconsistent — tighter spacing, more intentional hierarchy, better empty/loading states, and smoother motion.
|
||||
|
||||
### Design Principles
|
||||
1. **Clarity over cleverness** — Clear hierarchy, generous whitespace, comfortable density. Never sacrifice legibility for style.
|
||||
2. **AI as quiet partner** — Deeply integrated but never intrusive. Dashed borders for pending AI items, Sparkles icon as the sole AI marker. Surface AI capabilities without making them the hero.
|
||||
3. **Warmth in restraint** — The warm palette feels approachable without being playful. Dark mode trades warmth for focus. Neither mode should feel cold or aggressive.
|
||||
4. **Motion with purpose** — Spring animations reinforce spatial relationships and acknowledge state changes. Never purely decorative. Respect reduced-motion preferences where possible.
|
||||
5. **Polish over features** — Every surface should feel considered. Prefer refining what exists over introducing new complexity. The right amount of visual weight is the minimum needed.
|
||||
- **Langfuse Docs** (`https://langfuse.com/api/mcp`) — workspace-level, prompt management docs
|
||||
- **shadcn** (`npx shadcn@latest mcp`) — configured in `electron/` for UI component generation
|
||||
|
||||
269
.claude/CLAUDE.original.md
Normal file
269
.claude/CLAUDE.original.md
Normal file
@@ -0,0 +1,269 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Keeping This File Up to Date
|
||||
|
||||
Update this file whenever a lesson is learned during development. Specifically, update CLAUDE.md when:
|
||||
|
||||
- A non-obvious architectural decision is made or discovered
|
||||
- A gotcha, footgun, or surprising behavior is encountered (and the fix/workaround)
|
||||
- A new command, workflow, or tool is added to the project
|
||||
- A convention is established that isn't obvious from reading the code
|
||||
- An integration detail is clarified (e.g., how the IPC protocol actually behaves, edge cases in the agent tool call cycle)
|
||||
|
||||
Do **not** add things already derivable from reading the code, generic best practices, or ephemeral task notes — only durable, reusable knowledge.
|
||||
|
||||
## Repository Layout
|
||||
|
||||
This is a **monorepo with git submodules**. Each submodule is an independent repo with its own `.claude/CLAUDE.md` for detailed guidance.
|
||||
|
||||
| Directory | What | Submodule |
|
||||
|-----------|------|-----------|
|
||||
| **`adiuvAI/`** | Electron desktop app (TypeScript/React) | `git.muticolturano.com/adiuvAI/adiuvAI` |
|
||||
| **`api/`** | FastAPI backend (Python) | `git.muticolturano.com/adiuvAI/api` |
|
||||
| **`website/`** | Landing page (single `index.html`) | `git.muticolturano.com/adiuvAI/website` |
|
||||
| **`docs/`** | Planning docs & working memory (not a submodule) | -- |
|
||||
|
||||
After cloning, run `git submodule update --init --recursive` to populate submodule contents.
|
||||
|
||||
---
|
||||
|
||||
## adiuvAI (Electron App)
|
||||
|
||||
> **Detailed docs**: `adiuvAI/.claude/CLAUDE.md` covers commands, architecture, AI subsystem, design context, and conventions in depth.
|
||||
|
||||
### Commands
|
||||
|
||||
```bash
|
||||
cd adiuvAI
|
||||
npm run start # Start dev server (Electron + Vite)
|
||||
npm run lint # ESLint
|
||||
npm run knip # Dead code analysis
|
||||
npm run make # Build installers (Windows/Linux/macOS)
|
||||
npm run package # Package without creating installers
|
||||
npx drizzle-kit generate # Generate migration from schema changes
|
||||
npx drizzle-kit push # Push schema directly (dev only)
|
||||
```
|
||||
|
||||
No test suite currently.
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
Renderer (React 19 + TanStack Router)
|
||||
↓ custom ipcLink (NOT electron-trpc — incompatible with tRPC v11)
|
||||
Preload (contextBridge: window.electronTRPC + window.electronAI)
|
||||
↓ IPC channels
|
||||
Main Process (Node.js)
|
||||
├── tRPC router (all CRUD + AI procedures)
|
||||
├── SQLite (better-sqlite3 + Drizzle ORM, WAL mode)
|
||||
├── LanceDB (vector embeddings, 1536-dim text-embedding-3-small)
|
||||
└── LangGraph orchestrator (3 specialist agents, pluggable LLM providers)
|
||||
```
|
||||
|
||||
**This is a local-first app.** All user data (tasks, notes, projects) lives in local SQLite. The AI system (LangGraph + LangChain) runs entirely in the Electron main process with pluggable providers (OpenAI, Anthropic, GitHub Copilot).
|
||||
|
||||
**IPC channels**:
|
||||
- `'trpc'` — bidirectional tRPC request/response (all CRUD)
|
||||
- `'ai:stream'` — one-way token streaming from main → renderer
|
||||
- `'ai:action'` — AI side-effects (e.g., task auto-created by agent)
|
||||
|
||||
**Key source paths**:
|
||||
- `src/main/ipc.ts` — Custom tRPC↔IPC bridge
|
||||
- `src/main/router/index.ts` — All tRPC routers (~600 LOC)
|
||||
- `src/main/ai/orchestrator.ts` — LangGraph intent routing + 3 agents (~991 LOC)
|
||||
- `src/main/db/schema.ts` — 6 tables (clients, projects, tasks, checkpoints, notes, taskComments)
|
||||
- `src/renderer/routes/` — File-based routing (TanStack Router auto-generates `routeTree.gen.ts`)
|
||||
- `src/renderer/components/ui/` — shadcn/ui primitives (new-york theme, neutral colors)
|
||||
- `src/main/auth/auth-manager.ts` — Login, register, logout, OAuth flow (singleton)
|
||||
- `src/main/auth/backup-key.ts` — Device-specific AES-256 backup key (safeStorage, not password-derived)
|
||||
- `src/main/ai/token.ts` — Two-tier token storage: safeStorage + electron-store fallback
|
||||
- `src/main/auth/locale-defaults.ts` — Detects timezone, date/time format, language from OS locale
|
||||
- `src/main/api/format-row.ts` — Formats timestamp columns in query results using user's FormatPrefs
|
||||
|
||||
**Non-obvious details**:
|
||||
- `electron-trpc` is NOT used — custom IPC bridge in `ipc.ts` + `ipcLink.ts` because electron-trpc bundles tRPC v10 internals
|
||||
- Vite configs use `.mts` extension to avoid ESM/CJS conflicts with electron-forge
|
||||
- `forge.config.ts` has complex cross-compilation hooks (downloads platform-specific native binaries for better-sqlite3 and LanceDB)
|
||||
- DB has no foreign key constraints — cascade deletes are implemented in tRPC procedures
|
||||
- Timestamps are milliseconds (JavaScript `Date.getTime()`), not ISO strings
|
||||
- Notes auto-embed to LanceDB on create/update (fire-and-forget, errors swallowed)
|
||||
|
||||
**Settings Page (shared between Electron and Web)**:
|
||||
- The Settings page is designed to run in **both** the Electron app and a standalone web SPA (future landing-page user portal). The same React components are used — no duplication.
|
||||
- **Platform Adapter pattern**: `PlatformProvider` context (`src/renderer/lib/platform.tsx`) exposes `isElectron`/`isWeb`/`hasLocalAgents`/`hasFileDialog` flags. Components use `usePlatform()` to conditionally render Electron-only features (device ID, local agent filesystem) or disable them on web.
|
||||
- **6 sections**: Profile, AI Preferences, Account, Billing, Appearance, Agents. Sidebar nav with icons in `types.ts` (`SECTIONS` array).
|
||||
- **Web build**: `vite.web.config.mts` builds a standalone SPA to `dist-web/`. Entry: `web.html` → `src/renderer/web-main.tsx` (uses `httpBatchLink` via `src/renderer/lib/httpLink.ts` instead of `ipcLink`). Scripts: `npm run dev:web`, `npm run build:web`, `npm run preview:web`.
|
||||
- **Electron-only gating**: Device ID card and local agent filesystem features are gated behind `platform.isElectron`. On web, local agents are visible but disabled (not hidden).
|
||||
- **Gotcha**: Do NOT add Electron-specific settings (e.g. server URL, native file pickers) without wrapping in `platform.isElectron`. The same component tree renders on web.
|
||||
|
||||
**Onboarding Wizard**:
|
||||
- First-run wizard collects 5 fields: `job_role`, `industry`, `primary_use_case`, `tone_preference`, `language`. Plus `user_name` derived from profile `name`+`surname`.
|
||||
- All fields stored as encrypted core memory (backend `MemoryMiddleware`), not local electron-store.
|
||||
- `onboarding_completed_at` on the `users` table (nullable TIMESTAMPTZ) gates the flow — `null` = show wizard, non-null = skip.
|
||||
- `AppShell.tsx` gates: if `profile.onboardingCompletedAt == null` → render `<OnboardingFlow>` instead of the app.
|
||||
- `auth.status` tRPC procedure auto-seeds `language` and `user_name` into MemoryCore if missing (fire-and-forget `.catch(() => {})`).
|
||||
- Format prefs (timezone, dateFormat, timeFormat) are stored in electron-store (`FormatPrefs`), not core memory — they're device-specific.
|
||||
- `drizzle-executor.ts` wraps all query results through `formatRow()`/`formatRows()` using the user's FormatPrefs.
|
||||
- Settings > Profile section allows post-onboarding editing of all fields + format prefs.
|
||||
- **Gotcha — shadcn Button `outline` variant in dark mode**: The variant defines `dark:bg-input/30 dark:border-input dark:hover:bg-input/50` which overrides any custom `className` background. Fix: switch between `variant="default"` and `variant="outline"` instead of adding className overrides.
|
||||
- **Gotcha — locale codes vs human names**: `app.getLocale()` and `navigator.language` return codes like `en-US`. Use `Intl.DisplayNames(undefined, { type: 'language' })` to convert to "English". This must be done in both the main process (`locale-defaults.ts`) and renderer (`OnboardingFlow.tsx`).
|
||||
|
||||
**i18n (Internationalization)**:
|
||||
- Uses `i18next` + `react-i18next` with bundled JSON translations (no lazy loading).
|
||||
- Config in `src/renderer/i18n.ts`. 5 languages: EN, IT, ES, FR, DE. `SUPPORTED_LANGUAGES` array exported for UI selectors.
|
||||
- Translation files: `src/renderer/locales/{en,it,es,fr,de}/translation.json`. Namespaces: `nav`, `auth`, `tasks`, `settings`, `common`, `errors`, `home`, `timeline`, `projects`, `agents`.
|
||||
- **`common.*` namespace** holds shared labels (`save`, `cancel`, `delete`, `edit`, `add`, `rename`, `saving`, `deleting`, `creating`, `renameDescription`, `deleteTitle`). Before adding a new key, check if `common.*` already has it.
|
||||
- Pluralization uses i18next `_one`/`_other` suffixes (e.g. `tasksDueToday_one`, `tasksDueToday_other`).
|
||||
- `LanguageSync` component in `src/renderer/index.tsx` reads persisted `uiLanguage` from electron-store via tRPC on startup and syncs to i18next.
|
||||
- Language selector lives in `GeneralSection.tsx` (Settings > General). On change it: (1) calls `i18n.changeLanguage()`, (2) persists to electron-store via `setUiLanguage` mutation, (3) writes to backend core memory so AI responds in the same language.
|
||||
- `getUiLanguage()` exported from `src/main/store.ts` — used by `orchestrator.ts` to append language hint to daily brief prompt.
|
||||
- Static data arrays that need translation use `labelKey` pattern (not `label`): store a translation key, call `t(labelKey)` at render time. Used in `NAV_ITEMS`, `COLUMNS`, `SECTIONS`, `SUGGESTION_CHIPS`.
|
||||
- When adding new translated text: add the key to **all 5** JSON files. Keep `common.*` keys consistent across all languages.
|
||||
|
||||
**Google OAuth (adiuvAI side)**:
|
||||
- `adiuvai://` is NOT accepted by Google as a redirect URI — Google only accepts `http://localhost` or `https://`. The API backend exposes `GET /auth/oauth/google/web-callback` which receives the Google redirect and immediately bounces to `adiuvai://oauth/callback?...`. The redirect URI registered in Google Cloud Console points to the backend, not the Electron app.
|
||||
- `app.requestSingleInstanceLock()` is required for the `second-instance` event to fire on Windows/Linux. If it returns `false`, call `app.quit()` immediately (another instance is already running).
|
||||
- In dev (`process.defaultApp === true`), `setAsDefaultProtocolClient('adiuvai')` must include `[path.resolve(process.argv[1])]` as the third argument so the OS protocol registration includes the entry script.
|
||||
- `loginWithOAuth` uses `fetch()` directly (not `this.get()`) because the authorize endpoint is public — `get()` throws when not authenticated.
|
||||
- The backup key in `backup-key.ts` is stored in `encryptedTokens` under the key `backup_key`, reusing `getToken/setToken` from `token.ts`. It is device-bound and never password-derived, so social-login users can use backup features without issue.
|
||||
|
||||
---
|
||||
|
||||
## api (FastAPI Backend)
|
||||
|
||||
### Commands
|
||||
|
||||
```bash
|
||||
cd api
|
||||
|
||||
# Development
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
|
||||
# Production
|
||||
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120
|
||||
|
||||
# Database migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Testing
|
||||
pytest # all tests
|
||||
pytest -v # verbose
|
||||
pytest tests/test_agents.py # single file
|
||||
pytest tests/test_agents.py -k test_name # single test
|
||||
|
||||
# Linting/formatting
|
||||
ruff check .
|
||||
ruff format .
|
||||
|
||||
# Docker (full stack: app + postgres + minio + qdrant)
|
||||
docker compose up --build
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
FastAPI app (app/main.py)
|
||||
├── Middleware: TierRateLimiter → Sanitizer → CORS
|
||||
├── HTTP Routes (app/api/routes/)
|
||||
│ ├── auth.py — register, login, token refresh (bcrypt + HS256 JWT)
|
||||
│ ├── chat.py — POST /chat, WS /chat/stream
|
||||
│ ├── plans.py — execution plan playbooks
|
||||
│ ├── storage.py — E2E-encrypted cloud storage (S3)
|
||||
│ ├── backup.py — encrypted backup upload/download
|
||||
│ ├── vectors.py — encrypted vector upsert/search (Pinecone/Qdrant)
|
||||
│ ├── plugins.py — plugin marketplace (Power+ tier)
|
||||
│ └── billing.py — Stripe subscriptions
|
||||
├── Agent System (app/agents/)
|
||||
│ ├── task_agent.py — 8 tools
|
||||
│ ├── project_agent.py — 6 tools
|
||||
│ ├── checkpoint_agent.py — 4 tools
|
||||
│ └── note_agent.py — 5 tools
|
||||
├── Orchestration (app/core/)
|
||||
│ ├── orchestrator.py — intent classification + agent routing
|
||||
│ ├── agent_registry.py — decorator-based agent registry
|
||||
│ ├── execution_plan.py — server-side prompt templates + plan builder
|
||||
│ ├── llm.py — LiteLLM factory (100+ providers)
|
||||
│ └── memory_middleware.py
|
||||
├── Billing (app/billing/)
|
||||
│ ├── tier_manager.py — feature matrix (Free/Pro/Power/Team)
|
||||
│ └── stripe_service.py — Stripe checkout + webhooks
|
||||
├── Storage (app/storage/) — S3 blob store, vector store, encryption
|
||||
└── Marketplace (app/marketplace/) — plugin catalog, review, revenue sharing
|
||||
```
|
||||
|
||||
**LLM routing**: `gpt-4o-mini` classifies intent → routes to domain agent → agent uses `gpt-4o` with tools → tool calls describe client-side operations (JSON) → Electron executes locally and returns results.
|
||||
|
||||
**Zero-trust data model**: The backend never stores or decrypts user content. PostgreSQL holds only auth, billing, plugin metadata, and storage record pointers. All user data is E2E-encrypted before leaving the Electron client.
|
||||
|
||||
**Key config**: `app/config/settings.py` — all env vars via Pydantic Settings. Copy `.env.example` to `.env` for local dev. Stripe and S3 gracefully stub when keys aren't configured.
|
||||
|
||||
**Database**: PostgreSQL with async SQLAlchemy 2.0 + asyncpg. 9 ORM models in `app/models.py`. Alembic migrations in `alembic/versions/`.
|
||||
|
||||
**Testing**: pytest + pytest-asyncio. Fixtures in `tests/conftest.py` create in-memory SQLite + moto-mocked S3. Test users seeded per tier (free/pro/power/team).
|
||||
|
||||
### Non-obvious details
|
||||
|
||||
- **Tier from DB, not JWT**: `get_current_user` decodes JWT but fetches authoritative tier from `subscriptions` table — tier changes take effect immediately without re-login
|
||||
- **Refresh tokens hashed**: Plaintext returned to client, stored as SHA-256 in DB — server can never retrieve the plaintext (intentional)
|
||||
- **WebSocket auth via query param**: `?token=<jwt>` instead of Bearer header (WebSocket handshake limitation)
|
||||
- **Prompt IP protection**: `PromptTemplateRegistry` keeps prompts server-side; clients receive opaque `template_id`. `SanitizerMiddleware` strips leaked fragments from responses
|
||||
- **Agents don't execute operations**: Tools return JSON describing client-side ops — the Electron client executes against local SQLite
|
||||
- **Alembic async/sync split**: App uses `postgresql+asyncpg`, Alembic CLI needs `postgresql+psycopg2` — `env.py` handles the URL conversion
|
||||
- **Tool loop cap**: Agent `_tool_loop` stops after 5 iterations to prevent infinite loops
|
||||
- **Route order matters**: `/backup/history` must be declared before `/backup/{backup_id}` to avoid path param shadowing
|
||||
- **CORS includes `app://`**: Electron uses custom `app://` protocol, not http/https
|
||||
- **Vector search on encrypted data is not semantic**: Backend derives deterministic 32-dim floats from blob SHA-256 for storage/search — a known trade-off
|
||||
|
||||
**Onboarding (API side)**:
|
||||
- `PUT /auth/me/memory` — updates core memory k/v pairs and optionally marks onboarding complete (`mark_onboarded: true` sets `users.onboarding_completed_at`).
|
||||
- `POST /auth/me/onboarding/reset` — nullifies `onboarding_completed_at` so the wizard re-runs.
|
||||
- `POST /auth/onboarding/normalize` — LLM-normalizes free-text onboarding inputs via `gpt-4o-mini`; returns inputs unchanged on error.
|
||||
- `get_current_user()` in `auth.py` middleware now decrypts core memory blocks and includes them in `UserProfile.memory` dict.
|
||||
- `users.onboarding_completed_at` is a nullable TIMESTAMPTZ column — returned as epoch ms (int) in UserProfile schema.
|
||||
|
||||
**i18n (API side)**:
|
||||
- `_language_instruction()` in `app/core/deep_agent.py` reads the user's `language` from `MemoryCore` and appends a system prompt directive ("Always respond in {language}") to all 4 `run_*` functions.
|
||||
- The Electron client writes the user's chosen language to backend core memory on language change, so the API picks it up on the next agent call.
|
||||
|
||||
**Google OAuth (api side)**:
|
||||
- OAuth routes live in `app/api/routes/auth.py`: `GET /auth/oauth/{provider}/authorize`, `POST /auth/oauth/{provider}/callback`, `GET /auth/oauth/{provider}/web-callback` (bounces to deep link, excluded from OpenAPI schema).
|
||||
- Provider abstraction in `app/auth/oauth_providers.py` — `GoogleOAuthProvider` uses `httpx` directly (no `authlib`). PKCE S256 is implemented manually via `generate_pkce_pair()`.
|
||||
- `_pending_states` dict in `routes/auth.py` is **in-memory** — works for single-process dev but does not survive restarts and does not scale to multiple workers. Replace with Redis in production.
|
||||
- `users.password_hash` is **nullable** — social-only users have `password_hash=None`. `await db.flush()` is required before creating a linked `OAuthAccount` to populate `new_user.id` before commit.
|
||||
- `OAUTH_REDIRECT_URI` must point to the **API backend** (e.g. `https://api.adiuvai.com/...`), not the website domain. `adiuvai.com` is a static site with no server-side routing.
|
||||
- **Unverified email + existing account = 409**: if `email_verified=False` and the email is already registered, the callback returns 409. Without this guard, branch 3 would attempt to INSERT a duplicate email and crash with a DB constraint violation (500).
|
||||
- **Testing OAuth routes**: mock `GoogleOAuthProvider.exchange_code` and `get_userinfo` with `patch.object(..., new=AsyncMock(...))` — works because FastAPI instantiates a new provider per request. Use `monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_ID", ...)` to simulate configured credentials without restarting the app.
|
||||
|
||||
### Tier System
|
||||
|
||||
| Feature | Free | Pro | Power | Team |
|
||||
|---------|------|-----|-------|------|
|
||||
| Rate limit | 20/min | 60/min | 120/min | 200/min |
|
||||
| Agents | 3 | unlimited | unlimited | unlimited |
|
||||
| Cloud storage | 0 GB | 5 GB | 25 GB | unlimited |
|
||||
| Plugin marketplace | no | no | yes | yes |
|
||||
|
||||
Enforced in `app/api/middleware/rate_limit.py` (sliding window) and `app/billing/tier_manager.py` (feature checks + quota enforcement).
|
||||
|
||||
---
|
||||
|
||||
## Cross-Project Integration
|
||||
|
||||
The Electron app and FastAPI backend communicate via **WebSocket** (`/chat/stream`):
|
||||
|
||||
1. Electron connects with `?token=<jwt>` query param
|
||||
2. Client sends `ChatRequest` JSON frame
|
||||
3. Server streams text chunks, then a final frame: `{"done": true, "response": "...", "actions": []}`
|
||||
4. Server sends `tool_call` frames → Electron executes against local SQLite → returns `tool_result`
|
||||
5. Server pings every 30 seconds to keep connection alive
|
||||
|
||||
The Electron app also has a **fully local AI path** (LangGraph orchestrator in main process) that doesn't require the backend — this is the primary path for desktop use.
|
||||
|
||||
---
|
||||
|
||||
## MCP Servers
|
||||
|
||||
- **Langfuse Docs** (`https://langfuse.com/api/mcp`) — configured at workspace level for prompt management documentation
|
||||
- **shadcn** (`npx shadcn@latest mcp`) — configured in `adiuvAI/` for UI component generation
|
||||
@@ -1,14 +1,21 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(git add AI_REFACTOR_PLAN.md)",
|
||||
"Bash(git commit:*)",
|
||||
"Read(//home/rmusso/adiuva-api/**)",
|
||||
"mcp__shadcn__get_item_examples_from_registries",
|
||||
"mcp__shadcn__view_items_in_registries",
|
||||
"Bash(npm run lint)",
|
||||
"Bash(npx eslint --ext .ts,.tsx src/renderer/components/ai/blocks/)",
|
||||
"WebFetch(domain:ui.shadcn.com)"
|
||||
"allow": []
|
||||
},
|
||||
"enabledPlugins": {
|
||||
"caveman@caveman": true
|
||||
},
|
||||
"hooks": {
|
||||
"PreToolUse": [
|
||||
{
|
||||
"matcher": "Bash",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "CMD=$(python3 -c \"import json,sys; d=json.load(sys.stdin); print(d.get('tool_input',d).get('command',''))\" 2>/dev/null || true); case \"$CMD\" in *grep*|*rg\\ *|*ripgrep*|*find\\ *|*fd\\ *|*ack\\ *|*ag\\ *) [ -f graphify-out/graph.json ] && echo '{\"hookSpecificOutput\":{\"hookEventName\":\"PreToolUse\",\"additionalContext\":\"graphify: Knowledge graph exists. Read graphify-out/GRAPH_REPORT.md for god nodes and community structure before searching raw files.\"}}' || true ;; esac"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
113
.gitignore
vendored
113
.gitignore
vendored
@@ -1,97 +1,16 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
lerna-debug.log*
|
||||
|
||||
# Diagnostic reports (https://nodejs.org/api/report.html)
|
||||
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
|
||||
|
||||
# Runtime data
|
||||
pids
|
||||
*.pid
|
||||
*.seed
|
||||
*.pid.lock
|
||||
.DS_Store
|
||||
|
||||
# Directory for instrumented libs generated by jscoverage/JSCover
|
||||
lib-cov
|
||||
|
||||
# Coverage directory used by tools like istanbul
|
||||
coverage
|
||||
*.lcov
|
||||
|
||||
# nyc test coverage
|
||||
.nyc_output
|
||||
|
||||
# node-waf configuration
|
||||
.lock-wscript
|
||||
|
||||
# Compiled binary addons (https://nodejs.org/api/addons.html)
|
||||
build/Release
|
||||
|
||||
# Dependency directories
|
||||
node_modules/
|
||||
jspm_packages/
|
||||
|
||||
# TypeScript v1 declaration files
|
||||
typings/
|
||||
|
||||
# TypeScript cache
|
||||
*.tsbuildinfo
|
||||
|
||||
# Optional npm cache directory
|
||||
.npm
|
||||
|
||||
# Optional eslint cache
|
||||
.eslintcache
|
||||
|
||||
# Optional REPL history
|
||||
.node_repl_history
|
||||
|
||||
# Output of 'npm pack'
|
||||
*.tgz
|
||||
|
||||
# Yarn Integrity file
|
||||
.yarn-integrity
|
||||
|
||||
# dotenv environment variables file
|
||||
.env
|
||||
.env.test
|
||||
|
||||
# parcel-bundler cache (https://parceljs.org/)
|
||||
.cache
|
||||
|
||||
# next.js build output
|
||||
.next
|
||||
|
||||
# nuxt.js build output
|
||||
.nuxt
|
||||
|
||||
# vuepress build output
|
||||
.vuepress/dist
|
||||
|
||||
# Serverless directories
|
||||
.serverless/
|
||||
|
||||
# FuseBox cache
|
||||
.fusebox/
|
||||
|
||||
# DynamoDB Local files
|
||||
.dynamodb/
|
||||
|
||||
# Webpack
|
||||
.webpack/
|
||||
|
||||
# Vite
|
||||
.vite/
|
||||
|
||||
# Electron-Forge
|
||||
out/
|
||||
|
||||
# local config files
|
||||
.vscode/
|
||||
.agents/
|
||||
src/renderer/routeTree.gen.ts
|
||||
skills/
|
||||
unused_skills/
|
||||
.vscode/mcp.json
|
||||
.claude/skills/brand-guidelines/*
|
||||
.claude/skills/frontend-design/*
|
||||
.claude/skills/remotion-best-practices/*
|
||||
.mcp.json
|
||||
docs/node_modules
|
||||
docs/package.json
|
||||
docs/package-lock.json
|
||||
tmp/
|
||||
.superpowers/
|
||||
graphify-out/cache/
|
||||
graphify-out/manifest.json
|
||||
graphify-out/cost.json
|
||||
.claude/settings.local.json
|
||||
|
||||
11
.mcp.json
11
.mcp.json
@@ -1,11 +0,0 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"shadcn": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"shadcn@latest",
|
||||
"mcp"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,436 +0,0 @@
|
||||
# AI Refactor Plan — Adiuva Electron App
|
||||
|
||||
> **Objective:** Transform the Electron app into a backend-powered client. All AI intelligence (chat, tool calling, embeddings) lives on the backend. The Electron app owns the local database, executes structured CRUD operations from backend tools via Drizzle ORM, and handles auth, backup, and offline graceful degradation.
|
||||
>
|
||||
> **Backend:** Lives at `../adiuva-api/`. FastAPI + LiteLLM + 4 chat agents (task, checkpoint, project, note). Backend plan: `../adiuva-api/AI_REFACTOR_PLAN.md`.
|
||||
>
|
||||
> **Protocol:** Execute steps sequentially. Each step is atomic and committable. Mark `[x]` when done.
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Renderer (React 19) ──ipcLink──► Main (tRPC + SQLite) ──HTTP/WS──► Backend (FastAPI + LiteLLM)
|
||||
UI only Data + Drizzle executor All AI intelligence
|
||||
```
|
||||
|
||||
**Data flow for chat (bidirectional WebSocket):**
|
||||
1. User types message in renderer → tRPC `ai.chat` mutation
|
||||
2. Main process builds `ChatContext` (queries SQLite for tasks, notes, profile)
|
||||
3. Main opens WS to backend `/api/v1/chat/stream?token=<jwt>`, sends `chat_request` frame
|
||||
4. Backend classifies intent → routes to agent → agent calls LLM with tools
|
||||
5. LLM calls a tool (e.g. `list_tasks`) → tool calls `execute_on_client()`:
|
||||
- Backend sends `tool_call` frame: `{id, action:"select", table:"tasks", filters:{...}}`
|
||||
- Electron receives frame → Drizzle executor: `db.select().from(tasks).where(...)` → real rows
|
||||
- Electron sends `tool_result` frame: `{id, rows: [{id, title, ...}, ...]}`
|
||||
- Tool receives real data → returns formatted string to LLM
|
||||
6. Steps 5 repeats (max 5 iterations) until LLM has enough data to respond
|
||||
7. Backend streams response text → `text_chunk` frames → main forwards via `ai:stream` IPC → renderer
|
||||
8. Backend sends `final` frame: `{"done": true, "response": "..."}`
|
||||
|
||||
**No local LLM.** When offline, AI features show "You're offline" — all other features (tasks, notes, projects) work normally.
|
||||
|
||||
---
|
||||
|
||||
## WS Protocol — Typed Frames
|
||||
|
||||
| Direction | `type` | Payload |
|
||||
|---|---|---|
|
||||
| Client → Server | `chat_request` | `{ message, context }` |
|
||||
| Server → Client | `text_chunk` | `{ text: string }` |
|
||||
| Server → Client | `tool_call` | `{ id, action, table?, data?, filters?, vector?, limit? }` |
|
||||
| Client → Server | `tool_result` | `{ id, row?, rows?, results?, deleted?, ok?, error? }` |
|
||||
| Server → Client | `final` | `{ response: string }` |
|
||||
| Server → Client | `ping` | `{}` |
|
||||
|
||||
**Tool call actions (Electron → Drizzle mapping):**
|
||||
|
||||
| `action` | Drizzle call | Returns |
|
||||
|---|---|---|
|
||||
| `select` | `db.select().from(table).where(filters).all()` | `{ rows: [...] }` |
|
||||
| `get` | `db.select().from(table).where(eq(id, ...)).get()` | `{ row: {...} \| null }` |
|
||||
| `insert` | `db.insert(table).values({id: uuid(), ...data, createdAt: now()}).returning().get()` | `{ row: {...} }` |
|
||||
| `update` | `db.update(table).set(data.updates).where(eq(id,...)).returning().get()` | `{ row: {...} }` |
|
||||
| `delete` | `db.delete(table).where(eq(id,...)).run()` | `{ deleted: true }` |
|
||||
| `vector_upsert` | LanceDB delete-then-add with pre-computed vector | `{ ok: true }` |
|
||||
| `vector_search` | LanceDB `table.search(vector).limit(n)` | `{ results: [{id, content, score}...] }` |
|
||||
|
||||
Electron generates `id` (UUID v4) and `createdAt`/`updatedAt` (Unix ms) for inserts. Backend never generates IDs.
|
||||
|
||||
---
|
||||
|
||||
## Phase 0 — API Contracts & Types ✅
|
||||
|
||||
### Step 0.1 — Define backend API contract types ✅
|
||||
- [x] Create `src/shared/api-types.ts` with Zod schemas + inferred types
|
||||
- [x] Create `src/shared/batch-types.ts` with batch builder + storage types
|
||||
- [x] Update `tsconfig.json` paths — added `@shared/*` alias
|
||||
- **Outcome:** Type-safe contracts for all backend communication.
|
||||
|
||||
---
|
||||
|
||||
## Phase 1 — Auth & Backend Client
|
||||
|
||||
### Step 1.1 — Align shared types with backend schemas
|
||||
- [x] Update `src/shared/api-types.ts` to match backend `app/schemas.py` exactly:
|
||||
- `AuthTokens.expiresAt`: change from `z.string().datetime()` to `z.number().int()` (Unix epoch)
|
||||
- `ChatContext`: replace with backend's flat structure — `{ userProfile, relevantDocuments, recentTasks, conversationHistory }`; remove UI-only fields (`type`, `projectId`, `uiContext`)
|
||||
- Remove `PlanAction` entirely — no more action descriptors
|
||||
- `ChatResponse`: just `{ response: string }` — no `actions` array
|
||||
- Align `PlanStep` / `ExecutionPlan` with backend or remove if plan mode is deferred
|
||||
- [x] Add WebSocket frame Zod schemas:
|
||||
- `ToolCallAction` enum: `select`, `get`, `insert`, `update`, `delete`, `vector_upsert`, `vector_search`
|
||||
- `WsToolCall`: `{ type: "tool_call", id: string, action, table?, data?, filters?, vector?, limit? }`
|
||||
- `WsToolResult`: `{ type: "tool_result", id: string, row?, rows?, results?, deleted?, ok?, error? }`
|
||||
- `WsTextChunk`, `WsFinal`, `WsPing`, `WsChatRequest`
|
||||
- `WsServerFrame` / `WsClientFrame` discriminated unions
|
||||
- [x] Create `src/shared/casing.ts`:
|
||||
- `toSnakeCase(obj)` — deep-converts camelCase keys to snake_case (outgoing)
|
||||
- `toCamelCase(obj)` — deep-converts snake_case keys to camelCase (incoming)
|
||||
- [x] Create `UIChatContext` type in `src/renderer/hooks/useAIChat.ts` for renderer-only fields
|
||||
- **Files:** `src/shared/api-types.ts`, `src/shared/casing.ts`, `src/renderer/hooks/useAIChat.ts`
|
||||
- **Outcome:** Shared types match the live backend 1:1. WS frames are fully typed.
|
||||
|
||||
### Step 1.2 — Auth manager + tRPC procedures
|
||||
- [x] Create `src/main/auth/auth-manager.ts`:
|
||||
- `AuthManager` class (singleton):
|
||||
- `register(email, password): Promise<AuthTokens>` — POST `/api/v1/auth/register`
|
||||
- `login(email, password): Promise<AuthTokens>` — POST `/api/v1/auth/login`
|
||||
- `logout(): void` — clears stored tokens
|
||||
- `getAccessToken(): string | null` — current JWT
|
||||
- `refreshToken(): Promise<void>` — POST `/api/v1/auth/refresh`
|
||||
- `isAuthenticated(): boolean`
|
||||
- `getProfile(): Promise<UserProfile>` — GET `/api/v1/auth/me`
|
||||
- Token storage: reuse `src/main/ai/token.ts` (`safeStorage` + electron-store fallback)
|
||||
- Auto-refresh: check token expiry on every `getAccessToken()` call; if < 5 min remaining, refresh in background
|
||||
- [x] Add `authRouter` tRPC sub-router to `src/main/router/index.ts`
|
||||
- [x] Update `src/main/store.ts`: add `backendUrl: string`
|
||||
- **Files:** `src/main/auth/auth-manager.ts`, `src/main/router/index.ts`, `src/main/store.ts`
|
||||
- **Outcome:** Electron can authenticate with the backend. JWTs stored securely.
|
||||
|
||||
### Step 1.3 — Backend client with bidirectional WebSocket
|
||||
- [x] Create `src/main/api/backend-client.ts`:
|
||||
- `BackendClient` class (singleton):
|
||||
- Constructor: reads `backendUrl` from store, gets JWT from `AuthManager`
|
||||
- `chatStream(request: ChatRequest, onChunk: (text: string) => void): Promise<ChatResponse>`:
|
||||
1. Opens WS to `/api/v1/chat/stream?token=<jwt>`
|
||||
2. Sends `{ type: "chat_request", ... }` frame
|
||||
3. Message loop:
|
||||
- `text_chunk` → calls `onChunk(text)`
|
||||
- `tool_call` → calls `DrizzleExecutor.execute(payload)`, sends back `{ type: "tool_result", id, ... }`
|
||||
- `final` → resolves with `{ response }`
|
||||
- `ping` → ignore
|
||||
- `isOnline(): Promise<boolean>` — GET `/api/v1/health` with 3s timeout
|
||||
- `embedText(text: string): Promise<number[]>` — POST `/api/v1/storage/vectors/embed`
|
||||
- All requests include `Authorization: Bearer <jwt>` header
|
||||
- Auto-retry with exponential backoff (max 3 attempts) for non-auth errors
|
||||
- Response parsing: `toCamelCase()` on all incoming JSON
|
||||
- Request serialization: `toSnakeCase()` on all outgoing JSON
|
||||
- Error categorization: 401 → `AuthExpiredError`, 429 → `RateLimitError`, 5xx → `ServerError`, timeout → `OfflineError`
|
||||
- **Files:** `src/main/api/backend-client.ts`
|
||||
- **Outcome:** Type-safe HTTP + bidirectional WS client. Tool calls handled in the message loop.
|
||||
|
||||
### Step 1.4 — Drizzle executor (the dumb Electron layer)
|
||||
- [x] Create `src/main/api/drizzle-executor.ts`:
|
||||
- Table registry: map string names → Drizzle table objects from `src/main/db/schema.ts`:
|
||||
```
|
||||
{ tasks, projects, clients, checkpoints, notes, taskComments }
|
||||
```
|
||||
- `execute(payload): Promise<object>` — dispatches on `payload.action`:
|
||||
- **`select`**: `db.select().from(table)` + build `.where()` from `payload.filters` using Drizzle `eq()`/`and()`/`like()` + optional `.orderBy()` → returns `{ rows }`
|
||||
- **`get`**: `db.select().from(table).where(eq(table.id, payload.data.id)).get()` → returns `{ row }`
|
||||
- **`insert`**: `db.insert(table).values({id: crypto.randomUUID(), ...payload.data, createdAt: Date.now()}).returning().get()` → returns `{ row }`
|
||||
- **`update`**: `db.update(table).set(payload.data.updates).where(eq(table.id, payload.data.id)).returning().get()` → returns `{ row }`
|
||||
- **`delete`**: `db.delete(table).where(eq(table.id, payload.data.id)).run()` → returns `{ deleted: true }`
|
||||
- **`vector_upsert`**: calls `upsertWithVector()` from `vectordb.ts` with pre-computed vector → returns `{ ok: true }`
|
||||
- **`vector_search`**: LanceDB `table.search(payload.vector).limit(payload.limit)` → returns `{ results }`
|
||||
- Filter builder: maps `{key: value}` objects → Drizzle `and(eq(table[key], value), ...)`. Special cases:
|
||||
- `null` value → `isNull(table[key])`
|
||||
- `search` key → `like(table.title, '%value%')` or `like(table.content, '%value%')`
|
||||
- `orderBy` key → `.orderBy(asc(table[field]))` or `.orderBy(desc(...))`
|
||||
- `includeArchived: false` → adds `eq(table.status, 'active')` filter
|
||||
- `dueDateFrom`/`dueDateTo` → `between(table.dueDate, from, to)`
|
||||
- Security: validate `table` against registry (reject unknown), validate `action` against enum
|
||||
- Uses `getDb()` from `src/main/db/index.ts` — same Drizzle instance as everywhere else
|
||||
- **Files:** `src/main/api/drizzle-executor.ts`
|
||||
- **Outcome:** ~120 lines. Backend sends structured ops, Electron maps to Drizzle. No SQL building.
|
||||
|
||||
### Step 1.5 — Refactor orchestrator to delegate to backend
|
||||
- [x] Replace `src/main/ai/orchestrator.ts` entirely (996 lines → ~190 lines):
|
||||
- `orchestrate({ message, context, sender })`:
|
||||
1. Check `BackendClient.isOnline()` — if offline, return `{ response: '', error: 'You are offline.' }`
|
||||
2. Check `AuthManager.isAuthenticated()` — if not, return `{ response: '', error: 'Please log in.' }`
|
||||
3. Build `ChatContext` from local SQLite (userProfile, recentTasks, conversationHistory)
|
||||
4. Call `BackendClient.chatStream(request, chunk => sendStreamChunk(sender, chunk, false))`
|
||||
- `tool_call` frames handled inside the WS message loop (Step 1.3)
|
||||
5. On completion: `sendStreamChunk(sender, '', true)`
|
||||
- No PlanRunner, no action handling — writes happen mid-conversation via tool calls
|
||||
- Keep `sendStreamChunk()` IPC helper
|
||||
- Export `orchestrate()` and `dailyBrief()`
|
||||
- [x] Update `aiRouter` in `src/main/router/index.ts`:
|
||||
- Remove `setToken` mutation and `hasToken` query (replaced by `auth.status`)
|
||||
- Keep `chat` mutation (same interface) and `dailyBrief`
|
||||
- [x] Update `src/renderer/components/ai/AIChatPanel.tsx`:
|
||||
- Replace `trpc.ai.hasToken.useQuery()` with `trpc.auth.status.useQuery()`
|
||||
- Update auth-gate condition and daily brief trigger to use `authStatusQuery.data?.authenticated`
|
||||
- Replace `KeyRound` icon + provider-config messaging with `LogIn` icon + login messaging
|
||||
- **Files:** `src/main/ai/orchestrator.ts`, `src/main/router/index.ts`, `src/renderer/hooks/useAIChat.ts`
|
||||
- **Outcome:** ~916 lines removed. Chat works through backend. All tool execution is bidirectional.
|
||||
|
||||
### Step 1.6 — Migrate embeddings to backend
|
||||
- [x] Update `src/main/db/vectordb.ts`:
|
||||
- Add `upsertWithVector(noteId, projectId, content, vector)` — takes pre-computed vector, stores in LanceDB
|
||||
- Update `upsertNoteEmbedding()` → calls `BackendClient.embedText(content)` → `upsertWithVector()`
|
||||
- Keep `searchNotes()` and `migrateNotesIfNeeded()` (migration will call backend for embeddings)
|
||||
- If offline: skip embedding (next edit will re-embed when online)
|
||||
- Add `searchNotesByVector(vector, limit)` for direct pre-computed-vector search
|
||||
- [x] Update `src/main/api/drizzle-executor.ts`: use `searchNotesByVector` with pre-computed vector from tool call payload
|
||||
- [x] Delete `src/main/ai/embeddings.ts`
|
||||
- **Files:** `src/main/db/vectordb.ts`, `src/main/api/drizzle-executor.ts`, `src/main/ai/embeddings.ts` (deleted)
|
||||
- **Outcome:** Embeddings generated by backend `/vectors/embed`. Local LanceDB for storage + search.
|
||||
|
||||
---
|
||||
|
||||
## Phase 2 — Remove Local AI Stack
|
||||
|
||||
### Step 2.1 — Remove local AI code and dependencies ✅
|
||||
- [x] Delete `src/main/ai/llm.ts`, `src/main/ai/chat-copilot.ts`, `src/main/ai/copilot.ts`, `src/main/ai/provider.ts`
|
||||
- [x] Remove `import './ai/copilot'` and `initAI()` from `src/main/index.ts`
|
||||
- [x] Remove deps: `@langchain/core`, `@langchain/openai`, `@langchain/anthropic`, `@langchain/langgraph`, `@github/copilot-sdk`
|
||||
- [x] Clean up `src/main/store.ts` (remove `aiProvider`; kept `encryptedTokens` — still used by `token.ts` → `auth-manager.ts` for JWT storage)
|
||||
- [x] Clean up `vite.main.config.mts` (remove externalized LangChain/Copilot packages)
|
||||
- [x] Clean up `forge.config.ts` (remove LangChain/Copilot from `externalPackages`; remove copilot-sdk clipboard cleanup block)
|
||||
- **Files:** `src/main/ai/{llm,chat-copilot,copilot,provider}.ts` (deleted), `package.json`, `src/main/index.ts`, `src/main/store.ts`, `vite.main.config.mts`, `forge.config.ts`
|
||||
- **Outcome:** 34 npm packages removed. No LangChain, no Copilot SDK, no local LLM.
|
||||
|
||||
---
|
||||
|
||||
## Phase 3 — Agent System (Local Directory + Cloud Connectors)
|
||||
|
||||
> Two agent types at launch: **Local Directory Agent** (watches folders, Electron reads + pre-processes, backend runs AI) and **Cloud Connector Agent** (Gmail, Teams — 100% backend-managed). All configs live on the backend (synced, device-bound for local agents). Backend triggers agent runs via new WS frames when Electron is connected. Extracted data inserts into existing tables (tasks, notes, checkpoints) with `isAiSuggested=1`. Configuration prompts are built via a dedicated "Chatbot Journey" (multi-turn AI conversation on a dedicated page).
|
||||
>
|
||||
> **Backend Phase 3 plan:** `../adiuva-api/AI_REFACTOR_PLAN.md` Phase 3 section.
|
||||
|
||||
```
|
||||
Cloud Agent Flow:
|
||||
Backend cron ──► Backend fetches Gmail/Teams ──► Backend AI analyzes
|
||||
──► WS tool_call(insert, table:'tasks') ──► Electron persists locally
|
||||
|
||||
Local Agent Flow:
|
||||
Backend detects Electron online ──► WS agent_run frame (config + prompt)
|
||||
──► Electron reads files + pre-processes ──► WS agent_data frame (content)
|
||||
──► Backend AI analyzes with user prompt ──► WS tool_call(insert) ──► Electron persists
|
||||
```
|
||||
|
||||
Key constraints:
|
||||
- Local agents only run when Electron is active AND on the device where the path was configured
|
||||
- Cloud agents only push results when Electron is connected (no server-side content storage)
|
||||
- All AI communication goes through the backend (no local LLM)
|
||||
- Tier gating: free=2 active, pro=10, power/team=unlimited
|
||||
|
||||
### Step 3.1 — WS frame types + agent handler ✅
|
||||
- [x] Update `src/shared/api-types.ts`:
|
||||
- Add `WsAgentRun` schema: `{ type: "agent_run", run_id, agent_id, config: { paths, file_extensions, prompt_template, data_types } }`
|
||||
- Add `WsAgentData` schema: `{ type: "agent_data", run_id, files: [{ path, name, content, metadata }] }`
|
||||
- Add `WsAgentComplete` schema: `{ type: "agent_complete", run_id, files_read, errors }`
|
||||
- Add `WsDeviceHello` schema: `{ type: "device_hello", device_id, agent_ids }`
|
||||
- Extend `WsServerFrame` discriminated union with `agent_run`
|
||||
- Extend `WsClientFrame` with `agent_data`, `agent_complete`, `device_hello`
|
||||
- [x] Update `src/main/api/backend-client.ts`:
|
||||
- In WS message loop, handle `agent_run` frames:
|
||||
1. Read files from configured paths using the local agent handler (Step 3.2)
|
||||
2. Send `agent_data` frames back with pre-processed content
|
||||
3. Continue handling `tool_call` frames for DB inserts as usual
|
||||
- **Files:** `src/shared/api-types.ts`, `src/main/api/backend-client.ts`
|
||||
- **Outcome:** Electron can receive agent trigger frames and respond with file data.
|
||||
|
||||
### Step 3.2 — Local file reader ✅
|
||||
- [x] Create `src/main/agents/file-reader.ts`:
|
||||
- `readDirectory(paths: string[], extensions: string[]): AsyncGenerator<FileData>` — recursively reads configured directories, filters by extension
|
||||
- `preProcess(filePath: string): { name, content, metadata }`:
|
||||
- `.txt`, `.md`, `.eml` — read as text
|
||||
- `.pdf` — text extraction (dep: `pdf-parse`)
|
||||
- `.docx` — text extraction (dep: `mammoth`)
|
||||
- `.csv`, `.json` — read as structured text
|
||||
- Binary files: skip with warning
|
||||
- Respects path boundaries (no symlink escape, no `..` traversal)
|
||||
- Chunks large files (>50KB) to stay within LLM context limits
|
||||
- Returns `{ path, name, content, metadata: { size, mtime, extension } }`
|
||||
- [x] Update `BackendClient.handleAgentRun()` to call `readAgentFiles()` and return `{ files, errors, filesRead }`
|
||||
- **Files:** `src/main/agents/file-reader.ts`, `src/main/api/backend-client.ts`, `package.json` (`pdf-parse`, `mammoth` added)
|
||||
- **Dependencies:** `pdf-parse`, `mammoth`
|
||||
- **Outcome:** Electron can safely read + pre-process local files for AI analysis.
|
||||
|
||||
### Step 3.3 — Device ID management ✅
|
||||
- [x] Update `src/main/store.ts`: add `deviceId: string` (UUID generated once on first launch and persisted)
|
||||
- [x] Add `getDeviceId()` helper — lazily generates UUID v4 on first call, persists it; subsequent calls return the same value
|
||||
- [x] Add `settings.deviceId` tRPC query to `settingsRouter` — renderer can read the device ID; Step 3.4 (agent router) injects it into local agent config creation calls to the backend
|
||||
- [x] Electron sends `deviceId` when creating local agent configs → backend stores it (Step 3.4)
|
||||
- [x] When backend triggers a local agent run, it checks `config.device_id` matches the connected Electron's `deviceId` (Step 3.5)
|
||||
- **Files:** `src/main/store.ts`, `src/main/router/index.ts`
|
||||
- **Outcome:** Local agents are device-bound. Only triggered on the correct machine.
|
||||
|
||||
### Step 3.4 — Agent tRPC router ✅
|
||||
- [x] Add `agentRouter` to `src/main/router/index.ts`:
|
||||
- `agent.catalog` — query: proxy to backend `GET /api/v1/agents/catalog`
|
||||
- `agent.local.list` / `agent.local.create` / `agent.local.update` / `agent.local.delete` — proxy to backend with `deviceId` injected
|
||||
- `agent.cloud.list` / `agent.cloud.create` / `agent.cloud.update` / `agent.cloud.delete` — proxy to backend
|
||||
- `agent.runs` — query: proxy to backend run log
|
||||
- `agent.runNow` — mutation: proxy to backend manual trigger
|
||||
- `agent.journey.start` / `agent.journey.message` — proxy chatbot journey endpoints
|
||||
- All proxy calls include JWT from AuthManager + snake_case/camelCase conversion
|
||||
- [x] Also added response schemas to `src/shared/api-types.ts`: `AgentCatalogItemSchema`, `LocalAgentConfigSchema`, `CloudAgentConfigSchema`, `AgentRunLogSchema`, `JourneyMessageSchema`
|
||||
- [x] Added `proxyGet/proxyPost/proxyPut/proxyDelete` methods to `BackendClient` (authenticated, casing-converted HTTP proxies)
|
||||
- **Files:** `src/main/router/index.ts`, `src/shared/api-types.ts`, `src/main/api/backend-client.ts`
|
||||
- **Outcome:** Renderer can manage agents through tRPC — all requests proxied to backend.
|
||||
|
||||
### Step 3.5 — Persistent WS connection for agent triggers ✅
|
||||
- [x] Update `src/main/api/backend-client.ts`:
|
||||
- `connectPersistent()` — opens persistent WS to `/api/v1/ws/device?token=<jwt>` on app start
|
||||
- On connect: sends `device_hello` frame with `deviceId` and active agent IDs
|
||||
- Handles incoming `agent_run` frames → dispatches to file reader → sends `agent_data` back
|
||||
- Handles `tool_call` frames for DB inserts (same as chat WS)
|
||||
- `handleAgentRunAndSend()` — validates device ID, calls `handleAgentRun()`, sends `agent_data` + `agent_complete` frames
|
||||
- Auto-reconnects on disconnect with exponential backoff (1s → 2s → 4s → 8s → 16s → 30s cap)
|
||||
- Heartbeat WS-level ping every 30s; pong/message timeout triggers force-reconnect
|
||||
- `disconnectPersistent()` — disables reconnect, clears timers, closes WS cleanly
|
||||
- [x] Call `connectPersistent()` from `src/main/index.ts` after auth check on app startup
|
||||
- [x] `will-quit` handler in `src/main/index.ts` calls `disconnectPersistent()` for clean exit
|
||||
- [x] `authRouter.login` calls `connectPersistent()` on success
|
||||
- [x] `authRouter.logout` calls `disconnectPersistent()`
|
||||
- [x] Device ID validation in `handleAgentRunAndSend()` (completes Step 3.3 final checkbox)
|
||||
|
||||
### Step 3.6 — Agent Library page ✅
|
||||
- [x] Created `src/renderer/routes/settings.tsx`:
|
||||
- Settings page with 2-column layout (left nav: General, Account, Agents, Appearance)
|
||||
- Agents section is the agent library — catalog grid + my agents list with status indicators
|
||||
- Settings icon in sidebar navigates to `/settings` (replaced dropdown)
|
||||
- `validateSearch` for deep-link to specific section (e.g. `?section=account`)
|
||||
- [x] Added route to `src/renderer/routeTree.gen.ts`
|
||||
- [x] Updated sidebar nav in `src/renderer/components/layout/AppShell.tsx` (Settings is now a link)
|
||||
|
||||
### Step 3.7 — Agent config dialogs ✅
|
||||
- [x] `LocalAgentConfigPanel` component (inline, inside expanded agent row in Settings → Agents):
|
||||
- Native `dialog.showOpenDialog` directory picker (via new `dialog:showOpenDialog` IPC + `window.electronDialog` bridge)
|
||||
- File extension filter (preset groups + custom)
|
||||
- Data type selector (checkboxes: tasks, notes, checkpoints, projects)
|
||||
- Schedule picker (preset: every 15min, hourly, 6h, daily, manual)
|
||||
- "Customize AI Prompt" button → opens Chatbot Journey dialog
|
||||
- [x] `CloudAgentConfigPanel` component (inline, inside expanded agent row):
|
||||
- Provider badge + OAuth placeholder note
|
||||
- Data type selector + schedule picker
|
||||
- "Customize AI Prompt" button
|
||||
- [x] `AddAgentDialog` for creating new agents from the catalog
|
||||
- [x] Added `dialog:showOpenDialog` IPC handler in `src/main/index.ts` + `window.electronDialog` exposed in `src/preload/trpc.ts` + type declared in `src/renderer/lib/ipcLink.ts`
|
||||
- **Files:** `src/renderer/routes/settings.tsx`, `src/main/index.ts`, `src/preload/trpc.ts`, `src/renderer/lib/ipcLink.ts`
|
||||
- **Outcome:** Users can fully configure local and cloud agents from the Settings → Agents section.
|
||||
|
||||
### Step 3.8 — Chatbot Journey page ✅
|
||||
- [x] `JourneyDialog` component in `src/renderer/routes/settings.tsx`:
|
||||
- Dialog with spring-animated chat interface (message list, input, send button)
|
||||
- Starts via `agent.journey.start` (passes `agentType` + optional `agentId`) on mount
|
||||
- Multi-turn via `agent.journey.message` tRPC calls
|
||||
- Shows generated prompt preview when `done === true` / `promptTemplate` present
|
||||
- "Save & apply" button: saves promptTemplate to agent via `agent.local.update` / `agent.cloud.update`
|
||||
- Works in both Create flow (from `AddAgentDialog`) and Edit flow (from expanded agent row)
|
||||
- **Files:** `src/renderer/routes/settings.tsx`
|
||||
- **Outcome:** Users configure AI prompts through a guided conversation, directly inside agent config.
|
||||
|
||||
### Step 3.9 — Agent run logs UI ✅
|
||||
- [x] Create `src/renderer/components/agents/AgentRunLog.tsx`:
|
||||
- Per-agent run history: timestamp, status badge, items processed/created, errors
|
||||
- Lazy-loaded (only fetches when agent row is expanded), limit 10 runs
|
||||
- Skeleton loading state + "No runs yet" empty state
|
||||
- Per-run expandable error list (click to reveal all error strings)
|
||||
- Duration display (completedAt - startedAt formatted as Xs / Xm Ys)
|
||||
- Data via `agent.runs` tRPC query
|
||||
- [x] Integrated into `AgentRow` in `src/renderer/routes/settings.tsx` — replaced inline block
|
||||
- **Files:** `src/renderer/components/agents/AgentRunLog.tsx`, `src/renderer/routes/settings.tsx`
|
||||
- **Outcome:** Users see full history and status of each agent's runs with expandable error details.
|
||||
|
||||
---
|
||||
|
||||
## Phase 4 — Security: E2E Backup & Offline
|
||||
|
||||
### Step 4.1 — E2E encrypted backup
|
||||
- [x] `src/main/backup/e2e-crypto.ts` + `backup-manager.ts`
|
||||
- **Outcome:** User data never leaves the device unencrypted.
|
||||
|
||||
### Step 4.2 — Offline sync queue
|
||||
- [x] `src/main/backup/sync-queue.ts` + `sync_queue` table
|
||||
- **Outcome:** Queued actions auto-sync when online.
|
||||
|
||||
> **Step 4.3 (SQLCipher) — Dropped.** OS-level FDE covers at-rest encryption for a local-first desktop app. Backups already E2E encrypted via Argon2id + AES-256-GCM. Native module build complexity, ~10% perf overhead, and key management UX friction not justified by the threat model.
|
||||
|
||||
---
|
||||
|
||||
## Phase 5 — Shared Memory ❌ DEPRECATED
|
||||
|
||||
> **Superseded by V3 architecture.** The backend now implements a 4-tier memory system (Core, Associative, Episodic, Proactive) with per-user Fernet encryption — see `../adiuva-api/V3_MIGRATION_PLAN.md` Steps 6–7. Memory lives server-side, not in Electron SQLite. The Electron orchestrator's `buildChatContext()` is removed in V3 (server fetches data on-demand via tool_call reverse API). Chat history is handled by `conversationHistory` passed in `home_request` frames.
|
||||
>
|
||||
> **See:** `V3_ELECTRON_MIGRATION_PLAN.md` for the replacement architecture.
|
||||
|
||||
---
|
||||
|
||||
## Phase 6 — Renderer UI Updates
|
||||
|
||||
### Step 6.1 — Auth UI + settings restructure ✅
|
||||
- [x] `LoginForm.tsx` — centered login/register screen (`src/renderer/components/auth/LoginForm.tsx`)
|
||||
- [x] Auth gate in `AppShell` — shows `LoginForm` when `auth.status` returns `authenticated: false`; passes through while loading to avoid flicker; `staleTime: 5min` to avoid hammering backend
|
||||
- [x] `SettingsPage.tsx` Account section simplified — login form removed (AppShell handles it), always shows profile + sign out
|
||||
|
||||
### Step 6.2 — ChatPage with context panel ❌ DEPRECATED
|
||||
> **Superseded by V3.** Home chat with block rendering (charts, entities, tables, timelines) and FloatingChat with domain navigation replace this. See `V3_ELECTRON_MIGRATION_PLAN.md` Steps 4–7.
|
||||
|
||||
### Step 6.3 — BatchBuilderPage
|
||||
- [ ] Natural language input, config preview, connector/storage/schedule pickers, batch cards, test runner
|
||||
|
||||
### Step 6.4 — PluginStorePage
|
||||
- [ ] Marketplace + installed tabs, permission dialog on install
|
||||
|
||||
### Step 6.5 — DataManagerPage
|
||||
- [ ] Storage overview, per-source cards, migration wizard
|
||||
|
||||
### Step 6.6 — ActivityLogPage
|
||||
- [ ] Filterable activity table with CSV export
|
||||
|
||||
---
|
||||
|
||||
## Phase 7 — Cleanup & Hardening
|
||||
|
||||
### Step 7.1 — Error handling and logging
|
||||
### Step 7.2 — Integration tests
|
||||
|
||||
---
|
||||
|
||||
## Dependencies to Add
|
||||
|
||||
| Package | Purpose |
|
||||
|---|---|
|
||||
| `ws` | WebSocket client for backend streaming |
|
||||
| `argon2` | Key derivation for E2E backup |
|
||||
| `node-cron` | Batch agent scheduling |
|
||||
| `chokidar` | File watching (plugin) |
|
||||
| `imapflow` | IMAP client (plugin) |
|
||||
|
||||
## Dependencies to Remove
|
||||
|
||||
| Package | Reason |
|
||||
|---|---|
|
||||
| `@langchain/core` | No local LLM |
|
||||
| `@langchain/openai` | No local LLM |
|
||||
| `@langchain/anthropic` | No local LLM |
|
||||
| `@langchain/langgraph` | No local orchestrator |
|
||||
| `@github/copilot-sdk` | No local Copilot |
|
||||
|
||||
---
|
||||
|
||||
## Execution Notes
|
||||
|
||||
- **Phase 1 is the critical path.** Auth + backend client + drizzle executor + orchestrator refactor must land first.
|
||||
- **Steps 1.1–1.4 are additive** — existing app keeps working until Step 1.5 swaps the orchestrator.
|
||||
- **Step 2.1 is the point of no return** — after removing LangChain, there's no local AI fallback.
|
||||
- **Phase B (backend changes) must land before Phase 1.3–1.5** — Electron needs the bidirectional WS to talk to.
|
||||
- **Phase 3 and Phase 4 are independent** — can be parallelized after Phase 2.
|
||||
- **One step at a time.** Mark `[x]` and commit with `step N.N complete: <outcome>`.
|
||||
9
CLAUDE.md
Normal file
9
CLAUDE.md
Normal file
@@ -0,0 +1,9 @@
|
||||
## graphify
|
||||
|
||||
This project has a graphify knowledge graph at graphify-out/.
|
||||
|
||||
Rules:
|
||||
- Before answering architecture or codebase questions, read graphify-out/GRAPH_REPORT.md for god nodes and community structure
|
||||
- If graphify-out/wiki/index.md exists, navigate it instead of reading raw files
|
||||
- For cross-module "how does X relate to Y" questions, prefer `graphify query "<question>"`, `graphify path "<A>" "<B>"`, or `graphify explain "<concept>"` over grep — these traverse the graph's EXTRACTED + INFERRED edges instead of scanning files
|
||||
- After modifying code files in this session, run `graphify update .` to keep the graph current (AST-only, no API cost)
|
||||
@@ -1,400 +0,0 @@
|
||||
# V3 Electron Migration Plan — Multi-Agent AI Productivity App
|
||||
|
||||
> Incremental migration of the Electron app to v3 streaming architecture.
|
||||
> Each step is self-contained, testable, and backwards-compatible until the final cutover.
|
||||
> The backend (`../adiuva-api`) v3 migration is already complete (Steps 1–7).
|
||||
> No test suite — each step is verified manually via the running app.
|
||||
|
||||
---
|
||||
|
||||
## General Rules
|
||||
|
||||
**Code Cleanup**: As you implement each step, remove any code that becomes unused or obsolete. This includes:
|
||||
- Old functions/methods that are superseded by new ones
|
||||
- Deprecated imports or modules
|
||||
- Dead code paths
|
||||
|
||||
This keeps the codebase clean and prevents confusion. When removing code, note it in the commit message if significant.
|
||||
|
||||
---
|
||||
|
||||
## Decisions Log
|
||||
|
||||
| Topic | Decision |
|
||||
|---|---|
|
||||
| WS topology | Merge chat into persistent device WS (no more separate `/api/v1/chat/stream` WS) |
|
||||
| Context building | Remove `buildChatContext()` — server fetches data via reverse API `tool_call` round-trips |
|
||||
| IPC channel | Keep single `ai:stream` channel, change payload shape to discriminated union by `type` field |
|
||||
| Floating surface | Reuse existing `FloatingChat.tsx` — adapt to v3 floating pipeline |
|
||||
| Block rendering | New `blocks/` directory under `components/ai/` for chart, entity, table, timeline components |
|
||||
| useAIChat | Shared hook handles v3 frames for both Home and Floating (no mode split — divergence is in the rendering components) |
|
||||
|
||||
---
|
||||
|
||||
## Step 1 — V3 Frame Types (`api-types.ts`)
|
||||
|
||||
**Goal**: Define the v3 frame vocabulary so all subsequent steps can import typed frames.
|
||||
|
||||
**Changes**:
|
||||
- `src/shared/api-types.ts`:
|
||||
- Add Client → Server frame schemas:
|
||||
- `WsHomeRequest(type: 'home_request', message, conversationHistory?)`
|
||||
- `WsFloatingRequest(type: 'floating_request', message, scope: { type: 'task'|'project'|'note'|'checkpoint', id? })`
|
||||
- Add Server → Client frame schemas:
|
||||
- `WsStreamStart(type: 'stream_start', requestId)`
|
||||
- `WsStreamText(type: 'stream_text', requestId, chunk)`
|
||||
- `WsStreamBlock(type: 'stream_block', requestId, blockType: 'chart'|'entity_ref'|'table'|'timeline', data: Record<string, unknown>)`
|
||||
- `WsStreamEnd(type: 'stream_end', requestId, mutations?)`
|
||||
- `WsFloatingDomain(type: 'floating_domain', requestId, domain: 'tasks'|'notes'|'checkpoints'|'projects')`
|
||||
- Add block data interfaces:
|
||||
- `ChartBlockData { chartType: 'area'|'bar'|'line'|'pie'|'radar'|'radial', title, data: Record<string, unknown>[], config: Record<string, { label: string, color: string }> }`
|
||||
- `EntityRefBlockData { entity: 'task'|'project'|'note'|'checkpoint', items: Record<string, unknown>[] }`
|
||||
- `TableBlockData { headers: string[], rows: string[][] }`
|
||||
- `TimelineBlockData { checkpoints: { id: string, title: string, date: number }[] }`
|
||||
- Add new frames to `WsClientFrameSchema` and `WsServerFrameSchema` discriminated unions
|
||||
- Keep all existing v2 frame types (backward compat until Step 3 removes them)
|
||||
|
||||
**Files touched**: `src/shared/api-types.ts`
|
||||
|
||||
**Test**: App compiles with no type errors. Existing chat still works (v2 frames untouched).
|
||||
```bash
|
||||
source ~/.nvm/nvm.sh && npm run lint
|
||||
```
|
||||
|
||||
**Status**:
|
||||
- [x] Step 1 complete
|
||||
|
||||
**Commit**:
|
||||
```
|
||||
git commit -m "step-1: add v3 ws frame types (api-types.ts)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 2 — Unified WS Chat Transport (`backend-client.ts`)
|
||||
|
||||
**Goal**: Route Home and Floating chat through the persistent device WS instead of opening a separate per-chat WebSocket.
|
||||
|
||||
**Changes**:
|
||||
- `src/main/api/backend-client.ts`:
|
||||
- Add `private streamListeners: Map<string, StreamListener>` — keyed by `requestId`, each holding callbacks: `{ onStart, onText, onBlock, onEnd, onDomain, onError }`
|
||||
- Add `sendHomeRequest(message, conversationHistory?) -> { requestId, promise }`:
|
||||
1. Generates `requestId` (UUID)
|
||||
2. Registers a `StreamListener` in the map
|
||||
3. Sends `{ type: 'home_request', message, conversation_history }` on persistent WS
|
||||
4. Returns a promise that resolves when `stream_end` arrives (or rejects on error/timeout)
|
||||
- Add `sendFloatingRequest(message, scope) -> { requestId, promise }`:
|
||||
1. Same pattern but sends `{ type: 'floating_request', message, scope }`
|
||||
- Extend the persistent WS `on('message')` handler to dispatch v3 frames:
|
||||
- `stream_start` → call `listener.onStart()`
|
||||
- `stream_text` → call `listener.onText(chunk)`
|
||||
- `stream_block` → call `listener.onBlock(blockType, data)`
|
||||
- `stream_end` → call `listener.onEnd(mutations)`, remove listener, resolve promise
|
||||
- `floating_domain` → call `listener.onDomain(domain)`
|
||||
- `tool_call` frames already handled — no change needed (same persistent WS)
|
||||
- **Remove** `chatStream()` method and `openChatWebSocket()` private method (v2 per-chat WS)
|
||||
- **Remove** related imports: `ChatRequest`, `ChatResponse`, `ChatResponseSchema`
|
||||
|
||||
**Files touched**: `src/main/api/backend-client.ts`
|
||||
|
||||
**Test**: App starts, persistent WS connects, existing agent runs + tool calls still work. Chat is broken at this point (orchestrator still calls removed `chatStream()`) — that's expected, fixed in Step 3.
|
||||
```bash
|
||||
source ~/.nvm/nvm.sh && npm start
|
||||
# Verify: [DeviceWS] Connected. in console
|
||||
# Verify: agent_run still works if you have agents configured
|
||||
```
|
||||
|
||||
**Status**:
|
||||
- [x] Step 2 complete
|
||||
|
||||
**Commit**:
|
||||
```
|
||||
git commit -m "step-2: unify chat onto persistent device ws (backend-client.ts)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 3 — Orchestrator + IPC Bridge Refactor (`orchestrator.ts`, `preload/trpc.ts`, `router/index.ts`)
|
||||
|
||||
**Goal**: Orchestrator sends v3 frames via BackendClient and forwards typed stream events to the renderer via IPC.
|
||||
|
||||
**Changes**:
|
||||
- `src/main/ai/orchestrator.ts`:
|
||||
- **Remove** `buildChatContext()` entirely (server fetches data via tool_call reverse API)
|
||||
- **Remove** `sendStreamChunk()` helper
|
||||
- Replace `orchestrate()` with v3 version:
|
||||
1. Check connectivity + auth (unchanged)
|
||||
2. Call `client.sendHomeRequest(message, conversationHistory)` with stream callbacks:
|
||||
- `onStart(requestId)` → `sender.send('ai:stream', { type: 'stream_start', requestId })`
|
||||
- `onText(chunk)` → `sender.send('ai:stream', { type: 'stream_text', requestId, chunk })`
|
||||
- `onBlock(blockType, data)` → `sender.send('ai:stream', { type: 'stream_block', requestId, blockType, data })`
|
||||
- `onEnd()` → `sender.send('ai:stream', { type: 'stream_end', requestId })`
|
||||
3. Return `{ response: 'ok' }` (actual content streamed via IPC)
|
||||
- Add `orchestrateFloating()`:
|
||||
1. Same connectivity + auth checks
|
||||
2. Call `client.sendFloatingRequest(message, scope)` with stream callbacks:
|
||||
- Same as above, plus `onDomain(domain)` → `sender.send('ai:stream', { type: 'floating_domain', requestId, domain })`
|
||||
3. Return `{ response: 'ok' }`
|
||||
- Update `dailyBrief()` to use the v3 `orchestrate()` path
|
||||
- `src/preload/trpc.ts`:
|
||||
- Change `onStreamChunk` payload type from `{ token: string; done: boolean }` to the v3 discriminated union: `{ type: 'stream_start'|'stream_text'|'stream_block'|'stream_end'|'floating_domain', ... }`
|
||||
- Rename export to `onStreamEvent` (breaking change for renderer — fixed in Step 4)
|
||||
- **Remove** `onAction` channel handler (superseded by `stream_block` mutation frames)
|
||||
- `src/main/router/index.ts`:
|
||||
- Update `aiRouter.chat` input to accept optional `mode: 'home' | 'floating'` and optional `scope` for floating
|
||||
- Route to `orchestrate()` or `orchestrateFloating()` based on mode
|
||||
- Keep `dailyBrief` mutation (calls updated `dailyBrief()`)
|
||||
|
||||
**Files touched**: `src/main/ai/orchestrator.ts`, `src/preload/trpc.ts`, `src/main/router/index.ts`
|
||||
|
||||
**Test**: App starts. Sending a chat message from Home triggers `home_request` on persistent WS. Backend streams `stream_start` → `stream_text`* → `stream_end`. Renderer is broken (still expects v2 payloads) — fixed in Step 4.
|
||||
```bash
|
||||
source ~/.nvm/nvm.sh && npm start
|
||||
# Open DevTools → Console: verify [DeviceWS] sends home_request frame
|
||||
# Verify stream_text frames appear in console logs
|
||||
```
|
||||
|
||||
**Status**:
|
||||
- [x] Step 3 complete
|
||||
|
||||
**Commit**:
|
||||
```
|
||||
git commit -m "step-3: refactor orchestrator + ipc bridge to v3 frames"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 4 — Renderer Streaming Hook (`useAIChat.ts`)
|
||||
|
||||
**Goal**: `useAIChat` handles v3 typed stream events and produces structured messages with interleaved text + blocks.
|
||||
|
||||
**Changes**:
|
||||
- `src/renderer/hooks/useAIChat.ts`:
|
||||
- Update `ChatMessage` type:
|
||||
```ts
|
||||
interface StreamBlock {
|
||||
id: string;
|
||||
blockType: 'chart' | 'entity_ref' | 'table' | 'timeline';
|
||||
data: Record<string, unknown>;
|
||||
}
|
||||
|
||||
interface ChatMessage {
|
||||
id: string;
|
||||
role: 'user' | 'assistant';
|
||||
content: string; // accumulated text segments
|
||||
blocks: StreamBlock[]; // interleaved blocks (ordered by arrival)
|
||||
error?: boolean;
|
||||
}
|
||||
```
|
||||
- Replace `window.electronAI.onStreamChunk()` subscription with `window.electronAI.onStreamEvent()`:
|
||||
- `stream_start` → init streaming state, store `requestId`
|
||||
- `stream_text` → append `chunk` to `streamingContentRef` (same as before)
|
||||
- `stream_block` → append `{ id, blockType, data }` to `streamingBlocksRef`
|
||||
- `stream_end` → finalize message with accumulated text + blocks, cleanup
|
||||
- `floating_domain` → call `options?.onDomainSignal?.(domain)` callback
|
||||
- Add `streamingBlocks` state (exposed in return) for live block rendering during stream
|
||||
- Keep `[SECTION:xxx]` tag parsing for backward compat (remove later when floating_domain fully replaces it)
|
||||
- Update `UseAIChatReturn` to include `streamingBlocks: StreamBlock[]`
|
||||
|
||||
**Files touched**: `src/renderer/hooks/useAIChat.ts`
|
||||
|
||||
**Test**: Home chat works end-to-end with text streaming. Text appears word-by-word as before. Blocks array is populated (but not rendered yet — Step 5). FloatingChat also works (shares hook).
|
||||
```bash
|
||||
source ~/.nvm/nvm.sh && npm start
|
||||
# Type a message in Home chat → text streams in
|
||||
# Check React DevTools: message.blocks array exists (may be empty if backend sends text-only)
|
||||
```
|
||||
|
||||
**Status**:
|
||||
- [x] Step 4 complete
|
||||
|
||||
**Commit**:
|
||||
```
|
||||
git commit -m "step-4: update useAIChat for v3 structured streaming"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 5 — Block Renderer Components (`components/ai/blocks/`)
|
||||
|
||||
**Goal**: Visual components that render `stream_block` data inline in chat messages.
|
||||
|
||||
**Changes**:
|
||||
- `src/renderer/components/ai/blocks/ChatChartBlock.tsx` (new):
|
||||
- Receives `ChartBlockData` (`chartType`, `title`, `data`, `config`)
|
||||
- Renders the appropriate shadcn/ui chart component based on `chartType`:
|
||||
- `area` → `AreaChart`, `bar` → `BarChart`, `line` → `LineChart`, `pie` → `PieChart`, `radar` → `RadarChart`, `radial` → `RadialChart`
|
||||
- Uses `ChartContainer` + `ChartTooltip` from shadcn/ui
|
||||
- Wrapped in a card with title, scale-and-fade entrance animation
|
||||
- Install any missing shadcn chart components if needed
|
||||
- `src/renderer/components/ai/blocks/ChatEntityBlock.tsx` (new):
|
||||
- Receives `EntityRefBlockData` (`entity`, `items`)
|
||||
- **Reuses existing components** — no new card renderers:
|
||||
- `task` → `TaskRow` from `components/tasks/TaskRow.tsx` (compact mode, read-only)
|
||||
- `project` → `Item` + `ItemMedia` + `ItemContent` from `components/ui/item.tsx` (same pattern as `ProjectDetail.tsx`)
|
||||
- `note` → `Item` + `ItemContent` (same pattern as note cards in `ProjectDetail.tsx`, with `FileText` icon)
|
||||
- `checkpoint` → `Item` with dashed-border variant + `Sparkles` icon for AI-suggested (same pattern as pending checkpoints in `ProjectDetail.tsx`)
|
||||
- Also reuses `PriorityBadge` from `components/tasks/PriorityBadge.tsx` for task priority display
|
||||
- Maps server data shape to each component's expected props
|
||||
- `src/renderer/components/ai/blocks/ChatTableBlock.tsx` (new):
|
||||
- Receives `TableBlockData` (`headers`, `rows`)
|
||||
- Renders a simple styled table (shadcn Table component)
|
||||
- `src/renderer/components/ai/blocks/ChatTimelineBlock.tsx` (new):
|
||||
- Receives `TimelineBlockData` (`checkpoints`)
|
||||
- **Reuses `GanttChart`** from `components/timeline/GanttChart.tsx` (compact mode, read-only, no context menu)
|
||||
- Maps `TimelineBlockData.checkpoints` to `GanttCheckpoint[]` interface
|
||||
- `src/renderer/components/ai/blocks/index.tsx` (new):
|
||||
- `BlockRenderer` component: switches on `blockType`, renders the appropriate block component
|
||||
- Wraps each block in a `motion.div` with scale-and-fade entrance (spring: stiffness 400, damping 30)
|
||||
|
||||
**Files touched**: `src/renderer/components/ai/blocks/` (5 new files)
|
||||
|
||||
**Test**: Components render correctly when given mock data. Can test by temporarily hardcoding a block in a chat message.
|
||||
```bash
|
||||
source ~/.nvm/nvm.sh && npm start
|
||||
# Temporarily add a mock block to a message in useAIChat to verify rendering
|
||||
```
|
||||
|
||||
**Status**:
|
||||
- [x] Step 5 complete
|
||||
|
||||
**Commit**:
|
||||
```
|
||||
git commit -m "step-5: add block renderer components (chart, entity, table, timeline)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 6 — Home Chat Block Rendering (`AIChatPanel.tsx`)
|
||||
|
||||
**Goal**: Home chat renders blocks inline between text segments.
|
||||
|
||||
**Changes**:
|
||||
- `src/renderer/components/ai/AIChatPanel.tsx`:
|
||||
- Import `BlockRenderer` from `./blocks`
|
||||
- Update assistant message rendering:
|
||||
- After `ChatMarkdown` (text content), render `message.blocks.map(block => <BlockRenderer key={block.id} ... />)`
|
||||
- Blocks appear below/between text in the order they arrived
|
||||
- Update streaming state rendering:
|
||||
- Show `streamingBlocks` (from `useAIChat`) as they arrive during streaming (pop-in effect)
|
||||
- Each block gets a scale-and-fade entrance animation
|
||||
- Daily brief: if the brief response includes blocks, render them in the expandable toast
|
||||
|
||||
**Files touched**: `src/renderer/components/ai/AIChatPanel.tsx`
|
||||
|
||||
**Test**: Send a Home chat message that triggers the backend to return blocks (e.g., "Show me task status for project X" — should produce entity_ref or chart blocks). Text streams in, then blocks pop in when complete.
|
||||
```bash
|
||||
source ~/.nvm/nvm.sh && npm start
|
||||
# Ask: "Show me my task status" or "Give me a summary of project X"
|
||||
# Verify: text streams word-by-word, chart/entity blocks pop in after
|
||||
```
|
||||
|
||||
**Status**:
|
||||
- [x] Step 6 complete
|
||||
|
||||
**Commit**:
|
||||
```
|
||||
git commit -m "step-6: integrate block rendering in home chat (AIChatPanel.tsx)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 7 — Floating Domain Navigation (`FloatingChat.tsx`)
|
||||
|
||||
**Goal**: FloatingChat sends `floating_request` and handles `floating_domain` for background page navigation. Text-only rendering (no blocks).
|
||||
|
||||
**Changes**:
|
||||
- `src/renderer/components/ai/FloatingChat.tsx`:
|
||||
- Update `useAIChat` call to pass `onDomainSignal` callback:
|
||||
- Maps domain to route: `tasks → /tasks`, `projects → /projects`, `checkpoints → /timeline`, `notes → /notes`
|
||||
- Calls `navigate()` to the target route (background navigation — panel stays open)
|
||||
- Replaces the `SECTION_ROUTES` + `[SECTION:xxx]` tag mechanism with the deterministic `floating_domain` signal
|
||||
- Update `chatContext` construction to include `scope`:
|
||||
- When opened on a specific entity (e.g., double-click task #42): `scope: { type: 'task', id: 'task_42' }`
|
||||
- When opened on an area (e.g., tasks list): `scope: { type: 'task' }` (no id)
|
||||
- Messages render text-only — explicitly do **not** render `message.blocks` (floating is text-only per v3 spec)
|
||||
- Remove `SECTION_ROUTES` constant and `handleSectionTag` callback (replaced by `floating_domain`)
|
||||
- Remove `onSectionTag` option from `useAIChat` call (cleanup — if no other consumer uses it, remove from hook too)
|
||||
|
||||
**Files touched**: `src/renderer/components/ai/FloatingChat.tsx`, possibly `src/renderer/hooks/useAIChat.ts` (remove `onSectionTag` if unused)
|
||||
|
||||
**Test**: Double-click an entity → FloatingChat opens → type a question → floating_domain signal arrives → background page navigates → text streams in the panel.
|
||||
```bash
|
||||
source ~/.nvm/nvm.sh && npm start
|
||||
# Double-click a task → FloatingChat opens
|
||||
# Ask: "What's the checkpoint status?"
|
||||
# Verify: background navigates to /timeline, text streams in floating
|
||||
```
|
||||
|
||||
**Status**:
|
||||
- [x] Step 7 complete
|
||||
|
||||
**Commit**:
|
||||
```
|
||||
git commit -m "step-7: floating domain navigation in floating chat (FloatingChat.tsx)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 8 — Cleanup
|
||||
|
||||
**Goal**: Remove all v2 chat artifacts that are no longer used.
|
||||
|
||||
**Changes**:
|
||||
- `src/shared/api-types.ts`:
|
||||
- Remove v2 chat schemas: `WsChatRequestSchema`, `WsTextChunkSchema`, `WsFinalSchema`, `ChatContextSchema`, `ChatRequestSchema`, `ChatResponseSchema`
|
||||
- Remove from `WsClientFrameSchema` and `WsServerFrameSchema` discriminated unions
|
||||
- `src/main/api/backend-client.ts`:
|
||||
- Remove any leftover v2 imports or dead code
|
||||
- `src/main/ai/orchestrator.ts`:
|
||||
- Remove `OrchestrateResult` interface if no longer needed
|
||||
- Remove `AI_STREAM_CHANNEL` constant (now in preload only)
|
||||
- `src/preload/trpc.ts`:
|
||||
- Remove `onAction` channel if still present
|
||||
- `src/renderer/hooks/useAIChat.ts`:
|
||||
- Remove `onSectionTag` option if fully replaced by `onDomainSignal`
|
||||
- `src/main/router/index.ts`:
|
||||
- Clean up `aiRouter.chat` input schema (remove `uiContext` field — no longer sent)
|
||||
|
||||
**Files touched**: Multiple (cleanup pass)
|
||||
|
||||
**Test**: Full app smoke test — Home chat, Floating chat, daily brief, agent runs all work.
|
||||
```bash
|
||||
source ~/.nvm/nvm.sh && npm run lint && npm start
|
||||
```
|
||||
|
||||
**Status**:
|
||||
- [x] Step 8 complete
|
||||
|
||||
**Commit**:
|
||||
```
|
||||
git commit -m "step-8: remove v2 chat artifacts (cleanup)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
| Step | Component | Effort | Depends On |
|
||||
|------|-----------|--------|------------|
|
||||
| 1 | V3 Frame Types | Low | — |
|
||||
| 2 | Unified WS Transport | High | Step 1 |
|
||||
| 3 | Orchestrator + IPC Bridge | Medium | Step 2 |
|
||||
| 4 | Renderer Streaming Hook | Medium | Step 3 |
|
||||
| 5 | Block Renderer Components | High | Step 1 (types only) |
|
||||
| 6 | Home Chat Blocks | Medium | Steps 4, 5 |
|
||||
| 7 | Floating Domain Navigation | Medium | Step 4 |
|
||||
| 8 | Cleanup | Low | Steps 6, 7 |
|
||||
|
||||
Steps 1–4 form the streaming pipeline (serial dependency chain).
|
||||
Step 5 can run in parallel with Steps 2–4 (only needs types from Step 1).
|
||||
Steps 6 and 7 can run in parallel after Step 4 + 5.
|
||||
Step 8 is the final cleanup after everything works.
|
||||
|
||||
### What stays untouched
|
||||
- `src/main/api/drizzle-executor.ts` — already v3-compatible (reverse API)
|
||||
- `src/main/ai/token.ts` — unchanged
|
||||
- `src/main/agents/file-reader.ts` — unchanged
|
||||
- `src/main/db/` — no schema changes
|
||||
- `src/renderer/routes/` — no route changes
|
||||
- All existing tRPC routers (tasks, projects, notes, checkpoints, etc.) — unchanged
|
||||
95
api/.env.example
Normal file
95
api/.env.example
Normal file
@@ -0,0 +1,95 @@
|
||||
# ── Application ──────────────────────────────────────────────────────────────
|
||||
ENV=dev
|
||||
|
||||
# ── Database ──────────────────────────────────────────────────────────────────
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai
|
||||
|
||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||
JWT_SECRET=replace-with-a-long-random-secret
|
||||
JWT_ALGORITHM=HS256
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
|
||||
# ── LLM ───────────────────────────────────────────────────────────────────────
|
||||
# LiteLLM model identifiers — change to swap providers without code changes.
|
||||
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
||||
#
|
||||
# API keys — only the key(s) matching your chosen provider(s) are required.
|
||||
# The correct key is picked automatically from the model prefix (e.g.
|
||||
# "anthropic/..." → ANTHROPIC_API_KEY, "gemini/..." → GOOGLE_API_KEY).
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
GOOGLE_API_KEY=
|
||||
CEREBRAS_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
DEEPSEEK_API_KEY=
|
||||
|
||||
# Default model used by any agent that does not have a specific override below.
|
||||
LLM_MODEL=gpt-5-mini
|
||||
LLM_EMBED_MODEL=text-embedding-3-small
|
||||
|
||||
# GitHub Copilot — leave empty to use the LiteLLM default token directory.
|
||||
# In Docker, point this to a named-volume path so tokens survive restarts.
|
||||
# GITHUB_COPILOT_TOKEN_DIR=
|
||||
|
||||
# ── Per-agent model overrides ─────────────────────────────────────────────────
|
||||
# Leave a value empty to fall back to LLM_MODEL.
|
||||
# Each agent resolves its API key from the model prefix automatically.
|
||||
#
|
||||
# Intent classifier — routes user messages to the right domain agent.
|
||||
# A small/fast model (e.g. gpt-4o-mini) is usually sufficient here.
|
||||
LLM_MODEL_CLASSIFIER=
|
||||
|
||||
# Home-agent — handles chat from the home screen (all tools available).
|
||||
LLM_MODEL_HOME_AGENT=
|
||||
|
||||
# Floating-agent — handles contextual chat triggered from a task/project/note.
|
||||
LLM_MODEL_FLOATING_AGENT=
|
||||
|
||||
# Unified-processor — processes local directory files (local agent runner).
|
||||
LLM_MODEL_UNIFIED_PROCESSOR=
|
||||
|
||||
# Cloud-processor — fetches and processes data from cloud connectors.
|
||||
LLM_MODEL_CLOUD_PROCESSOR=
|
||||
|
||||
# Brief-agent — produces home and project text briefs.
|
||||
# A small model (e.g. gpt-4o-mini) is sufficient.
|
||||
# LLM_MODEL_BRIEF_AGENT=
|
||||
|
||||
# Task-brief-agent — per-task deep research (Stage 1 executive assistant).
|
||||
# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash).
|
||||
# LLM_MODEL_TASK_BRIEF_AGENT=
|
||||
|
||||
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
||||
LLM_MODEL_SETUP_AGENT=
|
||||
|
||||
# Memory-extractor — Mem0-style extract/decide pipeline (Phase 2).
|
||||
# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0).
|
||||
LLM_MODEL_MEMORY_EXTRACTOR=
|
||||
|
||||
# Memory-miner — proactive pattern mining from episodic history (Phase 5, Power+ only).
|
||||
# Defaults to gpt-4o-mini when empty.
|
||||
LLM_MODEL_MEMORY_MINER=
|
||||
|
||||
# Memory-auditor — weekly contradiction scan + relation label canonicalization (Phase 7).
|
||||
# Defaults to LLM_MODEL when empty (a reasoning-capable model is recommended).
|
||||
LLM_MODEL_MEMORY_AUDITOR=
|
||||
|
||||
# Scheduler — set to false to disable memory cron jobs (automatically false in tests).
|
||||
SCHEDULER_ENABLED=true
|
||||
|
||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||
STRIPE_SECRET_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
|
||||
# ── Langfuse (leave empty to disable observability) ───────────────────────────
|
||||
LANGFUSE_SECRET_KEY=
|
||||
LANGFUSE_PUBLIC_KEY=
|
||||
# LANGFUSE_BASE_URL=https://cloud.langfuse.com # EU (default)
|
||||
# LANGFUSE_BASE_URL=https://us.cloud.langfuse.com # US
|
||||
# LANGFUSE_BASE_URL=http://localhost:3000 # Self-hosted
|
||||
|
||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||
# Comma-separated list parsed by Settings (override default if needed)
|
||||
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
||||
93
api/.gitea/workflows/deploy.yaml
Normal file
93
api/.gitea/workflows/deploy.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
name: Test & Deploy API
|
||||
run-name: ${{ gitea.ref_name }} → Docker LXC
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
jobs:
|
||||
# ── 1. Run tests in an isolated Python container ──────────────────
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: python:3.12-slim
|
||||
|
||||
steps:
|
||||
- name: Install git
|
||||
run: apt-get update && apt-get install -y --no-install-recommends git
|
||||
|
||||
- name: Checkout Code
|
||||
run: |
|
||||
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
||||
"http://10.0.0.119:3000/${GITHUB_REPOSITORY}.git" . || \
|
||||
git clone --depth 1 "http://10.0.0.119:3000/${GITHUB_REPOSITORY}.git" . && \
|
||||
git checkout "${GITHUB_SHA}"
|
||||
|
||||
- name: Install Dependencies
|
||||
run: pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
- name: Run Linter
|
||||
run: ruff check app/ tests/
|
||||
|
||||
- name: Run Tests
|
||||
run: pytest tests/ -v --tb=short
|
||||
|
||||
# ── 2. Deploy to Docker LXC via SSH ─────────────────────────────────
|
||||
deploy:
|
||||
needs: test
|
||||
runs-on: ubuntu-latest
|
||||
if: gitea.event_name == 'push'
|
||||
|
||||
steps:
|
||||
- name: Deploy via SSH
|
||||
uses: appleboy/ssh-action@v1.0.0
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_KEY }}
|
||||
script: |
|
||||
set -e
|
||||
DEPLOY_DIR="/opt/adiuvai-api"
|
||||
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
||||
TAG="${{ gitea.ref_name }}"
|
||||
|
||||
# ── Pull latest code ──
|
||||
cd /tmp && rm -rf adiuvai-api-deploy
|
||||
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuvai-api-deploy
|
||||
|
||||
# ── Sync source (preserve .env) ──
|
||||
cp -rf /tmp/adiuvai-api-deploy/app/ \
|
||||
/tmp/adiuvai-api-deploy/alembic/ \
|
||||
/tmp/adiuvai-api-deploy/alembic.ini \
|
||||
/tmp/adiuvai-api-deploy/Dockerfile \
|
||||
/tmp/adiuvai-api-deploy/docker-compose.yml \
|
||||
/tmp/adiuvai-api-deploy/requirements.txt \
|
||||
"$DEPLOY_DIR/"
|
||||
rm -rf /tmp/adiuvai-api-deploy
|
||||
|
||||
# ── Verify .env ──
|
||||
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
||||
echo "❌ $DEPLOY_DIR/.env not found. Create it before deploying."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ── Build & restart ──
|
||||
cd "$DEPLOY_DIR"
|
||||
docker compose down --remove-orphans || true
|
||||
docker compose up -d --build
|
||||
|
||||
# ── Migrations ──
|
||||
docker compose exec -T app alembic upgrade head
|
||||
|
||||
# ── Health check ──
|
||||
echo "Waiting for app..."
|
||||
sleep 5
|
||||
HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/api/v1/health)
|
||||
if [ "$HTTP_CODE" -eq 200 ]; then
|
||||
echo "✅ API is healthy (HTTP ${HTTP_CODE})"
|
||||
else
|
||||
echo "❌ Health check failed (HTTP ${HTTP_CODE})"
|
||||
docker compose logs app --tail=50
|
||||
exit 1
|
||||
fi
|
||||
64
api/.github/workflows/ci.yml
vendored
Normal file
64
api/.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,64 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: Lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install ruff
|
||||
run: pip install ruff>=0.8.0
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check .
|
||||
|
||||
- name: Ruff format check
|
||||
run: ruff format --check .
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
needs: lint
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Cache pip
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install -r requirements.txt
|
||||
|
||||
- name: Run tests
|
||||
run: pytest -v --tb=short
|
||||
|
||||
docker:
|
||||
name: Docker Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: test
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Build image
|
||||
run: docker build -t adiuvai-api:ci .
|
||||
|
||||
- name: Verify gunicorn installed
|
||||
run: docker run --rm adiuvai-api:ci gunicorn --version
|
||||
38
api/.gitignore
vendored
Normal file
38
api/.gitignore
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Virtual environment
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Testing / coverage
|
||||
.pytest_cache/
|
||||
htmlcov/
|
||||
.coverage
|
||||
tests/fixtures/private*/
|
||||
|
||||
# Docker
|
||||
*.log
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
|
||||
# Smoke scripts (dev-only, not for CI)
|
||||
scripts/smoke_*.py
|
||||
Thumbs.db
|
||||
|
||||
# Claude Code
|
||||
.claude/
|
||||
logs/
|
||||
39
api/Dockerfile
Normal file
39
api/Dockerfile
Normal file
@@ -0,0 +1,39 @@
|
||||
# ── builder ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --upgrade pip && \
|
||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||
|
||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS runtime
|
||||
|
||||
# Non-root user
|
||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy installed packages from builder
|
||||
COPY --from=builder /install /usr/local
|
||||
|
||||
# Copy application source
|
||||
COPY app/ app/
|
||||
|
||||
# Copy Alembic migration files
|
||||
COPY alembic/ alembic/
|
||||
COPY alembic.ini .
|
||||
|
||||
# Ensure appuser owns the working directory
|
||||
RUN chown -R appuser:appgroup /app
|
||||
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["gunicorn", "app.main:app", \
|
||||
"-k", "uvicorn.workers.UvicornWorker", \
|
||||
"--bind", "0.0.0.0:8000", \
|
||||
"--workers", "4", \
|
||||
"--timeout", "120"]
|
||||
5
api/README.md
Normal file
5
api/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
## DEV
|
||||
Run in DEV with command:
|
||||
```
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-config logging.conf
|
||||
```
|
||||
47
api/alembic.ini
Normal file
47
api/alembic.ini
Normal file
@@ -0,0 +1,47 @@
|
||||
# Alembic configuration file.
|
||||
# The async app uses postgresql+asyncpg:// at runtime.
|
||||
# Alembic CLI uses the sync psycopg2 URL set in env.py (reads from DATABASE_URL env var).
|
||||
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
prepend_sys_path = .
|
||||
version_path_separator = os
|
||||
|
||||
# sqlalchemy.url is overridden in alembic/env.py — leave as placeholder.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
93
api/alembic/env.py
Normal file
93
api/alembic/env.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Alembic migration environment — async-compatible.
|
||||
|
||||
At runtime the app uses ``postgresql+asyncpg://``. Alembic's CLI is
|
||||
synchronous, so we derive a *sync* psycopg2 URL from the same DATABASE_URL
|
||||
env var by replacing the driver prefix.
|
||||
|
||||
Run migrations with:
|
||||
alembic upgrade head
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
# Alembic Config object (gives access to alembic.ini values).
|
||||
config = context.config
|
||||
|
||||
# Set up Python logging from alembic.ini.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Import the Base so that Alembic can detect model changes for --autogenerate.
|
||||
from app.models import Base # noqa: E402
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def _sync_url(async_url: str) -> str:
|
||||
"""Convert an asyncpg URL to a psycopg2 URL for Alembic CLI."""
|
||||
return re.sub(r"postgresql\+asyncpg", "postgresql+psycopg2", async_url)
|
||||
|
||||
|
||||
def _get_url() -> str:
|
||||
db_url = os.environ.get("DATABASE_URL", "")
|
||||
if not db_url:
|
||||
# Fall back to settings if env var not set directly.
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
db_url = settings.DATABASE_URL
|
||||
return _sync_url(db_url)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Emit SQL without a live DB connection."""
|
||||
url = _get_url()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection): # type: ignore[no-untyped-def]
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_migrations_online_async() -> None:
|
||||
"""Run migrations against a live DB using the async engine."""
|
||||
async_url = os.environ.get("DATABASE_URL", "")
|
||||
if not async_url:
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
async_url = settings.DATABASE_URL
|
||||
|
||||
connectable = create_async_engine(async_url, poolclass=pool.NullPool)
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
asyncio.run(run_migrations_online_async())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
api/alembic/script.py.mako
Normal file
28
api/alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
84
api/alembic/versions/001_initial_schema.py
Normal file
84
api/alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Initial schema: users, refresh_tokens, subscriptions.
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2026-03-02
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "001"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Enum types — idempotent creation via exception handling ───────────
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE billing_tier AS ENUM ('free', 'pro', 'power', 'team');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
|
||||
# ── users ─────────────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("email", sa.String(255), nullable=False),
|
||||
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||
sa.Column("tier", postgresql.ENUM("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
|
||||
sa.Column("stripe_customer_id", sa.String(255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("email"),
|
||||
)
|
||||
op.create_index("ix_users_email", "users", ["email"])
|
||||
|
||||
# ── refresh_tokens ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"refresh_tokens",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("token_hash", sa.String(64), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("token_hash"),
|
||||
)
|
||||
op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"])
|
||||
op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"])
|
||||
|
||||
# ── subscriptions ─────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"subscriptions",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("stripe_subscription_id", sa.String(255), nullable=True),
|
||||
sa.Column("tier", postgresql.ENUM("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
|
||||
sa.Column("status", sa.String(50), nullable=False, server_default="free"),
|
||||
sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("user_id"),
|
||||
)
|
||||
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("subscriptions")
|
||||
op.drop_table("refresh_tokens")
|
||||
op.drop_table("users")
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||
127
api/alembic/versions/003_agent_tables.py
Normal file
127
api/alembic/versions/003_agent_tables.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Add agent config and run log tables: local_agent_configs, cloud_agent_configs, agent_run_logs.
|
||||
|
||||
Revision ID: 003
|
||||
Revises: 002
|
||||
Create Date: 2026-03-05
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "003"
|
||||
down_revision: Union[str, None] = "001"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Enum types — idempotent creation ──────────────────────────────────
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
|
||||
# ── local_agent_configs ───────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"local_agent_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("device_id", sa.String(255), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||
|
||||
# ── cloud_agent_configs ───────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"cloud_agent_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column(
|
||||
"provider",
|
||||
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||
|
||||
# ── agent_run_logs ─────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"agent_run_logs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
# Plain string — not a FK because it references either local_agent_configs or
|
||||
# cloud_agent_configs depending on agent_type.
|
||||
sa.Column("agent_id", sa.String(255), nullable=False),
|
||||
sa.Column(
|
||||
"agent_type",
|
||||
postgresql.ENUM("local", "cloud", name="agent_type", create_type=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
postgresql.ENUM("running", "success", "error", "partial", name="agent_run_status", create_type=False),
|
||||
nullable=False,
|
||||
server_default="running",
|
||||
),
|
||||
sa.Column("items_processed", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("items_created", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("errors", sa.JSON, nullable=True),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_agent_run_logs_user_id", "agent_run_logs", ["user_id"])
|
||||
op.create_index("ix_agent_run_logs_agent_id", "agent_run_logs", ["agent_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("agent_run_logs")
|
||||
op.drop_table("cloud_agent_configs")
|
||||
op.drop_table("local_agent_configs")
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS cloud_provider;")
|
||||
op.execute("DROP TYPE IF EXISTS agent_run_status;")
|
||||
op.execute("DROP TYPE IF EXISTS agent_type;")
|
||||
144
api/alembic/versions/004_add_memory_tables.py
Normal file
144
api/alembic/versions/004_add_memory_tables.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Add memory tables and user encryption_key column.
|
||||
|
||||
Memory tables:
|
||||
memory_core — per-user key/value preferences (encrypted)
|
||||
memory_associative — semantic memory with pgvector embedding (encrypted)
|
||||
memory_episodic — session summaries (encrypted)
|
||||
memory_proactive — behavioral patterns (encrypted)
|
||||
|
||||
Also adds encryption_key column to users table.
|
||||
|
||||
Revision ID: 004
|
||||
Revises: 003
|
||||
Create Date: 2026-03-08
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "004"
|
||||
down_revision: Union[str, None] = "003"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Enable pgvector extension (idempotent) ────────────────────────────────
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||
|
||||
# ── Add encryption_key to users ───────────────────────────────────────────
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("encryption_key", sa.String(64), nullable=True),
|
||||
)
|
||||
|
||||
# ── memory_core ───────────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"memory_core",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=False),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("key", sa.String(255), nullable=False),
|
||||
sa.Column("value_encrypted", sa.Text, nullable=False),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.create_index("ix_memory_core_user_id", "memory_core", ["user_id"])
|
||||
|
||||
# ── memory_associative ────────────────────────────────────────────────────
|
||||
# The embedding column uses pgvector's vector(1536) type.
|
||||
op.create_table(
|
||||
"memory_associative",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=False),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("content_encrypted", sa.Text, nullable=False),
|
||||
sa.Column("entity_type", sa.String(100), nullable=True),
|
||||
sa.Column("entity_id", sa.String(255), nullable=True),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
# Add the pgvector column separately (not supported by generic sa types)
|
||||
op.execute(
|
||||
"ALTER TABLE memory_associative ADD COLUMN embedding vector(1536);"
|
||||
)
|
||||
op.create_index("ix_memory_associative_user_id", "memory_associative", ["user_id"])
|
||||
# IVFFlat index for approximate nearest-neighbour search
|
||||
op.execute(
|
||||
"CREATE INDEX ix_memory_associative_embedding "
|
||||
"ON memory_associative USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);"
|
||||
)
|
||||
|
||||
# ── memory_episodic ───────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"memory_episodic",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=False),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("summary_encrypted", sa.Text, nullable=False),
|
||||
sa.Column("session_id", sa.String(255), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.create_index("ix_memory_episodic_user_id", "memory_episodic", ["user_id"])
|
||||
op.create_index("ix_memory_episodic_session_id", "memory_episodic", ["session_id"])
|
||||
|
||||
# ── memory_proactive ──────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"memory_proactive",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=False),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("pattern_encrypted", sa.Text, nullable=False),
|
||||
sa.Column("confidence", sa.Float, nullable=False, server_default="0.5"),
|
||||
sa.Column("source", sa.String(50), nullable=False, server_default="inferred"),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.create_index("ix_memory_proactive_user_id", "memory_proactive", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("memory_proactive")
|
||||
op.drop_table("memory_episodic")
|
||||
op.drop_index("ix_memory_associative_embedding", "memory_associative")
|
||||
op.drop_table("memory_associative")
|
||||
op.drop_table("memory_core")
|
||||
op.drop_column("users", "encryption_key")
|
||||
54
api/alembic/versions/005_associative_pgvector.py
Normal file
54
api/alembic/versions/005_associative_pgvector.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Phase 1 — confirm pgvector activation on memory_associative.
|
||||
|
||||
Migration 004 created the embedding column as vector(1536) and added the
|
||||
IVFFlat index. This migration is the Phase-1 checkpoint:
|
||||
1. Ensures the pgvector extension is enabled (idempotent).
|
||||
2. Ensures the canonical Phase-1 IVFFlat index exists under the name
|
||||
memory_associative_embedding_idx (creates it only if absent).
|
||||
|
||||
Revision ID: 005
|
||||
Revises: 9a1f2d0b6c7e
|
||||
Create Date: 2026-04-15
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "005"
|
||||
down_revision: Union[str, None] = "e04100e88ace"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Ensure pgvector extension is enabled (also done in 004, idempotent).
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||
|
||||
# Ensure the canonical Phase-1 IVFFlat index exists.
|
||||
# 004 may have created ix_memory_associative_embedding; this adds the
|
||||
# Phase-1 name memory_associative_embedding_idx if it is missing.
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_indexes
|
||||
WHERE tablename = 'memory_associative'
|
||||
AND indexname = 'memory_associative_embedding_idx'
|
||||
) THEN
|
||||
CREATE INDEX memory_associative_embedding_idx
|
||||
ON memory_associative
|
||||
USING ivfflat (embedding vector_cosine_ops)
|
||||
WITH (lists = 100);
|
||||
END IF;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS memory_associative_embedding_idx;")
|
||||
74
api/alembic/versions/006_memory_relations.py
Normal file
74
api/alembic/versions/006_memory_relations.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Add memory_relations table (Phase 3 — relational tier).
|
||||
|
||||
Revision ID: 006
|
||||
Revises: 1f5975a4f3f4
|
||||
Create Date: 2026-04-16
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "006"
|
||||
down_revision: Union[str, None] = "1f5975a4f3f4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"memory_relations",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=False),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("subject_label", sa.String(128), nullable=False),
|
||||
sa.Column("subject_type", sa.String(32), nullable=False),
|
||||
sa.Column("predicate", sa.String(64), nullable=False),
|
||||
sa.Column("object_label", sa.String(128), nullable=False),
|
||||
sa.Column("object_type", sa.String(32), nullable=False),
|
||||
sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"),
|
||||
sa.Column(
|
||||
"source_episode_id",
|
||||
postgresql.UUID(as_uuid=False),
|
||||
sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("notes_encrypted", sa.LargeBinary, nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
"memory_relations_user_subject_idx",
|
||||
"memory_relations",
|
||||
["user_id", "subject_label"],
|
||||
)
|
||||
op.create_index(
|
||||
"memory_relations_user_predicate_idx",
|
||||
"memory_relations",
|
||||
["user_id", "predicate"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("memory_relations_user_predicate_idx", "memory_relations")
|
||||
op.drop_index("memory_relations_user_subject_idx", "memory_relations")
|
||||
op.drop_table("memory_relations")
|
||||
41
api/alembic/versions/007_rename_agents_to_scouts.py
Normal file
41
api/alembic/versions/007_rename_agents_to_scouts.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Rename agents to scouts.
|
||||
|
||||
Revision ID: 007
|
||||
Revises: d6e3f4a5b6c7
|
||||
Create Date: 2026-05-15
|
||||
|
||||
Renames the entire agents subsystem identifiers to scouts.
|
||||
Pre-1.0 — no data preservation concerns beyond ALTER TABLE rename.
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "007"
|
||||
down_revision: Union[str, None] = "d6e3f4a5b6c7"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Tables
|
||||
op.rename_table("local_agent_configs", "local_scout_configs")
|
||||
op.rename_table("cloud_agent_configs", "cloud_scout_configs")
|
||||
op.rename_table("agent_run_logs", "scout_run_logs")
|
||||
|
||||
# Columns
|
||||
op.alter_column("local_scout_configs", "agent_config", new_column_name="scout_config")
|
||||
op.alter_column("scout_run_logs", "agent_id", new_column_name="scout_id")
|
||||
op.alter_column("scout_run_logs", "agent_type", new_column_name="scout_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("scout_run_logs", "scout_type", new_column_name="agent_type")
|
||||
op.alter_column("scout_run_logs", "scout_id", new_column_name="agent_id")
|
||||
op.alter_column("local_scout_configs", "scout_config", new_column_name="agent_config")
|
||||
|
||||
op.rename_table("scout_run_logs", "agent_run_logs")
|
||||
op.rename_table("cloud_scout_configs", "cloud_agent_configs")
|
||||
op.rename_table("local_scout_configs", "local_agent_configs")
|
||||
59
api/alembic/versions/008_scout_triage_queue.py
Normal file
59
api/alembic/versions/008_scout_triage_queue.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Scout triage queue + cloud_scout_configs alterations.
|
||||
|
||||
Revision ID: 008
|
||||
Revises: 007
|
||||
Create Date: 2026-05-16
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "008"
|
||||
down_revision: Union[str, None] = "007"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"scout_triage_queue",
|
||||
sa.Column("id", sa.Uuid(as_uuid=False), primary_key=True),
|
||||
sa.Column("user_id", sa.Uuid(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("scout_id", sa.Uuid(as_uuid=False), sa.ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("source_type", sa.String(50), nullable=False),
|
||||
sa.Column("source_msg_ref", sa.String(255), nullable=False),
|
||||
sa.Column("triage_verdict", sa.String(20), nullable=False),
|
||||
sa.Column("triage_reason", sa.Text, nullable=True),
|
||||
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
|
||||
sa.Column("triaged_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
|
||||
sa.Column("delivered_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("acked_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
|
||||
)
|
||||
op.create_index("ix_scout_triage_queue_user_status", "scout_triage_queue", ["user_id", "status"])
|
||||
op.create_index(
|
||||
"ix_scout_triage_queue_expires_active",
|
||||
"scout_triage_queue",
|
||||
["expires_at"],
|
||||
postgresql_where=sa.text("status != 'acked'"),
|
||||
)
|
||||
|
||||
op.add_column("cloud_scout_configs", sa.Column("auto_trash_spam", sa.Boolean(), nullable=False, server_default=sa.text("false")))
|
||||
op.add_column("cloud_scout_configs", sa.Column("gmail_history_id", sa.String(64), nullable=True))
|
||||
op.add_column("cloud_scout_configs", sa.Column("gmail_watch_expires_at", sa.DateTime(timezone=True), nullable=True))
|
||||
op.add_column("cloud_scout_configs", sa.Column("device_inactivity_pause_days", sa.Integer(), nullable=False, server_default="14"))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("cloud_scout_configs", "device_inactivity_pause_days")
|
||||
op.drop_column("cloud_scout_configs", "gmail_watch_expires_at")
|
||||
op.drop_column("cloud_scout_configs", "gmail_history_id")
|
||||
op.drop_column("cloud_scout_configs", "auto_trash_spam")
|
||||
|
||||
op.drop_index("ix_scout_triage_queue_expires_active", table_name="scout_triage_queue")
|
||||
op.drop_index("ix_scout_triage_queue_user_status", table_name="scout_triage_queue")
|
||||
op.drop_table("scout_triage_queue")
|
||||
25
api/alembic/versions/009_cloud_scout_gmail_address.py
Normal file
25
api/alembic/versions/009_cloud_scout_gmail_address.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Add gmail_address to cloud_scout_configs.
|
||||
|
||||
Revision ID: 009
|
||||
Revises: 008
|
||||
Create Date: 2026-05-16
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "009"
|
||||
down_revision: Union[str, None] = "008"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("cloud_scout_configs", sa.Column("gmail_address", sa.String(320), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("cloud_scout_configs", "gmail_address")
|
||||
38
api/alembic/versions/1f5975a4f3f4_add_extraction_queue.py
Normal file
38
api/alembic/versions/1f5975a4f3f4_add_extraction_queue.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""add extraction_queue
|
||||
|
||||
Revision ID: 1f5975a4f3f4
|
||||
Revises: 005
|
||||
Create Date: 2026-04-16 17:26:25.790870
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1f5975a4f3f4'
|
||||
down_revision: Union[str, None] = '005'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'extraction_queue',
|
||||
sa.Column('id', sa.Uuid(as_uuid=False), nullable=False),
|
||||
sa.Column('user_id', sa.Uuid(as_uuid=False), nullable=False),
|
||||
sa.Column('episode_id', sa.Uuid(as_uuid=False), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
op.create_index(op.f('ix_extraction_queue_user_id'), 'extraction_queue', ['user_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f('ix_extraction_queue_user_id'), table_name='extraction_queue')
|
||||
op.drop_table('extraction_queue')
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add name and surname to users table
|
||||
|
||||
Revision ID: 818478c251dc
|
||||
Revises: 004
|
||||
Create Date: 2026-03-10 15:10:42.811947
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '818478c251dc'
|
||||
down_revision: Union[str, None] = '004'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('users', sa.Column('name', sa.String(length=100), nullable=True))
|
||||
op.add_column('users', sa.Column('surname', sa.String(length=100), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('users', 'surname')
|
||||
op.drop_column('users', 'name')
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Deprecate backend agent config tables.
|
||||
|
||||
The Electron client is now the source of truth for agent configuration
|
||||
(directory, extract targets, batch interval, custom prompt). Backend keeps
|
||||
billing checks and trigger/run logs only.
|
||||
|
||||
Revision ID: 9a1f2d0b6c7e
|
||||
Revises: 818478c251dc
|
||||
Create Date: 2026-03-16
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "9a1f2d0b6c7e"
|
||||
down_revision: Union[str, None] = "818478c251dc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
existing = set(inspector.get_table_names())
|
||||
|
||||
if "cloud_agent_configs" in existing:
|
||||
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||
op.drop_table("cloud_agent_configs")
|
||||
|
||||
if "local_agent_configs" in existing:
|
||||
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||
op.drop_table("local_agent_configs")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"local_agent_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("device_id", sa.String(255), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"cloud_agent_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column(
|
||||
"provider",
|
||||
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Restore agent config tables and add agent_config column.
|
||||
|
||||
9a1f2d0b6c7e dropped local_agent_configs and cloud_agent_configs, but both
|
||||
ORM models are still active. This migration recreates them with agent_config
|
||||
added to local_agent_configs.
|
||||
|
||||
Revision ID: a3b9c0d1e2f3
|
||||
Revises: 9a1f2d0b6c7e
|
||||
Create Date: 2026-04-07 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a3b9c0d1e2f3"
|
||||
down_revision: Union[str, None] = "9a1f2d0b6c7e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Recreate enum types (idempotent — they may already exist from migration 003)
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
existing = set(inspector.get_table_names())
|
||||
|
||||
# ── local_agent_configs (with agent_config column) ────────────────────
|
||||
if "local_agent_configs" not in existing:
|
||||
op.create_table(
|
||||
"local_agent_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("device_id", sa.String(255), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("agent_config", sa.JSON, nullable=True),
|
||||
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||
|
||||
# ── cloud_agent_configs ───────────────────────────────────────────────
|
||||
if "cloud_agent_configs" not in existing:
|
||||
op.create_table(
|
||||
"cloud_agent_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column(
|
||||
"provider",
|
||||
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||
op.drop_table("cloud_agent_configs")
|
||||
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||
op.drop_table("local_agent_configs")
|
||||
56
api/alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py
Normal file
56
api/alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Add oauth_accounts table, nullable password_hash, avatar_url to users.
|
||||
|
||||
Revision ID: b4c0d1e2f3a4
|
||||
Revises: a3b9c0d1e2f3
|
||||
Create Date: 2026-04-10 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b4c0d1e2f3a4"
|
||||
down_revision: Union[str, None] = "a3b9c0d1e2f3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── users: make password_hash nullable (social users have no password) ──
|
||||
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=True)
|
||||
|
||||
# ── users: add avatar_url ─────────────────────────────────────────────
|
||||
op.add_column("users", sa.Column("avatar_url", sa.String(2048), nullable=True))
|
||||
|
||||
# ── oauth_accounts ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"oauth_accounts",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("provider", sa.String(50), nullable=False),
|
||||
sa.Column("provider_user_id", sa.String(255), nullable=False),
|
||||
sa.Column("provider_email", sa.String(255), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
|
||||
)
|
||||
op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts")
|
||||
op.drop_table("oauth_accounts")
|
||||
op.drop_column("users", "avatar_url")
|
||||
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=False)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Add onboarding_completed_at column to users table.
|
||||
|
||||
Revision ID: c5d1e2f3a4b5
|
||||
Revises: b4c0d1e2f3a4
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c5d1e2f3a4b5"
|
||||
down_revision: Union[str, None] = "b4c0d1e2f3a4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("users", "onboarding_completed_at")
|
||||
46
api/alembic/versions/d6e3f4a5b6c7_folder_index_tables.py
Normal file
46
api/alembic/versions/d6e3f4a5b6c7_folder_index_tables.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Add token tracking columns for folder integration.
|
||||
|
||||
Revision ID: d6e3f4a5b6c7
|
||||
Revises: 006
|
||||
Create Date: 2026-05-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d6e3f4a5b6c7"
|
||||
down_revision: Union[str, None] = "006"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"agent_run_logs",
|
||||
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.create_table(
|
||||
"monthly_token_usage",
|
||||
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("year_month", sa.String(7), nullable=False),
|
||||
sa.Column("feature", sa.String(64), nullable=False),
|
||||
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_monthly_token_usage_user_month",
|
||||
"monthly_token_usage",
|
||||
["user_id", "year_month"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
|
||||
op.drop_table("monthly_token_usage")
|
||||
op.drop_column("agent_run_logs", "tokens_used")
|
||||
@@ -0,0 +1,34 @@
|
||||
"""avatar_url_varchar_to_text
|
||||
|
||||
Revision ID: e04100e88ace
|
||||
Revises: c5d1e2f3a4b5
|
||||
Create Date: 2026-04-13 09:13:06.733674
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e04100e88ace'
|
||||
down_revision: Union[str, None] = 'c5d1e2f3a4b5'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column('users', 'avatar_url',
|
||||
existing_type=sa.VARCHAR(length=2048),
|
||||
type_=sa.Text(),
|
||||
existing_nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column('users', 'avatar_url',
|
||||
existing_type=sa.Text(),
|
||||
type_=sa.VARCHAR(length=2048),
|
||||
existing_nullable=True)
|
||||
0
api/app/__init__.py
Normal file
0
api/app/__init__.py
Normal file
5
api/app/agents/__init__.py
Normal file
5
api/app/agents/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Expose tool modules used by deep orchestrator-worker graphs."""
|
||||
|
||||
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
|
||||
|
||||
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||
52
api/app/agents/client_agent.py
Normal file
52
api/app/agents/client_agent.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Client agent — read-only tools for the clients table."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
|
||||
@tool
|
||||
async def list_clients(search: str = "", limit: int = 20) -> str:
|
||||
"""List clients, optionally filtered by a name/email substring search.
|
||||
|
||||
search: optional substring to match against client name or email.
|
||||
limit: max rows to return (default 20).
|
||||
"""
|
||||
filters: dict[str, Any] = {"limit": limit}
|
||||
if search:
|
||||
filters["search"] = search
|
||||
|
||||
result = await execute_on_client(action="select", table="clients", filters=filters)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No clients found."
|
||||
lines = [
|
||||
f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, "
|
||||
f"company: {r.get('company', '')})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} client(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_client(id: str) -> str:
|
||||
"""Get full details for one client by UUID.
|
||||
|
||||
id: the client's UUID.
|
||||
"""
|
||||
if not id:
|
||||
return "Client id is required."
|
||||
|
||||
result = await execute_on_client(action="get", table="clients", data={"id": id})
|
||||
row = result.get("row") or result.get("rows", [None])[0] if result else None
|
||||
if not row:
|
||||
return f"Client '{id}' not found."
|
||||
return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}"
|
||||
|
||||
|
||||
CLIENT_TOOLS: list[Any] = [list_clients, get_client]
|
||||
194
api/app/agents/filesystem_agent.py
Normal file
194
api/app/agents/filesystem_agent.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||
|
||||
These tools delegate to the Electron client via ``execute_on_client()`` using
|
||||
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
|
||||
handles actual disk I/O and responds with ``tool_result`` frames.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
# Max characters returned by read_file_content in journey (exploration) tools.
|
||||
# The journey only needs to understand file structure, not full content.
|
||||
_JOURNEY_READ_MAX_CHARS: int = 4000
|
||||
|
||||
|
||||
def _resolve_path(path: str, base: str) -> str:
|
||||
"""Resolve *path* against *base* when *path* is relative.
|
||||
|
||||
The LLM often passes ``"."`` meaning "the configured directory".
|
||||
Without this, Electron resolves ``"."`` relative to its own CWD instead
|
||||
of the user's chosen directory.
|
||||
"""
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return str(Path(base) / path)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_directory(path: str) -> str:
|
||||
"""List files and folders in a local directory on the user's device.
|
||||
|
||||
Returns a formatted listing of entries with name, type (file/directory),
|
||||
and full path.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="list_directory",
|
||||
data={"path": path},
|
||||
)
|
||||
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||
if not entries:
|
||||
return f"Directory '{path}' is empty or does not exist."
|
||||
lines: list[str] = []
|
||||
for entry in entries:
|
||||
entry_type = entry.get("type", "unknown")
|
||||
entry_name = entry.get("name", "")
|
||||
entry_path = entry.get("path", "")
|
||||
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def read_file_content(path: str) -> str:
|
||||
"""Read the text content of a local file on the user's device.
|
||||
|
||||
Returns the file content as a string. Large files may be truncated
|
||||
by the Electron client.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="read_file_content",
|
||||
data={"path": path},
|
||||
)
|
||||
content: str = result.get("content", "")
|
||||
if not content:
|
||||
return f"File '{path}' is empty or could not be read."
|
||||
return content
|
||||
|
||||
|
||||
@tool
|
||||
async def get_file_metadata(path: str) -> str:
|
||||
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||
|
||||
Returns a formatted summary of the file's metadata.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="get_file_metadata",
|
||||
data={"path": path},
|
||||
)
|
||||
size = result.get("size", "unknown")
|
||||
created = result.get("createdAt", "unknown")
|
||||
modified = result.get("modifiedAt", "unknown")
|
||||
extension = result.get("extension", "unknown")
|
||||
name = result.get("name", path)
|
||||
return (
|
||||
f"File: {name}\n"
|
||||
f" Extension: {extension}\n"
|
||||
f" Size: {size} bytes\n"
|
||||
f" Created: {created}\n"
|
||||
f" Modified: {modified}"
|
||||
)
|
||||
|
||||
|
||||
FILESYSTEM_TOOLS: list[Any] = [
|
||||
list_directory,
|
||||
read_file_content,
|
||||
get_file_metadata,
|
||||
]
|
||||
|
||||
|
||||
def make_directory_tools(base_directory: str) -> list[Any]:
|
||||
"""Return filesystem tools that resolve relative paths against *base_directory*.
|
||||
|
||||
Use this instead of ``FILESYSTEM_TOOLS`` whenever you know the user's target
|
||||
directory upfront (e.g., journey setup sessions). Relative paths like ``"."``
|
||||
from the LLM are resolved to the correct absolute path before being sent to
|
||||
the Electron client, preventing it from falling back to its own CWD.
|
||||
"""
|
||||
|
||||
def _compact_for_journey(raw: str) -> str:
|
||||
"""Strip HTML noise and truncate for journey exploration.
|
||||
|
||||
The journey LLM only needs to understand file structure (headers,
|
||||
first paragraphs). Full CSS/style blocks are pure noise that eat
|
||||
up context window budget.
|
||||
"""
|
||||
text = re.sub(r"<style[^>]*>.*?</style>", "", raw, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r"<script[^>]*>.*?</script>", "", text, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r"<!--.*?-->", "", text, flags=re.DOTALL)
|
||||
if len(text) > _JOURNEY_READ_MAX_CHARS:
|
||||
text = text[:_JOURNEY_READ_MAX_CHARS] + "\n[…truncated for exploration]"
|
||||
return text
|
||||
|
||||
@tool
|
||||
async def list_directory(path: str) -> str: # noqa: F811
|
||||
"""List files and folders in a local directory on the user's device.
|
||||
|
||||
Returns a formatted listing of entries with name, type (file/directory),
|
||||
and full path.
|
||||
"""
|
||||
resolved = _resolve_path(path, base_directory)
|
||||
result = await execute_on_client(
|
||||
action="list_directory",
|
||||
data={"path": resolved},
|
||||
)
|
||||
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||
if not entries:
|
||||
return f"Directory '{resolved}' is empty or does not exist."
|
||||
lines: list[str] = []
|
||||
for entry in entries:
|
||||
entry_type = entry.get("type", "unknown")
|
||||
entry_name = entry.get("name", "")
|
||||
entry_path = entry.get("path", "")
|
||||
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||
return f"Directory listing for '{resolved}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||
|
||||
@tool
|
||||
async def read_file_content(path: str) -> str: # noqa: F811
|
||||
"""Read the text content of a local file on the user's device.
|
||||
|
||||
Returns the file content as a string. Large files may be truncated
|
||||
by the Electron client.
|
||||
"""
|
||||
resolved = _resolve_path(path, base_directory)
|
||||
result = await execute_on_client(
|
||||
action="read_file_content",
|
||||
data={"path": resolved},
|
||||
)
|
||||
content: str = result.get("content", "")
|
||||
if not content:
|
||||
return f"File '{resolved}' is empty or could not be read."
|
||||
return _compact_for_journey(content)
|
||||
|
||||
@tool
|
||||
async def get_file_metadata(path: str) -> str: # noqa: F811
|
||||
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||
|
||||
Returns a formatted summary of the file's metadata.
|
||||
"""
|
||||
resolved = _resolve_path(path, base_directory)
|
||||
result = await execute_on_client(
|
||||
action="get_file_metadata",
|
||||
data={"path": resolved},
|
||||
)
|
||||
size = result.get("size", "unknown")
|
||||
created = result.get("createdAt", "unknown")
|
||||
modified = result.get("modifiedAt", "unknown")
|
||||
extension = result.get("extension", "unknown")
|
||||
name = result.get("name", resolved)
|
||||
return (
|
||||
f"File: {name}\n"
|
||||
f" Extension: {extension}\n"
|
||||
f" Size: {size} bytes\n"
|
||||
f" Created: {created}\n"
|
||||
f" Modified: {modified}"
|
||||
)
|
||||
|
||||
return [list_directory, read_file_content, get_file_metadata]
|
||||
168
api/app/agents/folder_agent.py
Normal file
168
api/app/agents/folder_agent.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Scoped file-read and search tools for the project folder feature."""
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
# Cap returned slice size to keep tool output under control.
|
||||
_MAX_RETURN_CHARS = 50_000
|
||||
_MAX_SEARCH_MATCHES = 20
|
||||
|
||||
|
||||
def _is_unsafe_path(rel: str) -> bool:
|
||||
if not rel:
|
||||
return True
|
||||
norm = rel.replace("\\", "/")
|
||||
if norm.startswith("/"):
|
||||
return True
|
||||
# Windows drive letter
|
||||
if len(rel) >= 2 and rel[1] == ":":
|
||||
return True
|
||||
parts = norm.split("/")
|
||||
return ".." in parts
|
||||
|
||||
|
||||
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
|
||||
"""Return the raw Electron tool_result dict for a file read."""
|
||||
return await execute_on_client(
|
||||
action="read_project_folder_file",
|
||||
data={
|
||||
"projectId": project_id,
|
||||
"relativePath": relative_path,
|
||||
"offset": offset,
|
||||
"length": length,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _decode(result: dict) -> tuple[str, str, int]:
|
||||
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
|
||||
extracts text from base64. For images, returns a placeholder string.
|
||||
For text, content is already a sliced utf-8 string.
|
||||
"""
|
||||
kind = result.get("kind", "text")
|
||||
content = result.get("content", "") or ""
|
||||
total = int(result.get("totalSize", 0) or 0)
|
||||
if kind == "image":
|
||||
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
|
||||
if kind == "pdf":
|
||||
return (_extract_pdf_text(content), kind, total)
|
||||
if kind == "docx":
|
||||
return (_extract_docx_text(content), kind, total)
|
||||
return (content, kind, total)
|
||||
|
||||
|
||||
@tool
|
||||
async def read_project_folder_file(
|
||||
project_id: str,
|
||||
relative_path: str,
|
||||
offset: int = 0,
|
||||
length: int = _MAX_RETURN_CHARS,
|
||||
) -> str:
|
||||
"""Read a slice of a file inside the project's linked folder.
|
||||
|
||||
Args:
|
||||
project_id: project ID.
|
||||
relative_path: path relative to the linked folder root.
|
||||
offset: char offset to start reading from (0 = beginning).
|
||||
length: max chars to return. Default 50000. Use smaller values to save tokens.
|
||||
|
||||
Returns text content slice with a header showing position. Header tells you
|
||||
when more content is available; call again with the suggested next offset.
|
||||
|
||||
For PDF / DOCX files the backend extracts text first, then applies offset/length
|
||||
on the extracted text. For images returns a placeholder; navigate with the
|
||||
manifest summary instead.
|
||||
"""
|
||||
if _is_unsafe_path(relative_path):
|
||||
return "Access denied"
|
||||
|
||||
result = await _fetch_file(project_id, relative_path, offset, length)
|
||||
text, kind, total_size = _decode(result)
|
||||
|
||||
if not text and kind in ("missing", "error"):
|
||||
return f"File not found or unreadable: {relative_path}"
|
||||
|
||||
if kind in ("pdf", "docx"):
|
||||
# Backend extracted full text — apply offset/length on chars.
|
||||
sliced = text[offset:offset + length]
|
||||
slice_end = min(offset + length, len(text))
|
||||
header = (
|
||||
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
|
||||
f"totalChars={len(text)}]"
|
||||
)
|
||||
if slice_end < len(text):
|
||||
header += f"\n[More content available — call again with offset={slice_end}.]"
|
||||
return header + "\n" + sliced
|
||||
|
||||
if kind == "text":
|
||||
slice_end = offset + len(text)
|
||||
header = (
|
||||
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
|
||||
f"totalBytes={total_size}]"
|
||||
)
|
||||
if slice_end < total_size:
|
||||
header += f"\n[More content available — call again with offset={slice_end}.]"
|
||||
return header + "\n" + text
|
||||
|
||||
# image or unknown
|
||||
return text
|
||||
|
||||
|
||||
@tool
|
||||
async def search_project_folder_file(
|
||||
project_id: str,
|
||||
relative_path: str,
|
||||
query: str,
|
||||
context_lines: int = 3,
|
||||
) -> str:
|
||||
"""Search a project folder file for a query string (case-insensitive substring).
|
||||
|
||||
Args:
|
||||
project_id: project ID.
|
||||
relative_path: path relative to the linked folder root.
|
||||
query: text to search for.
|
||||
context_lines: number of lines of context around each match (default 3).
|
||||
|
||||
Returns matching line ranges with surrounding context and 1-based line numbers.
|
||||
Capped at 20 matches; if more exist the header shows the total.
|
||||
|
||||
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
|
||||
Images and binary files are not searchable.
|
||||
"""
|
||||
if _is_unsafe_path(relative_path):
|
||||
return "Access denied"
|
||||
if not query:
|
||||
return "Empty query."
|
||||
|
||||
# For text we still need full file; pass length=very large.
|
||||
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
|
||||
text, kind, _ = _decode(result)
|
||||
|
||||
if not text and kind in ("missing", "error"):
|
||||
return f"File not found or unreadable: {relative_path}"
|
||||
if kind == "image":
|
||||
return "Cannot search inside images."
|
||||
|
||||
lines = text.splitlines()
|
||||
q = query.lower()
|
||||
matches = [i for i, line in enumerate(lines) if q in line.lower()]
|
||||
if not matches:
|
||||
return f"No matches for '{query}' in {relative_path}."
|
||||
|
||||
shown = matches[:_MAX_SEARCH_MATCHES]
|
||||
snippets: list[str] = []
|
||||
for i in shown:
|
||||
start = max(0, i - context_lines)
|
||||
end = min(len(lines), i + context_lines + 1)
|
||||
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
|
||||
snippets.append(block)
|
||||
|
||||
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
|
||||
body = "\n---\n".join(snippets)
|
||||
return header + "\n" + body
|
||||
|
||||
|
||||
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]
|
||||
206
api/app/agents/note_agent.py
Normal file
206
api/app/agents/note_agent.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""Note agent — Markdown note management (list, get, create, update, propose edit)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.note_summarizer import generate_note_summary
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
|
||||
def _fmt_summary(row: dict) -> str:
|
||||
summary = (row.get("aiSummary") or row.get("ai_summary") or "").strip()
|
||||
if summary:
|
||||
return f" — {summary}"
|
||||
snippet = (row.get("content") or "")[:120].replace("\n", " ").strip()
|
||||
return f" — {snippet}" if snippet else ""
|
||||
|
||||
|
||||
@tool
|
||||
async def list_notes(project_id: str = "") -> str:
|
||||
"""List notes with AI summaries, optionally scoped to a project by project_id.
|
||||
|
||||
Returns id, title, and ai_summary for each note so you can decide which
|
||||
note to read in full with get_note before creating or updating.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="notes",
|
||||
filters={"projectId": normalized_project_id or None},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No notes found."
|
||||
lines = [f" - [{r['id']}] {r['title']}{_fmt_summary(r)}" for r in rows]
|
||||
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_note(note_id: str) -> str:
|
||||
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
|
||||
row = result.get("row")
|
||||
if not row:
|
||||
return f"Note {note_id} not found."
|
||||
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
|
||||
|
||||
|
||||
@tool
|
||||
async def create_note(
|
||||
title: str,
|
||||
content: str,
|
||||
project_id: str = "",
|
||||
) -> str:
|
||||
"""Create a new note.
|
||||
title: note heading (required)
|
||||
content: Markdown body text (required)
|
||||
project_id: optional UUID linking this note to a project
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="notes",
|
||||
data={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"projectId": project_id or None,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
note_id: str = row["id"]
|
||||
# Generate summary asynchronously — fire-and-forget.
|
||||
asyncio.create_task(_refresh_summary(note_id, title, content))
|
||||
return f"Note created: '{row['title']}' (id: {note_id})."
|
||||
|
||||
|
||||
@tool
|
||||
async def update_note(
|
||||
note_id: str,
|
||||
title: str = "",
|
||||
content: str = "",
|
||||
) -> str:
|
||||
"""Update an existing note directly (no approval required).
|
||||
Use propose_note_edit instead when human review is needed.
|
||||
note_id: UUID of the note (required)
|
||||
If you need to preserve existing content, call get_note first.
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if content:
|
||||
updates["content"] = content
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="notes",
|
||||
data={"id": note_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
if content:
|
||||
new_title = title or row.get("title", "")
|
||||
asyncio.create_task(_refresh_summary(note_id, new_title, content))
|
||||
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||
|
||||
|
||||
@tool
|
||||
async def propose_note_edit(
|
||||
note_id: str,
|
||||
edit_type: str,
|
||||
proposed_content: str,
|
||||
reasoning: str = "",
|
||||
anchor_before: str = "",
|
||||
anchor_text: str = "",
|
||||
agent_id: str = "",
|
||||
run_id: str = "",
|
||||
) -> str:
|
||||
"""Propose an AI edit to an existing note, pending human approval.
|
||||
|
||||
Use this instead of update_note when review_required is true.
|
||||
The user will see the proposal highlighted before it is merged.
|
||||
|
||||
note_id: UUID of the target note (required)
|
||||
edit_type: 'append' | 'insert' | 'replace'
|
||||
- append: adds proposed_content at the end of the note
|
||||
- insert: inserts proposed_content immediately after anchor_before text
|
||||
- replace: replaces the first occurrence of anchor_text with proposed_content
|
||||
proposed_content: the new Markdown text to add or substitute (required)
|
||||
reasoning: brief explanation shown to the user (recommended)
|
||||
anchor_before: for 'insert' — the text snippet that precedes the insertion point
|
||||
anchor_text: for 'replace' — the exact text to be replaced
|
||||
agent_id: agent identifier (for traceability)
|
||||
run_id: run identifier (for traceability)
|
||||
"""
|
||||
if edit_type not in ("append", "insert", "replace"):
|
||||
return f"Invalid edit_type '{edit_type}'. Use 'append', 'insert', or 'replace'."
|
||||
|
||||
result = await execute_on_client(
|
||||
action="propose_note_edit",
|
||||
data={
|
||||
"noteId": note_id,
|
||||
"type": edit_type,
|
||||
"proposedContent": proposed_content,
|
||||
"reasoning": reasoning or None,
|
||||
"anchorBefore": anchor_before or None,
|
||||
"anchorText": anchor_text or None,
|
||||
"agentId": agent_id or None,
|
||||
"runId": run_id or None,
|
||||
},
|
||||
)
|
||||
edit_id = result.get("id", "?")
|
||||
return (
|
||||
f"Edit proposal created (id: {edit_id}) for note {note_id}. "
|
||||
f"Status: pending user approval."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_note(note_id: str) -> str:
|
||||
"""Delete a note permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="notes", data={"id": note_id})
|
||||
return f"Note {note_id} deleted."
|
||||
|
||||
|
||||
async def _refresh_summary(note_id: str, title: str, content: str) -> None:
|
||||
"""Generate and persist the AI summary for a note. Fire-and-forget."""
|
||||
try:
|
||||
summary = await generate_note_summary(title, content)
|
||||
if summary:
|
||||
await execute_on_client(
|
||||
action="update",
|
||||
table="notes",
|
||||
data={
|
||||
"id": note_id,
|
||||
"updates": {
|
||||
"aiSummary": summary,
|
||||
"aiSummaryUpdatedAt": int(__import__("time").time() * 1000),
|
||||
},
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass # fire-and-forget; errors logged by generate_note_summary
|
||||
|
||||
|
||||
NOTE_TOOLS: list[Any] = [
|
||||
list_notes,
|
||||
get_note,
|
||||
create_note,
|
||||
update_note,
|
||||
propose_note_edit,
|
||||
delete_note,
|
||||
]
|
||||
|
||||
NOTE_READ_TOOLS: list[Any] = [
|
||||
list_notes,
|
||||
get_note,
|
||||
]
|
||||
133
api/app/agents/project_agent.py
Normal file
133
api/app/agents/project_agent.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
|
||||
@tool
|
||||
async def list_projects(
|
||||
client_id: str = "",
|
||||
include_archived: int = 0,
|
||||
) -> str:
|
||||
"""List projects, optionally filtered by client_id.
|
||||
include_archived: 1 to include archived projects, 0 for active only (default).
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="projects",
|
||||
filters={
|
||||
"clientId": client_id or None,
|
||||
"includeArchived": bool(include_archived),
|
||||
},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No projects found."
|
||||
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_all_projects() -> str:
|
||||
"""List every project regardless of client or status.
|
||||
Use only when the user wants a complete cross-client overview.
|
||||
"""
|
||||
result = await execute_on_client(action="select", table="projects")
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No projects found."
|
||||
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_project(project_id: str) -> str:
|
||||
"""Fetch a single project by its UUID."""
|
||||
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
|
||||
row = result.get("row")
|
||||
if not row:
|
||||
return f"Project {project_id} not found."
|
||||
return (
|
||||
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
|
||||
f"clientId: {row.get('clientId', 'none')})"
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def create_project(
|
||||
name: str,
|
||||
client_id: str = "",
|
||||
) -> str:
|
||||
"""Create a new project.
|
||||
name: human-readable project name (required)
|
||||
client_id: optional UUID of the owning client
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="projects",
|
||||
data={"name": name, "clientId": client_id or None},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Project created: '{row['name']}' (id: {row['id']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def update_project(
|
||||
project_id: str,
|
||||
name: str = "",
|
||||
client_id: str = "",
|
||||
status: str = "",
|
||||
ai_summary: str = "",
|
||||
) -> str:
|
||||
"""Update a project. Only pass fields that should change.
|
||||
project_id: UUID of the project (required)
|
||||
status: active | archived
|
||||
ai_summary: AI-generated summary text (populate only when explicitly requested)
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if name:
|
||||
updates["name"] = name
|
||||
if client_id:
|
||||
updates["clientId"] = client_id
|
||||
if status:
|
||||
updates["status"] = status
|
||||
if ai_summary:
|
||||
updates["aiSummary"] = ai_summary
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="projects",
|
||||
data={"id": project_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_project(project_id: str) -> str:
|
||||
"""Permanently delete a project and orphan its tasks.
|
||||
IMPORTANT: prefer update_project(status='archived') unless the user
|
||||
has explicitly confirmed they want permanent deletion.
|
||||
"""
|
||||
await execute_on_client(action="delete", table="projects", data={"id": project_id})
|
||||
return f"Project {project_id} permanently deleted."
|
||||
|
||||
|
||||
PROJECT_TOOLS: list[Any] = [
|
||||
list_projects,
|
||||
list_all_projects,
|
||||
get_project,
|
||||
create_project,
|
||||
update_project,
|
||||
delete_project,
|
||||
]
|
||||
|
||||
PROJECT_READ_TOOLS: list[Any] = [
|
||||
list_projects,
|
||||
list_all_projects,
|
||||
get_project,
|
||||
]
|
||||
63
api/app/agents/relations_agent.py
Normal file
63
api/app/agents/relations_agent.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import async_session
|
||||
|
||||
# Injected at tool-factory time by _brief_research_tools(); not a module-level global.
|
||||
# Each tool closure captures the user_id bound at factory time.
|
||||
|
||||
|
||||
def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any:
|
||||
"""Return a query_relations tool bound to *user_id*."""
|
||||
|
||||
@tool
|
||||
async def query_relations(
|
||||
subject_label: str = "",
|
||||
predicate: str = "",
|
||||
object_label: str = "",
|
||||
limit: int = 10,
|
||||
) -> str:
|
||||
"""Query the relational memory graph for entity relationships.
|
||||
|
||||
Returns rows where subject ↔ predicate ↔ object match the given filters.
|
||||
All parameters are optional — omit to retrieve all relations up to limit.
|
||||
|
||||
subject_label: entity label on the left side (e.g. a client name, "Acme Corp").
|
||||
predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to").
|
||||
object_label: entity label on the right side (e.g. a project name, "Website Redesign").
|
||||
limit: max rows to return (default 10).
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(
|
||||
"relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r",
|
||||
trace_id or "-", user_id, subject_label, predicate, object_label,
|
||||
)
|
||||
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
rows = await memory.query_relations(
|
||||
user_id=user_id,
|
||||
subject=subject_label or None,
|
||||
predicate=predicate or None,
|
||||
object_=object_label or None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return "No relational memory entries found for the given filters."
|
||||
|
||||
lines = [
|
||||
f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}"
|
||||
+ (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "")
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} relation(s):\n" + "\n".join(lines)
|
||||
|
||||
return query_relations
|
||||
358
api/app/agents/task_agent.py
Normal file
358
api/app/agents/task_agent.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""Task agent — full CRUD for tasks and task comments."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
|
||||
# ── Task tools ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks(
|
||||
project_id: str = "",
|
||||
status: str = "",
|
||||
priority: str = "",
|
||||
assignee: str = "",
|
||||
search: str = "",
|
||||
order_by: str = "",
|
||||
order_dir: str = "",
|
||||
due_date_from: int = -1,
|
||||
due_date_to: int = -1,
|
||||
created_at_from: int = -1,
|
||||
created_at_to: int = -1,
|
||||
completed_at_from: int = -1,
|
||||
completed_at_to: int = -1,
|
||||
is_ai_suggested: int = -1,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> str:
|
||||
"""List tasks with optional filters. Returns up to `limit` results (default 50).
|
||||
|
||||
project_id: UUID of the project to scope results to.
|
||||
status: filter by status — todo | in_progress | done.
|
||||
priority: filter by priority — high | medium | low.
|
||||
assignee: substring to match against assignee names. OMIT unless the user explicitly
|
||||
names a person or refers to themselves ("my tasks", "assigned to me", "mine").
|
||||
Do NOT default to the current user.
|
||||
search: substring search across title and description.
|
||||
order_by: sort field — dueDate | priority | createdAt | completedAt.
|
||||
order_dir: asc (default) | desc.
|
||||
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
|
||||
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
|
||||
limit: max rows to return (default 50). Use with offset to paginate.
|
||||
offset: skip first N rows (default 0).
|
||||
|
||||
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
|
||||
Tip — prefer count_tasks for "how many" questions to avoid listing rows.
|
||||
Tip — for natural-language windows ("today", "tomorrow", "this week", "last month", etc.)
|
||||
take due_date_from / due_date_to verbatim from the DATE CONTEXT block in the system prompt;
|
||||
do not compute boundaries from the current UTC instant.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
filters: dict[str, Any] = {
|
||||
"projectId": normalized_project_id or None,
|
||||
"status": status or None,
|
||||
"priority": priority or None,
|
||||
"search": search or None,
|
||||
"orderBy": order_by or None,
|
||||
"orderDir": order_dir or None,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if assignee:
|
||||
filters["assignee"] = assignee
|
||||
if due_date_from != -1:
|
||||
filters["dueDateFrom"] = due_date_from
|
||||
if due_date_to != -1:
|
||||
filters["dueDateTo"] = due_date_to
|
||||
if created_at_from != -1:
|
||||
filters["createdAtFrom"] = created_at_from
|
||||
if created_at_to != -1:
|
||||
filters["createdAtTo"] = created_at_to
|
||||
if completed_at_from != -1:
|
||||
filters["completedAtFrom"] = completed_at_from
|
||||
if completed_at_to != -1:
|
||||
filters["completedAtTo"] = completed_at_to
|
||||
if is_ai_suggested != -1:
|
||||
filters["isAiSuggested"] = is_ai_suggested
|
||||
|
||||
result = await execute_on_client(action="select", table="tasks", filters=filters)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks found matching the given filters."
|
||||
lines = [
|
||||
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, "
|
||||
f"dueDate: {r.get('dueDate')}, completedAt: {r.get('completedAt')}, "
|
||||
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def count_tasks(
|
||||
project_id: str = "",
|
||||
status: str = "",
|
||||
priority: str = "",
|
||||
assignee: str = "",
|
||||
search: str = "",
|
||||
due_date_from: int = -1,
|
||||
due_date_to: int = -1,
|
||||
created_at_from: int = -1,
|
||||
created_at_to: int = -1,
|
||||
completed_at_from: int = -1,
|
||||
completed_at_to: int = -1,
|
||||
is_ai_suggested: int = -1,
|
||||
) -> str:
|
||||
"""Count tasks matching the given filters without returning rows.
|
||||
|
||||
Use this instead of list_tasks for "how many" questions — it is much cheaper.
|
||||
Same filter parameters as list_tasks (no limit/offset/order_by needed).
|
||||
assignee: OMIT unless the user explicitly names a person or refers to themselves
|
||||
("my tasks"). Do NOT default to the current user.
|
||||
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
|
||||
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||
Tip — for natural-language windows take due_date_from / due_date_to from the DATE CONTEXT block;
|
||||
do not compute boundaries from the current UTC instant.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
filters: dict[str, Any] = {
|
||||
"projectId": normalized_project_id or None,
|
||||
"status": status or None,
|
||||
"priority": priority or None,
|
||||
"search": search or None,
|
||||
}
|
||||
if assignee:
|
||||
filters["assignee"] = assignee
|
||||
if due_date_from != -1:
|
||||
filters["dueDateFrom"] = due_date_from
|
||||
if due_date_to != -1:
|
||||
filters["dueDateTo"] = due_date_to
|
||||
if created_at_from != -1:
|
||||
filters["createdAtFrom"] = created_at_from
|
||||
if created_at_to != -1:
|
||||
filters["createdAtTo"] = created_at_to
|
||||
if completed_at_from != -1:
|
||||
filters["completedAtFrom"] = completed_at_from
|
||||
if completed_at_to != -1:
|
||||
filters["completedAtTo"] = completed_at_to
|
||||
if is_ai_suggested != -1:
|
||||
filters["isAiSuggested"] = is_ai_suggested
|
||||
|
||||
result = await execute_on_client(action="count", table="tasks", filters=filters)
|
||||
return f"Task count: {result.get('count', 0)}"
|
||||
|
||||
|
||||
@tool
|
||||
async def create_task(
|
||||
title: str,
|
||||
description: str = "",
|
||||
status: str = "todo",
|
||||
priority: str = "medium",
|
||||
assignees: str = "[]",
|
||||
due_date: int = 0,
|
||||
project_id: str = "",
|
||||
is_ai_suggested: int = 0,
|
||||
) -> str:
|
||||
"""Create a new task.
|
||||
title: task title (required)
|
||||
description: optional details
|
||||
status: todo | in_progress | done (default: todo)
|
||||
priority: high | medium | low (default: medium)
|
||||
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
|
||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||
project_id: optional UUID of the parent project
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
|
||||
completedAt is set automatically when status is 'done'.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="tasks",
|
||||
data={
|
||||
"title": title,
|
||||
"description": description or None,
|
||||
"status": status,
|
||||
"priority": priority,
|
||||
"assignee": assignees,
|
||||
"dueDate": due_date or None,
|
||||
"projectId": project_id or None,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
return (
|
||||
f"Task created: '{row['title']}' "
|
||||
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']}, projectId: {row.get('projectId')})"
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def update_task(
|
||||
task_id: str,
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
status: str = "",
|
||||
priority: str = "",
|
||||
assignees: str = "",
|
||||
due_date: int = -1,
|
||||
project_id: str = "",
|
||||
) -> str:
|
||||
"""Update fields on an existing task. Only pass fields you want to change.
|
||||
task_id: the task's UUID (required)
|
||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||
|
||||
completedAt is managed automatically:
|
||||
- setting status to 'done' records the current timestamp
|
||||
- changing status away from 'done' clears completedAt
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if description:
|
||||
updates["description"] = description
|
||||
if status:
|
||||
updates["status"] = status
|
||||
if priority:
|
||||
updates["priority"] = priority
|
||||
if assignees:
|
||||
updates["assignee"] = assignees
|
||||
if due_date != -1:
|
||||
updates["dueDate"] = due_date or None
|
||||
if project_id:
|
||||
updates["projectId"] = project_id
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="tasks",
|
||||
data={"id": task_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']}, projectId: {row.get('projectId')})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task(task_id: str) -> str:
|
||||
"""Delete a task permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
|
||||
return f"Task {task_id} deleted."
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks_due_today(user_timezone: str = "UTC", include_done: bool = False) -> str:
|
||||
"""List all tasks whose due date falls on today's date.
|
||||
|
||||
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
|
||||
Always pass the user's timezone so 'today' is computed in their local time.
|
||||
include_done: set True to also include already-completed tasks due today (default False).
|
||||
"""
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
tz = ZoneInfo(user_timezone or "UTC")
|
||||
except Exception:
|
||||
tz = timezone.utc
|
||||
now_local = datetime.now(tz=tz)
|
||||
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
|
||||
start_ms = int(start_dt.timestamp() * 1000)
|
||||
end_ms = start_ms + 86_400_000 - 1
|
||||
filters: dict[str, Any] = {"dueDateFrom": start_ms, "dueDateTo": end_ms}
|
||||
if not include_done:
|
||||
filters["status"] = "todo"
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="tasks",
|
||||
filters=filters,
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks are due today."
|
||||
lines = [
|
||||
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, "
|
||||
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
# ── Task comment tools ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@tool
|
||||
async def list_task_comments(task_id: str) -> str:
|
||||
"""List all comments on a task by its UUID."""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="taskComments",
|
||||
filters={"taskId": task_id},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return f"No comments found for task {task_id}."
|
||||
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||
"""Add a comment to a task.
|
||||
task_id: UUID of the task to comment on
|
||||
author: name or ID of the comment author
|
||||
content: comment text
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="taskComments",
|
||||
data={"taskId": task_id, "author": author, "content": content},
|
||||
)
|
||||
row = result.get("row", {})
|
||||
row_author = row.get("author", author)
|
||||
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||
row_comment_id = row.get("id", "unknown")
|
||||
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task_comment(comment_id: str) -> str:
|
||||
"""Delete a task comment by its UUID."""
|
||||
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
|
||||
return f"Comment {comment_id} deleted."
|
||||
|
||||
|
||||
# ── Agent ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
TASK_TOOLS: list[Any] = [
|
||||
list_tasks,
|
||||
count_tasks,
|
||||
create_task,
|
||||
update_task,
|
||||
delete_task,
|
||||
list_tasks_due_today,
|
||||
list_task_comments,
|
||||
add_task_comment,
|
||||
delete_task_comment,
|
||||
]
|
||||
|
||||
TASK_READ_TOOLS: list[Any] = [
|
||||
list_tasks,
|
||||
count_tasks,
|
||||
list_tasks_due_today,
|
||||
list_task_comments,
|
||||
]
|
||||
270
api/app/agents/timeline_agent.py
Normal file
270
api/app/agents/timeline_agent.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Timeline agent — project milestone management (list, create, update, delete)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
|
||||
@tool
|
||||
async def list_timelines(
|
||||
project_id: str = "",
|
||||
type: str = "",
|
||||
is_completed: int = -1,
|
||||
is_ai_suggested: int = -1,
|
||||
order_by: str = "",
|
||||
order_dir: str = "",
|
||||
date_from: int = -1,
|
||||
date_to: int = -1,
|
||||
created_at_from: int = -1,
|
||||
created_at_to: int = -1,
|
||||
completed_at_from: int = -1,
|
||||
completed_at_to: int = -1,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> str:
|
||||
"""List timeline events (milestones, checkpoints, activities) with optional filters.
|
||||
|
||||
project_id: UUID to scope results to a specific project.
|
||||
type: filter by event type — milestone | checkpoint | activity.
|
||||
is_completed: 0 = incomplete only, 1 = completed only, -1 = any (default).
|
||||
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
|
||||
order_by: sort field — date (default) | createdAt | completedAt.
|
||||
order_dir: asc (default) | desc.
|
||||
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
|
||||
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||
limit: max rows to return (default 50). Use with offset to paginate.
|
||||
offset: skip first N rows (default 0).
|
||||
|
||||
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
|
||||
Tip — prefer count_timelines for "how many" questions to avoid listing rows.
|
||||
Tip — for natural-language windows ("today", "this week", "last month", etc.)
|
||||
take date_from / date_to verbatim from the DATE CONTEXT block in the system prompt;
|
||||
do not compute boundaries from the current UTC instant.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
filters: dict[str, Any] = {
|
||||
"projectId": normalized_project_id or None,
|
||||
"orderBy": order_by or None,
|
||||
"orderDir": order_dir or None,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if type:
|
||||
filters["type"] = type
|
||||
if is_completed != -1:
|
||||
filters["isCompleted"] = is_completed
|
||||
if is_ai_suggested != -1:
|
||||
filters["isAiSuggested"] = is_ai_suggested
|
||||
if date_from != -1:
|
||||
filters["dateFrom"] = date_from
|
||||
if date_to != -1:
|
||||
filters["dateTo"] = date_to
|
||||
if created_at_from != -1:
|
||||
filters["createdAtFrom"] = created_at_from
|
||||
if created_at_to != -1:
|
||||
filters["createdAtTo"] = created_at_to
|
||||
if completed_at_from != -1:
|
||||
filters["completedAtFrom"] = completed_at_from
|
||||
if completed_at_to != -1:
|
||||
filters["completedAtTo"] = completed_at_to
|
||||
|
||||
result = await execute_on_client(action="select", table="timelines", filters=filters)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No timeline events found."
|
||||
lines = [
|
||||
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
|
||||
f"completed: {bool(r.get('isCompleted'))}, completedAt: {r.get('completedAt')}, "
|
||||
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} timeline event(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def count_timelines(
|
||||
project_id: str = "",
|
||||
type: str = "",
|
||||
is_completed: int = -1,
|
||||
is_ai_suggested: int = -1,
|
||||
date_from: int = -1,
|
||||
date_to: int = -1,
|
||||
created_at_from: int = -1,
|
||||
created_at_to: int = -1,
|
||||
completed_at_from: int = -1,
|
||||
completed_at_to: int = -1,
|
||||
) -> str:
|
||||
"""Count timeline events matching the given filters without returning rows.
|
||||
|
||||
Use this instead of list_timelines for "how many" questions — it is much cheaper.
|
||||
Same filter parameters as list_timelines (no limit/offset/order_by needed).
|
||||
|
||||
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
|
||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||
Tip — for natural-language windows take date_from / date_to from the DATE CONTEXT block;
|
||||
do not compute boundaries from the current UTC instant.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
filters: dict[str, Any] = {"projectId": normalized_project_id or None}
|
||||
if type:
|
||||
filters["type"] = type
|
||||
if is_completed != -1:
|
||||
filters["isCompleted"] = is_completed
|
||||
if is_ai_suggested != -1:
|
||||
filters["isAiSuggested"] = is_ai_suggested
|
||||
if date_from != -1:
|
||||
filters["dateFrom"] = date_from
|
||||
if date_to != -1:
|
||||
filters["dateTo"] = date_to
|
||||
if created_at_from != -1:
|
||||
filters["createdAtFrom"] = created_at_from
|
||||
if created_at_to != -1:
|
||||
filters["createdAtTo"] = created_at_to
|
||||
if completed_at_from != -1:
|
||||
filters["completedAtFrom"] = completed_at_from
|
||||
if completed_at_to != -1:
|
||||
filters["completedAtTo"] = completed_at_to
|
||||
|
||||
result = await execute_on_client(action="count", table="timelines", filters=filters)
|
||||
return f"Timeline event count: {result.get('count', 0)}"
|
||||
|
||||
|
||||
@tool
|
||||
async def create_timeline(
|
||||
project_id: str,
|
||||
title: str,
|
||||
date: int,
|
||||
type: str = "milestone",
|
||||
is_completed: int = 0,
|
||||
is_ai_suggested: int = 0,
|
||||
) -> str:
|
||||
"""Create a project timeline event.
|
||||
project_id: REQUIRED UUID of the parent project
|
||||
title: descriptive name for the event
|
||||
date: Unix timestamp in milliseconds for the event date
|
||||
type: milestone (default) | checkpoint | activity
|
||||
is_completed: 1 if already completed, 0 if not (default 0)
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
|
||||
completedAt is set automatically when is_completed is 1.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="timelines",
|
||||
data={
|
||||
"projectId": project_id,
|
||||
"title": title,
|
||||
"date": date,
|
||||
"type": type,
|
||||
"isCompleted": is_completed,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Timeline event created: '{row['title']}' (id: {row['id']}, date: {row['date']}, type: {row.get('type')})"
|
||||
|
||||
|
||||
@tool
|
||||
async def update_timeline(
|
||||
timeline_id: str,
|
||||
title: str = "",
|
||||
date: int = -1,
|
||||
is_completed: int = -1,
|
||||
) -> str:
|
||||
"""Update a timeline event. Only pass fields that should change.
|
||||
timeline_id: UUID of the event (required)
|
||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||
is_completed: 0 = mark incomplete, 1 = mark complete, -1 = unchanged
|
||||
|
||||
completedAt is managed automatically:
|
||||
- setting is_completed to 1 records the current timestamp
|
||||
- setting is_completed to 0 clears completedAt
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if date != -1:
|
||||
updates["date"] = date
|
||||
if is_completed != -1:
|
||||
updates["isCompleted"] = is_completed
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="timelines",
|
||||
data={"id": timeline_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Timeline event updated: '{row['title']}' (id: {row['id']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_timeline(timeline_id: str) -> str:
|
||||
"""Delete a timeline event permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||
return f"Timeline event {timeline_id} deleted."
|
||||
|
||||
|
||||
@tool
|
||||
async def list_timelines_today(user_timezone: str = "UTC", include_completed: bool = True) -> str:
|
||||
"""List all timeline events whose date falls on today.
|
||||
|
||||
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
|
||||
Always pass the user's timezone so 'today' is computed in their local time.
|
||||
include_completed: set False to exclude already-completed events (default True).
|
||||
"""
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
tz = ZoneInfo(user_timezone or "UTC")
|
||||
except Exception:
|
||||
tz = timezone.utc
|
||||
now_local = datetime.now(tz=tz)
|
||||
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
|
||||
start_ms = int(start_dt.timestamp() * 1000)
|
||||
end_ms = start_ms + 86_400_000 - 1
|
||||
filters: dict[str, Any] = {"dateFrom": start_ms, "dateTo": end_ms}
|
||||
if not include_completed:
|
||||
filters["isCompleted"] = 0
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="timelines",
|
||||
filters=filters,
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No timeline events today."
|
||||
lines = [
|
||||
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
|
||||
f"completed: {bool(r.get('isCompleted'))}, projectId: {r.get('projectId')}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Timeline events today ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
TIMELINE_TOOLS: list[Any] = [
|
||||
list_timelines,
|
||||
count_timelines,
|
||||
list_timelines_today,
|
||||
create_timeline,
|
||||
update_timeline,
|
||||
delete_timeline,
|
||||
]
|
||||
|
||||
TIMELINE_READ_TOOLS: list[Any] = [
|
||||
list_timelines,
|
||||
count_timelines,
|
||||
list_timelines_today,
|
||||
]
|
||||
0
api/app/api/__init__.py
Normal file
0
api/app/api/__init__.py
Normal file
14
api/app/api/deps.py
Normal file
14
api/app/api/deps.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Shared FastAPI dependencies.
|
||||
|
||||
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
|
||||
(the canonical location per Step 9). This module re-exports them so that all
|
||||
existing route imports (``from app.api.deps import get_current_user``) continue
|
||||
to work without modification.
|
||||
|
||||
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
|
||||
instead of reading it from the JWT payload.
|
||||
"""
|
||||
|
||||
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
|
||||
|
||||
__all__ = ["get_current_user", "oauth2_scheme"]
|
||||
19
api/app/api/middleware/__init__.py
Normal file
19
api/app/api/middleware/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""API middleware package.
|
||||
|
||||
Exports the three middleware components introduced in Step 9:
|
||||
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
|
||||
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
|
||||
- Sanitizer: ``SanitizerMiddleware``
|
||||
"""
|
||||
|
||||
from app.api.middleware.auth import get_current_user, oauth2_scheme
|
||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
|
||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
|
||||
__all__ = [
|
||||
"get_current_user",
|
||||
"oauth2_scheme",
|
||||
"TierRateLimitMiddleware",
|
||||
"limiter",
|
||||
"SanitizerMiddleware",
|
||||
]
|
||||
103
api/app/api/middleware/auth.py
Normal file
103
api/app/api/middleware/auth.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Auth middleware — JWT validation dependency.
|
||||
|
||||
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
||||
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
|
||||
from the ``subscriptions`` table so that tier changes take effect immediately
|
||||
without requiring token re-issue.
|
||||
|
||||
Exempt routes (no JWT required):
|
||||
- POST /api/v1/auth/register
|
||||
- POST /api/v1/auth/login
|
||||
- POST /api/v1/billing/webhook
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.db import get_session
|
||||
from app.schemas import UserProfile
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Validate a Bearer JWT and return the authenticated user.
|
||||
|
||||
The JWT is used for identity and expiry only. The tier is fetched live
|
||||
from the ``subscriptions`` table so that upgrades/downgrades take effect
|
||||
immediately. Falls back to ``'free'`` when no subscription row exists.
|
||||
|
||||
Raises HTTP 401 on any invalid or expired token.
|
||||
"""
|
||||
credentials_exc = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
if not user_id or not email:
|
||||
raise credentials_exc
|
||||
except JWTError:
|
||||
raise credentials_exc
|
||||
|
||||
# Live tier lookup — subscription row is the authoritative source.
|
||||
# In dev, fall back to 'power' (unlimited) so quota limits don't
|
||||
# block local development when no Stripe subscription exists.
|
||||
from app.models import Subscription, User # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||
tier: str = result.scalar_one_or_none() or default_tier
|
||||
|
||||
# Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
|
||||
user_result = await db.execute(
|
||||
select(
|
||||
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||
User.password_hash,
|
||||
).where(User.id == user_id)
|
||||
)
|
||||
user_row = user_result.one_or_none()
|
||||
|
||||
# Convert onboarding_completed_at to epoch ms (int) or None.
|
||||
onboarding_ms: int | None = None
|
||||
if user_row and user_row.onboarding_completed_at is not None:
|
||||
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||
|
||||
# Load decrypted core memory.
|
||||
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||
|
||||
memory_dict: dict[str, str] = {}
|
||||
try:
|
||||
mw = MemoryMiddleware(db)
|
||||
blocks = await mw.list_core_blocks(user_id)
|
||||
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||
except Exception:
|
||||
pass # Non-critical — return empty memory on failure
|
||||
|
||||
return UserProfile(
|
||||
id=user_id,
|
||||
email=email,
|
||||
name=user_row.name if user_row else None,
|
||||
surname=user_row.surname if user_row else None,
|
||||
avatar_url=user_row.avatar_url if user_row else None,
|
||||
has_password=bool(user_row.password_hash) if user_row else False,
|
||||
tier=tier,
|
||||
onboarding_completed_at=onboarding_ms,
|
||||
memory=memory_dict,
|
||||
) # type: ignore[arg-type]
|
||||
129
api/app/api/middleware/rate_limit.py
Normal file
129
api/app/api/middleware/rate_limit.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Tier-aware rate limiting middleware.
|
||||
|
||||
Uses a per-user sliding-window counter (in-process, no Redis required).
|
||||
The ``slowapi`` Limiter is also exported for optional route-level decoration.
|
||||
|
||||
Limits (requests per minute):
|
||||
- free: 20
|
||||
- pro: 60
|
||||
- power: 120
|
||||
- team: 200
|
||||
|
||||
Exempt paths bypass the limiter entirely:
|
||||
- POST /api/v1/auth/register
|
||||
- POST /api/v1/auth/login
|
||||
- POST /api/v1/billing/webhook
|
||||
- GET /api/v1/health
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import Request, Response
|
||||
from jose import JWTError, jwt
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
_TIER_LIMITS: dict[str, int] = {
|
||||
"free": 20,
|
||||
"pro": 60,
|
||||
"power": 120,
|
||||
"team": 200,
|
||||
}
|
||||
|
||||
_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/billing/webhook",
|
||||
"/api/v1/health",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_user_id_from_jwt(request: Request) -> str:
|
||||
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
|
||||
auth = request.headers.get("Authorization", "")
|
||||
token = auth.removeprefix("Bearer ").strip()
|
||||
if not token:
|
||||
return get_remote_address(request)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
return payload.get("sub") or get_remote_address(request)
|
||||
except JWTError:
|
||||
return get_remote_address(request)
|
||||
|
||||
|
||||
# Exported Limiter instance — available for optional route-level decoration.
|
||||
limiter = Limiter(key_func=_get_user_id_from_jwt)
|
||||
|
||||
|
||||
class TierRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Sliding-window rate limiter applied globally across all non-exempt routes.
|
||||
|
||||
Each authenticated user gets their own 60-second window sized by tier.
|
||||
Unauthenticated requests pass through (the auth dependency will reject them
|
||||
with 401 before the route handler runs).
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
# user_id → list of request timestamps (float, seconds since epoch)
|
||||
self._window: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||
if request.url.path in _EXEMPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
|
||||
auth = request.headers.get("Authorization", "")
|
||||
token = auth.removeprefix("Bearer ").strip()
|
||||
if not token:
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str = payload.get("sub") or get_remote_address(request)
|
||||
tier: str = payload.get("tier", "free")
|
||||
except JWTError:
|
||||
return await call_next(request)
|
||||
|
||||
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
|
||||
now = time.monotonic()
|
||||
window_start = now - 60.0
|
||||
|
||||
# Slide the window: discard timestamps older than 60 seconds.
|
||||
timestamps = [t for t in self._window[user_id] if t > window_start]
|
||||
|
||||
if len(timestamps) >= limit:
|
||||
retry_after = max(1, int(60 - (now - min(timestamps))))
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"detail": (
|
||||
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
|
||||
f"Retry in {retry_after}s."
|
||||
)
|
||||
}
|
||||
),
|
||||
status_code=429,
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
timestamps.append(now)
|
||||
self._window[user_id] = timestamps
|
||||
return await call_next(request)
|
||||
138
api/app/api/middleware/sanitizer.py
Normal file
138
api/app/api/middleware/sanitizer.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Response sanitizer middleware.
|
||||
|
||||
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
|
||||
that could reveal server-side prompt IP:
|
||||
- System prompt openers ("You are a/an/the …")
|
||||
- Agent routing metadata ("Available agents:", "intent classifier", …)
|
||||
- LangChain tool schema fragments (``"type": "function"``)
|
||||
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
||||
- Exact-match known prompt fingerprints
|
||||
|
||||
The middleware only activates for paths under /api/v1/chat.
|
||||
|
||||
Any sanitisation event is logged as a WARNING with the request path and the
|
||||
names of the fields that were modified.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection patterns — order matters: fingerprints checked first (exact),
|
||||
# then compiled regexes.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FINGERPRINTS: tuple[str, ...] = (
|
||||
"You are an intent classifier",
|
||||
"Respond with just the agent name",
|
||||
"Summarize these agent results",
|
||||
"Available agents:",
|
||||
"route to:",
|
||||
)
|
||||
|
||||
_PATTERNS: tuple[re.Pattern[str], ...] = (
|
||||
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"Available agents\s*:", re.IGNORECASE),
|
||||
re.compile(r"\bintent classifier\b", re.IGNORECASE),
|
||||
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
|
||||
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
|
||||
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
|
||||
re.compile(r"route\s+to\s*:", re.IGNORECASE),
|
||||
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_text(text: str) -> tuple[str, bool]:
|
||||
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
|
||||
|
||||
Returns ``(cleaned_text, was_changed)``.
|
||||
"""
|
||||
# Fingerprint check — if any exact phrase is present, redact the whole string.
|
||||
for fp in _FINGERPRINTS:
|
||||
if fp in text:
|
||||
return "[REDACTED]", True
|
||||
|
||||
changed = False
|
||||
for pattern in _PATTERNS:
|
||||
new_text, n = pattern.subn("[REDACTED]", text)
|
||||
if n:
|
||||
text = new_text
|
||||
changed = True
|
||||
|
||||
return text, changed
|
||||
|
||||
|
||||
class SanitizerMiddleware(BaseHTTPMiddleware):
|
||||
"""Strip prompt IP from /api/v1/chat JSON responses."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||
response: Response = await call_next(request)
|
||||
|
||||
# Only process chat endpoint responses.
|
||||
if not request.url.path.startswith("/api/v1/chat"):
|
||||
return response
|
||||
|
||||
# Read body — collect streaming chunks.
|
||||
body_bytes = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
|
||||
|
||||
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
|
||||
try:
|
||||
body = json.loads(body_bytes.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
return Response(
|
||||
content=body_bytes,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
if not isinstance(body, dict):
|
||||
return Response(
|
||||
content=body_bytes,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
# Walk top-level string fields and sanitise.
|
||||
sanitised_fields: list[str] = []
|
||||
for key, value in body.items():
|
||||
if isinstance(value, str):
|
||||
cleaned, changed = _sanitize_text(value)
|
||||
if changed:
|
||||
body[key] = cleaned
|
||||
sanitised_fields.append(key)
|
||||
|
||||
if sanitised_fields:
|
||||
logger.warning(
|
||||
"Sanitizer redacted prompt fragments",
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"fields": sanitised_fields,
|
||||
},
|
||||
)
|
||||
|
||||
new_body = json.dumps(body).encode("utf-8")
|
||||
headers = dict(response.headers)
|
||||
headers["content-length"] = str(len(new_body))
|
||||
|
||||
return Response(
|
||||
content=new_body,
|
||||
status_code=response.status_code,
|
||||
headers=headers,
|
||||
media_type="application/json",
|
||||
)
|
||||
0
api/app/api/routes/__init__.py
Normal file
0
api/app/api/routes/__init__.py
Normal file
795
api/app/api/routes/auth.py
Normal file
795
api/app/api/routes/auth.py
Normal file
@@ -0,0 +1,795 @@
|
||||
"""Auth routes: register, login, refresh, me, OAuth social login, onboarding.
|
||||
|
||||
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||
SHA-256 hashes so plaintext never reaches the DB.
|
||||
|
||||
OAuth (Google):
|
||||
GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state
|
||||
POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal
|
||||
|
||||
import bcrypt
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair
|
||||
from app.config.settings import settings
|
||||
from app.core.llm import get_llm
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.models import OAuthAccount, RefreshToken, User
|
||||
from app.schemas import AuthTokens, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── OAuth provider registry ───────────────────────────────────────────
|
||||
|
||||
def _get_google_provider() -> GoogleOAuthProvider:
|
||||
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
"Google login is not configured on this server",
|
||||
)
|
||||
return GoogleOAuthProvider(
|
||||
client_id=settings.GOOGLE_AUTH_CLIENT_ID,
|
||||
client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET,
|
||||
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
||||
)
|
||||
|
||||
|
||||
_PROVIDERS = {"google": _get_google_provider}
|
||||
|
||||
# In-memory state store: state → (code_verifier, expires_at_epoch_s)
|
||||
# Production note: replace with Redis for multi-process deployments.
|
||||
_pending_states: dict[str, tuple[str, float]] = {}
|
||||
_STATE_TTL_SECONDS = 600 # 10 minutes
|
||||
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def _verify_password(password: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
|
||||
|
||||
def _hash_token(plain_token: str) -> str:
|
||||
"""SHA-256 of the plain refresh token string."""
|
||||
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||
|
||||
|
||||
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||
"""Return (signed JWT, expires_at_ms)."""
|
||||
now = int(time.time())
|
||||
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"tier": tier,
|
||||
"exp": exp,
|
||||
"iat": now,
|
||||
}
|
||||
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
return token, exp * 1000 # ms for client
|
||||
|
||||
|
||||
# ── Request bodies ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _RegisterRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
name: str | None = None
|
||||
surname: str | None = None
|
||||
|
||||
|
||||
class _LoginRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
||||
|
||||
class _RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
body: _RegisterRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Create a new account and return JWT tokens."""
|
||||
existing = await db.execute(select(User).where(User.email == body.email))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||
|
||||
user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=body.email,
|
||||
name=body.name,
|
||||
surname=body.surname,
|
||||
password_hash=_hash_password(body.password),
|
||||
tier="free",
|
||||
encryption_key=Fernet.generate_key().decode(),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush() # get user.id without committing
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthTokens)
|
||||
async def login(
|
||||
body: _LoginRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Validate credentials and return JWT tokens."""
|
||||
result = await db.execute(select(User).where(User.email == body.email))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not _verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=AuthTokens)
|
||||
async def refresh(
|
||||
body: _RefreshRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Rotate a refresh token and return a new token pair."""
|
||||
token_hash = _hash_token(body.refresh_token)
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
rt = result.scalar_one_or_none()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||
|
||||
# Rotate: delete old token, issue new one.
|
||||
await db.delete(rt)
|
||||
|
||||
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
new_rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=new_expires,
|
||||
)
|
||||
db.add(new_rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
class _UpdateProfileRequest(BaseModel):
|
||||
name: str | None = None
|
||||
surname: str | None = None
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserProfile)
|
||||
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||
"""Return the profile for the authenticated user."""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.put("/me", response_model=UserProfile)
|
||||
async def update_profile(
|
||||
body: _UpdateProfileRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Update the authenticated user's name and surname."""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
|
||||
if body.name is not None:
|
||||
user.name = body.name
|
||||
if body.surname is not None:
|
||||
user.surname = body.surname
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return UserProfile(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
surname=user.surname,
|
||||
avatar_url=user.avatar_url,
|
||||
tier=current_user.tier,
|
||||
)
|
||||
|
||||
|
||||
# ── OAuth helpers ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]:
|
||||
"""Create a refresh token row and return (plain_token, AuthTokens)."""
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return plain_token, AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
# ── OAuth request/response schemas ───────────────────────────────────
|
||||
|
||||
|
||||
class _OAuthAuthorizeResponse(BaseModel):
|
||||
url: str
|
||||
state: str
|
||||
|
||||
|
||||
class _OAuthCallbackRequest(BaseModel):
|
||||
code: str
|
||||
state: str
|
||||
|
||||
|
||||
# ── OAuth routes ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/oauth/{provider}/web-callback",
|
||||
summary="Web-facing OAuth redirect — bounces to the adiuvai:// deep link",
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def oauth_web_callback(
|
||||
provider: Literal["google"],
|
||||
code: str,
|
||||
state: str,
|
||||
) -> RedirectResponse:
|
||||
"""Google redirects here after user consent.
|
||||
|
||||
This endpoint immediately redirects to the Electron deep-link URI so the
|
||||
desktop app receives the authorization code. It is intentionally simple —
|
||||
no state validation here (the Electron app + backend callback do that).
|
||||
|
||||
Registered in Google Cloud Console as:
|
||||
http://localhost:8000/api/v1/auth/oauth/google/web-callback (dev)
|
||||
https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback (prod)
|
||||
"""
|
||||
params = urllib.parse.urlencode({"code": code, "state": state, "provider": provider})
|
||||
deep_link = f"adiuvai://oauth/callback?{params}"
|
||||
return RedirectResponse(url=deep_link, status_code=302)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/oauth/{provider}/authorize",
|
||||
response_model=_OAuthAuthorizeResponse,
|
||||
summary="Start OAuth flow — returns the provider consent-screen URL",
|
||||
)
|
||||
async def oauth_authorize(
|
||||
provider: Literal["google"],
|
||||
) -> _OAuthAuthorizeResponse:
|
||||
"""Generate a PKCE state + code_challenge and return the authorization URL.
|
||||
|
||||
The client opens this URL in the system browser. After the user grants
|
||||
consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback)
|
||||
with ``code`` and ``state`` query params. The client then calls
|
||||
``POST /auth/oauth/{provider}/callback`` with those values.
|
||||
"""
|
||||
provider_factory = _PROVIDERS.get(provider)
|
||||
if provider_factory is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
||||
|
||||
oauth_provider = provider_factory()
|
||||
state = str(uuid.uuid4())
|
||||
code_verifier, code_challenge = generate_pkce_pair()
|
||||
|
||||
# Purge expired states to prevent unbounded growth.
|
||||
now = time.time()
|
||||
expired = [s for s, (_, exp) in _pending_states.items() if exp < now]
|
||||
for s in expired:
|
||||
del _pending_states[s]
|
||||
|
||||
_pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS)
|
||||
|
||||
url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge)
|
||||
return _OAuthAuthorizeResponse(url=url, state=state)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/oauth/{provider}/callback",
|
||||
response_model=AuthTokens,
|
||||
summary="Complete OAuth flow — exchange code and issue JWT tokens",
|
||||
)
|
||||
async def oauth_callback(
|
||||
provider: Literal["google"],
|
||||
body: _OAuthCallbackRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Validate state, exchange the authorization code, and sign in (or register) the user.
|
||||
|
||||
Resolution order:
|
||||
1. ``oauth_accounts`` row match → existing user, log in.
|
||||
2. Email match + ``email_verified=True`` → link OAuth account to existing user.
|
||||
3. No match → create new user (password_hash=None, avatar from provider).
|
||||
"""
|
||||
provider_factory = _PROVIDERS.get(provider)
|
||||
if provider_factory is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
||||
|
||||
# Validate state (CSRF protection).
|
||||
now = time.time()
|
||||
entry = _pending_states.pop(body.state, None)
|
||||
if entry is None or entry[1] < now:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
|
||||
|
||||
code_verifier, _ = entry
|
||||
|
||||
oauth_provider = provider_factory()
|
||||
|
||||
# Exchange code for tokens.
|
||||
try:
|
||||
token_data = await oauth_provider.exchange_code(
|
||||
code=body.code,
|
||||
code_verifier=code_verifier,
|
||||
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code"
|
||||
)
|
||||
|
||||
access_token_google = token_data.get("access_token")
|
||||
if not access_token_google:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response")
|
||||
|
||||
# Fetch user identity.
|
||||
try:
|
||||
userinfo = await oauth_provider.get_userinfo(access_token_google)
|
||||
except Exception:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider")
|
||||
|
||||
# ── Resolution order ──────────────────────────────────────────────
|
||||
|
||||
# 1. Existing OAuth link?
|
||||
oauth_result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_user_id == userinfo.provider_user_id,
|
||||
)
|
||||
)
|
||||
oauth_account = oauth_result.scalar_one_or_none()
|
||||
|
||||
if oauth_account is not None:
|
||||
user_result = await db.execute(select(User).where(User.id == oauth_account.user_id))
|
||||
user = user_result.scalar_one()
|
||||
# Backfill avatar if the user doesn't have one yet.
|
||||
if user.avatar_url is None and userinfo.avatar_url:
|
||||
user.avatar_url = userinfo.avatar_url
|
||||
await db.commit()
|
||||
plain_token, tokens = await _issue_refresh_token(user, db)
|
||||
await db.commit()
|
||||
return tokens
|
||||
|
||||
# 2. Email match with a verified Google email → link accounts.
|
||||
if userinfo.email_verified:
|
||||
email_result = await db.execute(select(User).where(User.email == userinfo.email))
|
||||
existing_user = email_result.scalar_one_or_none()
|
||||
|
||||
if existing_user is not None:
|
||||
new_link = OAuthAccount(
|
||||
user_id=existing_user.id,
|
||||
provider=provider,
|
||||
provider_user_id=userinfo.provider_user_id,
|
||||
provider_email=userinfo.email,
|
||||
)
|
||||
db.add(new_link)
|
||||
if existing_user.avatar_url is None and userinfo.avatar_url:
|
||||
existing_user.avatar_url = userinfo.avatar_url
|
||||
plain_token, tokens = await _issue_refresh_token(existing_user, db)
|
||||
await db.commit()
|
||||
return tokens
|
||||
|
||||
# Guard: if the email is already taken but we couldn't auto-link (e.g.
|
||||
# email_verified=False), refuse with 409 instead of hitting a DB constraint.
|
||||
if not userinfo.email_verified:
|
||||
conflict = await db.execute(select(User).where(User.email == userinfo.email))
|
||||
if conflict.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status.HTTP_409_CONFLICT,
|
||||
"An account with this email already exists. "
|
||||
"Please sign in with your password.",
|
||||
)
|
||||
|
||||
# 3. New user — social-only account (no password).
|
||||
new_user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=userinfo.email,
|
||||
name=userinfo.name,
|
||||
password_hash=None,
|
||||
avatar_url=userinfo.avatar_url,
|
||||
tier="free",
|
||||
encryption_key=Fernet.generate_key().decode(),
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.flush() # populate new_user.id
|
||||
|
||||
new_oauth = OAuthAccount(
|
||||
user_id=new_user.id,
|
||||
provider=provider,
|
||||
provider_user_id=userinfo.provider_user_id,
|
||||
provider_email=userinfo.email,
|
||||
)
|
||||
db.add(new_oauth)
|
||||
|
||||
plain_token, tokens = await _issue_refresh_token(new_user, db)
|
||||
await db.commit()
|
||||
return tokens
|
||||
|
||||
|
||||
# ── Onboarding helpers ────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile:
|
||||
"""Re-fetch and return a full UserProfile (reuses get_current_user logic)."""
|
||||
|
||||
# We can't call the FastAPI dependency directly, but we can replicate
|
||||
# the core logic inline. Instead, we just re-query the same way.
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||
tier: str = result.scalar_one_or_none() or default_tier
|
||||
|
||||
user_result = await db.execute(
|
||||
select(
|
||||
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||
User.password_hash,
|
||||
).where(User.id == user_id)
|
||||
)
|
||||
user_row = user_result.one_or_none()
|
||||
|
||||
onboarding_ms: int | None = None
|
||||
if user_row and user_row.onboarding_completed_at is not None:
|
||||
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||
|
||||
memory_dict: dict[str, str] = {}
|
||||
try:
|
||||
mw = MemoryMiddleware(db)
|
||||
blocks = await mw.list_core_blocks(user_id)
|
||||
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return UserProfile(
|
||||
id=user_id,
|
||||
email=email,
|
||||
name=user_row.name if user_row else None,
|
||||
surname=user_row.surname if user_row else None,
|
||||
avatar_url=user_row.avatar_url if user_row else None,
|
||||
has_password=bool(user_row.password_hash) if user_row else False,
|
||||
tier=tier,
|
||||
onboarding_completed_at=onboarding_ms,
|
||||
memory=memory_dict,
|
||||
)
|
||||
|
||||
|
||||
# ── Onboarding routes ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _UpdateMemoryRequest(BaseModel):
|
||||
memory: dict[str, str] = Field(default_factory=dict)
|
||||
mark_onboarded: bool = False
|
||||
|
||||
|
||||
@router.put("/me/memory", response_model=UserProfile)
|
||||
async def update_memory(
|
||||
body: _UpdateMemoryRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Update core memory key/value pairs and optionally mark onboarding complete."""
|
||||
mw = MemoryMiddleware(db)
|
||||
for key, value in body.memory.items():
|
||||
await mw.update_core(current_user.id, key, value)
|
||||
if body.mark_onboarded:
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
user.onboarding_completed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
return await _build_profile(current_user.id, current_user.email, db)
|
||||
|
||||
|
||||
@router.post("/me/onboarding/reset")
|
||||
async def reset_onboarding(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Reset onboarding so the wizard runs again on next login."""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
user.onboarding_completed_at = None
|
||||
await db.commit()
|
||||
return {"status": "reset"}
|
||||
|
||||
|
||||
class _NormalizeRequest(BaseModel):
|
||||
inputs: dict[str, str]
|
||||
|
||||
|
||||
class _NormalizeResponse(BaseModel):
|
||||
normalized: dict[str, str]
|
||||
|
||||
|
||||
@router.post("/onboarding/normalize", response_model=_NormalizeResponse)
|
||||
async def normalize_onboarding(
|
||||
body: _NormalizeRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _NormalizeResponse:
|
||||
"""One-shot LLM normalization for free-text onboarding answers."""
|
||||
if not body.inputs:
|
||||
return _NormalizeResponse(normalized={})
|
||||
try:
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0)
|
||||
prompt = (
|
||||
"You normalize user onboarding answers into clean, ≤3-word canonical labels.\n"
|
||||
"Return a JSON object with the same keys and normalized values.\n"
|
||||
"Examples: 'i build websites' → 'Web Developer', 'tech-ish stuff' → 'Technology'\n"
|
||||
f"Input: {json.dumps(body.inputs)}"
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[
|
||||
{"role": "system", "content": "You normalize user inputs. Return JSON only."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
)
|
||||
normalized = json.loads(response.content)
|
||||
return _NormalizeResponse(normalized=normalized)
|
||||
except Exception:
|
||||
# LLM failure must never block onboarding — return inputs unchanged
|
||||
return _NormalizeResponse(normalized=body.inputs)
|
||||
|
||||
|
||||
# ── Password management ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class _ChangePasswordRequest(BaseModel):
|
||||
current_password: str = Field(min_length=1)
|
||||
new_password: str = Field(min_length=8)
|
||||
|
||||
|
||||
@router.put("/me/password", status_code=status.HTTP_200_OK)
|
||||
async def change_password(
|
||||
body: _ChangePasswordRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Change the authenticated user's password.
|
||||
|
||||
Requires the current password for verification.
|
||||
Returns 400 for social-only users (no password set).
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
|
||||
if user.password_hash is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
"This account uses social login and has no password to change",
|
||||
)
|
||||
|
||||
if not _verify_password(body.current_password, user.password_hash):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
|
||||
|
||||
user.password_hash = _hash_password(body.new_password)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── OAuth account management ─────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/me/oauth-accounts", response_model=list[dict])
|
||||
async def list_oauth_accounts(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[dict]:
|
||||
"""List all OAuth providers linked to the authenticated user."""
|
||||
result = await db.execute(
|
||||
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||
)
|
||||
accounts = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"provider": a.provider,
|
||||
"provider_email": a.provider_email,
|
||||
"created_at": int(a.created_at.timestamp() * 1000),
|
||||
}
|
||||
for a in accounts
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
|
||||
async def unlink_oauth_account(
|
||||
provider: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Unlink an OAuth provider from the authenticated user.
|
||||
|
||||
Refuses if the user has no password and this is their only login method.
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
|
||||
oauth_result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
OAuthAccount.user_id == current_user.id,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
account = oauth_result.scalar_one_or_none()
|
||||
if account is None:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
|
||||
|
||||
# Safety: don't let users lock themselves out.
|
||||
all_oauth = await db.execute(
|
||||
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||
)
|
||||
oauth_count = len(all_oauth.scalars().all())
|
||||
|
||||
if user.password_hash is None and oauth_count <= 1:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
"Cannot unlink the only login method. Set a password first.",
|
||||
)
|
||||
|
||||
await db.delete(account)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Avatar update ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _UpdateAvatarRequest(BaseModel):
|
||||
avatar_url: str = Field(min_length=1)
|
||||
|
||||
|
||||
@router.put("/me/avatar", response_model=UserProfile)
|
||||
async def update_avatar(
|
||||
body: _UpdateAvatarRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Update the authenticated user's avatar URL.
|
||||
|
||||
Accepts {"avatar_url": "https://..."} — the client uploads the image
|
||||
to its own storage and passes the resulting URL here.
|
||||
"""
|
||||
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
|
||||
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
user.avatar_url = body.avatar_url
|
||||
await db.commit()
|
||||
|
||||
return await _build_profile(current_user.id, current_user.email, db)
|
||||
|
||||
|
||||
# ── Account deletion ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.delete("/me", status_code=status.HTTP_200_OK)
|
||||
async def delete_account(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Permanently delete the authenticated user's account.
|
||||
|
||||
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
|
||||
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
|
||||
is cancelled if active.
|
||||
"""
|
||||
# Cancel Stripe subscription if present.
|
||||
try:
|
||||
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
||||
await stripe_service.cancel_subscription(current_user.id, db)
|
||||
except HTTPException:
|
||||
pass # No subscription — that's fine
|
||||
|
||||
# Delete all memory rows (core, associative, episodic, proactive).
|
||||
try:
|
||||
from app.models import ( # noqa: PLC0415
|
||||
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
|
||||
)
|
||||
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
|
||||
await db.execute(
|
||||
model.__table__.delete().where(model.user_id == current_user.id)
|
||||
)
|
||||
except Exception:
|
||||
pass # Non-critical — cascade on User will handle most
|
||||
|
||||
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
await db.delete(user)
|
||||
await db.commit()
|
||||
|
||||
return {"ok": True}
|
||||
132
api/app/api/routes/billing.py
Normal file
132
api/app/api/routes/billing.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Billing routes: Stripe checkout, webhook, subscription management.
|
||||
|
||||
Business logic lives in ``app.billing.stripe_service.StripeService``.
|
||||
The route layer handles HTTP concerns (request parsing, response shaping)
|
||||
and delegates everything else to the service singleton.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.stripe_service import stripe_service
|
||||
from app.db import get_session
|
||||
from app.schemas import BillingTier, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||
|
||||
|
||||
# ── Request bodies ─────────────────────────────────────────────────────
|
||||
|
||||
class _CheckoutRequest(BaseModel):
|
||||
tier: BillingTier
|
||||
|
||||
|
||||
# ── Routes ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/checkout", response_model=dict)
|
||||
async def create_checkout(
|
||||
body: _CheckoutRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, str]:
|
||||
"""Create a Stripe checkout session for a tier upgrade.
|
||||
|
||||
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
||||
"""
|
||||
url = stripe_service.create_checkout_session(current_user.id, body.tier)
|
||||
return {"checkout_url": url}
|
||||
|
||||
|
||||
@router.post("/webhook", response_model=dict)
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Handle Stripe webhook events.
|
||||
|
||||
No JWT auth — authenticated via Stripe signature verification instead.
|
||||
Returns 200 immediately when Stripe is not configured (local dev).
|
||||
"""
|
||||
payload = await request.body()
|
||||
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/subscription", response_model=dict)
|
||||
async def get_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, Any]:
|
||||
"""Return the current subscription info for the authenticated user."""
|
||||
sub = await stripe_service.get_subscription(current_user.id, db)
|
||||
if sub is None:
|
||||
return {
|
||||
"tier": current_user.tier,
|
||||
"status": "free",
|
||||
"stripe_subscription_id": None,
|
||||
"current_period_end": None,
|
||||
}
|
||||
return sub
|
||||
|
||||
|
||||
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
|
||||
async def cancel_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Cancel the active subscription."""
|
||||
await stripe_service.cancel_subscription(current_user.id, db)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/invoices", response_model=list[dict])
|
||||
async def list_invoices(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return billing history (invoices) from Stripe.
|
||||
|
||||
Returns an empty list when Stripe is not configured.
|
||||
"""
|
||||
invoices = await stripe_service.list_invoices(current_user.id, db)
|
||||
return invoices
|
||||
|
||||
|
||||
# ── Quota check ────────────────────────────────────────────────────────
|
||||
|
||||
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
|
||||
|
||||
|
||||
class QuotaCheckRequest(BaseModel):
|
||||
feature: str
|
||||
estimated_files: int
|
||||
|
||||
|
||||
@router.post("/quota/check")
|
||||
async def quota_check(
|
||||
payload: QuotaCheckRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict:
|
||||
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
|
||||
if payload.feature != "folder_index":
|
||||
raise HTTPException(status_code=400, detail="Unknown feature")
|
||||
try:
|
||||
await check_folder_quota(
|
||||
user_id=current_user.id,
|
||||
tier=current_user.tier,
|
||||
estimated_files=payload.estimated_files,
|
||||
db=db,
|
||||
)
|
||||
except QuotaExceeded as exc:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={"reason": exc.reason, "message": str(exc)},
|
||||
)
|
||||
return {"ok": True}
|
||||
116
api/app/api/routes/chat.py
Normal file
116
api/app/api/routes/chat.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Chat routes: POST /chat (REST fallback) and POST /chat/embed (text → vector).
|
||||
|
||||
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||
from app.core.deep_agent import run_home
|
||||
from app.core.llm import embed
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import async_session
|
||||
from app.schemas import ChatRequest, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
|
||||
# ── Embed helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _EmbedRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class _EmbedResponse(BaseModel):
|
||||
vector: list[float]
|
||||
|
||||
|
||||
# ── Endpoints ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def chat(
|
||||
body: ChatRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> JSONResponse:
|
||||
"""REST fallback for home chat when websocket streaming is unavailable."""
|
||||
response = await run_home(
|
||||
user_id=current_user.id,
|
||||
message=body.message,
|
||||
context=body.context.model_dump(),
|
||||
)
|
||||
return JSONResponse(content={"response": response})
|
||||
|
||||
|
||||
class _BriefRequest(BaseModel):
|
||||
mode: Literal["home", "project"]
|
||||
project_id: str | None = None
|
||||
|
||||
|
||||
class _BriefResponse(BaseModel):
|
||||
response: str
|
||||
|
||||
|
||||
@router.post("/brief", response_model=_BriefResponse)
|
||||
async def brief(
|
||||
body: _BriefRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _BriefResponse:
|
||||
"""REST fallback for brief when the device WebSocket is not ready."""
|
||||
if body.mode == "project":
|
||||
if not body.project_id:
|
||||
raise HTTPException(status_code=422, detail="project_id required for project mode")
|
||||
try:
|
||||
uuid.UUID(body.project_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail="project_id must be a valid UUID")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
current_user.id,
|
||||
"",
|
||||
trace_id=request_id,
|
||||
session_id=request_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"_debug": {"request_id": request_id, "user_id": current_user.id},
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
chunks: list[str] = []
|
||||
if body.mode == "project":
|
||||
stream = run_project_brief(current_user.id, body.project_id, context) # type: ignore[arg-type]
|
||||
else:
|
||||
stream = run_home_brief(current_user.id, context)
|
||||
|
||||
async for event_type, data in stream:
|
||||
if event_type == "token" and data:
|
||||
chunks.append(str(data))
|
||||
|
||||
return _BriefResponse(response="".join(chunks))
|
||||
|
||||
|
||||
@router.post("/embed", response_model=_EmbedResponse)
|
||||
async def embed_text(
|
||||
body: _EmbedRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _EmbedResponse:
|
||||
"""Generate a 1536-dim embedding vector for the given text.
|
||||
|
||||
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
||||
Used by Electron (vectordb.ts) for local note search.
|
||||
"""
|
||||
vector = await embed(body.text)
|
||||
return _EmbedResponse(vector=vector)
|
||||
864
api/app/api/routes/device_ws.py
Normal file
864
api/app/api/routes/device_ws.py
Normal file
@@ -0,0 +1,864 @@
|
||||
"""Device WebSocket endpoint.
|
||||
|
||||
Persistent connection from Electron devices to the backend.
|
||||
|
||||
WS /api/v1/ws/device?token=<jwt>
|
||||
|
||||
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
|
||||
available during the WebSocket handshake).
|
||||
|
||||
Protocol:
|
||||
1. Client connects → JWT validated → connection accepted.
|
||||
2. Client sends ``device_hello`` frame: ``{ type, device_id, scout_ids }``.
|
||||
3. Backend registers the connection in ``DeviceConnectionManager``.
|
||||
4. Session enters message dispatch loop + heartbeat.
|
||||
|
||||
Incoming frame dispatch:
|
||||
- ``tool_result`` → resolves a pending tool-call Future.
|
||||
- ``journey_start`` → starts a guided setup journey session.
|
||||
- ``journey_message`` → continues a journey conversation.
|
||||
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||
- unknown types → logged, ignored.
|
||||
|
||||
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
||||
|
||||
On disconnect:
|
||||
- Unregisters from DeviceConnectionManager.
|
||||
- Marks all in-progress AgentRunLog rows for this user as ``error``
|
||||
with message "device disconnected".
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import update
|
||||
|
||||
from app.api.routes.scout_setup import handle_journey_message, handle_journey_start
|
||||
from app.config.settings import settings
|
||||
from app.scouts.engine import ScoutEngine
|
||||
from app.core.scout_runner import trigger_pending_runs
|
||||
from app.core.scout_session_buffer import session_buffer
|
||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||
from app.core.deep_agent import run_contextual_stream, run_home_stream, run_task_brief_research_stream
|
||||
from app.core.output_formatter import extract_canvas_block
|
||||
from app.core.device_manager import device_manager
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.core.output_formatter import StreamFormatter
|
||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||
from app.db import async_session
|
||||
from app.models import ScoutRunLog
|
||||
from app.schemas import WsFrameType, WsStreamEnd
|
||||
from app.schemas.contextual import ContextualScope, render_scope_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||
|
||||
# ── v7 folder index session state ─────────────────────────────────────
|
||||
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
|
||||
_index_sessions: dict[str, dict] = {}
|
||||
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||
|
||||
|
||||
@router.websocket("/device")
|
||||
async def device_ws(websocket: WebSocket) -> None:
|
||||
"""Persistent WebSocket endpoint for Electron device connections.
|
||||
|
||||
Authentication is via ``?token=<jwt>`` query parameter.
|
||||
"""
|
||||
# ── 1. Authenticate before accepting ─────────────────────────────
|
||||
token = websocket.query_params.get("token", "")
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
if not user_id:
|
||||
raise JWTError("missing sub")
|
||||
except JWTError:
|
||||
await websocket.close(code=1008) # Policy Violation
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# ── 2. Await device_hello frame ───────────────────────────────────
|
||||
try:
|
||||
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
||||
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
try:
|
||||
hello = json.loads(raw)
|
||||
if hello.get("type") != WsFrameType.device_hello:
|
||||
raise ValueError("expected device_hello as first frame")
|
||||
device_id: str = hello["device_id"]
|
||||
scout_ids: list[str] = hello.get("scout_ids", [])
|
||||
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
||||
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# ── 3. Register connection ────────────────────────────────────────
|
||||
device_manager.register(user_id, device_id, websocket)
|
||||
logger.info(
|
||||
"device_ws: connected user=%s device=%s scouts=%s",
|
||||
user_id,
|
||||
device_id,
|
||||
scout_ids,
|
||||
)
|
||||
|
||||
# Trigger any overdue agent runs now that the device is connected.
|
||||
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||
|
||||
# Drain any queued scout proposals and deliver to the client (non-blocking).
|
||||
async def _deliver_pending_safe() -> None:
|
||||
import uuid as _uuid # noqa: PLC0415
|
||||
try:
|
||||
await ScoutEngine().deliver_pending(_uuid.UUID(user_id), websocket)
|
||||
except Exception:
|
||||
logger.exception("scout deliver_pending failed for user %s", user_id)
|
||||
|
||||
asyncio.create_task(_deliver_pending_safe())
|
||||
|
||||
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||
try:
|
||||
await asyncio.gather(
|
||||
_message_loop(websocket, user_id),
|
||||
_heartbeat_loop(websocket),
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
|
||||
finally:
|
||||
device_manager.unregister(user_id)
|
||||
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
|
||||
await _mark_runs_disconnected(user_id)
|
||||
|
||||
|
||||
# ── Message dispatch loop ─────────────────────────────────────────────
|
||||
|
||||
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
"""Receive frames from Electron and dispatch to the appropriate handler."""
|
||||
async for raw in websocket.iter_text():
|
||||
try:
|
||||
frame: dict = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("device_ws: invalid JSON from user=%s", user_id)
|
||||
continue
|
||||
|
||||
frame_type = frame.get("type")
|
||||
|
||||
if frame_type == WsFrameType.tool_result:
|
||||
call_id = frame.get("id")
|
||||
if call_id:
|
||||
device_manager.resolve_pending_call(user_id, call_id, frame)
|
||||
else:
|
||||
logger.warning(
|
||||
"device_ws: tool_result missing id from user=%s", user_id
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.home_request:
|
||||
asyncio.create_task(
|
||||
_handle_home_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.brief_request:
|
||||
asyncio.create_task(
|
||||
_handle_brief_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.task_brief_request:
|
||||
asyncio.create_task(
|
||||
_handle_task_brief_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.journey_start:
|
||||
asyncio.create_task(
|
||||
_handle_journey_start(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.journey_message:
|
||||
asyncio.create_task(
|
||||
_handle_journey_message(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_session_start:
|
||||
asyncio.create_task(
|
||||
_handle_index_session_start(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_file_batch:
|
||||
asyncio.create_task(
|
||||
_handle_index_file_batch(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_session_cancel:
|
||||
await _handle_index_session_cancel(websocket, frame)
|
||||
|
||||
elif frame_type == WsFrameType.contextual_request:
|
||||
asyncio.create_task(
|
||||
_handle_contextual_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.contextual_scope_update:
|
||||
asyncio.create_task(
|
||||
_handle_contextual_scope_update(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == "scout_proposal_ack":
|
||||
proposal_id = frame.get("proposal_id")
|
||||
if proposal_id:
|
||||
try:
|
||||
await ScoutEngine().ack_proposal(proposal_id)
|
||||
except Exception:
|
||||
logger.exception("scout ack_proposal failed for %s", proposal_id)
|
||||
|
||||
elif frame_type == "pong":
|
||||
# Heartbeat ack — nothing to do, connection is alive.
|
||||
pass
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
|
||||
)
|
||||
|
||||
|
||||
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||
|
||||
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||
async def _executor(payload: dict) -> dict:
|
||||
payload["type"] = WsFrameType.tool_call
|
||||
await websocket.send_text(json.dumps(payload))
|
||||
future = device_manager.create_pending_call(user_id, payload["id"])
|
||||
return await future
|
||||
return _executor
|
||||
|
||||
|
||||
async def _handle_home_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
message: str = frame.get("message", "")
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
||||
logger.info(
|
||||
"device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
project_id,
|
||||
message[:200],
|
||||
)
|
||||
|
||||
# ── Memory: enrich context before LLM call ────────────────────────
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
message,
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"conversation_history": frame.get("conversation_history", []),
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
event_stream = run_home_stream(user_id, message, context, project_id=project_id)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
# Collect text chunks to build the full response for episode storage
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: home_request failed user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
# ── Memory: store episode after response ──────────────────────────
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.store_episode(
|
||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||
)
|
||||
logger.info(
|
||||
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
len("".join(response_chunks)),
|
||||
)
|
||||
|
||||
|
||||
# ── v8 Contextual Sidebar Handlers ───────────────────────────────────
|
||||
|
||||
|
||||
def get_session_buffer(user_id: str, session_id: str, channel: str = "contextual"):
|
||||
"""Return a session-scoped buffer proxy for the given user+session.
|
||||
|
||||
Returns a _ContextualBufferProxy that exposes append_system_message().
|
||||
Defined at module level so tests can monkeypatch it.
|
||||
The channel kwarg is accepted for forward-compatibility.
|
||||
"""
|
||||
from app.core.scout_session_buffer import ContextualBufferProxy # noqa: PLC0415
|
||||
return ContextualBufferProxy(session_buffer, user_id, session_id)
|
||||
|
||||
|
||||
async def _handle_contextual_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a contextual_request frame — runs the contextual agent and streams frames."""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
message: str = frame.get("message", "")
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
scope_payload: dict = frame.get("scope", {})
|
||||
logger.info(
|
||||
"device_ws: contextual_request_start user=%s req=%s session=%s msg=%s",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
message[:200],
|
||||
)
|
||||
|
||||
scope = ContextualScope.model_validate(scope_payload)
|
||||
|
||||
# Enrich context with memory before the LLM call.
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
message,
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"conversation_history": frame.get("conversation_history", []),
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
event_stream = run_contextual_stream(
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
context=context,
|
||||
scope=scope,
|
||||
)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: contextual_request failed user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
# Store episode so the contextual agent can recall prior turns.
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.store_episode(
|
||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||
)
|
||||
logger.info(
|
||||
"device_ws: contextual_request_end user=%s req=%s session=%s response_chars=%d",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
len("".join(response_chunks)),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_contextual_scope_update(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a contextual_scope_update frame.
|
||||
|
||||
Injects a synthetic system message into the session buffer so the next
|
||||
agent turn knows the user navigated. No LLM call is made.
|
||||
"""
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
scope = ContextualScope.model_validate(frame.get("scope", {}))
|
||||
block = render_scope_block(scope)
|
||||
buf = get_session_buffer(user_id, session_id, channel="contextual")
|
||||
buf.append_system_message(
|
||||
f"User navigated to a new view. {block} Treat this as the new active context."
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.contextual_scope_ack,
|
||||
"session_id": session_id,
|
||||
}))
|
||||
logger.info(
|
||||
"device_ws: contextual_scope_update user=%s session=%s page=%s",
|
||||
user_id, session_id, scope.page,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_brief_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a brief_request frame — streams plain-text brief back on the socket.
|
||||
|
||||
No episode storage — briefs are not conversations.
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
session_id = frame.get("session_id") or str(uuid4())
|
||||
mode: str = frame.get("mode", "home")
|
||||
project_id: str | None = frame.get("project_id")
|
||||
|
||||
logger.info(
|
||||
"device_ws: brief_request_start user=%s req=%s mode=%s project_id=%s",
|
||||
user_id, request_id, mode, project_id,
|
||||
)
|
||||
|
||||
# Validate project_id for project mode before touching LLM.
|
||||
if mode == "project":
|
||||
try:
|
||||
if not project_id:
|
||||
raise ValueError("project_id required for project mode")
|
||||
_uuid.UUID(project_id)
|
||||
except (ValueError, AttributeError) as exc:
|
||||
logger.warning(
|
||||
"device_ws: brief_request invalid project_id user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||
)
|
||||
return
|
||||
|
||||
# Enrich context with memory (no user message — use empty string as probe).
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
"",
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
try:
|
||||
if mode == "project":
|
||||
event_stream = run_project_brief(user_id, project_id, context) # type: ignore[arg-type]
|
||||
else:
|
||||
event_stream = run_home_brief(user_id, context)
|
||||
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: brief_request failed user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
logger.info(
|
||||
"device_ws: brief_request_end user=%s req=%s mode=%s",
|
||||
user_id, request_id, mode,
|
||||
)
|
||||
|
||||
|
||||
# ── v6 Task Brief Handler ────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_task_brief_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
|
||||
|
||||
Streams the briefing markdown back to the client.
|
||||
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
|
||||
"""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
session_id = frame.get("session_id") or str(uuid4())
|
||||
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
|
||||
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
||||
|
||||
logger.info(
|
||||
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
|
||||
user_id, request_id, task_id, project_id,
|
||||
)
|
||||
|
||||
if not task_id:
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
|
||||
)
|
||||
return
|
||||
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
f"task brief: {task_id}",
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
|
||||
try:
|
||||
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
elif ws_frame.type == "stream_start":
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
# stream_end is emitted below with mutations — skip formatter's version
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
|
||||
user_id, request_id, task_id, exc,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||
)
|
||||
return
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
# Extract canvas block then emit stream_end with optional mutations.
|
||||
full_response = "".join(response_chunks)
|
||||
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
|
||||
|
||||
mutations: list[dict] = []
|
||||
if canvas_content:
|
||||
mutations.append({
|
||||
"type": "canvas_draft",
|
||||
"content": canvas_content,
|
||||
"kind": canvas_kind,
|
||||
})
|
||||
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
|
||||
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
|
||||
)
|
||||
|
||||
|
||||
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_journey_start(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a journey_start frame — explores directory and sends first question."""
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
try:
|
||||
reply = await handle_journey_start(user_id, frame)
|
||||
await websocket.send_text(json.dumps(reply))
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": "journey_reply",
|
||||
"session_id": frame.get("session_id", ""),
|
||||
"message": f"Failed to start journey: {exc}",
|
||||
"done": True,
|
||||
"prompt_template": None,
|
||||
}))
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
|
||||
async def _handle_journey_message(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a journey_message frame — continues the journey conversation."""
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
try:
|
||||
reply = await handle_journey_message(user_id, frame)
|
||||
await websocket.send_text(json.dumps(reply))
|
||||
except Exception as exc:
|
||||
session_id = frame.get("session_id", "")
|
||||
logger.error(
|
||||
"device_ws: journey_message failed user=%s session=%s: %s",
|
||||
user_id, session_id, exc,
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": f"Journey error: {exc}",
|
||||
"done": True,
|
||||
"prompt_template": None,
|
||||
}))
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
|
||||
# ── v7 Folder Index Handlers ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_index_session_start(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Register a new folder index session. No response sent — client is declaring intent."""
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||
project_id: str | None = frame.get("projectId") or frame.get("project_id")
|
||||
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
|
||||
|
||||
if not session_id:
|
||||
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
|
||||
return
|
||||
|
||||
_index_sessions[session_id] = {
|
||||
"user_id": user_id,
|
||||
"project_id": project_id,
|
||||
"processed": 0,
|
||||
"total": total,
|
||||
"cancelled": False,
|
||||
}
|
||||
logger.info(
|
||||
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
|
||||
user_id, session_id, project_id, total,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_index_session_cancel(
|
||||
websocket: WebSocket,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||
session = _index_sessions.get(session_id)
|
||||
if session:
|
||||
session["cancelled"] = True
|
||||
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "cancelled",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info("device_ws: index_session_cancel session=%s", session_id)
|
||||
|
||||
|
||||
async def _handle_index_file_batch(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Process a batch of files for an index session, streaming results back."""
|
||||
# Lazy imports to avoid heavy load at module startup.
|
||||
from app.core.folder_indexer import ( # noqa: PLC0415
|
||||
summarize_image,
|
||||
summarize_pdf,
|
||||
summarize_docx,
|
||||
summarize_text,
|
||||
)
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.billing.quota import add_token_usage # noqa: PLC0415
|
||||
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||
files: list[dict] = frame.get("files", [])
|
||||
|
||||
session = _index_sessions.get(session_id)
|
||||
if not session or session.get("cancelled"):
|
||||
return
|
||||
|
||||
async with async_session() as db:
|
||||
tier = await tier_manager.get_tier(user_id, db)
|
||||
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||
cap: int | None = None if raw_cap == -1 else raw_cap
|
||||
|
||||
for file_info in files:
|
||||
if session.get("cancelled"):
|
||||
return
|
||||
|
||||
# Electron's toSnakeCase converts payload keys, so accept both forms.
|
||||
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
|
||||
kind: str = file_info.get("kind") or "text"
|
||||
content: str = file_info.get("content") or ""
|
||||
ext: str = file_info.get("ext") or ""
|
||||
mime: str = file_info.get("mime") or "application/octet-stream"
|
||||
name: str = rel_path.split("/")[-1] or rel_path
|
||||
|
||||
try:
|
||||
if kind == "image":
|
||||
res = await summarize_image(image_b64=content, mime=mime)
|
||||
elif kind == "pdf":
|
||||
res = await summarize_pdf(pdf_b64=content, name=name)
|
||||
elif kind == "docx":
|
||||
res = await summarize_docx(docx_b64=content, name=name)
|
||||
else:
|
||||
res = await summarize_text(content=content, ext=ext, name=name)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
|
||||
session_id, rel_path, exc,
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_file_result,
|
||||
"sessionId": session_id,
|
||||
"relPath": rel_path,
|
||||
"summary": None,
|
||||
"tokensUsed": 0,
|
||||
"error": str(exc),
|
||||
}))
|
||||
session["processed"] += 1
|
||||
continue
|
||||
|
||||
# Account for token usage and check cap.
|
||||
usage = await add_token_usage(
|
||||
user_id=user_id,
|
||||
feature="folder_index",
|
||||
tokens=res.tokens_used,
|
||||
db=db,
|
||||
cap=cap,
|
||||
)
|
||||
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_file_result,
|
||||
"sessionId": session_id,
|
||||
"relPath": rel_path,
|
||||
"summary": res.summary,
|
||||
"tokensUsed": res.tokens_used,
|
||||
}))
|
||||
session["processed"] += 1
|
||||
|
||||
if usage.exhausted:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "quota_exceeded",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info(
|
||||
"device_ws: index_session quota_exceeded user=%s session=%s",
|
||||
user_id, session_id,
|
||||
)
|
||||
return
|
||||
|
||||
# After processing the batch, emit progress.
|
||||
processed = session["processed"]
|
||||
total = session["total"]
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_progress,
|
||||
"sessionId": session_id,
|
||||
"processed": processed,
|
||||
"total": total,
|
||||
}))
|
||||
|
||||
if processed >= total:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "completed",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info(
|
||||
"device_ws: index_session_done completed user=%s session=%s processed=%d",
|
||||
user_id, session_id, processed,
|
||||
)
|
||||
|
||||
|
||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||
|
||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||
"""Send a ping frame every 30 s to keep the connection alive."""
|
||||
while True:
|
||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||
await websocket.send_text(json.dumps({"type": "ping"}))
|
||||
|
||||
|
||||
# ── Disconnect cleanup ────────────────────────────────────────────────
|
||||
|
||||
async def _mark_runs_disconnected(user_id: str) -> None:
|
||||
"""Mark all in-progress ScoutRunLog rows as 'error' for this user."""
|
||||
try:
|
||||
async with async_session() as db:
|
||||
await db.execute(
|
||||
update(ScoutRunLog)
|
||||
.where(
|
||||
ScoutRunLog.user_id == user_id,
|
||||
ScoutRunLog.status == "running",
|
||||
)
|
||||
.values(
|
||||
status="error",
|
||||
errors=["device disconnected"],
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: failed to mark runs as disconnected for user=%s: %s",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
225
api/app/api/routes/memory.py
Normal file
225
api/app/api/routes/memory.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Memory management routes — view/edit/delete user memory tiers.
|
||||
|
||||
All routes require authentication. Data is always user-scoped.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.models import (
|
||||
ExtractionQueue,
|
||||
MemoryAssociative,
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
MemoryProactive,
|
||||
MemoryRelation,
|
||||
)
|
||||
from app.schemas import UserProfile
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["memory"])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ALLOWED_PREDICATES = {
|
||||
"works_at",
|
||||
"reports_to",
|
||||
"stakeholder_of",
|
||||
"last_contacted_on",
|
||||
"owes_followup",
|
||||
"manages",
|
||||
"collaborates_with",
|
||||
"owns",
|
||||
"member_of",
|
||||
"custom",
|
||||
}
|
||||
|
||||
|
||||
# ── Response schemas ─────────────────────────────────────────────────────────
|
||||
|
||||
class RelationOut(BaseModel):
|
||||
id: str
|
||||
subject_label: str
|
||||
subject_type: str
|
||||
predicate: str
|
||||
object_label: str
|
||||
object_type: str
|
||||
confidence: float
|
||||
last_confirmed_at: int | None = None # epoch ms
|
||||
|
||||
|
||||
class RelationPatch(BaseModel):
|
||||
subject_label: str | None = None
|
||||
object_label: str | None = None
|
||||
predicate: str | None = None
|
||||
confidence: float | None = Field(None, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class CoreAddBody(BaseModel):
|
||||
key: str = Field(..., min_length=1, max_length=255)
|
||||
value: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _relation_to_out(row: MemoryRelation) -> RelationOut:
|
||||
last_ms: int | None = None
|
||||
if row.last_confirmed_at is not None:
|
||||
last_ms = int(row.last_confirmed_at.timestamp() * 1000)
|
||||
return RelationOut(
|
||||
id=row.id,
|
||||
subject_label=row.subject_label,
|
||||
subject_type=row.subject_type,
|
||||
predicate=row.predicate,
|
||||
object_label=row.object_label,
|
||||
object_type=row.object_type,
|
||||
confidence=row.confidence,
|
||||
last_confirmed_at=last_ms,
|
||||
)
|
||||
|
||||
|
||||
# ── Routes ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/core", response_model=dict[str, str])
|
||||
async def get_core_memory(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Return all core memory k/v pairs (plaintext) for the current user."""
|
||||
mw = MemoryMiddleware(db)
|
||||
blocks = await mw.list_core_blocks(current_user.id)
|
||||
return {b["label"]: b["value"] for b in blocks}
|
||||
|
||||
|
||||
@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_core_key(
|
||||
key: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Delete a single core memory key (GDPR Art. 17)."""
|
||||
mw = MemoryMiddleware(db)
|
||||
deleted = await mw.delete_core(current_user.id, key)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found")
|
||||
|
||||
|
||||
@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str])
|
||||
async def add_core_key(
|
||||
body: CoreAddBody,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Add or overwrite a core memory key/value pair."""
|
||||
mw = MemoryMiddleware(db)
|
||||
await mw.update_core(current_user.id, body.key, body.value)
|
||||
return {body.key: body.value}
|
||||
|
||||
|
||||
@router.get("/relational", response_model=list[RelationOut])
|
||||
async def get_relational_memory(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[RelationOut]:
|
||||
"""Return all relational memory rows for the current user."""
|
||||
mw = MemoryMiddleware(db)
|
||||
rows = await mw.query_relations(current_user.id, limit=200)
|
||||
return [_relation_to_out(r) for r in rows]
|
||||
|
||||
|
||||
@router.patch("/relational/{relation_id}", response_model=RelationOut)
|
||||
async def patch_relation(
|
||||
relation_id: str,
|
||||
body: RelationPatch,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> RelationOut:
|
||||
"""Edit a relation row's labels, predicate, or confidence."""
|
||||
if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}",
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(
|
||||
MemoryRelation.id == relation_id,
|
||||
MemoryRelation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||
|
||||
if body.subject_label is not None:
|
||||
row.subject_label = body.subject_label
|
||||
if body.object_label is not None:
|
||||
row.object_label = body.object_label
|
||||
if body.predicate is not None:
|
||||
row.predicate = body.predicate
|
||||
if body.confidence is not None:
|
||||
row.confidence = body.confidence
|
||||
row.last_confirmed_at = datetime.now(timezone.utc)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(row)
|
||||
logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id)
|
||||
return _relation_to_out(row)
|
||||
|
||||
|
||||
@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_relation(
|
||||
relation_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Hard-delete a relation row (GDPR Art. 17)."""
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(
|
||||
MemoryRelation.id == relation_id,
|
||||
MemoryRelation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||
await db.delete(row)
|
||||
await db.commit()
|
||||
logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id)
|
||||
|
||||
|
||||
@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def forget_all(
|
||||
x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Wipe all memory tiers for the current user (GDPR Art. 17).
|
||||
|
||||
Requires ``X-Confirm: true`` header. Does NOT delete the user account.
|
||||
"""
|
||||
if x_confirm != "true":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.",
|
||||
)
|
||||
|
||||
uid = current_user.id
|
||||
await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid))
|
||||
await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid))
|
||||
await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid))
|
||||
await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid))
|
||||
await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid))
|
||||
await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid))
|
||||
await db.commit()
|
||||
logger.warning("memory: forget_all GDPR wipe user=%s", uid)
|
||||
513
api/app/api/routes/scout_setup.py
Normal file
513
api/app/api/routes/scout_setup.py
Normal file
@@ -0,0 +1,513 @@
|
||||
"""Chatbot Journey — WS-based guided conversation to build an ScoutConfig.
|
||||
|
||||
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
||||
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
||||
frames to the functions exported here.
|
||||
|
||||
Journey flow:
|
||||
1. FE sends ``journey_start`` frame with basic agent info (directory,
|
||||
data_types, schedule).
|
||||
2. Server creates an in-memory session, sets up a WS executor so the
|
||||
setup LLM can use file-system tools, does a first directory scrape,
|
||||
and sends back a ``journey_reply`` with the first question.
|
||||
3. FE sends ``journey_message`` frames for each user reply.
|
||||
4. Server appends the user message, calls the LLM (which may read files
|
||||
via tools), and sends back a ``journey_reply``.
|
||||
5. After 3-5 turns the LLM wraps up by emitting an ``ScoutConfig`` JSON
|
||||
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
|
||||
6. Server parses and validates the JSON with Pydantic, sends
|
||||
``journey_reply`` with ``done=True`` and the serialised config.
|
||||
FE stores it locally.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
|
||||
from app.agents.filesystem_agent import make_directory_tools
|
||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||
from app.core.llm import get_agent_llm, model_for_agent
|
||||
from app.schemas import ScoutConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||
|
||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||
|
||||
# Sentinel strings used to delimit the LLM-produced ScoutConfig JSON.
|
||||
_CONFIG_START = "AGENT_CONFIG_START"
|
||||
_CONFIG_END = "AGENT_CONFIG_END"
|
||||
|
||||
# Minimum turns before we consider nudging the LLM to wrap up.
|
||||
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
|
||||
_MAX_TURNS: int = 15
|
||||
# Max tool-calling steps per LLM invocation.
|
||||
_MAX_TOOL_STEPS: int = 6
|
||||
|
||||
# ── In-memory session store ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class JourneySession:
|
||||
session_id: str
|
||||
user_id: str
|
||||
agent_type: str # "local" | "cloud"
|
||||
directory: str
|
||||
data_types: list[str]
|
||||
history: list[dict[str, Any]] = field(default_factory=list)
|
||||
system_prompt: str = ""
|
||||
langfuse_prompt: Any = None
|
||||
created_at: float = field(default_factory=time.monotonic)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
|
||||
|
||||
|
||||
# session_id → session
|
||||
_sessions: dict[str, JourneySession] = {}
|
||||
|
||||
|
||||
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||
s = _sessions.get(session_id)
|
||||
if s is None or s.is_expired():
|
||||
_sessions.pop(session_id, None)
|
||||
return None
|
||||
if s.user_id != user_id:
|
||||
return None
|
||||
return s
|
||||
|
||||
|
||||
# ── System prompt ─────────────────────────────────────────────────────────
|
||||
|
||||
_JOURNEY_SYSTEM_PROMPT = """\
|
||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||
Your job is to understand what files the user has in their directory and produce a
|
||||
structured ScoutConfig JSON that the extraction agent will use as its instruction set.
|
||||
|
||||
You have access to file-system tools to explore the user's directory:
|
||||
- list_directory: see folder structure and file names
|
||||
- read_file_content: peek at a file's content
|
||||
- get_file_metadata: check file size, extension, dates
|
||||
|
||||
The user's configured directory is: {directory}
|
||||
Target data types: {data_types}
|
||||
|
||||
## Your process
|
||||
|
||||
### Step 1 — Explore the directory
|
||||
Use list_directory and read_file_content to understand what types of files are present
|
||||
(HTML emails, plain-text documents, CSVs, etc.).
|
||||
|
||||
### Step 2 — Identify content types
|
||||
For each distinct file type found, decide:
|
||||
- A short id (e.g. "email_html", "plain_text", "csv")
|
||||
- Which preprocessing handler to use: "email_html" for HTML emails, "generic" for everything else
|
||||
- A human-readable label and optional detection_hint
|
||||
|
||||
### Step 3 — Ask focused questions (one at a time)
|
||||
Cover these topics based on what you discovered:
|
||||
1. How to map content to entity types (task / note / timeline entry)
|
||||
2. Field mapping rules (e.g. email Subject → task title, filename → note title)
|
||||
3. Priority or status rules (e.g. "urgent" in subject → high priority)
|
||||
4. Date extraction (e.g. "by Friday" → dueDate)
|
||||
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
|
||||
|
||||
### Step 4 — Produce the ScoutConfig JSON
|
||||
Once you are ≥ 90% confident, output the final config between these exact markers
|
||||
(each on its own line):
|
||||
|
||||
{config_start}
|
||||
{{
|
||||
"content_types": [
|
||||
{{
|
||||
"id": "email_html",
|
||||
"label": "Email HTML",
|
||||
"detection_hint": "HTML file with From/To/Subject headers",
|
||||
"preprocessing": "email_html",
|
||||
"extraction_prompt": "Detailed extraction instructions for this content type..."
|
||||
}}
|
||||
],
|
||||
"global_rules": [
|
||||
"If the file cannot be matched to any project, do not create any entity."
|
||||
],
|
||||
"data_types": {data_types_json}
|
||||
}}
|
||||
{config_end}
|
||||
|
||||
## Rules for the extraction_prompt field
|
||||
- Describe when to create a task vs note vs timeline entry (be specific and concrete)
|
||||
- Include field mapping rules based on what you found in the directory
|
||||
- Include priority/status/date rules if applicable
|
||||
- Do NOT include projectId logic — the runner handles project assignment automatically
|
||||
- Do NOT mention isAiSuggested — the runner always sets it to 1
|
||||
|
||||
## Constraints
|
||||
- Never ask about projects, projectId, or how to link records to projects
|
||||
- Never include projectId or project creation logic in the generated config
|
||||
- Keep asking questions until ≥ 90% confident, then output the JSON immediately
|
||||
|
||||
{existing_section}\
|
||||
Begin by exploring the directory, then ask your first question.\
|
||||
"""
|
||||
|
||||
|
||||
def _build_system_prompt(
|
||||
directory: str,
|
||||
data_types: list[str],
|
||||
existing_config: str | None = None,
|
||||
) -> tuple[str, Any]:
|
||||
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
|
||||
existing_section = (
|
||||
"\nThe user already has the following ScoutConfig — refine it based on their answers:\n"
|
||||
f"```json\n{existing_config}\n```\n"
|
||||
if existing_config
|
||||
else ""
|
||||
)
|
||||
template, prompt_obj = get_prompt_or_fallback(
|
||||
"journey_system", _JOURNEY_SYSTEM_PROMPT
|
||||
)
|
||||
compiled = compile_prompt(
|
||||
template,
|
||||
prompt_obj,
|
||||
directory=directory,
|
||||
data_types=", ".join(data_types),
|
||||
data_types_json=json.dumps(data_types),
|
||||
config_start=_CONFIG_START,
|
||||
config_end=_CONFIG_END,
|
||||
existing_section=existing_section,
|
||||
)
|
||||
return compiled, prompt_obj
|
||||
|
||||
|
||||
# ── ScoutConfig extraction ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _extract_agent_config(text: str) -> str | None:
|
||||
"""Return validated ScoutConfig JSON string from between markers, or None.
|
||||
|
||||
Parses the JSON with Pydantic to ensure it conforms to the schema before
|
||||
returning. Returns None if markers are absent or JSON is invalid.
|
||||
"""
|
||||
if _CONFIG_START not in text or _CONFIG_END not in text:
|
||||
return None
|
||||
start_idx = text.index(_CONFIG_START) + len(_CONFIG_START)
|
||||
end_idx = text.index(_CONFIG_END)
|
||||
raw = text[start_idx:end_idx].strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
parsed = ScoutConfig.model_validate_json(raw)
|
||||
return parsed.model_dump_json()
|
||||
except Exception as exc:
|
||||
logger.warning("agent_setup: failed to parse ScoutConfig JSON: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
# ── LLM call with tool support ───────────────────────────────────────────
|
||||
|
||||
|
||||
def _as_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
async def _call_llm_with_tools(
|
||||
system_prompt: str,
|
||||
history: list[dict[str, Any]],
|
||||
tools: list[Any],
|
||||
*,
|
||||
user_id: str = "",
|
||||
session_id: str = "",
|
||||
langfuse_prompt: Any = None,
|
||||
) -> str:
|
||||
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||
|
||||
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||
continue until a final text response is produced.
|
||||
"""
|
||||
lf = get_langfuse()
|
||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||
for turn in history:
|
||||
if turn["role"] == "user":
|
||||
messages.append(HumanMessage(content=turn["content"]))
|
||||
else:
|
||||
messages.append(AIMessage(content=turn["content"]))
|
||||
|
||||
llm = get_agent_llm("setup", temperature=0.4)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||
|
||||
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
|
||||
_lf_ctx.__enter__()
|
||||
|
||||
_span_ctx = (
|
||||
lf.start_as_current_observation(
|
||||
as_type="span",
|
||||
name="journey-setup",
|
||||
input=history[-1]["content"] if history else "",
|
||||
)
|
||||
if lf else None
|
||||
)
|
||||
_span = _span_ctx.__enter__() if _span_ctx else None
|
||||
|
||||
try:
|
||||
for step in range(_MAX_TOOL_STEPS):
|
||||
_gen_ctx = (
|
||||
lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="journey-setup-llm",
|
||||
model=model_for_agent("setup"),
|
||||
prompt=langfuse_prompt,
|
||||
input=messages,
|
||||
)
|
||||
if lf else None
|
||||
)
|
||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
if _gen_ctx:
|
||||
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||
_gen_ctx.__exit__(None, None, None)
|
||||
|
||||
resp_text = _as_text(response.content)
|
||||
|
||||
# Guard against empty responses (e.g. model returned finish_reason
|
||||
# 'error' which LiteLLM maps to 'stop' with empty content).
|
||||
if not response.tool_calls and not resp_text.strip():
|
||||
logger.warning(
|
||||
"agent_setup: journey LLM returned empty response at step %d — retrying",
|
||||
step,
|
||||
)
|
||||
# Drop the empty AIMessage so we don't pollute history, and retry.
|
||||
continue
|
||||
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
if _span:
|
||||
_span.update(output=resp_text)
|
||||
return resp_text
|
||||
|
||||
for call in response.tool_calls:
|
||||
call_name = str(call.get("name", ""))
|
||||
call_args = call.get("args", {})
|
||||
logger.info(
|
||||
"agent_setup: journey tool_call name=%s args=%s",
|
||||
call_name,
|
||||
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||
)
|
||||
|
||||
tool_fn = tool_map.get(call_name)
|
||||
if tool_fn is None:
|
||||
tool_output = f"Unknown tool: {call_name}"
|
||||
else:
|
||||
tool_output = await tool_fn.ainvoke(call_args)
|
||||
|
||||
logger.info(
|
||||
"agent_setup: journey tool_result name=%s output=%s",
|
||||
call_name,
|
||||
str(tool_output)[:800],
|
||||
)
|
||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||
|
||||
# Fallback: exceeded max steps.
|
||||
final = await llm.ainvoke(messages)
|
||||
final_text = _as_text(final.content)
|
||||
if _span:
|
||||
_span.update(output=final_text)
|
||||
return final_text or (
|
||||
"Sorry, I had trouble processing the files. "
|
||||
"Could you try again? If the issue persists, the files might be too large for me to analyse."
|
||||
)
|
||||
finally:
|
||||
if _span_ctx:
|
||||
_span_ctx.__exit__(None, None, None)
|
||||
_lf_ctx.__exit__(None, None, None)
|
||||
if lf:
|
||||
lf.flush()
|
||||
|
||||
|
||||
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
||||
|
||||
|
||||
async def handle_journey_start(
|
||||
user_id: str,
|
||||
frame: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Handle a ``journey_start`` WS frame.
|
||||
|
||||
Creates a session, runs the setup LLM with directory exploration,
|
||||
and returns the ``journey_reply`` payload.
|
||||
"""
|
||||
agent_type = frame.get("agent_type", "local")
|
||||
directory = frame.get("directory", "")
|
||||
data_types = frame.get("data_types", [])
|
||||
existing_config = frame.get("existing_config")
|
||||
|
||||
# Use the session_id provided by the FE so the reply matches the
|
||||
# listener key; fall back to a generated one if absent.
|
||||
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_config)
|
||||
|
||||
session = JourneySession(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
agent_type=agent_type,
|
||||
directory=directory,
|
||||
data_types=data_types,
|
||||
system_prompt=system_prompt,
|
||||
langfuse_prompt=langfuse_prompt,
|
||||
)
|
||||
|
||||
# Seed with an initial user message — some providers require at least one
|
||||
# user/input message to be present.
|
||||
seed_history: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||
]
|
||||
ai_reply = await _call_llm_with_tools(
|
||||
system_prompt=system_prompt,
|
||||
history=seed_history,
|
||||
tools=make_directory_tools(directory),
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
langfuse_prompt=langfuse_prompt,
|
||||
)
|
||||
|
||||
session.history.extend(seed_history)
|
||||
session.history.append({"role": "assistant", "content": ai_reply})
|
||||
_sessions[session_id] = session
|
||||
|
||||
logger.info(
|
||||
"agent_setup: journey session %s started for user %s (directory=%s)",
|
||||
session_id,
|
||||
user_id,
|
||||
directory,
|
||||
)
|
||||
|
||||
# Check if the LLM produced the config on the first turn (unlikely but possible).
|
||||
agent_config = _extract_agent_config(ai_reply)
|
||||
done = agent_config is not None
|
||||
|
||||
display_message = ai_reply
|
||||
if done:
|
||||
display_message = (
|
||||
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
||||
or "Here is your agent configuration. You can save it or continue refining."
|
||||
)
|
||||
_sessions.pop(session_id, None)
|
||||
|
||||
return {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": display_message,
|
||||
"done": done,
|
||||
"agent_config": agent_config,
|
||||
}
|
||||
|
||||
|
||||
async def handle_journey_message(
|
||||
user_id: str,
|
||||
frame: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Handle a ``journey_message`` WS frame.
|
||||
|
||||
Appends the user message, calls the LLM, and returns the
|
||||
``journey_reply`` payload.
|
||||
"""
|
||||
session_id = frame.get("session_id", "")
|
||||
message = frame.get("message", "")
|
||||
|
||||
session = get_journey_session(session_id, user_id)
|
||||
if session is None:
|
||||
return {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": "Journey session not found or expired. Please start a new setup.",
|
||||
"done": True,
|
||||
"agent_config": None,
|
||||
}
|
||||
|
||||
# Append user turn.
|
||||
session.history.append({"role": "user", "content": message})
|
||||
|
||||
# Call the LLM with tools.
|
||||
session_tools = make_directory_tools(session.directory)
|
||||
ai_reply = await _call_llm_with_tools(
|
||||
system_prompt=session.system_prompt,
|
||||
history=session.history,
|
||||
tools=session_tools,
|
||||
user_id=session.user_id,
|
||||
session_id=session_id,
|
||||
langfuse_prompt=session.langfuse_prompt,
|
||||
)
|
||||
|
||||
session.history.append({"role": "assistant", "content": ai_reply})
|
||||
|
||||
# Check if the LLM produced the final config.
|
||||
agent_config = _extract_agent_config(ai_reply)
|
||||
done = agent_config is not None
|
||||
|
||||
# If the LLM didn't produce a config, nudge it once it hits the hard safety cap.
|
||||
if not done:
|
||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||
if turns >= _MAX_TURNS:
|
||||
nudge_content = (
|
||||
"[System: You have enough information. Please generate the final "
|
||||
f"ScoutConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
|
||||
)
|
||||
session.history.append({"role": "user", "content": nudge_content})
|
||||
|
||||
nudge_reply = await _call_llm_with_tools(
|
||||
system_prompt=session.system_prompt,
|
||||
history=session.history,
|
||||
tools=session_tools,
|
||||
user_id=session.user_id,
|
||||
session_id=session_id,
|
||||
langfuse_prompt=session.langfuse_prompt,
|
||||
)
|
||||
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||
|
||||
agent_config = _extract_agent_config(nudge_reply)
|
||||
if agent_config is not None:
|
||||
done = True
|
||||
ai_reply = nudge_reply
|
||||
|
||||
display_message = ai_reply
|
||||
if done:
|
||||
display_message = (
|
||||
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
||||
if _CONFIG_START in ai_reply
|
||||
else "Here is your agent configuration. You can save it or continue refining."
|
||||
)
|
||||
_sessions.pop(session_id, None)
|
||||
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
|
||||
|
||||
return {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": display_message,
|
||||
"done": done,
|
||||
"agent_config": agent_config,
|
||||
}
|
||||
120
api/app/api/routes/scout_webhooks.py
Normal file
120
api/app/api/routes/scout_webhooks.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Gmail Pub/Sub push receiver.
|
||||
|
||||
Google Pub/Sub push subscriptions deliver Gmail watch notifications as POST
|
||||
requests with a JSON envelope. The body payload contains a base64-encoded
|
||||
JSON blob with ``emailAddress`` + ``historyId``. We resolve the user by
|
||||
email, look up their cloud_scout_configs row for provider='gmail', and
|
||||
hand off to ScoutEngine.trigger_scout.
|
||||
|
||||
Authentication: Pub/Sub push includes an OIDC JWT in the Authorization
|
||||
header. We verify it against Google's public keys with the audience
|
||||
configured in our Pub/Sub subscription.
|
||||
|
||||
Dev mode: when ``GMAIL_PUBSUB_AUDIENCE`` is empty, JWT verification is
|
||||
skipped and a warning is logged. Production must set this env var.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.db import async_session
|
||||
from app.models import CloudScoutConfig, User
|
||||
from app.scouts.engine import ScoutEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/scouts/webhooks", tags=["scout-webhooks"])
|
||||
|
||||
|
||||
def _verify_pubsub_jwt(token: str) -> bool:
|
||||
"""Verify the Google Pub/Sub OIDC JWT.
|
||||
|
||||
Returns True when valid, False on any verification failure.
|
||||
|
||||
Dev skip: if ``settings.GMAIL_PUBSUB_AUDIENCE`` is empty, logs a
|
||||
warning and returns True so local development works without a real
|
||||
Pub/Sub subscription. Production must configure the audience.
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
|
||||
if not settings.GMAIL_PUBSUB_AUDIENCE:
|
||||
logger.warning(
|
||||
"GMAIL_PUBSUB_AUDIENCE not set — skipping Pub/Sub JWT verification (dev mode only)"
|
||||
)
|
||||
return True
|
||||
|
||||
try:
|
||||
from google.auth.transport import requests as g_requests # noqa: PLC0415
|
||||
from google.oauth2 import id_token # noqa: PLC0415
|
||||
|
||||
id_token.verify_oauth2_token(
|
||||
token,
|
||||
g_requests.Request(),
|
||||
audience=settings.GMAIL_PUBSUB_AUDIENCE,
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.warning("pubsub jwt verification failed", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/gmail", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def gmail_pubsub(
|
||||
request: Request,
|
||||
authorization: str = Header(default=""),
|
||||
) -> None:
|
||||
"""Receive a Gmail Pub/Sub push notification.
|
||||
|
||||
Verifies the OIDC JWT, decodes the Pub/Sub envelope, resolves the user
|
||||
by email, and triggers ScoutEngine.trigger_scout for each enabled Gmail
|
||||
scout belonging to that user.
|
||||
|
||||
Returns 204 No Content on success (including benign no-ops like unknown
|
||||
email or empty message data). Returns 401 on JWT verification failure.
|
||||
"""
|
||||
token = authorization.removeprefix("Bearer ").strip()
|
||||
if not _verify_pubsub_jwt(token):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid Pub/Sub JWT")
|
||||
|
||||
body = await request.json()
|
||||
msg = body.get("message") or {}
|
||||
raw = msg.get("data")
|
||||
if not raw:
|
||||
return # ack without action — empty message data
|
||||
|
||||
try:
|
||||
decoded = json.loads(base64.b64decode(raw).decode())
|
||||
except Exception:
|
||||
logger.warning("pubsub payload decode failed")
|
||||
return
|
||||
|
||||
email = decoded.get("emailAddress")
|
||||
if not email:
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
user_q = await session.execute(select(User).where(User.email == email))
|
||||
user = user_q.scalar_one_or_none()
|
||||
if user is None:
|
||||
logger.info("pubsub: no user for %s — ignoring", email)
|
||||
return
|
||||
scouts_q = await session.execute(
|
||||
select(CloudScoutConfig).where(
|
||||
CloudScoutConfig.user_id == user.id,
|
||||
CloudScoutConfig.provider == "gmail",
|
||||
CloudScoutConfig.enabled == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
scouts = scouts_q.scalars().all()
|
||||
|
||||
engine = ScoutEngine()
|
||||
for scout in scouts:
|
||||
await engine.trigger_scout(uuid.UUID(str(scout.id)))
|
||||
807
api/app/api/routes/scouts.py
Normal file
807
api/app/api/routes/scouts.py
Normal file
@@ -0,0 +1,807 @@
|
||||
"""Scout routes.
|
||||
|
||||
Backend responsibilities are intentionally minimal:
|
||||
GET /scouts/catalog — static catalog for UI display
|
||||
POST /scouts/can-create — billing eligibility check
|
||||
POST /scouts/trigger — trigger a local scout run
|
||||
|
||||
Scout configuration is owned by the Electron app and is not persisted
|
||||
in backend scout-config tables.
|
||||
|
||||
Gmail OAuth setup (scout-specific consent):
|
||||
GET /scouts/oauth/gmail/authorize — returns consent-screen URL
|
||||
GET /scouts/oauth/gmail/web-callback — bounces to deep link (excluded from schema)
|
||||
POST /scouts/oauth/gmail/callback — exchanges code, stores encrypted token
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy import delete as sa_delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.auth.oauth_providers import generate_pkce_pair
|
||||
from app.billing.tier_manager import FEATURES
|
||||
from app.config.settings import settings
|
||||
from app.core.scout_runner import is_agent_running, run_local_agent
|
||||
from app.core.device_manager import device_manager
|
||||
from app.core.note_summarizer import generate_note_summary
|
||||
from app.db import get_session
|
||||
from app.integrations import decrypt_token, encrypt_token
|
||||
from app.models import CloudScoutConfig, ScoutRunLog, LocalScoutConfig
|
||||
from app.scouts.connectors.registry import get_connector
|
||||
from app.schemas import (
|
||||
CloudScoutCreateRequest,
|
||||
CloudScoutResponse,
|
||||
CloudScoutUpdateRequest,
|
||||
ScoutCatalogItem,
|
||||
ScoutCreationCheckRequest,
|
||||
ScoutCreationCheckResponse,
|
||||
ScoutRunLogResponse,
|
||||
ScoutTriggerRequest,
|
||||
UserProfile,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/scouts", tags=["scouts"])
|
||||
|
||||
|
||||
# ── Datetime helpers ──────────────────────────────────────────────────
|
||||
|
||||
def _dt_ms(dt: datetime) -> int:
|
||||
return int(dt.timestamp() * 1000)
|
||||
|
||||
|
||||
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
def _to_data_types(values: list[str]) -> list[str]:
|
||||
normalize = {
|
||||
"task": "tasks", "tasks": "tasks",
|
||||
"note": "notes", "notes": "notes",
|
||||
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||
"project": "projects", "projects": "projects",
|
||||
}
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for v in values:
|
||||
mapped = normalize.get(v)
|
||||
if mapped and mapped not in seen:
|
||||
seen.add(mapped)
|
||||
result.append(mapped)
|
||||
return result
|
||||
|
||||
|
||||
def _to_run_log_response(log: ScoutRunLog) -> ScoutRunLogResponse:
|
||||
return ScoutRunLogResponse(
|
||||
id=log.id,
|
||||
agent_id=log.scout_id,
|
||||
agent_type=log.scout_type, # type: ignore[arg-type]
|
||||
status=log.status, # type: ignore[arg-type]
|
||||
items_processed=log.items_processed,
|
||||
items_created=log.items_created,
|
||||
errors=log.errors or [],
|
||||
started_at=_dt_ms(log.started_at),
|
||||
completed_at=_dt_ms_opt(log.completed_at),
|
||||
)
|
||||
|
||||
|
||||
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||
if limit != -1 and current_count >= limit:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||
)
|
||||
return limit
|
||||
|
||||
|
||||
async def _enforce_run_frequency(
|
||||
tier: str,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||
if limit == -1:
|
||||
return # unlimited
|
||||
|
||||
today_start = datetime.now(timezone.utc).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
result = await db.execute(
|
||||
select(func.count(ScoutRunLog.id)).where(
|
||||
ScoutRunLog.user_id == user_id,
|
||||
ScoutRunLog.started_at >= today_start,
|
||||
)
|
||||
)
|
||||
runs_today: int = result.scalar_one()
|
||||
|
||||
if runs_today >= limit:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
||||
)
|
||||
|
||||
|
||||
# ── Catalog ───────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/catalog", response_model=list[ScoutCatalogItem])
|
||||
async def get_agent_catalog(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> list[ScoutCatalogItem]:
|
||||
"""Return the static list of available agent types and their descriptions."""
|
||||
return [
|
||||
ScoutCatalogItem(
|
||||
type="local_directory",
|
||||
name="Local Directory Monitor",
|
||||
description="Watches local directories, extracts data from files using AI",
|
||||
),
|
||||
ScoutCatalogItem(
|
||||
type="gmail",
|
||||
name="Gmail Connector",
|
||||
description="Scans Gmail inbox, extracts tasks/notes from emails",
|
||||
),
|
||||
ScoutCatalogItem(
|
||||
type="teams",
|
||||
name="Microsoft Teams Connector",
|
||||
description="Monitors Teams messages, extracts action items",
|
||||
),
|
||||
ScoutCatalogItem(
|
||||
type="outlook",
|
||||
name="Outlook Connector",
|
||||
description="Scans Outlook inbox, extracts tasks/notes",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@router.post("/can-create", response_model=ScoutCreationCheckResponse)
|
||||
async def can_create_agent(
|
||||
body: ScoutCreationCheckRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> ScoutCreationCheckResponse:
|
||||
"""Check if the user can create one more agent based on billing tier.
|
||||
|
||||
Since configuration is client-owned, the Electron app sends its current
|
||||
active agent count and the backend applies tier limits.
|
||||
"""
|
||||
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
||||
allowed = limit == -1 or body.active_agents < limit
|
||||
return ScoutCreationCheckResponse(
|
||||
allowed=allowed,
|
||||
tier=current_user.tier,
|
||||
active_agents=body.active_agents,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/trigger", response_model=ScoutRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||
async def trigger_agent_run(
|
||||
body: ScoutTriggerRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> ScoutRunLogResponse:
|
||||
"""Trigger a local agent run using client-provided configuration."""
|
||||
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
||||
|
||||
last_run_dt = (
|
||||
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
|
||||
if body.last_run_at
|
||||
else None
|
||||
)
|
||||
config = LocalScoutConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=current_user.id,
|
||||
device_id=body.device_id,
|
||||
name="Local Directory Monitor",
|
||||
directory_paths=[body.directory],
|
||||
data_types=_to_data_types(body.what_to_extract),
|
||||
prompt_template=body.custom_agent_prompt or "",
|
||||
scout_config=body.agent_config,
|
||||
file_extensions=[],
|
||||
schedule_cron=body.batch_interval,
|
||||
enabled=True,
|
||||
last_run_at=last_run_dt,
|
||||
)
|
||||
|
||||
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
||||
stable_agent_id = body.agent_id or config.id
|
||||
|
||||
if is_agent_running(stable_agent_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Agent is already running. Only one run per agent is allowed at a time.",
|
||||
)
|
||||
|
||||
run_log = ScoutRunLog(
|
||||
scout_id=stable_agent_id,
|
||||
scout_type="local",
|
||||
user_id=current_user.id,
|
||||
status="running",
|
||||
)
|
||||
db.add(run_log)
|
||||
await db.commit()
|
||||
await db.refresh(run_log)
|
||||
|
||||
run_context = {
|
||||
"type": "agent_batch",
|
||||
"run_id": run_log.id,
|
||||
"agent_id": stable_agent_id,
|
||||
}
|
||||
|
||||
asyncio.create_task(
|
||||
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
||||
)
|
||||
|
||||
return _to_run_log_response(run_log)
|
||||
|
||||
|
||||
# ── Note summary endpoint ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class NoteSummarizeRequest(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
|
||||
|
||||
class NoteSummarizeResponse(BaseModel):
|
||||
summary: str
|
||||
|
||||
|
||||
@router.post("/notes/summarize", response_model=NoteSummarizeResponse)
|
||||
async def summarize_note(
|
||||
body: NoteSummarizeRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> NoteSummarizeResponse:
|
||||
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
|
||||
summary = await generate_note_summary(body.title, body.content)
|
||||
return NoteSummarizeResponse(summary=summary)
|
||||
|
||||
|
||||
# ── Cloud scout CRUD ──────────────────────────────────────────────────────────
|
||||
|
||||
_DEFAULT_CLOUD_SCHEDULE = "0 */6 * * *"
|
||||
|
||||
|
||||
def _to_cloud_response(scout: CloudScoutConfig) -> dict:
|
||||
return {
|
||||
"id": scout.id,
|
||||
"user_id": scout.user_id,
|
||||
"provider": scout.provider,
|
||||
"name": scout.name,
|
||||
"data_types": scout.data_types or [],
|
||||
"prompt_template": scout.prompt_template or "",
|
||||
"schedule_cron": scout.schedule_cron,
|
||||
"filter_config": scout.filter_config,
|
||||
"auto_trash_spam": scout.auto_trash_spam,
|
||||
"enabled": scout.enabled,
|
||||
"last_run_at": _dt_ms_opt(scout.last_run_at),
|
||||
"gmail_address": scout.gmail_address,
|
||||
"oauth_connected": scout.oauth_token_encrypted is not None,
|
||||
"created_at": _dt_ms(scout.created_at),
|
||||
"updated_at": _dt_ms(scout.updated_at),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/cloud", response_model=list[CloudScoutResponse])
|
||||
async def list_cloud_scouts(
|
||||
db: AsyncSession = Depends(get_session),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
):
|
||||
rows = (await db.execute(
|
||||
select(CloudScoutConfig).where(CloudScoutConfig.user_id == current_user.id)
|
||||
)).scalars().all()
|
||||
return [_to_cloud_response(s) for s in rows]
|
||||
|
||||
|
||||
@router.post("/cloud", response_model=CloudScoutResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_cloud_scout(
|
||||
body: CloudScoutCreateRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
):
|
||||
scout = CloudScoutConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=current_user.id,
|
||||
provider=body.provider,
|
||||
name=body.name,
|
||||
data_types=body.data_types,
|
||||
prompt_template=body.prompt_template,
|
||||
filter_config=body.filter_config,
|
||||
schedule_cron=body.schedule_cron or _DEFAULT_CLOUD_SCHEDULE,
|
||||
auto_trash_spam=body.auto_trash_spam,
|
||||
enabled=True,
|
||||
)
|
||||
db.add(scout)
|
||||
await db.commit()
|
||||
await db.refresh(scout)
|
||||
return _to_cloud_response(scout)
|
||||
|
||||
|
||||
@router.put("/cloud/{scout_id}", response_model=CloudScoutResponse)
|
||||
async def update_cloud_scout(
|
||||
scout_id: str,
|
||||
body: CloudScoutUpdateRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
):
|
||||
scout = await db.get(CloudScoutConfig, scout_id)
|
||||
if scout is None or scout.user_id != current_user.id:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found")
|
||||
if body.name is not None:
|
||||
scout.name = body.name
|
||||
if body.data_types is not None:
|
||||
scout.data_types = body.data_types
|
||||
if body.prompt_template is not None:
|
||||
scout.prompt_template = body.prompt_template
|
||||
if body.schedule_cron is not None:
|
||||
scout.schedule_cron = body.schedule_cron
|
||||
if body.filter_config is not None:
|
||||
scout.filter_config = body.filter_config
|
||||
if body.auto_trash_spam is not None:
|
||||
scout.auto_trash_spam = body.auto_trash_spam
|
||||
if body.enabled is not None:
|
||||
scout.enabled = body.enabled
|
||||
await db.commit()
|
||||
await db.refresh(scout)
|
||||
return _to_cloud_response(scout)
|
||||
|
||||
|
||||
@router.delete("/cloud/{scout_id}")
|
||||
async def delete_cloud_scout(
|
||||
scout_id: str,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
):
|
||||
scout = await db.get(CloudScoutConfig, scout_id)
|
||||
if scout is None or scout.user_id != current_user.id:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found")
|
||||
# Core deletes bypass the polymorphic ScoutRunLog relationship whose
|
||||
# varchar scout_id vs uuid id join is not directly comparable in Postgres.
|
||||
# scout_run_logs.scout_id is a plain string (matches the str scout_id);
|
||||
# scout_triage_queue rows cascade automatically via their FK ondelete.
|
||||
await db.execute(sa_delete(ScoutRunLog).where(ScoutRunLog.scout_id == scout_id))
|
||||
await db.execute(sa_delete(CloudScoutConfig).where(CloudScoutConfig.id == scout_id))
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/cloud/{scout_id}/gmail-labels")
|
||||
async def list_gmail_labels(
|
||||
scout_id: str,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
):
|
||||
scout = await db.get(CloudScoutConfig, scout_id)
|
||||
if scout is None or scout.user_id != current_user.id:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found")
|
||||
try:
|
||||
connector = get_connector("gmail")
|
||||
except KeyError:
|
||||
return []
|
||||
return await connector.list_labels(scout)
|
||||
|
||||
|
||||
@router.post("/cloud/{scout_id}/gmail-disconnect", response_model=CloudScoutResponse)
|
||||
async def disconnect_gmail(
|
||||
scout_id: str,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
):
|
||||
scout = await db.get(CloudScoutConfig, scout_id)
|
||||
if scout is None or scout.user_id != current_user.id:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found")
|
||||
try:
|
||||
connector = get_connector("gmail")
|
||||
await connector.stop_watch(scout)
|
||||
except KeyError:
|
||||
pass
|
||||
scout.oauth_token_encrypted = None
|
||||
scout.gmail_history_id = None
|
||||
scout.gmail_watch_expires_at = None
|
||||
scout.gmail_address = None
|
||||
scout.enabled = False
|
||||
await db.commit()
|
||||
await db.refresh(scout)
|
||||
return _to_cloud_response(scout)
|
||||
|
||||
|
||||
# ── Gmail OAuth setup (scout-specific) ───────────────────────────────────────
|
||||
|
||||
# Scopes required for Gmail scout connectivity.
|
||||
_GMAIL_SCOUT_SCOPES = [
|
||||
"openid",
|
||||
"email",
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
]
|
||||
|
||||
# Google OAuth endpoints.
|
||||
_GOOGLE_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
_GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# In-memory pending OAuth states for scout Gmail consent.
|
||||
#
|
||||
# state → {
|
||||
# "code_verifier": str,
|
||||
# "user_id": str,
|
||||
# "expires_at": float (epoch seconds),
|
||||
# "mode": "reconnect" | "create",
|
||||
# "scout_id": str | None, # set for reconnect mode
|
||||
# "draft": {name, prompt_template, auto_trash_spam} | None, # set for create mode
|
||||
# "token_encrypted": str | None, # populated after a successful create-mode callback
|
||||
# "gmail_address": str | None,
|
||||
# }
|
||||
#
|
||||
# Zero-trust: in create mode the encrypted Gmail token lives ONLY here, in
|
||||
# process memory, for at most _SCOUT_OAUTH_TTL_SECONDS. It is persisted to the
|
||||
# DB only when the user finalizes the scout (POST /scouts/cloud/finalize).
|
||||
# An abandoned/errored flow leaves no scout row and no stored token.
|
||||
#
|
||||
# Production note: this in-memory store is single-process only — replace with
|
||||
# Redis (keyed by state, TTL'd) for multi-worker deployments.
|
||||
_pending_scout_oauth_states: dict[str, dict] = {}
|
||||
_SCOUT_OAUTH_TTL_SECONDS = 900 # 15 minutes
|
||||
|
||||
|
||||
def _purge_expired_oauth_states() -> None:
|
||||
now = time.time()
|
||||
expired = [s for s, e in _pending_scout_oauth_states.items() if e.get("expires_at", 0) < now]
|
||||
for s in expired:
|
||||
del _pending_scout_oauth_states[s]
|
||||
|
||||
|
||||
def _scout_gmail_redirect_uri() -> str:
|
||||
"""Derive the scout Gmail web-callback URI from the configured base OAUTH_REDIRECT_URI.
|
||||
|
||||
``OAUTH_REDIRECT_URI`` is the full path used for login OAuth
|
||||
(e.g. http://localhost:8000/api/v1/auth/oauth/google/web-callback).
|
||||
We strip the path to get the scheme+host base, then append the scout path.
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(settings.OAUTH_REDIRECT_URI)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
return f"{base}/api/v1/scouts/oauth/gmail/web-callback"
|
||||
|
||||
|
||||
class _ScoutGmailAuthorizeResponse(BaseModel):
|
||||
authorize_url: str
|
||||
|
||||
|
||||
class _ScoutGmailCallbackBody(BaseModel):
|
||||
code: str
|
||||
state: str
|
||||
|
||||
|
||||
class _ScoutGmailAuthorizeDraftBody(BaseModel):
|
||||
name: str
|
||||
prompt_template: str = ""
|
||||
auto_trash_spam: bool = False
|
||||
|
||||
|
||||
class _ScoutGmailFinalizeBody(BaseModel):
|
||||
session: str
|
||||
filter_config: dict | None = None
|
||||
|
||||
|
||||
def _build_gmail_authorize_url(state: str, code_challenge: str) -> str:
|
||||
"""Build the Google consent URL for the scout Gmail flow (shared by both modes)."""
|
||||
redirect_uri = _scout_gmail_redirect_uri()
|
||||
params = {
|
||||
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(_GMAIL_SCOUT_SCOPES),
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
}
|
||||
return f"{_GOOGLE_AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
|
||||
@router.get("/oauth/gmail/authorize", response_model=_ScoutGmailAuthorizeResponse)
|
||||
async def scout_gmail_oauth_authorize(
|
||||
scout_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _ScoutGmailAuthorizeResponse:
|
||||
"""Start the Gmail OAuth flow for a specific cloud scout.
|
||||
|
||||
Returns the Google consent-screen URL. The client opens this URL in the
|
||||
system browser; after consent Google redirects to web-callback which bounces
|
||||
to the ``adiuvai://scout/oauth/gmail/callback`` deep link.
|
||||
"""
|
||||
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
"Google OAuth is not configured on this server",
|
||||
)
|
||||
|
||||
code_verifier, code_challenge = generate_pkce_pair()
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
_purge_expired_oauth_states()
|
||||
|
||||
_pending_scout_oauth_states[state] = {
|
||||
"code_verifier": code_verifier,
|
||||
"user_id": current_user.id,
|
||||
"expires_at": time.time() + _SCOUT_OAUTH_TTL_SECONDS,
|
||||
"mode": "reconnect",
|
||||
"scout_id": scout_id,
|
||||
"draft": None,
|
||||
"token_encrypted": None,
|
||||
"gmail_address": None,
|
||||
}
|
||||
|
||||
return _ScoutGmailAuthorizeResponse(
|
||||
authorize_url=_build_gmail_authorize_url(state, code_challenge)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/oauth/gmail/authorize-draft", response_model=_ScoutGmailAuthorizeResponse)
|
||||
async def scout_gmail_oauth_authorize_draft(
|
||||
body: _ScoutGmailAuthorizeDraftBody,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _ScoutGmailAuthorizeResponse:
|
||||
"""Start the Gmail OAuth flow in *creation* mode — no scout row exists yet.
|
||||
|
||||
The draft scout fields are held in the pending OAuth session; the scout is
|
||||
only created once the user finalizes (POST /scouts/cloud/finalize).
|
||||
"""
|
||||
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
"Google OAuth is not configured on this server",
|
||||
)
|
||||
|
||||
code_verifier, code_challenge = generate_pkce_pair()
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
_purge_expired_oauth_states()
|
||||
|
||||
_pending_scout_oauth_states[state] = {
|
||||
"code_verifier": code_verifier,
|
||||
"user_id": current_user.id,
|
||||
"expires_at": time.time() + _SCOUT_OAUTH_TTL_SECONDS,
|
||||
"mode": "create",
|
||||
"scout_id": None,
|
||||
"draft": {
|
||||
"name": body.name,
|
||||
"prompt_template": body.prompt_template,
|
||||
"auto_trash_spam": body.auto_trash_spam,
|
||||
},
|
||||
"token_encrypted": None,
|
||||
"gmail_address": None,
|
||||
}
|
||||
|
||||
return _ScoutGmailAuthorizeResponse(
|
||||
authorize_url=_build_gmail_authorize_url(state, code_challenge)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/oauth/gmail/web-callback", include_in_schema=False)
|
||||
async def scout_gmail_oauth_web_callback(code: str, state: str) -> RedirectResponse:
|
||||
"""Google redirects here after Gmail consent.
|
||||
|
||||
Immediately bounces to the Electron deep link so the desktop app
|
||||
receives the authorization code.
|
||||
"""
|
||||
params = urllib.parse.urlencode({"code": code, "state": state})
|
||||
deep_link = f"adiuvai://scout/oauth/gmail/callback?{params}"
|
||||
return RedirectResponse(url=deep_link, status_code=302)
|
||||
|
||||
|
||||
@router.post("/oauth/gmail/callback")
|
||||
async def scout_gmail_oauth_callback(
|
||||
body: _ScoutGmailCallbackBody,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Exchange the Gmail authorization code and store the encrypted token on the scout.
|
||||
|
||||
Called by the Electron app after it receives the deep-link callback with
|
||||
the ``code`` and ``state`` params.
|
||||
"""
|
||||
entry = _pending_scout_oauth_states.pop(body.state, None)
|
||||
if (
|
||||
entry is None
|
||||
or entry["expires_at"] < time.time()
|
||||
or entry["user_id"] != current_user.id
|
||||
):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
|
||||
|
||||
code_verifier = entry["code_verifier"]
|
||||
mode = entry["mode"]
|
||||
scout_id = entry.get("scout_id")
|
||||
|
||||
redirect_uri = _scout_gmail_redirect_uri()
|
||||
|
||||
import httpx
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
_GOOGLE_TOKEN_URL,
|
||||
data={
|
||||
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
|
||||
"client_secret": settings.GOOGLE_AUTH_CLIENT_SECRET,
|
||||
"code": body.code,
|
||||
"code_verifier": code_verifier,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.error("Gmail token exchange failed: %s", exc.response.text)
|
||||
raise HTTPException(status.HTTP_502_BAD_GATEWAY, "Failed to exchange Gmail authorization code")
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
creds_dict: dict = {
|
||||
"token": token_data["access_token"],
|
||||
"refresh_token": token_data.get("refresh_token"),
|
||||
"token_uri": _GOOGLE_TOKEN_URL,
|
||||
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
|
||||
"client_secret": settings.GOOGLE_AUTH_CLIENT_SECRET,
|
||||
"scopes": [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
],
|
||||
}
|
||||
encrypted = encrypt_token(creds_dict)
|
||||
|
||||
# Fetch the connected Gmail address for display.
|
||||
gmail_address: str | None = None
|
||||
try:
|
||||
from googleapiclient.discovery import build
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
def _fetch_email() -> str | None:
|
||||
creds = Credentials(
|
||||
token=creds_dict["token"],
|
||||
refresh_token=creds_dict.get("refresh_token"),
|
||||
token_uri=creds_dict["token_uri"],
|
||||
client_id=creds_dict["client_id"],
|
||||
client_secret=creds_dict["client_secret"],
|
||||
scopes=creds_dict["scopes"],
|
||||
)
|
||||
service = build("gmail", "v1", credentials=creds, cache_discovery=False)
|
||||
profile = service.users().getProfile(userId="me").execute()
|
||||
return profile.get("emailAddress")
|
||||
|
||||
gmail_address = await asyncio.to_thread(_fetch_email)
|
||||
except Exception:
|
||||
logger.exception("failed to fetch gmail address (mode=%s)", mode)
|
||||
|
||||
if mode == "create":
|
||||
# Do NOT create a scout yet. Hold the encrypted token + address in the
|
||||
# transient in-memory session; the scout is created at finalize.
|
||||
entry["token_encrypted"] = encrypted
|
||||
entry["gmail_address"] = gmail_address
|
||||
entry["expires_at"] = time.time() + _SCOUT_OAUTH_TTL_SECONDS
|
||||
_pending_scout_oauth_states[body.state] = entry
|
||||
return {"ok": True, "session_id": body.state, "gmail_address": gmail_address}
|
||||
|
||||
# mode == "reconnect": update the existing scout in place.
|
||||
scout = await db.get(CloudScoutConfig, scout_id)
|
||||
if scout is None or scout.user_id != current_user.id:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found")
|
||||
scout.oauth_token_encrypted = encrypted
|
||||
scout.gmail_address = gmail_address
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Attempt to set up Gmail push watch so we start receiving Pub/Sub notifications.
|
||||
try:
|
||||
connector = get_connector("gmail")
|
||||
await connector.setup_watch(scout)
|
||||
await db.commit()
|
||||
except KeyError:
|
||||
logger.warning("gmail connector not registered — skipping setup_watch for scout %s", scout_id)
|
||||
except Exception:
|
||||
logger.exception("setup_watch failed for scout %s", scout_id)
|
||||
|
||||
return {"ok": True, "session_id": None, "gmail_address": gmail_address}
|
||||
|
||||
|
||||
@router.get("/oauth/gmail/session-labels")
|
||||
async def scout_gmail_session_labels(
|
||||
session: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""List Gmail labels for a pending create-mode OAuth session (no scout row yet).
|
||||
|
||||
Builds a Gmail service from the session's transient decrypted token.
|
||||
Returns [] on any error.
|
||||
"""
|
||||
entry = _pending_scout_oauth_states.get(session)
|
||||
if (
|
||||
entry is None
|
||||
or entry["expires_at"] < time.time()
|
||||
or entry["user_id"] != current_user.id
|
||||
or entry.get("token_encrypted") is None
|
||||
):
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Session not found or expired")
|
||||
|
||||
try:
|
||||
from app.scouts.connectors.gmail import _gmail_service_from_token
|
||||
|
||||
creds = decrypt_token(entry["token_encrypted"])
|
||||
|
||||
def _sync() -> list[dict]:
|
||||
service = _gmail_service_from_token(creds)
|
||||
resp = service.users().labels().list(userId="me").execute()
|
||||
return [{"id": lbl["id"], "name": lbl["name"]} for lbl in resp.get("labels", [])]
|
||||
|
||||
return await asyncio.to_thread(_sync)
|
||||
except Exception:
|
||||
logger.exception("session-labels failed for session %s", session)
|
||||
return []
|
||||
|
||||
|
||||
@router.post("/cloud/finalize", response_model=CloudScoutResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def finalize_cloud_scout(
|
||||
body: _ScoutGmailFinalizeBody,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
):
|
||||
"""Create the cloud scout from a completed create-mode OAuth session.
|
||||
|
||||
This is the only path that persists the Gmail token for a newly-created
|
||||
scout. Abandoned flows never reach here, so they leave no orphan rows.
|
||||
"""
|
||||
entry = _pending_scout_oauth_states.pop(body.session, None)
|
||||
if (
|
||||
entry is None
|
||||
or entry["expires_at"] < time.time()
|
||||
or entry["user_id"] != current_user.id
|
||||
or entry.get("mode") != "create"
|
||||
or entry.get("token_encrypted") is None
|
||||
):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth session")
|
||||
|
||||
draft = entry["draft"] or {}
|
||||
scout = CloudScoutConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=current_user.id,
|
||||
provider="gmail",
|
||||
name=draft.get("name", ""),
|
||||
data_types=[],
|
||||
prompt_template=draft.get("prompt_template", ""),
|
||||
filter_config=body.filter_config,
|
||||
schedule_cron=_DEFAULT_CLOUD_SCHEDULE,
|
||||
auto_trash_spam=draft.get("auto_trash_spam", False),
|
||||
enabled=True,
|
||||
oauth_token_encrypted=entry["token_encrypted"],
|
||||
gmail_address=entry.get("gmail_address"),
|
||||
)
|
||||
db.add(scout)
|
||||
await db.commit()
|
||||
await db.refresh(scout)
|
||||
|
||||
# Best-effort Gmail push watch — failure must not block scout creation.
|
||||
try:
|
||||
connector = get_connector("gmail")
|
||||
await connector.setup_watch(scout)
|
||||
await db.commit()
|
||||
except KeyError:
|
||||
logger.warning("gmail connector not registered — skipping setup_watch for scout %s", scout.id)
|
||||
except Exception:
|
||||
logger.exception("setup_watch failed for scout %s", scout.id)
|
||||
|
||||
return _to_cloud_response(scout)
|
||||
1
api/app/auth/__init__.py
Normal file
1
api/app/auth/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"OAuth provider abstractions and utilities."
|
||||
135
api/app/auth/oauth_providers.py
Normal file
135
api/app/auth/oauth_providers.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""OAuth 2.0 + PKCE provider abstractions.
|
||||
|
||||
Each provider implements a three-step flow designed for a desktop (public) client:
|
||||
|
||||
1. get_authorization_url(state, code_challenge) → str
|
||||
Build the provider's consent-screen URL. State and code_challenge are
|
||||
generated server-side; the client opens this URL in the system browser.
|
||||
|
||||
2. exchange_code(code, code_verifier, redirect_uri) → dict
|
||||
Exchange the short-lived authorization code for an access token.
|
||||
The code_verifier proves ownership of the PKCE challenge.
|
||||
|
||||
3. get_userinfo(access_token) → OAuthUserInfo
|
||||
Fetch the canonical user identity from the provider.
|
||||
|
||||
Currently supported providers:
|
||||
- GoogleOAuthProvider (scope: openid email profile)
|
||||
|
||||
Adding a new provider:
|
||||
- Implement the three methods above.
|
||||
- Register in _PROVIDERS inside routes/auth.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
# ── Data transfer objects ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthUserInfo:
|
||||
"""Normalized user identity returned by any provider."""
|
||||
|
||||
provider_user_id: str
|
||||
email: str
|
||||
email_verified: bool
|
||||
avatar_url: str | None
|
||||
name: str | None
|
||||
|
||||
|
||||
# ── PKCE helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def generate_pkce_pair() -> tuple[str, str]:
|
||||
"""Generate a (code_verifier, code_challenge) pair for PKCE S256.
|
||||
|
||||
The code_verifier is a random 32-byte URL-safe base64 string.
|
||||
The code_challenge is SHA-256(code_verifier) base64url-encoded (no padding).
|
||||
"""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode()
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
# ── Google provider ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class GoogleOAuthProvider:
|
||||
"""Google OAuth 2.0 provider (openid email profile scope).
|
||||
|
||||
Uses Google's standard authorization endpoint with PKCE S256.
|
||||
Does NOT use google-auth-oauthlib to keep the flow generic and async.
|
||||
"""
|
||||
|
||||
name = "google"
|
||||
|
||||
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_USERINFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str) -> None:
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self, state: str, code_challenge: str) -> str:
|
||||
"""Build the Google consent-screen URL."""
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": "openid email profile",
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "select_account",
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
async def exchange_code(
|
||||
self, code: str, code_verifier: str, redirect_uri: str
|
||||
) -> dict:
|
||||
"""Exchange authorization code for an access token."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self._TOKEN_URL,
|
||||
data={
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"code_verifier": code_verifier,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_userinfo(self, access_token: str) -> OAuthUserInfo:
|
||||
"""Fetch the authenticated user's identity from Google."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self._USERINFO_URL,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return OAuthUserInfo(
|
||||
provider_user_id=data["sub"],
|
||||
email=data["email"],
|
||||
email_verified=data.get("email_verified", False),
|
||||
avatar_url=data.get("picture"),
|
||||
name=data.get("name"),
|
||||
)
|
||||
4
api/app/billing/__init__.py
Normal file
4
api/app/billing/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from app.billing.stripe_service import stripe_service
|
||||
from app.billing.tier_manager import tier_manager
|
||||
|
||||
__all__ = ["stripe_service", "tier_manager"]
|
||||
139
api/app/billing/quota.py
Normal file
139
api/app/billing/quota.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Quota checks and atomic token-usage accounting for folder integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.billing.tier_manager import TierManager
|
||||
from app.models import MonthlyTokenUsage
|
||||
from app.schemas import BillingTier
|
||||
|
||||
|
||||
class QuotaExceeded(Exception):
|
||||
"""Raised when a folder operation cannot proceed under the user's tier."""
|
||||
|
||||
def __init__(self, reason: str, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self.reason = reason # "max_files" | "monthly_tokens"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenUsageResult:
|
||||
tokens_used: int
|
||||
exhausted: bool
|
||||
|
||||
|
||||
def _current_year_month() -> str:
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
|
||||
|
||||
_tier_manager = TierManager()
|
||||
|
||||
|
||||
async def check_folder_quota(
|
||||
*,
|
||||
user_id: str,
|
||||
tier: BillingTier,
|
||||
estimated_files: int,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
|
||||
would be violated. -1 in either feature means unlimited."""
|
||||
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
|
||||
if max_files != -1 and estimated_files > max_files:
|
||||
raise QuotaExceeded(
|
||||
"max_files",
|
||||
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
|
||||
)
|
||||
|
||||
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||
if cap == -1:
|
||||
return
|
||||
ym = _current_year_month()
|
||||
row = (
|
||||
await db.execute(
|
||||
select(MonthlyTokenUsage).where(
|
||||
MonthlyTokenUsage.user_id == user_id,
|
||||
MonthlyTokenUsage.year_month == ym,
|
||||
MonthlyTokenUsage.feature == "folder_index",
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
used = row.tokens_used if row else 0
|
||||
if used >= cap:
|
||||
raise QuotaExceeded(
|
||||
"monthly_tokens",
|
||||
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
|
||||
)
|
||||
|
||||
|
||||
async def add_token_usage(
|
||||
*,
|
||||
user_id: str,
|
||||
feature: str,
|
||||
tokens: int,
|
||||
db: AsyncSession,
|
||||
cap: int | None = None,
|
||||
) -> TokenUsageResult:
|
||||
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
|
||||
|
||||
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
|
||||
back to a read-then-write on other engines (e.g. aiosqlite in tests).
|
||||
Returns post-update total and whether cap is exhausted.
|
||||
"""
|
||||
ym = _current_year_month()
|
||||
|
||||
# Detect dialect to choose between native upsert and portable fallback.
|
||||
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
|
||||
|
||||
if dialect_name == "postgresql":
|
||||
# Native atomic upsert — production path.
|
||||
stmt = (
|
||||
pg_insert(MonthlyTokenUsage)
|
||||
.values(
|
||||
user_id=user_id,
|
||||
year_month=ym,
|
||||
feature=feature,
|
||||
tokens_used=tokens,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["user_id", "year_month", "feature"],
|
||||
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
|
||||
)
|
||||
.returning(MonthlyTokenUsage.tokens_used)
|
||||
)
|
||||
used: int = (await db.execute(stmt)).scalar_one()
|
||||
await db.commit()
|
||||
else:
|
||||
# Portable fallback — used in tests (SQLite) and any non-PG engine.
|
||||
row = (
|
||||
await db.execute(
|
||||
select(MonthlyTokenUsage).where(
|
||||
MonthlyTokenUsage.user_id == user_id,
|
||||
MonthlyTokenUsage.year_month == ym,
|
||||
MonthlyTokenUsage.feature == feature,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
row = MonthlyTokenUsage(
|
||||
user_id=user_id,
|
||||
year_month=ym,
|
||||
feature=feature,
|
||||
tokens_used=tokens,
|
||||
)
|
||||
db.add(row)
|
||||
else:
|
||||
row.tokens_used += tokens
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(row)
|
||||
used = row.tokens_used
|
||||
|
||||
exhausted = cap is not None and cap != -1 and used >= cap
|
||||
return TokenUsageResult(tokens_used=used, exhausted=exhausted)
|
||||
295
api/app/billing/stripe_service.py
Normal file
295
api/app/billing/stripe_service.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||
|
||||
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
|
||||
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
|
||||
configured, enabling local development without live credentials.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import stripe as stripe_lib
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
# Stripe price IDs per tier — replace with real IDs in production .env
|
||||
TIER_PRICE_IDS: dict[str, str] = {
|
||||
"pro": "price_pro_monthly",
|
||||
"power": "price_power_monthly",
|
||||
"team": "price_team_monthly",
|
||||
}
|
||||
|
||||
|
||||
class StripeService:
|
||||
"""Wraps all Stripe interactions and owns subscription persistence."""
|
||||
|
||||
# ── Internal helpers ────────────────────────────────────────────────
|
||||
|
||||
def _configured(self) -> bool:
|
||||
return bool(settings.STRIPE_SECRET_KEY)
|
||||
|
||||
def _client(self) -> Any:
|
||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||
return stripe_lib
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────────
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
user_id: str,
|
||||
tier: str,
|
||||
success_url: str = "https://app.adiuvai.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||
cancel_url: str = "https://app.adiuvai.app/billing/cancel",
|
||||
) -> str:
|
||||
"""Create a Stripe checkout session and return the URL.
|
||||
|
||||
Returns a stub URL when Stripe is not configured.
|
||||
Raises ``HTTP 400`` for the free tier or an unknown tier.
|
||||
"""
|
||||
if tier == "free":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot create a checkout session for the free tier",
|
||||
)
|
||||
|
||||
price_id = TIER_PRICE_IDS.get(tier)
|
||||
if not price_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unknown tier: {tier}",
|
||||
)
|
||||
|
||||
if not self._configured():
|
||||
return "https://stripe.com/stub-checkout"
|
||||
|
||||
s = self._client()
|
||||
session = s.checkout.Session.create(
|
||||
payment_method_types=["card"],
|
||||
mode="subscription",
|
||||
line_items=[{"price": price_id, "quantity": 1}],
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
metadata={"user_id": user_id, "tier": tier},
|
||||
)
|
||||
return session.url
|
||||
|
||||
async def handle_webhook(
|
||||
self,
|
||||
payload: bytes,
|
||||
sig_header: str,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Process a Stripe webhook event.
|
||||
|
||||
Verifies the signature, then dispatches on event type.
|
||||
Raises ``HTTP 400`` on signature mismatch.
|
||||
No-ops when Stripe is not configured.
|
||||
"""
|
||||
if not self._configured():
|
||||
return
|
||||
|
||||
try:
|
||||
s = self._client()
|
||||
event = s.Webhook.construct_event(
|
||||
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except stripe_lib.error.SignatureVerificationError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid Stripe signature",
|
||||
)
|
||||
|
||||
event_type: str = event["type"]
|
||||
data: dict[str, Any] = event["data"]["object"]
|
||||
|
||||
if event_type == "checkout.session.completed":
|
||||
user_id = data.get("metadata", {}).get("user_id")
|
||||
tier = data.get("metadata", {}).get("tier", "free")
|
||||
sub_id = data.get("subscription")
|
||||
period_end_ts = data.get("current_period_end")
|
||||
period_end = (
|
||||
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||
if period_end_ts
|
||||
else None
|
||||
)
|
||||
if user_id:
|
||||
await self._upsert_subscription(
|
||||
db, user_id, sub_id, tier, "active", period_end
|
||||
)
|
||||
|
||||
elif event_type == "customer.subscription.updated":
|
||||
sub_id = data.get("id")
|
||||
new_status = data.get("status", "active")
|
||||
period_end_ts = data.get("current_period_end")
|
||||
period_end = (
|
||||
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||
if period_end_ts
|
||||
else None
|
||||
)
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, status=new_status, current_period_end=period_end
|
||||
)
|
||||
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
sub_id = data.get("id")
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, tier="free", status="canceled"
|
||||
)
|
||||
|
||||
elif event_type == "invoice.payment_failed":
|
||||
sub_id = data.get("subscription")
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, status="past_due"
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
async def get_subscription(
|
||||
self, user_id: str, db: AsyncSession
|
||||
) -> dict[str, Any] | None:
|
||||
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
return None
|
||||
return {
|
||||
"tier": sub.tier,
|
||||
"stripe_subscription_id": sub.stripe_subscription_id,
|
||||
"status": sub.status,
|
||||
"current_period_end": (
|
||||
int(sub.current_period_end.timestamp() * 1000)
|
||||
if sub.current_period_end
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
||||
"""Cancel the user's Stripe subscription and downgrade them to free.
|
||||
|
||||
Raises ``HTTP 404`` when no active subscription exists.
|
||||
"""
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None or not sub.stripe_subscription_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No active subscription found",
|
||||
)
|
||||
|
||||
if self._configured():
|
||||
s = self._client()
|
||||
s.Subscription.cancel(sub.stripe_subscription_id)
|
||||
|
||||
sub.tier = "free"
|
||||
sub.status = "canceled"
|
||||
await db.commit()
|
||||
|
||||
async def list_invoices(
|
||||
self, user_id: str, db: AsyncSession, limit: int = 24
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return recent invoices for the user from Stripe.
|
||||
|
||||
Returns an empty list when Stripe is not configured or the user has
|
||||
no ``stripe_customer_id``.
|
||||
"""
|
||||
if not self._configured():
|
||||
return []
|
||||
|
||||
from app.models import User # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(User.stripe_customer_id).where(User.id == user_id)
|
||||
)
|
||||
customer_id = result.scalar_one_or_none()
|
||||
if not customer_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
s = self._client()
|
||||
invoices = s.Invoice.list(customer=customer_id, limit=limit)
|
||||
return [
|
||||
{
|
||||
"id": inv.id,
|
||||
"amount_due": inv.amount_due,
|
||||
"amount_paid": inv.amount_paid,
|
||||
"currency": inv.currency,
|
||||
"status": inv.status,
|
||||
"created": inv.created * 1000, # epoch ms
|
||||
"invoice_url": inv.hosted_invoice_url,
|
||||
"invoice_pdf": inv.invoice_pdf,
|
||||
}
|
||||
for inv in invoices.auto_paging_iter()
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# ── Private DB helpers ───────────────────────────────────────────────
|
||||
|
||||
async def _upsert_subscription(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
stripe_subscription_id: str | None,
|
||||
tier: str,
|
||||
sub_status: str,
|
||||
current_period_end: datetime | None,
|
||||
) -> None:
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
sub = Subscription(user_id=user_id)
|
||||
db.add(sub)
|
||||
sub.stripe_subscription_id = stripe_subscription_id
|
||||
sub.tier = tier
|
||||
sub.status = sub_status
|
||||
sub.current_period_end = current_period_end
|
||||
|
||||
async def _update_subscription_by_stripe_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
stripe_subscription_id: str,
|
||||
*,
|
||||
tier: str | None = None,
|
||||
status: str | None = None,
|
||||
current_period_end: datetime | None = None,
|
||||
) -> None:
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.stripe_subscription_id == stripe_subscription_id
|
||||
)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
return
|
||||
if tier is not None:
|
||||
sub.tier = tier
|
||||
if status is not None:
|
||||
sub.status = status
|
||||
if current_period_end is not None:
|
||||
sub.current_period_end = current_period_end
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
stripe_service = StripeService()
|
||||
149
api/app/billing/tier_manager.py
Normal file
149
api/app/billing/tier_manager.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Tier manager: feature matrix and quota enforcement.
|
||||
|
||||
``TierManager`` is the single source of truth for what each billing tier
|
||||
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
||||
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
||||
from ``current_user.tier`` (provided by ``get_current_user``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.schemas import BillingTier
|
||||
|
||||
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
||||
FEATURES: dict[str, dict[str, Any]] = {
|
||||
"free": {
|
||||
"agents": 3,
|
||||
"batch_active": 2,
|
||||
"batch_runs_per_day": 5,
|
||||
"providers": 1,
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": False, # keyword fallback only
|
||||
"realtime_extraction": False, # batch queue (Phase 2)
|
||||
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
||||
"proactive_mining": False, # Power+ only (Phase 5)
|
||||
"folder_max_files": 200,
|
||||
"folder_monthly_tokens": 100_000,
|
||||
},
|
||||
"pro": {
|
||||
"agents": -1, # unlimited
|
||||
"batch_active": 10,
|
||||
"batch_runs_per_day": 50,
|
||||
"providers": -1,
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": True, # pgvector cosine search
|
||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||
"relational_memory": True, # person/project predicates
|
||||
"proactive_mining": False, # Power+ only (Phase 5)
|
||||
"folder_max_files": 5000,
|
||||
"folder_monthly_tokens": 2_000_000,
|
||||
},
|
||||
"power": {
|
||||
"agents": -1,
|
||||
"batch_active": -1, # unlimited
|
||||
"batch_runs_per_day": -1, # unlimited
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"sso": False,
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
"relational_memory": True, # all predicates incl. custom
|
||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||
"folder_max_files": -1, # unlimited
|
||||
"folder_monthly_tokens": -1, # unlimited
|
||||
},
|
||||
"team": {
|
||||
"agents": -1,
|
||||
"batch_active": -1,
|
||||
"batch_runs_per_day": -1, # unlimited
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"sso": True,
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
"relational_memory": True, # all predicates incl. custom
|
||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||
"folder_max_files": -1, # unlimited
|
||||
"folder_monthly_tokens": -1, # unlimited
|
||||
},
|
||||
}
|
||||
|
||||
# Requests-per-minute limit per tier.
|
||||
RATE_LIMITS: dict[str, int] = {
|
||||
"free": 20,
|
||||
"pro": 60,
|
||||
"power": 120,
|
||||
"team": 200,
|
||||
}
|
||||
|
||||
|
||||
class TierManager:
|
||||
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||
|
||||
# ── Tier lookup ─────────────────────────────────────────────────────
|
||||
|
||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||
"""Return the current billing tier for ``user_id`` from the DB.
|
||||
|
||||
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
|
||||
when no subscription row exists.
|
||||
"""
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
tier: str | None = result.scalar_one_or_none()
|
||||
if tier is None or tier not in FEATURES:
|
||||
return "power" if settings.ENV == "dev" else "free"
|
||||
return tier # type: ignore[return-value]
|
||||
|
||||
# ── Feature access ───────────────────────────────────────────────────
|
||||
|
||||
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
||||
|
||||
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
||||
"""
|
||||
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return value != 0
|
||||
|
||||
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
||||
if not self.check_feature(tier, feature):
|
||||
detail = (
|
||||
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||
if tier_name
|
||||
else f"Feature '{feature}' is not available on your current tier."
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
|
||||
"""Return integer feature value for tier. -1 means unlimited."""
|
||||
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||
if not isinstance(value, int):
|
||||
return 0
|
||||
return value
|
||||
|
||||
# ── Rate limiting ────────────────────────────────────────────────────
|
||||
|
||||
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||
"""Return the requests-per-minute limit for ``tier``."""
|
||||
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
tier_manager = TierManager()
|
||||
0
api/app/config/__init__.py
Normal file
0
api/app/config/__init__.py
Normal file
95
api/app/config/settings.py
Normal file
95
api/app/config/settings.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from typing import Literal
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai"
|
||||
JWT_SECRET: str = "change-me-in-production"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||
|
||||
STRIPE_SECRET_KEY: str = ""
|
||||
STRIPE_WEBHOOK_SECRET: str = ""
|
||||
|
||||
OPENAI_API_KEY: str = ""
|
||||
ANTHROPIC_API_KEY: str = ""
|
||||
GOOGLE_API_KEY: str = ""
|
||||
CEREBRAS_API_KEY: str = ""
|
||||
GROQ_API_KEY: str = ""
|
||||
DEEPSEEK_API_KEY: str = ""
|
||||
|
||||
LLM_MODEL: str = "gpt-4o"
|
||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||
|
||||
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
|
||||
LLM_MODEL_CLASSIFIER: str = "" # classifier (intent routing, future use)
|
||||
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
|
||||
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
|
||||
LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research)
|
||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
||||
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
||||
LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit)
|
||||
|
||||
# GitHub Copilot OAuth token storage directory.
|
||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
||||
GITHUB_COPILOT_TOKEN_DIR: str = ""
|
||||
|
||||
# OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows.
|
||||
GMAIL_CLIENT_ID: str = ""
|
||||
GMAIL_CLIENT_SECRET: str = ""
|
||||
MS_CLIENT_ID: str = ""
|
||||
MS_CLIENT_SECRET: str = ""
|
||||
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
||||
MS_TENANT_ID: str = "common"
|
||||
|
||||
# Google Login OAuth credentials — scope: openid email profile.
|
||||
# Separate from GMAIL_CLIENT_ID/SECRET (which uses gmail.readonly scope).
|
||||
GOOGLE_AUTH_CLIENT_ID: str = ""
|
||||
GOOGLE_AUTH_CLIENT_SECRET: str = ""
|
||||
# The redirect URI registered in Google Cloud Console.
|
||||
# Google redirects here after consent; this backend route then bounces to
|
||||
# the adiuvai:// deep link so the Electron app receives the code.
|
||||
# Dev: http://localhost:8000/api/v1/auth/oauth/google/web-callback
|
||||
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
|
||||
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
|
||||
|
||||
# Gmail Pub/Sub topic for push notifications.
|
||||
# Full resource name, e.g. "projects/my-project/topics/gmail-push".
|
||||
# Leave empty in dev — setup_watch will skip registration gracefully.
|
||||
GMAIL_PUBSUB_TOPIC: str = ""
|
||||
# OIDC token audience for Pub/Sub push subscription JWT verification.
|
||||
# Set to the service account email or audience string configured in the
|
||||
# Pub/Sub push subscription. Leave empty in dev to skip verification
|
||||
# (a warning is logged — never silent in production).
|
||||
GMAIL_PUBSUB_AUDIENCE: str = ""
|
||||
|
||||
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||
OAUTH_ENCRYPTION_KEY: str = ""
|
||||
|
||||
CORS_ORIGINS: list[str] = [
|
||||
"app://.",
|
||||
"http://localhost:3000",
|
||||
"http://localhost:5173",
|
||||
"http://localhost:4173", # Vite preview (web SPA)
|
||||
"https://app.adiuvai.com", # Production web portal
|
||||
]
|
||||
|
||||
LANGFUSE_SECRET_KEY: str = ""
|
||||
LANGFUSE_PUBLIC_KEY: str = ""
|
||||
LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
|
||||
|
||||
SCHEDULER_ENABLED: bool = True
|
||||
|
||||
ENV: Literal["dev", "prod"] = "dev"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
0
api/app/core/__init__.py
Normal file
0
api/app/core/__init__.py
Normal file
228
api/app/core/brief_agent.py
Normal file
228
api/app/core/brief_agent.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""Brief agent — produces plain-text home and project status briefs.
|
||||
|
||||
Read-only tool subset only. Never calls _normalize_tagged_list_lines —
|
||||
the brief prompt forbids XML tags, so skipping post-processing is intentional.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
|
||||
from app.agents.note_agent import NOTE_READ_TOOLS
|
||||
from app.agents.project_agent import PROJECT_READ_TOOLS
|
||||
from app.agents.task_agent import TASK_READ_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
|
||||
from app.core.deep_agent import (
|
||||
_language_instruction,
|
||||
_proactive_hints_injection,
|
||||
_read_only_memory_tools,
|
||||
_relational_memory_injection,
|
||||
_run_single_agent_stream,
|
||||
_trace_id_from_context,
|
||||
build_brief_multi_project_manifest,
|
||||
)
|
||||
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
|
||||
|
||||
_LANGUAGE_NAMES: dict[str, str] = {
|
||||
"en": "English", "it": "Italian", "es": "Spanish",
|
||||
"fr": "French", "de": "German",
|
||||
"english": "English", "italian": "Italian", "italiano": "Italian",
|
||||
"spanish": "Spanish", "español": "Spanish",
|
||||
"french": "French", "français": "French",
|
||||
"german": "German", "deutsch": "German",
|
||||
}
|
||||
|
||||
_HOME_BRIEF_FALLBACK = """\
|
||||
You are the user's personal assistant producing a short daily brief.
|
||||
|
||||
ROLE
|
||||
Act like a calm, attentive secretary writing a stand-up note for your boss.
|
||||
Warm and human, never breezy. Never cheerful filler, never emojis, never
|
||||
"here is your brief" meta-text. The user is opening the app mid-workday and
|
||||
is probably stressed — your job is to lower cognitive load, not add noise.
|
||||
|
||||
TOOLS — always call before writing
|
||||
Pull fresh data every run. Do not invent counts or titles. Use at minimum:
|
||||
- list_tasks_due_today — tasks the user owes today
|
||||
- list_timelines_today — events starting or ending today
|
||||
- list_all_projects — projects currently in progress or at risk
|
||||
- memory_list_blocks / memory_get — personal context about people, clients,
|
||||
payment habits, working preferences
|
||||
If a tool returns nothing, simply omit that topic. Never report zeros.
|
||||
|
||||
WHAT TO INCLUDE
|
||||
1. Tasks due today (title + priority; group the 1-2 most important).
|
||||
2. Timeline events starting or ending today (and anything that starts/ends
|
||||
tomorrow if the user has a very light day).
|
||||
3. Active projects that need a nudge — stalled, blocked, or awaiting input.
|
||||
4. Memory-aware colour where it sharpens the brief. Examples:
|
||||
- "Client Rossi tends to pay late — the Acme invoice is 6 days out."
|
||||
- "You usually dislike meetings before 10:00 — the call at 09:30 is unusual."
|
||||
Only add a memory line when it changes what the user does. Do not pad.
|
||||
|
||||
WHAT TO OMIT
|
||||
- Zero-counts ("no overdue items", "0 meetings today").
|
||||
- Statistics ("2 active projects, 3 completed tasks").
|
||||
- Headers, titles, greetings, sign-offs, dates, emojis, slang.
|
||||
- Meta-phrases ("here is", "let me know if", "hope this helps").
|
||||
- XML/HTML tags of any kind. Plain prose only.
|
||||
|
||||
LIGHT-DAY CLAUSE
|
||||
If tasks + events + active-project-nudges together produce fewer than two
|
||||
sentences of content, also list 1-2 projects in status on_hold or waiting
|
||||
and ask a single, specific question about them — e.g. "Is the Bianchi
|
||||
redesign still paused, or ready to pick back up?" One question max, grounded
|
||||
in a real project name.
|
||||
|
||||
VOICE
|
||||
- Calm. Concise. Human. Short sentences.
|
||||
- Use **bold** sparingly for task titles, project names, and people's names.
|
||||
- No bullet lists. Flow as 2-4 sentences of prose.
|
||||
|
||||
LENGTH
|
||||
2-4 sentences total. Hard cap 4. If the day is truly empty, one sentence.
|
||||
|
||||
Respond in the user's language ({language}). Today is {today}.\
|
||||
"""
|
||||
|
||||
_PROJECT_BRIEF_FALLBACK = """\
|
||||
You are the project assistant producing a short status brief for ONE project.
|
||||
|
||||
ROLE
|
||||
A senior project manager summarising state-of-play for the owner. Factual,
|
||||
sharp, forward-looking. Never reassuring filler, never emojis.
|
||||
|
||||
SCOPE
|
||||
Work only with project_id = {project_id}. Do not mention or pull data from
|
||||
other projects. Use tools to fetch fresh data:
|
||||
- get_project — current status, dates, description
|
||||
- list_tasks(project_id) — open work, split by status
|
||||
- list_timelines(project_id) — milestones hit, upcoming, overdue
|
||||
- list_notes(project_id) — any recent decisions or blockers
|
||||
- memory_get — relevant context about the client, collaborators, constraints
|
||||
|
||||
STRUCTURE — follow exactly, one short paragraph per section, no headers
|
||||
1. **State.** One sentence: current phase, health (on track / at risk / blocked),
|
||||
and why. Cite the concrete signal (overdue milestone, stalled tasks, recent
|
||||
blocker note).
|
||||
2. **What's moving.** What was completed or progressed recently. Name specific
|
||||
tasks or milestones.
|
||||
3. **Next steps.** The 1-3 most important things the user should do next, in
|
||||
priority order. Be concrete — task name, who owns it, when due if known.
|
||||
If waiting on someone else, name them and what the ask is.
|
||||
4. **Risks / memory-flagged items.** One line max. Only include when there is
|
||||
a real risk or a relevant memory (e.g. late-paying client, tight deadline,
|
||||
scope change). Omit the section entirely if nothing to say.
|
||||
|
||||
WHAT TO OMIT
|
||||
- Zero-counts ("no overdue tasks").
|
||||
- Generic advice ("keep up the good work").
|
||||
- Greetings, headers, bullet lists, emojis, sign-offs, meta-phrases.
|
||||
- XML/HTML tags or bracketed id lists. Plain prose only.
|
||||
|
||||
VOICE
|
||||
- Direct. Factual. No fluff.
|
||||
- Use **bold** sparingly for task titles, milestone names, and the owner's name.
|
||||
- Short sentences. Prefer verbs over nouns ("Client review is blocking release"
|
||||
not "There is a blocker which is the client review").
|
||||
|
||||
LENGTH
|
||||
4-8 sentences total across the 3-4 sections. Hard cap 8.
|
||||
|
||||
Respond in the user's language ({language}). Today is {today}.\
|
||||
"""
|
||||
|
||||
|
||||
def _resolve_language(context: dict[str, Any]) -> str:
|
||||
core = context.get("core_memory") or {}
|
||||
raw = (core.get("language") or "en").strip().lower()
|
||||
return _LANGUAGE_NAMES.get(raw, raw.title()) or "English"
|
||||
|
||||
|
||||
def _build_read_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
return [
|
||||
*TASK_READ_TOOLS,
|
||||
*PROJECT_READ_TOOLS,
|
||||
*TIMELINE_READ_TOOLS,
|
||||
*NOTE_READ_TOOLS,
|
||||
*_read_only_memory_tools(user_id, trace_id),
|
||||
]
|
||||
|
||||
|
||||
async def run_home_brief(
|
||||
user_id: str,
|
||||
context: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Stream a plain-text daily home brief.
|
||||
|
||||
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
||||
Do NOT post-process output through _normalize_tagged_list_lines.
|
||||
"""
|
||||
from app.agents.folder_agent import FOLDER_TOOLS
|
||||
|
||||
trace_id = _trace_id_from_context(context)
|
||||
today = date.today().isoformat()
|
||||
language = _resolve_language(context)
|
||||
|
||||
raw_template, langfuse_prompt = get_prompt_or_fallback("home_brief", _HOME_BRIEF_FALLBACK)
|
||||
system_prompt = compile_prompt(raw_template, langfuse_prompt, language=language, today=today)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _proactive_hints_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
if today not in system_prompt:
|
||||
system_prompt += f"\nToday is {today}."
|
||||
|
||||
brief_manifest = await build_brief_multi_project_manifest()
|
||||
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
|
||||
|
||||
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
message="Generate the daily brief.",
|
||||
context=context,
|
||||
langfuse_prompt=langfuse_prompt,
|
||||
agent_name="brief-agent",
|
||||
tools=tools,
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
async def run_project_brief(
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
context: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Stream a plain-text project status brief for project_id.
|
||||
|
||||
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
||||
Do NOT post-process output through _normalize_tagged_list_lines.
|
||||
"""
|
||||
trace_id = _trace_id_from_context(context)
|
||||
today = date.today().isoformat()
|
||||
language = _resolve_language(context)
|
||||
|
||||
raw_template, langfuse_prompt = get_prompt_or_fallback("project_brief", _PROJECT_BRIEF_FALLBACK)
|
||||
system_prompt = compile_prompt(
|
||||
raw_template, langfuse_prompt,
|
||||
language=language, today=today, project_id=project_id,
|
||||
)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _proactive_hints_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
if today not in system_prompt:
|
||||
system_prompt += f"\nToday is {today}."
|
||||
|
||||
tools = _build_read_tools(user_id, trace_id)
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
message=f"Generate the project status brief for project {project_id}.",
|
||||
context=context,
|
||||
langfuse_prompt=langfuse_prompt,
|
||||
agent_name="brief-agent",
|
||||
tools=tools,
|
||||
):
|
||||
yield event
|
||||
1329
api/app/core/deep_agent.py
Normal file
1329
api/app/core/deep_agent.py
Normal file
File diff suppressed because it is too large
Load Diff
151
api/app/core/device_manager.py
Normal file
151
api/app/core/device_manager.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Device connection manager.
|
||||
|
||||
Maintains in-memory state for all active Electron → backend WebSocket
|
||||
connections. One connection per user (latest replaces previous).
|
||||
|
||||
The manager handles the **tool-call round-trip** pattern:
|
||||
- Backend sends ``tool_call`` frame → Electron executes the action →
|
||||
returns ``tool_result`` frame.
|
||||
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||
receive the result dict from Electron.
|
||||
|
||||
This pattern is used by all tools (CRUD, file-system, etc.) via
|
||||
``execute_on_client()`` in ``ws_context.py``.
|
||||
|
||||
The ``device_manager`` module-level singleton is imported by both the
|
||||
device WS route and the agent runner.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceConnection:
|
||||
"""State for a single connected Electron device."""
|
||||
|
||||
ws: WebSocket
|
||||
device_id: str
|
||||
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DeviceConnectionManager:
|
||||
"""Singleton registry of active Electron WebSocket connections.
|
||||
|
||||
Thread/task safety note: asyncio is single-threaded by design. All
|
||||
mutations happen inside await-points on the main event loop, so no
|
||||
locking is required for the in-memory dicts.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connections: dict[str, DeviceConnection] = {}
|
||||
|
||||
# ── Registration ──────────────────────────────────────────────────
|
||||
|
||||
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
|
||||
"""Store the active connection for *user_id*, replacing any previous one."""
|
||||
if user_id in self._connections:
|
||||
old = self._connections[user_id]
|
||||
logger.info(
|
||||
"device_manager: replacing existing connection for user=%s device=%s",
|
||||
user_id,
|
||||
old.device_id,
|
||||
)
|
||||
# Cancel any futures that were waiting on the old connection.
|
||||
for fut in old.pending_calls.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
|
||||
logger.info(
|
||||
"device_manager: registered user=%s device=%s", user_id, device_id
|
||||
)
|
||||
|
||||
def unregister(self, user_id: str) -> None:
|
||||
"""Remove the connection for *user_id* and cancel any pending futures."""
|
||||
conn = self._connections.pop(user_id, None)
|
||||
if conn is None:
|
||||
return
|
||||
for fut in conn.pending_calls.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
logger.info("device_manager: unregistered user=%s", user_id)
|
||||
|
||||
# ── Presence queries ──────────────────────────────────────────────
|
||||
|
||||
def get_ws(self, user_id: str) -> WebSocket | None:
|
||||
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
|
||||
conn = self._connections.get(user_id)
|
||||
return conn.ws if conn else None
|
||||
|
||||
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
|
||||
"""Return ``True`` if the user has an active connection.
|
||||
|
||||
If *device_id* is provided also checks that it matches the connected device.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
return False
|
||||
if device_id is not None:
|
||||
return conn.device_id == device_id
|
||||
return True
|
||||
|
||||
# ── Frame sending ─────────────────────────────────────────────────
|
||||
|
||||
async def send_frame(self, user_id: str, frame: dict) -> None:
|
||||
"""Send *frame* as a JSON text message to the device.
|
||||
|
||||
Raises ``RuntimeError`` if the user is not connected.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
raise RuntimeError(
|
||||
f"send_frame: user {user_id!r} is not connected"
|
||||
)
|
||||
await conn.ws.send_text(json.dumps(frame))
|
||||
|
||||
# ── Tool-call round-trip ──────────────────────────────────────────
|
||||
|
||||
def create_pending_call(
|
||||
self, user_id: str, call_id: str
|
||||
) -> asyncio.Future[dict]:
|
||||
"""Register a Future that will be resolved when the tool_result arrives.
|
||||
|
||||
Raises ``RuntimeError`` if the user is not connected.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
raise RuntimeError(
|
||||
f"create_pending_call: user {user_id!r} is not connected"
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
fut: asyncio.Future[dict] = loop.create_future()
|
||||
conn.pending_calls[call_id] = fut
|
||||
return fut
|
||||
|
||||
def resolve_pending_call(
|
||||
self, user_id: str, call_id: str, result: dict
|
||||
) -> None:
|
||||
"""Fulfil the Future registered under *call_id* with the Electron result.
|
||||
|
||||
No-ops if the call_id is unknown (already timed out or cancelled).
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
return
|
||||
fut = conn.pending_calls.pop(call_id, None)
|
||||
if fut is not None and not fut.done():
|
||||
fut.set_result(result)
|
||||
|
||||
|
||||
# Module-level singleton — import this everywhere.
|
||||
device_manager = DeviceConnectionManager()
|
||||
34
api/app/core/embeddings.py
Normal file
34
api/app/core/embeddings.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""OpenAI embedding helper for associative memory tier.
|
||||
|
||||
Single public function: ``embed_text(text) -> list[float] | None``.
|
||||
Returns None on any failure — callers must implement a keyword fallback.
|
||||
Never raises; all exceptions are logged as warnings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_INPUT_CHARS = 8000
|
||||
_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
|
||||
|
||||
async def embed_text(text: str) -> list[float] | None:
|
||||
"""Call OpenAI text-embedding-3-small. Return None on failure (caller falls back to keyword)."""
|
||||
try:
|
||||
client = AsyncOpenAI()
|
||||
truncated = text[:_MAX_INPUT_CHARS]
|
||||
response = await client.embeddings.create(
|
||||
input=truncated,
|
||||
model=_EMBEDDING_MODEL,
|
||||
)
|
||||
result: list[float] = response.data[0].embedding
|
||||
logger.debug("embeddings: embed_text dims=%d", len(result))
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("embeddings: embed_text failed: %s", exc)
|
||||
return None
|
||||
183
api/app/core/folder_indexer.py
Normal file
183
api/app/core/folder_indexer.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Per-file summarisation for project folder integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pypdf import PdfReader
|
||||
from docx import Document as DocxDocument
|
||||
|
||||
from app.core.langfuse_client import (
|
||||
compile_prompt,
|
||||
extract_usage,
|
||||
get_langfuse,
|
||||
get_prompt_or_fallback,
|
||||
)
|
||||
from app.core.llm import get_llm
|
||||
|
||||
_TEXT_FALLBACK = (
|
||||
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
|
||||
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
|
||||
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
|
||||
)
|
||||
_IMAGE_FALLBACK = (
|
||||
"You are summarising an image attached to a project folder.\n"
|
||||
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
|
||||
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
|
||||
)
|
||||
_MAX_INPUT_CHARS = 6000
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexResult:
|
||||
summary: str
|
||||
tokens_used: int
|
||||
|
||||
|
||||
async def _llm_text(messages: list) -> object:
|
||||
"""Make the LLM call for text summarisation.
|
||||
|
||||
Defined as a standalone async function so tests can patch it cleanly
|
||||
without needing to mock the LLM object itself.
|
||||
"""
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
||||
return await llm.ainvoke(messages)
|
||||
|
||||
|
||||
async def _llm_vision(messages: list) -> object:
|
||||
"""Make the LLM call for vision (image) summarisation.
|
||||
|
||||
Accepts the message list and returns the response directly, mirroring
|
||||
the ``_llm_text`` caller pattern so tests can patch it at the module level.
|
||||
"""
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
||||
return await llm.ainvoke(messages)
|
||||
|
||||
|
||||
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
|
||||
"""Return a compact summary of an image file using vision.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image_b64:
|
||||
Base64-encoded image bytes.
|
||||
mime:
|
||||
MIME type of the image, e.g. ``"image/png"``.
|
||||
file_name:
|
||||
Optional file name, attached to the Langfuse trace as input metadata.
|
||||
"""
|
||||
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
|
||||
messages = [
|
||||
SystemMessage(content=template),
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": "Summarise this image."},
|
||||
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
|
||||
]),
|
||||
]
|
||||
lf = get_langfuse()
|
||||
if lf is not None:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="folder-summarize-image",
|
||||
model="gpt-4o-mini",
|
||||
prompt=prompt_obj,
|
||||
input={"file_name": file_name, "mime": mime},
|
||||
) as gen:
|
||||
response = await _llm_vision(messages)
|
||||
usage = extract_usage(response)
|
||||
gen.update(output=response.content, usage_details=usage)
|
||||
else:
|
||||
response = await _llm_vision(messages)
|
||||
usage = extract_usage(response)
|
||||
summary = (response.content or "").strip()[:500]
|
||||
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
||||
|
||||
|
||||
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
|
||||
"""Return a compact summary of a text file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
content:
|
||||
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
|
||||
ext:
|
||||
File extension including the leading dot, e.g. ``".md"``.
|
||||
name:
|
||||
File name, e.g. ``"kickoff.md"``.
|
||||
"""
|
||||
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
|
||||
truncated = content[:_MAX_INPUT_CHARS]
|
||||
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
|
||||
messages = [
|
||||
SystemMessage(content=compiled),
|
||||
HumanMessage(content="Summarise this file."),
|
||||
]
|
||||
lf = get_langfuse()
|
||||
if lf is not None:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="folder-summarize-text",
|
||||
model="gpt-4o-mini",
|
||||
prompt=prompt_obj,
|
||||
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
|
||||
) as gen:
|
||||
response = await _llm_text(messages)
|
||||
usage = extract_usage(response)
|
||||
gen.update(output=response.content, usage_details=usage)
|
||||
else:
|
||||
response = await _llm_text(messages)
|
||||
usage = extract_usage(response)
|
||||
summary = (response.content or "").strip()[:500]
|
||||
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
||||
|
||||
|
||||
def _extract_pdf_text(pdf_b64: str) -> str:
|
||||
buf = io.BytesIO(base64.b64decode(pdf_b64))
|
||||
reader = PdfReader(buf)
|
||||
parts: list[str] = []
|
||||
for page in reader.pages:
|
||||
try:
|
||||
parts.append(page.extract_text() or "")
|
||||
except Exception:
|
||||
continue
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
|
||||
def _extract_docx_text(docx_b64: str) -> str:
|
||||
buf = io.BytesIO(base64.b64decode(docx_b64))
|
||||
doc = DocxDocument(buf)
|
||||
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
|
||||
|
||||
|
||||
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
|
||||
"""Return a compact summary of a PDF file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pdf_b64:
|
||||
Base64-encoded PDF bytes.
|
||||
name:
|
||||
File name, e.g. ``"report.pdf"``.
|
||||
"""
|
||||
text = _extract_pdf_text(pdf_b64)
|
||||
if not text:
|
||||
return IndexResult(summary="Could not extract text", tokens_used=0)
|
||||
return await summarize_text(content=text, ext=".pdf", name=name)
|
||||
|
||||
|
||||
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
|
||||
"""Return a compact summary of a DOCX file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
docx_b64:
|
||||
Base64-encoded DOCX bytes.
|
||||
name:
|
||||
File name, e.g. ``"spec.docx"``.
|
||||
"""
|
||||
text = _extract_docx_text(docx_b64)
|
||||
if not text:
|
||||
return IndexResult(summary="Could not extract text", tokens_used=0)
|
||||
return await summarize_text(content=text, ext=".docx", name=name)
|
||||
190
api/app/core/langfuse_client.py
Normal file
190
api/app/core/langfuse_client.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Langfuse observability — singleton client and prompt helpers.
|
||||
|
||||
If LANGFUSE_SECRET_KEY / LANGFUSE_PUBLIC_KEY are not set,
|
||||
all helpers are no-ops so the app works without Langfuse configured.
|
||||
|
||||
Usage
|
||||
-----
|
||||
Tracing::
|
||||
|
||||
from app.core.langfuse_client import get_langfuse
|
||||
|
||||
lf = get_langfuse()
|
||||
if lf:
|
||||
with lf.start_as_current_observation(as_type="span", name="my-agent") as span:
|
||||
span.update(input=user_message)
|
||||
# ... do work ...
|
||||
span.update(output=result)
|
||||
lf.flush()
|
||||
|
||||
Prompt management::
|
||||
|
||||
from app.core.langfuse_client import get_prompt_or_fallback
|
||||
|
||||
text, prompt_obj = get_prompt_or_fallback("home_system", FALLBACK_PROMPT)
|
||||
# Use text as the system prompt; pass prompt_obj to generations for linking.
|
||||
|
||||
Linking a prompt to a generation::
|
||||
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="llm-call",
|
||||
model="gpt-4o",
|
||||
prompt=prompt_obj, # links generation → prompt version in the UI
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=_usage(response))
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_client: Any = None
|
||||
_initialized: bool = False
|
||||
|
||||
|
||||
def get_langfuse() -> Any | None:
|
||||
"""Return the Langfuse singleton, or ``None`` when not configured."""
|
||||
global _client, _initialized
|
||||
if _initialized:
|
||||
return _client
|
||||
_initialized = True
|
||||
|
||||
from app.config.settings import settings # local import to avoid circular deps
|
||||
|
||||
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
|
||||
logger.debug("langfuse: not configured — observability disabled")
|
||||
return None
|
||||
|
||||
try:
|
||||
from langfuse import Langfuse
|
||||
|
||||
_client = Langfuse(
|
||||
secret_key=settings.LANGFUSE_SECRET_KEY,
|
||||
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
||||
host=settings.LANGFUSE_BASE_URL,
|
||||
)
|
||||
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_BASE_URL)
|
||||
except Exception as exc:
|
||||
logger.warning("langfuse: failed to initialize: %s", exc)
|
||||
_client = None
|
||||
|
||||
return _client
|
||||
|
||||
|
||||
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
|
||||
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
|
||||
|
||||
Returns ``(raw_template, prompt_obj_or_None)``.
|
||||
|
||||
* ``raw_template`` — the uncompiled template string. Do NOT call ``.format()``
|
||||
on it directly; use :func:`compile_prompt` instead so the correct variable
|
||||
syntax is applied (``{{var}}`` for Langfuse, ``{var}`` for the fallback).
|
||||
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
|
||||
unavailable / the fetch failed. Pass this to generation observations so
|
||||
Langfuse links the generation to the exact prompt version in the UI.
|
||||
"""
|
||||
lf = get_langfuse()
|
||||
if lf is None:
|
||||
return fallback, None
|
||||
|
||||
try:
|
||||
prompt = lf.get_prompt(name, label="production", fallback=fallback)
|
||||
# For text-type prompts .prompt holds the raw template string.
|
||||
raw = prompt.prompt if hasattr(prompt, "prompt") and isinstance(prompt.prompt, str) else fallback
|
||||
return raw, prompt
|
||||
except Exception as exc:
|
||||
logger.warning("langfuse: get_prompt %r failed: %s — using fallback", name, exc)
|
||||
return fallback, None
|
||||
|
||||
|
||||
def compile_prompt(template: str, prompt_obj: Any, **variables: Any) -> str:
|
||||
"""Compile *template* with *variables*, choosing the right syntax.
|
||||
|
||||
* When *prompt_obj* is a real Langfuse prompt object, calls
|
||||
``prompt_obj.compile(**variables)`` which handles ``{{variable}}``
|
||||
substitution as defined in the Langfuse UI.
|
||||
* When *prompt_obj* is ``None`` (Langfuse unavailable or fetch failed),
|
||||
falls back to ``template.format(**variables)`` which handles the
|
||||
``{variable}`` syntax used in the hardcoded fallback strings.
|
||||
|
||||
This keeps callers oblivious to which syntax is in use.
|
||||
"""
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
compiled = prompt_obj.compile(**variables)
|
||||
# compile() returns a string for text prompts.
|
||||
if isinstance(compiled, str):
|
||||
return compiled
|
||||
# Chat prompts return a list of dicts — join text parts.
|
||||
if isinstance(compiled, list):
|
||||
return "\n".join(
|
||||
m.get("content", "") for m in compiled if isinstance(m, dict)
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"langfuse: compile failed for prompt %r: %s — falling back to .format()",
|
||||
getattr(prompt_obj, "name", "?"),
|
||||
exc,
|
||||
)
|
||||
return template.format(**variables)
|
||||
|
||||
|
||||
def extract_usage(response: Any) -> dict[str, int]:
|
||||
"""Extract token usage from a LangChain AI message into Langfuse format."""
|
||||
meta = getattr(response, "usage_metadata", None)
|
||||
if not meta:
|
||||
return {}
|
||||
return {
|
||||
"input": int(meta.get("input_tokens", 0)),
|
||||
"output": int(meta.get("output_tokens", 0)),
|
||||
"total": int(meta.get("total_tokens", 0)),
|
||||
}
|
||||
|
||||
|
||||
def hash_user_id(user_id: str) -> str:
|
||||
"""Return a SHA-256 hash of *user_id* for use as Langfuse ``user_id``.
|
||||
|
||||
This avoids sending raw database UUIDs to external observability services
|
||||
while still providing a stable, deterministic identifier for per-user
|
||||
metrics in the Langfuse dashboard.
|
||||
"""
|
||||
return hashlib.sha256(user_id.encode()).hexdigest()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def langfuse_context(
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Propagate ``user_id`` (hashed) and ``session_id`` to all Langfuse observations.
|
||||
|
||||
No-op when Langfuse is not configured or parameters are empty.
|
||||
"""
|
||||
lf = get_langfuse()
|
||||
if lf is None or (not user_id and not session_id):
|
||||
yield
|
||||
return
|
||||
|
||||
try:
|
||||
from langfuse import propagate_attributes
|
||||
except ImportError:
|
||||
logger.debug("langfuse: propagate_attributes not available — skipping context")
|
||||
yield
|
||||
return
|
||||
|
||||
attrs: dict[str, str] = {}
|
||||
if user_id:
|
||||
attrs["user_id"] = hash_user_id(user_id)
|
||||
if session_id:
|
||||
attrs["session_id"] = session_id
|
||||
|
||||
with propagate_attributes(**attrs):
|
||||
yield
|
||||
156
api/app/core/llm.py
Normal file
156
api/app/core/llm.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||
|
||||
Every agent and the orchestrator call ``get_llm()``
|
||||
instead of directly constructing a provider-specific class. The model string
|
||||
follows the `LiteLLM model naming convention
|
||||
<https://docs.litellm.ai/docs/providers>`_:
|
||||
|
||||
* OpenAI: ``gpt-4o``, ``gpt-4o-mini``
|
||||
* Anthropic: ``anthropic/claude-3.5-sonnet``
|
||||
* Google: ``gemini/gemini-pro``
|
||||
* Ollama: ``ollama/llama3``
|
||||
* Bedrock: ``bedrock/anthropic.claude-v2``
|
||||
|
||||
Switch providers by changing **LLM_MODEL** in ``.env``
|
||||
— no code changes required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
import litellm
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from litellm import get_supported_openai_params # noqa: F401 – validates install
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
# Some models (e.g. gpt-5, o-series) reject unsupported params like temperature.
|
||||
# Drop them silently instead of raising UnsupportedParamsError.
|
||||
litellm.drop_params = True
|
||||
|
||||
# Some provider responses include a plain dict in the `usage` field where a
|
||||
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||
category=UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def _api_key_for_model(model: str) -> str | None:
|
||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||
if model.startswith("anthropic/"):
|
||||
return settings.ANTHROPIC_API_KEY or None
|
||||
if model.startswith("gemini/") or model.startswith("google/"):
|
||||
return settings.GOOGLE_API_KEY or None
|
||||
if model.startswith("cerebras/"):
|
||||
return settings.CEREBRAS_API_KEY or None
|
||||
if model.startswith("groq/"):
|
||||
return settings.GROQ_API_KEY or None
|
||||
if model.startswith("deepseek/"):
|
||||
return settings.DEEPSEEK_API_KEY or None
|
||||
if model.startswith("github_copilot/"):
|
||||
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
||||
# No API key is required; returning None lets LiteLLM handle auth.
|
||||
return None
|
||||
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
|
||||
return settings.OPENAI_API_KEY or None
|
||||
|
||||
|
||||
def get_llm(
|
||||
*,
|
||||
model: str | None = None,
|
||||
temperature: float = 0,
|
||||
) -> ChatOpenAI | ChatLiteLLM:
|
||||
"""Return a LangChain chat model backed by LiteLLM.
|
||||
|
||||
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
|
||||
at the LiteLLM proxy endpoint. In practice, ``litellm`` patches the
|
||||
``openai`` client transparently when the model string contains a provider
|
||||
prefix (``anthropic/…``, ``gemini/…``, etc.).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model:
|
||||
LiteLLM model identifier. Defaults to ``settings.LLM_MODEL``.
|
||||
temperature:
|
||||
Sampling temperature. ``0`` = deterministic.
|
||||
"""
|
||||
model = model or settings.LLM_MODEL
|
||||
|
||||
# Point LiteLLM to the custom token directory when configured.
|
||||
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||
|
||||
# Use ChatLiteLLM for provider-prefixed models (github_copilot/, anthropic/, etc.)
|
||||
# so LiteLLM handles routing and auth. ChatOpenAI for plain OpenAI model names.
|
||||
if "/" in model:
|
||||
return ChatLiteLLM(model=model, temperature=temperature)
|
||||
|
||||
return ChatOpenAI(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=_api_key_for_model(model),
|
||||
)
|
||||
|
||||
|
||||
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
|
||||
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
|
||||
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
|
||||
"task-brief-agent": lambda: settings.LLM_MODEL_TASK_BRIEF_AGENT or settings.LLM_MODEL,
|
||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
||||
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
|
||||
"note-summarizer": lambda: "gpt-4o-mini",
|
||||
}
|
||||
|
||||
|
||||
def model_for_agent(agent_name: str) -> str:
|
||||
"""Return the resolved model string for *agent_name* (for Langfuse tracking)."""
|
||||
return _AGENT_MODEL_SETTINGS.get(agent_name, lambda: settings.LLM_MODEL)()
|
||||
|
||||
|
||||
def get_agent_llm(
|
||||
agent_name: str,
|
||||
*,
|
||||
temperature: float = 0,
|
||||
) -> ChatOpenAI | ChatLiteLLM:
|
||||
"""Return an LLM configured for *agent_name*, respecting per-agent overrides.
|
||||
|
||||
Falls back to ``settings.LLM_MODEL`` for unknown agent names or when the
|
||||
per-agent override is left empty in ``.env``.
|
||||
"""
|
||||
model = model_for_agent(agent_name)
|
||||
return get_llm(model=model, temperature=temperature)
|
||||
|
||||
|
||||
async def embed(text: str) -> list[float]:
|
||||
"""Return an embedding vector for *text*.
|
||||
|
||||
Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env``
|
||||
(e.g. ``github_copilot/text-embedding-3-small``) applies here without any
|
||||
code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI
|
||||
model names to preserve existing behaviour.
|
||||
"""
|
||||
model = settings.LLM_EMBED_MODEL
|
||||
|
||||
if model.startswith("github_copilot/") or "/" in model:
|
||||
# Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.)
|
||||
# so the provider's auth mechanism is applied correctly.
|
||||
response = await litellm.aembedding(model=model, input=[text])
|
||||
return response.data[0]["embedding"]
|
||||
|
||||
# Plain OpenAI model name — use the raw AsyncOpenAI client (existing path).
|
||||
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
response = await client.embeddings.create(model=model, input=text)
|
||||
return response.data[0].embedding
|
||||
450
api/app/core/memory_extraction.py
Normal file
450
api/app/core/memory_extraction.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""Mem0-style Extract/Update pipeline — Phase 2.
|
||||
|
||||
Runs after every ``store_episode`` call to distil durable facts, preferences,
|
||||
routines, and relations from the latest conversation turn.
|
||||
|
||||
Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)``
|
||||
|
||||
Design notes
|
||||
------------
|
||||
- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate.
|
||||
- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving).
|
||||
- Zero-trust: never logs decrypted user content; relation subject/object labels are
|
||||
treated as identifiers (safe to log per spec).
|
||||
- Must not raise into the request path — caller wraps in asyncio.create_task().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context
|
||||
from app.core.llm import get_agent_llm, model_for_agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Fallback prompts (used when Langfuse unavailable) ─────────────────────────
|
||||
|
||||
_EXTRACTION_FALLBACK = (
|
||||
"You are a memory extractor for a personal AI secretary. Given the last conversation "
|
||||
"turn, the user's core memory, and recent episode summaries, identify durable facts, "
|
||||
"preferences, routines, and person/project relations worth remembering.\n\n"
|
||||
"Output JSON matching this schema exactly:\n"
|
||||
'{{"candidates": [{{"type": "<fact|preference|relation|routine>", '
|
||||
'"content": "<short canonical statement>", '
|
||||
'"target_tier": "<core|associative|relational|proactive>", '
|
||||
'"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n'
|
||||
"Rules:\n"
|
||||
"- Skip small talk, greetings, one-off questions.\n"
|
||||
"- Max 5 candidates per call.\n"
|
||||
"- Only extract durable information (still true next week).\n"
|
||||
"- For type=relation: subject/predicate/object required.\n"
|
||||
"- Default confidence=0.7.\n\n"
|
||||
"## Last turn\n{last_turn}\n\n"
|
||||
"## Core memory (current)\n{core_memory}\n\n"
|
||||
"## Recent episodes\n{recent_episodes}"
|
||||
)
|
||||
|
||||
_DECIDE_FALLBACK = (
|
||||
"You are a memory update decision engine. Given a new memory candidate and a list of "
|
||||
"existing memories from the same tier, decide what action to take.\n\n"
|
||||
"Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n"
|
||||
"- ADD: new information not in existing memories.\n"
|
||||
"- UPDATE: contradicts or supersedes an existing memory.\n"
|
||||
"- DELETE: states something is no longer true.\n"
|
||||
"- NOOP: already captured accurately.\n\n"
|
||||
"## New candidate\n{candidate}\n\n"
|
||||
"## Existing memories (same tier, top neighbours)\n{existing_memories}"
|
||||
)
|
||||
|
||||
|
||||
# ── Pydantic schemas ───────────────────────────────────────────────────────────
|
||||
|
||||
class MemoryCandidate(BaseModel):
|
||||
type: Literal["fact", "preference", "relation", "routine"]
|
||||
content: str
|
||||
target_tier: Literal["core", "associative", "relational", "proactive"]
|
||||
subject: str | None = None
|
||||
predicate: str | None = None
|
||||
object: str | None = None
|
||||
confidence: float = Field(default=0.7, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class ExtractionResult(BaseModel):
|
||||
candidates: list[MemoryCandidate] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Task 2.1 — Extract candidates ─────────────────────────────────────────────
|
||||
|
||||
async def extract_candidates(
|
||||
last_turn: str,
|
||||
core_memory: dict[str, str],
|
||||
recent_episodes: list[str],
|
||||
) -> ExtractionResult:
|
||||
"""Call gpt-4o-mini to extract memory candidates from the latest turn.
|
||||
|
||||
Returns an ExtractionResult (may be empty on failure — never raises).
|
||||
"""
|
||||
core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)"
|
||||
episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)"
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK)
|
||||
|
||||
# Compile with Langfuse variable syntax ({{var}}) or fallback {var}
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
system_text = prompt_obj.compile(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
if isinstance(system_text, list):
|
||||
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: compile failed: %s", exc)
|
||||
system_text = template.format(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
else:
|
||||
system_text = template.format(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
|
||||
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||
# Bind JSON mode so the model always returns parseable output.
|
||||
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
|
||||
|
||||
lf = get_langfuse()
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Extract memory candidates as JSON."),
|
||||
]
|
||||
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-extraction",
|
||||
model=model_for_agent("memory-extractor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
|
||||
raw = json.loads(response.content)
|
||||
result = ExtractionResult.model_validate(raw)
|
||||
logger.info("memory_extraction: extracted %d candidates", len(result.candidates))
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: extract_candidates failed: %s", exc)
|
||||
return ExtractionResult(candidates=[])
|
||||
|
||||
|
||||
# ── Task 2.2 — Decide action ──────────────────────────────────────────────────
|
||||
|
||||
async def decide_action(
|
||||
candidate: MemoryCandidate,
|
||||
existing: list[str],
|
||||
) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]:
|
||||
"""Decide what to do with a candidate given existing memories in the same tier.
|
||||
|
||||
Short-circuits to ADD without an LLM call when existing is empty (cost saving).
|
||||
Never raises.
|
||||
"""
|
||||
if not existing:
|
||||
return "ADD"
|
||||
|
||||
candidate_str = f"[{candidate.type}] {candidate.content}"
|
||||
existing_str = "\n".join(f"- {m}" for m in existing)
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK)
|
||||
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
system_text = prompt_obj.compile(
|
||||
candidate=candidate_str,
|
||||
existing_memories=existing_str,
|
||||
)
|
||||
if isinstance(system_text, list):
|
||||
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: decide compile failed: %s", exc)
|
||||
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||
else:
|
||||
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||
|
||||
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Decide action."),
|
||||
]
|
||||
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-decide-action",
|
||||
model=model_for_agent("memory-extractor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
verb = response.content.strip().upper()
|
||||
if verb in ("ADD", "UPDATE", "DELETE", "NOOP"):
|
||||
return verb # type: ignore[return-value]
|
||||
logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb)
|
||||
return "ADD"
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: decide_action failed: %s", exc)
|
||||
return "ADD"
|
||||
|
||||
|
||||
# ── Task 2.3 — Pipeline orchestrator ──────────────────────────────────────────
|
||||
|
||||
async def run_extraction(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
"""Full Mem0-style extract/update pipeline for one conversation turn.
|
||||
|
||||
Steps:
|
||||
1. Load core memory + last 5 episodes.
|
||||
2. extract_candidates() → up to 5 MemoryCandidate objects.
|
||||
3. For each candidate: find top-3 neighbours → decide_action() → apply.
|
||||
4. Trace via Langfuse.
|
||||
|
||||
Never raises — wraps everything in try/except.
|
||||
"""
|
||||
try:
|
||||
await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _run_extraction_inner(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||
|
||||
middleware = MemoryMiddleware(db)
|
||||
fernet = await middleware._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id)
|
||||
return
|
||||
|
||||
# 1. Load context
|
||||
core: dict[str, str] = await middleware._load_core(user_id, fernet)
|
||||
episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id)
|
||||
|
||||
last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}"
|
||||
|
||||
lf = get_langfuse()
|
||||
|
||||
async def _run(trace_id: str | None) -> dict[str, Any]:
|
||||
# 2. Extract candidates
|
||||
result = await extract_candidates(last_turn, core, episodes)
|
||||
if not result.candidates:
|
||||
logger.info("memory_extraction: no candidates user=%s", user_id)
|
||||
return {"candidates": 0, "applied": 0}
|
||||
|
||||
logger.info(
|
||||
"memory_extraction: processing %d candidates user=%s trace=%s",
|
||||
len(result.candidates),
|
||||
user_id,
|
||||
trace_id or "-",
|
||||
)
|
||||
|
||||
# 3. Apply each candidate
|
||||
applied = 0
|
||||
actions: list[str] = []
|
||||
for candidate in result.candidates:
|
||||
try:
|
||||
await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id)
|
||||
applied += 1
|
||||
actions.append(f"{candidate.type}:{candidate.target_tier}")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_extraction: apply failed candidate=%r user=%s: %s",
|
||||
candidate.content[:80],
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"memory_extraction: applied %d/%d candidates user=%s",
|
||||
applied,
|
||||
len(result.candidates),
|
||||
user_id,
|
||||
)
|
||||
return {"candidates": len(result.candidates), "applied": applied, "actions": actions}
|
||||
|
||||
with langfuse_context(user_id=user_id, session_id=session_id):
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="span",
|
||||
name="memory-extraction-pipeline",
|
||||
input={"last_turn_preview": last_turn[:200]},
|
||||
) as span:
|
||||
summary = await _run(trace_id=span.id)
|
||||
span.update(output=summary)
|
||||
try:
|
||||
lf.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
await _run(trace_id=None)
|
||||
|
||||
|
||||
async def _apply_candidate(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
fernet: Any,
|
||||
candidate: MemoryCandidate,
|
||||
trace_id: str | None,
|
||||
) -> None:
|
||||
"""Fetch neighbours, decide action, apply to the appropriate tier."""
|
||||
|
||||
neighbours: list[str] = []
|
||||
|
||||
if candidate.target_tier == "core":
|
||||
# For core tier: neighbours are existing core block values for similar keys.
|
||||
blocks = await middleware.list_core_blocks(user_id)
|
||||
neighbours = [b["value"] for b in blocks[:3]]
|
||||
|
||||
elif candidate.target_tier == "associative":
|
||||
neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3)
|
||||
|
||||
elif candidate.target_tier == "relational":
|
||||
# Relation candidates handled specially — passed to upsert_relation directly.
|
||||
# Neighbours: search by subject label if available.
|
||||
neighbours = []
|
||||
|
||||
elif candidate.target_tier == "proactive":
|
||||
neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3)
|
||||
|
||||
action = await decide_action(candidate, neighbours)
|
||||
logger.info(
|
||||
"memory_extraction: candidate type=%s tier=%s action=%s",
|
||||
candidate.type,
|
||||
candidate.target_tier,
|
||||
action,
|
||||
)
|
||||
|
||||
if action == "NOOP":
|
||||
return
|
||||
|
||||
if candidate.target_tier == "relational":
|
||||
# Always upsert relations — decide_action skipped (no neighbour search).
|
||||
if candidate.subject and candidate.predicate and candidate.object:
|
||||
await _upsert_relation(
|
||||
middleware, db, user_id, candidate, trace_id
|
||||
)
|
||||
return
|
||||
|
||||
if action in ("ADD", "UPDATE"):
|
||||
if candidate.target_tier == "core":
|
||||
# Derive a short key from the content (first 40 chars, snake_cased).
|
||||
key = _content_to_key(candidate.content)
|
||||
await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id)
|
||||
|
||||
elif candidate.target_tier == "associative":
|
||||
await middleware.store_associative(user_id, candidate.content)
|
||||
|
||||
elif candidate.target_tier == "proactive":
|
||||
await _store_proactive_stub(middleware, db, user_id, candidate, fernet)
|
||||
|
||||
elif action == "DELETE":
|
||||
if candidate.target_tier == "core":
|
||||
key = _content_to_key(candidate.content)
|
||||
await middleware.delete_core(user_id, key)
|
||||
|
||||
|
||||
def _content_to_key(content: str) -> str:
|
||||
"""Derive a short snake_case key from a content string (first 40 chars)."""
|
||||
import re # noqa: PLC0415
|
||||
slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_")
|
||||
return slug or "memory"
|
||||
|
||||
|
||||
async def _upsert_relation(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
candidate: MemoryCandidate,
|
||||
trace_id: str | None,
|
||||
) -> None:
|
||||
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
|
||||
await middleware.upsert_relation(
|
||||
user_id=user_id,
|
||||
subject=candidate.subject or "unknown",
|
||||
subject_type="unknown",
|
||||
predicate=candidate.predicate or "related_to",
|
||||
object_=candidate.object or "unknown",
|
||||
object_type="unknown",
|
||||
confidence=candidate.confidence,
|
||||
)
|
||||
logger.info(
|
||||
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
|
||||
candidate.subject,
|
||||
candidate.predicate,
|
||||
candidate.object,
|
||||
)
|
||||
|
||||
|
||||
async def _store_proactive_stub(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
candidate: MemoryCandidate,
|
||||
fernet: Any,
|
||||
) -> None:
|
||||
"""Store a proactive pattern row directly (MemoryProactive model)."""
|
||||
import uuid # noqa: PLC0415
|
||||
from app.models import MemoryProactive # noqa: PLC0415
|
||||
from app.core.memory_middleware import _encrypt # noqa: PLC0415
|
||||
|
||||
encrypted = _encrypt(fernet, candidate.content)
|
||||
row = MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
pattern_encrypted=encrypted,
|
||||
confidence=candidate.confidence,
|
||||
source="inferred",
|
||||
)
|
||||
db.add(row)
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info("memory_extraction: stored proactive pattern user=%s", user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: store proactive failed: %s", exc)
|
||||
await db.rollback()
|
||||
581
api/app/core/memory_maintenance.py
Normal file
581
api/app/core/memory_maintenance.py
Normal file
@@ -0,0 +1,581 @@
|
||||
"""Memory maintenance jobs — Phase 3/5.
|
||||
|
||||
Three entrypoints called by the scheduler (APScheduler) registered in app/main.py:
|
||||
|
||||
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
|
||||
mine_proactive_patterns(db, user_id) — Power+ pattern mining (Phase 5).
|
||||
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
|
||||
|
||||
All are safe to call manually or from tests; they never raise.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
||||
from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Decay parameters for relations
|
||||
_DECAY_FACTOR = 0.95
|
||||
_DECAY_PERIOD_DAYS = 30
|
||||
_PRUNE_THRESHOLD = 0.2
|
||||
|
||||
# Proactive pattern decay: 10 % per 7 days since last sighting
|
||||
_PROACTIVE_DECAY_FACTOR = 0.9
|
||||
_PROACTIVE_DECAY_PERIOD_DAYS = 7
|
||||
_PROACTIVE_PRUNE_THRESHOLD = 0.2
|
||||
|
||||
# Mining: require at least this many episodes to attempt pattern extraction
|
||||
_MIN_EPISODES_FOR_MINING = 3
|
||||
_MINING_LOOKBACK_DAYS = 30
|
||||
|
||||
# Audit: caps to control token cost
|
||||
_AUDIT_MAX_FACTS = 50
|
||||
_AUDIT_MAX_LABELS = 100
|
||||
|
||||
|
||||
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
||||
"""Apply confidence decay to all relation rows for a user.
|
||||
|
||||
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
|
||||
Rows whose confidence falls below 0.2 are deleted.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _decay_relations_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
now = datetime.now(timezone.utc)
|
||||
deleted = 0
|
||||
decayed = 0
|
||||
|
||||
for row in rows:
|
||||
reference = row.last_confirmed_at or row.created_at
|
||||
if reference is None:
|
||||
continue
|
||||
if reference.tzinfo is None:
|
||||
reference = reference.replace(tzinfo=timezone.utc)
|
||||
|
||||
days_elapsed = (now - reference).days
|
||||
if days_elapsed < _DECAY_PERIOD_DAYS:
|
||||
continue
|
||||
|
||||
periods = days_elapsed // _DECAY_PERIOD_DAYS
|
||||
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
|
||||
|
||||
if new_confidence < _PRUNE_THRESHOLD:
|
||||
await db.delete(row)
|
||||
deleted += 1
|
||||
logger.info(
|
||||
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
|
||||
"confidence=%.3f (below threshold)",
|
||||
row.id, user_id, row.subject_label, row.predicate, new_confidence,
|
||||
)
|
||||
else:
|
||||
row.confidence = new_confidence
|
||||
decayed += 1
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
|
||||
user_id, decayed, deleted,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
|
||||
|
||||
async def drain_extraction_queue(db: AsyncSession) -> None:
|
||||
"""Process pending ExtractionQueue rows for Free-tier users.
|
||||
|
||||
Each row corresponds to a stored episode that should be fed through the
|
||||
Mem0-style extraction pipeline. Rows are deleted after successful processing.
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _drain_extraction_queue_inner(db)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
|
||||
|
||||
|
||||
async def _drain_extraction_queue_inner(db: AsyncSession) -> None:
|
||||
from app.models import ExtractionQueue # noqa: PLC0415
|
||||
|
||||
result = await db.execute(select(ExtractionQueue))
|
||||
rows = result.scalars().all()
|
||||
|
||||
if not rows:
|
||||
logger.debug("memory_maintenance: drain_extraction_queue nothing to drain")
|
||||
return
|
||||
|
||||
logger.info("memory_maintenance: drain_extraction_queue pending=%d", len(rows))
|
||||
|
||||
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||
|
||||
processed = 0
|
||||
for row in rows:
|
||||
try:
|
||||
await run_extraction(
|
||||
db=db,
|
||||
user_id=row.user_id,
|
||||
last_user_msg="",
|
||||
last_assistant_msg="",
|
||||
session_id=None,
|
||||
)
|
||||
await db.delete(row)
|
||||
await db.commit()
|
||||
processed += 1
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: drain failed row=%s user=%s: %s",
|
||||
row.id, row.user_id, exc,
|
||||
)
|
||||
await db.rollback()
|
||||
|
||||
logger.info("memory_maintenance: drain_extraction_queue processed=%d/%d", processed, len(rows))
|
||||
|
||||
|
||||
async def mine_proactive_patterns(db: AsyncSession, user_id: str) -> None:
|
||||
"""Mine recurring behavioral patterns from last 30 days of episodes (Power+ only).
|
||||
|
||||
Steps:
|
||||
1. Gate on proactive_mining tier feature.
|
||||
2. Load + decrypt last 30 days of episodic summaries.
|
||||
3. Call gpt-4o-mini to identify recurring patterns.
|
||||
4. Encrypt and store each pattern in memory_proactive.
|
||||
5. Apply decay to existing proactive rows.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _mine_proactive_patterns_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: mine_proactive_patterns failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _mine_proactive_patterns_inner(db: AsyncSession, user_id: str) -> None:
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
tier = await tier_manager.get_tier(user_id, db)
|
||||
if not tier_manager.check_feature(tier, "proactive_mining"):
|
||||
logger.debug("memory_maintenance: mine_proactive_patterns skipped (tier=%s)", tier)
|
||||
return
|
||||
|
||||
# Load user Fernet key
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory_maintenance: mine_proactive_patterns no encryption_key user=%s", user_id)
|
||||
return
|
||||
|
||||
fernet = Fernet(user.encryption_key.encode())
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=_MINING_LOOKBACK_DAYS)
|
||||
|
||||
episodes_result = await db.execute(
|
||||
select(MemoryEpisodic)
|
||||
.where(
|
||||
MemoryEpisodic.user_id == user_id,
|
||||
MemoryEpisodic.created_at >= cutoff,
|
||||
)
|
||||
.order_by(MemoryEpisodic.created_at.asc())
|
||||
)
|
||||
episode_rows = episodes_result.scalars().all()
|
||||
|
||||
if len(episode_rows) < _MIN_EPISODES_FOR_MINING:
|
||||
logger.info(
|
||||
"memory_maintenance: mine_proactive_patterns skipped user=%s episodes=%d (< %d)",
|
||||
user_id, len(episode_rows), _MIN_EPISODES_FOR_MINING,
|
||||
)
|
||||
return
|
||||
|
||||
summaries: list[str] = []
|
||||
for ep in episode_rows:
|
||||
try:
|
||||
plaintext = fernet.decrypt(ep.summary_encrypted.encode()).decode()
|
||||
summaries.append(plaintext)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
|
||||
patterns = await _extract_proactive_patterns(summaries)
|
||||
if not patterns:
|
||||
logger.info("memory_maintenance: mine_proactive_patterns user=%s no patterns extracted", user_id)
|
||||
return
|
||||
|
||||
stored = 0
|
||||
for pattern_text in patterns:
|
||||
try:
|
||||
encrypted = fernet.encrypt(pattern_text.encode()).decode()
|
||||
row = MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
pattern_encrypted=encrypted,
|
||||
confidence=0.7,
|
||||
source="inferred",
|
||||
)
|
||||
db.add(row)
|
||||
stored += 1
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: failed to store pattern user=%s: %s", user_id, exc)
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: mine_proactive_patterns user=%s stored=%d",
|
||||
user_id, stored,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: mine_proactive_patterns commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
return
|
||||
|
||||
await _decay_proactive_patterns(db, user_id, fernet)
|
||||
|
||||
|
||||
async def _extract_proactive_patterns(summaries: list[str]) -> list[str]:
|
||||
"""Call memory-miner LLM to identify recurring behavioral/temporal patterns."""
|
||||
from app.core.llm import get_agent_llm # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-miner", temperature=0)
|
||||
combined = "\n---\n".join(summaries[-20:]) # cap at last 20 to control token usage
|
||||
prompt = (
|
||||
"You are analyzing conversation history for a personal AI secretary. "
|
||||
"Identify 3-5 recurring temporal or behavioral patterns (e.g. 'always works late on Thursdays', "
|
||||
"'prefers bullet-point summaries', 'frequently asks about Project Acme status'). "
|
||||
"Return each pattern as a plain, short English sentence on its own line. "
|
||||
"No numbering, no bullet points, no extra text.\n\n"
|
||||
f"Conversation history:\n{combined}"
|
||||
)
|
||||
try:
|
||||
response = await llm.ainvoke(prompt)
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
lines = [line.strip() for line in str(text).splitlines() if line.strip()]
|
||||
return lines[:5]
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: _extract_proactive_patterns LLM failed: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fernet) -> None:
|
||||
"""Decay confidence of existing proactive patterns; prune below threshold."""
|
||||
result = await db.execute(
|
||||
select(MemoryProactive).where(MemoryProactive.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
now = datetime.now(timezone.utc)
|
||||
deleted = 0
|
||||
decayed = 0
|
||||
|
||||
for row in rows:
|
||||
reference = row.created_at
|
||||
if reference is None:
|
||||
continue
|
||||
if reference.tzinfo is None:
|
||||
reference = reference.replace(tzinfo=timezone.utc)
|
||||
|
||||
days_elapsed = (now - reference).days
|
||||
if days_elapsed < _PROACTIVE_DECAY_PERIOD_DAYS:
|
||||
continue
|
||||
|
||||
periods = days_elapsed // _PROACTIVE_DECAY_PERIOD_DAYS
|
||||
new_confidence = row.confidence * (_PROACTIVE_DECAY_FACTOR ** periods)
|
||||
|
||||
if new_confidence < _PROACTIVE_PRUNE_THRESHOLD:
|
||||
await db.delete(row)
|
||||
deleted += 1
|
||||
else:
|
||||
row.confidence = new_confidence
|
||||
decayed += 1
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: decay_proactive user=%s decayed=%d deleted=%d",
|
||||
user_id, decayed, deleted,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
|
||||
|
||||
# ── Phase 7: weekly memory audit ──────────────────────────────────────────────
|
||||
|
||||
_AUDIT_CONTRADICTIONS_FALLBACK = (
|
||||
"You are auditing a personal AI assistant's memory bank. "
|
||||
"Each fact has an ID in brackets. "
|
||||
"Find pairs that directly contradict each other "
|
||||
"(e.g. 'prefers morning meetings' vs 'never schedules before noon'). "
|
||||
"For each contradiction, pick the ID to DELETE (the older or less specific one). "
|
||||
'Return ONLY a valid JSON array, no markdown fences: '
|
||||
'[{{"delete": "<id>", "reason": "<one line>"}}]. '
|
||||
"If no contradictions, return [].\n\n"
|
||||
"Facts:\n{facts}"
|
||||
)
|
||||
|
||||
_AUDIT_CANONICALIZE_FALLBACK = (
|
||||
"You are auditing entity labels in a personal AI assistant's relational memory. "
|
||||
"These are names of people, companies, projects, or topics. "
|
||||
"Group labels that clearly refer to the same real-world entity "
|
||||
"(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). "
|
||||
"Return ONLY a valid JSON array, no markdown fences: "
|
||||
'[{{"canonical": "<best label>", "variants": ["<v1>", "<v2>"]}}]. '
|
||||
"Only include groups with at least one variant. Singletons: omit.\n\n"
|
||||
"Labels:\n{labels}"
|
||||
)
|
||||
|
||||
|
||||
async def audit_memory(db: AsyncSession, user_id: str) -> None:
|
||||
"""Weekly audit: contradiction scan on associative facts + label canonicalization on relations.
|
||||
|
||||
Steps:
|
||||
1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM.
|
||||
2. LLM flags rows to delete (direct contradictions); hard-delete them.
|
||||
3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates.
|
||||
4. Rewrite variant labels to their canonical form in-place.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _audit_memory_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id)
|
||||
return
|
||||
|
||||
fernet = Fernet(user.encryption_key.encode())
|
||||
await _scan_associative_contradictions(db, user_id, fernet)
|
||||
await _canonicalize_relation_labels(db, user_id)
|
||||
|
||||
|
||||
async def _scan_associative_contradictions(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
fernet: Fernet,
|
||||
) -> None:
|
||||
"""Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows."""
|
||||
result = await db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc())
|
||||
.limit(_AUDIT_MAX_FACTS)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if len(rows) < 2:
|
||||
return
|
||||
|
||||
id_to_text: dict[str, str] = {}
|
||||
for row in rows:
|
||||
try:
|
||||
plaintext = fernet.decrypt(row.content_encrypted.encode()).decode()
|
||||
id_to_text[row.id] = plaintext
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if len(id_to_text) < 2:
|
||||
return
|
||||
|
||||
id_list = list(id_to_text.keys())
|
||||
numbered = "\n".join(
|
||||
f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list)
|
||||
)
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback(
|
||||
"memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK
|
||||
)
|
||||
system_text = compile_prompt(template, prompt_obj, facts=numbered)
|
||||
|
||||
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Audit facts for contradictions."),
|
||||
]
|
||||
try:
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-audit-contradictions",
|
||||
model=model_for_agent("memory-auditor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
deletions = json.loads(text.strip())
|
||||
if not isinstance(deletions, list):
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s",
|
||||
user_id, exc,
|
||||
)
|
||||
return
|
||||
|
||||
deleted = 0
|
||||
for item in deletions:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
rid = item.get("delete")
|
||||
if not rid or rid not in id_to_text:
|
||||
continue
|
||||
result2 = await db.execute(
|
||||
select(MemoryAssociative).where(
|
||||
MemoryAssociative.id == rid,
|
||||
MemoryAssociative.user_id == user_id,
|
||||
)
|
||||
)
|
||||
target = result2.scalar_one_or_none()
|
||||
if target:
|
||||
await db.delete(target)
|
||||
deleted += 1
|
||||
logger.info(
|
||||
"memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s",
|
||||
rid, user_id, item.get("reason", ""),
|
||||
)
|
||||
|
||||
if deleted:
|
||||
try:
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await db.rollback()
|
||||
|
||||
logger.info(
|
||||
"memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted
|
||||
)
|
||||
|
||||
|
||||
async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None:
|
||||
"""Group near-duplicate entity labels in memory_relations and unify to canonical form."""
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if not rows:
|
||||
return
|
||||
|
||||
all_labels: set[str] = set()
|
||||
for row in rows:
|
||||
all_labels.add(row.subject_label)
|
||||
all_labels.add(row.object_label)
|
||||
|
||||
labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS]
|
||||
if len(labels_list) < 2:
|
||||
return
|
||||
|
||||
labels_block = "\n".join(f"- {lbl}" for lbl in labels_list)
|
||||
template, prompt_obj = get_prompt_or_fallback(
|
||||
"memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK
|
||||
)
|
||||
system_text = compile_prompt(template, prompt_obj, labels=labels_block)
|
||||
|
||||
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Canonicalize entity labels."),
|
||||
]
|
||||
try:
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-audit-canonicalize",
|
||||
model=model_for_agent("memory-auditor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
groups = json.loads(text.strip())
|
||||
if not isinstance(groups, list):
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s",
|
||||
user_id, exc,
|
||||
)
|
||||
return
|
||||
|
||||
# Build variant → canonical map
|
||||
remap: dict[str, str] = {}
|
||||
for group in groups:
|
||||
if not isinstance(group, dict):
|
||||
continue
|
||||
canonical = group.get("canonical", "")
|
||||
variants = group.get("variants") or []
|
||||
if not canonical:
|
||||
continue
|
||||
for v in variants:
|
||||
if isinstance(v, str) and v != canonical:
|
||||
remap[v] = canonical
|
||||
|
||||
if not remap:
|
||||
return
|
||||
|
||||
updated = 0
|
||||
for row in rows:
|
||||
changed = False
|
||||
if row.subject_label in remap:
|
||||
row.subject_label = remap[row.subject_label]
|
||||
changed = True
|
||||
if row.object_label in remap:
|
||||
row.object_label = remap[row.object_label]
|
||||
changed = True
|
||||
if changed:
|
||||
updated += 1
|
||||
|
||||
if updated:
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: _canonicalize_relation_labels user=%s updated=%d",
|
||||
user_id, updated,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await db.rollback()
|
||||
733
api/app/core/memory_middleware.py
Normal file
733
api/app/core/memory_middleware.py
Normal file
@@ -0,0 +1,733 @@
|
||||
"""Memory Middleware — enrich requests with memory context and store interactions.
|
||||
|
||||
Four-tier memory model (MemGPT-style):
|
||||
core — persistent key/value user preferences, always injected
|
||||
associative — semantic similarity search via pgvector (top-k)
|
||||
episodic — recent session summaries (last N)
|
||||
proactive — behavioral patterns above confidence threshold
|
||||
|
||||
All memory content is encrypted at rest using the per-user Fernet key
|
||||
stored in User.encryption_key. Decryption happens in-memory only.
|
||||
|
||||
Usage:
|
||||
memory = MemoryMiddleware(db_session)
|
||||
context = await memory.enrich_context(user_id, message)
|
||||
# ... run agent ...
|
||||
await memory.store_episode(user_id, session_id, message, response)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import (
|
||||
ExtractionQueue,
|
||||
MemoryAssociative,
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
MemoryProactive,
|
||||
MemoryRelation,
|
||||
User,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# Tuning constants
|
||||
_ASSOCIATIVE_TOP_K = 5
|
||||
_EPISODIC_RECENT_N = 10
|
||||
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
||||
|
||||
|
||||
class MemoryMiddleware:
|
||||
"""Enrich orchestrator context with memory and persist interactions after."""
|
||||
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self._db = db
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────
|
||||
|
||||
async def enrich_context(
|
||||
self,
|
||||
user_id: str,
|
||||
message: str,
|
||||
trace_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build memory context dict to inject into the orchestrator before LLM call.
|
||||
|
||||
Returns a dict with keys:
|
||||
core_memory — {key: plaintext_value, ...}
|
||||
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
|
||||
"""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return {}
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
user_tier: str = user_dbg.get("tier") or "free"
|
||||
|
||||
core = await self._load_core(user_id, fernet)
|
||||
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||
proactive = await self._load_proactive(user_id, fernet)
|
||||
relational = await self._load_relational(user_id, user_tier=user_tier)
|
||||
|
||||
logger.info(
|
||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_tier,
|
||||
len(core),
|
||||
len(associative),
|
||||
len(episodic),
|
||||
len(proactive),
|
||||
len(relational),
|
||||
)
|
||||
|
||||
return {
|
||||
"core_memory": core,
|
||||
"associative_memory": associative,
|
||||
"episodic_memory": episodic,
|
||||
"proactive_hints": proactive,
|
||||
"relational_memory": relational,
|
||||
}
|
||||
|
||||
async def store_episode(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
message: str,
|
||||
response: str,
|
||||
trace_id: str | None = None,
|
||||
) -> None:
|
||||
"""Summarise and store a completed interaction in episodic memory.
|
||||
|
||||
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||
latency low. After committing the episode row, dispatches the Mem0-style
|
||||
extraction pipeline:
|
||||
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
|
||||
- Free → enqueue an ExtractionQueue row for the daily cron.
|
||||
"""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||
encrypted = _encrypt(fernet, summary)
|
||||
|
||||
episode = MemoryEpisodic(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
summary_encrypted=encrypted,
|
||||
session_id=session_id,
|
||||
)
|
||||
self._db.add(episode)
|
||||
episode_id: str = episode.id
|
||||
try:
|
||||
await self._db.commit()
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
tier = user_dbg.get("tier") or "free"
|
||||
logger.info(
|
||||
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
tier,
|
||||
session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
return
|
||||
|
||||
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
|
||||
await self._dispatch_extraction(
|
||||
user_id=user_id,
|
||||
episode_id=episode_id,
|
||||
last_user_msg=message,
|
||||
last_assistant_msg=response,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _dispatch_extraction(
|
||||
self,
|
||||
user_id: str,
|
||||
episode_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
"""Route extraction to realtime task or batch queue based on user tier."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
tier = await tier_manager.get_tier(user_id, self._db)
|
||||
|
||||
if tier_manager.check_feature(tier, "realtime_extraction"):
|
||||
# Pro/Power/Team: fire-and-forget in the background.
|
||||
# Must open a fresh session — request session closes after handler returns.
|
||||
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
|
||||
async def _task() -> None:
|
||||
try:
|
||||
async with async_session() as fresh_db:
|
||||
await run_extraction(
|
||||
db=fresh_db,
|
||||
user_id=user_id,
|
||||
last_user_msg=last_user_msg,
|
||||
last_assistant_msg=last_assistant_msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: extraction task failed user=%s: %s", user_id, exc
|
||||
)
|
||||
|
||||
asyncio.create_task(_task())
|
||||
logger.info("memory: realtime extraction dispatched user=%s", user_id)
|
||||
else:
|
||||
# Free tier: enqueue for daily batch cron.
|
||||
queue_row = ExtractionQueue(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
episode_id=episode_id,
|
||||
)
|
||||
self._db.add(queue_row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: extraction enqueued (batch) user=%s episode=%s",
|
||||
user_id,
|
||||
episode_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: extraction queue insert failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await self._db.rollback()
|
||||
|
||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||
"""Upsert a core memory key/value for a user."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
encrypted = _encrypt(fernet, value)
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(
|
||||
MemoryCore.user_id == user_id,
|
||||
MemoryCore.key == key,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing is not None:
|
||||
existing.value_encrypted = encrypted
|
||||
else:
|
||||
self._db.add(MemoryCore(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
value_encrypted=encrypted,
|
||||
))
|
||||
try:
|
||||
await self._db.commit()
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
logger.info(
|
||||
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_dbg.get("tier") or "-",
|
||||
key,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
||||
"""Return core memory as editable blocks (label/value)."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore)
|
||||
.where(MemoryCore.user_id == user_id)
|
||||
.order_by(MemoryCore.key.asc())
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[dict[str, str]] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append({"label": row.key, "value": plaintext})
|
||||
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
||||
return out
|
||||
|
||||
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
||||
"""Return a single core memory block value by label."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return None
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(
|
||||
MemoryCore.user_id == user_id,
|
||||
MemoryCore.key == label,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
||||
return None
|
||||
value = _safe_decrypt(fernet, row.value_encrypted)
|
||||
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
||||
return value
|
||||
|
||||
async def delete_core(self, user_id: str, label: str) -> bool:
|
||||
"""Delete a core memory block by label. Returns True if deleted."""
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(
|
||||
MemoryCore.user_id == user_id,
|
||||
MemoryCore.key == label,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
||||
return False
|
||||
|
||||
await self._db.delete(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
||||
await self._db.rollback()
|
||||
return False
|
||||
|
||||
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
||||
"""Append content to a core block, creating it if missing."""
|
||||
current = await self.get_core_block(user_id, label)
|
||||
if current is None:
|
||||
await self.update_core(user_id, label, content)
|
||||
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
||||
return
|
||||
await self.update_core(user_id, label, f"{current}\n{content}")
|
||||
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
||||
|
||||
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
||||
"""Replace one exact string inside a core block. Returns False if not found."""
|
||||
current = await self.get_core_block(user_id, label)
|
||||
if current is None or old not in current:
|
||||
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
||||
return False
|
||||
await self.update_core(user_id, label, current.replace(old, new, 1))
|
||||
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||
return True
|
||||
|
||||
async def store_associative(
|
||||
self,
|
||||
user_id: str,
|
||||
content: str,
|
||||
entity_type: str | None = None,
|
||||
entity_id: str | None = None,
|
||||
) -> None:
|
||||
"""Store associative memory; embed if user tier has real_embeddings."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
encrypted = _encrypt(fernet, content)
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
user_tier = user_dbg.get("tier") or "free"
|
||||
|
||||
embedding: list[float] | None = None
|
||||
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||
embedding = await embed_text(content)
|
||||
|
||||
row = MemoryAssociative(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
content_encrypted=encrypted,
|
||||
embedding=embedding,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
self._db.add(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: store_associative user=%s embedded=%s",
|
||||
user_id,
|
||||
embedding is not None,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def upsert_relation(
|
||||
self,
|
||||
user_id: str,
|
||||
subject: str,
|
||||
subject_type: str,
|
||||
predicate: str,
|
||||
object_: str,
|
||||
object_type: str,
|
||||
*,
|
||||
confidence: float = 0.7,
|
||||
source_episode_id: str | None = None,
|
||||
notes: str | None = None,
|
||||
) -> None:
|
||||
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
|
||||
|
||||
subject_label / object_label are plaintext entity identifiers — not encrypted.
|
||||
notes is optional; encrypted with user Fernet if provided.
|
||||
"""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
user_tier = user_dbg.get("tier") or "free"
|
||||
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
|
||||
return
|
||||
|
||||
notes_encrypted: bytes | None = None
|
||||
if notes:
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet:
|
||||
notes_encrypted = fernet.encrypt(notes.encode())
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryRelation).where(
|
||||
MemoryRelation.user_id == user_id,
|
||||
MemoryRelation.subject_label == subject,
|
||||
MemoryRelation.predicate == predicate,
|
||||
MemoryRelation.object_label == object_,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing is not None:
|
||||
existing.subject_type = subject_type
|
||||
existing.object_type = object_type
|
||||
existing.confidence = confidence
|
||||
existing.last_confirmed_at = _now()
|
||||
if notes_encrypted is not None:
|
||||
existing.notes_encrypted = notes_encrypted
|
||||
else:
|
||||
self._db.add(MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
subject_label=subject,
|
||||
subject_type=subject_type,
|
||||
predicate=predicate,
|
||||
object_label=object_,
|
||||
object_type=object_type,
|
||||
confidence=confidence,
|
||||
source_episode_id=source_episode_id,
|
||||
notes_encrypted=notes_encrypted,
|
||||
))
|
||||
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
|
||||
user_id, subject, predicate, object_,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def query_relations(
|
||||
self,
|
||||
user_id: str,
|
||||
subject: str | None = None,
|
||||
predicate: str | None = None,
|
||||
object_: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[MemoryRelation]:
|
||||
"""Query relation rows for a user with optional filters."""
|
||||
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
if subject is not None:
|
||||
q = q.where(MemoryRelation.subject_label == subject)
|
||||
if predicate is not None:
|
||||
q = q.where(MemoryRelation.predicate == predicate)
|
||||
if object_ is not None:
|
||||
q = q.where(MemoryRelation.object_label == object_)
|
||||
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
|
||||
result = await self._db.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||
"""Insert a long-term archival memory entry."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
encrypted = _encrypt(fernet, content)
|
||||
row = MemoryAssociative(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
content_encrypted=encrypted,
|
||||
embedding=None,
|
||||
entity_type=source,
|
||||
entity_id=None,
|
||||
)
|
||||
self._db.add(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
||||
except Exception as exc:
|
||||
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc())
|
||||
.limit(100)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
needle = query.strip().lower()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is None:
|
||||
continue
|
||||
if not needle or needle in plaintext.lower():
|
||||
out.append(plaintext)
|
||||
if len(out) >= max(top_k, 1):
|
||||
break
|
||||
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||
return out
|
||||
|
||||
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||
"""Search recall memory (episodic summaries) by keyword."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryEpisodic)
|
||||
.where(MemoryEpisodic.user_id == user_id)
|
||||
.order_by(MemoryEpisodic.created_at.desc())
|
||||
.limit(100)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
needle = query.strip().lower()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||
if plaintext is None:
|
||||
continue
|
||||
if not needle or needle in plaintext.lower():
|
||||
out.append(plaintext)
|
||||
if len(out) >= max(top_k, 1):
|
||||
break
|
||||
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||
return out
|
||||
|
||||
# ── Private helpers ───────────────────────────────────────────────────────
|
||||
|
||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||
"""Load the user's Fernet key from DB. Returns None if missing."""
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory: no encryption_key for user=%s", user_id)
|
||||
return None
|
||||
return Fernet(user.encryption_key.encode())
|
||||
|
||||
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||
"""Load lightweight user debug fields for trace logs."""
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
return {"tier": None}
|
||||
|
||||
sub_result = await self._db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub_tier: str | None = sub_result.scalar_one_or_none()
|
||||
if sub_tier:
|
||||
tier = sub_tier
|
||||
elif settings.ENV == "dev":
|
||||
tier = "power"
|
||||
else:
|
||||
tier = user.tier or "free"
|
||||
|
||||
return {"tier": tier}
|
||||
|
||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: dict[str, str] = {}
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||
if plaintext is not None:
|
||||
out[row.key] = plaintext
|
||||
return out
|
||||
|
||||
async def _load_associative(
|
||||
self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free"
|
||||
) -> list[str]:
|
||||
"""Load top-k associative memories.
|
||||
|
||||
Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature).
|
||||
Free / embedding failure: keyword-ordered fallback (most recent rows).
|
||||
"""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||
|
||||
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||
vec = await embed_text(message)
|
||||
if vec is not None:
|
||||
try:
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(
|
||||
MemoryAssociative.user_id == user_id,
|
||||
MemoryAssociative.embedding.isnot(None),
|
||||
)
|
||||
.order_by(MemoryAssociative.embedding.cosine_distance(vec))
|
||||
.limit(_ASSOCIATIVE_TOP_K)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
logger.info(
|
||||
"memory: _load_associative user=%s mode=vector hits=%d",
|
||||
user_id,
|
||||
len(out),
|
||||
)
|
||||
return out
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: vector search failed user=%s, falling back to keyword: %s",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
# Keyword fallback: most recent rows
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc())
|
||||
.limit(_ASSOCIATIVE_TOP_K)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
async def _load_episodic(
|
||||
self,
|
||||
user_id: str,
|
||||
fernet: Fernet,
|
||||
session_id: str | None = None,
|
||||
) -> list[str]:
|
||||
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||
if session_id:
|
||||
query = query.where(MemoryEpisodic.session_id == session_id)
|
||||
result = await self._db.execute(
|
||||
query
|
||||
.order_by(MemoryEpisodic.created_at.desc())
|
||||
.limit(_EPISODIC_RECENT_N)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
|
||||
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryRelation)
|
||||
.where(MemoryRelation.user_id == user_id)
|
||||
.order_by(MemoryRelation.confidence.desc())
|
||||
.limit(10)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out = [
|
||||
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
|
||||
for r in rows
|
||||
]
|
||||
return out
|
||||
|
||||
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||
result = await self._db.execute(
|
||||
select(MemoryProactive)
|
||||
.where(
|
||||
MemoryProactive.user_id == user_id,
|
||||
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
||||
)
|
||||
.order_by(MemoryProactive.confidence.desc())
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
|
||||
# ── Encryption helpers ────────────────────────────────────────────────────────
|
||||
|
||||
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
||||
return fernet.encrypt(plaintext.encode()).decode()
|
||||
|
||||
|
||||
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
||||
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
|
||||
try:
|
||||
return fernet.decrypt(ciphertext.encode()).decode()
|
||||
except (InvalidToken, Exception) as exc:
|
||||
logger.warning("memory: decrypt failed: %s", exc)
|
||||
return None
|
||||
51
api/app/core/note_summarizer.py
Normal file
51
api/app/core/note_summarizer.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Note summarizer — generates a compact AI summary for a note.
|
||||
|
||||
Called fire-and-forget from create_note / update_note tools so the
|
||||
``notes.ai_summary`` column stays current without blocking the agent loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from app.core.langfuse_client import get_prompt_or_fallback
|
||||
from app.core.llm import get_agent_llm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FALLBACK_PROMPT = """\
|
||||
Summarize this note in <=250 characters. Be terse and dense.
|
||||
Keep proper nouns, dates, decisions, and action items.
|
||||
Do not start with "This note".
|
||||
Respond with the summary text only — no intro, no labels.
|
||||
|
||||
Title: {title}
|
||||
Content: {content}"""
|
||||
|
||||
_MAX_CONTENT_CHARS = 4000
|
||||
|
||||
|
||||
async def generate_note_summary(title: str, content: str) -> str:
|
||||
"""Return a <=250-char summary of *title* + *content*.
|
||||
|
||||
Uses the Langfuse ``note_summary`` prompt (hot-swappable) with a local
|
||||
fallback. Truncates *content* to 4000 chars before sending to avoid
|
||||
token waste on large notes.
|
||||
"""
|
||||
template, _ = get_prompt_or_fallback("note_summary", _FALLBACK_PROMPT)
|
||||
trimmed = content[:_MAX_CONTENT_CHARS]
|
||||
system_prompt = template.format(title=title, content=trimmed)
|
||||
|
||||
try:
|
||||
llm = get_agent_llm("note-summarizer")
|
||||
response = await llm.ainvoke([
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content="Generate the summary."),
|
||||
])
|
||||
text = response.content if isinstance(response.content, str) else ""
|
||||
return text.strip()[:250]
|
||||
except Exception as exc:
|
||||
logger.warning("note_summarizer: failed to generate summary: %s", exc)
|
||||
return ""
|
||||
63
api/app/core/output_formatter.py
Normal file
63
api/app/core/output_formatter.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Output formatter for deep-agent stream events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
|
||||
|
||||
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
|
||||
_CANVAS_BLOCK_RE = re.compile(
|
||||
r'<canvas\s+kind=["\']([^"\']+)["\']>(.*?)</canvas>',
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
|
||||
"""Strip the first <canvas kind="...">...</canvas> block from *text*.
|
||||
|
||||
Returns ``(visible_text, canvas_content, canvas_kind)``.
|
||||
``canvas_content`` and ``canvas_kind`` are ``None`` when no block is found.
|
||||
"""
|
||||
match = _CANVAS_BLOCK_RE.search(text)
|
||||
if not match:
|
||||
return text, None, None
|
||||
|
||||
canvas_kind = match.group(1).strip()
|
||||
canvas_content = match.group(2).strip()
|
||||
visible = text[: match.start()] + text[match.end() :]
|
||||
visible = visible.strip()
|
||||
return visible, canvas_content, canvas_kind
|
||||
|
||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd
|
||||
|
||||
|
||||
class StreamFormatter:
|
||||
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
|
||||
async def format(
|
||||
self,
|
||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
started = False
|
||||
|
||||
async for event_type, data in event_stream:
|
||||
if event_type != "token":
|
||||
continue
|
||||
|
||||
if not started:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
started = True
|
||||
|
||||
text = str(data or "")
|
||||
if text:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=text)
|
||||
|
||||
if not started:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
yield WsStreamEnd(request_id=self.request_id)
|
||||
104
api/app/core/preprocessors/__init__.py
Normal file
104
api/app/core/preprocessors/__init__.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Preprocessor registry: detect content type and dispatch to handlers.
|
||||
|
||||
Public API
|
||||
----------
|
||||
detect_content_type(filename, raw_content) -> str
|
||||
Heuristic detection based on file extension and content patterns.
|
||||
|
||||
preprocess(content_type, raw_content) -> PreprocessResult
|
||||
Dispatch to the appropriate handler.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from app.core.preprocessors.base import PreprocessResult
|
||||
|
||||
# ── Heuristics ────────────────────────────────────────────────────────
|
||||
|
||||
# Patterns that strongly suggest an email HTML file
|
||||
_EMAIL_SIGNALS = re.compile(
|
||||
r"(Subject:|From:|To:|Date:|Sent:|MIME-Version:|Content-Type:\s*text/html)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Patterns that suggest a generic HTML page (not an email)
|
||||
_GENERIC_HTML_SIGNALS = re.compile(
|
||||
r"<(nav|main|header|footer|article|section)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def detect_content_type(filename: str, raw_content: str) -> str:
|
||||
"""Return a content-type string for the given file.
|
||||
|
||||
Supported types: ``"email_html"``, ``"generic_html"``,
|
||||
``"plain_text"``, ``"unknown"``.
|
||||
"""
|
||||
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
||||
|
||||
if ext == "txt":
|
||||
return "plain_text"
|
||||
|
||||
if ext in ("html", "htm", "eml", "mhtml", "mht"):
|
||||
# Prefer email detection over generic HTML
|
||||
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
||||
return "email_html"
|
||||
if _GENERIC_HTML_SIGNALS.search(raw_content[:4096]) or "<html" in raw_content[:200].lower():
|
||||
return "generic_html"
|
||||
# .html without clear signals — check for any email header
|
||||
if re.search(r"^(From|To|Subject|Date):", raw_content[:2048], re.MULTILINE | re.IGNORECASE):
|
||||
return "email_html"
|
||||
return "generic_html"
|
||||
|
||||
# Plain text files with email headers
|
||||
if ext in ("", "txt") or not ext:
|
||||
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
||||
return "email_html"
|
||||
|
||||
# Detect binary content
|
||||
try:
|
||||
raw_content.encode("utf-8")
|
||||
except (UnicodeEncodeError, AttributeError):
|
||||
return "unknown"
|
||||
|
||||
# Non-text bytes heuristic: high ratio of non-printable chars
|
||||
sample = raw_content[:512]
|
||||
non_printable = sum(1 for c in sample if ord(c) < 32 and c not in "\r\n\t")
|
||||
if len(sample) > 0 and non_printable / len(sample) > 0.1:
|
||||
return "unknown"
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
# ── Generic fallback handler ──────────────────────────────────────────
|
||||
|
||||
def _preprocess_generic(raw_content: str, content_type: str) -> PreprocessResult:
|
||||
"""Strip HTML tags if present, return text as-is."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
text = BeautifulSoup(raw_content, "html.parser").get_text(separator="\n")
|
||||
except ImportError:
|
||||
# No BeautifulSoup — strip tags with a simple regex
|
||||
text = re.sub(r"<[^>]+>", "", raw_content)
|
||||
|
||||
text = re.sub(r"\n{3,}", "\n\n", text).strip()
|
||||
return PreprocessResult(content_type=content_type, clean_text=text, metadata={})
|
||||
|
||||
|
||||
# ── Dispatch ──────────────────────────────────────────────────────────
|
||||
|
||||
def preprocess(content_type: str, raw_content: str) -> PreprocessResult:
|
||||
"""Dispatch *raw_content* to the handler registered for *content_type*.
|
||||
|
||||
Falls back to the generic handler for unknown types.
|
||||
"""
|
||||
if content_type == "email_html":
|
||||
from app.core.preprocessors.email_html import preprocess_email_html
|
||||
return preprocess_email_html(raw_content)
|
||||
|
||||
return _preprocess_generic(raw_content, content_type)
|
||||
|
||||
|
||||
__all__ = ["detect_content_type", "preprocess", "PreprocessResult"]
|
||||
25
api/app/core/preprocessors/base.py
Normal file
25
api/app/core/preprocessors/base.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Base types for the preprocessor system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessResult:
|
||||
"""Output of a preprocessor handler.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
content_type:
|
||||
The detected content type (e.g. ``"email_html"``, ``"plain_text"``).
|
||||
clean_text:
|
||||
Human-readable text stripped of markup/binary noise.
|
||||
metadata:
|
||||
Dict of extracted metadata (keys vary by handler).
|
||||
Common keys: ``subject``, ``from``, ``to``, ``date``, ``filename``.
|
||||
"""
|
||||
|
||||
content_type: str
|
||||
clean_text: str
|
||||
metadata: dict = field(default_factory=dict)
|
||||
111
api/app/core/preprocessors/email_html.py
Normal file
111
api/app/core/preprocessors/email_html.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Preprocessor for email HTML files.
|
||||
|
||||
Handles:
|
||||
- HTML stripping via BeautifulSoup
|
||||
- Metadata extraction (Subject, From, To, Date)
|
||||
- Thread splitting — isolates the latest reply
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.core.preprocessors.base import PreprocessResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
# ── Thread split markers ──────────────────────────────────────────────
|
||||
|
||||
# Matches patterns like:
|
||||
# "On Mon, Apr 7, 2026 at 10:00 AM, Alice <alice@co.com> wrote:"
|
||||
# "-----Original Message-----"
|
||||
# "> " (plain-text quote prefix)
|
||||
_THREAD_PATTERNS = [
|
||||
re.compile(r"^On\s+.+wrote\s*:", re.IGNORECASE | re.MULTILINE),
|
||||
re.compile(r"^-{3,}\s*(original message|forwarded message)\s*-{3,}", re.IGNORECASE | re.MULTILINE),
|
||||
re.compile(r"^>{1,}\s+\S", re.MULTILINE),
|
||||
re.compile(r"^From:\s+.+\nSent:\s+", re.IGNORECASE | re.MULTILINE),
|
||||
]
|
||||
|
||||
# ── Metadata patterns (applied on raw HTML / plain fallback) ──────────
|
||||
|
||||
_META_PATTERNS: dict[str, list[re.Pattern]] = {
|
||||
"subject": [
|
||||
re.compile(r"<title>(.+?)</title>", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"Subject:\s*(.+)", re.IGNORECASE),
|
||||
],
|
||||
"from": [
|
||||
re.compile(r'<meta[^>]+name=["\']?from["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||
re.compile(r"From:\s*(.+)", re.IGNORECASE),
|
||||
],
|
||||
"to": [
|
||||
re.compile(r'<meta[^>]+name=["\']?to["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||
re.compile(r"To:\s*(.+)", re.IGNORECASE),
|
||||
],
|
||||
"date": [
|
||||
re.compile(r'<meta[^>]+name=["\']?date["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||
re.compile(r"Date:\s*(.+)", re.IGNORECASE),
|
||||
re.compile(r"Sent:\s*(.+)", re.IGNORECASE),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _extract_metadata(raw_html: str, text: str) -> dict:
|
||||
"""Extract Subject/From/To/Date from raw HTML or plain text."""
|
||||
metadata: dict[str, str] = {}
|
||||
for field, patterns in _META_PATTERNS.items():
|
||||
for pat in patterns:
|
||||
m = pat.search(raw_html) or pat.search(text)
|
||||
if m:
|
||||
metadata[field] = m.group(1).strip()
|
||||
break
|
||||
return metadata
|
||||
|
||||
|
||||
def _split_thread(text: str) -> str:
|
||||
"""Return only the latest message in a threaded email."""
|
||||
earliest_pos: int | None = None
|
||||
for pat in _THREAD_PATTERNS:
|
||||
m = pat.search(text)
|
||||
if m and (earliest_pos is None or m.start() < earliest_pos):
|
||||
earliest_pos = m.start()
|
||||
|
||||
if earliest_pos is not None and earliest_pos > 0:
|
||||
return text[:earliest_pos].strip()
|
||||
return text.strip()
|
||||
|
||||
|
||||
def preprocess_email_html(raw_content: str) -> PreprocessResult:
|
||||
"""Strip HTML, extract metadata, split thread from an email HTML file."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup # lazy import — optional dep
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"beautifulsoup4 is required for email_html preprocessing. "
|
||||
"Install it with: pip install beautifulsoup4"
|
||||
) from exc
|
||||
|
||||
# Parse with lxml if available, fall back to html.parser
|
||||
try:
|
||||
soup = BeautifulSoup(raw_content, "lxml")
|
||||
except Exception:
|
||||
soup = BeautifulSoup(raw_content, "html.parser")
|
||||
|
||||
# Remove noise tags
|
||||
for tag in soup(["style", "script", "head", "noscript"]):
|
||||
tag.decompose()
|
||||
|
||||
clean_text = soup.get_text(separator="\n")
|
||||
# Collapse excessive blank lines
|
||||
clean_text = re.sub(r"\n{3,}", "\n\n", clean_text).strip()
|
||||
|
||||
metadata = _extract_metadata(raw_content, clean_text)
|
||||
latest_message = _split_thread(clean_text)
|
||||
|
||||
return PreprocessResult(
|
||||
content_type="email_html",
|
||||
clean_text=latest_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
30
api/app/core/scout_registry.py
Normal file
30
api/app/core/scout_registry.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Minimal agent base types retained for compatibility with batch runners."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Common base for non-chat agents still using the old base contract."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str = "",
|
||||
shared_memory: dict[str, Any] | None = None,
|
||||
vector_store_context: list[str] | None = None,
|
||||
) -> None:
|
||||
self.user_id = user_id
|
||||
self.shared_memory: dict[str, Any] = shared_memory or {}
|
||||
self.vector_store_context: list[str] = vector_store_context or []
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str: ...
|
||||
|
||||
@property
|
||||
def skills(self) -> list[str]:
|
||||
return []
|
||||
1051
api/app/core/scout_runner.py
Normal file
1051
api/app/core/scout_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
96
api/app/core/scout_session_buffer.py
Normal file
96
api/app/core/scout_session_buffer.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""In-process TTL buffer for per-session LangChain message history.
|
||||
|
||||
Stores the full message list (including AIMessage with tool_calls and ToolMessage)
|
||||
keyed by (user_id, session_id), so agents can reconstruct tool-call context across
|
||||
conversation turns without it being lossy through the wire.
|
||||
|
||||
Single-process only. For multi-worker deployments, replace the _SessionBuffer
|
||||
implementation with one backed by Redis (serialize LangChain messages to dicts via
|
||||
message_to_dict / messages_from_dict from langchain_core.messages).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from threading import Lock
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
SESSION_TTL_SECONDS = 1800 # 30-minute idle expiry
|
||||
MAX_MESSAGES_PER_SESSION = 80 # cap to avoid unbounded memory growth
|
||||
|
||||
|
||||
class _SessionBuffer:
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[tuple[str, str], tuple[float, list[BaseMessage]]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _evict_stale(self) -> None:
|
||||
now = time.monotonic()
|
||||
stale = [k for k, (ts, _) in self._store.items() if now - ts > SESSION_TTL_SECONDS]
|
||||
for k in stale:
|
||||
del self._store[k]
|
||||
|
||||
def get(self, user_id: str, session_id: str) -> list[BaseMessage] | None:
|
||||
key = (user_id, session_id)
|
||||
with self._lock:
|
||||
entry = self._store.get(key)
|
||||
if entry is None:
|
||||
return None
|
||||
ts, msgs = entry
|
||||
if time.monotonic() - ts > SESSION_TTL_SECONDS:
|
||||
del self._store[key]
|
||||
return None
|
||||
self._store[key] = (time.monotonic(), msgs)
|
||||
return list(msgs)
|
||||
|
||||
def set(self, user_id: str, session_id: str, messages: list[BaseMessage]) -> None:
|
||||
key = (user_id, session_id)
|
||||
capped = messages[-MAX_MESSAGES_PER_SESSION:]
|
||||
with self._lock:
|
||||
self._evict_stale()
|
||||
self._store[key] = (time.monotonic(), capped)
|
||||
|
||||
def clear(self, user_id: str, session_id: str) -> None:
|
||||
with self._lock:
|
||||
self._store.pop((user_id, session_id), None)
|
||||
|
||||
def append_system_message(self, user_id: str, session_id: str, text: str) -> None:
|
||||
"""Append a synthetic system message to the buffer for the given session.
|
||||
|
||||
Creates the session slot if it does not yet exist. Used by the
|
||||
contextual_scope_update handler to inject navigation events without
|
||||
making an LLM call.
|
||||
"""
|
||||
from langchain_core.messages import SystemMessage # noqa: PLC0415
|
||||
|
||||
key = (user_id, session_id)
|
||||
with self._lock:
|
||||
entry = self._store.get(key)
|
||||
if entry is None:
|
||||
msgs: list[BaseMessage] = [SystemMessage(content=text)]
|
||||
else:
|
||||
_, existing = entry
|
||||
msgs = list(existing) + [SystemMessage(content=text)]
|
||||
capped = msgs[-MAX_MESSAGES_PER_SESSION:]
|
||||
self._store[key] = (time.monotonic(), capped)
|
||||
|
||||
|
||||
class ContextualBufferProxy:
|
||||
"""Thin wrapper around _SessionBuffer that closes over user_id + session_id.
|
||||
|
||||
Returned by get_session_buffer() so callers can call
|
||||
``proxy.append_system_message(text)`` without threading user_id/session_id
|
||||
through every call site.
|
||||
"""
|
||||
|
||||
def __init__(self, buf: "_SessionBuffer", user_id: str, session_id: str) -> None:
|
||||
self._buf = buf
|
||||
self._user_id = user_id
|
||||
self._session_id = session_id
|
||||
|
||||
def append_system_message(self, text: str) -> None:
|
||||
self._buf.append_system_message(self._user_id, self._session_id, text)
|
||||
|
||||
|
||||
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
|
||||
session_buffer = _SessionBuffer()
|
||||
115
api/app/core/ws_context.py
Normal file
115
api/app/core/ws_context.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""WebSocket client executor context.
|
||||
|
||||
Holds a per-request async callback that tools call to execute CRUD
|
||||
operations on the Electron client's local SQLite / LanceDB databases.
|
||||
The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Coroutine
|
||||
from uuid import uuid4
|
||||
|
||||
_SNAKE_TO_CAMEL_RE = re.compile(r"_([a-z])")
|
||||
|
||||
|
||||
def _key_to_camel(key: str) -> str:
|
||||
return _SNAKE_TO_CAMEL_RE.sub(lambda m: m.group(1).upper(), key)
|
||||
|
||||
|
||||
def _keys_to_camel(obj: Any) -> Any:
|
||||
"""Recursively convert dict keys from snake_case to camelCase.
|
||||
|
||||
Mirrors the JS-side ``toCamelCase`` applied to incoming WS frames in
|
||||
``adiuvAI/src/main/api/backend-client.ts``. The Electron executor wraps
|
||||
tool_result payloads in ``toSnakeCase`` before sending; this restores the
|
||||
camelCase schema property names that the tool code expects to read.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {_key_to_camel(k): _keys_to_camel(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_keys_to_camel(v) for v in obj]
|
||||
return obj
|
||||
|
||||
# Holds the execute callback for the current WS session.
|
||||
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||
"_client_executor"
|
||||
)
|
||||
|
||||
# Optional collector that captures raw execute_on_client results.
|
||||
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||
"_tool_result_collector", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||
"""Register *lst* as the collector for this async context."""
|
||||
_tool_result_collector.set(lst)
|
||||
|
||||
|
||||
def clear_tool_result_collector() -> None:
|
||||
"""Clear the collector (best-effort)."""
|
||||
_tool_result_collector.set(None)
|
||||
|
||||
|
||||
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
||||
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
||||
_client_executor.set(fn)
|
||||
|
||||
|
||||
def clear_client_executor() -> None:
|
||||
"""Remove the executor binding (best-effort; ContextVar resets on task exit)."""
|
||||
try:
|
||||
_client_executor.set(None) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def execute_on_client(
|
||||
action: str,
|
||||
table: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
filters: dict[str, Any] | None = None,
|
||||
vector: list[float] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a CRUD/vector operation to the Electron client and return the result.
|
||||
|
||||
Builds a ``tool_call`` payload, invokes the per-session WS callback,
|
||||
and returns the ``tool_result`` dict from Electron.
|
||||
|
||||
Raises ``RuntimeError`` if no executor is set (i.e. called outside a WS session).
|
||||
"""
|
||||
callback = _client_executor.get(None)
|
||||
if callback is None:
|
||||
raise RuntimeError(
|
||||
"execute_on_client() called outside a WebSocket session — "
|
||||
"no client executor is set."
|
||||
)
|
||||
|
||||
payload: dict[str, Any] = {"id": str(uuid4()), "action": action}
|
||||
if table is not None:
|
||||
payload["table"] = table
|
||||
if data is not None:
|
||||
payload["data"] = data
|
||||
if filters is not None:
|
||||
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
||||
if vector is not None:
|
||||
payload["vector"] = vector
|
||||
if limit is not None:
|
||||
payload["limit"] = limit
|
||||
|
||||
result = await callback(payload)
|
||||
result = _keys_to_camel(result)
|
||||
collector = _tool_result_collector.get(None)
|
||||
if collector is not None:
|
||||
collector.append({
|
||||
"action": action,
|
||||
"table": table,
|
||||
"data": result,
|
||||
})
|
||||
return result
|
||||
40
api/app/db.py
Normal file
40
api/app/db.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Database engine, session factory, and base model.
|
||||
|
||||
All app code uses the async SQLAlchemy API. Alembic migrations use the
|
||||
synchronous psycopg2 URL for the CLI (see alembic/env.py).
|
||||
|
||||
Usage in routes:
|
||||
from app.db import get_session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
async def my_route(db: AsyncSession = Depends(get_session)):
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_pre_ping=True,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Shared declarative base for all ORM models."""
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""FastAPI dependency that yields an async DB session per request."""
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
164
api/app/integrations/__init__.py
Normal file
164
api/app/integrations/__init__.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Cloud provider integration utilities.
|
||||
|
||||
Provides:
|
||||
* Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by
|
||||
both the Gmail and MS Graph clients and consumed by ``agent_runner``.
|
||||
* ``get_provider()`` — factory that returns the correct client given a
|
||||
provider name and decrypted OAuth credentials dict.
|
||||
* ``encrypt_token()`` / ``decrypt_token()`` — Fernet-based at-rest
|
||||
encryption for OAuth tokens stored in ``cloud_agent_configs``.
|
||||
|
||||
Encryption rationale
|
||||
--------------------
|
||||
Unlike user content (which is E2E-encrypted client-side and **never**
|
||||
decrypted server-side), OAuth tokens *must* be decrypted server-side
|
||||
because the backend makes provider API calls on behalf of the user.
|
||||
The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var — it
|
||||
is never returned to clients.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.integrations.gmail import GmailClient
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Shared message types ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmailMessage:
|
||||
"""A single email message fetched from Gmail or Outlook."""
|
||||
|
||||
id: str
|
||||
subject: str
|
||||
sender: str
|
||||
body_text: str
|
||||
date: datetime
|
||||
labels: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def as_text(self) -> str:
|
||||
"""Return a human-readable text representation for LLM extraction."""
|
||||
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
|
||||
return (
|
||||
f"From: {self.sender}\n"
|
||||
f"Date: {date_str}{labels_str}\n"
|
||||
f"Subject: {self.subject}\n\n"
|
||||
f"{self.body_text}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""A single Teams chat or channel message fetched from MS Graph."""
|
||||
|
||||
id: str
|
||||
content: str
|
||||
sender: str
|
||||
channel: str | None
|
||||
date: datetime
|
||||
|
||||
@property
|
||||
def as_text(self) -> str:
|
||||
"""Return a human-readable text representation for LLM extraction."""
|
||||
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||
channel_str = f" [channel: {self.channel}]" if self.channel else ""
|
||||
return (
|
||||
f"From: {self.sender}\n"
|
||||
f"Date: {date_str}{channel_str}\n\n"
|
||||
f"{self.content}"
|
||||
)
|
||||
|
||||
|
||||
# ── Fernet helpers ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
"""Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``.
|
||||
|
||||
Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set — callers
|
||||
must ensure this is configured before persisting OAuth tokens.
|
||||
"""
|
||||
key = settings.OAUTH_ENCRYPTION_KEY
|
||||
if not key:
|
||||
raise RuntimeError(
|
||||
"OAUTH_ENCRYPTION_KEY is not set. "
|
||||
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
||||
)
|
||||
return Fernet(key.encode() if isinstance(key, str) else key)
|
||||
|
||||
|
||||
def encrypt_token(token_info: dict) -> str:
|
||||
"""Fernet-encrypt an OAuth credential dict and return a base64 string.
|
||||
|
||||
Stores the full ``{access_token, refresh_token, token_uri, client_id,
|
||||
client_secret, scopes, expiry}`` dict (or equivalent MSAL shape).
|
||||
|
||||
Raises:
|
||||
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||
ValueError: ``token_info`` is not a non-empty dict.
|
||||
"""
|
||||
if not isinstance(token_info, dict) or not token_info:
|
||||
raise ValueError("token_info must be a non-empty dict")
|
||||
plaintext = json.dumps(token_info).encode("utf-8")
|
||||
return _get_fernet().encrypt(plaintext).decode("utf-8")
|
||||
|
||||
|
||||
def decrypt_token(encrypted: str) -> dict:
|
||||
"""Decrypt a Fernet-encrypted token string and return the credential dict.
|
||||
|
||||
Raises:
|
||||
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||
ValueError: The encrypted string is invalid or was encrypted with a
|
||||
different key.
|
||||
"""
|
||||
try:
|
||||
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
|
||||
return json.loads(plaintext)
|
||||
except (InvalidToken, json.JSONDecodeError) as exc:
|
||||
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
|
||||
|
||||
|
||||
# ── Provider factory ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_provider(
|
||||
provider: str,
|
||||
credentials_info: dict,
|
||||
) -> "GmailClient | MSGraphClient":
|
||||
"""Return the correct provider client for *provider*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
provider:
|
||||
One of ``"gmail"``, ``"outlook"``, ``"teams"``.
|
||||
credentials_info:
|
||||
Decrypted OAuth credential dict (Google or Microsoft shape).
|
||||
|
||||
Raises:
|
||||
ValueError: Unknown provider name.
|
||||
"""
|
||||
if provider == "gmail":
|
||||
from app.integrations.gmail import GmailClient
|
||||
return GmailClient(credentials_info)
|
||||
if provider in {"outlook", "teams"}:
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
return MSGraphClient(credentials_info)
|
||||
raise ValueError(
|
||||
f"Unknown cloud provider {provider!r}. "
|
||||
"Supported: 'gmail', 'outlook', 'teams'."
|
||||
)
|
||||
335
api/app/integrations/gmail.py
Normal file
335
api/app/integrations/gmail.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Gmail API client for cloud agent integration.
|
||||
|
||||
Wraps the Google Gmail REST API to fetch email messages matching a
|
||||
``filter_config`` dict. Uses the official ``google-api-python-client``
|
||||
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
|
||||
blocking the event loop.
|
||||
|
||||
Token refresh is handled transparently: when the stored access token has
|
||||
expired, ``google.auth.transport.requests.Request`` will use the refresh
|
||||
token to obtain a fresh one. The caller is responsible for persisting
|
||||
any refreshed credentials back to ``CloudScoutConfig.oauth_token_encrypted``
|
||||
(see ``agent_runner.run_cloud_agent``).
|
||||
|
||||
Credential dict shape (Google OAuth2):
|
||||
{
|
||||
"token": "<access_token>",
|
||||
"refresh_token": "<refresh_token>",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"client_id": "<client_id>",
|
||||
"client_secret": "<client_secret>",
|
||||
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import email
|
||||
import html
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from app.integrations import EmailMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Gmail search date format — e.g. "after:2025/01/01"
|
||||
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
||||
|
||||
# Maximum characters of body text forwarded to the LLM.
|
||||
_BODY_TRUNCATE = 8_000
|
||||
|
||||
# Maximum messages retrieved per run (prevents runaway quota usage).
|
||||
_MAX_MESSAGES = 200
|
||||
|
||||
|
||||
def _build_gmail_query(
|
||||
filter_config: dict[str, Any] | None,
|
||||
since: datetime | None,
|
||||
) -> str:
|
||||
"""Build a Gmail search query string from *filter_config* and *since*.
|
||||
|
||||
Supported ``filter_config`` keys:
|
||||
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
|
||||
senders (list[str]): Sender addresses or domains to include
|
||||
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
|
||||
|
||||
A hard ``since`` date (from last run) always overrides ``date_range.from``
|
||||
when it is earlier.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
cfg = filter_config or {}
|
||||
|
||||
# Labels — joined with OR when multiple given.
|
||||
labels: list[str] = cfg.get("labels", [])
|
||||
if labels:
|
||||
if len(labels) == 1:
|
||||
parts.append(f"label:{labels[0]}")
|
||||
else:
|
||||
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
||||
parts.append(f"({label_expr})")
|
||||
|
||||
# Senders — each prefixed with "from:".
|
||||
senders: list[str] = cfg.get("senders", [])
|
||||
for sender in senders:
|
||||
parts.append(f"from:{sender}")
|
||||
|
||||
# Date range.
|
||||
date_range: dict = cfg.get("date_range", {})
|
||||
from_str: str | None = date_range.get("from")
|
||||
to_str: str | None = date_range.get("to")
|
||||
|
||||
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
|
||||
effective_since: datetime | None = since
|
||||
if from_str:
|
||||
try:
|
||||
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||
if cfg_since.tzinfo is None:
|
||||
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||
if effective_since is None or cfg_since > effective_since:
|
||||
effective_since = cfg_since
|
||||
except ValueError:
|
||||
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
|
||||
|
||||
if effective_since:
|
||||
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
|
||||
|
||||
if to_str:
|
||||
try:
|
||||
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
|
||||
except ValueError:
|
||||
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _strip_html(raw_html: str) -> str:
|
||||
"""Remove HTML tags and decode entities to get plain text."""
|
||||
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
||||
decoded = html.unescape(no_tags)
|
||||
return re.sub(r"\s+", " ", decoded).strip()
|
||||
|
||||
|
||||
def _parse_body(payload: dict[str, Any]) -> str:
|
||||
"""Recursively extract the plain-text body from a Gmail message payload.
|
||||
|
||||
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
|
||||
Returns an empty string if no body can be extracted.
|
||||
"""
|
||||
mime_type: str = payload.get("mimeType", "")
|
||||
body: dict = payload.get("body", {})
|
||||
parts: list[dict] = payload.get("parts", [])
|
||||
|
||||
if mime_type == "text/plain":
|
||||
data = body.get("data", "")
|
||||
if data:
|
||||
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||
return ""
|
||||
|
||||
if mime_type == "text/html":
|
||||
data = body.get("data", "")
|
||||
if data:
|
||||
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||
return _strip_html(raw)
|
||||
return ""
|
||||
|
||||
# Multipart — prefer text/plain part, fall back to text/html.
|
||||
plain_fallback = ""
|
||||
for part in parts:
|
||||
part_mime = part.get("mimeType", "")
|
||||
if part_mime == "text/plain":
|
||||
return _parse_body(part)
|
||||
if part_mime == "text/html" and not plain_fallback:
|
||||
plain_fallback = _parse_body(part)
|
||||
if part_mime.startswith("multipart/"):
|
||||
nested = _parse_body(part)
|
||||
if nested:
|
||||
return nested
|
||||
return plain_fallback
|
||||
|
||||
|
||||
def _parse_date(raw: str) -> datetime:
|
||||
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
|
||||
try:
|
||||
parsed = email.utils.parsedate_to_datetime(raw)
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed.astimezone(timezone.utc)
|
||||
except Exception:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class GmailClient:
|
||||
"""Fetch email messages from a Gmail account via the Gmail REST API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
credentials_info:
|
||||
Decrypted OAuth2 credential dict. Must contain at minimum
|
||||
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
|
||||
``client_id`` + ``client_secret``.
|
||||
"""
|
||||
|
||||
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
self._credentials_info = credentials_info
|
||||
expiry_str: str | None = credentials_info.get("expiry")
|
||||
expiry: datetime | None = None
|
||||
if expiry_str:
|
||||
try:
|
||||
expiry = datetime.fromisoformat(
|
||||
expiry_str.replace("Z", "+00:00")
|
||||
).replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self._credentials = Credentials(
|
||||
token=credentials_info.get("token"),
|
||||
refresh_token=credentials_info.get("refresh_token"),
|
||||
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||
client_id=credentials_info.get("client_id"),
|
||||
client_secret=credentials_info.get("client_secret"),
|
||||
scopes=credentials_info.get("scopes"),
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
# ── Public API ─────────────────────────────────────────────────────────
|
||||
|
||||
async def fetch_messages(
|
||||
self,
|
||||
filter_config: dict[str, Any] | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[EmailMessage]:
|
||||
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
|
||||
|
||||
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
|
||||
to avoid blocking the async event loop.
|
||||
|
||||
Token refresh is performed automatically when the access token has
|
||||
expired. After the call, ``self.refreshed_credentials`` may be
|
||||
consulted to detect whether new credentials should be persisted.
|
||||
"""
|
||||
query = _build_gmail_query(filter_config, since)
|
||||
logger.debug("gmail: executing search query %r", query)
|
||||
return await asyncio.to_thread(self._fetch_sync, query)
|
||||
|
||||
@property
|
||||
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||
"""Return updated credential dict if the access token was refreshed.
|
||||
|
||||
If the credentials were refreshed during ``fetch_messages()``, returns
|
||||
a new dict that should be re-encrypted and written back to the DB.
|
||||
Returns ``None`` if no refresh occurred.
|
||||
"""
|
||||
creds = self._credentials
|
||||
if not creds.valid and creds.expired:
|
||||
return None
|
||||
# Check whether the token changed from what was stored.
|
||||
if creds.token != self._credentials_info.get("token"):
|
||||
result = {
|
||||
"token": creds.token,
|
||||
"refresh_token": creds.refresh_token,
|
||||
"token_uri": creds.token_uri,
|
||||
"client_id": creds.client_id,
|
||||
"client_secret": creds.client_secret,
|
||||
"scopes": list(creds.scopes or []),
|
||||
}
|
||||
if creds.expiry:
|
||||
result["expiry"] = creds.expiry.isoformat()
|
||||
return result
|
||||
return None
|
||||
|
||||
# ── Internal sync worker ───────────────────────────────────────────────
|
||||
|
||||
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
||||
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
|
||||
import googleapiclient.discovery
|
||||
import googleapiclient.errors
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
# Refresh token if needed before building the service.
|
||||
if self._credentials.expired and self._credentials.refresh_token:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
|
||||
|
||||
service = googleapiclient.discovery.build(
|
||||
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
||||
)
|
||||
user_api = service.users() # type: ignore[attr-defined]
|
||||
|
||||
# ── List matching message IDs ──────────────────────────────────────
|
||||
ids: list[str] = []
|
||||
page_token: str | None = None
|
||||
while len(ids) < _MAX_MESSAGES:
|
||||
batch_size = min(100, _MAX_MESSAGES - len(ids))
|
||||
kwargs: dict[str, Any] = {
|
||||
"userId": "me",
|
||||
"maxResults": batch_size,
|
||||
}
|
||||
if query:
|
||||
kwargs["q"] = query
|
||||
if page_token:
|
||||
kwargs["pageToken"] = page_token
|
||||
|
||||
try:
|
||||
resp = user_api.messages().list(**kwargs).execute()
|
||||
except googleapiclient.errors.HttpError as exc:
|
||||
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
|
||||
|
||||
for msg in resp.get("messages", []):
|
||||
ids.append(msg["id"])
|
||||
|
||||
page_token = resp.get("nextPageToken")
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
if not ids:
|
||||
logger.debug("gmail: no messages matched query %r", query)
|
||||
return []
|
||||
|
||||
logger.info("gmail: fetching %d message(s)", len(ids))
|
||||
|
||||
# ── Fetch individual message details ──────────────────────────────
|
||||
messages: list[EmailMessage] = []
|
||||
for msg_id in ids:
|
||||
try:
|
||||
msg = user_api.messages().get(
|
||||
userId="me", id=msg_id, format="full"
|
||||
).execute()
|
||||
|
||||
headers: dict[str, str] = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in msg.get("payload", {}).get("headers", [])
|
||||
}
|
||||
subject = headers.get("subject", "(no subject)")
|
||||
sender = headers.get("from", "unknown")
|
||||
date_raw = headers.get("date", "")
|
||||
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
|
||||
|
||||
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
|
||||
labels = msg.get("labelIds", [])
|
||||
|
||||
messages.append(EmailMessage(
|
||||
id=msg_id,
|
||||
subject=subject,
|
||||
sender=sender,
|
||||
body_text=body_text,
|
||||
date=date,
|
||||
labels=labels,
|
||||
))
|
||||
except googleapiclient.errors.HttpError as exc:
|
||||
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
|
||||
except Exception as exc:
|
||||
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
|
||||
|
||||
logger.info("gmail: returned %d message(s)", len(messages))
|
||||
return messages
|
||||
352
api/app/integrations/ms_graph.py
Normal file
352
api/app/integrations/ms_graph.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Microsoft Graph API client for Outlook and Teams cloud agent integration.
|
||||
|
||||
Handles two data sources:
|
||||
|
||||
* **Outlook email** (``provider="outlook"``) — ``fetch_emails()`` calls
|
||||
``/me/messages`` with an OData ``$filter`` built from ``filter_config``.
|
||||
* **Teams messages** (``provider="teams"``) — ``fetch_messages()`` calls
|
||||
``/me/chats/getAllMessages`` filtered by date.
|
||||
|
||||
Authentication uses MSAL ``PublicClientApplication`` to acquire a token
|
||||
from a stored refresh token. The ``httpx.AsyncClient`` (already a project
|
||||
dependency) is used for all API calls.
|
||||
|
||||
Credential dict shape (Microsoft OAuth2 / MSAL):
|
||||
{
|
||||
"access_token": "<access_token>",
|
||||
"refresh_token": "<refresh_token>",
|
||||
"token_type": "Bearer",
|
||||
"scope": "Mail.Read ChannelMessage.Read.All offline_access",
|
||||
"expires_in": 3600
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.integrations import ChatMessage, EmailMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
# Max items fetched per run.
|
||||
_MAX_EMAILS = 200
|
||||
_MAX_MESSAGES = 200
|
||||
|
||||
# Max characters of body forwarded to the LLM.
|
||||
_BODY_TRUNCATE = 8_000
|
||||
|
||||
|
||||
def _strip_html(raw: str) -> str:
|
||||
"""Strip HTML tags and collapse whitespace."""
|
||||
no_tags = re.sub(r"<[^>]+>", " ", raw)
|
||||
import html as _html
|
||||
decoded = _html.unescape(no_tags)
|
||||
return re.sub(r"\s+", " ", decoded).strip()
|
||||
|
||||
|
||||
def _odata_datetime(dt: datetime) -> str:
|
||||
"""Format a datetime as an OData datetime literal (UTC, ISO 8601)."""
|
||||
utc = dt.astimezone(timezone.utc)
|
||||
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def _build_email_filter(
|
||||
filter_config: dict[str, Any] | None,
|
||||
since: datetime | None,
|
||||
) -> str:
|
||||
"""Build an OData ``$filter`` expression for the ``/me/messages`` endpoint.
|
||||
|
||||
Supported ``filter_config`` keys:
|
||||
senders (list[str]): Sender email addresses.
|
||||
date_range (dict): ``{from: "<ISO-8601>", to: "<ISO-8601>"}``
|
||||
folders (list[str]): Folder display names (not directly filterable
|
||||
via OData, so ignored here — callers iterate
|
||||
folder IDs separately if needed; listed for
|
||||
completeness).
|
||||
|
||||
A hard ``since`` date always overrides ``date_range.from`` when it is
|
||||
earlier.
|
||||
"""
|
||||
clauses: list[str] = []
|
||||
cfg = filter_config or {}
|
||||
|
||||
# Senders.
|
||||
senders: list[str] = cfg.get("senders", [])
|
||||
if senders:
|
||||
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
|
||||
clauses.append("(" + " or ".join(sender_clauses) + ")")
|
||||
|
||||
# Date range.
|
||||
date_range: dict = cfg.get("date_range", {})
|
||||
from_str: str | None = date_range.get("from")
|
||||
|
||||
effective_since: datetime | None = since
|
||||
if from_str:
|
||||
try:
|
||||
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||
if cfg_since.tzinfo is None:
|
||||
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||
if effective_since is None or cfg_since > effective_since:
|
||||
effective_since = cfg_since
|
||||
except ValueError:
|
||||
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
|
||||
|
||||
if effective_since:
|
||||
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
|
||||
|
||||
to_str: str | None = date_range.get("to")
|
||||
if to_str:
|
||||
try:
|
||||
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||
if to_dt.tzinfo is None:
|
||||
to_dt = to_dt.replace(tzinfo=timezone.utc)
|
||||
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
|
||||
except ValueError:
|
||||
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
|
||||
|
||||
return " and ".join(clauses)
|
||||
|
||||
|
||||
class MSGraphClient:
|
||||
"""Fetch emails and Teams messages via the Microsoft Graph REST API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
credentials_info:
|
||||
Decrypted MSAL credential dict.
|
||||
"""
|
||||
|
||||
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||
self._credentials_info = credentials_info
|
||||
self._access_token: str = credentials_info.get("access_token", "")
|
||||
self._original_access_token: str = self._access_token
|
||||
self._refresh_token: str | None = credentials_info.get("refresh_token")
|
||||
|
||||
# ── Token management ───────────────────────────────────────────────────
|
||||
|
||||
def _auth_headers(self) -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {self._access_token}"}
|
||||
|
||||
async def _refresh_access_token(self) -> None:
|
||||
"""Use MSAL to exchange the refresh token for a fresh access token.
|
||||
|
||||
Updates ``self._access_token`` and ``self._credentials_info`` in-place.
|
||||
|
||||
Raises:
|
||||
RuntimeError: MSAL reports an auth error.
|
||||
"""
|
||||
import msal
|
||||
|
||||
app = msal.ConfidentialClientApplication(
|
||||
client_id=settings.MS_CLIENT_ID,
|
||||
client_credential=settings.MS_CLIENT_SECRET,
|
||||
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
|
||||
)
|
||||
scopes: list[str] = self._credentials_info.get("scope", "").split()
|
||||
if not scopes:
|
||||
scopes = ["https://graph.microsoft.com/.default"]
|
||||
|
||||
result = app.acquire_token_by_refresh_token(
|
||||
self._refresh_token,
|
||||
scopes=scopes,
|
||||
)
|
||||
if "access_token" not in result:
|
||||
error = result.get("error_description", result.get("error", "unknown"))
|
||||
raise RuntimeError(f"MS Graph token refresh failed: {error}")
|
||||
|
||||
self._access_token = result["access_token"]
|
||||
# MSAL may issue a new refresh token.
|
||||
if "refresh_token" in result:
|
||||
self._refresh_token = result["refresh_token"]
|
||||
self._credentials_info["refresh_token"] = result["refresh_token"]
|
||||
self._credentials_info["access_token"] = self._access_token
|
||||
|
||||
@property
|
||||
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||
"""Return updated credential dict if the access token was refreshed.
|
||||
|
||||
Returns ``None`` if no change was made.
|
||||
"""
|
||||
if self._access_token != self._original_access_token:
|
||||
return {**self._credentials_info, "access_token": self._access_token}
|
||||
return None
|
||||
|
||||
# ── HTTP helpers ───────────────────────────────────────────────────────
|
||||
|
||||
async def _get(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
url: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
*,
|
||||
retry_on_401: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""GET *url* with auth; refresh token on 401 and retry once."""
|
||||
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
|
||||
logger.debug("ms_graph: 401 on %s — refreshing token", url)
|
||||
await self._refresh_access_token()
|
||||
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||
if resp.status_code == 429:
|
||||
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
# ── Public API ─────────────────────────────────────────────────────────
|
||||
|
||||
async def fetch_emails(
|
||||
self,
|
||||
filter_config: dict[str, Any] | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[EmailMessage]:
|
||||
"""Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter_config:
|
||||
Optional dict with ``senders``, ``date_range``, ``folders`` keys.
|
||||
since:
|
||||
Hard lower-bound on email date (from last agent run).
|
||||
"""
|
||||
odata_filter = _build_email_filter(filter_config, since)
|
||||
params: dict[str, Any] = {
|
||||
"$top": 50,
|
||||
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
|
||||
"$orderby": "receivedDateTime desc",
|
||||
}
|
||||
if odata_filter:
|
||||
params["$filter"] = odata_filter
|
||||
|
||||
emails: list[EmailMessage] = []
|
||||
url = f"{_GRAPH_BASE}/me/messages"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
while url and len(emails) < _MAX_EMAILS:
|
||||
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||
for item in data.get("value", []):
|
||||
emails.append(self._parse_email(item))
|
||||
if len(emails) >= _MAX_EMAILS:
|
||||
break
|
||||
url = data.get("@odata.nextLink", "")
|
||||
params = {} # nextLink already contains encoded params.
|
||||
|
||||
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
|
||||
return emails
|
||||
|
||||
async def fetch_messages(
|
||||
self,
|
||||
filter_config: dict[str, Any] | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*.
|
||||
|
||||
Fetches from ``/me/chats/getAllMessages`` (personal + group chats).
|
||||
The ``filter_config.channels`` key is checked as a text-filter on
|
||||
the channel name post-fetch (the API doesn't support channel OData
|
||||
filter directly on ``getAllMessages``).
|
||||
"""
|
||||
cfg = filter_config or {}
|
||||
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
|
||||
params: dict[str, Any] = {"$top": 50}
|
||||
if since:
|
||||
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
|
||||
|
||||
messages: list[ChatMessage] = []
|
||||
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
while url and len(messages) < _MAX_MESSAGES:
|
||||
try:
|
||||
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
# getAllMessages requires specific licensing; degrade gracefully.
|
||||
if exc.response.status_code in (403, 404):
|
||||
logger.warning(
|
||||
"ms_graph: /me/chats/getAllMessages not available (%d) — "
|
||||
"check Teams license or permissions",
|
||||
exc.response.status_code,
|
||||
)
|
||||
break
|
||||
raise
|
||||
|
||||
for item in data.get("value", []):
|
||||
msg = self._parse_teams_message(item)
|
||||
if channel_filter and msg.channel:
|
||||
if not any(c in msg.channel.lower() for c in channel_filter):
|
||||
continue
|
||||
messages.append(msg)
|
||||
if len(messages) >= _MAX_MESSAGES:
|
||||
break
|
||||
url = data.get("@odata.nextLink", "")
|
||||
params = {}
|
||||
|
||||
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
|
||||
return messages
|
||||
|
||||
# ── Parsers ────────────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _parse_email(item: dict[str, Any]) -> EmailMessage:
|
||||
subject: str = item.get("subject", "(no subject)") or "(no subject)"
|
||||
sender_block = item.get("from", {}) or {}
|
||||
sender_addr = (
|
||||
(sender_block.get("emailAddress") or {}).get("address", "unknown")
|
||||
)
|
||||
date_str: str = item.get("receivedDateTime", "")
|
||||
try:
|
||||
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||
except Exception:
|
||||
date = datetime.now(timezone.utc)
|
||||
|
||||
body_block = item.get("body", {}) or {}
|
||||
content_type: str = body_block.get("contentType", "text")
|
||||
raw_body: str = body_block.get("content", "")
|
||||
if content_type == "html":
|
||||
body_text = _strip_html(raw_body)
|
||||
else:
|
||||
body_text = raw_body or item.get("bodyPreview", "")
|
||||
body_text = body_text[:_BODY_TRUNCATE]
|
||||
|
||||
return EmailMessage(
|
||||
id=item.get("id", ""),
|
||||
subject=subject,
|
||||
sender=sender_addr,
|
||||
body_text=body_text,
|
||||
date=date,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
|
||||
msg_id: str = item.get("id", "")
|
||||
sender_block = (item.get("from") or {}).get("user") or {}
|
||||
sender: str = sender_block.get("displayName", "unknown")
|
||||
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
|
||||
|
||||
date_str: str = item.get("createdDateTime", "")
|
||||
try:
|
||||
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||
except Exception:
|
||||
date = datetime.now(timezone.utc)
|
||||
|
||||
body_block = item.get("body", {}) or {}
|
||||
content_type: str = body_block.get("contentType", "text")
|
||||
raw_content: str = body_block.get("content", "")
|
||||
content = _strip_html(raw_content) if content_type == "html" else raw_content
|
||||
content = content[:_BODY_TRUNCATE]
|
||||
|
||||
return ChatMessage(
|
||||
id=msg_id,
|
||||
content=content,
|
||||
sender=sender,
|
||||
channel=channel,
|
||||
date=date,
|
||||
)
|
||||
242
api/app/main.py
Normal file
242
api/app/main.py
Normal file
@@ -0,0 +1,242 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
from app.config.settings import settings
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
async def _memory_audit_cron_tick() -> None:
|
||||
"""Weekly cron: contradiction scan + label canonicalization for all users (Phase 7)."""
|
||||
import logging # noqa: PLC0415
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.info("memory audit cron tick: starting")
|
||||
try:
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
from app.core.memory_maintenance import audit_memory # noqa: PLC0415
|
||||
from app.models import User # noqa: PLC0415
|
||||
from sqlalchemy import select # noqa: PLC0415
|
||||
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(User.id))
|
||||
user_ids: list[str] = list(result.scalars().all())
|
||||
|
||||
for uid in user_ids:
|
||||
try:
|
||||
async with async_session() as db:
|
||||
await audit_memory(db, uid)
|
||||
except Exception as exc:
|
||||
_log.warning("memory audit cron tick: audit_memory failed user=%s: %s", uid, exc)
|
||||
|
||||
_log.info("memory audit cron tick: done users=%d", len(user_ids))
|
||||
except Exception as exc:
|
||||
_log.warning("memory audit cron tick: failed: %s", exc)
|
||||
|
||||
|
||||
async def _memory_cron_tick() -> None:
|
||||
"""Hourly cron: drain Free-tier extraction queue + mine proactive patterns for Power+ users."""
|
||||
import logging # noqa: PLC0415
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.info("memory cron tick: starting")
|
||||
try:
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
from app.core.memory_maintenance import drain_extraction_queue, mine_proactive_patterns # noqa: PLC0415
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.models import User # noqa: PLC0415
|
||||
from sqlalchemy import select # noqa: PLC0415
|
||||
|
||||
async with async_session() as db:
|
||||
await drain_extraction_queue(db)
|
||||
|
||||
# mine proactive patterns for every Power+ user
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(User.id))
|
||||
user_ids: list[str] = list(result.scalars().all())
|
||||
|
||||
for uid in user_ids:
|
||||
try:
|
||||
async with async_session() as db:
|
||||
tier = await tier_manager.get_tier(uid, db)
|
||||
if tier_manager.check_feature(tier, "proactive_mining"):
|
||||
await mine_proactive_patterns(db, uid)
|
||||
except Exception as exc:
|
||||
_log.warning("memory cron tick: mine_proactive_patterns failed user=%s: %s", uid, exc)
|
||||
|
||||
_log.info("memory cron tick: done users=%d", len(user_ids))
|
||||
except Exception as exc:
|
||||
_log.warning("memory cron tick: failed: %s", exc)
|
||||
|
||||
|
||||
async def _scout_cron_tick() -> None:
|
||||
"""Every-15-min cron: poll enabled cloud scouts (cron-fallback; push is primary).
|
||||
|
||||
Skips any scout whose ``last_run_at`` is within the last 5 minutes so
|
||||
a push notification and the fallback cron don't double-fire within the
|
||||
same window.
|
||||
"""
|
||||
import logging # noqa: PLC0415
|
||||
import uuid # noqa: PLC0415
|
||||
from datetime import datetime, timezone # noqa: PLC0415
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.info("scout cron tick: starting")
|
||||
try:
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
from app.models import CloudScoutConfig # noqa: PLC0415
|
||||
from app.scouts.engine import ScoutEngine # noqa: PLC0415
|
||||
from sqlalchemy import select # noqa: PLC0415
|
||||
|
||||
async with async_session() as session:
|
||||
scouts = (await session.execute(
|
||||
select(CloudScoutConfig).where(CloudScoutConfig.enabled == True) # noqa: E712
|
||||
)).scalars().all()
|
||||
|
||||
engine = ScoutEngine()
|
||||
triggered = 0
|
||||
for scout in scouts:
|
||||
# Rate-limit guard: push is primary; skip if ran within 5 minutes.
|
||||
if scout.last_run_at:
|
||||
elapsed = (datetime.now(tz=timezone.utc) - scout.last_run_at).total_seconds()
|
||||
if elapsed < 300:
|
||||
continue
|
||||
try:
|
||||
await engine.trigger_scout(uuid.UUID(str(scout.id)))
|
||||
triggered += 1
|
||||
except Exception as exc:
|
||||
_log.warning("scout cron tick: trigger failed scout=%s: %s", scout.id, exc)
|
||||
|
||||
_log.info("scout cron tick: done triggered=%d total=%d", triggered, len(scouts))
|
||||
except Exception as exc:
|
||||
_log.warning("scout cron tick: failed: %s", exc)
|
||||
|
||||
|
||||
async def _scout_watch_renewal_tick() -> None:
|
||||
"""Every-24-hour cron: re-issue Gmail users.watch for scouts expiring within 24h.
|
||||
|
||||
Handles missing or misconfigured connectors gracefully — logs and continues.
|
||||
"""
|
||||
import logging # noqa: PLC0415
|
||||
from datetime import datetime, timedelta, timezone # noqa: PLC0415
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.info("scout watch renewal tick: starting")
|
||||
try:
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
from app.models import CloudScoutConfig # noqa: PLC0415
|
||||
from app.scouts.connectors.registry import get_connector # noqa: PLC0415
|
||||
from sqlalchemy import select # noqa: PLC0415
|
||||
|
||||
threshold = datetime.now(tz=timezone.utc) + timedelta(hours=24)
|
||||
renewed = 0
|
||||
async with async_session() as session:
|
||||
scouts = (await session.execute(
|
||||
select(CloudScoutConfig).where(
|
||||
CloudScoutConfig.enabled == True, # noqa: E712
|
||||
CloudScoutConfig.provider == "gmail",
|
||||
CloudScoutConfig.gmail_watch_expires_at <= threshold,
|
||||
)
|
||||
)).scalars().all()
|
||||
|
||||
for scout in scouts:
|
||||
try:
|
||||
connector = get_connector("gmail")
|
||||
await connector.renew_watch(scout)
|
||||
renewed += 1
|
||||
except Exception:
|
||||
_log.exception("scout watch renewal tick: renew failed scout=%s", scout.id)
|
||||
|
||||
await session.commit()
|
||||
|
||||
_log.info("scout watch renewal tick: done renewed=%d", renewed)
|
||||
except Exception as exc:
|
||||
_log.warning("scout watch renewal tick: failed: %s", exc)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup: register source connectors.
|
||||
from app.scouts.connectors.gmail import GmailConnector # noqa: PLC0415
|
||||
from app.scouts.connectors.registry import register_connector # noqa: PLC0415
|
||||
register_connector(GmailConnector())
|
||||
|
||||
# Startup: ensure agent tool modules are loaded.
|
||||
import app.agents # noqa: F401
|
||||
|
||||
scheduler = None
|
||||
if settings.SCHEDULER_ENABLED:
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # noqa: PLC0415
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
|
||||
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
|
||||
scheduler.add_job(
|
||||
_scout_cron_tick, "interval", minutes=15,
|
||||
id="scout_cron_tick", replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_scout_watch_renewal_tick, "interval", hours=24,
|
||||
id="scout_watch_renewal_tick", replace_existing=True,
|
||||
)
|
||||
scheduler.start()
|
||||
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
|
||||
|
||||
yield
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.shutdown(wait=False)
|
||||
|
||||
# Shutdown: dispose SQLAlchemy connection pool
|
||||
from app.db import engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="AdiuvAI Cloud API",
|
||||
version="0.1.0",
|
||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
|
||||
# Request flow: TierRateLimit → Sanitizer → CORS → Router
|
||||
# Response flow: Router → CORS → Sanitizer → TierRateLimit
|
||||
app.add_middleware(SanitizerMiddleware)
|
||||
app.add_middleware(TierRateLimitMiddleware)
|
||||
|
||||
from app.api.routes import scouts, auth, billing, chat, device_ws, memory, scout_webhooks
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(chat.router, prefix="/api/v1")
|
||||
app.include_router(billing.router, prefix="/api/v1")
|
||||
app.include_router(scouts.router, prefix="/api/v1")
|
||||
app.include_router(scout_webhooks.router, prefix="/api/v1")
|
||||
app.include_router(device_ws.router, prefix="/api/v1")
|
||||
app.include_router(memory.router, prefix="/api/v1")
|
||||
|
||||
@app.get("/api/v1/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
return {"status": "ok", "version": app.version}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
474
api/app/models.py
Normal file
474
api/app/models.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""SQLAlchemy ORM models for all persistent tables.
|
||||
|
||||
Only auth, billing, scout config, and memory data live here.
|
||||
User content (notes, tasks, etc.) lives exclusively on the client.
|
||||
|
||||
Table inventory:
|
||||
users — account credentials + tier
|
||||
refresh_tokens — hashed refresh token store
|
||||
subscriptions — Stripe subscription records
|
||||
local_scout_configs — per-device batch scout configs
|
||||
cloud_scout_configs — OAuth-backed cloud scout configs
|
||||
scout_run_logs — execution history for all scouts
|
||||
memory_core — per-user persistent key/value preferences (encrypted)
|
||||
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||
memory_episodic — per-user session summaries (encrypted)
|
||||
memory_proactive — per-user behavioral patterns (encrypted)
|
||||
memory_relations — per-user entity/relation graph (Mem0g-light, Phase 3)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
DateTime,
|
||||
Enum,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
JSON,
|
||||
LargeBinary,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
Uuid,
|
||||
func,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db import Base
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# ── Enum types ────────────────────────────────────────────────────────────
|
||||
|
||||
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
||||
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
||||
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
password_hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
avatar_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||
# Used to encrypt/decrypt all memory rows for this user.
|
||||
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
onboarding_completed_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
|
||||
back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
subscription: Mapped[Subscription | None] = relationship(
|
||||
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||
)
|
||||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||
back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class RefreshToken(Base):
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||
|
||||
|
||||
class OAuthAccount(Base):
|
||||
__tablename__ = "oauth_accounts"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
provider_user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship(back_populates="oauth_accounts")
|
||||
|
||||
|
||||
class Subscription(Base):
|
||||
__tablename__ = "subscriptions"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, unique=True, index=True
|
||||
)
|
||||
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
|
||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
|
||||
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship(back_populates="subscription")
|
||||
|
||||
|
||||
class LocalScoutConfig(Base):
|
||||
__tablename__ = "local_scout_configs"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
device_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||
scout_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
run_logs: Mapped[list["ScoutRunLog"]] = relationship(
|
||||
back_populates="local_scout",
|
||||
primaryjoin="and_(ScoutRunLog.scout_id == LocalScoutConfig.id, ScoutRunLog.scout_type == 'local')",
|
||||
foreign_keys="ScoutRunLog.scout_id",
|
||||
cascade="all, delete-orphan",
|
||||
overlaps="run_logs,cloud_scout",
|
||||
)
|
||||
|
||||
|
||||
class CloudScoutConfig(Base):
|
||||
__tablename__ = "cloud_scout_configs"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
provider: Mapped[str] = mapped_column(CloudProviderEnum, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||
oauth_token_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
filter_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
auto_trash_spam: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default=text("false"))
|
||||
gmail_history_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
gmail_watch_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
device_inactivity_pause_days: Mapped[int] = mapped_column(Integer, nullable=False, default=14, server_default="14")
|
||||
gmail_address: Mapped[str | None] = mapped_column(String(320), nullable=True)
|
||||
|
||||
run_logs: Mapped[list["ScoutRunLog"]] = relationship(
|
||||
back_populates="cloud_scout",
|
||||
primaryjoin="and_(ScoutRunLog.scout_id == CloudScoutConfig.id, ScoutRunLog.scout_type == 'cloud')",
|
||||
foreign_keys="ScoutRunLog.scout_id",
|
||||
cascade="all, delete-orphan",
|
||||
overlaps="run_logs,local_scout",
|
||||
)
|
||||
|
||||
|
||||
class ScoutTriageQueue(Base):
|
||||
__tablename__ = "scout_triage_queue"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
scout_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False)
|
||||
source_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
source_msg_ref: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
triage_verdict: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
triage_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), nullable=False, default="queued", server_default="queued")
|
||||
triaged_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||
delivered_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
acked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
|
||||
class ScoutRunLog(Base):
|
||||
__tablename__ = "scout_run_logs"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
# Plain string — not a FK because it references either local_scout_configs or cloud_scout_configs
|
||||
# depending on scout_type. Query by (scout_id, scout_type) to locate the source config.
|
||||
scout_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
scout_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
||||
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
|
||||
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||
started_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
local_scout: Mapped["LocalScoutConfig | None"] = relationship(
|
||||
back_populates="run_logs",
|
||||
primaryjoin="and_(ScoutRunLog.scout_id == LocalScoutConfig.id, ScoutRunLog.scout_type == 'local')",
|
||||
foreign_keys="ScoutRunLog.scout_id",
|
||||
overlaps="run_logs,cloud_scout",
|
||||
)
|
||||
cloud_scout: Mapped["CloudScoutConfig | None"] = relationship(
|
||||
back_populates="run_logs",
|
||||
primaryjoin="and_(ScoutRunLog.scout_id == CloudScoutConfig.id, ScoutRunLog.scout_type == 'cloud')",
|
||||
foreign_keys="ScoutRunLog.scout_id",
|
||||
overlaps="run_logs,local_scout",
|
||||
)
|
||||
|
||||
|
||||
class MonthlyTokenUsage(Base):
|
||||
__tablename__ = "monthly_token_usage"
|
||||
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
year_month: Mapped[str] = mapped_column(String(7), primary_key=True) # 'YYYY-MM'
|
||||
feature: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
|
||||
# ── Memory models ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class MemoryCore(Base):
|
||||
"""Per-user persistent key/value preferences, encrypted at rest.
|
||||
|
||||
Examples: preferred_language, timezone, work_style.
|
||||
Decrypted in-memory only using User.encryption_key.
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_core"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class MemoryAssociative(Base):
|
||||
"""Per-user semantic memory: encrypted content + pgvector embedding for similarity search.
|
||||
|
||||
Production: ``embedding`` column is ``vector(1536)`` via pgvector.
|
||||
Tests (SQLite): stored as JSON list.
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_associative"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
# vector(1536) via pgvector; SQLite tests use NULL embeddings so no dialect issue.
|
||||
embedding: Mapped[list | None] = mapped_column(Vector(1536), nullable=True)
|
||||
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class MemoryEpisodic(Base):
|
||||
"""Per-user session summaries, encrypted at rest.
|
||||
|
||||
One row per session interaction; used to recall recent conversations.
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_episodic"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class MemoryProactive(Base):
|
||||
"""Per-user inferred behavioral patterns, encrypted at rest.
|
||||
|
||||
Confidence in [0.0, 1.0]; only patterns above threshold are injected.
|
||||
Source: 'inferred' (from episodes) or 'explicit' (user-stated).
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_proactive"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
|
||||
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class ExtractionQueue(Base):
|
||||
"""Batch extraction queue for Free-tier users (Phase 2).
|
||||
|
||||
Pro/Power/Team users get realtime asyncio.create_task() extraction.
|
||||
Free users get a queue row here; a daily cron (Phase 5) drains it.
|
||||
"""
|
||||
|
||||
__tablename__ = "extraction_queue"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
episode_id: Mapped[str | None] = mapped_column(
|
||||
Uuid(as_uuid=False), nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class MemoryRelation(Base):
|
||||
"""Per-user entity/relation graph row (Mem0g-light, Phase 3).
|
||||
|
||||
subject_label/object_label are plaintext entity identifiers (not user content).
|
||||
notes_encrypted is optional Fernet-encrypted per-user commentary.
|
||||
confidence in [0.0, 1.0] — decays 5 % per 30 days since last_confirmed_at.
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_relations"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
subject_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
subject_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
predicate: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
object_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
object_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.7)
|
||||
source_episode_id: Mapped[str | None] = mapped_column(
|
||||
Uuid(as_uuid=False),
|
||||
ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
notes_encrypted: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
last_confirmed_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class Plugin(Base):
|
||||
"""Plugin marketplace catalog entry."""
|
||||
|
||||
__tablename__ = "plugins"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
version: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
author_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
|
||||
status: Mapped[str] = mapped_column(String(50), nullable=False, default="pending")
|
||||
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
342
api/app/schemas/__init__.py
Normal file
342
api/app/schemas/__init__.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""Pydantic schemas — API request/response contracts.
|
||||
|
||||
Mirrors the TypeScript types from the Electron app (src/shared/api-types.ts).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ── Billing ──────────────────────────────────────────────────────────
|
||||
|
||||
BillingTier = Literal["free", "pro", "power", "team"]
|
||||
|
||||
|
||||
# ── Auth ─────────────────────────────────────────────────────────────
|
||||
|
||||
class AuthTokens(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_at: int
|
||||
|
||||
|
||||
class UserProfile(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
name: str | None = None
|
||||
surname: str | None = None
|
||||
tier: BillingTier
|
||||
avatar_url: str | None = None
|
||||
has_password: bool = True
|
||||
onboarding_completed_at: int | None = None # epoch ms, null = not onboarded
|
||||
memory: dict[str, str] = Field(default_factory=dict) # decrypted core memory k/v
|
||||
|
||||
|
||||
class OAuthAccountInfo(BaseModel):
|
||||
provider: str
|
||||
provider_email: str | None = None
|
||||
created_at: int # epoch ms
|
||||
|
||||
|
||||
# ── Chat ─────────────────────────────────────────────────────────────
|
||||
|
||||
class ChatContext(BaseModel):
|
||||
user_profile: dict[str, Any] = Field(default_factory=dict)
|
||||
relevant_documents: list[str] = Field(default_factory=list)
|
||||
recent_tasks: list[dict[str, Any]] = Field(default_factory=list)
|
||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
context: ChatContext = Field(default_factory=ChatContext)
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
|
||||
|
||||
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
||||
|
||||
class WsFrameType(str, Enum):
|
||||
# ── v2 frame types (kept for backward compat) ──────────────────────
|
||||
chat_request = "chat_request"
|
||||
text_chunk = "text_chunk"
|
||||
tool_call = "tool_call"
|
||||
tool_result = "tool_result"
|
||||
final = "final"
|
||||
ping = "ping"
|
||||
device_hello = "device_hello"
|
||||
# ── v3 frame types ─────────────────────────────────────────────────
|
||||
home_request = "home_request"
|
||||
stream_start = "stream_start"
|
||||
stream_text = "stream_text"
|
||||
stream_end = "stream_end"
|
||||
data_request = "data_request"
|
||||
data_response = "data_response"
|
||||
mutation = "mutation"
|
||||
# ── v4 journey frame types ────────────────────────────────────────
|
||||
journey_start = "journey_start"
|
||||
journey_message = "journey_message"
|
||||
journey_reply = "journey_reply"
|
||||
# ── v5 brief frame types ──────────────────────────────────────────
|
||||
brief_request = "brief_request"
|
||||
# ── v6 task brief frame types ─────────────────────────────────────
|
||||
task_brief_request = "task_brief_request"
|
||||
# ── v7 folder index frame types ───────────────────────────────────
|
||||
index_session_start = "index_session_start"
|
||||
index_file_batch = "index_file_batch"
|
||||
index_session_cancel = "index_session_cancel"
|
||||
index_file_result = "index_file_result"
|
||||
index_session_progress = "index_session_progress"
|
||||
index_session_done = "index_session_done"
|
||||
# ── v8 contextual sidebar frame types ────────────────────────────
|
||||
contextual_request = "contextual_request"
|
||||
contextual_scope_update = "contextual_scope_update"
|
||||
contextual_scope_ack = "contextual_scope_ack"
|
||||
# ── v9 scout proposal frame types ────────────────────────────────
|
||||
SCOUT_PROPOSAL = "scout_proposal"
|
||||
SCOUT_PROPOSAL_ACK = "scout_proposal_ack"
|
||||
|
||||
|
||||
class WsToolCall(BaseModel):
|
||||
"""Server → Client: requests a CRUD/vector operation on the local DB."""
|
||||
|
||||
type: Literal[WsFrameType.tool_call] = WsFrameType.tool_call
|
||||
id: str
|
||||
action: str
|
||||
table: str | None = None
|
||||
data: dict[str, Any] | None = None
|
||||
filters: dict[str, Any] | None = None
|
||||
vector: list[float] | None = None
|
||||
limit: int | None = None
|
||||
|
||||
|
||||
class WsToolResult(BaseModel):
|
||||
"""Client → Server: result of a CRUD/vector operation."""
|
||||
|
||||
type: Literal[WsFrameType.tool_result] = WsFrameType.tool_result
|
||||
id: str
|
||||
row: dict[str, Any] | None = None
|
||||
rows: list[dict[str, Any]] | None = None
|
||||
results: list[dict[str, Any]] | None = None
|
||||
deleted: bool | None = None
|
||||
ok: bool | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class WsTextChunk(BaseModel):
|
||||
"""Server → Client: incremental LLM response text."""
|
||||
|
||||
type: Literal[WsFrameType.text_chunk] = WsFrameType.text_chunk
|
||||
text: str
|
||||
|
||||
|
||||
class WsFinal(BaseModel):
|
||||
"""Server → Client: signals end of response with the complete text."""
|
||||
|
||||
type: Literal[WsFrameType.final] = WsFrameType.final
|
||||
response: str
|
||||
|
||||
|
||||
# ── WebSocket Agent Frame Protocol ────────────────────────────────────
|
||||
|
||||
class WsDeviceHello(BaseModel):
|
||||
"""Client → Server: device identification on WS connect."""
|
||||
|
||||
type: Literal[WsFrameType.device_hello] = WsFrameType.device_hello
|
||||
device_id: str
|
||||
scout_ids: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||
|
||||
class FormatPrefsModel(BaseModel):
|
||||
"""User display preferences sent by Electron on each request."""
|
||||
|
||||
timezone: str = "UTC"
|
||||
date_format: str = "dd/MM/yyyy"
|
||||
time_format: str = "24h"
|
||||
locale: str = "en-US"
|
||||
now_iso: str = ""
|
||||
|
||||
|
||||
class WsHomeRequest(BaseModel):
|
||||
"""Client → Server: Home chat message."""
|
||||
|
||||
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
|
||||
message: str
|
||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||
format_prefs: FormatPrefsModel | None = None
|
||||
|
||||
|
||||
class WsBriefRequest(BaseModel):
|
||||
"""Client → Server: Request a plain-text brief (home or project)."""
|
||||
|
||||
type: Literal[WsFrameType.brief_request] = WsFrameType.brief_request
|
||||
request_id: str | None = None
|
||||
session_id: str | None = None
|
||||
mode: Literal["home", "project"]
|
||||
project_id: str | None = None
|
||||
format_prefs: FormatPrefsModel | None = None
|
||||
|
||||
|
||||
class WsStreamStart(BaseModel):
|
||||
"""Server → Client: signals start of a streaming response."""
|
||||
|
||||
type: Literal[WsFrameType.stream_start] = WsFrameType.stream_start
|
||||
request_id: str
|
||||
|
||||
|
||||
class WsStreamText(BaseModel):
|
||||
"""Server → Client: streamed text token."""
|
||||
|
||||
type: Literal[WsFrameType.stream_text] = WsFrameType.stream_text
|
||||
request_id: str
|
||||
chunk: str
|
||||
|
||||
|
||||
class WsStreamEnd(BaseModel):
|
||||
"""Server → Client: signals end of a streaming response."""
|
||||
|
||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||
request_id: str
|
||||
error: str | None = None
|
||||
mutations: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
# ── Scout Config V2 ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ScoutContentTypeConfig(BaseModel):
|
||||
"""Per-type extraction config produced by the journey chatbot."""
|
||||
|
||||
id: str
|
||||
label: str = ""
|
||||
detection_hint: str = ""
|
||||
preprocessing: str = "generic" # handler name: "email_html", "plain_text", ...
|
||||
extraction_prompt: str
|
||||
|
||||
|
||||
class ScoutConfig(BaseModel):
|
||||
"""Structured scout configuration (replaces freeform prompt_template)."""
|
||||
|
||||
content_types: list[ScoutContentTypeConfig] = []
|
||||
global_rules: list[str] = []
|
||||
data_types: list[str] = []
|
||||
|
||||
|
||||
# ── Scout Catalog ─────────────────────────────────────────────────────
|
||||
|
||||
class ScoutCatalogItem(BaseModel):
|
||||
type: str
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class ScoutCreationCheckRequest(BaseModel):
|
||||
active_agents: int = Field(ge=0, default=0)
|
||||
|
||||
|
||||
class ScoutCreationCheckResponse(BaseModel):
|
||||
allowed: bool
|
||||
tier: BillingTier
|
||||
active_agents: int
|
||||
limit: int
|
||||
|
||||
|
||||
class ScoutTriggerRequest(BaseModel):
|
||||
directory: str = Field(min_length=1)
|
||||
device_id: str = Field(default="")
|
||||
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
||||
what_to_extract: list[str] = Field(min_length=1)
|
||||
batch_interval: str = Field(min_length=1)
|
||||
custom_agent_prompt: str | None = None
|
||||
agent_config: dict | None = None
|
||||
active_agents: int = Field(ge=0, default=0)
|
||||
last_run_at: int | None = None # epoch ms from FE — enables incremental scanning
|
||||
|
||||
|
||||
# ── Scout Run Log ─────────────────────────────────────────────────────
|
||||
|
||||
class ScoutRunLogResponse(BaseModel):
|
||||
id: str
|
||||
agent_id: str
|
||||
agent_type: Literal["local", "cloud"]
|
||||
status: Literal["running", "success", "error", "partial"]
|
||||
items_processed: int
|
||||
items_created: int
|
||||
errors: list[str]
|
||||
started_at: int
|
||||
completed_at: int | None
|
||||
|
||||
|
||||
# ── Cloud Scout CRUD ──────────────────────────────────────────────────
|
||||
|
||||
class CloudScoutCreateRequest(BaseModel):
|
||||
name: str
|
||||
provider: Literal["gmail", "teams", "outlook"]
|
||||
data_types: list[str] = Field(default_factory=list)
|
||||
prompt_template: str = ""
|
||||
schedule_cron: str | None = None # None → server default
|
||||
filter_config: dict | None = None
|
||||
auto_trash_spam: bool = False
|
||||
|
||||
|
||||
class CloudScoutUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
data_types: list[str] | None = None
|
||||
prompt_template: str | None = None
|
||||
schedule_cron: str | None = None
|
||||
filter_config: dict | None = None
|
||||
auto_trash_spam: bool | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class CloudScoutResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
provider: str
|
||||
name: str
|
||||
data_types: list[str]
|
||||
prompt_template: str
|
||||
schedule_cron: str
|
||||
filter_config: dict | None
|
||||
auto_trash_spam: bool
|
||||
enabled: bool
|
||||
last_run_at: int | None
|
||||
gmail_address: str | None
|
||||
oauth_connected: bool
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── Scout Proposal Frame Models ───────────────────────────────────────
|
||||
|
||||
class ScoutProposalPayload(BaseModel):
|
||||
id: str
|
||||
scout_id: str
|
||||
source_type: str
|
||||
source_msg_ref: str
|
||||
raw_subject: str | None = None
|
||||
raw_snippet: str | None = None
|
||||
category: Literal["unprocessed"] = "unprocessed"
|
||||
payload: dict | None = None
|
||||
|
||||
|
||||
class ScoutProposalFrame(BaseModel):
|
||||
type: Literal[WsFrameType.SCOUT_PROPOSAL]
|
||||
proposal: ScoutProposalPayload
|
||||
|
||||
|
||||
class ScoutProposalAckFrame(BaseModel):
|
||||
type: Literal[WsFrameType.SCOUT_PROPOSAL_ACK]
|
||||
proposal_id: str
|
||||
73
api/app/schemas/contextual.py
Normal file
73
api/app/schemas/contextual.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Contextual sidebar scope schema and prompt block renderer.
|
||||
|
||||
ContextualScope mirrors the TypeScript ContextualScope type sent by the
|
||||
Electron renderer when the user opens the side chat anchored to a specific
|
||||
view. The renderer ships camelCase keys; Pydantic's alias_generator maps
|
||||
them to snake_case Python attributes automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
|
||||
PageType = Literal[
|
||||
"timeline",
|
||||
"tasks",
|
||||
"projects-list",
|
||||
"project",
|
||||
"note",
|
||||
]
|
||||
|
||||
EntityType = Literal["project", "note", "task", "timeline_event"]
|
||||
|
||||
|
||||
class ContextualScope(BaseModel):
|
||||
"""Scope payload sent by the Electron renderer for contextual chat.
|
||||
|
||||
The renderer ships camelCase keys (entityType, entityId, ...). Pydantic's
|
||||
alias generator maps them to snake_case Python attrs.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
|
||||
|
||||
page: PageType
|
||||
entity_type: Optional[EntityType] = None
|
||||
entity_id: Optional[str] = None
|
||||
entity_name: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
char_count: Optional[int] = None
|
||||
counts: Optional[dict[str, int]] = None
|
||||
filters: Optional[dict] = None
|
||||
|
||||
|
||||
def render_scope_block(scope: ContextualScope) -> str:
|
||||
"""Produce a single-paragraph human-readable summary of the current view
|
||||
for injection into the contextual agent system prompt.
|
||||
|
||||
Never emits internal ids — only names. The LLM is told to use names in
|
||||
prose; ids travel through tool calls.
|
||||
"""
|
||||
if scope.entity_type == "project":
|
||||
c = scope.counts or {}
|
||||
return (
|
||||
f"User is viewing the project {scope.entity_name!r}. "
|
||||
f"{c.get('tasks', 0)} tasks, "
|
||||
f"{c.get('notes', 0)} notes, "
|
||||
f"{c.get('milestones', 0)} milestones."
|
||||
)
|
||||
if scope.entity_type == "note":
|
||||
return (
|
||||
f"User is viewing the note {scope.entity_name!r} "
|
||||
f"({scope.char_count or 0} characters)."
|
||||
)
|
||||
if scope.page == "tasks":
|
||||
return "User is viewing the global Tasks list (all projects)."
|
||||
if scope.page == "timeline":
|
||||
return "User is viewing the global Timeline view."
|
||||
if scope.page == "projects-list":
|
||||
return "User is viewing the Projects list."
|
||||
return f"User is on page {scope.page}."
|
||||
0
api/app/scouts/__init__.py
Normal file
0
api/app/scouts/__init__.py
Normal file
0
api/app/scouts/connectors/__init__.py
Normal file
0
api/app/scouts/connectors/__init__.py
Normal file
56
api/app/scouts/connectors/base.py
Normal file
56
api/app/scouts/connectors/base.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Source connector Protocol and shared item types.
|
||||
|
||||
A SourceConnector adapts a third-party data source (Gmail, Slack, ...) to the
|
||||
shared ScoutEngine interface. Each connector owns:
|
||||
|
||||
* how to enumerate new items since the last poll (``list_new``)
|
||||
* how to fetch a single item's metadata cheaply (``fetch_metadata``)
|
||||
* how to fetch a single item's full content for in-memory triage
|
||||
(``fetch_content``) — this content MUST NOT be persisted by the engine
|
||||
* how to archive/trash an item (``archive``) for spam handling
|
||||
* optional push-notification setup (``setup_watch`` / ``renew_watch``)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ItemRef(BaseModel):
|
||||
source_msg_ref: str
|
||||
received_at: datetime | None = None
|
||||
|
||||
|
||||
class ItemMetadata(BaseModel):
|
||||
subject: str | None = None
|
||||
sender: str | None = None
|
||||
snippet: str | None = None
|
||||
received_at: datetime | None = None
|
||||
|
||||
|
||||
class ItemContent(BaseModel):
|
||||
metadata: ItemMetadata
|
||||
body_text: str
|
||||
raw_headers: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TriageVerdict(BaseModel):
|
||||
verdict: Literal["relevant", "spam"]
|
||||
reason: str
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class SourceConnector(Protocol):
|
||||
"""Adapter for a third-party data source (Gmail, Slack, ...)."""
|
||||
|
||||
source_type: str # e.g. "gmail"
|
||||
|
||||
async def list_new(self, scout) -> list[ItemRef]: ...
|
||||
async def fetch_metadata(self, scout, ref: ItemRef) -> ItemMetadata: ...
|
||||
async def fetch_content(self, scout, ref: ItemRef) -> ItemContent: ...
|
||||
async def archive(self, scout, ref: ItemRef) -> None: ...
|
||||
async def setup_watch(self, scout) -> None: ...
|
||||
async def renew_watch(self, scout) -> None: ...
|
||||
248
api/app/scouts/connectors/gmail.py
Normal file
248
api/app/scouts/connectors/gmail.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Gmail SourceConnector — wraps the existing GmailClient.
|
||||
|
||||
Responsibilities:
|
||||
* list_new: incremental fetch since the scout's stored gmail_history_id
|
||||
* fetch_metadata: subject + sender + snippet only (Gmail metadata format)
|
||||
* fetch_content: full body text — transient, never persisted by engine
|
||||
* archive: move a message to Gmail Trash (recoverable for 30 days)
|
||||
* setup_watch / renew_watch: Gmail push notifications via Pub/Sub
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.integrations import decrypt_token
|
||||
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_plain_text_body(payload: dict) -> str:
|
||||
"""Recursively walk a Gmail message payload to find text/plain content."""
|
||||
import base64
|
||||
mime_type = payload.get("mimeType", "")
|
||||
if mime_type == "text/plain":
|
||||
data = payload.get("body", {}).get("data", "")
|
||||
if data:
|
||||
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||
return ""
|
||||
if mime_type.startswith("multipart/"):
|
||||
for part in payload.get("parts", []):
|
||||
text = _extract_plain_text_body(part)
|
||||
if text:
|
||||
return text
|
||||
# text/html fallback: strip tags rudimentarily if no text/plain part
|
||||
if mime_type == "text/html":
|
||||
data = payload.get("body", {}).get("data", "")
|
||||
if data:
|
||||
import re
|
||||
html = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||
return re.sub(r"<[^>]+>", " ", html)
|
||||
return ""
|
||||
|
||||
|
||||
def _gmail_service_from_token(creds_info: dict):
|
||||
"""Build a synchronous Gmail API client from a decrypted credentials dict.
|
||||
|
||||
Shared by ``_get_gmail_service`` (scout-backed) and the pending-session
|
||||
OAuth flow which has a raw token but no scout row yet.
|
||||
"""
|
||||
from googleapiclient.discovery import build
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
credentials = Credentials(
|
||||
token=creds_info.get("token"),
|
||||
refresh_token=creds_info.get("refresh_token"),
|
||||
token_uri=creds_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||
client_id=creds_info.get("client_id"),
|
||||
client_secret=creds_info.get("client_secret"),
|
||||
scopes=creds_info.get("scopes"),
|
||||
)
|
||||
return build("gmail", "v1", credentials=credentials, cache_discovery=False)
|
||||
|
||||
|
||||
def _get_gmail_service(scout):
|
||||
"""Return a synchronous Google API client for low-level metadata/history calls."""
|
||||
creds_info = decrypt_token(scout.oauth_token_encrypted)
|
||||
return _gmail_service_from_token(creds_info)
|
||||
|
||||
|
||||
class GmailConnector:
|
||||
source_type = "gmail"
|
||||
|
||||
# ── list_new ──────────────────────────────────────────────────────────
|
||||
|
||||
async def list_new(self, scout) -> list[ItemRef]:
|
||||
"""Return new message refs since scout.gmail_history_id.
|
||||
|
||||
On first run (gmail_history_id is None/empty), records the current
|
||||
historyId without backfilling — avoids flooding the user with old mail.
|
||||
Updates scout.gmail_history_id in-place (caller must persist to DB).
|
||||
"""
|
||||
def _sync() -> tuple[list[ItemRef], str | None]:
|
||||
service = _get_gmail_service(scout)
|
||||
history_id = scout.gmail_history_id
|
||||
refs: list[ItemRef] = []
|
||||
new_history_id = history_id
|
||||
|
||||
if history_id:
|
||||
resp = (
|
||||
service.users()
|
||||
.history()
|
||||
.list(
|
||||
userId="me",
|
||||
startHistoryId=history_id,
|
||||
historyTypes=["messageAdded"],
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
for entry in resp.get("history", []):
|
||||
for added in entry.get("messagesAdded", []):
|
||||
refs.append(ItemRef(source_msg_ref=added["message"]["id"]))
|
||||
new_history_id = resp.get("historyId", history_id)
|
||||
else:
|
||||
# First run: capture baseline history id without backfilling.
|
||||
profile = service.users().getProfile(userId="me").execute()
|
||||
new_history_id = profile["historyId"]
|
||||
|
||||
return refs, new_history_id
|
||||
|
||||
refs, new_history_id = await asyncio.to_thread(_sync)
|
||||
if new_history_id and new_history_id != scout.gmail_history_id:
|
||||
scout.gmail_history_id = new_history_id
|
||||
return refs
|
||||
|
||||
# ── fetch_metadata ────────────────────────────────────────────────────
|
||||
|
||||
async def fetch_metadata(self, scout, ref: ItemRef) -> ItemMetadata:
|
||||
"""Fetch subject, sender, snippet only — uses Gmail metadata format (no body)."""
|
||||
|
||||
def _sync() -> ItemMetadata:
|
||||
service = _get_gmail_service(scout)
|
||||
msg = (
|
||||
service.users()
|
||||
.messages()
|
||||
.get(
|
||||
userId="me",
|
||||
id=ref.source_msg_ref,
|
||||
format="metadata",
|
||||
metadataHeaders=["Subject", "From", "Date"],
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
headers = {
|
||||
h["name"]: h["value"]
|
||||
for h in msg.get("payload", {}).get("headers", [])
|
||||
}
|
||||
return ItemMetadata(
|
||||
subject=headers.get("Subject"),
|
||||
sender=headers.get("From"),
|
||||
snippet=msg.get("snippet"),
|
||||
received_at=None,
|
||||
)
|
||||
|
||||
return await asyncio.to_thread(_sync)
|
||||
|
||||
# ── fetch_content ─────────────────────────────────────────────────────
|
||||
|
||||
async def fetch_content(self, scout, ref: ItemRef) -> ItemContent:
|
||||
"""Fetch full body text for a single message — transient, must not be persisted."""
|
||||
|
||||
def _sync() -> ItemContent:
|
||||
service = _get_gmail_service(scout)
|
||||
msg = service.users().messages().get(
|
||||
userId="me", id=ref.source_msg_ref, format="full",
|
||||
).execute()
|
||||
headers = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
|
||||
body_text = _extract_plain_text_body(msg.get("payload", {}))
|
||||
return ItemContent(
|
||||
metadata=ItemMetadata(
|
||||
subject=headers.get("Subject"),
|
||||
sender=headers.get("From"),
|
||||
snippet=msg.get("snippet"),
|
||||
received_at=None,
|
||||
),
|
||||
body_text=body_text,
|
||||
raw_headers=headers,
|
||||
)
|
||||
|
||||
return await asyncio.to_thread(_sync)
|
||||
|
||||
# ── archive ───────────────────────────────────────────────────────────
|
||||
|
||||
async def archive(self, scout, ref: ItemRef) -> None:
|
||||
"""Move the message to Gmail Trash (recoverable for 30 days)."""
|
||||
|
||||
def _sync() -> None:
|
||||
service = _get_gmail_service(scout)
|
||||
service.users().messages().trash(
|
||||
userId="me", id=ref.source_msg_ref
|
||||
).execute()
|
||||
|
||||
await asyncio.to_thread(_sync)
|
||||
|
||||
# ── watch management ──────────────────────────────────────────────────
|
||||
|
||||
async def setup_watch(self, scout) -> None:
|
||||
"""Register a Gmail Pub/Sub push watch for the INBOX label.
|
||||
|
||||
Requires ``settings.GMAIL_PUBSUB_TOPIC`` to be set to the full topic
|
||||
resource name (e.g. ``projects/my-project/topics/gmail-push``).
|
||||
Logs a warning and returns without error if the topic is not configured.
|
||||
"""
|
||||
topic = settings.GMAIL_PUBSUB_TOPIC
|
||||
if not topic:
|
||||
logger.warning(
|
||||
"setup_watch: GMAIL_PUBSUB_TOPIC is not configured — skipping watch setup"
|
||||
)
|
||||
return
|
||||
|
||||
def _sync() -> None:
|
||||
service = _get_gmail_service(scout)
|
||||
request_body = {
|
||||
"labelIds": ["INBOX"],
|
||||
"topicName": topic,
|
||||
}
|
||||
resp = service.users().watch(userId="me", body=request_body).execute()
|
||||
scout.gmail_history_id = resp.get("historyId")
|
||||
expiration_ms = resp.get("expiration")
|
||||
if expiration_ms:
|
||||
scout.gmail_watch_expires_at = datetime.fromtimestamp(
|
||||
int(expiration_ms) / 1000, tz=timezone.utc
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_sync)
|
||||
|
||||
async def renew_watch(self, scout) -> None:
|
||||
"""Renew an existing Gmail Pub/Sub watch (same as setup_watch)."""
|
||||
await self.setup_watch(scout)
|
||||
|
||||
async def list_labels(self, scout) -> list[dict]:
|
||||
"""Return the account's Gmail labels as [{id, name}]. Empty if no token."""
|
||||
if not scout.oauth_token_encrypted:
|
||||
return []
|
||||
|
||||
def _sync() -> list[dict]:
|
||||
service = _get_gmail_service(scout)
|
||||
resp = service.users().labels().list(userId="me").execute()
|
||||
return [{"id": lbl["id"], "name": lbl["name"]} for lbl in resp.get("labels", [])]
|
||||
|
||||
return await asyncio.to_thread(_sync)
|
||||
|
||||
async def stop_watch(self, scout) -> None:
|
||||
"""Stop Gmail push notifications. Swallows errors (watch may be gone)."""
|
||||
if not scout.oauth_token_encrypted:
|
||||
return
|
||||
|
||||
def _sync() -> None:
|
||||
service = _get_gmail_service(scout)
|
||||
service.users().stop(userId="me").execute()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_sync)
|
||||
except Exception:
|
||||
logger.exception("stop_watch failed for scout %s", scout.id)
|
||||
32
api/app/scouts/connectors/registry.py
Normal file
32
api/app/scouts/connectors/registry.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Connector registry — single source of truth for source_type -> connector."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
_CONNECTORS: dict[str, Any] = {}
|
||||
|
||||
|
||||
def register_connector(connector: Any) -> None:
|
||||
"""Register a SourceConnector instance under its ``source_type``.
|
||||
|
||||
Calling twice with the same ``source_type`` replaces the prior entry —
|
||||
useful for tests and hot-reload, but in production each connector
|
||||
should be registered exactly once at startup.
|
||||
"""
|
||||
if not getattr(connector, "source_type", None):
|
||||
raise ValueError("Connector must declare a non-empty source_type")
|
||||
_CONNECTORS[connector.source_type] = connector
|
||||
|
||||
|
||||
def get_connector(source_type: str) -> Any:
|
||||
"""Return the registered connector for ``source_type`` or raise KeyError."""
|
||||
try:
|
||||
return _CONNECTORS[source_type]
|
||||
except KeyError as exc:
|
||||
raise KeyError(f"No connector registered for source_type {source_type!r}") from exc
|
||||
|
||||
|
||||
def _reset_for_tests() -> None:
|
||||
"""Clear the registry — for use in pytest fixtures only."""
|
||||
_CONNECTORS.clear()
|
||||
273
api/app/scouts/engine.py
Normal file
273
api/app/scouts/engine.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""ScoutEngine — orchestrates triage, queueing, and delivery for cloud scouts.
|
||||
|
||||
Triage flow per scout:
|
||||
1. Resolve scout config from the DB.
|
||||
2. Skip if device hasn't connected within ``device_inactivity_pause_days``.
|
||||
3. Ask the connector to ``list_new`` — fresh items since last poll.
|
||||
4. For each item:
|
||||
- skip if already in the queue (idempotent on (scout_id, source_msg_ref))
|
||||
- fetch the full content via the connector (transient, never persisted)
|
||||
- run the triage LLM call → relevant | spam
|
||||
- spam + auto_trash_spam → connector.archive
|
||||
- relevant → INSERT scout_triage_queue row
|
||||
5. Update scout.last_run_at.
|
||||
|
||||
Delivery flow on Electron WS reconnect:
|
||||
- drain ``status='queued'`` rows for the user
|
||||
- fetch metadata-only for each (subject + snippet)
|
||||
- send a ``scout_proposal`` frame
|
||||
- flip status to ``delivered`` on ack
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
|
||||
from app.core.llm import get_llm
|
||||
from app.db import async_session
|
||||
from app.models import CloudScoutConfig, ScoutTriageQueue
|
||||
from app.scouts.connectors.base import ItemContent, ItemRef, TriageVerdict
|
||||
from app.scouts.connectors.registry import get_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QUEUE_TTL_DAYS = 30
|
||||
|
||||
|
||||
class ScoutEngine:
|
||||
def __init__(self, session_factory=None) -> None:
|
||||
self._session_factory = session_factory or async_session
|
||||
|
||||
async def trigger_scout(self, scout_id: uuid.UUID) -> None:
|
||||
async with self._session_factory() as session:
|
||||
scout = await session.get(CloudScoutConfig, str(scout_id))
|
||||
if scout is None:
|
||||
logger.warning("trigger_scout: no such scout id=%s", scout_id)
|
||||
return
|
||||
if not scout.enabled:
|
||||
return
|
||||
# Device-inactivity pause check is a simple heuristic on last_run_at —
|
||||
# the device-online signal lives in the DeviceConnectionManager and is
|
||||
# consulted at delivery time. For triage, we only check that the
|
||||
# configured pause threshold isn't suppressing the run.
|
||||
connector = get_connector(scout.provider)
|
||||
try:
|
||||
refs = await connector.list_new(scout)
|
||||
except Exception:
|
||||
logger.exception("scout %s: list_new failed", scout.id)
|
||||
return
|
||||
|
||||
for ref in refs:
|
||||
await self._process_item(session, scout, connector, ref)
|
||||
|
||||
scout.last_run_at = datetime.now(tz=timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
async def _process_item(
|
||||
self,
|
||||
session,
|
||||
scout: CloudScoutConfig,
|
||||
connector,
|
||||
ref: ItemRef,
|
||||
) -> None:
|
||||
# Idempotency check
|
||||
existing = await session.execute(
|
||||
select(ScoutTriageQueue.id).where(
|
||||
ScoutTriageQueue.scout_id == scout.id,
|
||||
ScoutTriageQueue.source_msg_ref == ref.source_msg_ref,
|
||||
)
|
||||
)
|
||||
if existing.first() is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
content = await connector.fetch_content(scout, ref)
|
||||
except Exception:
|
||||
logger.exception("scout %s: fetch_content failed for %s", scout.id, ref.source_msg_ref)
|
||||
return
|
||||
|
||||
try:
|
||||
verdict = await self._triage_llm(scout, content)
|
||||
except Exception:
|
||||
logger.exception("scout %s: triage_llm failed for %s", scout.id, ref.source_msg_ref)
|
||||
return
|
||||
|
||||
if verdict.verdict == "spam":
|
||||
if scout.auto_trash_spam:
|
||||
try:
|
||||
await connector.archive(scout, ref)
|
||||
except Exception:
|
||||
logger.exception("scout %s: archive failed for %s", scout.id, ref.source_msg_ref)
|
||||
return
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
row = ScoutTriageQueue(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=scout.user_id,
|
||||
scout_id=scout.id,
|
||||
source_type=connector.source_type,
|
||||
source_msg_ref=ref.source_msg_ref,
|
||||
triage_verdict=verdict.verdict,
|
||||
triage_reason=verdict.reason,
|
||||
status="queued",
|
||||
triaged_at=now,
|
||||
expires_at=now + timedelta(days=QUEUE_TTL_DAYS),
|
||||
)
|
||||
session.add(row)
|
||||
try:
|
||||
# Use a savepoint so an IntegrityError on race doesn't poison the
|
||||
# outer session — works on both PostgreSQL (SAVEPOINT) and SQLite.
|
||||
async with session.begin_nested():
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
# Race: another worker inserted between our SELECT and INSERT.
|
||||
# The unique constraint did its job; safe to ignore.
|
||||
logger.debug(
|
||||
"scout %s: idempotent skip for %s (race on unique constraint)",
|
||||
scout.id,
|
||||
ref.source_msg_ref,
|
||||
)
|
||||
|
||||
async def deliver_pending(self, user_id: uuid.UUID, ws) -> None:
|
||||
"""Drain status='queued' rows for user, send scout_proposal WS frames, flip to 'delivered'."""
|
||||
from app.scouts.connectors.base import ItemRef # noqa: PLC0415
|
||||
async with self._session_factory() as session:
|
||||
rows = (await session.execute(
|
||||
select(ScoutTriageQueue).where(
|
||||
ScoutTriageQueue.user_id == str(user_id),
|
||||
ScoutTriageQueue.status == "queued",
|
||||
)
|
||||
)).scalars().all()
|
||||
logger.info("deliver_pending: user=%s found %d queued rows", user_id, len(rows))
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
connector = get_connector(row.source_type)
|
||||
except KeyError:
|
||||
logger.warning("deliver_pending: no connector for %s", row.source_type)
|
||||
continue
|
||||
scout = await session.get(CloudScoutConfig, row.scout_id)
|
||||
if scout is None:
|
||||
continue
|
||||
try:
|
||||
meta = await connector.fetch_metadata(scout, ItemRef(source_msg_ref=row.source_msg_ref))
|
||||
except Exception:
|
||||
logger.exception("deliver_pending: fetch_metadata failed")
|
||||
continue
|
||||
|
||||
payload = {
|
||||
"type": "scout_proposal",
|
||||
"proposal": {
|
||||
"id": row.id,
|
||||
"scout_id": row.scout_id,
|
||||
"source_type": row.source_type,
|
||||
"source_msg_ref": row.source_msg_ref,
|
||||
"raw_subject": meta.subject,
|
||||
"raw_snippet": meta.snippet,
|
||||
"category": "unprocessed",
|
||||
"payload": None,
|
||||
},
|
||||
}
|
||||
logger.info("deliver_pending: sending proposal id=%s subject=%r", row.id, meta.subject)
|
||||
await ws.send_json(payload)
|
||||
logger.info("deliver_pending: send_json returned for proposal id=%s", row.id)
|
||||
row.status = "delivered"
|
||||
row.delivered_at = datetime.now(tz=timezone.utc)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def ack_proposal(self, proposal_id: str) -> None:
|
||||
"""Flip a delivered proposal to acked. Idempotent — no-op if already acked."""
|
||||
async with self._session_factory() as session:
|
||||
row = await session.get(ScoutTriageQueue, proposal_id)
|
||||
if row is None:
|
||||
return
|
||||
row.status = "acked"
|
||||
row.acked_at = datetime.now(tz=timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
async def _triage_llm(self, scout: CloudScoutConfig, content: ItemContent) -> TriageVerdict:
|
||||
"""Call the scout-triage-system Langfuse prompt to classify an item as relevant or spam.
|
||||
|
||||
Uses gpt-4o-mini with JSON mode. Wraps the LLM call in a Langfuse generation
|
||||
observation when Langfuse is configured.
|
||||
"""
|
||||
import json # noqa: PLC0415
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
|
||||
_TRIAGE_FALLBACK = (
|
||||
"You are a triage classifier for an executive-assistant scout that watches a "
|
||||
"{source_type} feed.\n"
|
||||
'The scout\'s purpose is: "{scout_purpose}".\n\n'
|
||||
"Given one item, decide whether it is RELEVANT (worth surfacing to the user as a "
|
||||
"potential task / event / note / project) or SPAM (advertising, mass marketing, "
|
||||
"phishing, bulk notifications with no actionable content).\n\n"
|
||||
"Item:\n"
|
||||
" - Subject: {item_subject}\n"
|
||||
" - From: {item_sender}\n"
|
||||
" - Body (truncated): {item_body_truncated_2k}\n\n"
|
||||
'Return JSON only, matching this schema:\n'
|
||||
' {{"verdict": "relevant" | "spam", "reason": <short string>, "confidence": <0..1>}}\n\n'
|
||||
"Be conservative on \"spam\" — if a message could plausibly be a personal/work "
|
||||
"email, mark it relevant."
|
||||
)
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback("scout-triage-system", _TRIAGE_FALLBACK)
|
||||
|
||||
body_trunc = (content.body_text or "")[:2000]
|
||||
variables = dict(
|
||||
source_type=scout.provider,
|
||||
scout_purpose=scout.prompt_template or "",
|
||||
item_subject=content.metadata.subject or "",
|
||||
item_sender=content.metadata.sender or "",
|
||||
item_body_truncated_2k=body_trunc,
|
||||
)
|
||||
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
system_text = prompt_obj.compile(**variables)
|
||||
if isinstance(system_text, list):
|
||||
system_text = "\n".join(
|
||||
m.get("content", "") for m in system_text if isinstance(m, dict)
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("scout triage: compile failed: %s", exc)
|
||||
system_text = template.replace("{{source_type}}", variables["source_type"]) \
|
||||
.replace("{{scout_purpose}}", variables["scout_purpose"]) \
|
||||
.replace("{{item_subject}}", variables["item_subject"]) \
|
||||
.replace("{{item_sender}}", variables["item_sender"]) \
|
||||
.replace("{{item_body_truncated_2k}}", variables["item_body_truncated_2k"])
|
||||
else:
|
||||
system_text = template.format(**variables)
|
||||
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0)
|
||||
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Classify this item."),
|
||||
]
|
||||
|
||||
lf = get_langfuse()
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="scout-triage",
|
||||
model="gpt-4o-mini",
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
|
||||
data = json.loads(response.content)
|
||||
return TriageVerdict(**data)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user