77 Commits

Author SHA1 Message Date
Roberto Musso
cc94194fd1 update app name 2026-04-08 23:27:34 +02:00
Roberto Musso
96c91e386d remove deprecated docs 2026-04-08 23:23:14 +02:00
Roberto Musso
c0aef71141 refactor(tests): remove non-deterministic journey eval cases 4.2–4.5
Keep only 4.1 (first reply contains question) as automated eval.
Multi-turn cases (4.2–4.5) are non-deterministic and tested manually
with results tracked in Langfuse.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 09:41:43 +02:00
Roberto Musso
467abc8d42 Merge branch 'develop' into feature/batch-agent-v2 2026-04-08 00:48:23 +02:00
Roberto Musso
5753f8def9 refactor: remove storage, backup, plugin/marketplace features
- Delete app/storage/ (blob_store, vector_store, encryption)
- Delete app/marketplace/ (plugin_registry, plugin_review, revenue_share)
- Delete routes: backup.py, plugins.py, storage.py, vectors.py
- Relocate embed endpoint to POST /chat/embed
- Rewrite migration 001 (remove storage/plugin tables)
- Delete migration 002 (seed_plugins)
- Remove S3/Pinecone/Qdrant env vars from settings
- Remove storage/backup quotas from tier_manager
- Remove MinIO and Qdrant from docker-compose
- Delete tests: test_backup, test_plugins, test_storage
- Update README.md and clean .env.example
2026-04-08 00:47:37 +02:00
Roberto Musso
e672b58b6f fix(langfuse): remove invalid user_id/session_id kwargs from start_as_current_observation
Langfuse V3 does not accept user_id/session_id on observation-level calls.
Moved to metadata dict in agent_runner, deep_agent, and agent_setup.

refactor(tests): fixture-based pattern for agent_runner_v2 eval tests

- cases.yaml + data/ fixtures under tests/fixtures/agent_runner_v2/
- pytest_generate_tests parametrizes test_eval_runner from YAML
- _resolve_projects() handles symbolic names and inline dicts
- _evaluate_case() centralizes all assertion logic
- --runner-dir CLI option for custom fixture folders

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 00:45:15 +02:00
Roberto Musso
d8add7e8cb feat(local-agent-v2): step 4 — journey produces structured AgentConfig JSON
Replace freeform prompt_template output with validated AgentConfig JSON:
- agent_setup.py: new system prompt (journey_system_v2), AGENT_CONFIG_START/END
  markers, _extract_agent_config() with Pydantic validation, updated handlers
  returning agent_config key; import AgentConfig from schemas
- tests/test_journey_v2.py: 6 unit tests + 5 parametrized LLM eval cases
  following test_agent_runner_v2.py pattern; _run_journey uses
  set_client_executor/clear_client_executor mirroring device_ws
- tests/fixtures/journey_v2/: cases.yaml + email_action.html + email_info.html
- tests/conftest.py: add --journey-dir CLI option; remove S3/plugin fixtures
  (cleanup from microservices migration, already present in working tree)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 00:23:58 +02:00
Roberto Musso
c6c4578f9a fix(tests): migrate eval tests to Langfuse V3 API
lf.trace() and lf.score(trace_id=...) are V2 API removed in V3.

V3 pattern:
  lf.start_as_current_observation(name=...) as context manager → obs
  obs.score(name=..., value=...)
  contextlib.nullcontext() when lf is None so structure stays the same

Updated tests 2.1–2.7 in test_agent_runner_v2.py accordingly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 23:04:24 +02:00
Roberto Musso
3aa0b36a6c fix(langfuse): use compile() instead of .format() for prompt variable injection
Langfuse uses {{variable}} syntax in its prompt management UI, while the
hardcoded fallbacks use {variable} (Python str.format). The previous code
always called .format() which silently failed/errored when a real Langfuse
prompt was fetched.

- langfuse_client.py: add compile_prompt(template, prompt_obj, **vars)
  → uses prompt_obj.compile(**vars) when Langfuse is available
  → falls back to template.format(**vars) when using the hardcoded fallback
- agent_runner.py: replace .format() with compile_prompt() for
  unified_processing (V2 local) and batch_cloud_processing (cloud agent)
- agent_setup.py: replace .format() with compile_prompt() for journey_system

deep_agent.py prompts have no variables, so no change needed there.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 16:49:26 +02:00
Roberto Musso
fa231a3642 feat(local-agent-v2): step 2+3 — unified runner + AgentConfig schema
Step 3 (prerequisite):
- app/schemas.py: add ContentTypeConfig + AgentConfig Pydantic models
- app/models.py: add agent_config (JSON, nullable) to LocalAgentConfig
- alembic migration a3b9c0d1e2f3: ADD COLUMN agent_config

Step 2 (runner refactor):
- Remove _classify_file() and _BATCH_FILE_CLASSIFIER_PROMPT (LLM classification step)
- Add Phase A: detect_content_type + preprocess (zero LLM, per file)
- Add _UNIFIED_PROCESSING_PROMPT (hot-swappable via Langfuse "unified_processing")
- Add helper functions: _format_projects, _format_metadata, _get_extraction_rules,
  _get_no_match_behavior
- Single LLM call per file with tools (classify + extract + create)
- Fix items_created: count create_* tool calls via _tool_calls_out param
- test_agent_runner_v2.py: 10 cases (2.1-2.10) with Langfuse eval scoring

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 15:00:32 +02:00
Roberto Musso
d91c98f86d chore(tests): remove Langfuse from all preprocessor tests
I test del preprocessor sono deterministici — nessun LLM coinvolto,
nessuno score da tracciare.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 14:26:33 +02:00
Roberto Musso
c0619f5c4d fix(tests): move pytest_addoption after __future__ import in conftest
SyntaxError: from __future__ imports must occur at the beginning of the file.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 14:21:50 +02:00
Roberto Musso
da282229ff refactor(tests): remove redundant filename field
file: serve sia come path da leggere che come nome passato a detect_content_type.
Non c'è motivo di averli separati.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 14:13:14 +02:00
Roberto Musso
7fa6ad5760 feat(tests): add --preprocess-dir CLI option to pytest
- conftest.py: registra --preprocess-dir via pytest_addoption
- test_preprocessors.py: usa pytest_generate_tests per leggere i casi
  a collection time con accesso a config; _content e _fixtures_dir
  accettano path dinamico

Usage: pytest tests/test_preprocessors.py --preprocess-dir /my/folder

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 13:59:32 +02:00
Roberto Musso
dcd14220ca refactor(tests): simplify YAML fixture schema and test runner
YAML: rimosse op/description/score_name/assertions block — ora detect/process
come chiave diretta, assertions piatte sullo stesso livello del caso.

Runner: eliminato _run_assertions engine, assertions inline in test_preprocess.
Riduzione da ~170 a ~75 righe totali tra YAML + test.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 11:30:38 +02:00
Roberto Musso
3cc32569d9 chore(tests): remove Langfuse scoring from preprocess tests
Scoring is only meaningful for LLM-backed steps. Preprocess tests are
deterministic Python, so scores add no value. Kept only for detect tests.

- test_preprocess: drop _lf_score call, simplify _run_assertions return type
- cases.yaml: remove score_name from all op=preprocess entries

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 11:21:42 +02:00
Roberto Musso
bf445ac2ce refactor(tests): YAML-driven fixtures for preprocessor tests
- cases.yaml: 10 test cases con schema dichiarativo (op, assertions)
- data/: 7 file reali (email_action.html, email_thread.html, email_single.html,
  email_heavy.html, generic_page.html, notes.txt, fallback.txt)
- test_preprocessors.py: parametrize da YAML via test_detect / test_preprocess;
  assertion engine generico (no_html_tags, min_length, compression_ratio,
  metadata_keys, contains, not_contains, content_type)
- requirements.txt: add PyYAML

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 10:44:41 +02:00
Roberto Musso
a2d6d689e4 feat: add preprocessor system (Step 1 — Local Agent V2)
- app/core/preprocessors/__init__.py: detect_content_type + preprocess dispatcher
- app/core/preprocessors/base.py: PreprocessResult dataclass
- app/core/preprocessors/email_html.py: BeautifulSoup HTML stripping, metadata extraction, thread splitting
- requirements.txt: add beautifulsoup4 and lxml
- tests/test_preprocessors.py: 10 tests with Langfuse scoring (preprocess.* scores)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 10:19:02 +02:00
Roberto Musso
aa8bcbf0d8 Refactor system prompt variables for clarity and consistency across agent setup and runner modules 2026-04-07 00:23:41 +02:00
Roberto Musso
1ce1d492b0 Add Langfuse observability: traces, prompt management, prompt-to-generation linking
- New app/core/langfuse_client.py: lazy singleton client, get_prompt_or_fallback()
  helper (returns raw template + prompt obj for linking), extract_usage() for token
  counts. No-ops when LANGFUSE_* env vars are not set.
- deep_agent.py: home-agent and floating-agent runs wrapped in spans; each ainvoke
  wrapped in a generation with model/input/output/usage; prompts fetched from
  Langfuse (adiuva-home-agent, adiuva-floating-agent, adiuva-floating-classifier)
  with hardcoded fallback.
- agent_runner.py: step1-classifier and step2-processor LLM calls traced; batch
  agent _run_agent_with_tools spans + generations; cloud-processor included.
  Prompts: adiuva-step1-classifier, adiuva-step2-processor, adiuva-cloud-processor.
- agent_setup.py: journey-setup span + generation per ainvoke; prompt_obj stored
  on JourneySession and reused across turns. Prompt: journey_system.
- settings.py: LANGFUSE_SECRET_KEY, LANGFUSE_PUBLIC_KEY, LANGFUSE_HOST added.
- .env.example: Langfuse section with EU/US/self-hosted host comments.
- requirements.txt: langfuse>=2.0.0.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 00:19:20 +02:00
Roberto Musso
552b8eb305 Fix project creation: code-based in runner, not delegated to Step 2 LLM
Root causes fixed:
1. PROJECT_TOOLS removed from Step 2 tool set — project assignment is now
   exclusively handled by the runner in code, never by the LLM.
2. When Step 1 returns "new", runner calls execute_on_client insert/projects
   directly (before Step 2), gets the created id, and passes it as context.
3. Newly created projects are appended to the local `projects` list so that
   subsequent files in the same run can match to them via Step 1 — prevents
   one project per file when multiple files share the same topic.

Also add tests/test_classify_file.py with pytest cases for _classify_file
and a CLI runner: python -m tests.test_classify_file <file> [project...]

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-21 23:40:38 +01:00
Roberto Musso
0d93b3960d Exclude project/projectId questions from agent setup journey
- Add explicit MUST NOT instruction: never ask about projects, projectId,
  or how to link records; project assignment is handled by the agent runner
- Remove projectId from template field list; remove projects from entity types
- Remove stale isApproved=0 reference (already removed from the data model)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-21 22:58:05 +01:00
Roberto Musso
f07580574b Replace max_turns cap with 90% confidence stopping criterion in agent setup
- Remove fixed _MAX_TURNS=5 instruction from system prompt; LLM now decides
  when to stop based on self-assessed confidence (>= 90%)
- Add _MIN_TURNS_BEFORE_NUDGE=3 and raise safety cap to _MAX_TURNS=15
- Nudge message and hard cap still act as a safety net for infinite loops

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-21 22:54:34 +01:00
Roberto Musso
1a8bf11f90 update migration plan 2026-03-20 23:48:36 +01:00
Roberto Musso
e7cdce8287 Improve Step 1 project matching and Step 2 update-first enforcement
- Rewrite _STEP1_SYSTEM_PROMPT: lower matching threshold (no longer requires
  "clear" match), strongly prefer existing projects over creating new ones,
  use structured id=|name=|status= format with aiSummary for richer context
- Add code-level UUID validation: reject hallucinated ids not in the fetched
  projects list, fall back to "new" instead of creating a bad link
- Rewrite _PROCESSING_SYSTEM_PROMPT: enforce explicit scan-before-create
  process (read existing → search → update if found → create only if not)
  with hard rule against calling create_* without checking existing records

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 23:45:29 +01:00
Roberto Musso
58bc6efd4b Rewrite run_local_agent: code-based flow, concurrency guard, remove isApproved
- Replace LLM-driven triage with code-based directory scan and project fetch
- Two-step LLM approach: Step 1 classifies file→project+domains, Step 2 processes with tools
- Add domain descriptions to Step 1 prompt for better extraction accuracy
- Add _running_agents set for per-agent concurrency guard (one running instance per agent)
- Return 409 from route before DB write when agent already running
- Remove is_approved from task_agent create/update tools and system prompt
- Remove is_approved from timeline_agent create/update tools and system prompt

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 22:21:30 +01:00
Roberto Musso
6c450805cb possibile evoluzione 2026-03-20 20:57:03 +01:00
Roberto Musso
f340d0fa3e Fix dev tier: default to power when no subscription exists
The tier is resolved live from the subscriptions table in get_current_user.
Previously fell back to 'free' unconditionally, hitting the 5 runs/day
limit immediately in dev. Now falls back to 'power' (unlimited) when
ENV=dev and no subscription row exists.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 12:32:36 +01:00
Roberto Musso
edc53cb6eb Default to power tier (unlimited) in dev when no subscription exists
Users without a subscription row in dev get power tier so rate limits
and quota checks don't block local development. In prod the fallback
remains free tier as before.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 12:12:43 +01:00
Roberto Musso
725cece5c1 Add run_context to agent tool calls for FE run logging
- AgentTriggerRequest accepts optional agent_id (FE's stable electron-store UUID)
- _make_agent_executor injects run_context into every tool_call frame
  so Electron can attribute actions to the correct agent run
- run_local_agent accepts run_context and sends a run_complete WS frame
  when the run finishes so the FE can close the run record
- trigger_agent_run builds run_context with run_id=run_log.id and the
  stable agent_id, passes it through to run_local_agent

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 09:46:17 +01:00
Roberto Musso
297e20ce8d Fix 422 on agent trigger: accept plural data type names
AgentTriggerRequest.what_to_extract now accepts list[str] instead of
strict Literal values. _to_data_types normalises all FE variants
(tasks/task, notes/note, timelines/timeline/timelineEvents,
projects/project) to the canonical plural form the runner expects,
with deduplication.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-18 00:04:29 +01:00
Roberto Musso
5a03bd1cfb Clean up agent catalog and improve extraction agent prompts
- Remove unused config_schema from AgentCatalogItem (schema + route)
- Fix agent_setup system prompt: add extraction agent base behaviour
  context so journey LLM knows what is already handled and focuses on
  field mappings only; remove redundant data-types question (already
  known from user selection); derive data types list dynamically
- Rewrite processing base prompt to use actual tool names
  (list_tasks, update_task, add_task_comment, list_notes, update_note,
  list_timelines, update_timeline, list_all_projects, create_project)
  and enforce update-first strategy before falling back to creation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 23:52:54 +01:00
Roberto Musso
87b7a1c6c9 fix journey setup: honor FE session_id, seed LLM history, and force template on max turns
- Use session_id from the FE frame so replies match the listener key
- Seed conversation with a user message for LLM provider compatibility
- On max turns, nudge the LLM and immediately re-invoke to force
  prompt_template generation instead of deferring to next message
- Fix display_message extraction to safely check for template markers

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 16:25:53 +01:00
Roberto Musso
826f64d6bb refactor local directory agent to two-phase LLM-with-tools architecture
Replace the single-pass FE-driven agent_run/agent_data flow with a
BE-orchestrated two-phase execution using LangChain tool-calling:
- Phase 1 (Triage): explores directory via new filesystem tools, matches
  files to existing projects using PROJECT_TOOLS
- Phase 2 (Processing): reads files and performs CRUD per project group
  with clean LLM context windows

Key changes:
- Add filesystem_agent.py with list_directory, read_file_content,
  get_file_metadata tools using execute_on_client()
- Move setup journey from REST to WebSocket (journey_start/message frames)
- Add batch_runs_per_day billing limit and enforce in /trigger
- Remove deprecated agent_data/agent_complete frame handlers and queues

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 08:50:46 +01:00
5faa6b1d7c refactor agents to client-owned config flow 2026-03-16 22:35:46 +01:00
02a9684cd6 scope episodic memory enrichment by session_id 2026-03-16 00:33:11 +01:00
fae9efee0d removed old plan files 2026-03-13 16:58:43 +01:00
30b062dd4a fix floating stream empty responses with sanitizer-safe fallbacks 2026-03-13 16:57:30 +01:00
2a0331d7ce refactor floating_domain to structured object-only payload 2026-03-13 16:09:24 +01:00
13fd8677c1 fix: normalize home task/timeline responses to tag-only lines 2026-03-13 12:16:58 +01:00
9bd629cb59 chore: add interaction tracing and remove personal fields from logs 2026-03-13 10:23:47 +01:00
9c97702daa feat: add letta-style memory tools with request/user debug tracing 2026-03-13 09:34:23 +01:00
a1e364c9c0 refactor: switch to single-agent deep runner and add mock memory/tool tests 2026-03-13 08:20:42 +01:00
5b55f1292a make a single agent 2026-03-13 07:42:36 +01:00
5bc9ea6cd6 fix: make planner schema copilot-compatible and silence usage warning 2026-03-12 23:17:31 +01:00
f7404b6f66 refactor: move memory updates from synthesizer to orchestrator node 2026-03-12 23:03:38 +01:00
d667e43c73 refactor: use native LangGraph streaming and enforce structured summary on workers 2026-03-12 22:50:32 +01:00
fe085a7951 feat: migrate chat orchestration to deep langgraph workers 2026-03-12 22:25:36 +01:00
2de67213f8 rename from checkpoint to timeline agent 2026-03-10 23:17:38 +01:00
f6ed383b3a add user name and surname 2026-03-10 16:14:00 +01:00
9332e29e53 bug fix sending component 2026-03-10 09:11:24 +01:00
618076193a update alembic 2026-03-08 23:17:01 +01:00
34f01234c9 rename popup chat to floating chat 2026-03-08 22:53:31 +01:00
0bd46937d3 fix: add missing json imports and update agent tool tests
Code bugs fixed:
- checkpoint_agent.py, project_agent.py, note_agent.py: add missing
  'import json' (used in handle() for context serialization)

Test fixes:
- test_agents.py: add autouse ws_executor fixture that sets a fake
  execute_on_client so tools can run in unit tests without a WS session
- Rewrite all TestXxxAgentTools tests: patch execute_on_client per-test,
  assert on call_args (what payload was sent to the client) and on the
  formatted string return value — matching actual tool behavior

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 22:25:06 +01:00
e6b5bc2e7d step-7: add memory middleware (memory_middleware.py, device_ws.py)
MemoryMiddleware class:
- enrich_context(): loads core prefs, associative (top-k), episodic (last-N),
  and proactive hints (above 0.6 confidence) — all decrypted in-memory only
- store_episode(): encrypts and persists interaction summary to memory_episodic
- update_core(): upserts encrypted key/value to memory_core

device_ws.py home_request + popup_request handlers:
- enrich_context() called before orchestrate_v3_stream (memory injected into context)
- store_episode() called after stream completes (non-blocking)

10 unit + integration tests pass; pre-existing test_agents.py failures unrelated.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 22:14:28 +01:00
c90ed58078 step-6: add memory models and migration (models.py, alembic)
- User.encryption_key: per-user Fernet key generated on registration
- MemoryCore: encrypted key/value preferences
- MemoryAssociative: encrypted semantic memory + pgvector(1536) embedding
- MemoryEpisodic: encrypted session summaries
- MemoryProactive: encrypted behavioral patterns with confidence score
- Migration 004: enables pgvector extension, creates all 4 tables + ivfflat index
- auth.py register: generates Fernet key for new users
- 8 unit tests pass (SQLite in-memory, JSON embedding fallback)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 22:05:58 +01:00
76c8f2bdad step-5: unify ws handler (device_ws.py, chat.py)
- device_ws.py: dispatch home_request/popup_request to HomeFormatter/PopupFormatter
  via async tasks; each request gets a UUID request_id for frame correlation
- chat.py: remove chat_stream WS endpoint (superseded by unified device WS);
  keep POST /chat REST fallback unchanged
- 5 new integration tests pass; all 22 existing device_ws tests still pass

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 22:01:11 +01:00
393b3befd6 step-4: add output formatting layer (output_formatter.py)
HomeFormatter parses JSON block stream from orchestrator tokens and emits
stream_start / stream_text / stream_block / stream_end frames.
PopupFormatter emits popup_domain then plain stream_text.
All 13 unit tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 21:51:20 +01:00
2c08275934 step-3: add router refactor with streaming support (orchestrator.py)
- orchestrate_v3(user_id, message, context): classifies intent, returns
  (agent_name, agent_instance) — caller drives execution
- orchestrate_v3_stream(user_id, message, context): yields (agent_name, token)
  pairs; first yield is always (agent_name, "") as a domain-detection signal
- ChatAgent.handle_stream(): default implementation yields handle() result as
  one chunk; subclasses override for true token-level streaming
- Fix stale test_orchestrator.py assertions that expected a JSON final frame
  that orchestrate_stream never emitted

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 21:42:46 +01:00
7cb384fa63 step-2: add agent streaming and tool result capture (agent_registry.py)
- ChatAgent.__init__: adds tool_results: list[dict] = []
- _tool_loop: wraps execution in a result collector; populates
  self.tool_results with raw execute_on_client dicts after each run
- _tool_loop_stream: streaming variant — uses ainvoke for tool-call
  iterations, llm.astream() for the final answer; same result capture
- ws_context.py: adds _tool_result_collector ContextVar +
  set/clear helpers; execute_on_client appends to collector when set

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 21:37:15 +01:00
7efaeba283 chore: migrate Settings to Pydantic v2 ConfigDict
Replace deprecated Pydantic v1 `class Config:` inner class with
`model_config = SettingsConfigDict(...)` to eliminate the deprecation
warning emitted on every test run.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 21:25:45 +01:00
b61ded8458 step-1: add v3 ws frame protocol (schemas.py)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-08 21:21:03 +01:00
ac71d99f9a add cerebras models 2026-03-08 00:53:25 +01:00
3b3b3baf25 update memory implementation strategy 2026-03-08 00:47:24 +01:00
45415bb9ee Update plan 2026-03-05 23:54:45 +01:00
a775a2da18 feat(step-3.6): cloud provider integrations (Gmail, Outlook, Teams)
- Add app/integrations/__init__.py: Fernet token encryption helpers,
  EmailMessage/ChatMessage dataclasses, get_provider() factory
- Add app/integrations/gmail.py: GmailClient with async fetch_messages(),
  token refresh, configurable label/sender/date filters
- Add app/integrations/ms_graph.py: MSGraphClient with fetch_emails()
  (Outlook) and fetch_messages() (Teams), MSAL token refresh, OData filters
- Update app/core/agent_runner.py: replace run_cloud_agent() stub with
  full 8-step implementation; extend _finalize_run() for cloud config type
- Update app/config/settings.py: add OAuth + Fernet encryption settings
- Update requirements.txt: google-api-python-client, google-auth-*,
  msal, cryptography
- Add tests/test_integrations.py: 47 tests covering all integration code
- Update tests/test_agent_runner.py: replace stub test with 7 real tests

All 76 new/updated tests pass.
2026-03-05 18:05:07 +01:00
24772f2b67 step 3.5 complete: chatbot journey endpoint 2026-03-05 17:35:37 +01:00
fd1396a710 update plan 2026-03-05 16:15:24 +01:00
914f70bd85 step 3.4 complete: agent run orchestrator — local/cloud runner + trigger_pending_runs + 23 tests 2026-03-05 16:13:21 +01:00
608d6c784f step 3.3 complete: device WS endpoint + DeviceConnectionManager 2026-03-05 15:51:58 +01:00
19ad5be97f step 3.2 complete: agent CRUD API routes
- Add app/api/routes/agents.py with 11 endpoints:
  GET/POST/PUT/DELETE /agents/local (local directory agent configs)
  GET/POST/PUT/DELETE /agents/cloud (cloud connector agent configs)
  GET /agents/catalog (hardcoded agent type catalog)
  GET /agents/runs (paginated run logs with agent_id/page/limit filters)
  POST /agents/{id}/run (manual trigger stub, dispatch wired in step 3.4)
- Tier-gate creation via combined local+cloud batch_active limit
- Ownership checks on all mutations (404 on mismatch)
- Cascade delete of run logs via SQLAlchemy relationship
- Register agents router in app/main.py
- Fix missing import json in app/agents/task_agent.py
2026-03-05 15:33:53 +01:00
1dfd088e18 step 3.1 complete: agent config tables + schemas + migration 2026-03-05 15:14:43 +01:00
c6e1e4e7fd fix: migration enum creation — use DO/EXCEPTION instead of broken checkfirst 2026-03-05 00:24:31 +01:00
cc603aba06 step B.6 complete: POST /api/v1/storage/vectors/embed endpoint 2026-03-05 00:07:06 +01:00
6d9a16e513 steps B.3/B.4/B.5 complete: bidirectional WS handler, _tool_loop verified, clean final frame 2026-03-05 00:06:11 +01:00
27c087d5d8 step B.2 complete: all 23 tools use execute_on_client(); add embed() to llm 2026-03-05 00:03:01 +01:00
rmusso
4d7fd519c5 step B.1 complete: WS context + frame schemas 2026-03-04 23:59:31 +01:00
106 changed files with 11557 additions and 6566 deletions

View File

@@ -2,7 +2,7 @@
ENV=dev
# ── Database ──────────────────────────────────────────────────────────────────
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai
# ── Auth ──────────────────────────────────────────────────────────────────────
JWT_SECRET=replace-with-a-long-random-secret
@@ -23,21 +23,13 @@ LLM_ROUTER_MODEL=gpt-4o-mini
STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET=
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
S3_BUCKET=adiuva
S3_REGION=us-east-1
S3_ENDPOINT_URL=
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
# ── Vector Store ──────────────────────────────────────────────────────────────
# Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
PINECONE_API_KEY=
PINECONE_INDEX=adiuva
QDRANT_URL=
QDRANT_API_KEY=
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
# ── Langfuse (leave empty to disable observability) ───────────────────────────
LANGFUSE_SECRET_KEY=
LANGFUSE_PUBLIC_KEY=
# LANGFUSE_HOST=https://cloud.langfuse.com # EU (default)
# LANGFUSE_HOST=https://us.cloud.langfuse.com # US
# LANGFUSE_HOST=http://localhost:3000 # Self-hosted
# ── CORS ──────────────────────────────────────────────────────────────────────
# Comma-separated list parsed by Settings (override default if needed)

View File

@@ -48,23 +48,23 @@ jobs:
key: ${{ secrets.SSH_KEY }}
script: |
set -e
DEPLOY_DIR="/opt/adiuva-api"
DEPLOY_DIR="/opt/adiuvai-api"
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
TAG="${{ gitea.ref_name }}"
# ── Pull latest code ──
cd /tmp && rm -rf adiuva-api-deploy
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuva-api-deploy
cd /tmp && rm -rf adiuvai-api-deploy
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuvai-api-deploy
# ── Sync source (preserve .env) ──
cp -rf /tmp/adiuva-api-deploy/app/ \
/tmp/adiuva-api-deploy/alembic/ \
/tmp/adiuva-api-deploy/alembic.ini \
/tmp/adiuva-api-deploy/Dockerfile \
/tmp/adiuva-api-deploy/docker-compose.yml \
/tmp/adiuva-api-deploy/requirements.txt \
cp -rf /tmp/adiuvai-api-deploy/app/ \
/tmp/adiuvai-api-deploy/alembic/ \
/tmp/adiuvai-api-deploy/alembic.ini \
/tmp/adiuvai-api-deploy/Dockerfile \
/tmp/adiuvai-api-deploy/docker-compose.yml \
/tmp/adiuvai-api-deploy/requirements.txt \
"$DEPLOY_DIR/"
rm -rf /tmp/adiuva-api-deploy
rm -rf /tmp/adiuvai-api-deploy
# ── Verify .env ──
if [ ! -f "$DEPLOY_DIR/.env" ]; then

View File

@@ -58,7 +58,7 @@ jobs:
- uses: actions/checkout@v4
- name: Build image
run: docker build -t adiuva-api:ci .
run: docker build -t adiuvai-api:ci .
- name: Verify gunicorn installed
run: docker run --rm adiuva-api:ci gunicorn --version
run: docker run --rm adiuvai-api:ci gunicorn --version

2
.gitignore vendored
View File

@@ -21,6 +21,7 @@ env/
.pytest_cache/
htmlcov/
.coverage
tests/fixtures/private*/
# Docker
*.log
@@ -31,3 +32,4 @@ Thumbs.db
# Claude Code
.claude/
logs/

View File

@@ -1,533 +0,0 @@
# Backend Plan — Adiuva Cloud API
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
>
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, E2E backup blob storage, cloud storage (encrypted blobs), cloud vector store, and plugin marketplace.
> The backend NEVER persists user data in plaintext. Cloud storage blobs are E2E encrypted before upload — the backend only verifies integrity, never decrypts.
---
## Project Structure
```
adiuva-api/
├── app/
│ ├── __init__.py
│ ├── main.py # FastAPI entry + CORS + lifespan + router includes
│ ├── core/
│ │ ├── __init__.py
│ │ ├── agent_registry.py # Base classes + singleton registry
│ │ ├── orchestrator.py # LLM-based intent router
│ │ ├── execution_plan.py # Plan builder + cache
│ │ └── plugin_loader.py # Dynamic agent loading
│ ├── agents/ # Chat agents (proprietary logic + prompts)
│ │ ├── __init__.py # Auto-registers all agents
│ │ ├── task_agent.py
│ │ ├── calendar_agent.py
│ │ ├── email_agent.py
│ │ └── analytics_agent.py
│ ├── api/
│ │ ├── __init__.py
│ │ ├── routes/
│ │ │ ├── __init__.py
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
│ │ │ ├── plans.py # GET /plans/playbook
│ │ │ ├── storage.py # CRUD cloud storage (E2E encrypted blobs)
│ │ │ ├── vectors.py # Upsert/search cloud vector store
│ │ │ ├── backup.py # PUT/GET /backup
│ │ │ ├── plugins.py # Plugin marketplace
│ │ │ ├── auth.py # Register/login/refresh
│ │ │ └── billing.py # Checkout/webhook/subscription
│ │ └── middleware/
│ │ ├── __init__.py
│ │ ├── auth.py # JWT validation
│ │ ├── rate_limit.py # Tier-aware rate limiting
│ │ └── sanitizer.py # Strip prompt metadata from responses
│ ├── storage/
│ │ ├── __init__.py
│ │ ├── blob_store.py # S3 for E2E encrypted blobs
│ │ ├── vector_store.py # Cloud vector store (Pinecone/Qdrant)
│ │ └── encryption.py # Integrity verification only — NO decryption
│ ├── marketplace/
│ │ ├── __init__.py
│ │ ├── plugin_registry.py # Plugin catalog (metadata, versions, ratings)
│ │ ├── plugin_review.py # Review queue + approval workflow
│ │ └── revenue_share.py # 70/30 split tracking with Stripe Connect
│ ├── billing/
│ │ ├── __init__.py
│ │ ├── stripe_service.py # Stripe checkout + webhooks
│ │ └── tier_manager.py # Feature matrix per tier
│ └── config/
│ ├── __init__.py
│ └── settings.py # Pydantic BaseSettings (env-based)
├── tests/
│ ├── __init__.py
│ ├── conftest.py # Fixtures: test client, mock agents, mock LLM
│ ├── test_orchestrator.py
│ ├── test_agents.py
│ ├── test_auth.py
│ ├── test_backup.py
│ ├── test_storage.py
│ └── test_plugins.py
├── alembic/ # DB migrations (auth/billing/marketplace tables only)
│ ├── alembic.ini
│ └── versions/
├── requirements.txt
├── Dockerfile
├── docker-compose.yml # App + PostgreSQL + Redis (dev)
├── .env.example
└── README.md
```
---
## Step-by-Step Implementation
### Step 1 — Project scaffolding ✅
- [x] Initialize repo with the directory structure above
- [x] Write `requirements.txt`:
```
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
langchain>=0.3.0
langchain-openai>=0.3.0
pydantic>=2.10.0
python-jose[cryptography]>=3.3.0
stripe>=11.0.0
boto3>=1.35.0
slowapi>=0.1.9
sqlalchemy>=2.0.0
asyncpg>=0.30.0
alembic>=1.14.0
bcrypt>=4.2.0
python-dotenv>=1.0.0
httpx>=0.28.0
websockets>=14.0
pytest>=8.0.0
pytest-asyncio>=0.24.0
```
- [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
- [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod), `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
- [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
- [x] Write `docker-compose.yml`: app, postgres:16, optional redis
- [x] Write `.env.example`
- **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes).
### Step 2 — Pydantic schemas (API contracts) ✅
- [x] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification', 'call_agent']`, `table: str | None`, `data: dict | None`, `agent: str | None`
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
- `StorageRecord`: `id: str`, `user_id: str`, `table: str`, `blob: bytes`, `checksum: str`, `created_at: int`, `updated_at: int` — blob is always E2E encrypted by client
- `StorageRecordCreate`: `table: str`, `blob: bytes`, `checksum: str`
- `StorageRecordUpdate`: `blob: bytes`, `checksum: str`
- `VectorUpsertRequest`: `vectors: list[VectorItem]`
- `VectorItem`: `id: str`, `blob: bytes`, `checksum: str` — vector + metadata encrypted by client
- `VectorSearchRequest`: `query_blob: bytes`, `top_k: int = 10`
- `VectorSearchResponse`: `results: list[VectorSearchResult]`
- `VectorSearchResult`: `id: str`, `score: float`, `blob: bytes`
- `PluginManifest`: `id: str`, `name: str`, `description: str`, `version: str`, `author: str`, `permissions: list[str]`, `category: str`, `price_cents: int = 0`
- `PluginListResponse`: `plugins: list[PluginManifest]`, `total: int`, `page: int`
- `PluginInstallRequest`: `plugin_id: str`
- **Outcome:** All request/response models defined and validated.
### Step 3 — Agent Registry + base classes ✅
- [x] `app/core/agent_registry.py`:
- `BaseAgent(ABC)`:
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
- Abstract `get_name() -> str`, `get_description() -> str`
- `ChatAgent(BaseAgent)`:
- Abstract `async handle(query: str, context: dict) -> str`
- Abstract `get_tools() -> list` (LangChain tool definitions)
- Concrete `_tool_loop(llm, messages, tools, max_iter=5) -> str` — shared tool-calling loop
- `AgentRegistry` (singleton):
- `_agents: dict[str, ChatAgent]`
- `register(agent_class)` — decorator pattern
- `get(name) -> ChatAgent`
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
- `async call_agent(name, query, context) -> str` — for inter-agent calls
- [x] Unit tests: register, get, list, call_agent with mock
- **Outcome:** Pluggable agent framework.
### Step 4 — Orchestrator ✅
- [x] `app/core/orchestrator.py`:
- `async classify_intent(message, context, registry) -> str`:
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
- Uses gpt-4o-mini via LangChain for low latency
- Falls back to `task_agent` if no clear match
- `async route_single(agent_name, message, context) -> ChatResponse`:
- Instantiates agent from registry
- Calls `agent.handle(message, context)`
- Returns response + any actions the agent produced
- `async route_pipeline(agent_names, message, context) -> ChatResponse`:
- Executes agents in sequence
- Each agent receives `{...context, previous_results: [...]}`
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
- Main entry point
- Context is transparent to orchestrator — data may originate from local or cloud storage on the client side
- Classifies intent
- If `execution_mode == 'direct'`: route + return response
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
- Same as orchestrate but yields tokens for WebSocket streaming
- [x] Integration tests with mocked LLM and mocked agents
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
### Step 5 — Execution Plan generator ✅
- [x] `app/core/execution_plan.py`:
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
- `ExecutionPlanBuilder`:
- `add_step(action, params) -> self`
- `add_llm_step(template_id, variables) -> self`
- `add_data_step(action, data_from_step) -> self`
- `build() -> ExecutionPlan` — validates step references
- `PlanCache`:
- In-memory LRU (maxsize=1000)
- `cache_plan(key, plan)`, `get_plan(key)`, `get_all_playbooks() -> list[ExecutionPlan]`
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
### Step 6 — Chat Agents ✅
- [x] `app/agents/task_agent.py` — `@registry.register`:
- Description: "Manages tasks and comments: list, create, update, delete, due-today, comments"
- Tools (8): `list_tasks(project_id, status, search, order_by)`, `create_task(title, description, status, priority, assignees, due_date, project_id, is_ai_suggested, is_approved)`, `update_task(task_id, ...)`, `delete_task(task_id)`, `list_tasks_due_today()`, `list_task_comments(task_id)`, `add_task_comment(task_id, author, content)`, `delete_task_comment(comment_id)`
- status: `todo|in_progress|done`; priority: `high|medium|low`; assignees: JSON-encoded string; due_date: ms timestamp
- Accepts flexible context; sentinel `-1` for optional integer update fields
- [x] `app/agents/checkpoint_agent.py` — `@registry.register`:
- Description: "Manages project checkpoints (milestones): list, create, update, delete"
- Tools (4): `list_checkpoints(project_id)`, `create_checkpoint(project_id, title, date, is_ai_suggested, is_approved)`, `update_checkpoint(checkpoint_id, ...)`, `delete_checkpoint(checkpoint_id)`
- `project_id` is required for create; date is a ms timestamp; supports AI-suggestion + approval workflow
- [x] `app/agents/project_agent.py` — `@registry.register`:
- Description: "Manages projects: list, get, create, update, archive, delete"
- Tools (6): `list_projects(client_id, include_archived)`, `list_all_projects()`, `get_project(project_id)`, `create_project(name, client_id)`, `update_project(project_id, ...)`, `delete_project(project_id)`
- status: `active|archived`; prefers archive over deletion (docstring guard on delete)
- [x] `app/agents/note_agent.py` — `@registry.register`:
- Description: "Manages notes: list, get, create, update, delete"
- Tools (5): `list_notes(project_id)`, `get_note(note_id)`, `create_note(title, content, project_id)`, `update_note(note_id, ...)`, `delete_note(note_id)`
- content is Markdown; `get_note` should be called before update to preserve existing content
- [x] `app/agents/__init__.py`: imports all four agent modules to trigger `@registry.register` decorators
- [x] Unit tests per agent with mocked LLM (registration, names, tool counts, handle(), direct tool invocation)
- **Outcome:** Four domain-specific agents matching the UI data model (Tasks, Checkpoints, Projects, Notes), all registered and tested.
### Step 7 — Storage Layer ✅
- [x] `app/storage/blob_store.py`:
- `BlobStore`: `async upload`, `async download`, `async delete` (idempotent), `async list_keys`
- Keys: `{user_id}/{table}/{record_id}` — backend never inspects blob content
- boto3 S3 with SSE-S3 at-rest encryption; client checksum stored in S3 object metadata
- [x] `app/storage/vector_store.py`:
- `VectorStore`: `async upsert`, `async search`, `async delete`
- Pinecone (default, `namespace=user_id`) or Qdrant (`user_id` payload filter) — runtime-configurable
- 32-dim SHA-256-derived float vector; blob stored as base64 in metadata/payload
- ANN on encrypted data: known accuracy trade-off, documented
- [x] `app/storage/encryption.py`:
- `verify_checksum(blob, checksum) -> bool` — SHA-256 + `hmac.compare_digest` (constant-time)
- `reject_if_tampered(blob, checksum)` — raises `HTTP 400` on mismatch
- Backend NEVER holds decryption keys
- [x] `app/schemas.py`: added `StorageRecord*`, `VectorItem`, `VectorUpsertRequest`, `VectorSearch*`, `Plugin*` schemas
- [x] `app/config/settings.py`: added `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
- [x] `requirements.txt`: added `moto[s3]`, `pinecone`, `qdrant-client`
- [x] 37 unit tests covering encryption, BlobStore (moto), VectorStore Pinecone, VectorStore Qdrant
- **Outcome:** Cloud storage layer that handles E2E encrypted blobs without ever accessing plaintext.
### Step 8 — API Routes ✅
#### 8a — Chat endpoint
- [x] `app/api/routes/chat.py`:
- `POST /api/v1/chat`:
- Request: `ChatRequest`
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
- Response: `ChatResponse` or `ExecutionPlan`
- `WebSocket /api/v1/chat/stream`:
- Client sends `ChatRequest` as first JSON frame
- Server yields token strings via `orchestrate_stream()`
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
- Heartbeat ping every 30s to keep connection alive
#### 8b — Plans endpoint
- [x] `app/api/routes/plans.py`:
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
#### 8c — Storage endpoint (cloud records)
- [x] `app/api/routes/storage.py`:
- `POST /api/v1/storage/records`: Create encrypted record
- Request: `StorageRecordCreate`
- Verifies checksum, stores blob in S3, inserts metadata row in PostgreSQL
- Response: `{id: str, created_at: int}`
- `GET /api/v1/storage/records`: List record metadata (no blobs)
- Query params: `table: str`, `page: int`, `limit: int`
- Response: `list[{id, table, checksum, created_at, updated_at}]`
- `GET /api/v1/storage/records/{id}`: Download encrypted blob
- Response: blob bytes + `X-Checksum` header
- `PUT /api/v1/storage/records/{id}`: Update encrypted blob
- Request: `StorageRecordUpdate`
- `DELETE /api/v1/storage/records/{id}`: Delete record + S3 blob
- All routes enforce tier cloud_storage_gb quota via `TierManager.check_quota(user_id)`
#### 8d — Vectors endpoint (cloud vector store)
- [x] `app/api/routes/vectors.py`:
- `POST /api/v1/storage/vectors/upsert`:
- Request: `VectorUpsertRequest`
- Verifies checksums, delegates to `VectorStore.upsert()`
- Response: `{upserted: int}`
- `POST /api/v1/storage/vectors/search`:
- Request: `VectorSearchRequest`
- Delegates to `VectorStore.search()`
- Response: `VectorSearchResponse`
- `DELETE /api/v1/storage/vectors`:
- Request: `{ids: list[str]}`
#### 8e — Backup endpoint
- [x] `app/api/routes/backup.py`:
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
- Free: 0 (no backup)
- Pro: 5 GB
- Power: 25 GB
- Team: unlimited
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
#### 8f — Plugins endpoint
- [x] `app/api/routes/plugins.py`:
- `GET /api/v1/plugins`:
- Query params: `category: str | None`, `q: str | None`, `page: int`, `sort: Literal['rating', 'installs', 'newest']`
- Response: `PluginListResponse`
- Available from Power tier and above
- `GET /api/v1/plugins/{id}`:
- Response: `PluginManifest` + ratings + install count
- `POST /api/v1/plugins/{id}/install`:
- Request: `PluginInstallRequest`
- Records installation for the user (billing tracking, analytics)
- If plugin is paid: triggers Stripe Connect charge + revenue split (70% developer, 30% platform)
- Response: `{ok: true, download_url: str}` — signed S3 URL for plugin package
- `DELETE /api/v1/plugins/{id}/install`:
- Unregisters installation
#### 8g — Auth endpoint
- [x] `app/api/routes/auth.py`:
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
#### 8h — Billing endpoint
- [x] `app/api/routes/billing.py`:
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
- `GET /api/v1/billing/subscription`: Returns current subscription info
- `DELETE /api/v1/billing/subscription`: Cancels subscription
- **Outcome:** Complete REST + WebSocket API covering orchestration, storage, vectors, backup, marketplace.
### Step 9 — Middleware
#### 9a — Auth middleware
- [x] `app/api/middleware/auth.py`:
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
- Validates JWT signature, expiry, extracts `user_id` and `tier`
- Raises `401` on invalid/expired token
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
#### 9b — Rate limiter
- [x] `app/api/middleware/rate_limit.py`:
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
- Tier-based limits:
- Free: 20 req/min
- Pro: 60 req/min
- Power: 120 req/min
- Team: 200 req/seat/min
- Custom 429 response with `Retry-After` header
#### 9c — Sanitizer
- [x] `app/api/middleware/sanitizer.py`:
- Response middleware that scans response bodies
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
- Pattern-based detection + exact match against known prompt fingerprints
- Logs sanitization events for monitoring
- **Outcome:** Secure, rate-limited API with prompt IP protection.
### Step 10 — Plugin Marketplace ✅
- [x] `app/marketplace/plugin_registry.py`:
- `PluginRegistry`:
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
- `async get_plugin(plugin_id) -> PluginManifest | None`
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
- `async reject_plugin(plugin_id, reason: str) -> None`
- [x] `app/marketplace/plugin_review.py`:
- `ReviewQueue`:
- `async get_pending() -> list[dict]`
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
- [x] `app/marketplace/revenue_share.py`:
- `RevenueShare`:
- `async record_install(plugin_id, user_id, amount_cents) -> None`
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
- `async get_earnings(developer_id, period) -> dict`
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
### Step 11 — Billing & Tier management ✅
- [x] `app/billing/stripe_service.py`:
- `create_checkout_session(user_id, tier) -> str`
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
- `get_subscription(user_id) -> dict | None`
- `cancel_subscription(user_id) -> None`
- [x] `app/billing/tier_manager.py`:
- `TierManager`:
- Feature matrix:
```python
FEATURES = {
'free': {
'agents': 3,
'batch_active': 2,
'cloud_storage_gb': 0,
'backup_gb': 0,
'providers': 1,
'batch_builder': False,
'plugin_marketplace': False,
'sso': False,
},
'pro': {
'agents': -1, # unlimited
'batch_active': 10,
'cloud_storage_gb': 5,
'backup_gb': 5,
'providers': -1,
'batch_builder': False,
'plugin_marketplace': False,
'sso': False,
},
'power': {
'agents': -1,
'batch_active': -1, # unlimited
'cloud_storage_gb': 25,
'backup_gb': 25,
'providers': -1,
'batch_builder': True,
'plugin_marketplace': True,
'sso': False,
},
'team': {
'agents': -1,
'batch_active': -1,
'cloud_storage_gb': -1,
'backup_gb': -1,
'providers': -1,
'batch_builder': True,
'plugin_marketplace': True,
'sso': True,
},
}
```
- `get_tier(user_id) -> BillingTier`
- `check_feature(user_id, feature) -> bool`
- `get_rate_limit(tier) -> int`
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons
- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService`
- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota`
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
### Step 12 — Database (auth/billing/marketplace only)
- [x] PostgreSQL schema via Alembic:
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
- `storage_records`: `id UUID PK`, `user_id FK`, `table_name VARCHAR`, `s3_key`, `checksum`, `size_bytes`, `created_at`, `updated_at` — metadata only, no plaintext
- `plugins`: `id UUID PK`, `name`, `description`, `version`, `author_id FK`, `category`, `status` (pending_review/approved/rejected), `price_cents`, `s3_package_key`, `install_count`, `avg_rating`, `created_at`
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
- [x] Initial Alembic migration
- [x] SQLAlchemy models in `app/models.py`
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
### Step 13 — Testing & deployment ✅
- [x] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone
- [x] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
- [x] `tests/test_agents.py`: each agent with mocked tools
- [x] `tests/test_auth.py`: register → login → access protected → refresh → expired token
- [x] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
- [x] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement
- [x] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked)
- [x] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
- [x] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
- **Outcome:** Fully tested, deployable backend.
---
## API Contract Summary
| Method | Endpoint | Auth | Request | Response |
|--------|----------|------|---------|----------|
| POST | `/api/v1/auth/register` | No | `{email, password}` | `AuthTokens` |
| POST | `/api/v1/auth/login` | No | `{email, password}` | `AuthTokens` |
| POST | `/api/v1/auth/refresh` | No | `{refresh_token}` | `AuthTokens` |
| GET | `/api/v1/auth/me` | JWT | — | `UserProfile` |
| POST | `/api/v1/chat` | JWT | `ChatRequest` | `ChatResponse \| ExecutionPlan` |
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
| POST | `/api/v1/storage/records` | JWT | `StorageRecordCreate` | `{id, created_at}` |
| GET | `/api/v1/storage/records` | JWT | `?table&page&limit` | `RecordMeta[]` |
| GET | `/api/v1/storage/records/:id` | JWT | — | Binary blob |
| PUT | `/api/v1/storage/records/:id` | JWT | `StorageRecordUpdate` | `{ok: true}` |
| DELETE | `/api/v1/storage/records/:id` | JWT | — | `{ok: true}` |
| POST | `/api/v1/storage/vectors/upsert` | JWT | `VectorUpsertRequest` | `{upserted: int}` |
| POST | `/api/v1/storage/vectors/search` | JWT | `VectorSearchRequest` | `VectorSearchResponse` |
| DELETE | `/api/v1/storage/vectors` | JWT | `{ids: list[str]}` | `{ok: true}` |
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
| GET | `/api/v1/backup` | JWT | — | Binary blob |
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
| GET | `/api/v1/plugins` | JWT | `?category&q&page&sort` | `PluginListResponse` |
| GET | `/api/v1/plugins/:id` | JWT | — | `PluginManifest` + stats |
| POST | `/api/v1/plugins/:id/install` | JWT | `PluginInstallRequest` | `{ok, download_url}` |
| DELETE | `/api/v1/plugins/:id/install` | JWT | — | `{ok: true}` |
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
| DELETE | `/api/v1/billing/subscription` | JWT | — | `{ok: true}` |
| GET | `/api/v1/health` | No | — | `{status, version}` |
---
## Stack
| Layer | Technology |
|-------|-----------|
| Framework | FastAPI + Uvicorn |
| LLM | LangChain + langchain-openai |
| Auth | PyJWT + bcrypt + OAuth2 |
| Billing | stripe-python + Stripe Connect |
| Blob storage | boto3 (S3) |
| Vector store | Pinecone or Qdrant (configurable) |
| Database | PostgreSQL + SQLAlchemy + Alembic |
| Rate limiting | slowapi |
| Testing | pytest + pytest-asyncio + httpx + moto (S3 mock) |
| Deployment | Docker → fly.io / Railway / AWS ECS |
---
## Development Rules
1. **NEVER persist user data in plaintext.** The DB stores only auth, billing, storage metadata, and marketplace data. User context arrives in requests and is discarded. Cloud blobs are E2E encrypted client-side — backend only stores opaque bytes.
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. In plan mode, `prompt_template` fields are reference IDs only.
3. **NEVER decrypt user blobs.** `app/storage/encryption.py` only verifies checksums. No decryption key ever reaches the backend.
4. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
5. **Type hints everywhere.** All functions have full type annotations.
6. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
7. **Structured logging.** JSON logs with request ID correlation.
8. **Tier gates are enforced server-side.** Never trust client-reported tier. Always fetch from DB via `TierManager.get_tier(user_id)`.
9. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.

316
README.md
View File

@@ -1,8 +1,8 @@
# Adiuva Cloud API
# AdiuvAI Cloud API
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
**AI-powered project management backend with LLM orchestration and subscription billing.**
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe
---
@@ -20,9 +20,7 @@ Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
- [AI Agent System](#ai-agent-system)
- [Orchestration & Execution Plans](#orchestration--execution-plans)
- [Middleware](#middleware)
- [Storage Layer](#storage-layer)
- [Billing & Tiers](#billing--tiers)
- [Plugin Marketplace](#plugin-marketplace)
- [Testing](#testing)
- [Project Structure](#project-structure)
- [License](#license)
@@ -31,15 +29,13 @@ Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
## Overview
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
AdiuvAI Cloud API is the FastAPI backend that powers the **AdiuvAI Electron desktop app**. It provides LLM-powered chat orchestration, text embedding generation, and Stripe-based subscription billing across four tiers.
### Design Principles
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
1. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
2. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
3. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
---
@@ -54,27 +50,26 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
│ ┌──────────────────┐ ┌────────────────────────────┐ │
│ │ Auth Routes │ │ Chat Routes │ │
│ │ Billing Routes │ │ ↓ │ │
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
│ │ Backup Routes │ │ ↓ classify intent │ │
│ Plugin Routes │ │ Agent Registry │ │
Vector Routes │ │ ↓ │ │
Plans Routes │ │ TaskAgent | ProjectAgent │ │
└──────────────────┘ │ NoteAgent | CheckptAgent │ │
│ │ Agent Routes │ │ Orchestrator (GPT-4o-mini)│ │
│ │ Device WS │ │ ↓ classify intent │ │
└──────────────────┘ │ Agent Registry │ │
│ ↓ │ │
│ TaskAgent | ProjectAgent │ │
│ NoteAgent | CheckptAgent │ │
│ │ (GPT-4o + LangChain) │ │
│ └────────────────────────────┘ │
└────────────────────────────────────────────────────────┘
│ │
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
│ Billing, │ │ backups) │ │ (Vectors) │
Metadata) └───────────────┘ └────────────────┘
┌────────▼───┐
│ PostgreSQL │
│ (Auth, │
│ Billing, │
Agents)
└────────────┘
┌────────▼───┐
│ Stripe │
│ (Billing,
│ Connect) │
│ (Billing)
└────────────┘
```
@@ -83,20 +78,16 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
## Key Features
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Checkpoints (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
10. **Prompt IP protection**Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
4. **Text embeddings** — Generates text-embedding-3-small vectors for local client-side note search.
5. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
6. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
7. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
8. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
9. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
10. **Alembic migrations**Versioned schema management.
11. **Comprehensive test suite** — In-memory SQLite, per-tier test fixtures, and full API coverage without external dependencies.
---
@@ -114,7 +105,6 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
| `boto3` | ≥ 1.35.0 | AWS S3 client |
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
@@ -124,12 +114,9 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
| `websockets` | ≥ 14.0 | WebSocket protocol support |
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
| `pytest` | ≥ 8.0.0 | Test framework |
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
| `ruff` | ≥ 0.8.0 | Linter and formatter |
---
@@ -142,13 +129,12 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
- PostgreSQL 16+
- An OpenAI API key (for LLM features)
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
- AWS credentials (optional — needed for S3 storage in production)
### Installation
```bash
# Clone the repository
git clone <repo-url> && cd adiuva-api
git clone <repo-url> && cd adiuvai-api
# Create a virtual environment
python -m venv .venv && source .venv/bin/activate
@@ -194,11 +180,6 @@ This starts two services:
- **app** — FastAPI server on port `8000`
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
The compose file also includes optional services for fully local deployments:
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
### Dockerfile Details
The Dockerfile uses a multi-stage build:
@@ -216,7 +197,7 @@ gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0
## Homelab / Self-Hosted Deployment
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**.
### 1. Start all services
@@ -224,34 +205,13 @@ You can run the entire stack locally on a homelab with **no cloud dependencies e
docker compose up -d
```
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
This starts PostgreSQL alongside the app.
### 2. Create the MinIO bucket
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
```bash
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
docker compose exec minio mc mb local/adiuva
```
### 3. Configure your `.env`
### 2. Configure your `.env`
```bash
# Database (uses the compose PostgreSQL)
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
# S3 → MinIO
S3_BUCKET=adiuva
S3_REGION=us-east-1
S3_ENDPOINT_URL=http://minio:9000
AWS_ACCESS_KEY_ID=minioadmin
AWS_SECRET_ACCESS_KEY=minioadmin
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
QDRANT_URL=http://qdrant:6333
QDRANT_API_KEY=
PINECONE_API_KEY=
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuvai
# Billing — leave empty to stub (no Stripe needed)
STRIPE_SECRET_KEY=
@@ -267,7 +227,7 @@ JWT_SECRET=your-secret-here
ENV=dev
```
### 4. Run migrations
### 3. Run migrations
```bash
docker compose exec app alembic upgrade head
@@ -278,9 +238,7 @@ docker compose exec app alembic upgrade head
| Service | Runs on | Port | Notes |
|---|---|---|---|
| FastAPI app | Docker | 8000 | API server |
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
| PostgreSQL | Docker | 5432 | Auth, billing, agents |
| Stripe | — | — | Stubbed when keys are empty |
| OpenAI / LLM | Cloud | — | Only external dependency |
@@ -294,23 +252,13 @@ All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/
| Variable | Type | Default | Description |
|---|---|---|---|
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai` | Async SQLAlchemy connection string |
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
| `S3_REGION` | `str` | `us-east-1` | AWS region |
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
| `STRIPE_WEBHOOK_SECRET` | `str` | `\"\"` | Stripe webhook signature secret |\n| `OPENAI_API_KEY` | `str` | `\"\"` | OpenAI key for LLM agent calls |
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
@@ -342,6 +290,7 @@ All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebS
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
| `POST` | `/api/v1/chat/embed` | JWT | Generate a 1536-dim text embedding vector (`text-embedding-3-small`). Used by Electron for local note search. |
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
### Plans
@@ -351,42 +300,6 @@ All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebS
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
### Storage (Cloud Records)
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
### Vectors (Cloud Vector Store)
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
### Backup
| Method | Path | Auth | Description |
|---|---|---|---|
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
### Plugins (Marketplace)
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
### Billing
| Method | Path | Auth | Description |
@@ -400,7 +313,7 @@ All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebS
## Data Model
9 tables managed by Alembic migrations. Source: `app/models.py`
3 tables managed by Alembic migrations. Source: `app/models.py`
### Tables
@@ -409,27 +322,18 @@ All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebS
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
### Enum Types
| Enum | Values |
|---|---|
| `billing_tier` | `free`, `pro`, `power`, `team` |
| `plugin_status` | `pending_review`, `approved`, `rejected` |
| `review_decision` | `approved`, `rejected` |
### Migrations
| Version | Description |
|---|---|
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
| `001_initial_schema` | Creates core auth and billing tables with indexes and foreign key constraints |
---
@@ -439,7 +343,7 @@ The agent system uses a registry pattern with LangChain tool-calling agents powe
### Architecture
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
- **`BaseAgent`** — Abstract base with `user_id` and `shared_memory`.
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
@@ -449,7 +353,7 @@ The agent system uses a registry pattern with LangChain tool-calling agents powe
|---|---|---|---|
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
| **CheckpointAgent** | `checkpoint_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_checkpoints`, `create_checkpoint`, `update_checkpoint`, `delete_checkpoint` |
| **TimelineAgent** | `timeline_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_timelines`, `create_timeline`, `update_timeline`, `delete_timeline` |
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
@@ -504,7 +408,7 @@ Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
### Built-in Templates (6)
`tpl_task_agent_default`, `tpl_checkpoint_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
`tpl_task_agent_default`, `tpl_timeline_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
### Built-in Playbooks (2)
@@ -554,39 +458,6 @@ Source: `app/api/middleware/sanitizer.py`
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
- Logs sanitization events as `WARNING`.
- Binary responses (storage, backup) are never touched.
---
## Storage Layer
### Blob Store
Source: `app/storage/blob_store.py`
- S3-backed storage for E2E encrypted blobs.
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
- The backend **never inspects or decrypts blob content**.
### Vector Store
Source: `app/storage/vector_store.py`
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
- Methods: `upsert()`, `search()`, `delete()`
### Encryption Utilities
Source: `app/storage/encryption.py`
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
- **No decryption key ever reaches the backend.**
---
@@ -600,11 +471,8 @@ Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
|---|---|---|---|---|
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
| Batch Active | 2 | 10 | Unlimited | Unlimited |
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
| Batch Builder | — | — | ✓ | ✓ |
| Plugin Marketplace | — | — | ✓ | ✓ |
| SSO | — | — | — | ✓ |
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
@@ -620,47 +488,6 @@ Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
- `get_tier(user_id)` — Returns the user's current billing tier.
- `check_feature(tier, feature)` — Boolean feature gate check.
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
---
## Plugin Marketplace
Source: `app/marketplace/`
### Plugin Registry
- PostgreSQL-backed catalog of submitted and approved plugins.
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
### Review Queue
- Automated security checklist before human review:
- Plugin ID must match `^[a-z0-9-]+$`
- Permissions must be from the allowed set only
- No binary blobs in the manifest
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:checkpoints`, `write:checkpoints`, `read:calendar`, `write:calendar`
- `get_pending(db)` — Lists plugins awaiting review.
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
### Revenue Sharing
- **70% developer / 30% platform** split on all paid plugin sales.
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
- Gracefully stubs transfers when Stripe is not configured.
### Seed Plugins
| Plugin | Category | Price |
|---|---|---|
| GitHub Sync | Productivity | Free |
| Slack Notifier | Communication | €4.99 |
| Time Tracker | Productivity | €9.99 |
---
@@ -682,10 +509,8 @@ pytest -v
### Test Infrastructure
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
- **No external dependencies** — all tests run fully offline.
@@ -694,13 +519,6 @@ pytest -v
| File | Coverage |
|---|---|
| `test_auth.py` | Register, login, token access, refresh, expiration |
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
---
@@ -708,9 +526,8 @@ pytest -v
## Project Structure
```
adiuva-api/
adiuvai-api/
├── alembic.ini # Alembic configuration
├── BACKEND_PLAN.md # Architecture & design decisions
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
├── Dockerfile # Multi-stage production build
├── requirements.txt # Python dependencies
@@ -719,13 +536,12 @@ adiuva-api/
│ ├── env.py # Alembic environment config
│ ├── script.py.mako # Migration template
│ └── versions/
── 001_initial_schema.py # Tables, indexes, FKs
│ └── 002_seed_plugins.py # Seed marketplace plugins
── 001_initial_schema.py # Tables, indexes, FKs
├── app/ # Application source
│ ├── main.py # FastAPI app factory, middleware, routes
│ ├── db.py # Async SQLAlchemy engine & session
│ ├── models.py # SQLAlchemy ORM models (9 tables)
│ ├── models.py # SQLAlchemy ORM models
│ ├── schemas.py # Pydantic request/response schemas
│ │
│ ├── config/
@@ -734,53 +550,35 @@ adiuva-api/
│ ├── agents/ # LLM-powered domain agents
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
│ │ ├── project_agent.py # Project lifecycle (6 tools)
│ │ ├── checkpoint_agent.py # Milestones (4 tools)
│ │ ├── timeline_agent.py # Milestones (4 tools)
│ │ └── note_agent.py # Markdown notes (5 tools)
│ │
│ ├── core/ # Orchestration engine
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
│ │ ── orchestrator.py # Intent classification & routing
│ │ └── execution_plan.py # Plan builder, templates, cache
│ │ ── deep_agent.py # Deep agent orchestration
│ │
│ ├── api/ # HTTP layer
│ │ ├── deps.py # Shared FastAPI dependencies
│ │ ├── middleware/
│ │ │ ├── auth.py # JWT validation, live tier lookup
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
│ │ │ └── sanitizer.py # Prompt IP leak protection
│ │ └── routes/
│ │ ├── auth.py # Register, login, refresh, me
│ │ ├── chat.py # Chat + WebSocket streaming
│ │ ├── plans.py # Execution plan playbooks
│ │ ├── storage.py # E2E encrypted record CRUD
│ │ ── vectors.py # Vector upsert, search, delete
│ │ ├── backup.py # Encrypted backup management
│ │ ├── plugins.py # Marketplace browse & install
│ │ └── billing.py # Stripe checkout & webhooks
│ │ ├── chat.py # Chat + embed endpoint
│ │ ├── billing.py # Stripe checkout, webhooks, subscription
│ │ ├── agents.py # Agent catalog, config, runs
│ │ ── device_ws.py # Persistent device WebSocket
│ │
── storage/ # Storage backends
├── blob_store.py # S3 blob storage
── vector_store.py # Pinecone / Qdrant vector store
│ │ └── encryption.py # Checksum verification utilities
│ │
│ ├── billing/ # Subscription management
│ │ ├── stripe_service.py # Stripe API integration
│ │ └── tier_manager.py # Feature matrix & quota enforcement
│ │
│ └── marketplace/ # Plugin ecosystem
│ ├── plugin_registry.py # Catalog CRUD & search
│ ├── plugin_review.py # Security checklist & review queue
│ └── revenue_share.py # 70/30 split & Stripe Connect
── billing/
├── stripe_service.py # Stripe API wrapper
── tier_manager.py # Feature matrix, rate limits
└── tests/ # Test suite
├── conftest.py # Fixtures: DB, S3, auth, seeds
├── conftest.py # Fixtures: DB, auth, seeds
├── test_auth.py
├── test_orchestrator.py
├── test_agents.py
├── test_storage.py
├── test_backup.py
├── test_plugins.py
├── test_agent_registry.py
├── test_execution_plan.py
└── test_middleware.py

View File

@@ -1,5 +1,4 @@
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
"""Initial schema: users, refresh_tokens, subscriptions.
Revision ID: 001
Revises:
@@ -21,18 +20,13 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ── Enum types ────────────────────────────────────────────────────────
billing_tier = postgresql.ENUM(
"free", "pro", "power", "team", name="billing_tier", create_type=False
)
plugin_status = postgresql.ENUM(
"pending_review", "approved", "rejected", name="plugin_status", create_type=False
)
review_decision = postgresql.ENUM(
"approved", "rejected", name="review_decision", create_type=False
)
for enum in (billing_tier, plugin_status, review_decision):
enum.create(op.get_bind(), checkfirst=True)
# ── Enum types — idempotent creation via exception handling ───────────
op.execute("""
DO $$ BEGIN
CREATE TYPE billing_tier AS ENUM ('free', 'pro', 'power', 'team');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
# ── users ─────────────────────────────────────────────────────────────
op.create_table(
@@ -40,7 +34,7 @@ def upgrade() -> None:
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("email", sa.String(255), nullable=False),
sa.Column("password_hash", sa.String(255), nullable=False),
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
sa.Column("tier", postgresql.ENUM("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
sa.Column("stripe_customer_id", sa.String(255), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
@@ -70,7 +64,7 @@ def upgrade() -> None:
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("stripe_subscription_id", sa.String(255), nullable=True),
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
sa.Column("tier", postgresql.ENUM("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
sa.Column("status", sa.String(50), nullable=False, server_default="free"),
sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
@@ -81,122 +75,10 @@ def upgrade() -> None:
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
# ── storage_records ───────────────────────────────────────────────────
op.create_table(
"storage_records",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("table_name", sa.String(100), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
# ── backup_metadata ───────────────────────────────────────────────────
op.create_table(
"backup_metadata",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("version", sa.Integer, nullable=False),
sa.Column("timestamp", sa.BigInteger, nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
# ── plugins ───────────────────────────────────────────────────────────
op.create_table(
"plugins",
sa.Column("id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("description", sa.Text, nullable=False, server_default=""),
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
sa.Column("category", sa.String(100), nullable=False, server_default=""),
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
sa.Column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
sa.Column("s3_package_key", sa.String(500), nullable=True),
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
sa.Column("rejection_reason", sa.Text, nullable=True),
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
)
# ── plugin_installations ──────────────────────────────────────────────
op.create_table(
"plugin_installations",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
)
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
# ── plugin_reviews ────────────────────────────────────────────────────
op.create_table(
"plugin_reviews",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("decision", sa.Enum("approved", "rejected", name="review_decision", create_type=False), nullable=False),
sa.Column("notes", sa.Text, nullable=True),
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
)
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
# ── revenue_events ────────────────────────────────────────────────────
op.create_table(
"revenue_events",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
def downgrade() -> None:
op.drop_table("revenue_events")
op.drop_table("plugin_reviews")
op.drop_table("plugin_installations")
op.drop_table("plugins")
op.drop_table("backup_metadata")
op.drop_table("storage_records")
op.drop_table("subscriptions")
op.drop_table("refresh_tokens")
op.drop_table("users")
op.execute("DROP TYPE IF EXISTS review_decision")
op.execute("DROP TYPE IF EXISTS plugin_status")
op.execute("DROP TYPE IF EXISTS billing_tier")

View File

@@ -1,92 +0,0 @@
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
Revision ID: 002
Revises: 001
Create Date: 2026-03-03
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "002"
down_revision: Union[str, None] = "001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
_SEED_PLUGINS = [
{
"id": "plugin-github-sync",
"name": "GitHub Sync",
"description": "Sync tasks with GitHub Issues and pull requests.",
"version": "1.0.0",
"author_name": "Adiuva",
"category": "productivity",
"price_cents": 0,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-slack-notify",
"name": "Slack Notifier",
"description": "Post task and checkpoint updates to Slack channels.",
"version": "1.2.0",
"author_name": "Adiuva",
"category": "communication",
"price_cents": 499,
"permissions": json.dumps(["read:tasks", "read:checkpoints"]),
"status": "approved",
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-time-tracker",
"name": "Time Tracker",
"description": "Track time spent on tasks with automatic reporting.",
"version": "0.9.1",
"author_name": "Third Party",
"category": "productivity",
"price_cents": 999,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
]
def upgrade() -> None:
plugins = sa.table(
"plugins",
sa.column("id", sa.String),
sa.column("name", sa.String),
sa.column("description", sa.Text),
sa.column("version", sa.String),
sa.column("author_name", sa.String),
sa.column("category", sa.String),
sa.column("price_cents", sa.Integer),
sa.column("permissions", sa.Text),
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
sa.column("s3_package_key", sa.String),
sa.column("install_count", sa.Integer),
sa.column("avg_rating", sa.Float),
)
op.bulk_insert(plugins, _SEED_PLUGINS)
def downgrade() -> None:
op.execute(
"DELETE FROM plugins WHERE id IN ("
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
")"
)

View File

@@ -0,0 +1,127 @@
"""Add agent config and run log tables: local_agent_configs, cloud_agent_configs, agent_run_logs.
Revision ID: 003
Revises: 002
Create Date: 2026-03-05
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "003"
down_revision: Union[str, None] = "002"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ── Enum types — idempotent creation ──────────────────────────────────
op.execute("""
DO $$ BEGIN
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
# ── local_agent_configs ───────────────────────────────────────────────
op.create_table(
"local_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("device_id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
# ── cloud_agent_configs ───────────────────────────────────────────────
op.create_table(
"cloud_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column(
"provider",
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
nullable=False,
),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
sa.Column("filter_config", sa.JSON, nullable=True),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
# ── agent_run_logs ─────────────────────────────────────────────────────
op.create_table(
"agent_run_logs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
# Plain string — not a FK because it references either local_agent_configs or
# cloud_agent_configs depending on agent_type.
sa.Column("agent_id", sa.String(255), nullable=False),
sa.Column(
"agent_type",
postgresql.ENUM("local", "cloud", name="agent_type", create_type=False),
nullable=False,
),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column(
"status",
postgresql.ENUM("running", "success", "error", "partial", name="agent_run_status", create_type=False),
nullable=False,
server_default="running",
),
sa.Column("items_processed", sa.Integer, nullable=False, server_default="0"),
sa.Column("items_created", sa.Integer, nullable=False, server_default="0"),
sa.Column("errors", sa.JSON, nullable=True),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_agent_run_logs_user_id", "agent_run_logs", ["user_id"])
op.create_index("ix_agent_run_logs_agent_id", "agent_run_logs", ["agent_id"])
def downgrade() -> None:
op.drop_table("agent_run_logs")
op.drop_table("cloud_agent_configs")
op.drop_table("local_agent_configs")
op.execute("DROP TYPE IF EXISTS cloud_provider;")
op.execute("DROP TYPE IF EXISTS agent_run_status;")
op.execute("DROP TYPE IF EXISTS agent_type;")

View File

@@ -0,0 +1,144 @@
"""Add memory tables and user encryption_key column.
Memory tables:
memory_core — per-user key/value preferences (encrypted)
memory_associative — semantic memory with pgvector embedding (encrypted)
memory_episodic — session summaries (encrypted)
memory_proactive — behavioral patterns (encrypted)
Also adds encryption_key column to users table.
Revision ID: 004
Revises: 003
Create Date: 2026-03-08
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "004"
down_revision: Union[str, None] = "003"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ── Enable pgvector extension (idempotent) ────────────────────────────────
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
# ── Add encryption_key to users ───────────────────────────────────────────
op.add_column(
"users",
sa.Column("encryption_key", sa.String(64), nullable=True),
)
# ── memory_core ───────────────────────────────────────────────────────────
op.create_table(
"memory_core",
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=False),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("key", sa.String(255), nullable=False),
sa.Column("value_encrypted", sa.Text, nullable=False),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.create_index("ix_memory_core_user_id", "memory_core", ["user_id"])
# ── memory_associative ────────────────────────────────────────────────────
# The embedding column uses pgvector's vector(1536) type.
op.create_table(
"memory_associative",
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=False),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("content_encrypted", sa.Text, nullable=False),
sa.Column("entity_type", sa.String(100), nullable=True),
sa.Column("entity_id", sa.String(255), nullable=True),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
# Add the pgvector column separately (not supported by generic sa types)
op.execute(
"ALTER TABLE memory_associative ADD COLUMN embedding vector(1536);"
)
op.create_index("ix_memory_associative_user_id", "memory_associative", ["user_id"])
# IVFFlat index for approximate nearest-neighbour search
op.execute(
"CREATE INDEX ix_memory_associative_embedding "
"ON memory_associative USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);"
)
# ── memory_episodic ───────────────────────────────────────────────────────
op.create_table(
"memory_episodic",
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=False),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("summary_encrypted", sa.Text, nullable=False),
sa.Column("session_id", sa.String(255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.create_index("ix_memory_episodic_user_id", "memory_episodic", ["user_id"])
op.create_index("ix_memory_episodic_session_id", "memory_episodic", ["session_id"])
# ── memory_proactive ──────────────────────────────────────────────────────
op.create_table(
"memory_proactive",
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=False),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("pattern_encrypted", sa.Text, nullable=False),
sa.Column("confidence", sa.Float, nullable=False, server_default="0.5"),
sa.Column("source", sa.String(50), nullable=False, server_default="inferred"),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.create_index("ix_memory_proactive_user_id", "memory_proactive", ["user_id"])
def downgrade() -> None:
op.drop_table("memory_proactive")
op.drop_table("memory_episodic")
op.drop_index("ix_memory_associative_embedding", "memory_associative")
op.drop_table("memory_associative")
op.drop_table("memory_core")
op.drop_column("users", "encryption_key")

View File

@@ -0,0 +1,30 @@
"""add name and surname to users table
Revision ID: 818478c251dc
Revises: 004
Create Date: 2026-03-10 15:10:42.811947
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '818478c251dc'
down_revision: Union[str, None] = '004'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column('users', sa.Column('name', sa.String(length=100), nullable=True))
op.add_column('users', sa.Column('surname', sa.String(length=100), nullable=True))
def downgrade() -> None:
op.drop_column('users', 'surname')
op.drop_column('users', 'name')

View File

@@ -0,0 +1,92 @@
"""Deprecate backend agent config tables.
The Electron client is now the source of truth for agent configuration
(directory, extract targets, batch interval, custom prompt). Backend keeps
billing checks and trigger/run logs only.
Revision ID: 9a1f2d0b6c7e
Revises: 818478c251dc
Create Date: 2026-03-16
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "9a1f2d0b6c7e"
down_revision: Union[str, None] = "818478c251dc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
existing = set(inspector.get_table_names())
if "cloud_agent_configs" in existing:
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
op.drop_table("cloud_agent_configs")
if "local_agent_configs" in existing:
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
op.drop_table("local_agent_configs")
def downgrade() -> None:
op.create_table(
"local_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("device_id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
op.execute(
"""
DO $$ BEGIN
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
"""
)
op.create_table(
"cloud_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column(
"provider",
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
nullable=False,
),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
sa.Column("filter_config", sa.JSON, nullable=True),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])

View File

@@ -0,0 +1,31 @@
"""add agent_config to local_agent_configs
Revision ID: a3b9c0d1e2f3
Revises: 9a1f2d0b6c7e
Create Date: 2026-04-07 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "a3b9c0d1e2f3"
down_revision: Union[str, None] = "9a1f2d0b6c7e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"local_agent_configs",
sa.Column("agent_config", sa.JSON(), nullable=True),
)
def downgrade() -> None:
op.drop_column("local_agent_configs", "agent_config")

View File

@@ -1,5 +1,5 @@
"""Import all agent modules to trigger @registry.register decorators."""
"""Expose tool modules used by deep orchestrator-worker graphs."""
from app.agents import checkpoint_agent, note_agent, project_agent, task_agent
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
__all__ = ["checkpoint_agent", "note_agent", "project_agent", "task_agent"]
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]

View File

@@ -1,121 +0,0 @@
"""Checkpoint agent — project milestone management (list, create, update, delete)."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
_SYSTEM_PROMPT = (
"You are a project checkpoint assistant. Checkpoints are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a checkpoint, 0 otherwise\n"
" - is_approved: 0 until the user explicitly confirms; then 1\n"
" - For update_checkpoint, use -1 for integer fields you do not want to change\n"
" - Listing without a project_id returns all checkpoints across projects\n"
" - Always echo the title and formatted date in your confirmation."
)
@tool
async def list_checkpoints(project_id: str = "") -> str:
"""List checkpoints. Provide project_id to scope to a specific project."""
return json.dumps({
"action": "list",
"table": "checkpoints",
"filters": {"projectId": project_id or None},
})
@tool
async def create_checkpoint(
project_id: str,
title: str,
date: int,
is_ai_suggested: int = 0,
is_approved: int = 0,
) -> str:
"""Create a project checkpoint (milestone).
project_id: REQUIRED UUID of the parent project
title: descriptive name for the milestone
date: Unix timestamp in milliseconds
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
is_approved: 0 until the user confirms
"""
return json.dumps({
"action": "create_record",
"table": "checkpoints",
"data": {
"projectId": project_id,
"title": title,
"date": date,
"isAiSuggested": is_ai_suggested,
"isApproved": is_approved,
},
})
@tool
async def update_checkpoint(
checkpoint_id: str,
title: str = "",
date: int = -1,
is_approved: int = -1,
) -> str:
"""Update a checkpoint. Only pass fields that should change.
checkpoint_id: UUID of the checkpoint (required)
date: -1 means unchanged; any other value sets the new date (ms timestamp)
is_approved: -1 means unchanged; 0 or 1 sets the approval state
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if date != -1:
updates["date"] = date
if is_approved != -1:
updates["isApproved"] = is_approved
return json.dumps({
"action": "update_record",
"table": "checkpoints",
"data": {"id": checkpoint_id, "updates": updates},
})
@tool
async def delete_checkpoint(checkpoint_id: str) -> str:
"""Delete a checkpoint permanently by its UUID."""
return json.dumps({
"action": "delete_record",
"table": "checkpoints",
"data": {"id": checkpoint_id},
})
@registry.register
class CheckpointAgent(ChatAgent):
def get_name(self) -> str:
return "checkpoint_agent"
def get_description(self) -> str:
return "Manages project checkpoints (milestones): list, create, update, delete"
def get_tools(self) -> list[Any]:
return [list_checkpoints, create_checkpoint, update_checkpoint, delete_checkpoint]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -0,0 +1,85 @@
"""Filesystem agent — tools for reading local directories and files on Electron.
These tools delegate to the Electron client via ``execute_on_client()`` using
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
handles actual disk I/O and responds with ``tool_result`` frames.
"""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
@tool
async def list_directory(path: str) -> str:
"""List files and folders in a local directory on the user's device.
Returns a formatted listing of entries with name, type (file/directory),
and full path.
"""
result = await execute_on_client(
action="list_directory",
data={"path": path},
)
entries: list[dict[str, Any]] = result.get("entries", [])
if not entries:
return f"Directory '{path}' is empty or does not exist."
lines: list[str] = []
for entry in entries:
entry_type = entry.get("type", "unknown")
entry_name = entry.get("name", "")
entry_path = entry.get("path", "")
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
@tool
async def read_file_content(path: str) -> str:
"""Read the text content of a local file on the user's device.
Returns the file content as a string. Large files may be truncated
by the Electron client.
"""
result = await execute_on_client(
action="read_file_content",
data={"path": path},
)
content: str = result.get("content", "")
if not content:
return f"File '{path}' is empty or could not be read."
return content
@tool
async def get_file_metadata(path: str) -> str:
"""Get metadata for a local file: size, creation date, modification date, extension.
Returns a formatted summary of the file's metadata.
"""
result = await execute_on_client(
action="get_file_metadata",
data={"path": path},
)
size = result.get("size", "unknown")
created = result.get("createdAt", "unknown")
modified = result.get("modifiedAt", "unknown")
extension = result.get("extension", "unknown")
name = result.get("name", path)
return (
f"File: {name}\n"
f" Extension: {extension}\n"
f" Size: {size} bytes\n"
f" Created: {created}\n"
f" Modified: {modified}"
)
FILESYSTEM_TOOLS: list[Any] = [
list_directory,
read_file_content,
get_file_metadata,
]

View File

@@ -2,16 +2,23 @@
from __future__ import annotations
import json
import re
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
from app.core.llm import embed
from app.core.ws_context import execute_on_client
_SYSTEM_PROMPT = (
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
NOTE_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
@@ -21,6 +28,7 @@ _SYSTEM_PROMPT = (
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@@ -29,21 +37,27 @@ _SYSTEM_PROMPT = (
@tool
async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id."""
return json.dumps({
"action": "list",
"table": "notes",
"filters": {"projectId": project_id or None},
})
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="notes",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No notes found."
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
@tool
async def get_note(note_id: str) -> str:
"""Fetch a single note by its UUID to read its full Markdown content."""
return json.dumps({
"action": "get",
"table": "notes",
"data": {"id": note_id},
})
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
row = result.get("row")
if not row:
return f"Note {note_id} not found."
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
@tool
@@ -57,15 +71,24 @@ async def create_note(
content: Markdown body text (required)
project_id: optional UUID linking this note to a project
"""
return json.dumps({
"action": "create_record",
"table": "notes",
"data": {
result = await execute_on_client(
action="insert",
table="notes",
data={
"title": title,
"content": content,
"projectId": project_id or None,
},
})
)
row = result["row"]
# Index the note content in the vector store.
vector = await embed(content)
await execute_on_client(
action="vector_upsert",
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note created: '{row['title']}' (id: {row['id']})."
@tool
@@ -83,40 +106,34 @@ async def update_note(
updates["title"] = title
if content:
updates["content"] = content
return json.dumps({
"action": "update_record",
"table": "notes",
"data": {"id": note_id, "updates": updates},
})
result = await execute_on_client(
action="update",
table="notes",
data={"id": note_id, "updates": updates},
)
row = result["row"]
# Re-index if content changed.
if content:
vector = await embed(content)
await execute_on_client(
action="vector_upsert",
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note updated: '{row['title']}' (id: {row['id']})."
@tool
async def delete_note(note_id: str) -> str:
"""Delete a note permanently by its UUID."""
return json.dumps({
"action": "delete_record",
"table": "notes",
"data": {"id": note_id},
})
await execute_on_client(action="delete", table="notes", data={"id": note_id})
return f"Note {note_id} deleted."
@registry.register
class NoteAgent(ChatAgent):
def get_name(self) -> str:
return "note_agent"
def get_description(self) -> str:
return "Manages notes: list, get, create, update, delete"
def get_tools(self) -> list[Any]:
return [list_notes, get_note, create_note, update_note, delete_note]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())
NOTE_TOOLS: list[Any] = [
list_notes,
get_note,
create_note,
update_note,
delete_note,
]

View File

@@ -2,16 +2,13 @@
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
from app.core.ws_context import execute_on_client
_SYSTEM_PROMPT = (
PROJECT_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
@@ -36,14 +33,19 @@ async def list_projects(
"""List projects, optionally filtered by client_id.
include_archived: 1 to include archived projects, 0 for active only (default).
"""
return json.dumps({
"action": "list",
"table": "projects",
"filters": {
result = await execute_on_client(
action="select",
table="projects",
filters={
"clientId": client_id or None,
"includeArchived": bool(include_archived),
},
})
)
rows = result.get("rows", [])
if not rows:
return "No projects found."
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
@tool
@@ -51,20 +53,25 @@ async def list_all_projects() -> str:
"""List every project regardless of client or status.
Use only when the user wants a complete cross-client overview.
"""
return json.dumps({
"action": "list_all",
"table": "projects",
})
result = await execute_on_client(action="select", table="projects")
rows = result.get("rows", [])
if not rows:
return "No projects found."
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
@tool
async def get_project(project_id: str) -> str:
"""Fetch a single project by its UUID."""
return json.dumps({
"action": "get",
"table": "projects",
"data": {"id": project_id},
})
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
row = result.get("row")
if not row:
return f"Project {project_id} not found."
return (
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
f"clientId: {row.get('clientId', 'none')})"
)
@tool
@@ -76,14 +83,13 @@ async def create_project(
name: human-readable project name (required)
client_id: optional UUID of the owning client
"""
return json.dumps({
"action": "create_record",
"table": "projects",
"data": {
"name": name,
"clientId": client_id or None,
},
})
result = await execute_on_client(
action="insert",
table="projects",
data={"name": name, "clientId": client_id or None},
)
row = result["row"]
return f"Project created: '{row['name']}' (id: {row['id']})"
@tool
@@ -108,11 +114,13 @@ async def update_project(
updates["status"] = status
if ai_summary:
updates["aiSummary"] = ai_summary
return json.dumps({
"action": "update_record",
"table": "projects",
"data": {"id": project_id, "updates": updates},
})
result = await execute_on_client(
action="update",
table="projects",
data={"id": project_id, "updates": updates},
)
row = result["row"]
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
@tool
@@ -121,37 +129,15 @@ async def delete_project(project_id: str) -> str:
IMPORTANT: prefer update_project(status='archived') unless the user
has explicitly confirmed they want permanent deletion.
"""
return json.dumps({
"action": "delete_record",
"table": "projects",
"data": {"id": project_id},
})
await execute_on_client(action="delete", table="projects", data={"id": project_id})
return f"Project {project_id} permanently deleted."
@registry.register
class ProjectAgent(ChatAgent):
def get_name(self) -> str:
return "project_agent"
def get_description(self) -> str:
return "Manages projects: list, get, create, update, archive, delete"
def get_tools(self) -> list[Any]:
return [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())
PROJECT_TOOLS: list[Any] = [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]

View File

@@ -2,16 +2,23 @@
from __future__ import annotations
import json
from datetime import datetime, timezone
import re
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
from app.core.ws_context import execute_on_client
_SYSTEM_PROMPT = (
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TASK_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n"
@@ -22,7 +29,7 @@ _SYSTEM_PROMPT = (
" - project_id is optional; link to a project when the user mentions one\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
" did not explicitly request; 0 otherwise\n"
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
" - Use list_tasks_due_today for 'what's due today' queries\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
@@ -41,16 +48,25 @@ async def list_tasks(
) -> str:
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
a search string, or an order_by field name (dueDate|priority|createdAt)."""
return json.dumps({
"action": "list",
"table": "tasks",
"filters": {
"projectId": project_id or None,
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="tasks",
filters={
"projectId": normalized_project_id or None,
"status": status or None,
"search": search or None,
"orderBy": order_by or None,
},
})
)
rows = result.get("rows", [])
if not rows:
return "No tasks found matching the given filters."
lines = [
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
for r in rows
]
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
@tool
@@ -63,7 +79,6 @@ async def create_task(
due_date: int = 0,
project_id: str = "",
is_ai_suggested: int = 0,
is_approved: int = 0,
) -> str:
"""Create a new task.
title: task title (required)
@@ -74,12 +89,11 @@ async def create_task(
due_date: Unix timestamp in milliseconds; 0 means no due date
project_id: optional UUID of the parent project
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
is_approved: 0 until the user confirms; 1 when confirmed
"""
return json.dumps({
"action": "create_record",
"table": "tasks",
"data": {
result = await execute_on_client(
action="insert",
table="tasks",
data={
"title": title,
"description": description or None,
"status": status,
@@ -88,9 +102,13 @@ async def create_task(
"dueDate": due_date or None,
"projectId": project_id or None,
"isAiSuggested": is_ai_suggested,
"isApproved": is_approved,
},
})
)
row = result["row"]
return (
f"Task created: '{row['title']}' "
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
)
@tool
@@ -103,12 +121,10 @@ async def update_task(
assignees: str = "",
due_date: int = -1,
project_id: str = "",
is_approved: int = -1,
) -> str:
"""Update fields on an existing task. Only pass fields you want to change.
task_id: the task's UUID (required)
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
is_approved: -1 means unchanged; 0 or 1 sets the value
"""
updates: dict[str, Any] = {}
if title:
@@ -125,32 +141,41 @@ async def update_task(
updates["dueDate"] = due_date or None
if project_id:
updates["projectId"] = project_id
if is_approved != -1:
updates["isApproved"] = is_approved
return json.dumps({
"action": "update_record",
"table": "tasks",
"data": {"id": task_id, "updates": updates},
})
result = await execute_on_client(
action="update",
table="tasks",
data={"id": task_id, "updates": updates},
)
row = result["row"]
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
@tool
async def delete_task(task_id: str) -> str:
"""Delete a task permanently by its UUID."""
return json.dumps({
"action": "delete_record",
"table": "tasks",
"data": {"id": task_id},
})
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
return f"Task {task_id} deleted."
@tool
async def list_tasks_due_today() -> str:
"""List all tasks whose due date falls on today's date."""
return json.dumps({
"action": "list_due_today",
"table": "tasks",
})
now = datetime.now(tz=timezone.utc)
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
end_ms = start_ms + 86_400_000 - 1 # last ms of today
result = await execute_on_client(
action="select",
table="tasks",
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
)
rows = result.get("rows", [])
if not rows:
return "No tasks are due today."
lines = [
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
for r in rows
]
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
# ── Task comment tools ────────────────────────────────────────────────
@@ -159,11 +184,16 @@ async def list_tasks_due_today() -> str:
@tool
async def list_task_comments(task_id: str) -> str:
"""List all comments on a task by its UUID."""
return json.dumps({
"action": "list",
"table": "taskComments",
"filters": {"taskId": task_id},
})
result = await execute_on_client(
action="select",
table="taskComments",
filters={"taskId": task_id},
)
rows = result.get("rows", [])
if not rows:
return f"No comments found for task {task_id}."
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
@tool
@@ -173,56 +203,36 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
author: name or ID of the comment author
content: comment text
"""
return json.dumps({
"action": "create_record",
"table": "taskComments",
"data": {
"taskId": task_id,
"author": author,
"content": content,
},
})
result = await execute_on_client(
action="insert",
table="taskComments",
data={"taskId": task_id, "author": author, "content": content},
)
row = result.get("row", {})
row_author = row.get("author", author)
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
row_task_id = row.get("taskId") or row.get("task_id") or task_id
row_comment_id = row.get("id", "unknown")
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
@tool
async def delete_task_comment(comment_id: str) -> str:
"""Delete a task comment by its UUID."""
return json.dumps({
"action": "delete_record",
"table": "taskComments",
"data": {"id": comment_id},
})
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
return f"Comment {comment_id} deleted."
# ── Agent ─────────────────────────────────────────────────────────────
@registry.register
class TaskAgent(ChatAgent):
def get_name(self) -> str:
return "task_agent"
def get_description(self) -> str:
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
def get_tools(self) -> list[Any]:
return [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())
TASK_TOOLS: list[Any] = [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]

View File

@@ -0,0 +1,114 @@
"""Timeline agent — project milestone management (list, create, update, delete)."""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TIMELINE_SYSTEM_PROMPT = (
"You are a project timeline assistant. Timelines are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - For update_timeline, use -1 for integer fields you do not want to change\n"
" - Listing without a project_id returns all timelines across projects\n"
" - Always echo the title and formatted date in your confirmation."
)
@tool
async def list_timelines(project_id: str = "") -> str:
"""List timelines. Provide project_id to scope to a specific project."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="timelines",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No timelines found."
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
@tool
async def create_timeline(
project_id: str,
title: str,
date: int,
is_ai_suggested: int = 0,
) -> str:
"""Create a project timeline (milestone).
project_id: REQUIRED UUID of the parent project
title: descriptive name for the milestone
date: Unix timestamp in milliseconds
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
"""
result = await execute_on_client(
action="insert",
table="timelines",
data={
"projectId": project_id,
"title": title,
"date": date,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
@tool
async def update_timeline(
timeline_id: str,
title: str = "",
date: int = -1,
) -> str:
"""Update a timeline. Only pass fields that should change.
timeline_id: UUID of the timeline (required)
date: -1 means unchanged; any other value sets the new date (ms timestamp)
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if date != -1:
updates["date"] = date
result = await execute_on_client(
action="update",
table="timelines",
data={"id": timeline_id, "updates": updates},
)
row = result["row"]
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
@tool
async def delete_timeline(timeline_id: str) -> str:
"""Delete a timeline permanently by its UUID."""
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
return f"Timeline {timeline_id} deleted."
TIMELINE_TOOLS: list[Any] = [
list_timelines,
create_timeline,
update_timeline,
delete_timeline,
]

View File

@@ -55,11 +55,26 @@ async def get_current_user(
raise credentials_exc
# Live tier lookup — subscription row is the authoritative source.
from app.models import Subscription # noqa: PLC0415
# In dev, fall back to 'power' (unlimited) so quota limits don't
# block local development when no Stripe subscription exists.
from app.models import Subscription, User # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
tier: str = result.scalar_one_or_none() or "free"
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
# Fetch name/surname from user row.
user_result = await db.execute(
select(User.name, User.surname).where(User.id == user_id)
)
user_row = user_result.one_or_none()
return UserProfile(
id=user_id,
email=email,
name=user_row.name if user_row else None,
surname=user_row.surname if user_row else None,
tier=tier,
) # type: ignore[arg-type]

View File

@@ -8,8 +8,7 @@ that could reveal server-side prompt IP:
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
- Exact-match known prompt fingerprints
Binary responses (storage blobs, backup data) are never touched — the
middleware only activates for paths under /api/v1/chat.
The middleware only activates for paths under /api/v1/chat.
Any sanitisation event is logged as a WARNING with the request path and the
names of the fields that were modified.

View File

@@ -0,0 +1,495 @@
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig.
The journey is driven entirely through WebSocket frames (no REST endpoints).
The device WS handler dispatches ``journey_start`` and ``journey_message``
frames to the functions exported here.
Journey flow:
1. FE sends ``journey_start`` frame with basic agent info (directory,
data_types, schedule).
2. Server creates an in-memory session, sets up a WS executor so the
setup LLM can use file-system tools, does a first directory scrape,
and sends back a ``journey_reply`` with the first question.
3. FE sends ``journey_message`` frames for each user reply.
4. Server appends the user message, calls the LLM (which may read files
via tools), and sends back a ``journey_reply``.
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
6. Server parses and validates the JSON with Pydantic, sends
``journey_reply`` with ``done=True`` and the serialised config.
FE stores it locally.
"""
from __future__ import annotations
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from app.config.settings import settings
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
from app.core.llm import get_llm
from app.schemas import AgentConfig
logger = logging.getLogger(__name__)
# ── Session TTL ───────────────────────────────────────────────────────────
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON.
_CONFIG_START = "AGENT_CONFIG_START"
_CONFIG_END = "AGENT_CONFIG_END"
# Minimum turns before we consider nudging the LLM to wrap up.
_MIN_TURNS_BEFORE_NUDGE: int = 3
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
_MAX_TURNS: int = 15
# Max tool-calling steps per LLM invocation.
_MAX_TOOL_STEPS: int = 6
# ── In-memory session store ───────────────────────────────────────────────
@dataclass
class JourneySession:
session_id: str
user_id: str
agent_type: str # "local" | "cloud"
directory: str
data_types: list[str]
history: list[dict[str, Any]] = field(default_factory=list)
system_prompt: str = ""
langfuse_prompt: Any = None
created_at: float = field(default_factory=time.monotonic)
def is_expired(self) -> bool:
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
# session_id → session
_sessions: dict[str, JourneySession] = {}
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
"""Retrieve session; return None on missing, expired, or wrong owner."""
s = _sessions.get(session_id)
if s is None or s.is_expired():
_sessions.pop(session_id, None)
return None
if s.user_id != user_id:
return None
return s
# ── System prompt ─────────────────────────────────────────────────────────
_JOURNEY_SYSTEM_PROMPT = """\
You are a friendly assistant helping a freelancer configure a data-extraction agent.
Your job is to understand what files the user has in their directory and produce a
structured AgentConfig JSON that the extraction agent will use as its instruction set.
You have access to file-system tools to explore the user's directory:
- list_directory: see folder structure and file names
- read_file_content: peek at a file's content
- get_file_metadata: check file size, extension, dates
The user's configured directory is: {directory}
Target data types: {data_types}
## Your process
### Step 1 — Explore the directory
Use list_directory and read_file_content to understand what types of files are present
(HTML emails, plain-text documents, CSVs, etc.).
### Step 2 — Identify content types
For each distinct file type found, decide:
- A short id (e.g. "email_html", "plain_text", "csv")
- Which preprocessing handler to use: "email_html" for HTML emails, "generic" for everything else
- A human-readable label and optional detection_hint
### Step 3 — Ask focused questions (one at a time)
Cover these topics based on what you discovered:
1. How to map content to entity types (task / note / timeline entry)
2. Field mapping rules (e.g. email Subject → task title, filename → note title)
3. Priority or status rules (e.g. "urgent" in subject → high priority)
4. Date extraction (e.g. "by Friday" → dueDate)
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
### Step 4 — Produce the AgentConfig JSON
Once you are ≥ 90% confident, output the final config between these exact markers
(each on its own line):
{config_start}
{{
"content_types": [
{{
"id": "email_html",
"label": "Email HTML",
"detection_hint": "HTML file with From/To/Subject headers",
"preprocessing": "email_html",
"extraction_prompt": "Detailed extraction instructions for this content type..."
}}
],
"global_rules": [
"If the file cannot be matched to any project, do not create any entity."
],
"data_types": {data_types_json}
}}
{config_end}
## Rules for the extraction_prompt field
- Describe when to create a task vs note vs timeline entry (be specific and concrete)
- Include field mapping rules based on what you found in the directory
- Include priority/status/date rules if applicable
- Do NOT include projectId logic — the runner handles project assignment automatically
- Do NOT mention isAiSuggested — the runner always sets it to 1
## Constraints
- Never ask about projects, projectId, or how to link records to projects
- Never include projectId or project creation logic in the generated config
- Keep asking questions until ≥ 90% confident, then output the JSON immediately
{existing_section}\
Begin by exploring the directory, then ask your first question.\
"""
def _build_system_prompt(
directory: str,
data_types: list[str],
existing_config: str | None = None,
) -> tuple[str, Any]:
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
existing_section = (
"\nThe user already has the following AgentConfig — refine it based on their answers:\n"
f"```json\n{existing_config}\n```\n"
if existing_config
else ""
)
template, prompt_obj = get_prompt_or_fallback(
"journey_system", _JOURNEY_SYSTEM_PROMPT
)
compiled = compile_prompt(
template,
prompt_obj,
directory=directory,
data_types=", ".join(data_types),
data_types_json=json.dumps(data_types),
config_start=_CONFIG_START,
config_end=_CONFIG_END,
existing_section=existing_section,
)
return compiled, prompt_obj
# ── AgentConfig extraction ────────────────────────────────────────────────
def _extract_agent_config(text: str) -> str | None:
"""Return validated AgentConfig JSON string from between markers, or None.
Parses the JSON with Pydantic to ensure it conforms to the schema before
returning. Returns None if markers are absent or JSON is invalid.
"""
if _CONFIG_START not in text or _CONFIG_END not in text:
return None
start_idx = text.index(_CONFIG_START) + len(_CONFIG_START)
end_idx = text.index(_CONFIG_END)
raw = text[start_idx:end_idx].strip()
if not raw:
return None
try:
parsed = AgentConfig.model_validate_json(raw)
return parsed.model_dump_json()
except Exception as exc:
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc)
return None
# ── LLM call with tool support ───────────────────────────────────────────
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _call_llm_with_tools(
system_prompt: str,
history: list[dict[str, Any]],
tools: list[Any],
*,
user_id: str = "",
session_id: str = "",
langfuse_prompt: Any = None,
) -> str:
"""Build LangChain messages from history and invoke the LLM with tools.
Handles tool-calling loops: if the LLM calls tools, execute them and
continue until a final text response is produced.
"""
lf = get_langfuse()
messages: list[Any] = [SystemMessage(content=system_prompt)]
for turn in history:
if turn["role"] == "user":
messages.append(HumanMessage(content=turn["content"]))
else:
messages.append(AIMessage(content=turn["content"]))
llm = get_llm(model=None, temperature=0.4)
llm_with_tools = llm.bind_tools(tools)
tool_map = {tool_def.name: tool_def for tool_def in tools}
_span_ctx = (
lf.start_as_current_observation(
as_type="span",
name="journey-setup",
metadata={"user_id": user_id or None, "session_id": session_id or None},
input=history[-1]["content"] if history else "",
)
if lf else None
)
_span = _span_ctx.__enter__() if _span_ctx else None
try:
for _ in range(_MAX_TOOL_STEPS):
_gen_ctx = (
lf.start_as_current_observation(
as_type="generation",
name="journey-setup-llm",
model=settings.LLM_MODEL,
prompt=langfuse_prompt,
input=messages,
)
if lf else None
)
_gen = _gen_ctx.__enter__() if _gen_ctx else None
response: AIMessage = await llm_with_tools.ainvoke(messages)
if _gen_ctx:
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
_gen_ctx.__exit__(None, None, None)
messages.append(response)
if not response.tool_calls:
if _span:
_span.update(output=_as_text(response.content))
return _as_text(response.content)
for call in response.tool_calls:
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"agent_setup: journey tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:500],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"agent_setup: journey tool_result name=%s output=%s",
call_name,
str(tool_output)[:800],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
# Fallback: exceeded max steps.
final = await llm.ainvoke(messages)
final_text = _as_text(final.content)
if _span:
_span.update(output=final_text)
return final_text
finally:
if _span_ctx:
_span_ctx.__exit__(None, None, None)
if lf:
lf.flush()
# ── Journey handlers (called from device_ws.py) ──────────────────────────
async def handle_journey_start(
user_id: str,
frame: dict[str, Any],
) -> dict[str, Any]:
"""Handle a ``journey_start`` WS frame.
Creates a session, runs the setup LLM with directory exploration,
and returns the ``journey_reply`` payload.
"""
agent_type = frame.get("agent_type", "local")
directory = frame.get("directory", "")
data_types = frame.get("data_types", [])
existing_config = frame.get("existing_config")
# Use the session_id provided by the FE so the reply matches the
# listener key; fall back to a generated one if absent.
session_id = frame.get("session_id") or str(uuid.uuid4())
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_config)
session = JourneySession(
session_id=session_id,
user_id=user_id,
agent_type=agent_type,
directory=directory,
data_types=data_types,
system_prompt=system_prompt,
langfuse_prompt=langfuse_prompt,
)
# Seed with an initial user message — some providers require at least one
# user/input message to be present.
seed_history: list[dict[str, Any]] = [
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
]
ai_reply = await _call_llm_with_tools(
system_prompt=system_prompt,
history=seed_history,
tools=list(FILESYSTEM_TOOLS),
user_id=user_id,
session_id=session_id,
langfuse_prompt=langfuse_prompt,
)
session.history.extend(seed_history)
session.history.append({"role": "assistant", "content": ai_reply})
_sessions[session_id] = session
logger.info(
"agent_setup: journey session %s started for user %s (directory=%s)",
session_id,
user_id,
directory,
)
# Check if the LLM produced the config on the first turn (unlikely but possible).
agent_config = _extract_agent_config(ai_reply)
done = agent_config is not None
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
or "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"agent_config": agent_config,
}
async def handle_journey_message(
user_id: str,
frame: dict[str, Any],
) -> dict[str, Any]:
"""Handle a ``journey_message`` WS frame.
Appends the user message, calls the LLM, and returns the
``journey_reply`` payload.
"""
session_id = frame.get("session_id", "")
message = frame.get("message", "")
session = get_journey_session(session_id, user_id)
if session is None:
return {
"type": "journey_reply",
"session_id": session_id,
"message": "Journey session not found or expired. Please start a new setup.",
"done": True,
"agent_config": None,
}
# Append user turn.
session.history.append({"role": "user", "content": message})
# Call the LLM with tools.
ai_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
user_id=session.user_id,
session_id=session_id,
langfuse_prompt=session.langfuse_prompt,
)
session.history.append({"role": "assistant", "content": ai_reply})
# Check if the LLM produced the final config.
agent_config = _extract_agent_config(ai_reply)
done = agent_config is not None
# If the LLM didn't produce a config, nudge it once it hits the hard safety cap.
if not done:
turns = sum(1 for t in session.history if t["role"] == "user")
if turns >= _MAX_TURNS:
nudge_content = (
"[System: You have enough information. Please generate the final "
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
)
session.history.append({"role": "user", "content": nudge_content})
nudge_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
user_id=session.user_id,
session_id=session_id,
langfuse_prompt=session.langfuse_prompt,
)
session.history.append({"role": "assistant", "content": nudge_reply})
agent_config = _extract_agent_config(nudge_reply)
if agent_config is not None:
done = True
ai_reply = nudge_reply
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
if _CONFIG_START in ai_reply
else "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"agent_config": agent_config,
}

222
app/api/routes/agents.py Normal file
View File

@@ -0,0 +1,222 @@
"""Agent routes.
Backend responsibilities are intentionally minimal:
GET /agents/catalog — static catalog for UI display
POST /agents/can-create — billing eligibility check
POST /agents/trigger — trigger a local agent run
Agent configuration is owned by the Electron app and is not persisted
in backend agent-config tables.
"""
from __future__ import annotations
import asyncio
import uuid
from datetime import datetime, timedelta, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES
from app.core.agent_runner import is_agent_running, run_local_agent
from app.core.device_manager import device_manager
from app.db import get_session
from app.models import AgentRunLog, LocalAgentConfig
from app.schemas import (
AgentCatalogItem,
AgentCreationCheckRequest,
AgentCreationCheckResponse,
AgentRunLogResponse,
AgentTriggerRequest,
UserProfile,
)
router = APIRouter(prefix="/agents", tags=["agents"])
# ── Datetime helpers ──────────────────────────────────────────────────
def _dt_ms(dt: datetime) -> int:
return int(dt.timestamp() * 1000)
def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
return AgentRunLogResponse(
id=log.id,
agent_id=log.agent_id,
agent_type=log.agent_type, # type: ignore[arg-type]
status=log.status, # type: ignore[arg-type]
items_processed=log.items_processed,
items_created=log.items_created,
errors=log.errors or [],
started_at=_dt_ms(log.started_at),
completed_at=_dt_ms_opt(log.completed_at),
)
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
async def _enforce_run_frequency(
tier: str,
user_id: str,
db: AsyncSession,
) -> None:
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return # unlimited
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
result = await db.execute(
select(func.count(AgentRunLog.id)).where(
AgentRunLog.user_id == user_id,
AgentRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@router.get("/catalog", response_model=list[AgentCatalogItem])
async def get_agent_catalog(
current_user: UserProfile = Depends(get_current_user),
) -> list[AgentCatalogItem]:
"""Return the static list of available agent types and their descriptions."""
return [
AgentCatalogItem(
type="local_directory",
name="Local Directory Monitor",
description="Watches local directories, extracts data from files using AI",
),
AgentCatalogItem(
type="gmail",
name="Gmail Connector",
description="Scans Gmail inbox, extracts tasks/notes from emails",
),
AgentCatalogItem(
type="teams",
name="Microsoft Teams Connector",
description="Monitors Teams messages, extracts action items",
),
AgentCatalogItem(
type="outlook",
name="Outlook Connector",
description="Scans Outlook inbox, extracts tasks/notes",
),
]
@router.post("/can-create", response_model=AgentCreationCheckResponse)
async def can_create_agent(
body: AgentCreationCheckRequest,
current_user: UserProfile = Depends(get_current_user),
) -> AgentCreationCheckResponse:
"""Check if the user can create one more agent based on billing tier.
Since configuration is client-owned, the Electron app sends its current
active agent count and the backend applies tier limits.
"""
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or body.active_agents < limit
return AgentCreationCheckResponse(
allowed=allowed,
tier=current_user.tier,
active_agents=body.active_agents,
limit=limit,
)
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
body: AgentTriggerRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> AgentRunLogResponse:
"""Trigger a local agent run using client-provided configuration."""
_enforce_agent_limit(current_user.tier, body.active_agents)
await _enforce_run_frequency(current_user.tier, current_user.id, db)
config = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=current_user.id,
device_id=body.device_id,
name="Local Directory Monitor",
directory_paths=[body.directory],
data_types=_to_data_types(body.what_to_extract),
prompt_template=body.custom_agent_prompt,
file_extensions=[],
schedule_cron=body.batch_interval,
enabled=True,
)
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
stable_agent_id = body.agent_id or config.id
if is_agent_running(stable_agent_id):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Agent is already running. Only one run per agent is allowed at a time.",
)
run_log = AgentRunLog(
agent_id=stable_agent_id,
agent_type="local",
user_id=current_user.id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_context = {
"type": "agent_batch",
"run_id": run_log.id,
"agent_id": stable_agent_id,
}
asyncio.create_task(
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
)
return _to_run_log_response(run_log)

View File

@@ -13,6 +13,7 @@ import uuid
from datetime import datetime, timedelta, timezone
import bcrypt
from cryptography.fernet import Fernet
from fastapi import APIRouter, Depends, HTTPException, status
from jose import jwt
from pydantic import BaseModel
@@ -65,6 +66,8 @@ def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
class _RegisterRequest(BaseModel):
email: str
password: str
name: str | None = None
surname: str | None = None
class _LoginRequest(BaseModel):
@@ -92,8 +95,11 @@ async def register(
user = User(
id=str(uuid.uuid4()),
email=body.email,
name=body.name,
surname=body.surname,
password_hash=_hash_password(body.password),
tier="free",
encryption_key=Fernet.generate_key().decode(),
)
db.add(user)
await db.flush() # get user.id without committing
@@ -191,7 +197,39 @@ async def refresh(
)
class _UpdateProfileRequest(BaseModel):
name: str | None = None
surname: str | None = None
@router.get("/me", response_model=UserProfile)
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
"""Return the profile for the authenticated user."""
return current_user
@router.put("/me", response_model=UserProfile)
async def update_profile(
body: _UpdateProfileRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Update the authenticated user's name and surname."""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
if body.name is not None:
user.name = body.name
if body.surname is not None:
user.surname = body.surname
await db.commit()
await db.refresh(user)
return UserProfile(
id=user.id,
email=user.email,
name=user.name,
surname=user.surname,
tier=current_user.tier,
)

View File

@@ -1,171 +0,0 @@
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
PostgreSQL ``backup_metadata`` table.
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
treating "history" as a ``{backup_id}`` path parameter.
"""
from __future__ import annotations
import uuid
from email.utils import parsedate_to_datetime
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import BackupMetadata as BackupMetadataModel
from app.schemas import BackupMetadata, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/backup", tags=["backup"])
_blob_store = BlobStore()
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total backup bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
BackupMetadataModel.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_backup_quota(
user: UserProfile, size_bytes: int, db: AsyncSession
) -> None:
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
current = await _current_backup_bytes(user.id, db)
tier_manager.enforce_backup_quota(
user.tier, current_bytes=current, additional_bytes=size_bytes
)
@router.put("")
async def upload_backup(
request: Request,
x_backup_version: int = Header(..., alias="X-Backup-Version"),
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Upload an E2E-encrypted backup blob.
Metadata is passed via custom headers; the raw body is the encrypted blob.
"""
blob = await request.body()
reject_if_tampered(blob, x_backup_checksum)
await _check_backup_quota(current_user, len(blob), db)
s3_key = await _blob_store.upload(
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
)
row = BackupMetadataModel(
id=str(uuid.uuid4()),
user_id=current_user.id,
s3_key=s3_key,
version=x_backup_version,
timestamp=x_backup_timestamp,
checksum=x_backup_checksum,
size_bytes=len(blob),
)
db.add(row)
await db.commit()
return {"ok": True}
@router.get("/history", response_model=list[BackupMetadata])
async def backup_history(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[BackupMetadata]:
"""Return backup metadata records for the authenticated user (no blob bytes)."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
)
rows = result.scalars().all()
return [
BackupMetadata(
version=r.version,
timestamp=r.timestamp,
checksum=r.checksum,
chunk_count=1,
)
for r in rows
]
@router.get("")
async def download_backup(
request: Request,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
.limit(1)
)
latest = result.scalar_one_or_none()
if latest is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
ims_header = request.headers.get("If-Modified-Since")
if ims_header:
try:
ims_dt = parsedate_to_datetime(ims_header)
ims_ms = int(ims_dt.timestamp() * 1000)
if latest.timestamp <= ims_ms:
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
except Exception:
pass # malformed header — ignore and serve the blob
blob = await _blob_store.download(current_user.id, latest.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={
"X-Backup-Version": str(latest.version),
"X-Backup-Timestamp": str(latest.timestamp),
"X-Checksum": latest.checksum,
},
)
@router.delete("/{backup_id}", response_model=dict)
async def delete_backup(
backup_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a specific backup by ID."""
result = await db.execute(
select(BackupMetadataModel).where(
BackupMetadataModel.id == backup_id,
BackupMetadataModel.user_id == current_user.id,
)
)
target = result.scalar_one_or_none()
if target is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
await _blob_store.delete(current_user.id, target.s3_key)
await db.delete(target)
await db.commit()
return {"ok": True}

View File

@@ -1,22 +1,34 @@
"""Chat routes: POST /chat and WebSocket /chat/stream."""
"""Chat routes: POST /chat (REST fallback) and POST /chat/embed (text → vector).
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
"""
from __future__ import annotations
import asyncio
import json
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from jose import JWTError, jwt
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.config.settings import settings
from app.core.orchestrator import orchestrate, orchestrate_stream
from app.core.deep_agent import run_home
from app.core.llm import embed
from app.schemas import ChatRequest, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"])
_HEARTBEAT_INTERVAL = 30 # seconds
# ── Embed helpers ─────────────────────────────────────────────────────────
class _EmbedRequest(BaseModel):
text: str
class _EmbedResponse(BaseModel):
vector: list[float]
# ── Endpoints ─────────────────────────────────────────────────────────────
@router.post("")
@@ -24,55 +36,24 @@ async def chat(
body: ChatRequest,
current_user: UserProfile = Depends(get_current_user),
) -> JSONResponse:
"""Route a chat message through the orchestrator.
"""REST fallback for home chat when websocket streaming is unavailable."""
response = await run_home(
user_id=current_user.id,
message=body.message,
context=body.context.model_dump(),
)
return JSONResponse(content={"response": response})
Returns ``ChatResponse`` for ``execution_mode='direct'``,
or ``ExecutionPlan`` for ``execution_mode='plan'``.
@router.post("/embed", response_model=_EmbedResponse)
async def embed_text(
body: _EmbedRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _EmbedResponse:
"""Generate a 1536-dim embedding vector for the given text.
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
Used by Electron (vectordb.ts) for local note search.
"""
result = await orchestrate(body)
return JSONResponse(content=result.model_dump())
@router.websocket("/stream")
async def chat_stream(websocket: WebSocket) -> None:
"""Streaming chat via WebSocket.
Auth: ``?token=<jwt>`` query param (Bearer not possible during WS handshake).
Protocol:
1. Client sends ``ChatRequest`` as the first JSON text frame.
2. Server streams response text chunks.
3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``.
4. Server pings every 30 s to keep the connection alive.
"""
# Authenticate before accepting the connection
token = websocket.query_params.get("token", "")
try:
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
user_id: str | None = payload.get("sub")
if not user_id:
raise JWTError("missing sub")
except JWTError:
await websocket.close(code=1008) # 1008 = Policy Violation
return
await websocket.accept()
try:
raw = await websocket.receive_text()
body = ChatRequest.model_validate_json(raw)
async def _heartbeat() -> None:
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
await websocket.send_text(json.dumps({"ping": True}))
heartbeat_task = asyncio.create_task(_heartbeat())
try:
async for chunk in orchestrate_stream(body):
await websocket.send_text(chunk)
finally:
heartbeat_task.cancel()
except WebSocketDisconnect:
pass
vector = await embed(body.text)
return _EmbedResponse(vector=vector)

417
app/api/routes/device_ws.py Normal file
View File

@@ -0,0 +1,417 @@
"""Device WebSocket endpoint.
Persistent connection from Electron devices to the backend.
WS /api/v1/ws/device?token=<jwt>
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
available during the WebSocket handshake).
Protocol:
1. Client connects → JWT validated → connection accepted.
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
3. Backend registers the connection in ``DeviceConnectionManager``.
4. Session enters message dispatch loop + heartbeat.
Incoming frame dispatch:
- ``tool_result`` → resolves a pending tool-call Future.
- ``journey_start`` → starts a guided setup journey session.
- ``journey_message`` → continues a journey conversation.
- ``pong`` → heartbeat acknowledgement (updates last-seen).
- unknown types → logged, ignored.
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
On disconnect:
- Unregisters from DeviceConnectionManager.
- Marks all in-progress AgentRunLog rows for this user as ``error``
with message "device disconnected".
"""
from __future__ import annotations
import asyncio
import json
import logging
from uuid import uuid4
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt
from sqlalchemy import update
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs
from app.core.deep_agent import run_floating_stream, run_home_stream
from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware
from app.core.output_formatter import StreamFormatter
from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session
from app.models import AgentRunLog
from app.schemas import WsFrameType
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["device-ws"])
_HEARTBEAT_INTERVAL = 30 # seconds
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
@router.websocket("/device")
async def device_ws(websocket: WebSocket) -> None:
"""Persistent WebSocket endpoint for Electron device connections.
Authentication is via ``?token=<jwt>`` query parameter.
"""
# ── 1. Authenticate before accepting ─────────────────────────────
token = websocket.query_params.get("token", "")
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
if not user_id:
raise JWTError("missing sub")
except JWTError:
await websocket.close(code=1008) # Policy Violation
return
await websocket.accept()
# ── 2. Await device_hello frame ───────────────────────────────────
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
except (asyncio.TimeoutError, WebSocketDisconnect):
await websocket.close(code=1008)
return
try:
hello = json.loads(raw)
if hello.get("type") != WsFrameType.device_hello:
raise ValueError("expected device_hello as first frame")
device_id: str = hello["device_id"]
agent_ids: list[str] = hello.get("agent_ids", [])
except (KeyError, ValueError, json.JSONDecodeError) as exc:
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
await websocket.close(code=1008)
return
# ── 3. Register connection ────────────────────────────────────────
device_manager.register(user_id, device_id, websocket)
logger.info(
"device_ws: connected user=%s device=%s agents=%s",
user_id,
device_id,
agent_ids,
)
# Trigger any overdue agent runs now that the device is connected.
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
# ── 4. Concurrent message loop + heartbeat ────────────────────────
try:
await asyncio.gather(
_message_loop(websocket, user_id),
_heartbeat_loop(websocket),
)
except WebSocketDisconnect:
pass
except Exception as exc:
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
finally:
device_manager.unregister(user_id)
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
await _mark_runs_disconnected(user_id)
# ── Message dispatch loop ─────────────────────────────────────────────
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
"""Receive frames from Electron and dispatch to the appropriate handler."""
async for raw in websocket.iter_text():
try:
frame: dict = json.loads(raw)
except json.JSONDecodeError:
logger.warning("device_ws: invalid JSON from user=%s", user_id)
continue
frame_type = frame.get("type")
if frame_type == WsFrameType.tool_result:
call_id = frame.get("id")
if call_id:
device_manager.resolve_pending_call(user_id, call_id, frame)
else:
logger.warning(
"device_ws: tool_result missing id from user=%s", user_id
)
elif frame_type == WsFrameType.home_request:
asyncio.create_task(
_handle_home_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.floating_request:
asyncio.create_task(
_handle_floating_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_start:
asyncio.create_task(
_handle_journey_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_message:
asyncio.create_task(
_handle_journey_message(websocket, user_id, frame)
)
elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive.
pass
else:
logger.debug(
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
)
# ── v3 Chat Handlers ──────────────────────────────────────────────────
async def _make_ws_executor(websocket: WebSocket, user_id: str):
"""Return a callback that sends tool_call frames and awaits tool_result."""
async def _executor(payload: dict) -> dict:
payload["type"] = WsFrameType.tool_call
await websocket.send_text(json.dumps(payload))
future = device_manager.create_pending_call(user_id, payload["id"])
return await future
return _executor
async def _handle_home_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
logger.info(
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
user_id,
request_id,
session_id,
message[:200],
)
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_home_stream(user_id, message, context)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
# Collect text chunks to build the full response for episode storage
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: home_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# ── Memory: store episode after response ──────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
)
async def _handle_floating_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {})
logger.info(
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
user_id,
request_id,
session_id,
json.dumps(scope, ensure_ascii=True)[:200],
message[:200],
)
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_floating_stream(user_id, message, context)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: floating_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# ── Memory: store episode after response ──────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
)
# ── v4 Journey Handlers ─────────────────────────────────────────────
async def _handle_journey_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a journey_start frame — explores directory and sends first question."""
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
reply = await handle_journey_start(user_id, frame)
await websocket.send_text(json.dumps(reply))
except Exception as exc:
logger.error(
"device_ws: journey_start failed user=%s: %s", user_id, exc
)
await websocket.send_text(json.dumps({
"type": "journey_reply",
"session_id": frame.get("session_id", ""),
"message": f"Failed to start journey: {exc}",
"done": True,
"prompt_template": None,
}))
finally:
clear_client_executor()
async def _handle_journey_message(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a journey_message frame — continues the journey conversation."""
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
reply = await handle_journey_message(user_id, frame)
await websocket.send_text(json.dumps(reply))
except Exception as exc:
session_id = frame.get("session_id", "")
logger.error(
"device_ws: journey_message failed user=%s session=%s: %s",
user_id, session_id, exc,
)
await websocket.send_text(json.dumps({
"type": "journey_reply",
"session_id": session_id,
"message": f"Journey error: {exc}",
"done": True,
"prompt_template": None,
}))
finally:
clear_client_executor()
# ── Heartbeat ─────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None:
"""Send a ping frame every 30 s to keep the connection alive."""
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
await websocket.send_text(json.dumps({"type": "ping"}))
# ── Disconnect cleanup ────────────────────────────────────────────────
async def _mark_runs_disconnected(user_id: str) -> None:
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
try:
async with async_session() as db:
await db.execute(
update(AgentRunLog)
.where(
AgentRunLog.user_id == user_id,
AgentRunLog.status == "running",
)
.values(
status="error",
errors=["device disconnected"],
)
)
await db.commit()
except Exception as exc:
logger.error(
"device_ws: failed to mark runs as disconnected for user=%s: %s",
user_id,
exc,
)

View File

@@ -1,37 +0,0 @@
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from app.api.deps import get_current_user
from app.core.execution_plan import plan_cache
from app.schemas import ExecutionPlan, UserProfile
router = APIRouter(prefix="/plans", tags=["plans"])
@router.get("/playbook", response_model=list[ExecutionPlan])
async def list_playbooks(
current_user: UserProfile = Depends(get_current_user),
) -> list[ExecutionPlan]:
"""Return all cached execution plan playbooks for the authenticated user.
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
"""
return plan_cache.get_all_playbooks()
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
async def get_playbook(
plan_id: str,
current_user: UserProfile = Depends(get_current_user),
) -> ExecutionPlan:
"""Return a specific execution plan playbook by ID."""
plan = plan_cache.get_plan(plan_id)
if plan is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Plan not found: {plan_id}",
)
return plan

View File

@@ -1,148 +0,0 @@
"""Plugins routes: browse and install plugins from the marketplace.
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
"""
from __future__ import annotations
from typing import Any, Literal
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.db import get_session
from app.marketplace.plugin_registry import registry
from app.marketplace.revenue_share import revenue_share
from app.models import PluginInstallation, PluginReview as PluginReviewModel
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
router = APIRouter(prefix="/plugins", tags=["plugins"])
# ── Tier gate ─────────────────────────────────────────────────────────
def _require_plugin_tier(user: UserProfile) -> None:
"""Raise HTTP 403 for users below Power tier."""
if user.tier not in ("power", "team"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Plugin marketplace requires Power tier or above",
)
# ── Local detail schema ────────────────────────────────────────────────
class _PluginDetail(BaseModel):
plugin: PluginManifest
install_count: int
ratings: list[Any]
# ── Routes ────────────────────────────────────────────────────────────
@router.get("", response_model=PluginListResponse)
async def list_plugins(
category: str | None = Query(default=None),
q: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> PluginListResponse:
"""Browse the plugin marketplace. Requires Power tier or above."""
_require_plugin_tier(current_user)
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
@router.get("/{plugin_id}", response_model=_PluginDetail)
async def get_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _PluginDetail:
"""Get full plugin details including install count. Requires Power tier or above."""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Fetch review ratings for this plugin
review_result = await db.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
)
reviews = review_result.scalars().all()
ratings = [
{
"reviewer_id": r.reviewer_id,
"decision": r.decision,
"notes": r.notes,
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
}
for r in reviews
]
return _PluginDetail(
plugin=entry["manifest"],
install_count=entry["install_count"],
ratings=ratings,
)
@router.post("/{plugin_id}/install", response_model=dict)
async def install_plugin(
plugin_id: str,
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
Requires Power tier or above.
"""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Record the installation in plugin_installations
installation = PluginInstallation(
plugin_id=plugin_id,
user_id=current_user.id,
)
db.add(installation)
await db.flush()
await revenue_share.record_install(
db,
plugin_id=plugin_id,
user_id=current_user.id,
amount_cents=entry["manifest"].price_cents,
)
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
return {"ok": True, "download_url": download_url}
@router.delete("/{plugin_id}/install", response_model=dict)
async def uninstall_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Unregister a plugin installation."""
result = await db.execute(
select(PluginInstallation).where(
PluginInstallation.plugin_id == plugin_id,
PluginInstallation.user_id == current_user.id,
)
)
installation = result.scalar_one_or_none()
if installation is not None:
await db.delete(installation)
await db.commit()
await registry.record_uninstall(db, plugin_id)
return {"ok": True}

View File

@@ -1,195 +0,0 @@
"""Storage routes: CRUD for E2E-encrypted cloud records.
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
PostgreSQL ``storage_records`` table.
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import StorageRecord
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/storage", tags=["storage"])
_blob_store = BlobStore()
# ── Local response schemas ─────────────────────────────────────────────
class _CreateResponse(BaseModel):
id: str
created_at: int
class _RecordMeta(BaseModel):
id: str
table: str
checksum: str
created_at: int
updated_at: int
# ── Helpers ────────────────────────────────────────────────────────────
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
StorageRecord.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
current = await _current_usage_bytes(user.id, db)
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
async def _get_record_for_user(
record_id: str, user_id: str, db: AsyncSession
) -> StorageRecord:
"""Look up a record and verify ownership. Returns 404 on mismatch
to prevent user enumeration attacks."""
result = await db.execute(
select(StorageRecord).where(
StorageRecord.id == record_id, StorageRecord.user_id == user_id
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
return record
# ── Routes ─────────────────────────────────────────────────────────────
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
async def create_record(
body: StorageRecordCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _CreateResponse:
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
reject_if_tampered(body.blob, body.checksum)
await _check_quota(current_user, len(body.blob), db)
record_id = str(uuid.uuid4())
s3_key = await _blob_store.upload(
current_user.id, body.table, record_id, body.blob, body.checksum
)
record = StorageRecord(
id=record_id,
user_id=current_user.id,
table_name=body.table,
s3_key=s3_key,
checksum=body.checksum,
size_bytes=len(body.blob),
)
db.add(record)
await db.commit()
await db.refresh(record)
created_at_ms = int(record.created_at.timestamp() * 1000)
return _CreateResponse(id=record_id, created_at=created_at_ms)
@router.get("/records", response_model=list[_RecordMeta])
async def list_records(
table: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=50, ge=1, le=200),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[_RecordMeta]:
"""List record metadata for the authenticated user. Blob bytes are never returned."""
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
if table is not None:
query = query.where(StorageRecord.table_name == table)
query = query.offset((page - 1) * limit).limit(limit)
result = await db.execute(query)
rows = result.scalars().all()
return [
_RecordMeta(
id=r.id,
table=r.table_name,
checksum=r.checksum,
created_at=int(r.created_at.timestamp() * 1000),
updated_at=int(r.updated_at.timestamp() * 1000),
)
for r in rows
]
@router.get("/records/{record_id}")
async def download_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
record = await _get_record_for_user(record_id, current_user.id, db)
blob = await _blob_store.download(current_user.id, record.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={"X-Checksum": record.checksum},
)
@router.put("/records/{record_id}", response_model=dict)
async def update_record(
record_id: str,
body: StorageRecordUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Replace the blob for an existing record. Verifies checksum before storing."""
record = await _get_record_for_user(record_id, current_user.id, db)
reject_if_tampered(body.blob, body.checksum)
delta = len(body.blob) - record.size_bytes
if delta > 0:
await _check_quota(current_user, delta, db)
s3_key = await _blob_store.upload(
current_user.id, record.table_name, record_id, body.blob, body.checksum
)
record.s3_key = s3_key
record.checksum = body.checksum
record.size_bytes = len(body.blob)
await db.commit()
return {"ok": True}
@router.delete("/records/{record_id}", response_model=dict)
async def delete_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a record and its S3 blob."""
record = await _get_record_for_user(record_id, current_user.id, db)
await _blob_store.delete(current_user.id, record.s3_key)
await db.delete(record)
await db.commit()
return {"ok": True}

View File

@@ -1,56 +0,0 @@
"""Vectors routes: upsert, search, and delete cloud vector store entries."""
from __future__ import annotations
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.schemas import (
UserProfile,
VectorSearchRequest,
VectorSearchResponse,
VectorUpsertRequest,
)
from app.storage.encryption import reject_if_tampered
from app.storage.vector_store import VectorStore
router = APIRouter(prefix="/storage", tags=["vectors"])
_vector_store = VectorStore()
class _VectorDeleteRequest(BaseModel):
ids: list[str]
@router.post("/vectors/upsert", response_model=dict)
async def upsert_vectors(
body: VectorUpsertRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, int]:
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
for item in body.vectors:
reject_if_tampered(item.blob, item.checksum)
await _vector_store.upsert(current_user.id, body.vectors)
return {"upserted": len(body.vectors)}
@router.post("/vectors/search", response_model=VectorSearchResponse)
async def search_vectors(
body: VectorSearchRequest,
current_user: UserProfile = Depends(get_current_user),
) -> VectorSearchResponse:
"""Search the user-scoped vector namespace with an encrypted query blob."""
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
return VectorSearchResponse(results=results)
@router.delete("/vectors", response_model=dict)
async def delete_vectors(
body: _VectorDeleteRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, bool]:
"""Delete vectors by ID, scoped to the authenticated user."""
await _vector_store.delete(current_user.id, body.ids)
return {"ok": True}

View File

@@ -43,8 +43,8 @@ class StripeService:
self,
user_id: str,
tier: str,
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
cancel_url: str = "https://app.adiuva.app/billing/cancel",
success_url: str = "https://app.adiuvai.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
cancel_url: str = "https://app.adiuvai.app/billing/cancel",
) -> str:
"""Create a Stripe checkout session and return the URL.

View File

@@ -21,41 +21,33 @@ FEATURES: dict[str, dict[str, Any]] = {
"free": {
"agents": 3,
"batch_active": 2,
"cloud_storage_gb": 0,
"backup_gb": 0,
"batch_runs_per_day": 5,
"providers": 1,
"batch_builder": False,
"plugin_marketplace": False,
"sso": False,
},
"pro": {
"agents": -1, # unlimited
"batch_active": 10,
"cloud_storage_gb": 5,
"backup_gb": 5,
"batch_runs_per_day": 50,
"providers": -1,
"batch_builder": False,
"plugin_marketplace": False,
"sso": False,
},
"power": {
"agents": -1,
"batch_active": -1, # unlimited
"cloud_storage_gb": 25,
"backup_gb": 25,
"batch_runs_per_day": -1, # unlimited
"providers": -1,
"batch_builder": True,
"plugin_marketplace": True,
"sso": False,
},
"team": {
"agents": -1,
"batch_active": -1,
"cloud_storage_gb": -1, # unlimited
"backup_gb": -1, # unlimited
"batch_runs_per_day": -1, # unlimited
"providers": -1,
"batch_builder": True,
"plugin_marketplace": True,
"sso": True,
},
}
@@ -77,16 +69,18 @@ class TierManager:
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
"""Return the current billing tier for ``user_id`` from the DB.
Falls back to ``'free'`` when no subscription row exists.
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
when no subscription row exists.
"""
from app.models import Subscription # noqa: PLC0415
from app.config.settings import settings # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
tier: str | None = result.scalar_one_or_none()
if tier is None or tier not in FEATURES:
return "free"
return "power" if settings.ENV == "dev" else "free"
return tier # type: ignore[return-value]
# ── Feature access ───────────────────────────────────────────────────
@@ -119,71 +113,6 @@ class TierManager:
"""Return the requests-per-minute limit for ``tier``."""
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
# ── Storage quota ────────────────────────────────────────────────────
def enforce_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
``tier`` is the caller's current tier (from ``current_user.tier``).
``current_bytes`` is the total bytes already stored (queried by caller).
"""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Cloud storage is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Storage quota exceeded for tier '{tier}'",
)
def enforce_backup_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
limit_gb: int = FEATURES[tier]["backup_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup quota exceeded for tier '{tier}'",
)
def check_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> bool:
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
return False
if limit_gb == -1:
return True
limit_bytes = limit_gb * 1024 ** 3
return current_bytes + additional_bytes <= limit_bytes
# Module-level singleton shared across the app.
tier_manager = TierManager()

View File

@@ -1,9 +1,9 @@
from typing import Literal
from pydantic_settings import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai"
JWT_SECRET: str = "change-me-in-production"
JWT_ALGORITHM: str = "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
@@ -12,31 +12,42 @@ class Settings(BaseSettings):
STRIPE_SECRET_KEY: str = ""
STRIPE_WEBHOOK_SECRET: str = ""
S3_BUCKET: str = ""
S3_REGION: str = "us-east-1"
S3_ENDPOINT_URL: str = ""
AWS_ACCESS_KEY_ID: str = ""
AWS_SECRET_ACCESS_KEY: str = ""
PINECONE_API_KEY: str = ""
PINECONE_INDEX: str = "adiuva"
QDRANT_URL: str = ""
QDRANT_API_KEY: str = ""
OPENAI_API_KEY: str = ""
ANTHROPIC_API_KEY: str = ""
GOOGLE_API_KEY: str = ""
CEREBRAS_API_KEY: str = ""
LLM_MODEL: str = "gpt-4o"
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
LLM_EMBED_MODEL: str = "text-embedding-3-small"
# GitHub Copilot OAuth token storage directory.
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
GITHUB_COPILOT_TOKEN_DIR: str = ""
# OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows.
GMAIL_CLIENT_ID: str = ""
GMAIL_CLIENT_SECRET: str = ""
MS_CLIENT_ID: str = ""
MS_CLIENT_SECRET: str = ""
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
MS_TENANT_ID: str = "common"
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
OAUTH_ENCRYPTION_KEY: str = ""
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
LANGFUSE_SECRET_KEY: str = ""
LANGFUSE_PUBLIC_KEY: str = ""
LANGFUSE_HOST: str = "https://cloud.langfuse.com"
ENV: Literal["dev", "prod"] = "dev"
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
settings = Settings()

View File

@@ -1,4 +1,4 @@
"""Agent Registry — base classes and singleton registry for chat agents."""
"""Minimal agent base types retained for compatibility with batch runners."""
from __future__ import annotations
@@ -7,7 +7,7 @@ from typing import Any
class BaseAgent(ABC):
"""Common base for all agents."""
"""Common base for non-chat agents still using the old base contract."""
def __init__(
self,
@@ -27,111 +27,4 @@ class BaseAgent(ABC):
@property
def skills(self) -> list[str]:
"""Override in subclasses to advertise capabilities."""
return []
class ChatAgent(BaseAgent):
"""Base class for LLM-powered chat agents."""
@abstractmethod
async def handle(self, query: str, context: dict[str, Any]) -> str:
"""Process a user query and return a text response."""
...
@abstractmethod
def get_tools(self) -> list[Any]:
"""Return LangChain tool definitions available to this agent."""
...
async def _tool_loop(
self,
llm: Any,
messages: list[Any],
tools: list[Any],
max_iter: int = 5,
) -> str:
"""Shared tool-calling loop.
Binds *tools* to *llm*, invokes iteratively until the model stops
requesting tool calls or *max_iter* is reached, and returns the
final text response.
"""
from langchain_core.messages import AIMessage, ToolMessage
llm_with_tools = llm.bind_tools(tools) if tools else llm
for _ in range(max_iter):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return str(response.content)
# Execute each requested tool call
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
result = f"Unknown tool: {call['name']}"
else:
result = await tool_fn.ainvoke(call["args"])
messages.append(
ToolMessage(content=str(result), tool_call_id=call["id"])
)
# Exhausted iterations — ask model for a final answer without tools
response = await llm.ainvoke(messages)
return str(response.content)
class AgentRegistry:
"""Singleton registry for ChatAgent subclasses."""
_instance: AgentRegistry | None = None
def __init__(self) -> None:
self._agents: dict[str, type[ChatAgent]] = {}
def __new__(cls) -> AgentRegistry:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._agents = {}
return cls._instance
# ── public API ───────────────────────────────────────────────────
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
"""Class decorator — registers an agent by its name."""
instance = agent_class()
name = instance.get_name()
self._agents[name] = agent_class
return agent_class
def get(self, name: str) -> ChatAgent:
"""Return a fresh instance of the named agent."""
cls = self._agents.get(name)
if cls is None:
raise KeyError(f"Agent not found: {name}")
return cls()
def list_agents(self) -> list[dict[str, str]]:
"""Return ``[{name, description}]`` for the orchestrator prompt."""
result: list[dict[str, str]] = []
for cls in self._agents.values():
inst = cls()
result.append(
{"name": inst.get_name(), "description": inst.get_description()}
)
return result
async def call_agent(
self, name: str, query: str, context: dict[str, Any]
) -> str:
"""Instantiate the named agent and call its ``handle`` method."""
agent = self.get(name)
return await agent.handle(query, context)
# Module-level singleton
registry = AgentRegistry()

1036
app/core/agent_runner.py Normal file

File diff suppressed because it is too large Load Diff

962
app/core/deep_agent.py Normal file
View File

@@ -0,0 +1,962 @@
"""Single-agent runners for home and floating chat contexts."""
from __future__ import annotations
import json
import logging
import re
from datetime import date
from collections.abc import AsyncGenerator
from typing import Any, Literal
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
from app.core.llm import get_llm
from app.config.settings import settings
from app.core.memory_middleware import MemoryMiddleware
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
from app.db import async_session
logger = logging.getLogger(__name__)
FloatingDomainType = Literal["task", "timeline", "project", "node"]
FloatingDomainSection = Literal["task", "timeline", "note"]
_HOME_SYSTEM_PROMPT = (
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Always use tools for factual data retrieval before answering. "
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
)
_FLOATING_SYSTEM_PROMPT = (
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Stay focused on the floating scope in context.scope and answer concisely. "
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
"Always use tools for factual data retrieval before answering. "
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
)
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = (
"You are a strict domain classifier for websocket floating requests. "
"Return ONLY a JSON object with keys: type, id, section. "
"Allowed type values: task, timeline, project, node. "
"Allowed section values: task, timeline, note, or null. "
"Rules: infer from user message intent first; do not blindly trust scope.type. "
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
"If project id is unknown but context.resolved_project_id exists, use it as id. "
"If id is unknown, use null. "
"No markdown, no prose, JSON only."
)
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
def _candidate_tokens(message: str) -> list[str]:
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
return [token for token in tokens if len(token) >= 3]
async def _resolve_project_id_from_message(message: str) -> str | None:
"""Resolve likely project UUID from user message using client project list."""
try:
result = await execute_on_client(action="select", table="projects")
except Exception as exc:
logger.warning("deep_agent: project resolve select failed: %s", exc)
return None
rows = result.get("rows", [])
if not isinstance(rows, list) or not rows:
return None
tokens = _candidate_tokens(message)
scored: list[tuple[int, dict[str, Any]]] = []
for row in rows:
if not isinstance(row, dict):
continue
name = str(row.get("name", "")).lower()
score = sum(1 for token in tokens if token in name)
if score > 0:
scored.append((score, row))
if not scored:
return None
scored.sort(key=lambda item: item[0], reverse=True)
top_score = scored[0][0]
top_rows = [row for score, row in scored if score == top_score]
if len(top_rows) != 1:
return None
project_id = top_rows[0].get("id")
return project_id if isinstance(project_id, str) else None
def _needs_project_resolution(message: str) -> bool:
lowered = message.lower()
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
prepared = dict(context)
if _needs_project_resolution(message):
resolved_project_id = await _resolve_project_id_from_message(message)
if resolved_project_id:
prepared["resolved_project_id"] = resolved_project_id
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
return prepared
def _all_tools() -> list[Any]:
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
debug = context.get("_debug")
if isinstance(debug, dict):
request_id = debug.get("request_id")
if isinstance(request_id, str) and request_id:
return request_id
return None
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
sanitized = dict(context)
sanitized.pop("_debug", None)
return sanitized
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
def _is_upcoming_timeline_query(message: str) -> bool:
lowered = message.lower()
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
has_timeline_topic = any(
token in lowered
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
)
return has_upcoming and has_timeline_topic
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
match = _TIMELINE_DMY_RE.search(dmy)
if not match:
return True
try:
parsed = date(
int(match.group("y")),
int(match.group("m")),
int(match.group("d")),
)
except ValueError:
return True
today = date.today()
return parsed >= today and parsed.year == today.year and parsed.month == today.month
def _normalize_tagged_list_lines(text: str, message: str) -> str:
if not text:
return text
upcoming_timeline_only = _is_upcoming_timeline_query(message)
output_lines: list[str] = []
for line in text.splitlines():
matches = list(_TAG_LINE_RE.finditer(line))
if not matches:
output_lines.append(line)
continue
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
if not had_non_tag_text and len(matches) == 1:
tag_text = matches[0].group(0)
if (
upcoming_timeline_only
and "<timeline>" in tag_text
and not _timeline_date_in_current_month_or_future(line)
):
continue
output_lines.append(tag_text)
continue
for match in matches:
tag_text = match.group(0)
if (
upcoming_timeline_only
and "<timeline>" in tag_text
and not _timeline_date_in_current_month_or_future(line)
):
continue
output_lines.append(tag_text)
return "\n".join(output_lines)
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
_FLOATING_EMPTY_FALLBACK = "No results found."
def _strip_floating_markup_fragment(text: str) -> str:
if not text:
return text
cleaned = _GENERIC_TAG_RE.sub("", text)
return _BRACKETED_ID_RE.sub("", cleaned)
def _strip_floating_markup(text: str) -> str:
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
if not text:
return text
cleaned = _strip_floating_markup_fragment(text)
# Collapse excessive spaces introduced by tag/id removal while preserving lines.
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
return "\n".join(line for line in lines if line)
def _fallback_from_raw_floating_text(raw_text: str) -> str:
fallback = _strip_floating_markup_fragment(raw_text or "")
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
return fallback or _FLOATING_EMPTY_FALLBACK
class _FloatingStreamSanitizer:
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
def __init__(self) -> None:
self._pending = ""
@staticmethod
def _split_safe_boundary(text: str) -> tuple[str, str]:
boundary = len(text)
last_lt = text.rfind("<")
if last_lt != -1 and ">" not in text[last_lt:]:
boundary = min(boundary, last_lt)
last_lb = text.rfind("[")
if last_lb != -1 and "]" not in text[last_lb:]:
boundary = min(boundary, last_lb)
if boundary == len(text):
return text, ""
return text[:boundary], text[boundary:]
def feed(self, chunk: str) -> str:
combined = f"{self._pending}{chunk}"
safe_text, self._pending = self._split_safe_boundary(combined)
return _strip_floating_markup_fragment(safe_text)
def finalize(self) -> str:
# Drop dangling unfinished wrappers at the very end.
tail = re.sub(r"<[^>\n]*$", "", self._pending)
tail = re.sub(r"\[[^\]\n]*$", "", tail)
self._pending = ""
return _strip_floating_markup_fragment(tail)
def _normalize_memory_label(path_or_label: str) -> str:
value = path_or_label.strip()
if value.startswith("/memories/"):
value = value[len("/memories/"):]
value = value.strip("/")
return value
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
@tool
async def memory_list_blocks() -> str:
"""List all core memory blocks currently stored for the user."""
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
async with async_session() as db:
memory = MemoryMiddleware(db)
blocks = await memory.list_core_blocks(user_id)
if not blocks:
return "No memory blocks found."
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
return "Memory blocks:\n" + "\n".join(lines)
@tool
async def memory_get(path_or_label: str) -> str:
"""Get one memory block by label or /memories/<label> path."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
value = await memory.get_core_block(user_id, label)
if value is None:
return f"Memory block '{label}' not found."
return f"Memory block '{label}':\n{value}"
@tool
async def memory_create(path_or_label: str, value: str) -> str:
"""Create or overwrite a memory block value by label or /memories/<label> path."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, label, value, trace_id=trace_id)
return f"Memory block '{label}' saved."
@tool
async def memory_append(path_or_label: str, content: str) -> str:
"""Append content to a memory block, creating it if missing."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.append_core(user_id, label, content)
return f"Memory block '{label}' appended."
@tool
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
"""Replace one exact string in a memory block."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
changed = await memory.replace_core(user_id, label, old_string, new_string)
if not changed:
return f"No replacement made in '{label}' (old string not found)."
return f"Memory block '{label}' updated."
@tool
async def memory_delete(path_or_label: str) -> str:
"""Delete a memory block by label or /memories/<label> path."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
deleted = await memory.delete_core(user_id, label)
if not deleted:
return f"Memory block '{label}' not found."
return f"Memory block '{label}' deleted."
@tool
async def archival_memory_insert(content: str) -> str:
"""Insert a long-term archival memory entry."""
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.insert_archival(user_id, content, source="assistant")
return "Archival memory saved."
@tool
async def archival_memory_search(query: str, top_k: int = 5) -> str:
"""Search long-term archival memory by semantic fallback (keyword currently)."""
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
async with async_session() as db:
memory = MemoryMiddleware(db)
results = await memory.search_archival(user_id, query, top_k=top_k)
if not results:
return "No archival memory results found."
lines = [f"- {item}" for item in results]
return "Archival memory results:\n" + "\n".join(lines)
@tool
async def conversation_search(query: str, top_k: int = 5) -> str:
"""Search recall memory from prior episodic conversation summaries."""
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
async with async_session() as db:
memory = MemoryMiddleware(db)
results = await memory.search_recall(user_id, query, top_k=top_k)
if not results:
return "No recall memory results found."
lines = [f"- {item}" for item in results]
return "Recall memory results:\n" + "\n".join(lines)
return [
memory_list_blocks,
memory_get,
memory_create,
memory_append,
memory_replace,
memory_delete,
archival_memory_insert,
archival_memory_search,
conversation_search,
]
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
lowered = message.lower()
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
return "timeline"
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
return "task"
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
return "note"
return None
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
type_raw = str(payload.get("type") or "").strip().lower()
domain_type: FloatingDomainType = "task"
if type_raw in {"task", "timeline", "project", "node"}:
domain_type = type_raw
id_value = payload.get("id")
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
if domain_type == "project" and not domain_id:
domain_id = fallback_id
section_raw = payload.get("section")
section: FloatingDomainSection | None = None
if isinstance(section_raw, str):
section_candidate = section_raw.strip().lower()
if section_candidate in {"task", "timeline", "note"}:
section = section_candidate
if domain_type != "project":
section = None
return {
"type": domain_type,
"id": domain_id,
"section": section,
}
def _parse_json_object(text: str) -> dict[str, Any] | None:
raw = text.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", raw, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
section = _detect_domain_section(message)
scope = context.get("scope") if isinstance(context, dict) else None
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
if isinstance(scope, dict):
scope_type = str(scope.get("type") or "").strip().lower()
scope_id = scope.get("id")
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
if scope_type in {"task", "tasks"}:
return {"type": "task", "id": scope_id_value, "section": None}
if scope_type in {"project", "projects"}:
project_scope_id = scope_id_value or project_id
return {
"type": "project",
"id": project_scope_id,
"section": section,
}
if scope_type in {"note", "notes"}:
return {
"type": "node",
"id": scope_id_value,
"section": None,
}
if scope_type in {"timeline", "timelines"}:
return {"type": "timeline", "id": scope_id_value, "section": None}
lowered = message.lower()
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
return {
"type": "project",
"id": project_id,
"section": section,
}
if section == "timeline":
return {"type": "timeline", "id": None, "section": None}
if section == "note":
return {"type": "node", "id": None, "section": None}
return {"type": "task", "id": None, "section": None}
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
classifier_context = {
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
"resolved_project_id": project_id,
}
try:
llm = get_llm()
classifier_messages = [
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
HumanMessage(
content=(
f"Message:\n{message}\n\n"
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
)
),
]
lf = get_langfuse()
_, classifier_prompt_obj = get_prompt_or_fallback(
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_PROMPT
)
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="floating-classifier",
model=settings.LLM_MODEL,
prompt=classifier_prompt_obj,
input=classifier_messages,
) as gen:
response = await llm.ainvoke(classifier_messages)
gen.update(output=_as_text(response.content), usage=extract_usage(response))
else:
response = await llm.ainvoke(classifier_messages)
parsed = _parse_json_object(_as_text(response.content))
if parsed is not None:
domain = _normalize_domain_payload(parsed, project_id)
logger.info(
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
domain.get("type"),
domain.get("id"),
domain.get("section"),
)
return domain
logger.warning("deep_agent: floating_domain classifier returned non-json output")
except Exception as exc:
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
return _infer_floating_domain_rule_based(message, context)
async def _run_single_agent(
*,
user_id: str,
system_prompt: str,
message: str,
context: dict[str, Any],
max_steps: int = 6,
langfuse_prompt: Any = None,
agent_name: str = "agent",
) -> str:
trace_id = _trace_id_from_context(context)
lf = get_langfuse()
llm = get_llm()
tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(
content=(
f"User message:\n{message}\n\n"
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
)
),
]
tool_calls_count = 0
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
_span_ctx = (
lf.start_as_current_observation(
as_type="span",
name=agent_name,
metadata={"user_id": user_id, "session_id": trace_id},
input=message,
)
if lf else None
)
_span = _span_ctx.__enter__() if _span_ctx else None
try:
for _ in range(max_steps):
_gen_ctx = (
lf.start_as_current_observation(
as_type="generation",
name=f"{agent_name}-llm",
model=settings.LLM_MODEL,
prompt=langfuse_prompt,
input=messages,
)
if lf else None
)
_gen = _gen_ctx.__enter__() if _gen_ctx else None
response: AIMessage = await llm_with_tools.ainvoke(messages)
if _gen_ctx:
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
_gen_ctx.__exit__(None, None, None)
messages.append(response)
if not response.tool_calls:
final_text = _as_text(response.content)
logger.info(
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
trace_id or "-",
user_id,
tool_calls_count,
len(final_text),
)
if _span:
_span.update(output=final_text)
return final_text
tool_map = {tool_def.name: tool_def for tool_def in tools}
for call in response.tool_calls:
tool_calls_count += 1
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
call_id,
call_name,
str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
final_text = _as_text(final.content)
logger.info(
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
trace_id or "-",
user_id,
tool_calls_count,
len(final_text),
)
if _span:
_span.update(output=final_text)
return final_text
finally:
clear_tool_result_collector()
if _span_ctx:
_span_ctx.__exit__(None, None, None)
if lf:
lf.flush()
async def _run_single_agent_stream(
*,
user_id: str,
system_prompt: str,
message: str,
context: dict[str, Any],
max_steps: int = 6,
langfuse_prompt: Any = None,
agent_name: str = "agent",
) -> AsyncGenerator[tuple[str, Any], None]:
trace_id = _trace_id_from_context(context)
lf = get_langfuse()
llm = get_llm()
tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(
content=(
f"User message:\n{message}\n\n"
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
)
),
]
tool_calls_count = 0
streamed_chars = 0
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
_span_ctx = (
lf.start_as_current_observation(
as_type="span",
name=f"{agent_name}-stream",
metadata={"user_id": user_id, "session_id": trace_id},
input=message,
)
if lf else None
)
_span = _span_ctx.__enter__() if _span_ctx else None
streamed_text: list[str] = []
try:
for _ in range(max_steps):
_gen_ctx = (
lf.start_as_current_observation(
as_type="generation",
name=f"{agent_name}-llm",
model=settings.LLM_MODEL,
prompt=langfuse_prompt,
input=messages,
)
if lf else None
)
_gen = _gen_ctx.__enter__() if _gen_ctx else None
response: AIMessage = await llm_with_tools.ainvoke(messages)
if _gen_ctx:
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
_gen_ctx.__exit__(None, None, None)
messages.append(response)
if not response.tool_calls:
emitted_any = False
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
streamed_chars += len(token)
streamed_text.append(token)
emitted_any = True
yield "token", token
# Some providers return final text in `response.content` but stream no chunks.
if not emitted_any:
fallback_text = _as_text(response.content)
if fallback_text:
streamed_chars += len(fallback_text)
streamed_text.append(fallback_text)
yield "token", fallback_text
logger.info(
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
trace_id or "-",
user_id,
tool_calls_count,
streamed_chars,
)
if _span:
_span.update(output="".join(streamed_text))
return
tool_map = {tool_def.name: tool_def for tool_def in tools}
for call in response.tool_calls:
tool_calls_count += 1
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
call_id,
call_name,
str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
streamed_chars += len(token)
streamed_text.append(token)
yield "token", token
logger.info(
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
trace_id or "-",
user_id,
tool_calls_count,
streamed_chars,
)
if _span:
_span.update(output="".join(streamed_text))
finally:
clear_tool_result_collector()
if _span_ctx:
_span_ctx.__exit__(None, None, None)
if lf:
lf.flush()
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
prepared_context = await _prepare_context(message, context)
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"home_system", _HOME_SYSTEM_PROMPT
)
response = await _run_single_agent(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="home-agent",
)
return _normalize_tagged_list_lines(response, message)
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"floating_system", _FLOATING_SYSTEM_PROMPT
)
response = await _run_single_agent(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="floating-agent",
)
sanitized = _strip_floating_markup(response)
if not sanitized and response:
sanitized = _fallback_from_raw_floating_text(response)
return sanitized, domain
async def run_home_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context)
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"home_system", _HOME_SYSTEM_PROMPT
)
text_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="home-agent",
):
event_type, data = event
if event_type != "token":
yield event
continue
text_chunks.append(str(data or ""))
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
if normalized:
yield "token", normalized
async def run_floating_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
yield "floating_domain", domain
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"floating_system", _FLOATING_SYSTEM_PROMPT
)
sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False
raw_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="floating-agent",
):
event_type, data = event
if event_type != "token":
yield event
continue
raw_chunk = str(data or "")
raw_chunks.append(raw_chunk)
sanitized_chunk = sanitizer.feed(raw_chunk)
if sanitized_chunk:
emitted_sanitized = True
yield "token", sanitized_chunk
tail = sanitizer.finalize()
if tail:
emitted_sanitized = True
yield "token", tail
if not emitted_sanitized and raw_chunks:
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
async def update_core_memory(user_id: str, key: str, value: str) -> None:
"""Compatibility helper kept for callers that expect explicit memory update API."""
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, key, value)

151
app/core/device_manager.py Normal file
View File

@@ -0,0 +1,151 @@
"""Device connection manager.
Maintains in-memory state for all active Electron → backend WebSocket
connections. One connection per user (latest replaces previous).
The manager handles the **tool-call round-trip** pattern:
- Backend sends ``tool_call`` frame → Electron executes the action →
returns ``tool_result`` frame.
- ``create_pending_call`` registers a Future keyed by ``call_id``.
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
receive the result dict from Electron.
This pattern is used by all tools (CRUD, file-system, etc.) via
``execute_on_client()`` in ``ws_context.py``.
The ``device_manager`` module-level singleton is imported by both the
device WS route and the agent runner.
"""
from __future__ import annotations
import asyncio
import json
import logging
from dataclasses import dataclass, field
from fastapi import WebSocket
logger = logging.getLogger(__name__)
@dataclass
class DeviceConnection:
"""State for a single connected Electron device."""
ws: WebSocket
device_id: str
# Futures indexed by tool_call id — resolved when tool_result arrives.
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
class DeviceConnectionManager:
"""Singleton registry of active Electron WebSocket connections.
Thread/task safety note: asyncio is single-threaded by design. All
mutations happen inside await-points on the main event loop, so no
locking is required for the in-memory dicts.
"""
def __init__(self) -> None:
self._connections: dict[str, DeviceConnection] = {}
# ── Registration ──────────────────────────────────────────────────
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
"""Store the active connection for *user_id*, replacing any previous one."""
if user_id in self._connections:
old = self._connections[user_id]
logger.info(
"device_manager: replacing existing connection for user=%s device=%s",
user_id,
old.device_id,
)
# Cancel any futures that were waiting on the old connection.
for fut in old.pending_calls.values():
if not fut.done():
fut.cancel()
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
logger.info(
"device_manager: registered user=%s device=%s", user_id, device_id
)
def unregister(self, user_id: str) -> None:
"""Remove the connection for *user_id* and cancel any pending futures."""
conn = self._connections.pop(user_id, None)
if conn is None:
return
for fut in conn.pending_calls.values():
if not fut.done():
fut.cancel()
logger.info("device_manager: unregistered user=%s", user_id)
# ── Presence queries ──────────────────────────────────────────────
def get_ws(self, user_id: str) -> WebSocket | None:
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
conn = self._connections.get(user_id)
return conn.ws if conn else None
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
"""Return ``True`` if the user has an active connection.
If *device_id* is provided also checks that it matches the connected device.
"""
conn = self._connections.get(user_id)
if conn is None:
return False
if device_id is not None:
return conn.device_id == device_id
return True
# ── Frame sending ─────────────────────────────────────────────────
async def send_frame(self, user_id: str, frame: dict) -> None:
"""Send *frame* as a JSON text message to the device.
Raises ``RuntimeError`` if the user is not connected.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"send_frame: user {user_id!r} is not connected"
)
await conn.ws.send_text(json.dumps(frame))
# ── Tool-call round-trip ──────────────────────────────────────────
def create_pending_call(
self, user_id: str, call_id: str
) -> asyncio.Future[dict]:
"""Register a Future that will be resolved when the tool_result arrives.
Raises ``RuntimeError`` if the user is not connected.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"create_pending_call: user {user_id!r} is not connected"
)
loop = asyncio.get_event_loop()
fut: asyncio.Future[dict] = loop.create_future()
conn.pending_calls[call_id] = fut
return fut
def resolve_pending_call(
self, user_id: str, call_id: str, result: dict
) -> None:
"""Fulfil the Future registered under *call_id* with the Electron result.
No-ops if the call_id is unknown (already timed out or cancelled).
"""
conn = self._connections.get(user_id)
if conn is None:
return
fut = conn.pending_calls.pop(call_id, None)
if fut is not None and not fut.done():
fut.set_result(result)
# Module-level singleton — import this everywhere.
device_manager = DeviceConnectionManager()

View File

@@ -1,222 +0,0 @@
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
from __future__ import annotations
from collections import OrderedDict
from typing import Any
from app.schemas import ExecutionPlan, PlanStep
# ── Prompt Template Registry ──────────────────────────────────────────
class PromptTemplateRegistry:
"""Server-side store mapping template IDs to prompt text.
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
The actual prompt text is resolved here on the server, keeping prompt IP
out of API responses.
"""
def __init__(self) -> None:
self._templates: dict[str, str] = {}
def register(self, template_id: str, prompt_text: str) -> None:
self._templates[template_id] = prompt_text
def get(self, template_id: str) -> str:
"""Resolve a template ID to its prompt text.
Raises ``KeyError`` if the template is not registered.
"""
text = self._templates.get(template_id)
if text is None:
raise KeyError(f"Template not found: {template_id!r}")
return text
def has(self, template_id: str) -> bool:
return template_id in self._templates
def list_ids(self) -> list[str]:
"""Return all registered template IDs (never the text)."""
return list(self._templates.keys())
# ── Execution Plan Builder ────────────────────────────────────────────
class ExecutionPlanBuilder:
"""Fluent builder for ``ExecutionPlan`` objects.
Example::
plan = (
ExecutionPlanBuilder("task_agent")
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
.add_data_step("create_record", data_from_step=0)
.build()
)
"""
def __init__(self, agent: str) -> None:
self._agent = agent
self._steps: list[PlanStep] = []
# ── step adders ──────────────────────────────────────────────────
def add_step(
self, action: str, params: dict[str, Any] | None = None
) -> ExecutionPlanBuilder:
"""Append a generic action step with optional parameters."""
self._steps.append(PlanStep(action=action, variables=params))
return self
def add_llm_step(
self, template_id: str, variables: dict[str, Any] | None = None
) -> ExecutionPlanBuilder:
"""Append an LLM step referencing a server-side template by ID."""
self._steps.append(
PlanStep(action="llm", prompt_template=template_id, variables=variables)
)
return self
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
"""Append a step whose input comes from the output of an earlier step."""
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
return self
# ── build ────────────────────────────────────────────────────────
def build(self) -> ExecutionPlan:
"""Validate step references and return the ``ExecutionPlan``.
Raises ``ValueError`` if any ``data_from_step`` references a
non-existent or future step index.
"""
for i, step in enumerate(self._steps):
if step.data_from_step is not None:
if not (0 <= step.data_from_step < i):
raise ValueError(
f"Step {i}: data_from_step={step.data_from_step} must "
f"reference a preceding step index in range 0..{i - 1}"
)
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
class PlanCache:
"""In-memory LRU cache for ``ExecutionPlan`` objects.
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
The cache also serves as a runtime memoisation layer so that repeated
identical intent classifications can skip re-building the plan.
"""
def __init__(self, maxsize: int = 1000) -> None:
self._maxsize = maxsize
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
if key in self._cache:
del self._cache[key] # remove so re-insertion places it at the end
elif len(self._cache) >= self._maxsize:
self._cache.popitem(last=False) # evict least-recently-used
self._cache[key] = plan
def get_plan(self, key: str) -> ExecutionPlan | None:
"""Return the cached plan for *key*, or ``None`` if not present.
Accessing a plan marks it as most-recently used.
"""
if key not in self._cache:
return None
self._cache.move_to_end(key)
return self._cache[key]
def get_all_playbooks(self) -> list[ExecutionPlan]:
"""Return all cached plans (most-recently used last)."""
return list(self._cache.values())
# ── Module-level singletons ───────────────────────────────────────────
template_registry = PromptTemplateRegistry()
plan_cache = PlanCache()
def _register_builtin_templates() -> None:
"""Register the built-in server-side prompt templates.
These strings never leave the server. Clients only receive the IDs.
"""
_tpls: dict[str, str] = {
"tpl_task_agent_default": (
"You are a task management assistant. Help the user create, update, "
"list, and track tasks. Use correct status values (todo, in_progress, "
"done) and priority values (high, medium, low) from the workspace model."
),
"tpl_checkpoint_agent_default": (
"You are a project checkpoint assistant. Help the user create and manage "
"milestone checkpoints on their projects. Every checkpoint requires a "
"project_id and a date expressed as a Unix timestamp in milliseconds."
),
"tpl_project_agent_default": (
"You are a project management assistant. Help the user create, find, "
"update, and archive projects. Projects have a name, an optional client, "
"and a status of either active or archived."
),
"tpl_note_agent_default": (
"You are a note-taking assistant. Help the user create, retrieve, update, "
"and delete Markdown notes. Notes can optionally be linked to a project."
),
"tpl_task_extract_from_project": (
"Extract all actionable tasks from the provided project context. "
"Return a structured list of tasks, each with a title, inferred priority "
"(high, medium, or low), suggested status (todo), and a due_date in "
"milliseconds where a deadline can be inferred."
),
"tpl_note_weekly_summary": (
"Generate a weekly project summary note from the provided workspace data. "
"Include: tasks completed this week, tasks due soon, active projects, "
"and upcoming checkpoints. Format the output as clean Markdown."
),
}
for tid, text in _tpls.items():
template_registry.register(tid, text)
def _load_playbooks() -> None:
"""Pre-build and cache the built-in playbooks."""
playbooks: list[tuple[str, ExecutionPlan]] = [
(
"create_tasks_from_project",
ExecutionPlanBuilder("project_agent")
.add_llm_step(
"tpl_task_extract_from_project",
{"source": "project_context"},
)
.add_data_step("create_record", data_from_step=0)
.build(),
),
(
"generate_weekly_note",
ExecutionPlanBuilder("note_agent")
.add_llm_step(
"tpl_note_weekly_summary",
{"period": "last_7_days"},
)
.add_data_step("create_record", data_from_step=0)
.build(),
),
]
for key, plan in playbooks:
plan_cache.cache_plan(key, plan)
# Initialise on module load
_register_builtin_templates()
_load_playbooks()

147
app/core/langfuse_client.py Normal file
View File

@@ -0,0 +1,147 @@
"""Langfuse observability — singleton client and prompt helpers.
If LANGFUSE_SECRET_KEY / LANGFUSE_PUBLIC_KEY are not set,
all helpers are no-ops so the app works without Langfuse configured.
Usage
-----
Tracing::
from app.core.langfuse_client import get_langfuse
lf = get_langfuse()
if lf:
with lf.start_as_current_observation(as_type="span", name="my-agent") as span:
span.update(input=user_message)
# ... do work ...
span.update(output=result)
lf.flush()
Prompt management::
from app.core.langfuse_client import get_prompt_or_fallback
text, prompt_obj = get_prompt_or_fallback("home_system", FALLBACK_PROMPT)
# Use text as the system prompt; pass prompt_obj to generations for linking.
Linking a prompt to a generation::
with lf.start_as_current_observation(
as_type="generation",
name="llm-call",
model="gpt-4o",
prompt=prompt_obj, # links generation → prompt version in the UI
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=_usage(response))
"""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
_client: Any = None
_initialized: bool = False
def get_langfuse() -> Any | None:
"""Return the Langfuse singleton, or ``None`` when not configured."""
global _client, _initialized
if _initialized:
return _client
_initialized = True
from app.config.settings import settings # local import to avoid circular deps
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
logger.debug("langfuse: not configured — observability disabled")
return None
try:
from langfuse import Langfuse
_client = Langfuse(
secret_key=settings.LANGFUSE_SECRET_KEY,
public_key=settings.LANGFUSE_PUBLIC_KEY,
host=settings.LANGFUSE_HOST,
)
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_HOST)
except Exception as exc:
logger.warning("langfuse: failed to initialize: %s", exc)
_client = None
return _client
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
Returns ``(raw_template, prompt_obj_or_None)``.
* ``raw_template`` — the uncompiled template string. Do NOT call ``.format()``
on it directly; use :func:`compile_prompt` instead so the correct variable
syntax is applied (``{{var}}`` for Langfuse, ``{var}`` for the fallback).
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
unavailable / the fetch failed. Pass this to generation observations so
Langfuse links the generation to the exact prompt version in the UI.
"""
lf = get_langfuse()
if lf is None:
return fallback, None
try:
prompt = lf.get_prompt(name, label="production", fallback=fallback)
# For text-type prompts .prompt holds the raw template string.
raw = prompt.prompt if hasattr(prompt, "prompt") and isinstance(prompt.prompt, str) else fallback
return raw, prompt
except Exception as exc:
logger.warning("langfuse: get_prompt %r failed: %s — using fallback", name, exc)
return fallback, None
def compile_prompt(template: str, prompt_obj: Any, **variables: Any) -> str:
"""Compile *template* with *variables*, choosing the right syntax.
* When *prompt_obj* is a real Langfuse prompt object, calls
``prompt_obj.compile(**variables)`` which handles ``{{variable}}``
substitution as defined in the Langfuse UI.
* When *prompt_obj* is ``None`` (Langfuse unavailable or fetch failed),
falls back to ``template.format(**variables)`` which handles the
``{variable}`` syntax used in the hardcoded fallback strings.
This keeps callers oblivious to which syntax is in use.
"""
if prompt_obj is not None:
try:
compiled = prompt_obj.compile(**variables)
# compile() returns a string for text prompts.
if isinstance(compiled, str):
return compiled
# Chat prompts return a list of dicts — join text parts.
if isinstance(compiled, list):
return "\n".join(
m.get("content", "") for m in compiled if isinstance(m, dict)
)
except Exception as exc:
logger.warning(
"langfuse: compile failed for prompt %r: %s — falling back to .format()",
getattr(prompt_obj, "name", "?"),
exc,
)
return template.format(**variables)
def extract_usage(response: Any) -> dict[str, int]:
"""Extract token usage from a LangChain AI message into Langfuse format."""
meta = getattr(response, "usage_metadata", None)
if not meta:
return {}
return {
"input": int(meta.get("input_tokens", 0)),
"output": int(meta.get("output_tokens", 0)),
"total": int(meta.get("total_tokens", 0)),
}

View File

@@ -17,11 +17,30 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
from __future__ import annotations
import os
import warnings
from openai import AsyncOpenAI
import litellm
from langchain_openai import ChatOpenAI
from langchain_litellm import ChatLiteLLM
from litellm import get_supported_openai_params # noqa: F401 validates install
from app.config.settings import settings
# Some models (e.g. gpt-5, o-series) reject unsupported params like temperature.
# Drop them silently instead of raising UnsupportedParamsError.
litellm.drop_params = True
# Some provider responses include a plain dict in the `usage` field where a
# richer Pydantic model is expected. This warning is noisy but non-fatal.
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
"""Return the most appropriate API key for the given LiteLLM model string."""
@@ -29,6 +48,12 @@ def _api_key_for_model(model: str) -> str | None:
return settings.ANTHROPIC_API_KEY or None
if model.startswith("gemini/") or model.startswith("google/"):
return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None
if model.startswith("github_copilot/"):
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
# No API key is required; returning None lets LiteLLM handle auth.
return None
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
return settings.OPENAI_API_KEY or None
@@ -37,7 +62,7 @@ def get_llm(
*,
model: str | None = None,
temperature: float = 0,
) -> ChatOpenAI:
) -> ChatOpenAI | ChatLiteLLM:
"""Return a LangChain chat model backed by LiteLLM.
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
@@ -53,6 +78,16 @@ def get_llm(
Sampling temperature. ``0`` = deterministic.
"""
model = model or settings.LLM_MODEL
# Point LiteLLM to the custom token directory when configured.
if settings.GITHUB_COPILOT_TOKEN_DIR:
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
# Use ChatLiteLLM for provider-prefixed models (github_copilot/, anthropic/, etc.)
# so LiteLLM handles routing and auth. ChatOpenAI for plain OpenAI model names.
if "/" in model:
return ChatLiteLLM(model=model, temperature=temperature)
return ChatOpenAI(
model=model,
temperature=temperature,
@@ -63,6 +98,28 @@ def get_llm(
def get_router_llm(
*,
temperature: float = 0,
) -> ChatOpenAI:
) -> ChatOpenAI | ChatLiteLLM:
"""Return the lighter model used for intent classification / routing."""
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
async def embed(text: str) -> list[float]:
"""Return an embedding vector for *text*.
Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env``
(e.g. ``github_copilot/text-embedding-3-small``) applies here without any
code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI
model names to preserve existing behaviour.
"""
model = settings.LLM_EMBED_MODEL
if model.startswith("github_copilot/") or "/" in model:
# Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.)
# so the provider's auth mechanism is applied correctly.
response = await litellm.aembedding(model=model, input=[text])
return response.data[0]["embedding"]
# Plain OpenAI model name — use the raw AsyncOpenAI client (existing path).
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
response = await client.embeddings.create(model=model, input=text)
return response.data[0].embedding

View File

@@ -0,0 +1,441 @@
"""Memory Middleware — enrich requests with memory context and store interactions.
Four-tier memory model (MemGPT-style):
core — persistent key/value user preferences, always injected
associative — semantic similarity search via pgvector (top-k)
episodic — recent session summaries (last N)
proactive — behavioral patterns above confidence threshold
All memory content is encrypted at rest using the per-user Fernet key
stored in User.encryption_key. Decryption happens in-memory only.
Usage:
memory = MemoryMiddleware(db_session)
context = await memory.enrich_context(user_id, message)
# ... run agent ...
await memory.store_episode(user_id, session_id, message, response)
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import (
MemoryAssociative,
MemoryCore,
MemoryEpisodic,
MemoryProactive,
User,
)
logger = logging.getLogger(__name__)
# Tuning constants
_ASSOCIATIVE_TOP_K = 5
_EPISODIC_RECENT_N = 10
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
class MemoryMiddleware:
"""Enrich orchestrator context with memory and persist interactions after."""
def __init__(self, db: AsyncSession) -> None:
self._db = db
# ── Public API ────────────────────────────────────────────────────────────
async def enrich_context(
self,
user_id: str,
message: str,
trace_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Build memory context dict to inject into the orchestrator before LLM call.
Returns a dict with keys:
core_memory — {key: plaintext_value, ...}
associative_memory — [plaintext_content, ...] (top-k by keyword match)
episodic_memory — [plaintext_summary, ...] (most recent N)
proactive_hints — [plaintext_pattern, ...] (above threshold)
"""
fernet = await self._get_fernet(user_id)
if fernet is None:
return {}
core = await self._load_core(user_id, fernet)
associative = await self._load_associative(user_id, message, fernet)
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
proactive = await self._load_proactive(user_id, fernet)
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
len(core),
len(associative),
len(episodic),
len(proactive),
)
return {
"core_memory": core,
"associative_memory": associative,
"episodic_memory": episodic,
"proactive_hints": proactive,
}
async def store_episode(
self,
user_id: str,
session_id: str,
message: str,
response: str,
trace_id: str | None = None,
) -> None:
"""Summarise and store a completed interaction in episodic memory.
The summary is a simple heuristic concatenation (no LLM call) to keep
latency low. Full LLM summarisation can be added in a later step.
"""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
encrypted = _encrypt(fernet, summary)
row = MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=user_id,
summary_encrypted=encrypted,
session_id=session_id,
)
self._db.add(row)
try:
await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: store_episode trace=%s user=%s tier=%s session=%s",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
session_id,
)
except Exception as exc:
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
"""Upsert a core memory key/value for a user."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, value)
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == key,
)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.value_encrypted = encrypted
else:
self._db.add(MemoryCore(
id=str(uuid.uuid4()),
user_id=user_id,
key=key,
value_encrypted=encrypted,
))
try:
await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: update_core trace=%s user=%s tier=%s key=%s",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
key,
)
except Exception as exc:
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
await self._db.rollback()
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
"""Return core memory as editable blocks (label/value)."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryCore)
.where(MemoryCore.user_id == user_id)
.order_by(MemoryCore.key.asc())
)
rows = result.scalars().all()
out: list[dict[str, str]] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out.append({"label": row.key, "value": plaintext})
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
return out
async def get_core_block(self, user_id: str, label: str) -> str | None:
"""Return a single core memory block value by label."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return None
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == label,
)
)
row = result.scalar_one_or_none()
if row is None:
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
return None
value = _safe_decrypt(fernet, row.value_encrypted)
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
return value
async def delete_core(self, user_id: str, label: str) -> bool:
"""Delete a core memory block by label. Returns True if deleted."""
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == label,
)
)
row = result.scalar_one_or_none()
if row is None:
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
return False
await self._db.delete(row)
try:
await self._db.commit()
logger.info("memory: delete_core user=%s label=%s", user_id, label)
return True
except Exception as exc:
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
await self._db.rollback()
return False
async def append_core(self, user_id: str, label: str, content: str) -> None:
"""Append content to a core block, creating it if missing."""
current = await self.get_core_block(user_id, label)
if current is None:
await self.update_core(user_id, label, content)
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
return
await self.update_core(user_id, label, f"{current}\n{content}")
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
"""Replace one exact string inside a core block. Returns False if not found."""
current = await self.get_core_block(user_id, label)
if current is None or old not in current:
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
return False
await self.update_core(user_id, label, current.replace(old, new, 1))
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
return True
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
"""Insert a long-term archival memory entry."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, content)
row = MemoryAssociative(
id=str(uuid.uuid4()),
user_id=user_id,
content_encrypted=encrypted,
embedding=None,
entity_type=source,
entity_id=None,
)
self._db.add(row)
try:
await self._db.commit()
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
except Exception as exc:
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc())
.limit(100)
)
rows = result.scalars().all()
needle = query.strip().lower()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
return out
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
"""Search recall memory (episodic summaries) by keyword."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryEpisodic)
.where(MemoryEpisodic.user_id == user_id)
.order_by(MemoryEpisodic.created_at.desc())
.limit(100)
)
rows = result.scalars().all()
needle = query.strip().lower()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
return out
# ── Private helpers ───────────────────────────────────────────────────────
async def _get_fernet(self, user_id: str) -> Fernet | None:
"""Load the user's Fernet key from DB. Returns None if missing."""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.encryption_key:
logger.warning("memory: no encryption_key for user=%s", user_id)
return None
return Fernet(user.encryption_key.encode())
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
"""Load lightweight user debug fields for trace logs."""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return {"tier": None}
return {
"tier": user.tier,
}
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id)
)
rows = result.scalars().all()
out: dict[str, str] = {}
for row in rows:
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out[row.key] = plaintext
return out
async def _load_associative(
self, user_id: str, message: str, fernet: Fernet
) -> list[str]:
"""Load top-k associative memories.
Production: uses pgvector cosine similarity on the message embedding.
Current implementation: keyword-based fallback (no external embedding call)
so tests pass without a live OpenAI key.
"""
result = await self._db.execute(
select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc())
.limit(_ASSOCIATIVE_TOP_K)
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
async def _load_episodic(
self,
user_id: str,
fernet: Fernet,
session_id: str | None = None,
) -> list[str]:
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
if session_id:
query = query.where(MemoryEpisodic.session_id == session_id)
result = await self._db.execute(
query
.order_by(MemoryEpisodic.created_at.desc())
.limit(_EPISODIC_RECENT_N)
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
result = await self._db.execute(
select(MemoryProactive)
.where(
MemoryProactive.user_id == user_id,
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
)
.order_by(MemoryProactive.confidence.desc())
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
# ── Encryption helpers ────────────────────────────────────────────────────────
def _encrypt(fernet: Fernet, plaintext: str) -> str:
return fernet.encrypt(plaintext.encode()).decode()
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
try:
return fernet.decrypt(ciphertext.encode()).decode()
except (InvalidToken, Exception) as exc:
logger.warning("memory: decrypt failed: %s", exc)
return None

View File

@@ -1,168 +0,0 @@
"""Orchestrator — LLM-based intent router and agent pipeline."""
from __future__ import annotations
import json
from typing import Any, AsyncGenerator
from langchain_core.messages import HumanMessage, SystemMessage
from app.core.agent_registry import AgentRegistry
from app.core.llm import get_router_llm
from app.core.agent_registry import registry as _default_registry
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
_FALLBACK_AGENT = "task_agent"
_CLASSIFY_SYSTEM = (
"You are an intent classifier. Given the user message and context, decide "
"which agent to route to.\n"
"Available agents: {agents}\n"
"Respond with just the agent name, nothing else."
)
_SYNTHESIZE_HUMAN = (
"Combine the following agent results into one coherent response.\n\n"
"Agent results:\n{results}\n\n"
"Original message: {message}"
)
def _make_llm():
return get_router_llm()
async def classify_intent(
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> str:
"""Use gpt-4o-mini to classify intent and return the matching agent name.
Falls back to ``task_agent`` when the registry is empty or the model
returns a name that is not registered.
"""
agents = reg.list_agents()
if not agents:
return _FALLBACK_AGENT
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
# Truncate context to keep the classification prompt short
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
llm = _make_llm()
response = await llm.ainvoke(
[SystemMessage(content=system), HumanMessage(content=human)]
)
agent_name = str(response.content).strip().lower()
known = {a["name"] for a in agents}
return agent_name if agent_name in known else _FALLBACK_AGENT
async def route_single(
agent_name: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> ChatResponse:
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
response_text = await reg.call_agent(agent_name, message, context)
return ChatResponse(response=response_text)
async def route_pipeline(
agent_names: list[str],
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> ChatResponse:
"""Execute agents sequentially; each agent receives previous results in context.
A final LLM synthesis call merges all results into one coherent response.
"""
previous_results: list[str] = []
for agent_name in agent_names:
ctx = {**context, "previous_results": list(previous_results)}
result = await reg.call_agent(agent_name, message, ctx)
previous_results.append(result)
results_str = "\n\n".join(
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
)
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
llm = _make_llm()
synthesis = await llm.ainvoke([HumanMessage(content=human)])
return ChatResponse(response=str(synthesis.content))
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
"""Build an ``ExecutionPlan`` for the resolved agent.
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
If a default template exists for the agent, an LLM step is emitted;
otherwise a plain ``handle`` action step is used.
"""
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
template_id = f"tpl_{agent_name}_default"
builder = ExecutionPlanBuilder(agent_name)
if template_registry.has(template_id):
builder.add_llm_step(template_id, {"message": message})
else:
builder.add_step("handle", {"message": message})
return builder.build()
async def orchestrate(
request: ChatRequest,
reg: AgentRegistry | None = None,
) -> ChatResponse | ExecutionPlan:
"""Main orchestration entry point.
* Classifies the user's intent to select an agent.
* ``execution_mode == 'direct'``: routes to the agent and returns a
``ChatResponse``.
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
resolved agent and a template-ID-only step (prompt IP stays server-side).
"""
if reg is None:
reg = _default_registry
context = request.context.model_dump()
agent_name = await classify_intent(request.message, context, reg)
if request.execution_mode == "direct":
return await route_single(agent_name, request.message, context, reg)
# plan mode — return plan, do not execute
return _build_plan(agent_name, request.message)
async def orchestrate_stream(
request: ChatRequest,
reg: AgentRegistry | None = None,
) -> AsyncGenerator[str, None]:
"""Streaming orchestration — yields text chunks then a final JSON frame.
The final frame is a JSON object:
``{"done": true, "response": "...", "actions": []}``.
Agents do not yet support token-level streaming; the full response is
fetched first, then emitted in fixed-size chunks. Token-level streaming
will be wired in Step 6 when agents expose ``astream()``.
"""
if reg is None:
reg = _default_registry
context = request.context.model_dump()
agent_name = await classify_intent(request.message, context, reg)
response_text = await reg.call_agent(agent_name, request.message, context)
chunk_size = 50
for i in range(0, len(response_text), chunk_size):
yield response_text[i : i + chunk_size]
final = ChatResponse(response=response_text)
yield json.dumps({"done": True, **final.model_dump()})

View File

@@ -0,0 +1,47 @@
"""Output formatter for deep-agent stream events."""
from __future__ import annotations
from collections.abc import AsyncGenerator
from typing import Any
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
class StreamFormatter:
"""Convert `(event_type, data)` stream events into websocket frame models."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
async def format(
self,
event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]:
started = False
async for event_type, data in event_stream:
if event_type == "floating_domain":
if isinstance(data, dict):
yield WsFloatingDomain(
request_id=self.request_id,
domain=data,
)
continue
if event_type != "token":
continue
if not started:
yield WsStreamStart(request_id=self.request_id)
started = True
text = str(data or "")
if text:
yield WsStreamText(request_id=self.request_id, chunk=text)
if not started:
yield WsStreamStart(request_id=self.request_id)
yield WsStreamEnd(request_id=self.request_id)

View File

@@ -0,0 +1,104 @@
"""Preprocessor registry: detect content type and dispatch to handlers.
Public API
----------
detect_content_type(filename, raw_content) -> str
Heuristic detection based on file extension and content patterns.
preprocess(content_type, raw_content) -> PreprocessResult
Dispatch to the appropriate handler.
"""
from __future__ import annotations
import re
from app.core.preprocessors.base import PreprocessResult
# ── Heuristics ────────────────────────────────────────────────────────
# Patterns that strongly suggest an email HTML file
_EMAIL_SIGNALS = re.compile(
r"(Subject:|From:|To:|Date:|Sent:|MIME-Version:|Content-Type:\s*text/html)",
re.IGNORECASE,
)
# Patterns that suggest a generic HTML page (not an email)
_GENERIC_HTML_SIGNALS = re.compile(
r"<(nav|main|header|footer|article|section)\b",
re.IGNORECASE,
)
def detect_content_type(filename: str, raw_content: str) -> str:
"""Return a content-type string for the given file.
Supported types: ``"email_html"``, ``"generic_html"``,
``"plain_text"``, ``"unknown"``.
"""
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
if ext == "txt":
return "plain_text"
if ext in ("html", "htm", "eml", "mhtml", "mht"):
# Prefer email detection over generic HTML
if _EMAIL_SIGNALS.search(raw_content[:4096]):
return "email_html"
if _GENERIC_HTML_SIGNALS.search(raw_content[:4096]) or "<html" in raw_content[:200].lower():
return "generic_html"
# .html without clear signals — check for any email header
if re.search(r"^(From|To|Subject|Date):", raw_content[:2048], re.MULTILINE | re.IGNORECASE):
return "email_html"
return "generic_html"
# Plain text files with email headers
if ext in ("", "txt") or not ext:
if _EMAIL_SIGNALS.search(raw_content[:4096]):
return "email_html"
# Detect binary content
try:
raw_content.encode("utf-8")
except (UnicodeEncodeError, AttributeError):
return "unknown"
# Non-text bytes heuristic: high ratio of non-printable chars
sample = raw_content[:512]
non_printable = sum(1 for c in sample if ord(c) < 32 and c not in "\r\n\t")
if len(sample) > 0 and non_printable / len(sample) > 0.1:
return "unknown"
return "unknown"
# ── Generic fallback handler ──────────────────────────────────────────
def _preprocess_generic(raw_content: str, content_type: str) -> PreprocessResult:
"""Strip HTML tags if present, return text as-is."""
try:
from bs4 import BeautifulSoup
text = BeautifulSoup(raw_content, "html.parser").get_text(separator="\n")
except ImportError:
# No BeautifulSoup — strip tags with a simple regex
text = re.sub(r"<[^>]+>", "", raw_content)
text = re.sub(r"\n{3,}", "\n\n", text).strip()
return PreprocessResult(content_type=content_type, clean_text=text, metadata={})
# ── Dispatch ──────────────────────────────────────────────────────────
def preprocess(content_type: str, raw_content: str) -> PreprocessResult:
"""Dispatch *raw_content* to the handler registered for *content_type*.
Falls back to the generic handler for unknown types.
"""
if content_type == "email_html":
from app.core.preprocessors.email_html import preprocess_email_html
return preprocess_email_html(raw_content)
return _preprocess_generic(raw_content, content_type)
__all__ = ["detect_content_type", "preprocess", "PreprocessResult"]

View File

@@ -0,0 +1,25 @@
"""Base types for the preprocessor system."""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class PreprocessResult:
"""Output of a preprocessor handler.
Attributes
----------
content_type:
The detected content type (e.g. ``"email_html"``, ``"plain_text"``).
clean_text:
Human-readable text stripped of markup/binary noise.
metadata:
Dict of extracted metadata (keys vary by handler).
Common keys: ``subject``, ``from``, ``to``, ``date``, ``filename``.
"""
content_type: str
clean_text: str
metadata: dict = field(default_factory=dict)

View File

@@ -0,0 +1,111 @@
"""Preprocessor for email HTML files.
Handles:
- HTML stripping via BeautifulSoup
- Metadata extraction (Subject, From, To, Date)
- Thread splitting — isolates the latest reply
"""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from app.core.preprocessors.base import PreprocessResult
if TYPE_CHECKING:
pass
# ── Thread split markers ──────────────────────────────────────────────
# Matches patterns like:
# "On Mon, Apr 7, 2026 at 10:00 AM, Alice <alice@co.com> wrote:"
# "-----Original Message-----"
# "> " (plain-text quote prefix)
_THREAD_PATTERNS = [
re.compile(r"^On\s+.+wrote\s*:", re.IGNORECASE | re.MULTILINE),
re.compile(r"^-{3,}\s*(original message|forwarded message)\s*-{3,}", re.IGNORECASE | re.MULTILINE),
re.compile(r"^>{1,}\s+\S", re.MULTILINE),
re.compile(r"^From:\s+.+\nSent:\s+", re.IGNORECASE | re.MULTILINE),
]
# ── Metadata patterns (applied on raw HTML / plain fallback) ──────────
_META_PATTERNS: dict[str, list[re.Pattern]] = {
"subject": [
re.compile(r"<title>(.+?)</title>", re.IGNORECASE | re.DOTALL),
re.compile(r"Subject:\s*(.+)", re.IGNORECASE),
],
"from": [
re.compile(r'<meta[^>]+name=["\']?from["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
re.compile(r"From:\s*(.+)", re.IGNORECASE),
],
"to": [
re.compile(r'<meta[^>]+name=["\']?to["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
re.compile(r"To:\s*(.+)", re.IGNORECASE),
],
"date": [
re.compile(r'<meta[^>]+name=["\']?date["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
re.compile(r"Date:\s*(.+)", re.IGNORECASE),
re.compile(r"Sent:\s*(.+)", re.IGNORECASE),
],
}
def _extract_metadata(raw_html: str, text: str) -> dict:
"""Extract Subject/From/To/Date from raw HTML or plain text."""
metadata: dict[str, str] = {}
for field, patterns in _META_PATTERNS.items():
for pat in patterns:
m = pat.search(raw_html) or pat.search(text)
if m:
metadata[field] = m.group(1).strip()
break
return metadata
def _split_thread(text: str) -> str:
"""Return only the latest message in a threaded email."""
earliest_pos: int | None = None
for pat in _THREAD_PATTERNS:
m = pat.search(text)
if m and (earliest_pos is None or m.start() < earliest_pos):
earliest_pos = m.start()
if earliest_pos is not None and earliest_pos > 0:
return text[:earliest_pos].strip()
return text.strip()
def preprocess_email_html(raw_content: str) -> PreprocessResult:
"""Strip HTML, extract metadata, split thread from an email HTML file."""
try:
from bs4 import BeautifulSoup # lazy import — optional dep
except ImportError as exc:
raise ImportError(
"beautifulsoup4 is required for email_html preprocessing. "
"Install it with: pip install beautifulsoup4"
) from exc
# Parse with lxml if available, fall back to html.parser
try:
soup = BeautifulSoup(raw_content, "lxml")
except Exception:
soup = BeautifulSoup(raw_content, "html.parser")
# Remove noise tags
for tag in soup(["style", "script", "head", "noscript"]):
tag.decompose()
clean_text = soup.get_text(separator="\n")
# Collapse excessive blank lines
clean_text = re.sub(r"\n{3,}", "\n\n", clean_text).strip()
metadata = _extract_metadata(raw_content, clean_text)
latest_message = _split_thread(clean_text)
return PreprocessResult(
content_type="email_html",
clean_text=latest_message,
metadata=metadata,
)

92
app/core/ws_context.py Normal file
View File

@@ -0,0 +1,92 @@
"""WebSocket client executor context.
Holds a per-request async callback that tools call to execute CRUD
operations on the Electron client's local SQLite / LanceDB databases.
The callback sends a `tool_call` WS frame and awaits the `tool_result`.
"""
from __future__ import annotations
from contextvars import ContextVar
from typing import Any, Callable, Coroutine
from uuid import uuid4
# Holds the execute callback for the current WS session.
# Set by the chat WS handler before the orchestrator runs; cleared after.
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
"_client_executor"
)
# Optional collector that captures raw execute_on_client results.
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None
)
def set_tool_result_collector(lst: list[dict]) -> None:
"""Register *lst* as the collector for this async context."""
_tool_result_collector.set(lst)
def clear_tool_result_collector() -> None:
"""Clear the collector (best-effort)."""
_tool_result_collector.set(None)
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
_client_executor.set(fn)
def clear_client_executor() -> None:
"""Remove the executor binding (best-effort; ContextVar resets on task exit)."""
try:
_client_executor.set(None) # type: ignore[arg-type]
except Exception:
pass
async def execute_on_client(
action: str,
table: str | None = None,
data: dict[str, Any] | None = None,
filters: dict[str, Any] | None = None,
vector: list[float] | None = None,
limit: int | None = None,
) -> dict[str, Any]:
"""Send a CRUD/vector operation to the Electron client and return the result.
Builds a ``tool_call`` payload, invokes the per-session WS callback,
and returns the ``tool_result`` dict from Electron.
Raises ``RuntimeError`` if no executor is set (i.e. called outside a WS session).
"""
callback = _client_executor.get(None)
if callback is None:
raise RuntimeError(
"execute_on_client() called outside a WebSocket session — "
"no client executor is set."
)
payload: dict[str, Any] = {"id": str(uuid4()), "action": action}
if table is not None:
payload["table"] = table
if data is not None:
payload["data"] = data
if filters is not None:
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
if vector is not None:
payload["vector"] = vector
if limit is not None:
payload["limit"] = limit
result = await callback(payload)
collector = _tool_result_collector.get(None)
if collector is not None:
collector.append({
"action": action,
"table": table,
"data": result,
})
return result

View File

@@ -24,7 +24,7 @@ from app.config.settings import settings
engine = create_async_engine(
settings.DATABASE_URL,
pool_pre_ping=True,
echo=settings.ENV == "dev",
echo=False,
)
async_session = async_sessionmaker(engine, expire_on_commit=False)

View File

@@ -0,0 +1,164 @@
"""Cloud provider integration utilities.
Provides:
* Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by
both the Gmail and MS Graph clients and consumed by ``agent_runner``.
* ``get_provider()`` — factory that returns the correct client given a
provider name and decrypted OAuth credentials dict.
* ``encrypt_token()`` / ``decrypt_token()`` — Fernet-based at-rest
encryption for OAuth tokens stored in ``cloud_agent_configs``.
Encryption rationale
--------------------
Unlike user content (which is E2E-encrypted client-side and **never**
decrypted server-side), OAuth tokens *must* be decrypted server-side
because the backend makes provider API calls on behalf of the user.
The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var — it
is never returned to clients.
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from app.config.settings import settings
if TYPE_CHECKING:
from app.integrations.gmail import GmailClient
from app.integrations.ms_graph import MSGraphClient
logger = logging.getLogger(__name__)
# ── Shared message types ──────────────────────────────────────────────────
@dataclass
class EmailMessage:
"""A single email message fetched from Gmail or Outlook."""
id: str
subject: str
sender: str
body_text: str
date: datetime
labels: list[str] = field(default_factory=list)
@property
def as_text(self) -> str:
"""Return a human-readable text representation for LLM extraction."""
date_str = self.date.strftime("%Y-%m-%d %H:%M")
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
return (
f"From: {self.sender}\n"
f"Date: {date_str}{labels_str}\n"
f"Subject: {self.subject}\n\n"
f"{self.body_text}"
)
@dataclass
class ChatMessage:
"""A single Teams chat or channel message fetched from MS Graph."""
id: str
content: str
sender: str
channel: str | None
date: datetime
@property
def as_text(self) -> str:
"""Return a human-readable text representation for LLM extraction."""
date_str = self.date.strftime("%Y-%m-%d %H:%M")
channel_str = f" [channel: {self.channel}]" if self.channel else ""
return (
f"From: {self.sender}\n"
f"Date: {date_str}{channel_str}\n\n"
f"{self.content}"
)
# ── Fernet helpers ────────────────────────────────────────────────────────
def _get_fernet() -> Fernet:
"""Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``.
Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set — callers
must ensure this is configured before persisting OAuth tokens.
"""
key = settings.OAUTH_ENCRYPTION_KEY
if not key:
raise RuntimeError(
"OAUTH_ENCRYPTION_KEY is not set. "
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
)
return Fernet(key.encode() if isinstance(key, str) else key)
def encrypt_token(token_info: dict) -> str:
"""Fernet-encrypt an OAuth credential dict and return a base64 string.
Stores the full ``{access_token, refresh_token, token_uri, client_id,
client_secret, scopes, expiry}`` dict (or equivalent MSAL shape).
Raises:
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
ValueError: ``token_info`` is not a non-empty dict.
"""
if not isinstance(token_info, dict) or not token_info:
raise ValueError("token_info must be a non-empty dict")
plaintext = json.dumps(token_info).encode("utf-8")
return _get_fernet().encrypt(plaintext).decode("utf-8")
def decrypt_token(encrypted: str) -> dict:
"""Decrypt a Fernet-encrypted token string and return the credential dict.
Raises:
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
ValueError: The encrypted string is invalid or was encrypted with a
different key.
"""
try:
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
return json.loads(plaintext)
except (InvalidToken, json.JSONDecodeError) as exc:
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
# ── Provider factory ──────────────────────────────────────────────────────
def get_provider(
provider: str,
credentials_info: dict,
) -> "GmailClient | MSGraphClient":
"""Return the correct provider client for *provider*.
Parameters
----------
provider:
One of ``"gmail"``, ``"outlook"``, ``"teams"``.
credentials_info:
Decrypted OAuth credential dict (Google or Microsoft shape).
Raises:
ValueError: Unknown provider name.
"""
if provider == "gmail":
from app.integrations.gmail import GmailClient
return GmailClient(credentials_info)
if provider in {"outlook", "teams"}:
from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(credentials_info)
raise ValueError(
f"Unknown cloud provider {provider!r}. "
"Supported: 'gmail', 'outlook', 'teams'."
)

335
app/integrations/gmail.py Normal file
View File

@@ -0,0 +1,335 @@
"""Gmail API client for cloud agent integration.
Wraps the Google Gmail REST API to fetch email messages matching a
``filter_config`` dict. Uses the official ``google-api-python-client``
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
blocking the event loop.
Token refresh is handled transparently: when the stored access token has
expired, ``google.auth.transport.requests.Request`` will use the refresh
token to obtain a fresh one. The caller is responsible for persisting
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
(see ``agent_runner.run_cloud_agent``).
Credential dict shape (Google OAuth2):
{
"token": "<access_token>",
"refresh_token": "<refresh_token>",
"token_uri": "https://oauth2.googleapis.com/token",
"client_id": "<client_id>",
"client_secret": "<client_secret>",
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
}
"""
from __future__ import annotations
import asyncio
import base64
import email
import html
import logging
import re
from datetime import datetime, timezone
from typing import Any
from app.integrations import EmailMessage
logger = logging.getLogger(__name__)
# Gmail search date format — e.g. "after:2025/01/01"
_GMAIL_DATE_FMT = "%Y/%m/%d"
# Maximum characters of body text forwarded to the LLM.
_BODY_TRUNCATE = 8_000
# Maximum messages retrieved per run (prevents runaway quota usage).
_MAX_MESSAGES = 200
def _build_gmail_query(
filter_config: dict[str, Any] | None,
since: datetime | None,
) -> str:
"""Build a Gmail search query string from *filter_config* and *since*.
Supported ``filter_config`` keys:
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
senders (list[str]): Sender addresses or domains to include
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
A hard ``since`` date (from last run) always overrides ``date_range.from``
when it is earlier.
"""
parts: list[str] = []
cfg = filter_config or {}
# Labels — joined with OR when multiple given.
labels: list[str] = cfg.get("labels", [])
if labels:
if len(labels) == 1:
parts.append(f"label:{labels[0]}")
else:
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
parts.append(f"({label_expr})")
# Senders — each prefixed with "from:".
senders: list[str] = cfg.get("senders", [])
for sender in senders:
parts.append(f"from:{sender}")
# Date range.
date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from")
to_str: str | None = date_range.get("to")
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
effective_since: datetime | None = since
if from_str:
try:
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
if cfg_since.tzinfo is None:
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
if effective_since is None or cfg_since > effective_since:
effective_since = cfg_since
except ValueError:
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
if effective_since:
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
if to_str:
try:
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
except ValueError:
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
return " ".join(parts)
def _strip_html(raw_html: str) -> str:
"""Remove HTML tags and decode entities to get plain text."""
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
decoded = html.unescape(no_tags)
return re.sub(r"\s+", " ", decoded).strip()
def _parse_body(payload: dict[str, Any]) -> str:
"""Recursively extract the plain-text body from a Gmail message payload.
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
Returns an empty string if no body can be extracted.
"""
mime_type: str = payload.get("mimeType", "")
body: dict = payload.get("body", {})
parts: list[dict] = payload.get("parts", [])
if mime_type == "text/plain":
data = body.get("data", "")
if data:
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return ""
if mime_type == "text/html":
data = body.get("data", "")
if data:
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return _strip_html(raw)
return ""
# Multipart — prefer text/plain part, fall back to text/html.
plain_fallback = ""
for part in parts:
part_mime = part.get("mimeType", "")
if part_mime == "text/plain":
return _parse_body(part)
if part_mime == "text/html" and not plain_fallback:
plain_fallback = _parse_body(part)
if part_mime.startswith("multipart/"):
nested = _parse_body(part)
if nested:
return nested
return plain_fallback
def _parse_date(raw: str) -> datetime:
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
try:
parsed = email.utils.parsedate_to_datetime(raw)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed.astimezone(timezone.utc)
except Exception:
return datetime.now(timezone.utc)
class GmailClient:
"""Fetch email messages from a Gmail account via the Gmail REST API.
Parameters
----------
credentials_info:
Decrypted OAuth2 credential dict. Must contain at minimum
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
``client_id`` + ``client_secret``.
"""
def __init__(self, credentials_info: dict[str, Any]) -> None:
from google.oauth2.credentials import Credentials
self._credentials_info = credentials_info
expiry_str: str | None = credentials_info.get("expiry")
expiry: datetime | None = None
if expiry_str:
try:
expiry = datetime.fromisoformat(
expiry_str.replace("Z", "+00:00")
).replace(tzinfo=timezone.utc)
except ValueError:
pass
self._credentials = Credentials(
token=credentials_info.get("token"),
refresh_token=credentials_info.get("refresh_token"),
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
client_id=credentials_info.get("client_id"),
client_secret=credentials_info.get("client_secret"),
scopes=credentials_info.get("scopes"),
expiry=expiry,
)
# ── Public API ─────────────────────────────────────────────────────────
async def fetch_messages(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[EmailMessage]:
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
to avoid blocking the async event loop.
Token refresh is performed automatically when the access token has
expired. After the call, ``self.refreshed_credentials`` may be
consulted to detect whether new credentials should be persisted.
"""
query = _build_gmail_query(filter_config, since)
logger.debug("gmail: executing search query %r", query)
return await asyncio.to_thread(self._fetch_sync, query)
@property
def refreshed_credentials(self) -> dict[str, Any] | None:
"""Return updated credential dict if the access token was refreshed.
If the credentials were refreshed during ``fetch_messages()``, returns
a new dict that should be re-encrypted and written back to the DB.
Returns ``None`` if no refresh occurred.
"""
creds = self._credentials
if not creds.valid and creds.expired:
return None
# Check whether the token changed from what was stored.
if creds.token != self._credentials_info.get("token"):
result = {
"token": creds.token,
"refresh_token": creds.refresh_token,
"token_uri": creds.token_uri,
"client_id": creds.client_id,
"client_secret": creds.client_secret,
"scopes": list(creds.scopes or []),
}
if creds.expiry:
result["expiry"] = creds.expiry.isoformat()
return result
return None
# ── Internal sync worker ───────────────────────────────────────────────
def _fetch_sync(self, query: str) -> list[EmailMessage]:
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
import googleapiclient.discovery
import googleapiclient.errors
from google.auth.transport.requests import Request
# Refresh token if needed before building the service.
if self._credentials.expired and self._credentials.refresh_token:
try:
self._credentials.refresh(Request())
except Exception as exc:
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
service = googleapiclient.discovery.build(
"gmail", "v1", credentials=self._credentials, cache_discovery=False
)
user_api = service.users() # type: ignore[attr-defined]
# ── List matching message IDs ──────────────────────────────────────
ids: list[str] = []
page_token: str | None = None
while len(ids) < _MAX_MESSAGES:
batch_size = min(100, _MAX_MESSAGES - len(ids))
kwargs: dict[str, Any] = {
"userId": "me",
"maxResults": batch_size,
}
if query:
kwargs["q"] = query
if page_token:
kwargs["pageToken"] = page_token
try:
resp = user_api.messages().list(**kwargs).execute()
except googleapiclient.errors.HttpError as exc:
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
for msg in resp.get("messages", []):
ids.append(msg["id"])
page_token = resp.get("nextPageToken")
if not page_token:
break
if not ids:
logger.debug("gmail: no messages matched query %r", query)
return []
logger.info("gmail: fetching %d message(s)", len(ids))
# ── Fetch individual message details ──────────────────────────────
messages: list[EmailMessage] = []
for msg_id in ids:
try:
msg = user_api.messages().get(
userId="me", id=msg_id, format="full"
).execute()
headers: dict[str, str] = {
h["name"].lower(): h["value"]
for h in msg.get("payload", {}).get("headers", [])
}
subject = headers.get("subject", "(no subject)")
sender = headers.get("from", "unknown")
date_raw = headers.get("date", "")
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
labels = msg.get("labelIds", [])
messages.append(EmailMessage(
id=msg_id,
subject=subject,
sender=sender,
body_text=body_text,
date=date,
labels=labels,
))
except googleapiclient.errors.HttpError as exc:
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
except Exception as exc:
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
logger.info("gmail: returned %d message(s)", len(messages))
return messages

View File

@@ -0,0 +1,352 @@
"""Microsoft Graph API client for Outlook and Teams cloud agent integration.
Handles two data sources:
* **Outlook email** (``provider="outlook"``) — ``fetch_emails()`` calls
``/me/messages`` with an OData ``$filter`` built from ``filter_config``.
* **Teams messages** (``provider="teams"``) — ``fetch_messages()`` calls
``/me/chats/getAllMessages`` filtered by date.
Authentication uses MSAL ``PublicClientApplication`` to acquire a token
from a stored refresh token. The ``httpx.AsyncClient`` (already a project
dependency) is used for all API calls.
Credential dict shape (Microsoft OAuth2 / MSAL):
{
"access_token": "<access_token>",
"refresh_token": "<refresh_token>",
"token_type": "Bearer",
"scope": "Mail.Read ChannelMessage.Read.All offline_access",
"expires_in": 3600
}
"""
from __future__ import annotations
import logging
import re
from datetime import datetime, timedelta, timezone
from typing import Any
import httpx
from app.config.settings import settings
from app.integrations import ChatMessage, EmailMessage
logger = logging.getLogger(__name__)
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
# Max items fetched per run.
_MAX_EMAILS = 200
_MAX_MESSAGES = 200
# Max characters of body forwarded to the LLM.
_BODY_TRUNCATE = 8_000
def _strip_html(raw: str) -> str:
"""Strip HTML tags and collapse whitespace."""
no_tags = re.sub(r"<[^>]+>", " ", raw)
import html as _html
decoded = _html.unescape(no_tags)
return re.sub(r"\s+", " ", decoded).strip()
def _odata_datetime(dt: datetime) -> str:
"""Format a datetime as an OData datetime literal (UTC, ISO 8601)."""
utc = dt.astimezone(timezone.utc)
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
def _build_email_filter(
filter_config: dict[str, Any] | None,
since: datetime | None,
) -> str:
"""Build an OData ``$filter`` expression for the ``/me/messages`` endpoint.
Supported ``filter_config`` keys:
senders (list[str]): Sender email addresses.
date_range (dict): ``{from: "<ISO-8601>", to: "<ISO-8601>"}``
folders (list[str]): Folder display names (not directly filterable
via OData, so ignored here — callers iterate
folder IDs separately if needed; listed for
completeness).
A hard ``since`` date always overrides ``date_range.from`` when it is
earlier.
"""
clauses: list[str] = []
cfg = filter_config or {}
# Senders.
senders: list[str] = cfg.get("senders", [])
if senders:
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
clauses.append("(" + " or ".join(sender_clauses) + ")")
# Date range.
date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from")
effective_since: datetime | None = since
if from_str:
try:
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
if cfg_since.tzinfo is None:
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
if effective_since is None or cfg_since > effective_since:
effective_since = cfg_since
except ValueError:
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
if effective_since:
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
to_str: str | None = date_range.get("to")
if to_str:
try:
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
if to_dt.tzinfo is None:
to_dt = to_dt.replace(tzinfo=timezone.utc)
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
except ValueError:
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
return " and ".join(clauses)
class MSGraphClient:
"""Fetch emails and Teams messages via the Microsoft Graph REST API.
Parameters
----------
credentials_info:
Decrypted MSAL credential dict.
"""
def __init__(self, credentials_info: dict[str, Any]) -> None:
self._credentials_info = credentials_info
self._access_token: str = credentials_info.get("access_token", "")
self._original_access_token: str = self._access_token
self._refresh_token: str | None = credentials_info.get("refresh_token")
# ── Token management ───────────────────────────────────────────────────
def _auth_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._access_token}"}
async def _refresh_access_token(self) -> None:
"""Use MSAL to exchange the refresh token for a fresh access token.
Updates ``self._access_token`` and ``self._credentials_info`` in-place.
Raises:
RuntimeError: MSAL reports an auth error.
"""
import msal
app = msal.ConfidentialClientApplication(
client_id=settings.MS_CLIENT_ID,
client_credential=settings.MS_CLIENT_SECRET,
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
)
scopes: list[str] = self._credentials_info.get("scope", "").split()
if not scopes:
scopes = ["https://graph.microsoft.com/.default"]
result = app.acquire_token_by_refresh_token(
self._refresh_token,
scopes=scopes,
)
if "access_token" not in result:
error = result.get("error_description", result.get("error", "unknown"))
raise RuntimeError(f"MS Graph token refresh failed: {error}")
self._access_token = result["access_token"]
# MSAL may issue a new refresh token.
if "refresh_token" in result:
self._refresh_token = result["refresh_token"]
self._credentials_info["refresh_token"] = result["refresh_token"]
self._credentials_info["access_token"] = self._access_token
@property
def refreshed_credentials(self) -> dict[str, Any] | None:
"""Return updated credential dict if the access token was refreshed.
Returns ``None`` if no change was made.
"""
if self._access_token != self._original_access_token:
return {**self._credentials_info, "access_token": self._access_token}
return None
# ── HTTP helpers ───────────────────────────────────────────────────────
async def _get(
self,
client: httpx.AsyncClient,
url: str,
params: dict[str, Any] | None = None,
*,
retry_on_401: bool = True,
) -> dict[str, Any]:
"""GET *url* with auth; refresh token on 401 and retry once."""
resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
logger.debug("ms_graph: 401 on %s — refreshing token", url)
await self._refresh_access_token()
resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 429:
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
resp.raise_for_status()
return resp.json()
# ── Public API ─────────────────────────────────────────────────────────
async def fetch_emails(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[EmailMessage]:
"""Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*.
Parameters
----------
filter_config:
Optional dict with ``senders``, ``date_range``, ``folders`` keys.
since:
Hard lower-bound on email date (from last agent run).
"""
odata_filter = _build_email_filter(filter_config, since)
params: dict[str, Any] = {
"$top": 50,
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
"$orderby": "receivedDateTime desc",
}
if odata_filter:
params["$filter"] = odata_filter
emails: list[EmailMessage] = []
url = f"{_GRAPH_BASE}/me/messages"
async with httpx.AsyncClient(timeout=30.0) as client:
while url and len(emails) < _MAX_EMAILS:
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
for item in data.get("value", []):
emails.append(self._parse_email(item))
if len(emails) >= _MAX_EMAILS:
break
url = data.get("@odata.nextLink", "")
params = {} # nextLink already contains encoded params.
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
return emails
async def fetch_messages(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[ChatMessage]:
"""Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*.
Fetches from ``/me/chats/getAllMessages`` (personal + group chats).
The ``filter_config.channels`` key is checked as a text-filter on
the channel name post-fetch (the API doesn't support channel OData
filter directly on ``getAllMessages``).
"""
cfg = filter_config or {}
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
params: dict[str, Any] = {"$top": 50}
if since:
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
messages: list[ChatMessage] = []
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
async with httpx.AsyncClient(timeout=30.0) as client:
while url and len(messages) < _MAX_MESSAGES:
try:
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
except httpx.HTTPStatusError as exc:
# getAllMessages requires specific licensing; degrade gracefully.
if exc.response.status_code in (403, 404):
logger.warning(
"ms_graph: /me/chats/getAllMessages not available (%d) — "
"check Teams license or permissions",
exc.response.status_code,
)
break
raise
for item in data.get("value", []):
msg = self._parse_teams_message(item)
if channel_filter and msg.channel:
if not any(c in msg.channel.lower() for c in channel_filter):
continue
messages.append(msg)
if len(messages) >= _MAX_MESSAGES:
break
url = data.get("@odata.nextLink", "")
params = {}
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
return messages
# ── Parsers ────────────────────────────────────────────────────────────
@staticmethod
def _parse_email(item: dict[str, Any]) -> EmailMessage:
subject: str = item.get("subject", "(no subject)") or "(no subject)"
sender_block = item.get("from", {}) or {}
sender_addr = (
(sender_block.get("emailAddress") or {}).get("address", "unknown")
)
date_str: str = item.get("receivedDateTime", "")
try:
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except Exception:
date = datetime.now(timezone.utc)
body_block = item.get("body", {}) or {}
content_type: str = body_block.get("contentType", "text")
raw_body: str = body_block.get("content", "")
if content_type == "html":
body_text = _strip_html(raw_body)
else:
body_text = raw_body or item.get("bodyPreview", "")
body_text = body_text[:_BODY_TRUNCATE]
return EmailMessage(
id=item.get("id", ""),
subject=subject,
sender=sender_addr,
body_text=body_text,
date=date,
)
@staticmethod
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
msg_id: str = item.get("id", "")
sender_block = (item.get("from") or {}).get("user") or {}
sender: str = sender_block.get("displayName", "unknown")
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
date_str: str = item.get("createdDateTime", "")
try:
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except Exception:
date = datetime.now(timezone.utc)
body_block = item.get("body", {}) or {}
content_type: str = body_block.get("contentType", "text")
raw_content: str = body_block.get("content", "")
content = _strip_html(raw_content) if content_type == "html" else raw_content
content = content[:_BODY_TRUNCATE]
return ChatMessage(
id=msg_id,
content=content,
sender=sender,
channel=channel,
date=date,
)

View File

@@ -1,8 +1,16 @@
from contextlib import asynccontextmanager
import logging
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
from app.api.middleware.rate_limit import TierRateLimitMiddleware
from app.api.middleware.sanitizer import SanitizerMiddleware
from app.config.settings import settings
@@ -10,9 +18,8 @@ from app.config.settings import settings
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: initialise DB connection pool and agent registry
from app.core.agent_registry import registry # noqa: F401 — triggers module load
import app.agents # noqa: F401 — triggers @registry.register decorators
# Startup: ensure agent tool modules are loaded.
import app.agents # noqa: F401
yield
@@ -23,7 +30,7 @@ async def lifespan(app: FastAPI):
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva Cloud API",
title="AdiuvAI Cloud API",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
@@ -43,16 +50,13 @@ def create_app() -> FastAPI:
app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors
from app.api.routes import agents, auth, billing, chat, device_ws
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
app.include_router(plans.router, prefix="/api/v1")
app.include_router(storage.router, prefix="/api/v1")
app.include_router(vectors.router, prefix="/api/v1")
app.include_router(backup.router, prefix="/api/v1")
app.include_router(plugins.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1")
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1")
app.include_router(agents.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:

View File

@@ -1,7 +0,0 @@
"""Plugin marketplace package.
Three service classes introduced in Step 10:
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
- ``ReviewQueue`` — approval workflow + security checklist
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
"""

View File

@@ -1,212 +0,0 @@
"""Plugin catalog registry backed by PostgreSQL.
Maintains the authoritative list of plugins, their review status, and
aggregate install counts. All data is persisted in the ``plugins`` table.
Module-level singleton::
from app.marketplace.plugin_registry import registry
"""
from __future__ import annotations
import json
from typing import Any, Literal
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Plugin
from app.schemas import PluginListResponse, PluginManifest
_PAGE_SIZE = 20
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
try:
permissions = json.loads(p.permissions) if p.permissions else []
except (json.JSONDecodeError, TypeError):
permissions = []
return PluginManifest(
id=p.id,
name=p.name,
description=p.description,
version=p.version,
author=p.author_name,
permissions=permissions,
category=p.category,
price_cents=p.price_cents,
)
class PluginRegistry:
"""PostgreSQL-backed plugin catalog.
All methods accept an ``AsyncSession`` parameter so the calling route
controls the session lifecycle.
"""
# ── Queries ──────────────────────────────────────────────────────
async def list_plugins(
self,
db: AsyncSession,
category: str | None = None,
query: str | None = None,
page: int = 1,
sort: Literal["rating", "installs", "newest"] = "newest",
) -> PluginListResponse:
"""Return a page of approved plugins, optionally filtered and sorted."""
base = select(Plugin).where(Plugin.status == "approved")
if category:
base = base.where(Plugin.category == category)
if query:
pattern = f"%{query}%"
base = base.where(
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
)
# Count
count_q = select(func.count()).select_from(base.subquery())
total = (await db.execute(count_q)).scalar_one()
# Sort
if sort == "installs":
base = base.order_by(Plugin.install_count.desc())
elif sort == "rating":
base = base.order_by(Plugin.avg_rating.desc())
else: # newest
base = base.order_by(Plugin.created_at.desc())
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
rows = (await db.execute(base)).scalars().all()
return PluginListResponse(
plugins=[_plugin_to_manifest(r) for r in rows],
total=total,
page=page,
)
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
p = result.scalar_one_or_none()
if p is None:
return None
return {
"manifest": _plugin_to_manifest(p),
"status": p.status,
"install_count": p.install_count,
"avg_rating": p.avg_rating,
}
# ── Mutations ────────────────────────────────────────────────────
async def submit_plugin(
self,
db: AsyncSession,
manifest: PluginManifest,
package_s3_key: str,
) -> str:
"""Add *manifest* to the catalog with ``status='pending_review'``.
Returns the plugin_id. If a plugin with the same id already exists
it is overwritten (re-submission after rejection).
"""
plugin_id = manifest.id
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = existing.scalar_one_or_none()
if row is not None:
row.name = manifest.name
row.description = manifest.description
row.version = manifest.version
row.author_name = manifest.author
row.category = manifest.category
row.price_cents = manifest.price_cents
row.permissions = json.dumps(manifest.permissions)
row.status = "pending_review"
row.s3_package_key = package_s3_key
row.rejection_reason = None
else:
row = Plugin(
id=plugin_id,
name=manifest.name,
description=manifest.description,
version=manifest.version,
author_name=manifest.author,
category=manifest.category,
price_cents=manifest.price_cents,
permissions=json.dumps(manifest.permissions),
status="pending_review",
s3_package_key=package_s3_key,
install_count=0,
avg_rating=0.0,
)
db.add(row)
await db.commit()
return plugin_id
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
"""Set *plugin_id* status to ``'approved'``.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "approved"
row.rejection_reason = None
await db.commit()
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "rejected"
row.rejection_reason = reason
await db.commit()
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
"""Increment the install count for *plugin_id* (no-op if not found)."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = row.install_count + 1
await db.commit()
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
"""Decrement the install count for *plugin_id*, floored at 0."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = max(0, row.install_count - 1)
await db.commit()
# ── Internal helpers used by ReviewQueue ─────────────────────────
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all entries with status='pending_review'."""
result = await db.execute(
select(Plugin).where(Plugin.status == "pending_review")
)
rows = result.scalars().all()
return [
{
"manifest": _plugin_to_manifest(r),
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
}
for r in rows
]
# Module-level singleton
registry = PluginRegistry()

View File

@@ -1,125 +0,0 @@
"""Plugin review workflow backed by PostgreSQL.
Manages the approval queue for newly submitted plugins and enforces a
security checklist before any plugin is made visible in the marketplace.
Module-level singleton::
from app.marketplace.plugin_review import review_queue
"""
from __future__ import annotations
import re
from typing import Any, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.marketplace.plugin_registry import registry
from app.models import PluginReview as PluginReviewModel
from app.schemas import PluginManifest
# ── Security policy ───────────────────────────────────────────────────
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
{
"read:tasks",
"write:tasks",
"read:projects",
"write:projects",
"read:notes",
"write:notes",
"read:checkpoints",
"write:checkpoints",
"read:calendar",
"write:calendar",
}
)
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
def validate_manifest(manifest: PluginManifest) -> None:
"""Enforce the plugin security checklist.
Raises:
``ValueError`` on the first violation found. Callers should catch
this and return HTTP 422 / reject the submission.
Checks:
1. Plugin id matches ``^[a-z0-9-]+$``
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
3. No manifest field contains raw binary data
"""
if not _PLUGIN_ID_RE.match(manifest.id):
raise ValueError(
f"Invalid plugin id format: '{manifest.id}'. "
"Only lowercase letters, digits, and hyphens are allowed."
)
for perm in manifest.permissions:
if perm not in ALLOWED_PERMISSIONS:
raise ValueError(
f"Unknown permission: '{perm}'. "
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
)
for field_name, value in manifest.model_dump().items():
if isinstance(value, (bytes, bytearray)):
raise ValueError(
f"Binary content is not allowed in manifest field '{field_name}'."
)
class ReviewQueue:
"""Approval queue for pending plugin submissions.
Delegates status changes to the shared ``PluginRegistry`` singleton.
Review records are persisted in the ``plugin_reviews`` table.
"""
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all plugins currently awaiting review.
Each item is ``{plugin_id, manifest, submitted_at}``.
"""
entries = await registry.get_pending_entries(db)
return [
{
"plugin_id": e["manifest"].id,
"manifest": e["manifest"],
"submitted_at": e["submitted_at"],
}
for e in entries
]
async def submit_review(
self,
db: AsyncSession,
plugin_id: str,
reviewer_id: str,
decision: Literal["approved", "rejected"],
notes: str = "",
) -> None:
"""Record a review decision and update the plugin's status.
Raises:
``KeyError`` if *plugin_id* is not found in the registry.
"""
if decision == "approved":
await registry.approve_plugin(db, plugin_id)
else:
await registry.reject_plugin(db, plugin_id, reason=notes)
review = PluginReviewModel(
plugin_id=plugin_id,
reviewer_id=reviewer_id,
decision=decision,
notes=notes,
)
db.add(review)
await db.commit()
# Module-level singleton
review_queue = ReviewQueue()

View File

@@ -1,233 +0,0 @@
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
Records every plugin installation as a revenue event and facilitates
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
in the ``revenue_events`` table.
Module-level singleton::
from app.marketplace.revenue_share import revenue_share
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Any
import stripe as stripe_lib
from sqlalchemy import extract, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.marketplace.plugin_registry import registry
from app.models import Plugin, RevenueEvent
logger = logging.getLogger(__name__)
# ── Revenue split constants ───────────────────────────────────────────
DEVELOPER_SHARE: float = 0.70
PLATFORM_SHARE: float = 0.30
class RevenueShare:
"""Records installation revenue events and coordinates developer payouts.
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
is not configured, consistent with the rest of the billing layer.
"""
# ── Helpers ──────────────────────────────────────────────────────
@staticmethod
def _stripe_configured() -> bool:
return bool(settings.STRIPE_SECRET_KEY)
@staticmethod
def _stripe() -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Core operations ──────────────────────────────────────────────
async def record_install(
self,
db: AsyncSession,
plugin_id: str,
user_id: str,
amount_cents: int,
) -> None:
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
For free plugins (``amount_cents == 0``) no payment is initiated but
the event is still recorded for analytics.
For paid plugins the developer receives 70 % via a Stripe Connect
destination charge. If Stripe is not configured or the charge fails
the installation still succeeds (the event is recorded and the install
count is incremented) — a warning is logged for monitoring.
"""
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
stripe_transfer_id: str | None = None
if amount_cents > 0 and self._stripe_configured():
# Look up the plugin's author Stripe account from the DB
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = result.scalar_one_or_none()
developer_stripe_account: str | None = None
if plugin_row and plugin_row.author_id:
# Future: look up user.stripe_connect_account_id
developer_stripe_account = None # no real account yet
if developer_stripe_account:
try:
s = self._stripe()
transfer = s.Transfer.create(
amount=developer_share_cents,
currency="eur",
destination=developer_stripe_account,
description=f"Revenue share for plugin {plugin_id}",
metadata={"plugin_id": plugin_id, "user_id": user_id},
)
stripe_transfer_id = transfer["id"]
except Exception as exc:
logger.warning(
"Stripe Connect transfer failed for plugin %s: %s",
plugin_id,
exc,
)
else:
logger.debug(
"No Stripe account on file for plugin %s developer; "
"skipping transfer.",
plugin_id,
)
event = RevenueEvent(
plugin_id=plugin_id,
user_id=user_id,
amount_cents=amount_cents,
developer_share_cents=developer_share_cents,
stripe_transfer_id=stripe_transfer_id,
)
db.add(event)
await db.commit()
await registry.record_install(db, plugin_id)
async def get_earnings(
self,
db: AsyncSession,
developer_id: str,
period: str | None = None,
) -> dict[str, Any]:
"""Return aggregated earnings for *developer_id*.
``period`` is an optional ``YYYY-MM`` string to restrict the window.
Returns::
{
"developer_id": str,
"period": str | None,
"total_installs": int,
"total_revenue_cents": int,
"developer_share_cents": int,
}
"""
# Find plugin ids belonging to this developer (by author_name match)
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
plugin_result = await db.execute(plugin_q)
developer_plugin_ids = [row[0] for row in plugin_result.all()]
if not developer_plugin_ids:
return {
"developer_id": developer_id,
"period": period,
"total_installs": 0,
"total_revenue_cents": 0,
"developer_share_cents": 0,
}
query = select(
func.count().label("total_installs"),
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
if period:
# Filter by YYYY-MM: extract year and month from created_at
try:
year, month = period.split("-")
query = query.where(
extract("year", RevenueEvent.created_at) == int(year),
extract("month", RevenueEvent.created_at) == int(month),
)
except ValueError:
pass # invalid period format — return all
result = await db.execute(query)
row = result.one()
return {
"developer_id": developer_id,
"period": period,
"total_installs": row.total_installs,
"total_revenue_cents": row.total_revenue,
"developer_share_cents": row.dev_share,
}
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
Marks processed events with ``paid_at`` timestamp.
Stubs gracefully when Stripe is not configured.
"""
try:
year, month = period.split("-")
year_int, month_int = int(year), int(month)
except ValueError:
logger.warning("Invalid period format: %s", period)
return
result = await db.execute(
select(RevenueEvent).where(
RevenueEvent.plugin_id == plugin_id,
RevenueEvent.paid_at.is_(None),
extract("year", RevenueEvent.created_at) == year_int,
extract("month", RevenueEvent.created_at) == month_int,
)
)
unpaid = list(result.scalars().all())
total_dev_share = sum(e.developer_share_cents for e in unpaid)
if total_dev_share <= 0 or not unpaid:
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
return
if self._stripe_configured():
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = plugin_result.scalar_one_or_none()
developer_stripe_account: str | None = None # Future: fetch from DB
if plugin_row and developer_stripe_account:
try:
s = self._stripe()
s.Transfer.create(
amount=total_dev_share,
currency="eur",
destination=developer_stripe_account,
description=f"Payout for plugin {plugin_id} period {period}",
)
except Exception as exc:
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
return
paid_ts = datetime.now(timezone.utc)
for event in unpaid:
event.paid_at = paid_ts
await db.commit()
# Module-level singleton
revenue_share = RevenueShare()

View File

@@ -1,19 +1,19 @@
"""SQLAlchemy ORM models for all persistent tables.
Only auth, billing, storage metadata, and marketplace data live here.
User content (notes, tasks, etc.) is NEVER persisted server-side —
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
Only auth, billing, agent config, and memory data live here.
User content (notes, tasks, etc.) lives exclusively on the client.
Table inventory:
users — account credentials + tier
refresh_tokens — hashed refresh token store
subscriptions — Stripe subscription records
storage_records — S3 blob metadata (no plaintext)
backup_metadata — encrypted backup manifests
plugins — marketplace plugin catalog
plugin_installations — per-user install records
plugin_reviews — admin review decisions
revenue_events — Stripe Connect 70/30 split ledger
local_agent_configs — per-device batch agent configs
cloud_agent_configs — OAuth-backed cloud agent configs
agent_run_logs — execution history for all agents
memory_core — per-user persistent key/value preferences (encrypted)
memory_associative — per-user semantic memory with embeddings (encrypted)
memory_episodic — per-user session summaries (encrypted)
memory_proactive — per-user behavioral patterns (encrypted)
"""
from __future__ import annotations
@@ -22,15 +22,15 @@ import uuid
from datetime import datetime, timezone
from sqlalchemy import (
BigInteger,
Boolean,
DateTime,
Enum,
Float,
ForeignKey,
Integer,
JSON,
String,
Text,
UniqueConstraint,
Uuid,
func,
)
@@ -52,8 +52,9 @@ def _now() -> datetime:
# ── Enum types ────────────────────────────────────────────────────────────
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
# ── Models ────────────────────────────────────────────────────────────────
@@ -66,9 +67,14 @@ class User(Base):
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
# Used to encrypt/decrypt all memory rows for this user.
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
@@ -123,8 +129,8 @@ class Subscription(Base):
user: Mapped[User] = relationship(back_populates="subscription")
class StorageRecord(Base):
__tablename__ = "storage_records"
class LocalAgentConfig(Base):
__tablename__ = "local_agent_configs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
@@ -132,10 +138,16 @@ class StorageRecord(Base):
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
device_id: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
agent_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
@@ -143,9 +155,17 @@ class StorageRecord(Base):
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
run_logs: Mapped[list[AgentRunLog]] = relationship(
back_populates="local_agent",
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
foreign_keys="AgentRunLog.agent_id",
cascade="all, delete-orphan",
overlaps="run_logs,cloud_agent",
)
class BackupMetadata(Base):
__tablename__ = "backup_metadata"
class CloudAgentConfig(Base):
__tablename__ = "cloud_agent_configs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
@@ -153,116 +173,152 @@ class BackupMetadata(Base):
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class Plugin(Base):
__tablename__ = "plugins"
id: Mapped[str] = mapped_column(String(255), primary_key=True)
provider: Mapped[str] = mapped_column(CloudProviderEnum, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
# nullable until developer account system is built
author_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
submitted_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
oauth_token_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
filter_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
installations: Mapped[list[PluginInstallation]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
reviews: Mapped[list[PluginReview]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
revenue_events: Mapped[list[RevenueEvent]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
run_logs: Mapped[list[AgentRunLog]] = relationship(
back_populates="cloud_agent",
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
foreign_keys="AgentRunLog.agent_id",
cascade="all, delete-orphan",
overlaps="run_logs,local_agent",
)
class PluginInstallation(Base):
__tablename__ = "plugin_installations"
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
class AgentRunLog(Base):
__tablename__ = "agent_run_logs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
# Plain string — not a FK because it references either local_agent_configs or cloud_agent_configs
# depending on agent_type. Query by (agent_id, agent_type) to locate the source config.
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
installed_at: Mapped[datetime] = mapped_column(
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
started_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
plugin: Mapped[Plugin] = relationship(back_populates="installations")
class PluginReview(Base):
__tablename__ = "plugin_reviews"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
local_agent: Mapped[LocalAgentConfig | None] = relationship(
back_populates="run_logs",
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
foreign_keys="AgentRunLog.agent_id",
overlaps="run_logs,cloud_agent",
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
reviewer_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
reviewed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
cloud_agent: Mapped[CloudAgentConfig | None] = relationship(
back_populates="run_logs",
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
foreign_keys="AgentRunLog.agent_id",
overlaps="run_logs,local_agent",
)
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
# ── Memory models ─────────────────────────────────────────────────────────────
class RevenueEvent(Base):
__tablename__ = "revenue_events"
class MemoryCore(Base):
"""Per-user persistent key/value preferences, encrypted at rest.
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
Examples: preferred_language, timezone, work_style.
Decrypted in-memory only using User.encryption_key.
"""
__tablename__ = "memory_core"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
key: Mapped[str] = mapped_column(String(255), nullable=False)
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class MemoryAssociative(Base):
"""Per-user semantic memory: encrypted content + pgvector embedding for similarity search.
Production: ``embedding`` column is ``vector(1536)`` via pgvector.
Tests (SQLite): stored as JSON list.
"""
__tablename__ = "memory_associative"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class MemoryEpisodic(Base):
"""Per-user session summaries, encrypted at rest.
One row per session interaction; used to recall recent conversations.
"""
__tablename__ = "memory_episodic"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
class MemoryProactive(Base):
"""Per-user inferred behavioral patterns, encrypted at rest.
Confidence in [0.0, 1.0]; only patterns above threshold are injected.
Source: 'inferred' (from episodes) or 'explicit' (user-stated).
"""
__tablename__ = "memory_proactive"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)

View File

@@ -5,6 +5,7 @@ Mirrors the TypeScript types from the Electron app (src/shared/api-types.ts).
from __future__ import annotations
from enum import Enum
from typing import Any, Literal
from pydantic import BaseModel, Field
@@ -26,6 +27,8 @@ class AuthTokens(BaseModel):
class UserProfile(BaseModel):
id: str
email: str
name: str | None = None
surname: str | None = None
tier: BillingTier
@@ -38,120 +41,220 @@ class ChatContext(BaseModel):
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
class PlanAction(BaseModel):
type: Literal[
"create_record",
"update_record",
"delete_record",
"index_document",
"send_notification",
]
table: str | None = None
data: dict[str, Any] | None = None
class ChatRequest(BaseModel):
message: str
context: ChatContext = Field(default_factory=ChatContext)
execution_mode: Literal["direct", "plan"] = "direct"
class ChatResponse(BaseModel):
response: str
actions: list[PlanAction] = Field(default_factory=list)
# ── Execution Plans ──────────────────────────────────────────────────
# ── WebSocket Frame Protocol ──────────────────────────────────────────
class PlanStep(BaseModel):
class WsFrameType(str, Enum):
# ── v2 frame types (kept for backward compat) ──────────────────────
chat_request = "chat_request"
text_chunk = "text_chunk"
tool_call = "tool_call"
tool_result = "tool_result"
final = "final"
ping = "ping"
device_hello = "device_hello"
# ── v3 frame types ─────────────────────────────────────────────────
home_request = "home_request"
floating_request = "floating_request"
stream_start = "stream_start"
stream_text = "stream_text"
stream_end = "stream_end"
floating_domain = "floating_domain"
data_request = "data_request"
data_response = "data_response"
mutation = "mutation"
# ── v4 journey frame types ────────────────────────────────────────
journey_start = "journey_start"
journey_message = "journey_message"
journey_reply = "journey_reply"
class WsToolCall(BaseModel):
"""Server → Client: requests a CRUD/vector operation on the local DB."""
type: Literal[WsFrameType.tool_call] = WsFrameType.tool_call
id: str
action: str
prompt_template: str | None = None
variables: dict[str, Any] | None = None
data_from_step: int | None = None
table: str | None = None
data: dict[str, Any] | None = None
filters: dict[str, Any] | None = None
vector: list[float] | None = None
limit: int | None = None
class ExecutionPlan(BaseModel):
agent: str
steps: list[PlanStep] = Field(default_factory=list)
class WsToolResult(BaseModel):
"""Client → Server: result of a CRUD/vector operation."""
# ── Backup ───────────────────────────────────────────────────────────
class BackupMetadata(BaseModel):
version: int
timestamp: int
checksum: str
chunk_count: int
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
class StorageRecord(BaseModel):
type: Literal[WsFrameType.tool_result] = WsFrameType.tool_result
id: str
user_id: str
table: str
blob: bytes
checksum: str
created_at: int
updated_at: int
row: dict[str, Any] | None = None
rows: list[dict[str, Any]] | None = None
results: list[dict[str, Any]] | None = None
deleted: bool | None = None
ok: bool | None = None
error: str | None = None
class StorageRecordCreate(BaseModel):
table: str
blob: bytes
checksum: str
class WsTextChunk(BaseModel):
"""Server → Client: incremental LLM response text."""
type: Literal[WsFrameType.text_chunk] = WsFrameType.text_chunk
text: str
class StorageRecordUpdate(BaseModel):
blob: bytes
checksum: str
class WsFinal(BaseModel):
"""Server → Client: signals end of response with the complete text."""
type: Literal[WsFrameType.final] = WsFrameType.final
response: str
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
# ── WebSocket Agent Frame Protocol ────────────────────────────────────
class WsDeviceHello(BaseModel):
"""Client → Server: device identification on WS connect."""
type: Literal[WsFrameType.device_hello] = WsFrameType.device_hello
device_id: str
agent_ids: list[str] = Field(default_factory=list)
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
class WsFloatingScope(BaseModel):
"""Scope for a floating request — narrows the agent to a specific entity."""
type: Literal["task", "project", "note", "timeline"]
id: str | None = None
class WsHomeRequest(BaseModel):
"""Client → Server: Home chat message."""
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
message: str
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
class WsFloatingRequest(BaseModel):
"""Client → Server: Floating chat message scoped to an entity."""
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
message: str
scope: WsFloatingScope
class WsStreamStart(BaseModel):
"""Server → Client: signals start of a streaming response."""
type: Literal[WsFrameType.stream_start] = WsFrameType.stream_start
request_id: str
class WsStreamText(BaseModel):
"""Server → Client: streamed text token."""
type: Literal[WsFrameType.stream_text] = WsFrameType.stream_text
request_id: str
chunk: str
class WsStreamEnd(BaseModel):
"""Server → Client: signals end of a streaming response."""
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
request_id: str
class WsDomain(BaseModel):
"""Structured floating domain payload for UI routing decisions."""
type: Literal["task", "timeline", "project", "node"]
id: str | None = None
section: Literal["task", "timeline", "note"] | None = None
class WsFloatingDomain(BaseModel):
"""Server → Client: domain determined for a floating request."""
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
request_id: str
domain: WsDomain
# ── Agent Config V2 ───────────────────────────────────────────────────
class ContentTypeConfig(BaseModel):
"""Per-type extraction config produced by the journey chatbot."""
class VectorItem(BaseModel):
id: str
blob: bytes # encrypted vector + metadata — backend never decrypts
checksum: str
label: str = ""
detection_hint: str = ""
preprocessing: str = "generic" # handler name: "email_html", "plain_text", ...
extraction_prompt: str
class VectorUpsertRequest(BaseModel):
vectors: list[VectorItem]
class AgentConfig(BaseModel):
"""Structured agent configuration (replaces freeform prompt_template)."""
content_types: list[ContentTypeConfig] = []
global_rules: list[str] = []
data_types: list[str] = []
class VectorSearchRequest(BaseModel):
query_blob: bytes # encrypted query — backend never decrypts
top_k: int = 10
# ── Agent Catalog ─────────────────────────────────────────────────────
class VectorSearchResult(BaseModel):
id: str
score: float
blob: bytes
class VectorSearchResponse(BaseModel):
results: list[VectorSearchResult]
# ── Plugin Marketplace ────────────────────────────────────────────────
class PluginManifest(BaseModel):
id: str
class AgentCatalogItem(BaseModel):
type: str
name: str
description: str
version: str
author: str
permissions: list[str]
category: str
price_cents: int = 0
class PluginListResponse(BaseModel):
plugins: list[PluginManifest]
total: int
page: int
class AgentCreationCheckRequest(BaseModel):
active_agents: int = Field(ge=0, default=0)
class PluginInstallRequest(BaseModel):
plugin_id: str
class AgentCreationCheckResponse(BaseModel):
allowed: bool
tier: BillingTier
active_agents: int
limit: int
class AgentTriggerRequest(BaseModel):
directory: str = Field(min_length=1)
device_id: str = Field(default="")
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
what_to_extract: list[str] = Field(min_length=1)
actions_by_type: dict[str, list[str]] | None = None
batch_interval: str = Field(min_length=1)
custom_agent_prompt: str = Field(min_length=1)
active_agents: int = Field(ge=0, default=0)
# ── Agent Run Log ─────────────────────────────────────────────────────
class AgentRunLogResponse(BaseModel):
id: str
agent_id: str
agent_type: Literal["local", "cloud"]
status: Literal["running", "success", "error", "partial"]
items_processed: int
items_created: int
errors: list[str]
started_at: int
completed_at: int | None
# ── Chatbot Journey ───────────────────────────────────────────────────

View File

@@ -1 +0,0 @@
"""Cloud storage layer — E2E encrypted blobs and vectors."""

View File

@@ -1,106 +0,0 @@
"""S3-backed store for E2E-encrypted blobs.
Keys are structured as ``{user_id}/{table}/{record_id}``.
The backend never inspects blob content — it stores and retrieves opaque bytes.
"""
from __future__ import annotations
from typing import Any
import boto3
from app.config.settings import settings
class BlobStore:
"""Thin wrapper around boto3 S3.
All blobs must be E2E encrypted by the client before upload.
The backend adds SSE-S3 as an extra layer of at-rest encryption
but cannot decrypt the inner client-side payload.
"""
def _client(self) -> Any:
kwargs: dict[str, Any] = {
"region_name": settings.S3_REGION,
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
}
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
return boto3.client("s3", **kwargs)
@staticmethod
def _key(user_id: str, table: str, record_id: str) -> str:
return f"{user_id}/{table}/{record_id}"
async def upload(
self,
user_id: str,
table: str,
record_id: str,
blob: bytes,
checksum: str,
) -> str:
"""Store *blob* in S3 and return the S3 key.
Args:
user_id: Owner of the blob (used as key prefix).
table: Logical table name (e.g. ``"tasks"``).
record_id: Record UUID.
blob: Raw bytes (pre-encrypted by client).
checksum: SHA-256 hex digest supplied by the client; stored as
object metadata for download-time verification.
Returns:
The S3 key under which the blob was stored.
"""
key = self._key(user_id, table, record_id)
self._client().put_object(
Bucket=settings.S3_BUCKET,
Key=key,
Body=blob,
ServerSideEncryption="AES256", # SSE-S3 at rest
Metadata={"checksum": checksum},
)
return key
async def download(self, user_id: str, s3_key: str) -> bytes:
"""Retrieve the blob stored at *s3_key*.
*user_id* is retained in the signature so higher-level code can
enforce ownership without re-parsing the key.
Raises:
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
object does not exist.
"""
response = self._client().get_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
return response["Body"].read()
async def delete(self, user_id: str, s3_key: str) -> None:
"""Delete the object at *s3_key*.
S3 ``delete_object`` is idempotent — it succeeds even if the key does
not exist.
"""
self._client().delete_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
async def list_keys(self, user_id: str, table: str) -> list[str]:
"""Return all S3 keys for a given user + table combination.
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
"""
prefix = f"{user_id}/{table}/"
response = self._client().list_objects_v2(
Bucket=settings.S3_BUCKET,
Prefix=prefix,
)
return [obj["Key"] for obj in response.get("Contents", [])]

View File

@@ -1,32 +0,0 @@
"""Integrity verification only — the backend NEVER decrypts user data."""
from __future__ import annotations
import hashlib
import hmac
from fastapi import HTTPException
def verify_checksum(blob: bytes, checksum: str) -> bool:
"""Return ``True`` if SHA-256(blob) matches *checksum*.
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
timing-based side-channel attacks.
"""
computed = hashlib.sha256(blob).hexdigest()
return hmac.compare_digest(computed, checksum)
def reject_if_tampered(blob: bytes, checksum: str) -> None:
"""Raise ``HTTP 400`` if the blob does not match its checksum.
Call this before storing or forwarding any client-provided blob.
The backend never holds decryption keys — this check only verifies
that the opaque bytes arrived intact.
"""
if not verify_checksum(blob, checksum):
raise HTTPException(
status_code=400,
detail="Checksum mismatch: blob integrity check failed",
)

View File

@@ -1,205 +0,0 @@
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
Vectors are pre-encrypted blobs from the client. The backend stores them
alongside a deterministic 32-dim float representation derived from the blob's
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
is a known trade-off documented in the backend plan.
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
``user_id`` payload field on a shared collection.
"""
from __future__ import annotations
import base64
import hashlib
from typing import Any
from pinecone import Pinecone
from qdrant_client import QdrantClient
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
from app.config.settings import settings
from app.schemas import VectorItem, VectorSearchResult
_QDRANT_COLLECTION = "adiuva_vectors"
def _blob_to_vector(blob: bytes) -> list[float]:
"""Derive a 32-dim float vector from *blob* for storage purposes only.
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
normalises each byte to the range [-1.0, 1.0]. This vector carries no
semantic meaning on encrypted data.
"""
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
class VectorStore:
"""Thin wrapper around Pinecone or Qdrant.
The backend to use is selected at runtime:
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
"""
def _use_pinecone(self) -> bool:
return bool(settings.PINECONE_API_KEY)
# ── Pinecone helpers ──────────────────────────────────────────────
def _pinecone_index(self) -> Any:
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
return pc.Index(settings.PINECONE_INDEX)
# ── Qdrant helpers ────────────────────────────────────────────────
def _qdrant_client(self) -> Any:
return QdrantClient(
url=settings.QDRANT_URL,
api_key=settings.QDRANT_API_KEY or None,
)
# ── Public API ────────────────────────────────────────────────────
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
"""Store encrypted vectors in the backend.
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
so it can be returned verbatim during search.
Args:
user_id: Used as Pinecone namespace or Qdrant payload field.
vectors: List of encrypted vector items from the client.
"""
if self._use_pinecone():
await self._pinecone_upsert(user_id, vectors)
else:
await self._qdrant_upsert(user_id, vectors)
async def search(
self,
user_id: str,
query_blob: bytes,
top_k: int,
) -> list[VectorSearchResult]:
"""Query the vector store and return encrypted result blobs.
The query vector is derived from *query_blob* using the same
deterministic mapping as upsert.
Args:
user_id: Scopes the search to this user's namespace.
query_blob: Encrypted query from the client.
top_k: Maximum number of results to return.
Returns:
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
"""
if self._use_pinecone():
return await self._pinecone_search(user_id, query_blob, top_k)
return await self._qdrant_search(user_id, query_blob, top_k)
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
"""Remove vectors by ID, scoped to *user_id*.
Args:
user_id: Namespace / payload filter to prevent cross-user deletion.
vector_ids: List of vector IDs to remove.
"""
if self._use_pinecone():
await self._pinecone_delete(user_id, vector_ids)
else:
await self._qdrant_delete(user_id, vector_ids)
# ── Pinecone implementation ───────────────────────────────────────
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
index = self._pinecone_index()
records = [
{
"id": v.id,
"values": _blob_to_vector(v.blob),
"metadata": {
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
}
for v in vectors
]
index.upsert(vectors=records, namespace=user_id)
async def _pinecone_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
index = self._pinecone_index()
query_vector = _blob_to_vector(query_blob)
response = index.query(
vector=query_vector,
top_k=top_k,
namespace=user_id,
include_metadata=True,
)
results: list[VectorSearchResult] = []
for match in response.get("matches", []):
blob_bytes = base64.b64decode(match["metadata"]["blob"])
results.append(
VectorSearchResult(
id=match["id"],
score=match["score"],
blob=blob_bytes,
)
)
return results
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
index = self._pinecone_index()
index.delete(ids=vector_ids, namespace=user_id)
# ── Qdrant implementation ─────────────────────────────────────────
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
client = self._qdrant_client()
points = [
PointStruct(
id=v.id,
vector=_blob_to_vector(v.blob),
payload={
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
)
for v in vectors
]
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
async def _qdrant_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
client = self._qdrant_client()
query_vector = _blob_to_vector(query_blob)
hits = client.search(
collection_name=_QDRANT_COLLECTION,
query_vector=query_vector,
query_filter=Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
),
limit=top_k,
)
return [
VectorSearchResult(
id=str(hit.id),
score=hit.score,
blob=base64.b64decode(hit.payload["blob"]),
)
for hit in hits
]
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
client = self._qdrant_client()
client.delete(
collection_name=_QDRANT_COLLECTION,
points_selector=PointIdsList(points=vector_ids),
)

View File

@@ -7,18 +7,21 @@ services:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuvai
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
volumes:
- copilot_tokens:/root/.config/litellm/github_copilot
depends_on:
db:
condition: service_healthy
restart: unless-stopped
db:
image: postgres:16-alpine
image: pgvector/pgvector:pg16
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: adiuva
POSTGRES_DB: adiuvai
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
@@ -33,36 +36,6 @@ services:
# image: redis:7-alpine
# restart: unless-stopped
# ── Local S3-compatible storage (MinIO) ──
minio:
image: minio/minio:latest
command: server /data --console-address ":9001"
ports:
- "9000:9000"
- "9001:9001"
environment:
MINIO_ROOT_USER: minioadmin
MINIO_ROOT_PASSWORD: minioadmin
volumes:
- minio_data:/data
healthcheck:
test: ["CMD", "mc", "ready", "local"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
# ── Local vector store (Qdrant) ──
qdrant:
image: qdrant/qdrant:latest
ports:
- "6333:6333"
- "6334:6334"
volumes:
- qdrant_data:/qdrant/storage
restart: unless-stopped
volumes:
postgres_data:
minio_data:
qdrant_data:
copilot_tokens:

56
logging.conf Normal file
View File

@@ -0,0 +1,56 @@
[loggers]
keys=root,uvicorn,uvicorn.error,uvicorn.access,sqlalchemy,watchfiles
[handlers]
keys=console,file
[formatters]
keys=default
[logger_root]
level=INFO
handlers=console,file
[logger_uvicorn]
level=INFO
handlers=
qualname=uvicorn
propagate=1
[logger_uvicorn.error]
level=INFO
handlers=
qualname=uvicorn.error
propagate=1
[logger_uvicorn.access]
level=INFO
handlers=
qualname=uvicorn.access
propagate=1
[logger_sqlalchemy]
level=WARNING
handlers=
qualname=sqlalchemy
propagate=1
[logger_watchfiles]
level=WARNING
handlers=
qualname=watchfiles
propagate=1
[handler_console]
class=StreamHandler
formatter=default
args=(sys.stderr,)
[handler_file]
class=logging.handlers.RotatingFileHandler
formatter=default
args=('logs/app.log', 'a', 10485760, 5, 'utf-8')
[formatter_default]
format=%(asctime)s %(levelname)s %(name)s: %(message)s
datefmt=%Y-%m-%d %H:%M:%S

View File

@@ -3,6 +3,7 @@ uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
langchain>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.1.0
litellm>=1.50.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
@@ -24,4 +25,15 @@ aiosqlite>=0.20.0
moto[s3]>=5.0.0
pinecone>=5.0.0
qdrant-client>=1.7.0
croniter>=3.0.0
google-api-python-client>=2.130.0
google-auth>=2.29.0
google-auth-oauthlib>=1.2.0
google-auth-httplib2>=0.2.0
msal>=1.28.0
cryptography>=42.0.0
langfuse>=2.0.0
beautifulsoup4>=4.12.0
lxml>=5.0.0
PyYAML>=6.0.0
ruff>=0.8.0

View File

@@ -6,26 +6,21 @@ a per-test session, and a FastAPI ``TestClient`` wired to use it.
from __future__ import annotations
import json
import os
import time
import uuid
from collections.abc import AsyncGenerator, Generator
from unittest.mock import patch
import boto3
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from jose import jwt
from moto import mock_aws
from sqlalchemy import StaticPool, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.config.settings import settings
from app.db import Base, get_session
from app.main import app
from app.models import Plugin, Subscription, User
from app.models import Subscription, User
# ── Fixed test user IDs (one per tier) ───────────────────────────────
@@ -109,79 +104,6 @@ def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # n
app.dependency_overrides.pop(get_session, None)
# ── Seed data helpers ────────────────────────────────────────────────
_SEED_PLUGINS = [
Plugin(
id="plugin-github-sync",
name="GitHub Sync",
description="Sync tasks with GitHub Issues and pull requests.",
version="1.0.0",
author_name="Adiuva",
category="productivity",
price_cents=0,
permissions=json.dumps(["read:tasks", "write:tasks"]),
status="approved",
s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip",
install_count=0,
avg_rating=0.0,
),
Plugin(
id="plugin-slack-notify",
name="Slack Notifier",
description="Post task and checkpoint updates to Slack channels.",
version="1.2.0",
author_name="Adiuva",
category="communication",
price_cents=499,
permissions=json.dumps(["read:tasks", "read:checkpoints"]),
status="approved",
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
install_count=0,
avg_rating=0.0,
),
Plugin(
id="plugin-time-tracker",
name="Time Tracker",
description="Track time spent on tasks with automatic reporting.",
version="0.9.1",
author_name="Third Party",
category="productivity",
price_cents=999,
permissions=json.dumps(["read:tasks", "write:tasks"]),
status="approved",
s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip",
install_count=0,
avg_rating=0.0,
),
]
@pytest_asyncio.fixture
async def seed_plugins(db_session: AsyncSession) -> list[Plugin]:
"""Insert the 3 default approved plugins and return them."""
plugins = []
for template in _SEED_PLUGINS:
p = Plugin(
id=template.id,
name=template.name,
description=template.description,
version=template.version,
author_name=template.author_name,
category=template.category,
price_cents=template.price_cents,
permissions=template.permissions,
status=template.status,
s3_package_key=template.s3_package_key,
install_count=template.install_count,
avg_rating=template.avg_rating,
)
db_session.add(p)
plugins.append(p)
await db_session.commit()
return plugins
# ── JWT helpers ──────────────────────────────────────────────────────
@@ -212,24 +134,21 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
# ── S3 mock fixture ──────────────────────────────────────────────────
# ── CLI options ───────────────────────────────────────────────────────
S3_TEST_BUCKET = "test-bucket"
S3_TEST_REGION = "us-east-1"
@pytest.fixture
def s3_bucket():
"""Create a mocked S3 bucket via moto and patch BlobStore settings."""
with mock_aws():
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
os.environ.setdefault("AWS_DEFAULT_REGION", S3_TEST_REGION)
client = boto3.client("s3", region_name=S3_TEST_REGION)
client.create_bucket(Bucket=S3_TEST_BUCKET)
with patch("app.storage.blob_store.settings") as mock_settings:
mock_settings.S3_BUCKET = S3_TEST_BUCKET
mock_settings.S3_REGION = S3_TEST_REGION
mock_settings.AWS_ACCESS_KEY_ID = "testing"
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
yield S3_TEST_BUCKET
def pytest_addoption(parser):
parser.addoption(
"--preprocess-dir",
default=None,
help="Override fixture folder for preprocessor tests (must contain cases.yaml + data/)",
)
parser.addoption(
"--runner-dir",
default=None,
help="Override fixture folder for agent_runner_v2 eval tests (must contain cases.yaml + data/)",
)
parser.addoption(
"--journey-dir",
default=None,
help="Override fixture folder for journey_v2 eval tests (must contain cases.yaml + data/)",
)

View File

@@ -0,0 +1,86 @@
# Agent Runner V2 — eval test cases (Step 2, requires real LLM)
#
# Each case drives one parametrized `test_eval_runner` invocation.
#
# Keys
# ----
# id: str unique identifier shown in pytest output
# description: str human-readable label
# file: str filename inside data/
# file_path: str path reported to the executor (affects project-matching via filename)
# projects: [alpha|beta] symbolic project names resolved by the test helper
#
# Optional pre-existing records (dedup tests)
# existing_tasks: list of {id, title, status, priority}
# existing_notes: list of {id, title, content}
# existing_timelines: list of {id, title, date}
#
# Assertions (one or more)
# expect_insert: <table> at least 1 insert row in this table (tasks|notes|timelines)
# expect_no_insert: true zero inserts in any table
# expect_project_id: <id> any insert must carry this projectId
# expect_dedup: true task inserts == 0 OR task updates >= 1 (dedup check)
#
# Langfuse
# score_name: str observation score name
- id: "2.1"
description: "Action email → create_task"
file: email_action.html
file_path: /emails/ProjectAlpha_action.html
projects: [alpha, beta]
expect_insert: tasks
score_name: runner.email_to_task
- id: "2.2"
description: "Informational email → create_note"
file: email_info.html
file_path: /emails/ProjectAlpha_info.html
projects: [alpha, beta]
expect_insert: notes
score_name: runner.email_to_note
- id: "2.3"
description: "Email with meeting date → create_timeline"
file: email_date.html
file_path: /emails/ProjectAlpha_kickoff.html
projects: [alpha, beta]
expect_insert: timelines
score_name: runner.email_to_timeline
- id: "2.4"
description: "Filename contains project name → correct project assigned"
file: email_action.html
file_path: /emails/ProjectAlpha_report.html
projects: [alpha, beta]
expect_project_id: proj-alpha
score_name: runner.project_filename
- id: "2.5"
description: "Email body mentions project → correct project assigned"
file: email_action.html
file_path: /emails/email_001.html
projects: [alpha, beta]
expect_project_id: proj-alpha
score_name: runner.project_content
- id: "2.6"
description: "Newsletter + global rule no-project → no creates"
file: email_no_project.html
file_path: /emails/newsletter.html
projects: [alpha, beta]
expect_no_insert: true
score_name: runner.no_project
- id: "2.7"
description: "Existing task with same title → dedup (update not create)"
file: email_action.html
file_path: /emails/ProjectAlpha_followup.html
projects: [alpha]
existing_tasks:
- id: task-existing
title: Fix the login bug
status: todo
priority: medium
expect_dedup: true
score_name: runner.dedup

View File

@@ -0,0 +1,7 @@
<html><head></head><body>
<p><b>From:</b> boss@company.com</p>
<p><b>To:</b> dev@company.com</p>
<p><b>Subject:</b> Fix the login bug</p>
<p><b>Date:</b> 2026-04-07</p>
<p>Hi,<br>Please fix the login bug in Project Alpha by Friday. High priority!</p>
</body></html>

View File

@@ -0,0 +1,5 @@
<html><head></head><body>
<p><b>From:</b> pm@company.com</p>
<p><b>Subject:</b> Project Alpha kick-off meeting</p>
<p>The kick-off meeting for Project Alpha is scheduled for 2026-04-15 at 10:00.</p>
</body></html>

View File

@@ -0,0 +1,7 @@
<html><head></head><body>
<p><b>From:</b> pm@company.com</p>
<p><b>To:</b> team@company.com</p>
<p><b>Subject:</b> FYI: New policy for Project Alpha</p>
<p>Just a heads-up that starting next week all code reviews must be done
within 24 hours for Project Alpha. No action needed from you now.</p>
</body></html>

View File

@@ -0,0 +1,5 @@
<html><head></head><body>
<p><b>From:</b> newsletter@ads.com</p>
<p><b>Subject:</b> Weekly newsletter</p>
<p>Check out our latest deals on electronics!</p>
</body></html>

19
tests/fixtures/journey_v2/cases.yaml vendored Normal file
View File

@@ -0,0 +1,19 @@
# Journey V2 eval test cases — Step 4
#
# Only case 4.1 is kept as an automated eval. Cases 4.24.5 (multi-turn
# conversations that expect the LLM to produce a complete AgentConfig)
# are non-deterministic and tested manually — results tracked in Langfuse.
#
# Assertion keys:
# expect_question: true → first reply must contain "?"
- id: "4.1"
description: "Journey start explores directory, first reply contains a question"
directory: "/test/emails"
data_types: ["tasks", "notes", "timelines"]
directory_files:
- path: "/test/emails/outlook_export_2024.html"
content_file: "email_action.html"
user_messages: []
score_name: "journey.start"
expect_question: true

View File

@@ -0,0 +1,23 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Email: Fix the login bug</title>
<style>body { font-family: Arial; } .header { color: #666; }</style>
</head>
<body>
<div class="header">
<p><strong>From:</strong> boss@company.com</p>
<p><strong>To:</strong> dev@company.com</p>
<p><strong>Subject:</strong> Fix the login bug</p>
<p><strong>Date:</strong> Mon, 7 Apr 2026 09:15:00 +0000</p>
</div>
<div class="body">
<p>Hi,</p>
<p>Please fix the login bug in Project Alpha as soon as possible.
Users are reporting that they can't log in with their Google accounts.
This is blocking the whole team. Please resolve it by Friday.</p>
<p>Thanks,<br>Boss</p>
</div>
</body>
</html>

View File

@@ -0,0 +1,23 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Email: New policy update</title>
<style>body { font-family: Arial; }</style>
</head>
<body>
<div class="header">
<p><strong>From:</strong> hr@company.com</p>
<p><strong>To:</strong> all@company.com</p>
<p><strong>Subject:</strong> FYI: New remote work policy effective May 1</p>
<p><strong>Date:</strong> Tue, 8 Apr 2026 10:00:00 +0000</p>
</div>
<div class="body">
<p>Hi everyone,</p>
<p>Just a heads-up that starting May 1, 2026 the company will be moving to
a hybrid work model. You will be expected to come into the office at least
two days per week. More details will follow in the employee handbook.</p>
<p>Best,<br>HR Team</p>
</div>
</body>
</html>

68
tests/fixtures/preprocessors/cases.yaml vendored Normal file
View File

@@ -0,0 +1,68 @@
# Preprocessor test cases
#
# detect: <expected_type> → chiama detect_content_type(filename, content)
# process: <content_type> → chiama preprocess(content_type, content)
#
# Sorgente: file: <nome in data/> oppure generate: binary_noise
#
# Assertions piatte (solo per process):
# no_html: true clean_text senza tag HTML
# min_chars: N len(clean_text) >= N
# ratio_lt: F len(clean) / len(raw) < F
# has_meta: [k, ...] chiavi presenti in metadata
# contains: str | [str] substring(s) presenti in clean_text
# excludes: str | [str] substring(s) assenti da clean_text
# content_type: str result.content_type == questo valore
- id: "1.1"
file: email_action.html
detect: email_html
- id: "1.2"
file: generic_page.html
detect: generic_html
- id: "1.3"
file: notes.txt
detect: plain_text
- id: "1.4"
file: archive.xyz
generate: binary_noise
detect: unknown
- id: "1.5"
file: email_action.html
process: email_html
no_html: true
min_chars: 50
ratio_lt: 0.8
- id: "1.6"
file: email_action.html
process: email_html
has_meta: [subject, from]
- id: "1.7"
file: email_thread.html
process: email_html
contains: "Sure, I'll handle the deploy"
excludes: "Let's plan the deploy"
- id: "1.8"
file: email_single.html
process: email_html
contains: "deploy is done"
- id: "1.9"
file: email_heavy.html
process: email_html
no_html: true
min_chars: 30
excludes: [border-collapse, font-size]
- id: "1.10"
file: fallback.txt
process: unknown
min_chars: 1
content_type: unknown

View File

@@ -0,0 +1,25 @@
<!DOCTYPE html>
<html>
<head>
<title>Fix the login bug</title>
<style>
body { font-family: Arial, sans-serif; color: #333; margin: 0; padding: 20px; }
.header { background: #f5f5f5; padding: 10px; border-bottom: 1px solid #ddd; }
.body { padding: 20px; }
</style>
</head>
<body>
<div class="header">
<p><strong>From:</strong> boss@company.com</p>
<p><strong>To:</strong> dev@company.com</p>
<p><strong>Subject:</strong> Fix the login bug</p>
<p><strong>Date:</strong> Mon, 7 Apr 2026 09:00:00 +0200</p>
</div>
<div class="body">
<p>Hi,</p>
<p>Please fix the login bug by Friday. It is blocking the release.</p>
<p>Priority: high. Let me know if you need anything.</p>
<p>Thanks,<br>Boss</p>
</div>
</body>
</html>

View File

@@ -0,0 +1,49 @@
<!DOCTYPE html>
<html>
<head>
<style>
table { border-collapse: collapse; width: 100%; max-width: 600px; margin: 0 auto; }
td { padding: 8px 12px; border: 1px solid #dddddd; font-size: 12px; color: #444444; }
.header-row { background-color: #003366; color: #ffffff; font-weight: bold; }
.label-col { background-color: #f0f0f0; width: 80px; font-weight: bold; }
.footer-row { font-size: 10px; color: #999999; text-align: center; }
</style>
</head>
<body bgcolor="#eeeeee">
<center>
<table cellpadding="0" cellspacing="0">
<tr class="header-row">
<td colspan="2">Company Internal Update</td>
</tr>
<tr>
<td class="label-col">From:</td>
<td>newsletter@corp.com</td>
</tr>
<tr>
<td class="label-col">Subject:</td>
<td>Q1 Results Update</td>
</tr>
<tr>
<td class="label-col">Date:</td>
<td>Apr 7, 2026</td>
</tr>
<tr>
<td colspan="2">
<table width="100%" cellpadding="10">
<tr>
<td>
<p style="font-size:14px; font-weight:bold;">Dear Team,</p>
<p>Q1 results are in. Revenue up 15% year-over-year.</p>
<p>Please review the attached report and share any feedback by EOW.</p>
</td>
</tr>
</table>
</td>
</tr>
<tr class="footer-row">
<td colspan="2">Confidential — do not forward outside the company.</td>
</tr>
</table>
</center>
</body>
</html>

View File

@@ -0,0 +1,8 @@
<!DOCTYPE html>
<html><body>
<p><strong>From:</strong> alice@co.com</p>
<p><strong>To:</strong> team@co.com</p>
<p><strong>Subject:</strong> Quick update</p>
<p><strong>Date:</strong> Tue, 7 Apr 2026 10:30:00 +0200</p>
<p>The deploy is done. Everything looks good. No issues so far.</p>
</body></html>

View File

@@ -0,0 +1,24 @@
<!DOCTYPE html>
<html><body>
<div class="message-latest">
<p><strong>From:</strong> alice@co.com</p>
<p><strong>Subject:</strong> Re: Re: Deploy plan</p>
<p>Sure, I'll handle the deploy.</p>
</div>
<p>On Mon, Apr 6, 2026 at 3:00 PM, Bob &lt;bob@co.com&gt; wrote:</p>
<blockquote>
<p>From: bob@co.com</p>
<p>Can you handle the deploy?</p>
<p>On Sun, Apr 5, 2026 at 1:00 PM, Alice &lt;alice@co.com&gt; wrote:</p>
<blockquote>
<p>From: alice@co.com</p>
<p>Let's plan the deploy for Monday.</p>
<p>On Sat, Apr 4, 2026 at 11:00 AM, Charlie &lt;charlie@co.com&gt; wrote:</p>
<blockquote>
<p>From: charlie@co.com</p>
<p>We need to schedule the deploy. What day works?</p>
</blockquote>
</blockquote>
</blockquote>
</body></html>

View File

@@ -0,0 +1,3 @@
random text content without any structure
line two with some words
line three and more content here

View File

@@ -0,0 +1,35 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>My Web App</title>
<link rel="stylesheet" href="styles.css">
</head>
<body>
<nav>
<a href="/">Home</a>
<a href="/about">About</a>
<a href="/contact">Contact</a>
</nav>
<main>
<header>
<h1>Welcome to My App</h1>
</header>
<article>
<p>This is a generic web page with no email headers.</p>
<p>It has navigation, main content, and a footer.</p>
</article>
<section>
<h2>Features</h2>
<ul>
<li>Fast</li>
<li>Reliable</li>
<li>Secure</li>
</ul>
</section>
</main>
<footer>
<p>&copy; 2026 My App</p>
</footer>
</body>
</html>

View File

@@ -0,0 +1,15 @@
Meeting notes - April 7, 2026
Attendees: Alice, Bob, Charlie
Discussion points:
- Deploy scheduled for Friday
- Bug fix for login must be completed by Thursday
- Review Q1 numbers before EOW
Action items:
- Alice: fix login bug
- Bob: prepare deploy checklist
- Charlie: send Q1 report
Next meeting: April 14, 2026

View File

@@ -1,214 +0,0 @@
"""Unit tests for the agent registry, base classes, and tool loop."""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.core.agent_registry import AgentRegistry, ChatAgent
# ── Helpers ──────────────────────────────────────────────────────────
class _StubAgent(ChatAgent):
"""Minimal concrete agent for testing."""
def get_name(self) -> str:
return "stub"
def get_description(self) -> str:
return "A stub agent for tests"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return f"echo: {query}"
class _AnotherAgent(ChatAgent):
def get_name(self) -> str:
return "another"
def get_description(self) -> str:
return "Another stub"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return "another"
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _fresh_registry():
"""Reset the singleton between tests."""
AgentRegistry._instance = None
yield
AgentRegistry._instance = None
@pytest.fixture()
def reg() -> AgentRegistry:
return AgentRegistry()
# ── Tests ────────────────────────────────────────────────────────────
class TestRegisterAndGet:
def test_register_decorator(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
agent = reg.get("stub")
assert isinstance(agent, _StubAgent)
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
with pytest.raises(KeyError, match="not found"):
reg.get("nonexistent")
def test_register_multiple(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
reg.register(_AnotherAgent)
assert reg.get("stub").get_name() == "stub"
assert reg.get("another").get_name() == "another"
class TestListAgents:
def test_empty(self, reg: AgentRegistry) -> None:
assert reg.list_agents() == []
def test_list_after_register(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
agents = reg.list_agents()
assert len(agents) == 1
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
def test_list_multiple(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
reg.register(_AnotherAgent)
names = {a["name"] for a in reg.list_agents()}
assert names == {"stub", "another"}
class TestCallAgent:
@pytest.mark.asyncio
async def test_call_agent(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
result = await reg.call_agent("stub", "hello", {})
assert result == "echo: hello"
@pytest.mark.asyncio
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
with pytest.raises(KeyError):
await reg.call_agent("nope", "hi", {})
class TestSingleton:
def test_singleton_identity(self) -> None:
a = AgentRegistry()
b = AgentRegistry()
assert a is b
class TestToolLoop:
@pytest.mark.asyncio
async def test_no_tool_calls(self) -> None:
"""When the LLM responds without tool calls, return content directly."""
agent = _StubAgent()
ai_msg = MagicMock()
ai_msg.content = "final answer"
ai_msg.tool_calls = []
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
llm.ainvoke = AsyncMock(return_value=ai_msg)
result = await agent._tool_loop(llm, [], [])
assert result == "final answer"
@pytest.mark.asyncio
async def test_tool_call_then_answer(self) -> None:
"""LLM requests one tool call, gets result, then answers."""
agent = _StubAgent()
# First response: tool call
tool_call_msg = MagicMock()
tool_call_msg.content = ""
tool_call_msg.tool_calls = [
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
]
# Second response: final answer
final_msg = MagicMock()
final_msg.content = "done"
final_msg.tool_calls = []
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
# Mock tool
tool = AsyncMock()
tool.name = "my_tool"
tool.ainvoke = AsyncMock(return_value="tool_result")
result = await agent._tool_loop(llm, [], [tool])
assert result == "done"
tool.ainvoke.assert_called_once_with({"x": 1})
@pytest.mark.asyncio
async def test_unknown_tool_handled(self) -> None:
"""Unknown tool names produce an error message instead of crashing."""
agent = _StubAgent()
tool_call_msg = MagicMock()
tool_call_msg.content = ""
tool_call_msg.tool_calls = [
{"id": "call_1", "name": "missing", "args": {}}
]
final_msg = MagicMock()
final_msg.content = "recovered"
final_msg.tool_calls = []
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
result = await agent._tool_loop(llm, [], [])
assert result == "recovered"
@pytest.mark.asyncio
async def test_max_iter_reached(self) -> None:
"""When max iterations are exhausted, a final no-tools call is made."""
agent = _StubAgent()
# Every response requests a tool call
loop_msg = MagicMock()
loop_msg.content = ""
loop_msg.tool_calls = [
{"id": "call_x", "name": "t", "args": {}}
]
final_msg = MagicMock()
final_msg.content = "gave up"
final_msg.tool_calls = []
tool = AsyncMock()
tool.name = "t"
tool.ainvoke = AsyncMock(return_value="ok")
llm_with_tools = AsyncMock()
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm.ainvoke = AsyncMock(return_value=final_msg)
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
assert result == "gave up"
assert llm_with_tools.ainvoke.call_count == 2

810
tests/test_agent_runner.py Normal file
View File

@@ -0,0 +1,810 @@
"""Tests for Step 3.4: agent_runner module.
Coverage:
Unit:
- _is_overdue — cron schedule overdue detection
- _extract_items_from_content — LLM extraction + JSON parsing + validation
- _send_insert_to_client — tool_call frame construction + timeout
- run_local_agent — end-to-end local agent happy path
- run_local_agent — device offline path
- run_local_agent — file-read timeout path
- run_local_agent — LLM extraction error path
- run_cloud_agent — stub returns error immediately
- trigger_pending_runs — skipped when config is client-owned
- trigger_pending_runs — non-overdue skipped
- trigger_pending_runs — device_id filter for local agents
Integration:
- POST /agents/can-create — billing eligibility check
- POST /agents/trigger — creates run log + dispatches background task
"""
from __future__ import annotations
import asyncio
import json
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from app.core.agent_runner import (
_extract_items_from_content,
_is_overdue,
_send_insert_to_client,
run_cloud_agent,
run_local_agent,
trigger_pending_runs,
)
from app.core.device_manager import DeviceConnectionManager
from app.db import get_session
from app.main import app
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
from tests.conftest import TEST_USER_IDS, auth_header
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_FREE_UID = TEST_USER_IDS["free"]
_PRO_UID = TEST_USER_IDS["pro"]
def _make_local_config(user_id: str = _FREE_UID, device_id: str = "dev-001") -> LocalAgentConfig:
return LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=user_id,
device_id=device_id,
name="Test Local Agent",
directory_paths=["/home/user/emails"],
data_types=["tasks", "notes"],
prompt_template="Extract tasks and notes from this document.",
file_extensions=[".txt", ".eml"],
schedule_cron="0 */6 * * *",
enabled=True,
last_run_at=None,
)
def _make_cloud_config(user_id: str = _FREE_UID) -> CloudAgentConfig:
return CloudAgentConfig(
id=str(uuid.uuid4()),
user_id=user_id,
provider="gmail",
name="Test Gmail Agent",
data_types=["tasks"],
prompt_template="Extract tasks from email.",
schedule_cron="0 */6 * * *",
enabled=True,
last_run_at=None,
)
def _make_run_log(agent_id: str, agent_type: str = "local", user_id: str = _FREE_UID) -> AgentRunLog:
return AgentRunLog(
id=str(uuid.uuid4()),
agent_id=agent_id,
agent_type=agent_type,
user_id=user_id,
status="running",
started_at=datetime.now(timezone.utc),
)
def _make_manager(user_id: str = _FREE_UID, device_id: str = "dev-001") -> DeviceConnectionManager:
mgr = DeviceConnectionManager()
ws = MagicMock()
ws.send_text = AsyncMock()
mgr.register(user_id, device_id, ws)
return mgr
# ---------------------------------------------------------------------------
# _is_overdue
# ---------------------------------------------------------------------------
def test_is_overdue_never_run():
"""An agent that has never run is always overdue."""
assert _is_overdue("0 */6 * * *", None) is True
def test_is_overdue_very_recently_run():
"""An agent that just ran is not overdue."""
last = datetime.now(timezone.utc)
assert _is_overdue("0 */6 * * *", last) is False
def test_is_overdue_long_ago():
"""An agent last run 2 days ago with a 6-hour schedule is overdue."""
from datetime import timedelta
last = datetime.now(timezone.utc) - timedelta(days=2)
assert _is_overdue("0 */6 * * *", last) is True
def test_is_overdue_invalid_cron_returns_false():
"""Unparseable cron must not raise and should return False (fail-safe)."""
assert _is_overdue("not a cron", None) is False
def test_is_overdue_naive_datetime():
"""Naive datetime objects are handled without raising."""
from datetime import timedelta
last = datetime.utcnow() - timedelta(days=1) # naive
# Should not raise.
result = _is_overdue("0 */6 * * *", last)
assert isinstance(result, bool)
# ---------------------------------------------------------------------------
# _extract_items_from_content
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_extract_items_happy_path():
"""LLM returns valid JSON array; items with allowed tables are returned."""
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = json.dumps([
{"table": "tasks", "data": {"title": "Buy milk", "priority": "high"}},
{"table": "notes", "data": {"title": "Meeting recap", "content": "Discussed roadmap"}},
])
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
items = await _extract_items_from_content(
"Extract tasks and notes.",
"Email body: Buy milk urgently. Notes from meeting: discussed roadmap.",
["tasks", "notes"],
)
assert len(items) == 2
assert items[0]["table"] == "tasks"
assert items[0]["data"]["title"] == "Buy milk"
assert items[1]["table"] == "notes"
@pytest.mark.asyncio
async def test_extract_items_strips_forbidden_fields():
"""Fields like id, createdAt, isAiSuggested must be stripped from extracted data."""
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = json.dumps([
{
"table": "tasks",
"data": {
"title": "Review PR",
"id": "should-be-removed",
"createdAt": 99999,
"isAiSuggested": 0,
"isApproved": 1,
},
}
])
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
items = await _extract_items_from_content("Extract tasks.", "Review the PR.", ["tasks"])
assert len(items) == 1
data = items[0]["data"]
assert "id" not in data
assert "createdAt" not in data
assert "isAiSuggested" not in data
assert "isApproved" not in data
assert data["title"] == "Review PR"
@pytest.mark.asyncio
async def test_extract_items_invalid_json_returns_empty():
"""LLM returning invalid JSON must return empty list without raising."""
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = "Sorry, I cannot extract anything."
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
items = await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
assert items == []
@pytest.mark.asyncio
async def test_extract_items_disallowed_table_filtered():
"""Items whose table is not in data_types are discarded."""
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = json.dumps([
{"table": "tasks", "data": {"title": "Valid task"}},
{"table": "projects", "data": {"name": "Should be filtered"}},
])
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
# Only "tasks" is in data_types — "projects" should be filtered.
items = await _extract_items_from_content("Extract.", "content", ["tasks"])
assert len(items) == 1
assert items[0]["table"] == "tasks"
@pytest.mark.asyncio
async def test_extract_items_empty_data_types_returns_empty():
"""If no allowed data_types match, skip LLM call and return immediately."""
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock()
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
items = await _extract_items_from_content("Extract.", "content", [])
mock_llm.ainvoke.assert_not_called()
assert items == []
@pytest.mark.asyncio
async def test_extract_items_llm_error_propagates():
"""LLM API errors propagate so the caller (run_local_agent) can record them."""
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("API unavailable"))
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
with pytest.raises(RuntimeError, match="API unavailable"):
await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
# ---------------------------------------------------------------------------
# _send_insert_to_client
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_insert_to_client_happy_path():
"""Frame is sent with isAiSuggested/isApproved added; result is returned."""
mgr = _make_manager()
sent_payloads: list[dict] = []
original_send = mgr.send_frame
async def _capture_send(uid: str, frame: dict) -> None:
sent_payloads.append(frame)
# Immediately resolve the pending call with a success result.
call_id = frame["id"]
mgr.resolve_pending_call(uid, call_id, {"row": {"id": "new-id", "title": "Buy milk"}})
mgr.send_frame = _capture_send # type: ignore[method-assign]
result = await _send_insert_to_client(
_FREE_UID, "tasks", {"title": "Buy milk", "priority": "high"}, mgr
)
assert len(sent_payloads) == 1
payload = sent_payloads[0]
assert payload["action"] == "insert"
assert payload["table"] == "tasks"
assert payload["data"]["title"] == "Buy milk"
assert payload["data"]["isAiSuggested"] == 1
assert payload["data"]["isApproved"] == 0
assert result["row"]["title"] == "Buy milk"
@pytest.mark.asyncio
async def test_send_insert_to_client_timeout():
"""asyncio.TimeoutError is raised when Electron does not respond."""
mgr = _make_manager()
async def _slow_send(uid: str, frame: dict) -> None:
# Never resolve the pending call.
pass
mgr.send_frame = _slow_send # type: ignore[method-assign]
with patch("app.core.agent_runner._INSERT_TIMEOUT", 0.05):
with pytest.raises(asyncio.TimeoutError):
await _send_insert_to_client(_FREE_UID, "tasks", {"title": "X"}, mgr)
# ---------------------------------------------------------------------------
# run_local_agent
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_local_agent_device_offline():
"""run_local_agent marks run as error when device is offline."""
config = _make_local_config()
run_log = _make_run_log(config.id)
mgr = DeviceConnectionManager() # Empty — no device registered.
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_local_agent(_FREE_UID, config, run_log, mgr)
mock_finalize.assert_called_once()
_args, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("not connected" in e for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_local_agent_happy_path():
"""End-to-end: files received, LLM extracts one task, insert sent + ack'd."""
config = _make_local_config()
run_log = _make_run_log(config.id)
mgr = _make_manager()
# Build a fake agent_data frame (will be queued after send).
file_frame = {
"type": "agent_data",
"run_id": run_log.id,
"files": [{"path": "/email.eml", "content": "Urgent: fix the bug by Friday."}],
}
agent_complete_frame = None # sentinel
sent_frames: list[dict] = []
async def _mock_send(uid: str, frame: dict) -> None:
sent_frames.append(frame)
if frame.get("type") == "agent_run":
# Simulate Electron responding with file data then agent_complete.
q = mgr.get_agent_data_queue(uid, frame["run_id"])
await q.put(file_frame)
await q.put(agent_complete_frame)
elif frame.get("type") == "tool_call":
# Resolve the pending insert immediately.
mgr.resolve_pending_call(uid, frame["id"], {"row": {"id": "new-task", "title": "Fix the bug"}})
mgr.send_frame = _mock_send # type: ignore[method-assign]
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = json.dumps([
{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}
])
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_local_agent(_FREE_UID, config, run_log, mgr)
mock_finalize.assert_called_once()
_args, kwargs = mock_finalize.call_args
assert kwargs["status"] == "success"
assert kwargs["items_processed"] == 1
assert kwargs["items_created"] == 1
assert kwargs["errors"] == []
assert kwargs["update_config_last_run"] is False
# Verify agent_run frame was sent.
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
assert len(agent_run_frames) == 1
assert agent_run_frames[0]["agent_id"] == config.id
assert "paths" in agent_run_frames[0]["config"]
# Verify insert frame was sent with AI flags.
insert_frames = [f for f in sent_frames if f.get("type") == "tool_call"]
assert len(insert_frames) == 1
assert insert_frames[0]["data"]["isAiSuggested"] == 1
assert insert_frames[0]["data"]["isApproved"] == 0
@pytest.mark.asyncio
async def test_run_local_agent_file_read_timeout():
"""run_local_agent marks run as partial/error when device stops sending files."""
config = _make_local_config()
run_log = _make_run_log(config.id)
mgr = _make_manager()
async def _mock_send(uid: str, frame: dict) -> None:
# Don't put anything in the queue — simulate stalled device.
pass
mgr.send_frame = _mock_send # type: ignore[method-assign]
with patch("app.core.agent_runner._FILE_READ_TIMEOUT", 0.1), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_local_agent(_FREE_UID, config, run_log, mgr)
mock_finalize.assert_called_once()
_args, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error" # No items created, so error (not partial).
assert any("timed out" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_local_agent_llm_extraction_error():
"""LLM errors per-file are recorded; run continues for remaining files."""
config = _make_local_config()
run_log = _make_run_log(config.id)
mgr = _make_manager()
file_frame = {
"type": "agent_data",
"run_id": run_log.id,
"files": [
{"path": "/file1.eml", "content": "Email one."},
{"path": "/file2.eml", "content": "Email two."},
],
}
async def _mock_send(uid: str, frame: dict) -> None:
if frame.get("type") == "agent_run":
q = mgr.get_agent_data_queue(uid, frame["run_id"])
await q.put(file_frame)
await q.put(None) # agent_complete sentinel
mgr.send_frame = _mock_send # type: ignore[method-assign]
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM boom"))
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_local_agent(_FREE_UID, config, run_log, mgr)
_args, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert kwargs["items_processed"] == 2 # Both files attempted.
assert kwargs["items_created"] == 0
assert len(kwargs["errors"]) == 2 # One error per file.
# ---------------------------------------------------------------------------
# run_cloud_agent (stub)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_cloud_agent_device_offline():
"""Cloud agent aborts immediately when no device is connected."""
config = _make_cloud_config()
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = DeviceConnectionManager() # empty — no devices registered
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
mock_finalize.assert_called_once()
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_cloud_agent_no_oauth_token():
"""Cloud agent errors when no OAuth token is stored."""
config = _make_cloud_config()
config.oauth_token_encrypted = None
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_cloud_agent_token_decrypt_failure():
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
config = _make_cloud_config()
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
from cryptography.fernet import Fernet as _Fernet
valid_key = _Fernet.generate_key().decode()
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("decrypt" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_cloud_agent_happy_path_gmail():
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
from app.integrations import EmailMessage, encrypt_token
from cryptography.fernet import Fernet as _Fernet
fernet_key = _Fernet.generate_key().decode()
credentials = {
"token": "access_abc",
"refresh_token": "refresh_xyz",
"token_uri": "https://oauth2.googleapis.com/token",
"client_id": "cid",
"client_secret": "csec",
}
config = _make_cloud_config()
config.provider = "gmail"
config.prompt_template = "Extract tasks from this email."
config.data_types = ["tasks"]
with patch("app.integrations.settings") as ms:
ms.OAUTH_ENCRYPTION_KEY = fernet_key
config.oauth_token_encrypted = encrypt_token(credentials)
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
sample_email = EmailMessage(
id="msg001",
subject="Action required",
sender="boss@company.com",
body_text="Please fix the bug by Friday.",
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
)
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
with patch("app.integrations.settings") as mock_int_settings, \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
patch("app.core.agent_runner.async_session"):
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
mock_gmail = AsyncMock()
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
mock_gmail.refreshed_credentials = None
with patch("app.integrations.decrypt_token", return_value=credentials), \
patch("app.integrations.get_provider", return_value=mock_gmail):
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
mock_extract.assert_called_once()
mock_insert.assert_called_once()
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "success"
assert kwargs["items_processed"] == 1
assert kwargs["items_created"] == 1
assert kwargs["config_type"] == "cloud"
@pytest.mark.asyncio
async def test_run_cloud_agent_provider_fetch_error():
"""Cloud agent records error status when provider fetch raises RuntimeError."""
credentials = {"token": "abc"}
config = _make_cloud_config()
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
config.prompt_template = "Extract tasks."
config.data_types = ["tasks"]
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
mock_provider = AsyncMock()
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
mock_provider.refreshed_credentials = None
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
patch("app.integrations.decrypt_token", return_value=credentials), \
patch("app.integrations.get_provider", return_value=mock_provider), \
patch("app.core.agent_runner.async_session"):
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_cloud_agent_refreshed_token_persisted():
"""When the provider refreshes its token, the new ciphertext is written to DB."""
from app.integrations import EmailMessage, encrypt_token
from cryptography.fernet import Fernet as _Fernet
fernet_key = _Fernet.generate_key().decode()
credentials = {"token": "old_token", "refresh_token": "rt_old"}
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
config = _make_cloud_config()
config.prompt_template = "Extract tasks."
config.data_types = ["tasks"]
with patch("app.integrations.settings") as ms:
ms.OAUTH_ENCRYPTION_KEY = fernet_key
config.oauth_token_encrypted = encrypt_token(credentials)
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
mock_provider = AsyncMock()
mock_provider.fetch_messages = AsyncMock(return_value=[])
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
# Track DB writes via mock async_session.
mock_cfg_row = MagicMock()
mock_cfg_row.oauth_token_encrypted = None
mock_db = AsyncMock()
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
mock_db.__aexit__ = AsyncMock(return_value=False)
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
cfg_result = MagicMock()
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
mock_db.execute = AsyncMock(return_value=cfg_result)
mock_db.commit = AsyncMock()
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
patch("app.integrations.decrypt_token", return_value=credentials), \
patch("app.integrations.get_provider", return_value=mock_provider), \
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
patch("app.core.agent_runner.async_session", return_value=mock_db), \
patch("app.integrations.settings") as mock_int_settings:
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
# The new encrypted token should have been written to the config row.
mock_encrypt.assert_called_once_with(fresh_credentials)
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
@pytest.mark.asyncio
async def test_finalize_run_updates_cloud_config_last_run_at():
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
from app.core.agent_runner import _finalize_run
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
run_log.id = str(uuid.uuid4())
mock_cfg = MagicMock()
mock_cfg.last_run_at = None
cfg_result = MagicMock()
cfg_result.scalar_one_or_none.return_value = mock_cfg
mock_db = AsyncMock()
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
mock_db.__aexit__ = AsyncMock(return_value=False)
mock_db.merge = AsyncMock(return_value=run_log)
mock_db.execute = AsyncMock(return_value=cfg_result)
mock_db.commit = AsyncMock()
config_id = str(uuid.uuid4())
with patch("app.core.agent_runner.async_session", return_value=mock_db):
await _finalize_run(
run_log,
status="success",
update_config_last_run=True,
config_id=config_id,
config_type="cloud",
)
# CloudAgentConfig.last_run_at should have been set.
assert mock_cfg.last_run_at is not None
mock_db.commit.assert_called()
# ---------------------------------------------------------------------------
# trigger_pending_runs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_trigger_pending_runs_no_overdue():
"""Pending-run scan is skipped because agent config is client-owned."""
mgr = _make_manager()
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called()
@pytest.mark.asyncio
async def test_trigger_pending_runs_device_id_filter():
"""Device filtering is no longer backend-managed in pending runs."""
mgr = _make_manager(device_id="dev-001")
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called()
@pytest.mark.asyncio
async def test_trigger_pending_runs_dispatches_overdue():
"""No pending runs are dispatched by backend after config deprecation."""
mgr = _make_manager()
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called()
# ---------------------------------------------------------------------------
# Integration: POST /agents/can-create and /agents/trigger
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _override_db(db_session):
"""Route all get_session calls to the test SQLite session."""
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
@pytest.mark.asyncio
async def test_can_create_agent_allows_when_under_limit(client):
"""POST /agents/can-create returns allowed=True when under tier limit."""
resp = client.post(
"/api/v1/agents/can-create",
json={"active_agents": 0},
headers=auth_header("free"),
)
assert resp.status_code == 200
body = resp.json()
assert body["allowed"] is True
assert body["tier"] == "free"
assert body["active_agents"] == 0
assert body["limit"] == 2
@pytest.mark.asyncio
async def test_can_create_agent_denies_when_at_limit(client):
"""POST /agents/can-create returns allowed=False at free-tier limit."""
resp = client.post(
"/api/v1/agents/can-create",
json={"active_agents": 2},
headers=auth_header("free"),
)
assert resp.status_code == 200
body = resp.json()
assert body["allowed"] is False
assert body["limit"] == 2
@pytest.mark.asyncio
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
"""POST /agents/trigger creates a local run log and dispatches background task."""
dispatched: list[tuple[str, str]] = []
async def _fake_run(user_id, cfg, run_log, device_mgr):
dispatched.append((user_id, cfg.id))
def _fake_create_task(coro):
coro.close()
return MagicMock()
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
patch("asyncio.create_task") as mock_create_task:
mock_create_task.side_effect = _fake_create_task
resp = client.post(
"/api/v1/agents/trigger",
json={
"directory": "/home/user/docs",
"what_to_extract": ["task", "note"],
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
"batch_interval": "0 */6 * * *",
"custom_agent_prompt": "Extract tasks and notes.",
"active_agents": 0,
},
headers=auth_header("power"),
)
assert resp.status_code == 202
data = resp.json()
assert isinstance(data["agent_id"], str)
assert data["agent_id"]
assert data["status"] == "running"
assert data["agent_type"] == "local"
# Verify create_task was called (dispatching background run).
mock_create_task.assert_called_once()

View File

@@ -0,0 +1,432 @@
"""Tests for Local Agent V2 runner (Step 2).
Covers the unified per-file flow:
Phase A — detect + preprocess (Python, zero LLM)
Phase B — single LLM call with tools (classify + extract + create)
Fixture-based eval tests (2.12.7)
-----------------------------------
Cases are defined in tests/fixtures/agent_runner_v2/cases.yaml.
Email HTML files live in tests/fixtures/agent_runner_v2/data/.
Use --runner-dir to point at a custom folder (same structure required).
Unit tests (no LLM)
--------------------
2.8 items_created count → items_created == N create_* calls
2.9 Device offline → status=error
2.10 Empty file → items_processed=0, status=success
Run:
pytest tests/test_agent_runner_v2.py -v
pytest tests/test_agent_runner_v2.py -v -k "2_9 or 2_10 or 2_8" # unit only
pytest tests/test_agent_runner_v2.py -v -k "eval" # LLM evals only
pytest tests/test_agent_runner_v2.py -v --runner-dir /path/to/dir # custom fixtures
"""
from __future__ import annotations
import uuid
from contextlib import nullcontext
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import yaml
from app.core.agent_runner import (
_format_metadata,
_format_projects,
_get_extraction_rules,
_get_no_match_behavior,
_is_overdue,
run_local_agent,
)
from app.core.device_manager import DeviceConnectionManager
from app.core.langfuse_client import get_langfuse
from app.models import AgentRunLog, LocalAgentConfig
from tests.conftest import TEST_USER_IDS
# ── Constants ─────────────────────────────────────────────────────────────
_USER_ID = TEST_USER_IDS["power"]
_DEFAULT_FIXTURE_DIR = Path(__file__).parent / "fixtures" / "agent_runner_v2"
_AGENT_CONFIG = {
"content_types": [
{
"id": "email_html",
"label": "Email HTML",
"detection_hint": "HTML file with From/To/Subject headers",
"preprocessing": "email_html",
"extraction_prompt": (
"If the email contains a direct action request or task assignment → create a task. "
"If the email contains informational content, updates, or FYI → create a note. "
"If the email mentions a specific date for a meeting or deadline → create a timeline entry."
),
}
],
"global_rules": [
"Se il file non è riconducibile a nessun progetto, non creare alcuna entità."
],
"data_types": ["tasks", "notes", "timelines"],
}
# Canonical project definitions, referenced symbolically in cases.yaml.
_PROJECTS: dict[str, dict] = {
"alpha": {"id": "proj-alpha", "name": "Project Alpha", "status": "active"},
"beta": {"id": "proj-beta", "name": "Project Beta", "status": "active"},
}
# ── Fixture loading ───────────────────────────────────────────────────────
def _fixtures_dir(config) -> Path:
override = config.getoption("--runner-dir")
return Path(override) if override else _DEFAULT_FIXTURE_DIR
def _load_cases(config) -> list[dict]:
return yaml.safe_load(
(_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8")
)
def _read_case_file(case: dict, data_dir: Path) -> str:
return (data_dir / case["file"]).read_text(encoding="utf-8")
def _resolve_projects(entries: list[str | dict]) -> list[dict]:
"""Resolve project list from YAML: symbolic names and/or inline dicts."""
result = []
for entry in entries:
if isinstance(entry, str):
if entry in _PROJECTS:
result.append(_PROJECTS[entry])
elif isinstance(entry, dict):
result.append(entry)
return result
# ── pytest_generate_tests — parametrize eval tests from YAML ─────────────
def pytest_generate_tests(metafunc):
if "runner_case" not in metafunc.fixturenames:
return
cases = _load_cases(metafunc.config)
metafunc.parametrize("runner_case", cases, ids=[c["id"] for c in cases])
# ── Test helpers ──────────────────────────────────────────────────────────
def _make_config(
agent_config: dict | None = None,
directory: str = "/emails",
device_id: str = "dev-001",
) -> LocalAgentConfig:
return LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=_USER_ID,
device_id=device_id,
name="Test V2 Agent",
directory_paths=[directory],
data_types=["tasks", "notes", "timelines"],
prompt_template="",
agent_config=agent_config or _AGENT_CONFIG,
file_extensions=[".html", ".eml"],
schedule_cron="0 */6 * * *",
enabled=True,
last_run_at=None,
)
def _make_run_log(agent_id: str) -> AgentRunLog:
return AgentRunLog(
id=str(uuid.uuid4()),
agent_id=agent_id,
agent_type="local",
user_id=_USER_ID,
status="running",
started_at=datetime.now(timezone.utc),
)
def _make_manager(online: bool = True) -> DeviceConnectionManager:
mgr = DeviceConnectionManager()
if online:
ws = MagicMock()
ws.send_text = AsyncMock()
mgr.register(_USER_ID, "dev-001", ws)
return mgr
def _make_executor(
file_path: str,
file_content: str,
projects: list[dict] | None = None,
existing_tasks: list[dict] | None = None,
existing_notes: list[dict] | None = None,
existing_timelines: list[dict] | None = None,
) -> tuple[Any, list[dict]]:
"""Return (async_executor, captured_calls).
The executor handles all ``execute_on_client`` payloads:
directory listing, file reading, project/entity fetching, and CRUD.
"""
calls: list[dict] = []
_projects = projects if projects is not None else list(_PROJECTS.values())
async def _executor(payload: dict) -> dict:
action = payload.get("action", "")
table = payload.get("table", "")
data = payload.get("data") or {}
calls.append({"action": action, "table": table, "data": data})
if action == "list_directory":
return {"entries": [{"type": "file", "path": file_path}]}
if action == "get_file_metadata":
return {"modifiedAt": None}
if action == "read_file_content":
return {"content": file_content}
if action == "select":
if table == "projects":
return {"rows": _projects}
if table == "tasks":
return {"rows": existing_tasks or []}
if table == "notes":
return {"rows": existing_notes or []}
if table == "timelines":
return {"rows": existing_timelines or []}
return {"rows": []}
if action == "insert":
return {"row": {"id": str(uuid.uuid4()), **data}}
if action == "update":
return {"success": True}
return {}
return _executor, calls
# ── Unit: helper functions ────────────────────────────────────────────────
def test_format_projects_empty():
assert "(no projects" in _format_projects([])
def test_format_projects_with_data():
result = _format_projects([_PROJECTS["alpha"]])
assert "proj-alpha" in result
assert "Project Alpha" in result
def test_format_metadata_empty():
assert _format_metadata({}) == ""
def test_format_metadata_email():
meta = {"subject": "Fix bug", "from": "boss@co.com", "date": "2026-04-07"}
result = _format_metadata(meta)
assert "Fix bug" in result
assert "boss@co.com" in result
def test_get_extraction_rules_match():
rules = _get_extraction_rules(_AGENT_CONFIG, "email_html")
assert "task" in rules.lower()
def test_get_extraction_rules_fallback():
rules = _get_extraction_rules(_AGENT_CONFIG, "plain_text")
assert "extract" in rules.lower()
def test_get_no_match_behavior_from_global_rules():
behavior = _get_no_match_behavior(_AGENT_CONFIG)
assert behavior # non-empty
def test_get_no_match_behavior_default():
behavior = _get_no_match_behavior({})
assert "project" in behavior.lower()
# ── Unit: 2.9 — device offline ───────────────────────────────────────────
@pytest.mark.asyncio
async def test_2_9_device_offline():
"""2.9 No device online → status=error, no executor created."""
config = _make_config()
run_log = _make_run_log(config.id)
mgr = _make_manager(online=False)
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args
assert kwargs["status"] == "error"
assert any("not connected" in e for e in kwargs.get("errors", []))
# ── Unit: 2.10 — empty file ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_2_10_empty_file():
"""2.10 File with empty content → skipped, items_processed=0, success."""
config = _make_config()
run_log = _make_run_log(config.id)
mgr = _make_manager()
executor, calls = _make_executor(
file_path="/emails/empty.html",
file_content="",
projects=[_PROJECTS["alpha"]],
)
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args
assert kwargs["items_processed"] == 0
assert kwargs["status"] == "success"
assert kwargs["items_created"] == 0
# ── Unit: 2.8 — items_created count ─────────────────────────────────────
@pytest.mark.asyncio
async def test_2_8_items_created_count():
"""2.8 items_created == number of create_* tool calls per run."""
config = _make_config()
run_log = _make_run_log(config.id)
mgr = _make_manager()
executor, _calls = _make_executor(
file_path="/emails/action.html",
file_content="<html><body><p>Fix the login bug in Project Alpha.</p></body></html>",
projects=[_PROJECTS["alpha"]],
)
async def mock_run_agent(*, _tool_calls_out=None, **kw) -> str:
if _tool_calls_out is not None:
_tool_calls_out.extend(["create_task", "create_note", "update_task"])
return "Done."
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
patch("app.core.agent_runner._run_agent_with_tools", side_effect=mock_run_agent), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args
# Only create_task + create_note count (not update_task).
assert kwargs["items_created"] == 2
assert kwargs["items_processed"] == 1
# ── Eval: 2.12.7 — fixture-driven, real LLM + Langfuse scoring ──────────
#
# Cases loaded from tests/fixtures/agent_runner_v2/cases.yaml.
# Supported assertions (from YAML):
# expect_insert: <table> → at least 1 insert in that table
# expect_no_insert: true → zero inserts in any table
# expect_project_id: <id> → any insert carries this projectId
# expect_dedup: true → task inserts == 0 OR task updates >= 1
# ─────────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
@pytest.mark.eval
async def test_eval_runner(runner_case, pytestconfig):
"""Parametrized eval test — one invocation per YAML case."""
case: dict = runner_case
data_dir = _fixtures_dir(pytestconfig) / "data"
file_content = _read_case_file(case, data_dir)
projects = _resolve_projects(case.get("projects", []))
config = _make_config()
run_log = _make_run_log(config.id)
mgr = _make_manager()
executor, calls = _make_executor(
file_path=case["file_path"],
file_content=file_content,
projects=projects,
existing_tasks=case.get("existing_tasks"),
existing_notes=case.get("existing_notes"),
existing_timelines=case.get("existing_timelines"),
)
lf = get_langfuse()
obs_ctx = lf.start_as_current_observation(
name=f"eval-runner-{case['id']}-{case.get('score_name', 'unknown').replace('.', '-')}",
metadata={"step": "2", "case_id": case["id"]},
) if lf else nullcontext()
with obs_ctx as obs:
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args
inserts = [c for c in calls if c["action"] == "insert"]
score, comment = _evaluate_case(case, calls, kwargs)
if obs is not None:
obs.score(
name=case.get("score_name", f"runner.case_{case['id']}"),
value=score,
comment=comment,
)
if lf:
lf.flush()
assert score == 1.0, f"[{case['id']}] {case.get('description', '')}{comment}"
def _evaluate_case(case: dict, calls: list[dict], finalize_kwargs: dict) -> tuple[float, str]:
"""Return (score, comment) for a YAML case given the captured executor calls."""
inserts = [c for c in calls if c["action"] == "insert"]
if case.get("expect_no_insert"):
score = 1.0 if len(inserts) == 0 else 0.0
return score, f"inserts={len(inserts)} (expected 0)"
if "expect_insert" in case:
tables = case["expect_insert"]
if isinstance(tables, str):
tables = [tables]
missing = [t for t in tables if not any(c["table"] == t for c in inserts)]
score = 1.0 if not missing else 0.0
counts = {t: sum(1 for c in inserts if c["table"] == t) for t in tables}
return score, f"inserts={counts}" + (f" missing={missing}" if missing else "")
if "expect_project_id" in case:
expected_pid = case["expect_project_id"]
correct = any(c.get("data", {}).get("projectId") == expected_pid for c in inserts)
score = 1.0 if correct else 0.0
all_pids = [c.get("data", {}).get("projectId") for c in inserts]
return score, f"projectIds={all_pids} (expected {expected_pid!r})"
if case.get("expect_dedup"):
task_creates = [c for c in inserts if c["table"] == "tasks"]
task_updates = [c for c in calls if c["action"] == "update" and c["table"] == "tasks"]
score = 1.0 if len(task_creates) == 0 or len(task_updates) >= 1 else 0.0
return score, f"task_creates={len(task_creates)} task_updates={len(task_updates)}"
return 0.0, "no assertion defined in case"

243
tests/test_agent_setup.py Normal file
View File

@@ -0,0 +1,243 @@
"""Tests for the Chatbot Journey endpoints.
Covers:
1. Start journey for local agent → session_id + first question, done=False
2. Start journey for cloud agent → contextual email-focused question
3. Start journey with existing agent_id → session seeded, first question returned
4. Start journey with non-existent agent_id → still succeeds (graceful fallback)
5. Message: continue conversation → done=False, follow-up question returned
6. Message: LLM wraps up → done=True + prompt_template extracted correctly
7. Message with max-turns nudge → no crash, returns response
8. Invalid session_id → 404
9. Expired session → 404
10. Session ownership: user B cannot access user A's session
11. No JWT on /start → 401
12. No JWT on /message → 401
"""
from __future__ import annotations
import time
import uuid
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.routes.agent_setup import (
_SESSION_TTL_SECONDS,
_TEMPLATE_END,
_TEMPLATE_START,
_extract_template,
_sessions,
)
from app.models import LocalAgentConfig
from tests.conftest import TEST_USER_IDS, auth_header
# ── Helpers ──────────────────────────────────────────────────────────────
def _start(client: TestClient, agent_type: str = "local", agent_id: str | None = None, tier: str = "power") -> dict:
body: dict = {"agent_type": agent_type}
if agent_id:
body["agent_id"] = agent_id
resp = client.post("/api/v1/agents/journey/start", json=body, headers=auth_header(tier))
return resp
def _message(client: TestClient, session_id: str, message: str, tier: str = "power") -> dict:
return client.post(
"/api/v1/agents/journey/message",
json={"session_id": session_id, "message": message},
headers=auth_header(tier),
)
# ── Unit: _extract_template ───────────────────────────────────────────────
def test_extract_template_present():
text = f"Some preamble.\n{_TEMPLATE_START}\nExtract tasks from emails.\n{_TEMPLATE_END}\nTrailing text."
result = _extract_template(text)
assert result == "Extract tasks from emails."
def test_extract_template_absent():
assert _extract_template("No markers here.") is None
def test_extract_template_empty_content():
text = f"{_TEMPLATE_START}\n{_TEMPLATE_END}"
assert _extract_template(text) is None
# ── Start journey ─────────────────────────────────────────────────────────
def test_start_journey_local(client: TestClient):
resp = _start(client, agent_type="local")
assert resp.status_code == 200
body = resp.json()
assert "session_id" in body
assert body["done"] is False
assert body["prompt_template"] is None
assert len(body["message"]) > 0
# Local question should be about files/directories
assert any(w in body["message"].lower() for w in ("file", "director", "document", "monitor"))
def test_start_journey_cloud(client: TestClient):
resp = _start(client, agent_type="cloud")
assert resp.status_code == 200
body = resp.json()
assert body["done"] is False
# Cloud question should mention emails or messages
assert any(w in body["message"].lower() for w in ("email", "message", "communication"))
def test_start_journey_with_agent_id(client: TestClient, db_session: AsyncSession):
"""When agent_id is provided, session should be created even if agent doesn't exist."""
fake_agent_id = str(uuid.uuid4())
resp = _start(client, agent_type="local", agent_id=fake_agent_id)
# Should succeed gracefully even if the agent_id doesn't exist
assert resp.status_code == 200
body = resp.json()
assert body["done"] is False
def test_start_journey_with_existing_agent(client: TestClient, db_session: AsyncSession):
"""When a real local agent is provided, session is seeded with its prompt_template."""
import asyncio
user_id = TEST_USER_IDS["power"]
agent = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=user_id,
name="Test Agent",
device_id="device-1",
directory_paths=["/home/user/emails"],
data_types=["tasks"],
prompt_template="Extract tasks from .eml files.",
file_extensions=[".eml"],
schedule_cron="0 */6 * * *",
enabled=True,
)
async def _seed():
db_session.add(agent)
await db_session.commit()
asyncio.get_event_loop().run_until_complete(_seed())
resp = _start(client, agent_type="local", agent_id=agent.id)
assert resp.status_code == 200
body = resp.json()
assert body["done"] is False
# The session should be stored
assert body["session_id"] in _sessions
def test_start_journey_requires_auth(client: TestClient):
resp = client.post("/api/v1/agents/journey/start", json={"agent_type": "local"})
assert resp.status_code == 401
# ── Message ───────────────────────────────────────────────────────────────
def test_message_continues_conversation(client: TestClient):
"""A mid-journey reply (no template markers) returns done=False."""
follow_up = "That looks good. Can you tell me more about priority rules?"
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
start_resp = _start(client, agent_type="local")
assert start_resp.status_code == 200
session_id = start_resp.json()["session_id"]
msg_resp = _message(client, session_id, "I have .eml and .txt files")
assert msg_resp.status_code == 200
body = msg_resp.json()
assert body["done"] is False
assert body["prompt_template"] is None
assert body["message"] == follow_up
assert body["session_id"] == session_id
def test_message_produces_template(client: TestClient):
"""When the LLM includes PROMPT_TEMPLATE markers, done=True and prompt_template is set."""
final_template = "Extract tasks from email. Subject → title. 'urgent' → high priority."
llm_response = (
"Great, I have all the information I need.\n"
f"{_TEMPLATE_START}\n{final_template}\n{_TEMPLATE_END}\n"
)
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=llm_response)):
start_resp = _start(client, agent_type="cloud")
assert start_resp.status_code == 200
session_id = start_resp.json()["session_id"]
msg_resp = _message(client, session_id, "Only invoices from clients")
assert msg_resp.status_code == 200
body = msg_resp.json()
assert body["done"] is True
assert body["prompt_template"] == final_template
# Session should be cleaned up
assert session_id not in _sessions
def test_message_invalid_session(client: TestClient):
resp = _message(client, "nonexistent-session-id", "hello")
assert resp.status_code == 404
def test_message_wrong_owner(client: TestClient):
"""User B cannot access user A's session."""
start_resp = _start(client, agent_type="local", tier="power")
session_id = start_resp.json()["session_id"]
# user with "pro" tier (different user_id) tries to send a message
resp = client.post(
"/api/v1/agents/journey/message",
json={"session_id": session_id, "message": "hello"},
headers=auth_header("pro"), # different user
)
assert resp.status_code == 404
def test_message_expired_session(client: TestClient):
"""Expired sessions return 404."""
start_resp = _start(client, agent_type="local")
session_id = start_resp.json()["session_id"]
# Manually expire the session
_sessions[session_id].created_at = time.monotonic() - _SESSION_TTL_SECONDS - 1
resp = _message(client, session_id, "hello")
assert resp.status_code == 404
def test_message_requires_auth(client: TestClient):
resp = client.post(
"/api/v1/agents/journey/message",
json={"session_id": "any", "message": "hello"},
)
assert resp.status_code == 401
def test_message_max_turns_nudge(client: TestClient):
"""After _MAX_TURNS user messages, a system nudge is appended but no crash occurs."""
from app.api.routes.agent_setup import _MAX_TURNS
follow_up = "Tell me more about priority rules."
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
start_resp = _start(client, agent_type="local")
session_id = start_resp.json()["session_id"]
for i in range(_MAX_TURNS):
resp = _message(client, session_id, f"Answer {i + 1}")
assert resp.status_code == 200
# While no template produced, session must still exist
if resp.json()["done"]:
break # LLM decided to wrap up early — also fine

View File

@@ -1,620 +0,0 @@
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
from __future__ import annotations
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import app.agents # noqa: F401 — triggers @registry.register decorators
from app.agents.checkpoint_agent import CheckpointAgent
from app.agents.note_agent import NoteAgent
from app.agents.project_agent import ProjectAgent
from app.agents.task_agent import TaskAgent
from app.core.agent_registry import registry
# ── Helpers ──────────────────────────────────────────────────────────
def _mock_llm(response_text: str) -> MagicMock:
"""Return a mock LLM that responds with *response_text* (no tool calls)."""
msg = MagicMock()
msg.content = response_text
msg.tool_calls = []
llm = MagicMock()
bound = MagicMock()
bound.ainvoke = AsyncMock(return_value=msg)
llm.bind_tools = MagicMock(return_value=bound)
llm.ainvoke = AsyncMock(return_value=msg)
return llm
def _mock_llm_with_tool_call(
tool_name: str, tool_args: dict[str, Any], final_text: str
) -> MagicMock:
"""Mock LLM that fires one tool call then returns *final_text*."""
tool_msg = MagicMock()
tool_msg.content = ""
tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
final_msg = MagicMock()
final_msg.content = final_text
final_msg.tool_calls = []
bound = MagicMock()
bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
llm = MagicMock()
llm.bind_tools = MagicMock(return_value=bound)
llm.ainvoke = AsyncMock(return_value=final_msg)
return llm
# ── Registration ──────────────────────────────────────────────────────
class TestAgentRegistration:
def test_all_agents_registered(self) -> None:
names = {a["name"] for a in registry.list_agents()}
assert {
"task_agent", "checkpoint_agent", "project_agent", "note_agent"
}.issubset(names)
def test_registry_returns_correct_types(self) -> None:
assert isinstance(registry.get("task_agent"), TaskAgent)
assert isinstance(registry.get("checkpoint_agent"), CheckpointAgent)
assert isinstance(registry.get("project_agent"), ProjectAgent)
assert isinstance(registry.get("note_agent"), NoteAgent)
def test_descriptions_present(self) -> None:
for agent_info in registry.list_agents():
assert agent_info["description"], f"Empty description: {agent_info['name']}"
# ── TaskAgent ─────────────────────────────────────────────────────────
class TestTaskAgent:
def test_name(self) -> None:
assert TaskAgent().get_name() == "task_agent"
def test_description(self) -> None:
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
def test_get_tools_count(self) -> None:
assert len(TaskAgent().get_tools()) == 8
def test_tool_names(self) -> None:
names = {t.name for t in TaskAgent().get_tools()}
assert names == {
"list_tasks",
"create_task",
"update_task",
"delete_task",
"list_tasks_due_today",
"list_task_comments",
"add_task_comment",
"delete_task_comment",
}
@pytest.mark.asyncio
async def test_handle_returns_string(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Task created.")
result = await TaskAgent().handle("create a task", {})
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Here are your tasks.")
result = await TaskAgent().handle("list my tasks", {})
assert result == "Here are your tasks."
@pytest.mark.asyncio
async def test_handle_with_create_task_tool_call(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_task",
{"title": "Buy groceries", "priority": "low"},
"Task 'Buy groceries' created.",
)
result = await TaskAgent().handle("add a grocery task", {})
assert result == "Task 'Buy groceries' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await TaskAgent().handle("help", {})
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_handle_accepts_rich_context(self) -> None:
context = {
"user_profile": {"id": "u1", "tier": "pro"},
"recent_tasks": [{"id": "t1", "title": "Old task"}],
}
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Tasks listed.")
result = await TaskAgent().handle("show tasks", context)
assert isinstance(result, str)
class TestTaskAgentTools:
@pytest.mark.asyncio
async def test_list_tasks_defaults(self) -> None:
from app.agents.task_agent import list_tasks
result = await list_tasks.ainvoke({})
data = json.loads(result)
assert data["action"] == "list"
assert data["table"] == "tasks"
@pytest.mark.asyncio
async def test_list_tasks_with_status_filter(self) -> None:
from app.agents.task_agent import list_tasks
result = await list_tasks.ainvoke({"status": "done"})
data = json.loads(result)
assert data["filters"]["status"] == "done"
@pytest.mark.asyncio
async def test_create_task_defaults(self) -> None:
from app.agents.task_agent import create_task
result = await create_task.ainvoke({"title": "Test task"})
data = json.loads(result)
assert data["action"] == "create_record"
assert data["table"] == "tasks"
assert data["data"]["title"] == "Test task"
assert data["data"]["status"] == "todo"
assert data["data"]["priority"] == "medium"
@pytest.mark.asyncio
async def test_create_task_with_all_fields(self) -> None:
from app.agents.task_agent import create_task
result = await create_task.ainvoke({
"title": "Deploy",
"priority": "high",
"status": "in_progress",
"project_id": "p1",
"is_ai_suggested": 1,
})
data = json.loads(result)
assert data["data"]["priority"] == "high"
assert data["data"]["status"] == "in_progress"
assert data["data"]["projectId"] == "p1"
assert data["data"]["isAiSuggested"] == 1
@pytest.mark.asyncio
async def test_update_task_with_status(self) -> None:
from app.agents.task_agent import update_task
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
data = json.loads(result)
assert data["action"] == "update_record"
assert data["data"]["id"] == "t1"
assert data["data"]["updates"]["status"] == "done"
@pytest.mark.asyncio
async def test_update_task_empty_updates(self) -> None:
from app.agents.task_agent import update_task
result = await update_task.ainvoke({"task_id": "t1"})
data = json.loads(result)
assert data["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_task(self) -> None:
from app.agents.task_agent import delete_task
result = await delete_task.ainvoke({"task_id": "t1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["table"] == "tasks"
assert data["data"]["id"] == "t1"
@pytest.mark.asyncio
async def test_list_tasks_due_today(self) -> None:
from app.agents.task_agent import list_tasks_due_today
result = await list_tasks_due_today.ainvoke({})
data = json.loads(result)
assert data["action"] == "list_due_today"
assert data["table"] == "tasks"
@pytest.mark.asyncio
async def test_list_task_comments(self) -> None:
from app.agents.task_agent import list_task_comments
result = await list_task_comments.ainvoke({"task_id": "t1"})
data = json.loads(result)
assert data["action"] == "list"
assert data["table"] == "taskComments"
assert data["filters"]["taskId"] == "t1"
@pytest.mark.asyncio
async def test_add_task_comment(self) -> None:
from app.agents.task_agent import add_task_comment
result = await add_task_comment.ainvoke({
"task_id": "t1",
"author": "Alice",
"content": "Looks good!",
})
data = json.loads(result)
assert data["action"] == "create_record"
assert data["table"] == "taskComments"
assert data["data"]["taskId"] == "t1"
assert data["data"]["author"] == "Alice"
assert data["data"]["content"] == "Looks good!"
@pytest.mark.asyncio
async def test_delete_task_comment(self) -> None:
from app.agents.task_agent import delete_task_comment
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["table"] == "taskComments"
assert data["data"]["id"] == "c1"
# ── CheckpointAgent ───────────────────────────────────────────────────
class TestCheckpointAgent:
def test_name(self) -> None:
assert CheckpointAgent().get_name() == "checkpoint_agent"
def test_description(self) -> None:
assert CheckpointAgent().get_description() == "Manages project checkpoints (milestones): list, create, update, delete"
def test_get_tools_count(self) -> None:
assert len(CheckpointAgent().get_tools()) == 4
def test_tool_names(self) -> None:
names = {t.name for t in CheckpointAgent().get_tools()}
assert names == {"list_checkpoints", "create_checkpoint", "update_checkpoint", "delete_checkpoint"}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("No checkpoints found.")
result = await CheckpointAgent().handle("list checkpoints", {})
assert result == "No checkpoints found."
@pytest.mark.asyncio
async def test_handle_with_create_tool_call(self) -> None:
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_checkpoint",
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
"Checkpoint 'MVP Launch' created.",
)
result = await CheckpointAgent().handle("add MVP checkpoint", {})
assert result == "Checkpoint 'MVP Launch' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await CheckpointAgent().handle("show milestones", {})
assert isinstance(result, str)
class TestCheckpointAgentTools:
@pytest.mark.asyncio
async def test_list_checkpoints_no_project(self) -> None:
from app.agents.checkpoint_agent import list_checkpoints
result = await list_checkpoints.ainvoke({})
data = json.loads(result)
assert data["action"] == "list"
assert data["table"] == "checkpoints"
assert data["filters"]["projectId"] is None
@pytest.mark.asyncio
async def test_list_checkpoints_with_project(self) -> None:
from app.agents.checkpoint_agent import list_checkpoints
result = await list_checkpoints.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["filters"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_create_checkpoint(self) -> None:
from app.agents.checkpoint_agent import create_checkpoint
result = await create_checkpoint.ainvoke({
"project_id": "p1",
"title": "Beta release",
"date": 1700000000000,
})
data = json.loads(result)
assert data["action"] == "create_record"
assert data["table"] == "checkpoints"
assert data["data"]["projectId"] == "p1"
assert data["data"]["title"] == "Beta release"
assert data["data"]["date"] == 1700000000000
@pytest.mark.asyncio
async def test_create_checkpoint_ai_suggested(self) -> None:
from app.agents.checkpoint_agent import create_checkpoint
result = await create_checkpoint.ainvoke({
"project_id": "p1",
"title": "Review",
"date": 1700000000000,
"is_ai_suggested": 1,
})
data = json.loads(result)
assert data["data"]["isAiSuggested"] == 1
assert data["data"]["isApproved"] == 0
@pytest.mark.asyncio
async def test_update_checkpoint_approve(self) -> None:
from app.agents.checkpoint_agent import update_checkpoint
result = await update_checkpoint.ainvoke({
"checkpoint_id": "c1",
"is_approved": 1,
})
data = json.loads(result)
assert data["action"] == "update_record"
assert data["data"]["id"] == "c1"
assert data["data"]["updates"]["isApproved"] == 1
@pytest.mark.asyncio
async def test_update_checkpoint_empty_updates(self) -> None:
from app.agents.checkpoint_agent import update_checkpoint
result = await update_checkpoint.ainvoke({"checkpoint_id": "c1"})
data = json.loads(result)
assert data["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_checkpoint(self) -> None:
from app.agents.checkpoint_agent import delete_checkpoint
result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["table"] == "checkpoints"
assert data["data"]["id"] == "c1"
# ── ProjectAgent ──────────────────────────────────────────────────────
class TestProjectAgent:
def test_name(self) -> None:
assert ProjectAgent().get_name() == "project_agent"
def test_description(self) -> None:
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
def test_get_tools_count(self) -> None:
assert len(ProjectAgent().get_tools()) == 6
def test_tool_names(self) -> None:
names = {t.name for t in ProjectAgent().get_tools()}
assert names == {
"list_projects",
"list_all_projects",
"get_project",
"create_project",
"update_project",
"delete_project",
}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.project_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Project Alpha is active.")
result = await ProjectAgent().handle("show my projects", {})
assert result == "Project Alpha is active."
@pytest.mark.asyncio
async def test_handle_with_create_project_tool_call(self) -> None:
with patch("app.agents.project_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_project",
{"name": "Pippo"},
"Project 'Pippo' created.",
)
result = await ProjectAgent().handle("create project Pippo", {})
assert result == "Project 'Pippo' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.project_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await ProjectAgent().handle("archive old project", {})
assert isinstance(result, str)
class TestProjectAgentTools:
@pytest.mark.asyncio
async def test_list_projects_defaults(self) -> None:
from app.agents.project_agent import list_projects
result = await list_projects.ainvoke({})
data = json.loads(result)
assert data["action"] == "list"
assert data["table"] == "projects"
assert data["filters"]["includeArchived"] is False
@pytest.mark.asyncio
async def test_list_projects_include_archived(self) -> None:
from app.agents.project_agent import list_projects
result = await list_projects.ainvoke({"include_archived": 1})
data = json.loads(result)
assert data["filters"]["includeArchived"] is True
@pytest.mark.asyncio
async def test_list_all_projects(self) -> None:
from app.agents.project_agent import list_all_projects
result = await list_all_projects.ainvoke({})
data = json.loads(result)
assert data["action"] == "list_all"
assert data["table"] == "projects"
@pytest.mark.asyncio
async def test_get_project(self) -> None:
from app.agents.project_agent import get_project
result = await get_project.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["action"] == "get"
assert data["table"] == "projects"
assert data["data"]["id"] == "p1"
@pytest.mark.asyncio
async def test_create_project_name_only(self) -> None:
from app.agents.project_agent import create_project
result = await create_project.ainvoke({"name": "Alpha"})
data = json.loads(result)
assert data["action"] == "create_record"
assert data["data"]["name"] == "Alpha"
assert data["data"]["clientId"] is None
@pytest.mark.asyncio
async def test_create_project_with_client(self) -> None:
from app.agents.project_agent import create_project
result = await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
data = json.loads(result)
assert data["data"]["clientId"] == "cl1"
@pytest.mark.asyncio
async def test_update_project_archive(self) -> None:
from app.agents.project_agent import update_project
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
data = json.loads(result)
assert data["action"] == "update_record"
assert data["data"]["id"] == "p1"
assert data["data"]["updates"]["status"] == "archived"
@pytest.mark.asyncio
async def test_update_project_empty_updates(self) -> None:
from app.agents.project_agent import update_project
result = await update_project.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_project(self) -> None:
from app.agents.project_agent import delete_project
result = await delete_project.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["data"]["id"] == "p1"
# ── NoteAgent ─────────────────────────────────────────────────────────
class TestNoteAgent:
def test_name(self) -> None:
assert NoteAgent().get_name() == "note_agent"
def test_description(self) -> None:
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
def test_get_tools_count(self) -> None:
assert len(NoteAgent().get_tools()) == 5
def test_tool_names(self) -> None:
names = {t.name for t in NoteAgent().get_tools()}
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.note_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Note created.")
result = await NoteAgent().handle("create a note", {})
assert result == "Note created."
@pytest.mark.asyncio
async def test_handle_with_create_note_tool_call(self) -> None:
with patch("app.agents.note_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_note",
{"title": "Daily log", "content": "# Today\nAll good."},
"Note 'Daily log' created.",
)
result = await NoteAgent().handle("log today's progress", {})
assert result == "Note 'Daily log' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.note_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await NoteAgent().handle("show notes", {})
assert isinstance(result, str)
class TestNoteAgentTools:
@pytest.mark.asyncio
async def test_list_notes_no_project(self) -> None:
from app.agents.note_agent import list_notes
result = await list_notes.ainvoke({})
data = json.loads(result)
assert data["action"] == "list"
assert data["table"] == "notes"
assert data["filters"]["projectId"] is None
@pytest.mark.asyncio
async def test_list_notes_with_project(self) -> None:
from app.agents.note_agent import list_notes
result = await list_notes.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["filters"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_get_note(self) -> None:
from app.agents.note_agent import get_note
result = await get_note.ainvoke({"note_id": "n1"})
data = json.loads(result)
assert data["action"] == "get"
assert data["table"] == "notes"
assert data["data"]["id"] == "n1"
@pytest.mark.asyncio
async def test_create_note_minimal(self) -> None:
from app.agents.note_agent import create_note
result = await create_note.ainvoke({
"title": "Daily log",
"content": "# Today\nAll good.",
})
data = json.loads(result)
assert data["action"] == "create_record"
assert data["table"] == "notes"
assert data["data"]["title"] == "Daily log"
assert data["data"]["content"] == "# Today\nAll good."
assert data["data"]["projectId"] is None
@pytest.mark.asyncio
async def test_create_note_with_project(self) -> None:
from app.agents.note_agent import create_note
result = await create_note.ainvoke({
"title": "Sprint notes",
"content": "## Sprint 1",
"project_id": "p1",
})
data = json.loads(result)
assert data["data"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_update_note_content_only(self) -> None:
from app.agents.note_agent import update_note
result = await update_note.ainvoke({
"note_id": "n1",
"content": "# Updated content",
})
data = json.loads(result)
assert data["action"] == "update_record"
assert data["data"]["id"] == "n1"
assert data["data"]["updates"]["content"] == "# Updated content"
assert "title" not in data["data"]["updates"]
@pytest.mark.asyncio
async def test_update_note_empty_updates(self) -> None:
from app.agents.note_agent import update_note
result = await update_note.ainvoke({"note_id": "n1"})
data = json.loads(result)
assert data["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_note(self) -> None:
from app.agents.note_agent import delete_note
result = await delete_note.ainvoke({"note_id": "n1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["table"] == "notes"
assert data["data"]["id"] == "n1"

View File

@@ -1,243 +0,0 @@
"""Tests for backup routes: upload, download, history, delete.
Exercises the backup lifecycle through the FastAPI TestClient against the
in-memory SQLite test database and moto-mocked S3 bucket.
"""
from __future__ import annotations
import hashlib
from tests.conftest import auth_header, TEST_USER_IDS
# ── Helpers ───────────────────────────────────────────────────────────
_BLOB = b"encrypted-backup-blob-opaque-bytes"
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
_VERSION = 1
_TIMESTAMP = 1700000000000 # arbitrary ms timestamp
def _backup_headers(tier: str = "power", **overrides) -> dict[str, str]:
"""Return auth + backup metadata headers."""
headers = auth_header(tier)
headers["X-Backup-Version"] = str(overrides.get("version", _VERSION))
headers["X-Backup-Timestamp"] = str(overrides.get("timestamp", _TIMESTAMP))
headers["X-Backup-Checksum"] = overrides.get("checksum", _CHECKSUM)
headers["Content-Type"] = "application/octet-stream"
return headers
def _upload(client, tier="power", **overrides) -> "Response": # noqa: F821
"""Upload a backup blob and return the response."""
return client.put(
"/api/v1/backup",
content=overrides.pop("blob", _BLOB),
headers=_backup_headers(tier, **overrides),
)
# ── TestUploadBackup ──────────────────────────────────────────────────
class TestUploadBackup:
"""PUT /api/v1/backup"""
def test_upload_success(self, client, s3_bucket) -> None:
resp = _upload(client, tier="power")
assert resp.status_code == 200
assert resp.json() == {"ok": True}
def test_upload_creates_history_entry(self, client, s3_bucket) -> None:
_upload(client, tier="power")
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert len(history) == 1
assert history[0]["version"] == _VERSION
assert history[0]["timestamp"] == _TIMESTAMP
assert history[0]["checksum"] == _CHECKSUM
def test_upload_bad_checksum(self, client, s3_bucket) -> None:
resp = _upload(client, tier="power", checksum="0" * 64)
assert resp.status_code == 400
def test_upload_free_tier_blocked(self, client, s3_bucket) -> None:
"""Free tier has backup_gb=0 → should return 402."""
resp = _upload(client, tier="free")
assert resp.status_code == 402
def test_upload_pro_tier_allowed(self, client, s3_bucket) -> None:
"""Pro tier has backup_gb=5 → small blob succeeds."""
resp = _upload(client, tier="pro")
assert resp.status_code == 200
# ── TestDownloadBackup ────────────────────────────────────────────────
class TestDownloadBackup:
"""GET /api/v1/backup"""
def test_download_latest(self, client, s3_bucket) -> None:
_upload(client, tier="power")
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.content == _BLOB
assert resp.headers["X-Checksum"] == _CHECKSUM
assert resp.headers["X-Backup-Version"] == str(_VERSION)
def test_download_no_backup_returns_404(self, client, s3_bucket) -> None:
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 404
def test_download_if_modified_since_returns_304(self, client, s3_bucket) -> None:
"""When If-Modified-Since is after the backup timestamp → 304."""
_upload(client, tier="power", timestamp=1700000000000)
resp = client.get(
"/api/v1/backup",
headers={
**auth_header("power"),
"If-Modified-Since": "Thu, 01 Jan 2099 00:00:00 GMT",
},
)
assert resp.status_code == 304
def test_download_if_modified_since_returns_200(self, client, s3_bucket) -> None:
"""When If-Modified-Since is before the backup timestamp → serve blob."""
_upload(client, tier="power", timestamp=1700000000000)
resp = client.get(
"/api/v1/backup",
headers={
**auth_header("power"),
"If-Modified-Since": "Thu, 01 Jan 2000 00:00:00 GMT",
},
)
assert resp.status_code == 200
assert resp.content == _BLOB
def test_download_multiple_returns_latest(self, client, s3_bucket) -> None:
"""When multiple backups exist, GET returns the one with the highest timestamp."""
_upload(client, tier="power", timestamp=1000)
blob2 = b"second-encrypted-backup"
checksum2 = hashlib.sha256(blob2).hexdigest()
_upload(client, tier="power", timestamp=2000, blob=blob2, checksum=checksum2)
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.content == blob2
# ── TestBackupHistory ─────────────────────────────────────────────────
class TestBackupHistory:
"""GET /api/v1/backup/history"""
def test_history_empty(self, client, s3_bucket) -> None:
resp = client.get("/api/v1/backup/history", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.json() == []
def test_history_returns_entries(self, client, s3_bucket) -> None:
_upload(client, tier="power", timestamp=1000)
_upload(client, tier="power", timestamp=2000)
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert len(history) == 2
# Ordered by timestamp descending
assert history[0]["timestamp"] == 2000
assert history[1]["timestamp"] == 1000
def test_history_isolated_per_user(self, client, s3_bucket) -> None:
"""One user's backups should not appear in another user's history."""
_upload(client, tier="power")
resp = client.get("/api/v1/backup/history", headers=auth_header("team"))
assert resp.json() == []
# ── TestDeleteBackup ──────────────────────────────────────────────────
class TestDeleteBackup:
"""DELETE /api/v1/backup/{backup_id}"""
def _get_backup_id(self, client, tier="power") -> str:
"""Upload a backup and return its DB id from history."""
_upload(client, tier=tier)
client.get(
"/api/v1/backup/history", headers=auth_header(tier)
).json()
# History returns BackupMetadata schema which doesn't have `id`.
# We need to look it up via a different means.
# Since there's only 1 backup, find via history length.
# Actually the schema doesn't return id — let's verify via re-download.
# We'll use a workaround: upload, then list history to confirm it exists,
# then try to delete — but we need the id...
# Let's check if history includes an id field.
# The schema is: version, timestamp, checksum, chunk_count — no id.
# We'll need to query the DB directly or use a known ID.
# For testing, we'll search history then use the DB.
return None # pragma: no cover — overridden below
def test_delete_success(self, client, s3_bucket, db_session) -> None:
_upload(client, tier="power")
# Discover the backup_id via direct DB query
import asyncio
from sqlalchemy import select
from app.models import BackupMetadata
async def _get_id():
result = await db_session.execute(
select(BackupMetadata.id).where(
BackupMetadata.user_id == TEST_USER_IDS["power"]
)
)
return result.scalar_one()
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
resp = client.delete(
f"/api/v1/backup/{backup_id}", headers=auth_header("power")
)
assert resp.status_code == 200
assert resp.json() == {"ok": True}
# History should now be empty
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert history == []
def test_delete_nonexistent(self, client, s3_bucket) -> None:
resp = client.delete(
"/api/v1/backup/no-such-id", headers=auth_header("power")
)
assert resp.status_code == 404
def test_delete_other_users_backup(self, client, s3_bucket, db_session) -> None:
"""Cannot delete another user's backup (ownership check returns 404)."""
_upload(client, tier="power")
import asyncio
from sqlalchemy import select
from app.models import BackupMetadata
async def _get_id():
result = await db_session.execute(
select(BackupMetadata.id).where(
BackupMetadata.user_id == TEST_USER_IDS["power"]
)
)
return result.scalar_one()
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
# team user tries to delete power user's backup → 404
resp = client.delete(
f"/api/v1/backup/{backup_id}", headers=auth_header("team")
)
assert resp.status_code == 404

184
tests/test_classify_file.py Normal file
View File

@@ -0,0 +1,184 @@
"""Unit tests for Step 1 file classification (_classify_file).
These tests call the real LLM so they require OPENAI_API_KEY / LLM env vars.
Run with: pytest tests/test_classify_file.py -v
To run a quick manual check against a real file without the full UI:
python -m tests.test_classify_file <path/to/file.txt> [project_name...]
"""
from __future__ import annotations
import asyncio
import sys
import pytest
from app.core.agent_runner import _classify_file
# ── Fixtures ──────────────────────────────────────────────────────────────
PROJECTS_SAMPLE = [
{
"id": "aaaa-0001-0000-0000-000000000001",
"name": "ARPA Sicilia POC",
"status": "active",
"aiSummary": "Proof of concept for AI features targeting ARPA Sicilia agency.",
},
{
"id": "bbbb-0002-0000-0000-000000000002",
"name": "SNAM AI Meeting Prep",
"status": "active",
"aiSummary": "AI-assisted preparation of meeting materials for SNAM.",
},
{
"id": "cccc-0003-0000-0000-000000000003",
"name": "SFERA+ Wave 2",
"status": "active",
"aiSummary": "Second wave of the SFERA+ whitelist project.",
},
]
ARPA_EMAIL = """\
to: roberto.musso@hpe.com; luca.tondin@hpecds.com
isImportance: normal
hasAttachment: True
---
## Body
Buongiorno,
In riferimento alla riunione di ieri sul POC ARPA Sicilia, vi invio il riassunto
dei deliverable concordati:
- Preparare demo entro il 30 marzo
- Condividere documentazione tecnica con il team ARPA
- Fissare call di follow-up la prossima settimana
Cordiali saluti
Roberto Marchetti
"""
SNAM_EMAIL = """\
to: roberto.musso@hpe.com
isImportance: high
hasAttachment: False
---
## Body
Ciao,
ti invio l'agenda per la riunione SNAM di domani.
Per favore conferma la tua presenza.
"""
UNRELATED_EMAIL = """\
to: roberto.musso@hpe.com
isImportance: normal
---
## Body
Benvenuto nel programma HPE Employee Learning Series.
Completa la formazione richiesta entro la fine del trimestre.
"""
# ── Tests ─────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_classify_arpa_matches_existing():
project_id, domains, new_name = await _classify_file(
file_path="arpa_email.txt",
file_content=ARPA_EMAIL,
projects=PROJECTS_SAMPLE,
config_data_types=["tasks", "notes", "timelines"],
)
assert project_id == "aaaa-0001-0000-0000-000000000001", (
f"Expected ARPA project, got project_id={project_id!r} new_name={new_name!r}"
)
assert new_name is None
@pytest.mark.asyncio
async def test_classify_snam_matches_existing():
project_id, domains, new_name = await _classify_file(
file_path="snam_email.txt",
file_content=SNAM_EMAIL,
projects=PROJECTS_SAMPLE,
config_data_types=["tasks", "notes"],
)
assert project_id == "bbbb-0002-0000-0000-000000000002", (
f"Expected SNAM project, got project_id={project_id!r} new_name={new_name!r}"
)
@pytest.mark.asyncio
async def test_classify_unrelated_returns_new():
project_id, domains, new_name = await _classify_file(
file_path="learning_email.txt",
file_content=UNRELATED_EMAIL,
projects=PROJECTS_SAMPLE,
config_data_types=["tasks", "notes"],
)
assert project_id == "new"
assert new_name is not None # LLM should suggest a name
@pytest.mark.asyncio
async def test_classify_empty_file_returns_new():
project_id, domains, new_name = await _classify_file(
file_path="empty.txt",
file_content=" ",
projects=PROJECTS_SAMPLE,
config_data_types=["tasks"],
)
assert project_id == "new"
@pytest.mark.asyncio
async def test_classify_no_projects_returns_new():
project_id, domains, new_name = await _classify_file(
file_path="arpa_email.txt",
file_content=ARPA_EMAIL,
projects=[],
config_data_types=["tasks", "notes"],
)
assert project_id == "new"
assert new_name is not None
# ── CLI quick-test runner ─────────────────────────────────────────────────
async def _cli_test(file_path: str, project_names: list[str]) -> None:
"""Run Step 1 classification against a real file from the CLI."""
import json
from pathlib import Path
content = Path(file_path).read_text(encoding="utf-8", errors="replace")
projects = [
{"id": f"test-id-{i:04d}", "name": name, "status": "active", "aiSummary": ""}
for i, name in enumerate(project_names)
]
print(f"\nClassifying: {file_path}")
print(f"Projects in context: {[p['name'] for p in projects]}\n")
project_id, domains, new_name = await _classify_file(
file_path=file_path,
file_content=content,
projects=projects,
config_data_types=["tasks", "notes", "timelines"],
)
result = {
"project_id": project_id,
"matched_name": next((p["name"] for p in projects if p["id"] == project_id), None),
"new_project_name": new_name,
"domains": domains,
}
print(json.dumps(result, indent=2, ensure_ascii=False))
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python -m tests.test_classify_file <file_path> [project_name ...]")
sys.exit(1)
asyncio.run(_cli_test(sys.argv[1], sys.argv[2:]))

288
tests/test_deep_agent.py Normal file
View File

@@ -0,0 +1,288 @@
"""Unit tests for single-agent deep_agent flows with mocked tool results."""
from __future__ import annotations
from datetime import date, timedelta
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from app.core.deep_agent import (
_infer_floating_domain,
_normalize_tagged_list_lines,
run_floating,
run_floating_stream,
run_home,
)
class _FakeTool:
name = "list_tasks"
async def ainvoke(self, args):
return {"rows": [{"id": "task-1", "title": "Mock Task"}], "echo": args}
class _FakeLLM:
def __init__(self) -> None:
self.agent_calls = 0
def bind_tools(self, _tools):
return self
async def ainvoke(self, messages):
system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
if "strict domain classifier" in system_prompt:
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
self.agent_calls += 1
if self.agent_calls == 1:
return AIMessage(
content="",
tool_calls=[
{
"id": "call-1",
"name": "list_tasks",
"args": {"project_id": "proj-1"},
}
],
)
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
assert tool_messages, "Expected at least one tool message"
return AIMessage(content=f"Final answer from mocked tool: {tool_messages[-1].content}")
async def astream(self, _messages):
yield SimpleNamespace(content="stream-")
yield SimpleNamespace(content="ok")
@pytest.mark.asyncio
async def test_run_home_uses_mocked_tool_result():
fake_llm = _FakeLLM()
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
out = await run_home("user-1", "list my tasks", {})
assert "Final answer from mocked tool" in out
assert "Mock Task" in out
@pytest.mark.asyncio
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
fake_llm = _FakeLLM()
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"show me timeline updates",
{"scope": {"type": "timeline", "id": "tl-1"}},
):
events.append(event)
assert events[0] == (
"floating_domain",
{"type": "timeline", "id": "tl-1", "section": None},
)
assert ("token", "stream-") in events
assert ("token", "ok") in events
@pytest.mark.asyncio
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
class _ClassifierOnlyLLM:
async def ainvoke(self, _messages):
return AIMessage(
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
)
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
domain = await _infer_floating_domain(
"Quali sono i miei task per il progetto X",
{
"scope": {"type": "timeline"},
"resolved_project_id": "213213-312321-312312-421321",
},
)
assert domain == {
"type": "project",
"id": "213213-312321-312312-421321",
"section": "task",
}
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
raw = (
"Certo!\n\n"
"1. **Task A** — priorita high <task>[task-1]</task>\n"
"2. **Task B** — priorita medium <task>[task-2]</task>\n"
)
out = _normalize_tagged_list_lines(raw, "quali sono le prossime attivita?")
assert "<task>[task-1]</task>" in out
assert "<task>[task-2]</task>" in out
assert "Task A" not in out
assert "Task B" not in out
def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_month_future_only():
today = date.today()
tomorrow = today + timedelta(days=1)
yesterday = today - timedelta(days=1)
next_month = (today.replace(day=28) + timedelta(days=5)).replace(day=1)
raw = "\n".join(
[
f"- Milestone old — {yesterday.strftime('%d/%m/%Y')} <timeline>[tl-old]</timeline>",
f"- Milestone next — {tomorrow.strftime('%d/%m/%Y')} <timeline>[tl-next]</timeline>",
f"- Milestone future — {next_month.strftime('%d/%m/%Y')} <timeline>[tl-future]</timeline>",
]
)
out = _normalize_tagged_list_lines(raw, "invece i miei eventi prossimi?")
assert "<timeline>[tl-next]</timeline>" in out
assert "<timeline>[tl-old]</timeline>" not in out
assert "<timeline>[tl-future]</timeline>" not in out
@pytest.mark.asyncio
async def test_run_floating_strips_xml_like_tags_from_final_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return (
"Hai 1 task:\\n"
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
)
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert "<task>" not in text
assert "</task>" not in text
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
@pytest.mark.asyncio
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "Hai 1 task:\\n"
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
token_events = [str(data) for event_type, data in events if event_type == "token"]
combined = "".join(token_events)
assert "<task>" not in combined
assert "</task>" not in combined
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
@pytest.mark.asyncio
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
class _NoChunkLLM:
def __init__(self) -> None:
self.calls = 0
def bind_tools(self, _tools):
return self
async def ainvoke(self, _messages):
self.calls += 1
if self.calls == 1:
return AIMessage(
content="",
tool_calls=[
{
"id": "call-1",
"name": "list_tasks",
"args": {},
}
],
)
return AIMessage(content="No notes found.")
async def astream(self, _messages):
if False:
yield None
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"quali sono le note?",
{"scope": {"type": "note"}},
):
events.append(event)
assert events[0][0] == "floating_domain"
assert ("token", "No notes found.") in events
@pytest.mark.asyncio
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert text == "No results found."
@pytest.mark.asyncio
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
assert ("token", "No results found.") in events

362
tests/test_device_ws.py Normal file
View File

@@ -0,0 +1,362 @@
"""Tests for Step 3.3: DeviceConnectionManager and device WS endpoint.
Coverage:
Unit tests — DeviceConnectionManager register/unregister/is_online/
get_ws/send_frame/pending-call round-trip/agent-data queue
Integration — /api/v1/ws/device endpoint via TestClient WebSocket:
auth rejection, happy-path connect, tool_result dispatch,
agent_data queue routing, agent_complete sentinel, disconnect
cleanup (AgentRunLog marked as error)
"""
from __future__ import annotations
import asyncio
import json
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from app.core.device_manager import DeviceConnection, DeviceConnectionManager
from app.db import get_session
from app.main import app
from app.models import AgentRunLog
from tests.conftest import TEST_USER_IDS, auth_header, make_jwt
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_FREE_UID = TEST_USER_IDS["free"]
_PRO_UID = TEST_USER_IDS["pro"]
def _device_hello(device_id: str = "dev-001", agent_ids: list[str] | None = None) -> str:
return json.dumps(
{"type": "device_hello", "device_id": device_id, "agent_ids": agent_ids or []}
)
# ---------------------------------------------------------------------------
# DB override (shared across integration tests)
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _override_db(db_session):
"""Route all get_session calls to the test SQLite session."""
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
# ---------------------------------------------------------------------------
# DeviceConnectionManager unit tests
# ---------------------------------------------------------------------------
@pytest.fixture()
def manager() -> DeviceConnectionManager:
"""Fresh manager instance for each test."""
return DeviceConnectionManager()
@pytest.fixture()
def mock_ws() -> MagicMock:
ws = MagicMock()
ws.send_text = AsyncMock()
return ws
def test_manager_register_and_is_online(manager, mock_ws):
assert not manager.is_online("user1")
manager.register("user1", "dev-A", mock_ws)
assert manager.is_online("user1")
assert manager.is_online("user1", "dev-A")
assert not manager.is_online("user1", "dev-B")
def test_manager_get_ws_returns_none_when_offline(manager):
assert manager.get_ws("no-such-user") is None
def test_manager_unregister(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
assert manager.is_online("user1")
manager.unregister("user1")
assert not manager.is_online("user1")
assert manager.get_ws("user1") is None
def test_manager_unregister_unknown_is_noop(manager):
# Must not raise.
manager.unregister("ghost")
def test_manager_replace_connection_cancels_old_futures(manager):
ws_a = MagicMock()
ws_a.send_text = AsyncMock()
ws_b = MagicMock()
ws_b.send_text = AsyncMock()
# Create event loop context for Future.
loop = asyncio.new_event_loop()
try:
async def _run():
manager.register("user1", "dev-A", ws_a)
fut = manager.create_pending_call("user1", "call-1")
# Replace connection — old future should be cancelled.
manager.register("user1", "dev-B", ws_b)
assert fut.cancelled()
loop.run_until_complete(_run())
finally:
loop.close()
@pytest.mark.asyncio
async def test_manager_send_frame(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
await manager.send_frame("user1", {"type": "ping"})
mock_ws.send_text.assert_called_once_with(json.dumps({"type": "ping"}))
@pytest.mark.asyncio
async def test_manager_send_frame_raises_when_offline(manager):
with pytest.raises(RuntimeError, match="not connected"):
await manager.send_frame("ghost", {"type": "ping"})
@pytest.mark.asyncio
async def test_manager_pending_call_round_trip(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
fut = manager.create_pending_call("user1", "call-42")
result = {"type": "tool_result", "id": "call-42", "rows": [{"id": "row1"}]}
manager.resolve_pending_call("user1", "call-42", result)
assert fut.done()
assert await fut == result
@pytest.mark.asyncio
async def test_manager_resolve_unknown_call_is_noop(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
# Should not raise.
manager.resolve_pending_call("user1", "no-such-call", {})
@pytest.mark.asyncio
async def test_manager_unregister_cancels_pending_calls(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
fut = manager.create_pending_call("user1", "call-1")
manager.unregister("user1")
assert fut.cancelled()
@pytest.mark.asyncio
async def test_manager_agent_data_queue(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
q = manager.get_agent_data_queue("user1", "run-xyz")
# Put a frame and get it back.
frame = {"type": "agent_data", "run_id": "run-xyz", "files": []}
await q.put(frame)
assert await q.get() == frame
@pytest.mark.asyncio
async def test_manager_agent_data_queue_creates_once(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
q1 = manager.get_agent_data_queue("user1", "run-1")
q2 = manager.get_agent_data_queue("user1", "run-1")
assert q1 is q2
@pytest.mark.asyncio
async def test_manager_agent_data_queue_raises_when_offline(manager):
with pytest.raises(RuntimeError, match="not connected"):
manager.get_agent_data_queue("ghost", "run-1")
@pytest.mark.asyncio
async def test_manager_cleanup_agent_data_queue(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
manager.get_agent_data_queue("user1", "run-1")
manager.cleanup_agent_data_queue("user1", "run-1")
# After cleanup a new queue is created (not the same object).
q_new = manager.get_agent_data_queue("user1", "run-1")
assert q_new is not None
# ---------------------------------------------------------------------------
# Integration tests — /api/v1/ws/device endpoint
# ---------------------------------------------------------------------------
def test_ws_device_rejects_without_token(client):
with pytest.raises(Exception):
# TestClient will raise or close when the server rejects.
with client.websocket_connect("/api/v1/ws/device") as ws:
ws.receive_text()
def test_ws_device_rejects_invalid_token(client):
with pytest.raises(Exception):
with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws:
ws.receive_text()
def test_ws_device_happy_path(client):
"""Connect, send device_hello, receive ping, then close."""
token = make_jwt(tier="free")
# Patch the heartbeat sleep so the test doesn't block 30 s.
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.01):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(_device_hello("dev-001"))
# Next message from server should be a heartbeat ping (interval=0.01s).
msg = ws.receive_text()
data = json.loads(msg)
assert data["type"] == "ping"
# Close gracefully.
ws.close()
def test_ws_device_invalid_first_frame_closes(client):
"""Non-device_hello first frame should close the connection."""
token = make_jwt(tier="free")
with pytest.raises(Exception):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({"type": "chat_request", "message": "hi"}))
ws.receive_text() # server should close after bad frame
def test_ws_device_tool_result_dispatched(client):
"""tool_result frame is routed to the DeviceConnectionManager."""
token = make_jwt(tier="free")
user_id = TEST_USER_IDS["free"]
from app.core.device_manager import device_manager as dm
captured: list[dict] = []
original_resolve = dm.resolve_pending_call
def _spy(uid, call_id, result):
captured.append({"uid": uid, "call_id": call_id, "result": result})
original_resolve(uid, call_id, result)
with patch.object(dm, "resolve_pending_call", side_effect=_spy):
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(_device_hello("dev-001"))
# Send a tool_result frame.
ws.send_text(
json.dumps(
{
"type": "tool_result",
"id": "call-123",
"rows": [{"id": "task-1", "title": "Buy milk"}],
}
)
)
ws.close()
assert any(c["call_id"] == "call-123" for c in captured)
def test_ws_device_agent_data_enqueued(client):
"""agent_data frame is placed in the per-run queue by the message loop."""
from app.core.device_manager import device_manager as dm
token = make_jwt(tier="free")
user_id = TEST_USER_IDS["free"]
# Capture the queue object the message loop accesses.
captured_queue: list[asyncio.Queue] = []
original_get_queue = dm.get_agent_data_queue
def _spy_get_queue(uid, run_id):
q = original_get_queue(uid, run_id)
if not captured_queue:
captured_queue.append(q)
return q
with patch.object(dm, "get_agent_data_queue", side_effect=_spy_get_queue):
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(_device_hello("dev-001"))
ws.send_text(
json.dumps(
{
"type": "agent_data",
"run_id": "run-XYZ",
"files": [{"path": "/tmp/file.txt", "content": "hello"}],
}
)
)
ws.close()
# The queue should have received exactly one frame.
assert captured_queue, "queue was never accessed"
assert not captured_queue[0].empty()
def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session):
"""On disconnect, _mark_runs_disconnected is called with the correct user_id."""
from app.api.routes import device_ws as _dws
token = make_jwt(tier="free")
user_id = TEST_USER_IDS["free"]
cleanup_calls: list[str] = []
async def _fake_cleanup(uid: str) -> None:
cleanup_calls.append(uid)
with patch.object(_dws, "_mark_runs_disconnected", side_effect=_fake_cleanup):
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(_device_hello("dev-001"))
ws.close()
assert user_id in cleanup_calls
@pytest.mark.asyncio
async def test_mark_runs_disconnected_updates_db(db_session):
"""_mark_runs_disconnected marks in-progress runs as error in the DB."""
from sqlalchemy import select
from app.api.routes.device_ws import _mark_runs_disconnected
from tests.conftest import _TestSessionLocal
user_id = TEST_USER_IDS["free"]
run_log = AgentRunLog(
id=str(uuid.uuid4()),
agent_id=str(uuid.uuid4()),
agent_type="local",
user_id=user_id,
status="running",
started_at=datetime.now(timezone.utc),
)
db_session.add(run_log)
await db_session.commit()
# Route the function to the same test-DB session factory.
with patch("app.api.routes.device_ws.async_session", _TestSessionLocal):
await _mark_runs_disconnected(user_id)
# Verify through the same session factory.
async with _TestSessionLocal() as s:
result = await s.execute(
select(AgentRunLog).where(AgentRunLog.id == run_log.id)
)
updated = result.scalar_one_or_none()
assert updated is not None
assert updated.status == "error"
assert updated.errors and "device disconnected" in updated.errors

View File

@@ -1,286 +0,0 @@
"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache."""
from __future__ import annotations
import pytest
from app.core.execution_plan import (
ExecutionPlanBuilder,
PlanCache,
PromptTemplateRegistry,
plan_cache,
template_registry,
)
from app.schemas import ExecutionPlan
# ── PromptTemplateRegistry ────────────────────────────────────────────
class TestPromptTemplateRegistry:
def test_register_and_get(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_foo", "You are a foo agent.")
assert reg.get("tpl_foo") == "You are a foo agent."
def test_get_unknown_raises_key_error(self) -> None:
reg = PromptTemplateRegistry()
with pytest.raises(KeyError, match="tpl_missing"):
reg.get("tpl_missing")
def test_has_returns_true_for_registered(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_x", "prompt text")
assert reg.has("tpl_x") is True
def test_has_returns_false_for_unregistered(self) -> None:
reg = PromptTemplateRegistry()
assert reg.has("tpl_missing") is False
def test_list_ids_returns_all_registered_ids(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_a", "a")
reg.register("tpl_b", "b")
assert set(reg.list_ids()) == {"tpl_a", "tpl_b"}
def test_list_ids_does_not_return_prompt_text(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_secret", "top secret prompt")
ids = reg.list_ids()
assert "top secret prompt" not in ids
def test_overwrite_existing_template(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_x", "v1")
reg.register("tpl_x", "v2")
assert reg.get("tpl_x") == "v2"
def test_empty_registry_has_no_ids(self) -> None:
reg = PromptTemplateRegistry()
assert reg.list_ids() == []
# ── ExecutionPlanBuilder ──────────────────────────────────────────────
class TestExecutionPlanBuilder:
def test_builds_empty_plan(self) -> None:
plan = ExecutionPlanBuilder("task_agent").build()
assert plan.agent == "task_agent"
assert plan.steps == []
def test_add_step_basic(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("create_task", {"priority": "high"})
.build()
)
assert len(plan.steps) == 1
assert plan.steps[0].action == "create_task"
assert plan.steps[0].variables == {"priority": "high"}
assert plan.steps[0].prompt_template is None
assert plan.steps[0].data_from_step is None
def test_add_step_no_params(self) -> None:
plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build()
assert plan.steps[0].variables is None
def test_add_llm_step(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_llm_step("tpl_task_default", {"message": "hi"})
.build()
)
assert plan.steps[0].action == "llm"
assert plan.steps[0].prompt_template == "tpl_task_default"
assert plan.steps[0].variables == {"message": "hi"}
def test_add_llm_step_no_variables(self) -> None:
plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build()
assert plan.steps[0].variables is None
def test_add_data_step(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("fetch_data")
.add_data_step("transform", data_from_step=0)
.build()
)
assert plan.steps[1].action == "transform"
assert plan.steps[1].data_from_step == 0
def test_fluent_chaining_returns_builder(self) -> None:
builder = ExecutionPlanBuilder("analytics_agent")
result = builder.add_step("a")
assert result is builder
def test_fluent_chain_multiple_steps(self) -> None:
plan = (
ExecutionPlanBuilder("analytics_agent")
.add_llm_step("tpl_analytics_default")
.add_step("format_output")
.add_data_step("store", data_from_step=0)
.build()
)
assert len(plan.steps) == 3
def test_build_validates_data_from_step_out_of_range(self) -> None:
with pytest.raises(ValueError, match="data_from_step"):
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build()
def test_build_validates_data_from_step_self_reference(self) -> None:
"""data_from_step=0 on the first step (index 0) is invalid."""
with pytest.raises(ValueError, match="data_from_step"):
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build()
def test_build_validates_data_from_step_negative(self) -> None:
with pytest.raises(ValueError, match="data_from_step"):
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build()
def test_valid_data_from_step_at_index_two(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("step0")
.add_step("step1")
.add_data_step("step2", data_from_step=1)
.build()
)
assert plan.steps[2].data_from_step == 1
def test_data_from_step_zero_valid_at_index_one(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("step0")
.add_data_step("step1", data_from_step=0)
.build()
)
assert plan.steps[1].data_from_step == 0
def test_build_returns_new_plan_each_call(self) -> None:
builder = ExecutionPlanBuilder("task_agent").add_step("do_thing")
plan1 = builder.build()
plan2 = builder.build()
assert plan1 is not plan2
assert plan1.steps == plan2.steps
def test_plan_is_execution_plan_instance(self) -> None:
plan = ExecutionPlanBuilder("task_agent").build()
assert isinstance(plan, ExecutionPlan)
# ── PlanCache ─────────────────────────────────────────────────────────
class TestPlanCache:
def _plan(self, agent: str = "a") -> ExecutionPlan:
return ExecutionPlanBuilder(agent).build()
def test_cache_and_get(self) -> None:
cache = PlanCache()
plan = self._plan()
cache.cache_plan("key1", plan)
assert cache.get_plan("key1") is plan
def test_get_missing_returns_none(self) -> None:
cache = PlanCache()
assert cache.get_plan("nonexistent") is None
def test_get_all_playbooks_empty(self) -> None:
cache = PlanCache()
assert cache.get_all_playbooks() == []
def test_get_all_playbooks_returns_all_stored(self) -> None:
cache = PlanCache()
p1, p2 = self._plan("a"), self._plan("b")
cache.cache_plan("k1", p1)
cache.cache_plan("k2", p2)
playbooks = cache.get_all_playbooks()
assert len(playbooks) == 2
assert p1 in playbooks
assert p2 in playbooks
def test_lru_evicts_oldest_entry(self) -> None:
cache = PlanCache(maxsize=2)
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
cache.cache_plan("k1", p1)
cache.cache_plan("k2", p2)
cache.cache_plan("k3", p3) # k1 should be evicted
assert cache.get_plan("k1") is None
assert cache.get_plan("k2") is p2
assert cache.get_plan("k3") is p3
def test_lru_access_updates_recency(self) -> None:
cache = PlanCache(maxsize=2)
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
cache.cache_plan("k1", p1)
cache.cache_plan("k2", p2)
cache.get_plan("k1") # k1 is now most-recently used
cache.cache_plan("k3", p3) # k2 should be evicted (LRU)
assert cache.get_plan("k1") is p1
assert cache.get_plan("k2") is None
assert cache.get_plan("k3") is p3
def test_overwrite_existing_key(self) -> None:
cache = PlanCache()
p1, p2 = self._plan("a"), self._plan("b")
cache.cache_plan("same_key", p1)
cache.cache_plan("same_key", p2)
assert cache.get_plan("same_key") is p2
assert len(cache.get_all_playbooks()) == 1
def test_overwrite_does_not_consume_capacity(self) -> None:
cache = PlanCache(maxsize=2)
p1, p2 = self._plan("a"), self._plan("b")
cache.cache_plan("k1", p1)
cache.cache_plan("k1", p2) # overwrite, not a new slot
cache.cache_plan("k2", p1) # should fit without eviction
assert cache.get_plan("k1") is p2
assert cache.get_plan("k2") is p1
# ── Module-level singletons ───────────────────────────────────────────
class TestModuleSingletons:
def test_template_registry_has_all_agent_defaults(self) -> None:
for agent in ("task_agent", "checkpoint_agent", "project_agent", "note_agent"):
assert template_registry.has(f"tpl_{agent}_default"), (
f"Missing template: tpl_{agent}_default"
)
def test_template_registry_has_operation_templates(self) -> None:
assert template_registry.has("tpl_task_extract_from_project")
assert template_registry.has("tpl_note_weekly_summary")
def test_template_registry_get_returns_non_empty_string(self) -> None:
text = template_registry.get("tpl_task_agent_default")
assert isinstance(text, str)
assert len(text) > 0
def test_plan_cache_has_prebuilt_playbooks(self) -> None:
assert len(plan_cache.get_all_playbooks()) >= 2
def test_playbook_create_tasks_from_project(self) -> None:
plan = plan_cache.get_plan("create_tasks_from_project")
assert plan is not None
assert plan.agent == "project_agent"
assert len(plan.steps) == 2
assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
assert plan.steps[1].data_from_step == 0
def test_playbook_generate_weekly_note(self) -> None:
plan = plan_cache.get_plan("generate_weekly_note")
assert plan is not None
assert plan.agent == "note_agent"
assert len(plan.steps) == 2
assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
assert plan.steps[1].data_from_step == 0
def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
"""Plans must not embed prompt text — only template IDs."""
for plan in plan_cache.get_all_playbooks():
for step in plan.steps:
if step.prompt_template is not None:
assert step.prompt_template.startswith("tpl_"), (
f"prompt_template looks like raw text: {step.prompt_template!r}"
)

729
tests/test_integrations.py Normal file
View File

@@ -0,0 +1,729 @@
"""Tests for Step 3.6: cloud provider integration clients.
Coverage:
Unit \u2014 app/integrations/__init__.py:
- encrypt_token / decrypt_token round-trip
- decrypt_token raises ValueError on invalid ciphertext
- encrypt_token raises ValueError on empty/non-dict input
- _get_fernet raises RuntimeError when OAUTH_ENCRYPTION_KEY not set
- get_provider returns GmailClient for 'gmail'
- get_provider returns MSGraphClient for 'outlook' and 'teams'
- get_provider raises ValueError for unknown provider
Unit \u2014 app/integrations/gmail.py:
- _build_gmail_query with no filter returns empty string
- _build_gmail_query with labels builds label: expr
- _build_gmail_query with senders builds from: expr
- _build_gmail_query with date_range builds after:/before: exprs
- _build_gmail_query since overrides date_range.from when more recent
- _build_gmail_query date_range.from overrides since when more recent
- _parse_body extracts text/plain part
- _parse_body extracts text/html part (stripped)
- _parse_body recurses into multipart, prefers text/plain
- GmailClient.fetch_messages: happy path with mocked service
- GmailClient.fetch_messages: no messages returns empty list
- GmailClient.fetch_messages: HTTP error on messages.list raises RuntimeError
- GmailClient.refreshed_credentials: None when token unchanged
- GmailClient.refreshed_credentials: returns dict when token changes
Unit \u2014 app/integrations/ms_graph.py:
- _build_email_filter with no filter returns empty string
- _build_email_filter with senders builds OData from clause
- _build_email_filter with since builds receivedDateTime ge clause
- MSGraphClient.fetch_emails: happy path with mocked httpx
- MSGraphClient.fetch_emails: 401 triggers token refresh and retries
- MSGraphClient.fetch_messages: happy path with mocked httpx
- MSGraphClient.fetch_messages: 403 from getAllMessages degrades gracefully
- MSGraphClient.refreshed_credentials: None when token unchanged
- MSGraphClient._refresh_access_token: MSAL error raises RuntimeError
"""
from __future__ import annotations
import asyncio
import json
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest
from app.integrations import (
ChatMessage,
EmailMessage,
decrypt_token,
encrypt_token,
get_provider,
)
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# Helpers
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
_FERNET_KEY = "eW91LXNob3VsZC1ub3QtdXNlLXRoaXMta2V5LWluLXByb2Q="
# ^ 32-char URL-safe base64 (generated for tests only; not a real Fernet key length,
# so we generate a proper one below)
from cryptography.fernet import Fernet as _Fernet # noqa: E402
_VALID_KEY = _Fernet.generate_key().decode("utf-8")
_TOKEN_DICT = {
"token": "access_abc",
"refresh_token": "refresh_xyz",
"token_uri": "https://oauth2.googleapis.com/token",
"client_id": "client_id_123",
"client_secret": "client_secret_456",
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
}
_MS_TOKEN_DICT = {
"access_token": "ms_access_abc",
"refresh_token": "ms_refresh_xyz",
"token_type": "Bearer",
"scope": "Mail.Read offline_access",
}
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# encrypt_token / decrypt_token
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
class TestTokenEncryption:
"""encrypt_token / decrypt_token round-trip tests."""
def test_round_trip(self):
with patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
encrypted = encrypt_token(_TOKEN_DICT)
assert isinstance(encrypted, str)
assert encrypted != json.dumps(_TOKEN_DICT) # must be ciphertext, not plaintext
recovered = decrypt_token(encrypted)
assert recovered == _TOKEN_DICT
def test_decrypt_invalid_ciphertext_raises_value_error(self):
with patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
with pytest.raises(ValueError, match="Failed to decrypt"):
decrypt_token("this-is-not-valid-fernet-ciphertext")
def test_decrypt_wrong_key_raises_value_error(self):
"""Decrypting with a different key must fail with ValueError."""
other_key = _Fernet.generate_key().decode("utf-8")
with patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
encrypted = encrypt_token(_TOKEN_DICT)
with patch("app.integrations.settings") as mock_settings2:
mock_settings2.OAUTH_ENCRYPTION_KEY = other_key
with pytest.raises(ValueError, match="Failed to decrypt"):
decrypt_token(encrypted)
def test_encrypt_empty_dict_raises_value_error(self):
with patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
with pytest.raises(ValueError, match="non-empty dict"):
encrypt_token({})
def test_encrypt_non_dict_raises_value_error(self):
with patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
with pytest.raises(ValueError, match="non-empty dict"):
encrypt_token("not-a-dict") # type: ignore[arg-type]
def test_missing_key_raises_runtime_error(self):
with patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = ""
with pytest.raises(RuntimeError, match="OAUTH_ENCRYPTION_KEY"):
encrypt_token(_TOKEN_DICT)
def test_email_message_as_text(self):
msg = EmailMessage(
id="m1",
subject="Hello",
sender="alice@example.com",
body_text="Test body",
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
)
text = msg.as_text
assert "From: alice@example.com" in text
assert "Subject: Hello" in text
assert "Test body" in text
def test_chat_message_as_text(self):
msg = ChatMessage(
id="c1",
content="Buy milk",
sender="bob",
channel="general",
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
)
text = msg.as_text
assert "From: bob" in text
assert "channel: general" in text
assert "Buy milk" in text
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# get_provider factory
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
class TestGetProvider:
def test_gmail_returns_gmail_client(self):
from app.integrations.gmail import GmailClient
client = get_provider("gmail", _TOKEN_DICT)
assert isinstance(client, GmailClient)
def test_outlook_returns_ms_graph_client(self):
from app.integrations.ms_graph import MSGraphClient
client = get_provider("outlook", _MS_TOKEN_DICT)
assert isinstance(client, MSGraphClient)
def test_teams_returns_ms_graph_client(self):
from app.integrations.ms_graph import MSGraphClient
client = get_provider("teams", _MS_TOKEN_DICT)
assert isinstance(client, MSGraphClient)
def test_unknown_provider_raises_value_error(self):
with pytest.raises(ValueError, match="Unknown cloud provider"):
get_provider("slack", {})
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# Gmail client \u2014 query builder
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
class TestBuildGmailQuery:
"""Unit tests for gmail._build_gmail_query."""
def setup_method(self):
from app.integrations.gmail import _build_gmail_query
self._fn = _build_gmail_query
def test_empty_returns_empty_string(self):
assert self._fn(None, None) == ""
def test_single_label(self):
q = self._fn({"labels": ["INBOX"]}, None)
assert "label:INBOX" in q
def test_multiple_labels_joined_with_or(self):
q = self._fn({"labels": ["INBOX", "work"]}, None)
assert "label:INBOX OR label:work" in q
def test_senders(self):
q = self._fn({"senders": ["alice@example.com"]}, None)
assert "from:alice@example.com" in q
def test_date_range_from(self):
q = self._fn({"date_range": {"from": "2025-01-15"}}, None)
assert "after:2025/01/15" in q
def test_date_range_to(self):
q = self._fn({"date_range": {"to": "2025-03-01"}}, None)
assert "before:2025/03/01" in q
def test_since_overrides_earlier_date_range_from(self):
"""since=Feb is more recent than date_range.from=Jan, so after: should be Feb."""
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
q = self._fn({"date_range": {"from": "2025-01-01"}}, since)
assert "after:2025/02/01" in q
assert "after:2025/01/01" not in q
def test_date_range_from_overrides_earlier_since(self):
"""date_range.from=Feb is more recent than since=Jan, so after: should be Feb."""
since = datetime(2025, 1, 1, tzinfo=timezone.utc)
q = self._fn({"date_range": {"from": "2025-02-01"}}, since)
assert "after:2025/02/01" in q
def test_invalid_date_ignored(self):
"""An invalid date string in filter_config must not raise, just be skipped."""
q = self._fn({"date_range": {"from": "not-a-date"}}, None)
assert "after:" not in q
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# Gmail client \u2014 body parsing
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
class TestParseBody:
"""Unit tests for gmail._parse_body."""
def setup_method(self):
from app.integrations.gmail import _parse_body
self._fn = _parse_body
def _encode(self, text: str) -> str:
import base64
return base64.urlsafe_b64encode(text.encode()).decode()
def test_text_plain_extracted(self):
payload = {
"mimeType": "text/plain",
"body": {"data": self._encode("Hello world")},
}
assert self._fn(payload) == "Hello world"
def test_text_html_stripped(self):
payload = {
"mimeType": "text/html",
"body": {"data": self._encode("<p>Hello <b>world</b></p>")},
}
result = self._fn(payload)
assert "Hello" in result
assert "<p>" not in result
def test_multipart_prefers_plain_over_html(self):
plain_data = self._encode("Plain text")
html_data = self._encode("<p>HTML text</p>")
payload = {
"mimeType": "multipart/alternative",
"body": {},
"parts": [
{"mimeType": "text/html", "body": {"data": html_data}},
{"mimeType": "text/plain", "body": {"data": plain_data}},
],
}
result = self._fn(payload)
assert result == "Plain text"
def test_empty_payload_returns_empty_string(self):
assert self._fn({}) == ""
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# GmailClient.fetch_messages
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
def _make_gmail_message(
msg_id: str = "msg001",
subject: str = "Test email",
sender: str = "alice@example.com",
body_text: str = "Hello world",
date: str = "Mon, 01 Jan 2025 10:00:00 +0000",
) -> dict:
"""Build a minimal Gmail API message response dict."""
import base64
body_data = base64.urlsafe_b64encode(body_text.encode()).decode()
return {
"id": msg_id,
"labelIds": ["INBOX"],
"payload": {
"mimeType": "text/plain",
"headers": [
{"name": "Subject", "value": subject},
{"name": "From", "value": sender},
{"name": "Date", "value": date},
],
"body": {"data": body_data},
},
}
class TestGmailClientFetchMessages:
"""GmailClient.fetch_messages tests with mocked Google API."""
def _make_client(self) -> "GmailClient":
from app.integrations.gmail import GmailClient
return GmailClient(_TOKEN_DICT)
@pytest.mark.asyncio
async def test_happy_path_returns_email_messages(self):
client = self._make_client()
msg = _make_gmail_message()
mock_service = MagicMock()
mock_users = mock_service.users.return_value
mock_messages = mock_users.messages.return_value
mock_messages.list.return_value.execute.return_value = {
"messages": [{"id": "msg001"}]
}
mock_messages.get.return_value.execute.return_value = msg
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
# Simulate to_thread running the sync function and returning results.
async def fake_to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
mock_thread.side_effect = fake_to_thread
with patch("googleapiclient.discovery.build", return_value=mock_service), \
patch("google.auth.transport.requests.Request"), \
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
results = await client.fetch_messages()
assert len(results) == 1
assert results[0].subject == "Test email"
assert results[0].sender == "alice@example.com"
assert results[0].body_text == "Hello world"
@pytest.mark.asyncio
async def test_no_messages_returns_empty_list(self):
client = self._make_client()
mock_service = MagicMock()
mock_users = mock_service.users.return_value
mock_messages = mock_users.messages.return_value
mock_messages.list.return_value.execute.return_value = {"messages": []}
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
async def fake_to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
mock_thread.side_effect = fake_to_thread
with patch("googleapiclient.discovery.build", return_value=mock_service), \
patch("google.auth.transport.requests.Request"), \
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
results = await client.fetch_messages()
assert results == []
@pytest.mark.asyncio
async def test_list_http_error_raises_runtime_error(self):
import googleapiclient.errors
client = self._make_client()
mock_service = MagicMock()
mock_users = mock_service.users.return_value
mock_messages = mock_users.messages.return_value
mock_resp = MagicMock()
mock_resp.status = 403
mock_resp.reason = "Forbidden"
mock_messages.list.return_value.execute.side_effect = (
googleapiclient.errors.HttpError(mock_resp, b"Forbidden")
)
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
async def fake_to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
mock_thread.side_effect = fake_to_thread
with patch("googleapiclient.discovery.build", return_value=mock_service), \
patch("google.auth.transport.requests.Request"), \
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
with pytest.raises(RuntimeError, match="Gmail messages.list failed"):
await client.fetch_messages()
def test_refreshed_credentials_none_when_unchanged(self):
client = self._make_client()
# Token unchanged — should return None.
assert client.refreshed_credentials is None
def test_refreshed_credentials_returns_dict_when_token_changes(self):
client = self._make_client()
# Simulate a token refresh by changing the access token on the credentials object.
client._credentials.token = "new_access_token_xyz"
refreshed = client.refreshed_credentials
assert refreshed is not None
assert refreshed["token"] == "new_access_token_xyz"
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# MS Graph client \u2014 email filter builder
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
class TestBuildEmailFilter:
"""Unit tests for ms_graph._build_email_filter."""
def setup_method(self):
from app.integrations.ms_graph import _build_email_filter
self._fn = _build_email_filter
def test_empty_returns_empty_string(self):
assert self._fn(None, None) == ""
def test_single_sender(self):
result = self._fn({"senders": ["alice@example.com"]}, None)
assert "from/emailAddress/address eq 'alice@example.com'" in result
def test_multiple_senders_joined_with_or(self):
result = self._fn({"senders": ["a@x.com", "b@x.com"]}, None)
assert " or " in result
assert "a@x.com" in result
assert "b@x.com" in result
def test_since_adds_received_date_ge_clause(self):
since = datetime(2025, 3, 1, tzinfo=timezone.utc)
result = self._fn(None, since)
assert "receivedDateTime ge 2025-03-01T00:00:00Z" in result
def test_date_range_to_adds_received_date_le_clause(self):
result = self._fn({"date_range": {"to": "2025-06-30"}}, None)
assert "receivedDateTime le" in result
def test_since_overrides_earlier_date_range_from(self):
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
result = self._fn({"date_range": {"from": "2025-01-01"}}, since)
assert "2025-02-01T00:00:00Z" in result
assert "2025-01-01" not in result
def test_invalid_date_ignored(self):
result = self._fn({"date_range": {"from": "bad-date"}}, None)
assert "receivedDateTime" not in result
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# MSGraphClient.fetch_emails
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
def _make_graph_email(
msg_id: str = "email001",
subject: str = "Meeting tomorrow",
sender_address: str = "boss@company.com",
body_content: str = "Please prepare the report.",
received: str = "2025-06-01T10:00:00Z",
) -> dict:
"""Build a minimal MS Graph message item dict."""
return {
"id": msg_id,
"subject": subject,
"from": {"emailAddress": {"address": sender_address}},
"receivedDateTime": received,
"body": {"contentType": "text", "content": body_content},
"bodyPreview": body_content[:100],
}
def _make_graph_teams_message(
msg_id: str = "teams001",
content: str = "Stand-up at 9am",
sender: str = "alice",
channel_id: str = "chan001",
created: str = "2025-06-01T08:00:00Z",
) -> dict:
return {
"id": msg_id,
"body": {"contentType": "text", "content": content},
"from": {"user": {"displayName": sender}},
"channelIdentity": {"channelId": channel_id},
"createdDateTime": created,
}
class TestMSGraphClientFetchEmails:
"""MSGraphClient.fetch_emails tests with mocked httpx."""
def _make_client(self) -> "MSGraphClient":
from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(_MS_TOKEN_DICT)
@pytest.mark.asyncio
async def test_happy_path_returns_email_messages(self):
client = self._make_client()
graph_email = _make_graph_email()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"value": [graph_email]}
mock_response.raise_for_status = MagicMock()
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
results = await client.fetch_emails()
assert len(results) == 1
assert results[0].subject == "Meeting tomorrow"
assert results[0].sender == "boss@company.com"
assert results[0].body_text == "Please prepare the report."
@pytest.mark.asyncio
async def test_pagination_stops_at_max_emails(self):
"""No nextLink in first page \u2014 only one batch returned."""
client = self._make_client()
emails_batch = [_make_graph_email(msg_id=str(i)) for i in range(3)]
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"value": emails_batch} # no @odata.nextLink
mock_response.raise_for_status = MagicMock()
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
results = await client.fetch_emails()
assert len(results) == 3
@pytest.mark.asyncio
async def test_401_triggers_token_refresh_and_retries(self):
"""On first 401, token refresh is attempted and the request retried."""
from app.integrations.ms_graph import MSGraphClient
client = MSGraphClient(_MS_TOKEN_DICT)
graph_email = _make_graph_email()
response_401 = MagicMock()
response_401.status_code = 401
response_200 = MagicMock()
response_200.status_code = 200
response_200.json.return_value = {"value": [graph_email]}
response_200.raise_for_status = MagicMock()
call_count = 0
async def fake_get(url, params=None, headers=None):
nonlocal call_count
call_count += 1
if call_count == 1:
return response_401
return response_200
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls, \
patch.object(client, "_refresh_access_token", new_callable=AsyncMock) as mock_refresh:
mock_http = AsyncMock()
mock_http.get = fake_get
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
results = await client.fetch_emails()
mock_refresh.assert_called_once()
assert len(results) == 1
def test_refreshed_credentials_none_when_token_unchanged(self):
client = self._make_client()
assert client.refreshed_credentials is None
def test_refreshed_credentials_returns_dict_when_token_changes(self):
client = self._make_client()
client._access_token = "new_token_abc"
assert client.refreshed_credentials is not None
assert client.refreshed_credentials["access_token"] == "new_token_abc"
class TestMSGraphClientFetchMessages:
"""MSGraphClient.fetch_messages (Teams) tests."""
def _make_client(self) -> "MSGraphClient":
from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(_MS_TOKEN_DICT)
@pytest.mark.asyncio
async def test_happy_path_returns_chat_messages(self):
client = self._make_client()
teams_msg = _make_graph_teams_message()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"value": [teams_msg]}
mock_response.raise_for_status = MagicMock()
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
results = await client.fetch_messages()
assert len(results) == 1
assert results[0].content == "Stand-up at 9am"
assert results[0].sender == "alice"
@pytest.mark.asyncio
async def test_403_degrades_gracefully(self):
"""getAllMessages returning 403 (license issue) returns empty list, no exception."""
import httpx as _httpx
client = self._make_client()
error_response = MagicMock()
error_response.status_code = 403
http_error = _httpx.HTTPStatusError(
"Forbidden", request=MagicMock(), response=error_response
)
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
mock_http = AsyncMock()
mock_http.get = AsyncMock(side_effect=http_error)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
results = await client.fetch_messages()
assert results == []
@pytest.mark.asyncio
async def test_channel_filter_applied(self):
"""Messages from non-matching channels are filtered out."""
client = self._make_client()
matching = _make_graph_teams_message(channel_id="dev-channel", content="Deploy today")
non_matching = _make_graph_teams_message(msg_id="t2", channel_id="random", content="Lunch?")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"value": [matching, non_matching]}
mock_response.raise_for_status = MagicMock()
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
results = await client.fetch_messages(
filter_config={"channels": ["dev-channel"]}
)
assert len(results) == 1
assert results[0].content == "Deploy today"
class TestMSGraphClientRefreshToken:
"""MSGraphClient._refresh_access_token with mocked MSAL."""
@pytest.mark.asyncio
async def test_msal_error_raises_runtime_error(self):
from app.integrations.ms_graph import MSGraphClient
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_test"})
mock_app = MagicMock()
mock_app.acquire_token_by_refresh_token.return_value = {
"error": "invalid_grant",
"error_description": "Refresh token expired",
}
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
patch("app.integrations.ms_graph.settings") as mock_settings:
mock_settings.MS_CLIENT_ID = "client_id"
mock_settings.MS_CLIENT_SECRET = "secret"
mock_settings.MS_TENANT_ID = "common"
with pytest.raises(RuntimeError, match="MS Graph token refresh failed"):
await client._refresh_access_token()
@pytest.mark.asyncio
async def test_successful_refresh_updates_access_token(self):
from app.integrations.ms_graph import MSGraphClient
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_old"})
mock_app = MagicMock()
mock_app.acquire_token_by_refresh_token.return_value = {
"access_token": "new_access_token",
"refresh_token": "new_refresh_token",
}
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
patch("app.integrations.ms_graph.settings") as mock_settings:
mock_settings.MS_CLIENT_ID = "client_id"
mock_settings.MS_CLIENT_SECRET = "secret"
mock_settings.MS_TENANT_ID = "common"
await client._refresh_access_token()
assert client._access_token == "new_access_token"
assert client._refresh_token == "new_refresh_token"

299
tests/test_journey_v2.py Normal file
View File

@@ -0,0 +1,299 @@
"""Tests for Local Agent V2 journey setup (Step 4).
Covers the chatbot journey that produces a structured AgentConfig JSON
instead of a freeform prompt_template string.
Unit tests (no LLM)
--------------------
4.6a _extract_agent_config: valid JSON → returns serialised config
4.6b _extract_agent_config: invalid JSON → returns None
4.6c _extract_agent_config: markers absent → returns None
4.6d _extract_agent_config: only START marker → returns None
4.6e Session not found → done=True, agent_config=None
4.6f Nudge uses AGENT_CONFIG_START/END markers (not old PROMPT_TEMPLATE)
Eval test (real LLM + Langfuse scoring)
----------------------------------------
4.1 Journey start explores directory → first reply contains a question
Cases 4.24.5 (multi-turn conversations producing a full AgentConfig) are
non-deterministic and tested manually — results tracked in Langfuse.
Run:
pytest tests/test_journey_v2.py -v
pytest tests/test_journey_v2.py -v -k "4_6" # unit only
pytest tests/test_journey_v2.py -v -k "eval" # single LLM eval
pytest tests/test_journey_v2.py -v --journey-dir /p # custom fixtures
"""
from __future__ import annotations
import uuid
from contextlib import nullcontext
from pathlib import Path
from typing import Any
from unittest.mock import patch
import pytest
import yaml
from app.api.routes.agent_setup import (
_CONFIG_END,
_CONFIG_START,
_MAX_TURNS,
_extract_agent_config,
_sessions,
handle_journey_message,
handle_journey_start,
)
from app.core.langfuse_client import get_langfuse
from app.core.ws_context import clear_client_executor, set_client_executor
from app.schemas import AgentConfig
from tests.conftest import TEST_USER_IDS
# ── Constants ─────────────────────────────────────────────────────────────
_USER_ID = TEST_USER_IDS["power"]
_DEFAULT_FIXTURE_DIR = Path(__file__).parent / "fixtures" / "journey_v2"
# ── Fixture loading ───────────────────────────────────────────────────────
def _fixtures_dir(config) -> Path:
override = config.getoption("--journey-dir")
return Path(override) if override else _DEFAULT_FIXTURE_DIR
def _load_cases(config) -> list[dict]:
return yaml.safe_load(
(_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8")
)
def _read_data_file(filename: str, fixtures_dir: Path) -> str:
return (fixtures_dir / "data" / filename).read_text(encoding="utf-8")
# ── pytest_generate_tests ─────────────────────────────────────────────────
def pytest_generate_tests(metafunc):
if "journey_case" not in metafunc.fixturenames:
return
cases = _load_cases(metafunc.config)
metafunc.parametrize("journey_case", cases, ids=[c["id"] for c in cases])
# ── Executor builder ──────────────────────────────────────────────────────
def _make_fs_executor(directory_files: list[dict], fixtures_dir: Path):
"""Return an async callback that simulates filesystem tool responses.
Matches the signature expected by ``set_client_executor`` / ``execute_on_client``:
receives the full ``payload`` dict and returns a result dict.
``directory_files`` is a list of ``{path, content_file}`` dicts;
``content_file`` is relative to ``fixtures_dir/data/``.
"""
file_map: dict[str, str] = {
entry["path"]: _read_data_file(entry["content_file"], fixtures_dir)
for entry in directory_files
}
async def _executor(payload: dict) -> dict:
action = payload.get("action", "")
data = payload.get("data") or {}
if action == "list_directory":
return {"entries": [
{"type": "file", "name": p.split("/")[-1], "path": p}
for p in file_map
]}
if action == "read_file_content":
path = data.get("path", "")
return {"content": file_map.get(path, "")}
if action == "get_file_metadata":
path = data.get("path", "")
name = path.split("/")[-1]
ext = "." + name.rsplit(".", 1)[-1] if "." in name else ""
return {"name": name, "extension": ext, "size": 1024,
"createdAt": None, "modifiedAt": None}
return {}
return _executor
# ── Journey runner helper ─────────────────────────────────────────────────
async def _run_journey(user_id: str, case: dict, executor) -> dict[str, Any]:
"""Drive start + all user_messages for a case. Returns the final reply dict.
Mirrors ``device_ws._handle_journey_start/message``: sets the client
executor (so filesystem tools work) before each handler call.
"""
session_id = str(uuid.uuid4())
try:
set_client_executor(executor)
reply = await handle_journey_start(user_id, {
"agent_type": "local",
"directory": case["directory"],
"data_types": case["data_types"],
"session_id": session_id,
})
for msg in case.get("user_messages", []):
if reply.get("done"):
break
set_client_executor(executor)
reply = await handle_journey_message(user_id, {
"session_id": reply["session_id"],
"message": msg,
})
finally:
clear_client_executor()
_sessions.pop(session_id, None)
return reply
# ── Assertion helper ──────────────────────────────────────────────────────
def _evaluate_case(case: dict, reply: dict) -> tuple[float, str]:
"""Return (score, comment) for a journey case given the final reply dict."""
if case.get("expect_question"):
has_q = "?" in reply.get("message", "")
return (1.0 if has_q else 0.0), f"first_reply_has_question={has_q}"
return 1.0, "no specific assertion"
# ── Unit tests ────────────────────────────────────────────────────────────
def test_4_6a_extract_valid_json():
"""_extract_agent_config: valid JSON between markers → returns serialised config."""
config = AgentConfig(
content_types=[],
global_rules=["No project = no entity"],
data_types=["tasks"],
)
text = f"Some preamble\n{_CONFIG_START}\n{config.model_dump_json()}\n{_CONFIG_END}\nTrailing"
result = _extract_agent_config(text)
assert result is not None
parsed = AgentConfig.model_validate_json(result)
assert parsed.global_rules == ["No project = no entity"]
def test_4_6b_extract_invalid_json():
"""_extract_agent_config: malformed JSON between markers → returns None."""
text = f"{_CONFIG_START}\n{{not: valid json\n{_CONFIG_END}"
assert _extract_agent_config(text) is None
def test_4_6c_extract_markers_absent():
"""_extract_agent_config: no markers at all → returns None."""
assert _extract_agent_config("No markers here at all") is None
def test_4_6d_extract_only_start_marker():
"""_extract_agent_config: START without END → returns None."""
assert _extract_agent_config(f"text {_CONFIG_START} no end marker") is None
@pytest.mark.asyncio
async def test_4_6e_session_not_found():
"""4.6e Session not found → done=True, agent_config=None, informative message."""
reply = await handle_journey_message(_USER_ID, {
"session_id": "nonexistent-session-id",
"message": "Hello",
})
assert reply["done"] is True
assert reply["agent_config"] is None
assert "not found" in reply["message"].lower() or "expired" in reply["message"].lower()
@pytest.mark.asyncio
async def test_4_6f_nudge_uses_new_markers():
"""4.6f Nudge injected after max turns uses AGENT_CONFIG markers, not PROMPT_TEMPLATE."""
session_id = str(uuid.uuid4())
captured_histories: list[list[dict]] = []
async def _mock_llm(system_prompt, history, tools, **kwargs) -> str:
captured_histories.append(list(history))
# Return plain text — no markers — to trigger the nudge path.
return "I still need more information from you."
from app.api.routes.agent_setup import JourneySession
fake_session = JourneySession(
session_id=session_id,
user_id=_USER_ID,
agent_type="local",
directory="/test",
data_types=["tasks"],
system_prompt="system",
langfuse_prompt=None,
)
# Fill history to the turn limit so the next message triggers the nudge.
for i in range(_MAX_TURNS):
fake_session.history.append({"role": "user", "content": f"msg {i}"})
fake_session.history.append({"role": "assistant", "content": "ok"})
_sessions[session_id] = fake_session
try:
with patch("app.api.routes.agent_setup._call_llm_with_tools", side_effect=_mock_llm):
await handle_journey_message(_USER_ID, {
"session_id": session_id,
"message": "one more message to trigger nudge",
})
finally:
_sessions.pop(session_id, None)
# Second LLM call receives the nudge appended to history.
assert len(captured_histories) >= 2, "Expected ≥ 2 LLM calls (main reply + nudge)"
nudge_history = captured_histories[1]
user_msgs = " ".join(t["content"] for t in nudge_history if t["role"] == "user")
assert _CONFIG_START in user_msgs, f"Nudge must reference {_CONFIG_START}"
assert _CONFIG_END in user_msgs, f"Nudge must reference {_CONFIG_END}"
assert "PROMPT_TEMPLATE" not in user_msgs, "Old PROMPT_TEMPLATE markers must not appear in nudge"
# ── Eval tests (real LLM + Langfuse) ─────────────────────────────────────
@pytest.mark.asyncio
@pytest.mark.eval
async def test_eval_journey(journey_case, pytestconfig):
"""Parametrized eval test — one invocation per YAML case."""
case: dict = journey_case
fixtures_dir = _fixtures_dir(pytestconfig)
executor = _make_fs_executor(case.get("directory_files", []), fixtures_dir)
lf = get_langfuse()
obs_ctx = lf.start_as_current_observation(
name=f"eval-journey-{case['id']}-{case.get('score_name', 'unknown').replace('.', '-')}",
metadata={"step": "4", "case_id": case["id"]},
) if lf else nullcontext()
with obs_ctx as obs:
reply = await _run_journey(_USER_ID, case, executor)
score, comment = _evaluate_case(case, reply)
if obs is not None:
obs.score(
name=case.get("score_name", f"journey.case_{case['id']}"),
value=score,
comment=comment,
)
if lf:
lf.flush()
assert score == 1.0, f"[{case['id']}] {case.get('description', '')}{comment}"

View File

@@ -0,0 +1,343 @@
"""Tests for Step 7 — MemoryMiddleware.
Coverage:
1. enrich_context returns core prefs + associative + episodic + proactive
2. store_episode creates an encrypted row decryptable with the user's key
3. update_core upserts correctly
4. User with no encryption_key returns empty context (no crash)
5. End-to-end: home_request WS frame results in an episodic row being stored
"""
from __future__ import annotations
import json
import uuid
from unittest.mock import patch
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.core.memory_middleware import MemoryMiddleware, _PROACTIVE_CONFIDENCE_THRESHOLD
from app.db import get_session
from app.main import app
from app.models import (
MemoryAssociative,
MemoryCore,
MemoryEpisodic,
MemoryProactive,
User,
)
from tests.conftest import TEST_USER_IDS, make_jwt
USER_ID = TEST_USER_IDS["power"]
_FERNET_KEY = Fernet.generate_key().decode()
# ── DB override ───────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest_asyncio.fixture
async def user_with_key(db_session):
"""Set encryption_key on the seeded power user."""
result = await db_session.execute(select(User).where(User.id == USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
def _fernet():
return Fernet(_FERNET_KEY.encode())
def _enc(plaintext: str) -> str:
return _fernet().encrypt(plaintext.encode()).decode()
def _dec(ciphertext: str) -> str:
return _fernet().decrypt(ciphertext.encode()).decode()
# ── enrich_context ────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_enrich_context_returns_core_memory(db_session, user_with_key):
# Seed a core memory row
db_session.add(MemoryCore(
id=str(uuid.uuid4()),
user_id=USER_ID,
key="timezone",
value_encrypted=_enc("UTC"),
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "What are my tasks?")
assert "core_memory" in ctx
assert ctx["core_memory"]["timezone"] == "UTC"
@pytest.mark.asyncio
async def test_enrich_context_returns_episodic_memory(db_session, user_with_key):
session_id = str(uuid.uuid4())
db_session.add(MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=USER_ID,
summary_encrypted=_enc("User asked about Q1 tasks"),
session_id=session_id,
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "any message")
assert "episodic_memory" in ctx
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
@pytest.mark.asyncio
async def test_enrich_context_filters_episodic_by_session_id(db_session, user_with_key):
target_session = str(uuid.uuid4())
other_session = str(uuid.uuid4())
db_session.add(MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=USER_ID,
summary_encrypted=_enc("Target session memory"),
session_id=target_session,
))
db_session.add(MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=USER_ID,
summary_encrypted=_enc("Other session memory"),
session_id=other_session,
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "any message", session_id=target_session)
episodic = ctx.get("episodic_memory", [])
assert any("Target session" in s for s in episodic)
assert not any("Other session" in s for s in episodic)
@pytest.mark.asyncio
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
# Add one pattern above threshold and one below
db_session.add(MemoryProactive(
id=str(uuid.uuid4()),
user_id=USER_ID,
pattern_encrypted=_enc("User prefers short summaries"),
confidence=0.9,
source="inferred",
))
db_session.add(MemoryProactive(
id=str(uuid.uuid4()),
user_id=USER_ID,
pattern_encrypted=_enc("User likes dark mode"),
confidence=0.1,
source="inferred",
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "any message")
assert "proactive_hints" in ctx
hints = ctx["proactive_hints"]
assert any("short summaries" in h for h in hints)
assert not any("dark mode" in h for h in hints)
@pytest.mark.asyncio
async def test_enrich_context_returns_associative_memory(db_session, user_with_key):
db_session.add(MemoryAssociative(
id=str(uuid.uuid4()),
user_id=USER_ID,
content_encrypted=_enc("Related memory about meetings"),
embedding=None,
entity_type="note",
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "meetings")
assert "associative_memory" in ctx
assert any("meetings" in m for m in ctx["associative_memory"])
@pytest.mark.asyncio
async def test_enrich_context_empty_for_user_without_key(db_session):
"""User with no encryption_key → empty context, no crash."""
result = await db_session.execute(select(User).where(User.id == USER_ID))
user = result.scalar_one()
user.encryption_key = None
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "hello")
assert ctx == {}
# ── store_episode ─────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_store_episode_creates_encrypted_row(db_session, user_with_key):
session_id = str(uuid.uuid4())
middleware = MemoryMiddleware(db_session)
await middleware.store_episode(USER_ID, session_id, "hello", "world")
result = await db_session.execute(
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
)
row = result.scalar_one()
plaintext = _dec(row.summary_encrypted)
assert "hello" in plaintext
assert "world" in plaintext
@pytest.mark.asyncio
async def test_store_episode_decryptable(db_session, user_with_key):
session_id = str(uuid.uuid4())
middleware = MemoryMiddleware(db_session)
await middleware.store_episode(USER_ID, session_id, "msg", "resp")
result = await db_session.execute(
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
)
row = result.scalar_one()
# Decrypt using the same key — must not raise
decrypted = _dec(row.summary_encrypted)
assert len(decrypted) > 0
# ── update_core ───────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_update_core_insert(db_session, user_with_key):
middleware = MemoryMiddleware(db_session)
await middleware.update_core(USER_ID, "lang", "en")
result = await db_session.execute(
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
)
row = result.scalar_one()
assert _dec(row.value_encrypted) == "en"
@pytest.mark.asyncio
async def test_update_core_upsert(db_session, user_with_key):
middleware = MemoryMiddleware(db_session)
await middleware.update_core(USER_ID, "lang", "en")
await middleware.update_core(USER_ID, "lang", "fr")
result = await db_session.execute(
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
)
rows = result.scalars().all()
assert len(rows) == 1
assert _dec(rows[0].value_encrypted) == "fr"
@pytest.mark.asyncio
async def test_core_block_edit_ops(db_session, user_with_key):
middleware = MemoryMiddleware(db_session)
await middleware.update_core(USER_ID, "human", "Name: Roberto")
await middleware.append_core(USER_ID, "human", "Timezone: Europe/Rome")
replaced = await middleware.replace_core(USER_ID, "human", "Roberto", "Robert")
blocks = await middleware.list_core_blocks(USER_ID)
human = next(b for b in blocks if b["label"] == "human")
assert replaced is True
assert "Name: Robert" in human["value"]
assert "Timezone: Europe/Rome" in human["value"]
deleted = await middleware.delete_core(USER_ID, "human")
assert deleted is True
assert await middleware.get_core_block(USER_ID, "human") is None
@pytest.mark.asyncio
async def test_archival_and_recall_search_helpers(db_session, user_with_key):
middleware = MemoryMiddleware(db_session)
await middleware.insert_archival(USER_ID, "Project whitelist has release risk", source="assistant")
await middleware.store_episode(USER_ID, str(uuid.uuid4()), "How is whitelist?", "Whitelist is delayed")
arch = await middleware.search_archival(USER_ID, "whitelist", top_k=3)
rec = await middleware.search_recall(USER_ID, "delayed", top_k=3)
assert any("whitelist" in item.lower() for item in arch)
assert any("delayed" in item.lower() for item in rec)
# ── End-to-end WS: memory middleware is called during home_request ────────────
def test_home_request_calls_memory_middleware(client):
"""home_request triggers enrich_context before and store_episode after the LLM."""
enrich_calls: list[tuple] = []
store_calls: list[tuple] = []
class _MockMiddleware:
def __init__(self, db):
pass
async def enrich_context(self, user_id, message, **kwargs):
enrich_calls.append((user_id, message))
return {"core_memory": {"tz": "UTC"}}
async def store_episode(self, user_id, session_id, message, response, **kwargs):
store_calls.append((user_id, session_id, message, response))
token = make_jwt("power", user_id=USER_ID)
session_id = str(uuid.uuid4())
async def _mock_stream(user_id, message, context):
# Verify memory context was injected
assert context.get("core_memory") == {"tz": "UTC"}
yield "token", "Done"
with (
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_stream),
):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-mem", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "home_request",
"request_id": "r-mem",
"session_id": session_id,
"message": "Show tasks",
}))
for _ in range(20):
raw = ws.receive_text()
frame = json.loads(raw)
if frame.get("type") == "stream_end":
break
assert len(enrich_calls) == 1
assert enrich_calls[0] == (USER_ID, "Show tasks")
assert len(store_calls) == 1
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
assert stored_session_id == session_id
assert stored_message == "Show tasks"

205
tests/test_memory_models.py Normal file
View File

@@ -0,0 +1,205 @@
"""Tests for Step 6 — memory ORM models and User.encryption_key.
Uses the SQLite in-memory test DB (from conftest). The pgvector embedding
column is stored as JSON in tests (SQLite-compatible).
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.models import MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive, User
from tests.conftest import TEST_USER_IDS
USER_ID = TEST_USER_IDS["power"]
# ── helpers ───────────────────────────────────────────────────────────────────
def _fernet_key() -> str:
return Fernet.generate_key().decode()
def _encrypt(key: str, plaintext: str) -> str:
return Fernet(key.encode()).encrypt(plaintext.encode()).decode()
def _decrypt(key: str, ciphertext: str) -> str:
return Fernet(key.encode()).decrypt(ciphertext.encode()).decode()
# ── User.encryption_key ───────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_user_encryption_key_column_exists(db_session):
"""User model has encryption_key column and it can be set."""
result = await db_session.execute(select(User).where(User.id == USER_ID))
user = result.scalar_one()
# Column exists (may be None for seeded users)
assert hasattr(user, "encryption_key")
@pytest.mark.asyncio
async def test_user_encryption_key_can_be_set(db_session):
key = _fernet_key()
result = await db_session.execute(select(User).where(User.id == USER_ID))
user = result.scalar_one()
user.encryption_key = key
await db_session.commit()
result2 = await db_session.execute(select(User).where(User.id == USER_ID))
user2 = result2.scalar_one()
assert user2.encryption_key == key
# ── MemoryCore ────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_memory_core_create_and_read(db_session):
key = _fernet_key()
encrypted_val = _encrypt(key, "UTC")
row = MemoryCore(
id=str(uuid.uuid4()),
user_id=USER_ID,
key="timezone",
value_encrypted=encrypted_val,
)
db_session.add(row)
await db_session.commit()
result = await db_session.execute(
select(MemoryCore).where(MemoryCore.user_id == USER_ID)
)
fetched = result.scalar_one()
assert fetched.key == "timezone"
assert _decrypt(key, fetched.value_encrypted) == "UTC"
@pytest.mark.asyncio
async def test_memory_core_cascade_delete(db_session):
"""Deleting a user cascades to memory_core."""
row = MemoryCore(
id=str(uuid.uuid4()),
user_id=USER_ID,
key="lang",
value_encrypted="enc",
)
db_session.add(row)
await db_session.commit()
user = (await db_session.execute(select(User).where(User.id == USER_ID))).scalar_one()
await db_session.delete(user)
await db_session.commit()
remaining = (
await db_session.execute(select(MemoryCore).where(MemoryCore.user_id == USER_ID))
).scalars().all()
assert remaining == []
# ── MemoryAssociative ─────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_memory_associative_create_and_read(db_session):
key = _fernet_key()
content = _encrypt(key, "User prefers morning meetings")
embedding = [0.1] * 1536 # fake embedding
row = MemoryAssociative(
id=str(uuid.uuid4()),
user_id=USER_ID,
content_encrypted=content,
embedding=embedding,
entity_type="preference",
entity_id=None,
)
db_session.add(row)
await db_session.commit()
result = await db_session.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == USER_ID)
)
fetched = result.scalar_one()
assert fetched.entity_type == "preference"
assert _decrypt(key, fetched.content_encrypted) == "User prefers morning meetings"
assert len(fetched.embedding) == 1536
# ── MemoryEpisodic ────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_memory_episodic_create_and_read(db_session):
key = _fernet_key()
session_id = str(uuid.uuid4())
summary = _encrypt(key, "User asked about Q1 tasks")
row = MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=USER_ID,
summary_encrypted=summary,
session_id=session_id,
)
db_session.add(row)
await db_session.commit()
result = await db_session.execute(
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
)
fetched = result.scalar_one()
assert _decrypt(key, fetched.summary_encrypted) == "User asked about Q1 tasks"
assert isinstance(fetched.created_at, datetime)
# ── MemoryProactive ───────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_memory_proactive_create_and_read(db_session):
key = _fernet_key()
pattern = _encrypt(key, "User always assigns tasks to self")
row = MemoryProactive(
id=str(uuid.uuid4()),
user_id=USER_ID,
pattern_encrypted=pattern,
confidence=0.85,
source="inferred",
)
db_session.add(row)
await db_session.commit()
result = await db_session.execute(
select(MemoryProactive).where(MemoryProactive.user_id == USER_ID)
)
fetched = result.scalar_one()
assert fetched.confidence == pytest.approx(0.85)
assert fetched.source == "inferred"
assert _decrypt(key, fetched.pattern_encrypted) == "User always assigns tasks to self"
# ── Auth registration generates encryption_key ───────────────────────────────
def test_register_sets_encryption_key(client):
"""POST /api/v1/auth/register creates a user with a valid Fernet key."""
resp = client.post(
"/api/v1/auth/register",
json={"email": "newuser@test.com", "password": "testpassword123"},
)
assert resp.status_code == 201
# Fetch the newly created user via the access token
token = resp.json()["access_token"]
me_resp = client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"},
)
assert me_resp.status_code == 200
# We can't see encryption_key in the API response (not in UserProfile),
# but we verify registration didn't crash — key generation is implicit.

View File

@@ -20,7 +20,6 @@ from jose import jwt
from app.config.settings import settings
from app.db import get_session
from app.main import app
from app.schemas import ChatResponse
from tests.conftest import TEST_USER_IDS
# ---------------------------------------------------------------------------
@@ -50,7 +49,6 @@ _CHAT_BODY = {
"recent_tasks": [],
"conversation_history": [],
},
"execution_mode": "direct",
}
@@ -240,7 +238,7 @@ class TestRateLimitMiddleware:
class TestSanitizerMiddleware:
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
"""Mock ``run_home`` to inject controlled strings into chat responses."""
_CHAT_PATH = "/api/v1/chat"
@@ -248,11 +246,10 @@ class TestSanitizerMiddleware:
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
def _post_chat(self, client: TestClient, response_text: str) -> dict:
mock_response = ChatResponse(response=response_text, actions=[])
with patch(
"app.api.routes.chat.orchestrate",
"app.api.routes.chat.run_home",
new_callable=AsyncMock,
return_value=mock_response,
return_value=response_text,
):
resp = client.post(
self._CHAT_PATH,

View File

@@ -1,348 +0,0 @@
"""Integration tests for the orchestrator module."""
from __future__ import annotations
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.core.agent_registry import AgentRegistry, ChatAgent
from app.core.orchestrator import (
classify_intent,
orchestrate,
orchestrate_stream,
route_pipeline,
route_single,
)
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
# ── Stub agents ──────────────────────────────────────────────────────
class _TaskAgent(ChatAgent):
def get_name(self) -> str:
return "task_agent"
def get_description(self) -> str:
return "Manages tasks: create, update, list, suggest"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return f"task: {query}"
class _CalendarAgent(ChatAgent):
def get_name(self) -> str:
return "calendar_agent"
def get_description(self) -> str:
return "Calendar management: events, conflicts, scheduling"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return f"calendar: {query}"
# ── Helpers ──────────────────────────────────────────────────────────
def _mock_llm(response_text: str) -> MagicMock:
"""Return a mock LLM that always produces *response_text*."""
msg = MagicMock()
msg.content = response_text
llm = MagicMock()
llm.ainvoke = AsyncMock(return_value=msg)
return llm
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _fresh_registry():
"""Reset the AgentRegistry singleton between tests."""
AgentRegistry._instance = None
yield
AgentRegistry._instance = None
@pytest.fixture()
def reg() -> AgentRegistry:
r = AgentRegistry()
r.register(_TaskAgent)
r.register(_CalendarAgent)
return r
# ── classify_intent ───────────────────────────────────────────────────
class TestClassifyIntent:
@pytest.mark.asyncio
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
result = await classify_intent("add a task", {}, reg)
assert result == "task_agent"
@pytest.mark.asyncio
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("calendar_agent")
result = await classify_intent("schedule a meeting", {}, reg)
assert result == "calendar_agent"
@pytest.mark.asyncio
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("nonexistent_agent")
result = await classify_intent("do something", {}, reg)
assert result == "task_agent"
@pytest.mark.asyncio
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
empty_reg = AgentRegistry()
# No LLM should be instantiated — early return path
with patch("app.core.orchestrator._make_llm") as mock_cls:
result = await classify_intent("anything", {}, empty_reg)
mock_cls.assert_not_called()
assert result == "task_agent"
@pytest.mark.asyncio
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm(" task_agent \n")
result = await classify_intent("create task", {}, reg)
assert result == "task_agent"
# ── route_single ─────────────────────────────────────────────────────
class TestRouteSingle:
@pytest.mark.asyncio
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
result = await route_single("task_agent", "create a task", {}, reg)
assert isinstance(result, ChatResponse)
@pytest.mark.asyncio
async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
result = await route_single("task_agent", "create a task", {}, reg)
assert result.response == "task: create a task"
@pytest.mark.asyncio
async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
with pytest.raises(KeyError):
await route_single("nonexistent", "hello", {}, reg)
@pytest.mark.asyncio
async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
result = await route_single("task_agent", "hi", {}, reg)
assert result.actions == []
# ── route_pipeline ────────────────────────────────────────────────────
class TestRoutePipeline:
@pytest.mark.asyncio
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("synthesized result")
result = await route_pipeline(
["task_agent", "calendar_agent"], "plan my week", {}, reg
)
assert isinstance(result, ChatResponse)
@pytest.mark.asyncio
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("synthesized result")
result = await route_pipeline(
["task_agent", "calendar_agent"], "plan my week", {}, reg
)
assert result.response == "synthesized result"
@pytest.mark.asyncio
async def test_passes_previous_results_to_subsequent_agents(
self, reg: AgentRegistry
) -> None:
"""Each agent after the first should receive prior outputs in context."""
received_contexts: list[dict[str, Any]] = []
class _CapturingAgent(ChatAgent):
def get_name(self) -> str:
return "capture"
def get_description(self) -> str:
return "captures context for testing"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
received_contexts.append(dict(context))
return "captured"
reg.register(_CapturingAgent)
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("done")
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
# The second agent (capture) must have received previous results
assert len(received_contexts) == 1
assert "previous_results" in received_contexts[0]
assert received_contexts[0]["previous_results"] == ["task: hi"]
@pytest.mark.asyncio
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("single result")
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
assert result.response == "single result"
# ── orchestrate ───────────────────────────────────────────────────────
class TestOrchestrate:
@pytest.mark.asyncio
async def test_direct_mode_returns_chat_response(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
result = await orchestrate(request, reg)
assert isinstance(result, ChatResponse)
@pytest.mark.asyncio
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
result = await orchestrate(request, reg)
assert isinstance(result, ChatResponse)
assert result.response == "task: add a task"
@pytest.mark.asyncio
async def test_plan_mode_returns_execution_plan(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="plan my tasks", execution_mode="plan")
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
@pytest.mark.asyncio
async def test_plan_mode_agent_matches_classified(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("calendar_agent")
request = ChatRequest(
message="schedule something", execution_mode="plan"
)
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
assert result.agent == "calendar_agent"
@pytest.mark.asyncio
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="plan tasks", execution_mode="plan")
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
assert len(result.steps) >= 1
@pytest.mark.asyncio
async def test_plan_mode_template_id_contains_agent_name(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="plan tasks", execution_mode="plan")
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
assert result.steps[0].prompt_template is not None
assert "task_agent" in result.steps[0].prompt_template
@pytest.mark.asyncio
async def test_default_execution_mode_is_direct(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
# execution_mode defaults to "direct"
request = ChatRequest(message="help me")
result = await orchestrate(request, reg)
assert isinstance(result, ChatResponse)
# ── orchestrate_stream ────────────────────────────────────────────────
class TestOrchestrateStream:
@pytest.mark.asyncio
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
assert len(chunks) >= 1
@pytest.mark.asyncio
async def test_last_chunk_is_final_json_frame(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
last = json.loads(chunks[-1])
assert last["done"] is True
assert "response" in last
assert "actions" in last
@pytest.mark.asyncio
async def test_final_frame_response_matches_agent_output(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="create a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
final = json.loads(chunks[-1])
assert final["response"] == "task: create a task"
@pytest.mark.asyncio
async def test_text_chunks_before_final_frame(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(
message="x" * 200, execution_mode="direct"
) # long enough to produce multiple chunks
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
# All but the last chunk should be plain text (not valid final JSON)
non_final = chunks[:-1]
for chunk in non_final:
try:
parsed = json.loads(chunk)
assert parsed.get("done") is not True
except json.JSONDecodeError:
pass # plain text chunk — expected

Some files were not shown because too many files have changed in this diff Show More