Compare commits
39 Commits
feature/de
...
333bba6fdd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
333bba6fdd | ||
|
|
229e20d073 | ||
|
|
0b491b3643 | ||
|
|
0d5fa3e569 | ||
|
|
aff68a9051 | ||
|
|
5e9ef2809e | ||
|
|
90018af311 | ||
|
|
1e2e395676 | ||
|
|
59d3a53980 | ||
|
|
9feeaa79c8 | ||
|
|
aa219a4d08 | ||
|
|
552b8eb305 | ||
|
|
0d93b3960d | ||
|
|
f07580574b | ||
|
|
1a8bf11f90 | ||
|
|
e7cdce8287 | ||
|
|
58bc6efd4b | ||
|
|
6c450805cb | ||
|
|
f340d0fa3e | ||
|
|
edc53cb6eb | ||
|
|
725cece5c1 | ||
|
|
297e20ce8d | ||
|
|
5a03bd1cfb | ||
|
|
87b7a1c6c9 | ||
|
|
826f64d6bb | ||
| 5faa6b1d7c | |||
| 02a9684cd6 | |||
| fae9efee0d | |||
| 30b062dd4a | |||
| 2a0331d7ce | |||
| 13fd8677c1 | |||
| 9bd629cb59 | |||
| 9c97702daa | |||
| a1e364c9c0 | |||
| 5b55f1292a | |||
| 5bc9ea6cd6 | |||
| f7404b6f66 | |||
| d667e43c73 | |||
| fe085a7951 |
20
.env.example
20
.env.example
@@ -4,9 +4,17 @@ ENV=dev
|
|||||||
# ── Database ──────────────────────────────────────────────────────────────────
|
# ── Database ──────────────────────────────────────────────────────────────────
|
||||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
||||||
|
|
||||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
# ── Redis ─────────────────────────────────────────────────────────────────────
|
||||||
JWT_SECRET=replace-with-a-long-random-secret
|
REDIS_URL=redis://localhost:6379/0
|
||||||
JWT_ALGORITHM=HS256
|
|
||||||
|
# ── Auth (JWT RS256) ──────────────────────────────────────────────────────────
|
||||||
|
# Public key for optional local JWT verification (Traefik ForwardAuth handles
|
||||||
|
# this in production — services trust X-User-* headers from Traefik).
|
||||||
|
# Generate keypair:
|
||||||
|
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||||
|
# openssl rsa -in private.pem -pubout -out public.pem
|
||||||
|
# Paste PEM content with literal \n for newlines.
|
||||||
|
JWT_PUBLIC_KEY=
|
||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
|
||||||
@@ -17,7 +25,6 @@ OPENAI_API_KEY=
|
|||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
LLM_MODEL=gpt-4o
|
LLM_MODEL=gpt-4o
|
||||||
LLM_ROUTER_MODEL=gpt-4o-mini
|
|
||||||
|
|
||||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||||
STRIPE_SECRET_KEY=
|
STRIPE_SECRET_KEY=
|
||||||
@@ -42,3 +49,8 @@ QDRANT_API_KEY=
|
|||||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||||
# Comma-separated list parsed by Settings (override default if needed)
|
# Comma-separated list parsed by Settings (override default if needed)
|
||||||
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
||||||
|
|
||||||
|
# ── Langfuse (observability) ─────────────────────────────────────────────────
|
||||||
|
LANGFUSE_SECRET_KEY=sk-lf-...
|
||||||
|
LANGFUSE_PUBLIC_KEY=pk-lf-...
|
||||||
|
LANGFUSE_HOST=https://cloud.langfuse.com # or self-hosted URL
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -13,6 +13,9 @@ env/
|
|||||||
# Environment variables
|
# Environment variables
|
||||||
.env
|
.env
|
||||||
|
|
||||||
|
# Cryptographic keys
|
||||||
|
*.pem
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
|
|||||||
@@ -739,7 +739,7 @@ adiuva-api/
|
|||||||
│ │
|
│ │
|
||||||
│ ├── core/ # Orchestration engine
|
│ ├── core/ # Orchestration engine
|
||||||
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
||||||
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
|
│ │ ├── llm.py # LiteLLM factory (get_llm)
|
||||||
│ │ ├── orchestrator.py # Intent classification & routing
|
│ │ ├── orchestrator.py # Intent classification & routing
|
||||||
│ │ └── execution_plan.py # Plan builder, templates, cache
|
│ │ └── execution_plan.py # Plan builder, templates, cache
|
||||||
│ │
|
│ │
|
||||||
|
|||||||
@@ -0,0 +1,92 @@
|
|||||||
|
"""Deprecate backend agent config tables.
|
||||||
|
|
||||||
|
The Electron client is now the source of truth for agent configuration
|
||||||
|
(directory, extract targets, batch interval, custom prompt). Backend keeps
|
||||||
|
billing checks and trigger/run logs only.
|
||||||
|
|
||||||
|
Revision ID: 9a1f2d0b6c7e
|
||||||
|
Revises: 818478c251dc
|
||||||
|
Create Date: 2026-03-16
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "9a1f2d0b6c7e"
|
||||||
|
down_revision: Union[str, None] = "818478c251dc"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = sa.inspect(bind)
|
||||||
|
existing = set(inspector.get_table_names())
|
||||||
|
|
||||||
|
if "cloud_agent_configs" in existing:
|
||||||
|
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||||
|
op.drop_table("cloud_agent_configs")
|
||||||
|
|
||||||
|
if "local_agent_configs" in existing:
|
||||||
|
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||||
|
op.drop_table("local_agent_configs")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"local_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("device_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"cloud_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"provider",
|
||||||
|
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||||
|
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
|
"""Expose tool modules used by deep orchestrator-worker graphs."""
|
||||||
|
|
||||||
from app.agents import timeline_agent, note_agent, project_agent, task_agent
|
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
__all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]
|
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
|
|||||||
85
app/agents/filesystem_agent.py
Normal file
85
app/agents/filesystem_agent.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||||
|
|
||||||
|
These tools delegate to the Electron client via ``execute_on_client()`` using
|
||||||
|
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
|
||||||
|
handles actual disk I/O and responds with ``tool_result`` frames.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_directory(path: str) -> str:
|
||||||
|
"""List files and folders in a local directory on the user's device.
|
||||||
|
|
||||||
|
Returns a formatted listing of entries with name, type (file/directory),
|
||||||
|
and full path.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
return f"Directory '{path}' is empty or does not exist."
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
entry_type = entry.get("type", "unknown")
|
||||||
|
entry_name = entry.get("name", "")
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||||
|
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_file_content(path: str) -> str:
|
||||||
|
"""Read the text content of a local file on the user's device.
|
||||||
|
|
||||||
|
Returns the file content as a string. Large files may be truncated
|
||||||
|
by the Electron client.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return f"File '{path}' is empty or could not be read."
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_file_metadata(path: str) -> str:
|
||||||
|
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||||
|
|
||||||
|
Returns a formatted summary of the file's metadata.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_file_metadata",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
size = result.get("size", "unknown")
|
||||||
|
created = result.get("createdAt", "unknown")
|
||||||
|
modified = result.get("modifiedAt", "unknown")
|
||||||
|
extension = result.get("extension", "unknown")
|
||||||
|
name = result.get("name", path)
|
||||||
|
return (
|
||||||
|
f"File: {name}\n"
|
||||||
|
f" Extension: {extension}\n"
|
||||||
|
f" Size: {size} bytes\n"
|
||||||
|
f" Created: {created}\n"
|
||||||
|
f" Modified: {modified}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FILESYSTEM_TOOLS: list[Any] = [
|
||||||
|
list_directory,
|
||||||
|
read_file_content,
|
||||||
|
get_file_metadata,
|
||||||
|
]
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
"""Note agent — tool definitions for Markdown note CRUD."""
|
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
@@ -9,14 +10,38 @@ from langchain_core.tools import tool
|
|||||||
from app.core.llm import embed
|
from app.core.llm import embed
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
NOTE_SYSTEM_PROMPT = (
|
||||||
|
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||||
|
"and delete Markdown notes in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - content is always Markdown; preserve formatting when updating\n"
|
||||||
|
" - project_id is optional; link a note to a project when mentioned\n"
|
||||||
|
" - When updating, call get_note first if you need to read existing content\n"
|
||||||
|
" before appending or replacing sections\n"
|
||||||
|
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||||
|
" when the user is working within a specific project\n"
|
||||||
|
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
||||||
|
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||||
|
" is already in the note (retrieved via get_note)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="notes",
|
table="notes",
|
||||||
filters={"projectId": project_id or None},
|
filters={"projectId": normalized_project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -105,4 +130,10 @@ async def delete_note(note_id: str) -> str:
|
|||||||
return f"Note {note_id} deleted."
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
NOTE_TOOLS: list[Any] = [
|
||||||
|
list_notes,
|
||||||
|
get_note,
|
||||||
|
create_note,
|
||||||
|
update_note,
|
||||||
|
delete_note,
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Project agent — tool definitions for project lifecycle CRUD."""
|
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -8,6 +8,22 @@ from langchain_core.tools import tool
|
|||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
PROJECT_SYSTEM_PROMPT = (
|
||||||
|
"You are a project management assistant. You help users create, find,\n"
|
||||||
|
"update, and archive projects in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: active, archived\n"
|
||||||
|
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
||||||
|
" - ai_summary is populated only when the user asks for a project summary;\n"
|
||||||
|
" derive it from context data — do not fabricate content\n"
|
||||||
|
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
||||||
|
" user wants a complete cross-client view including archived projects\n"
|
||||||
|
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
||||||
|
" list_projects if you only have a project name\n"
|
||||||
|
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
||||||
|
" only call delete_project when the user explicitly confirms deletion."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_projects(
|
async def list_projects(
|
||||||
@@ -117,4 +133,11 @@ async def delete_project(project_id: str) -> str:
|
|||||||
return f"Project {project_id} permanently deleted."
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_TOOLS: list[Any] = [
|
||||||
|
list_projects,
|
||||||
|
list_all_projects,
|
||||||
|
get_project,
|
||||||
|
create_project,
|
||||||
|
update_project,
|
||||||
|
delete_project,
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,14 +1,40 @@
|
|||||||
"""Task agent — tool definitions for task and task comment CRUD."""
|
"""Task agent — full CRUD for tasks and task comments."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
TASK_SYSTEM_PROMPT = (
|
||||||
|
"You are a task management assistant for a project workspace.\n"
|
||||||
|
"You create, update, list, and track tasks and their comments.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: todo, in_progress, done\n"
|
||||||
|
" - priority must be one of: high, medium, low\n"
|
||||||
|
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
||||||
|
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
||||||
|
" - project_id is optional; link to a project when the user mentions one\n"
|
||||||
|
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
||||||
|
" did not explicitly request; 0 otherwise\n"
|
||||||
|
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
|
||||||
|
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
||||||
|
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Always confirm the action in plain, user-friendly language."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Task tools ────────────────────────────────────────────────────────
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -22,11 +48,12 @@ async def list_tasks(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
filters={
|
filters={
|
||||||
"projectId": project_id or None,
|
"projectId": normalized_project_id or None,
|
||||||
"status": status or None,
|
"status": status or None,
|
||||||
"search": search or None,
|
"search": search or None,
|
||||||
"orderBy": order_by or None,
|
"orderBy": order_by or None,
|
||||||
@@ -52,7 +79,6 @@ async def create_task(
|
|||||||
due_date: int = 0,
|
due_date: int = 0,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
is_approved: int = 0,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new task.
|
"""Create a new task.
|
||||||
title: task title (required)
|
title: task title (required)
|
||||||
@@ -63,7 +89,6 @@ async def create_task(
|
|||||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||||
project_id: optional UUID of the parent project
|
project_id: optional UUID of the parent project
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
is_approved: 0 until the user confirms; 1 when confirmed
|
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -77,7 +102,6 @@ async def create_task(
|
|||||||
"dueDate": due_date or None,
|
"dueDate": due_date or None,
|
||||||
"projectId": project_id or None,
|
"projectId": project_id or None,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
"isApproved": is_approved,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
@@ -97,12 +121,10 @@ async def update_task(
|
|||||||
assignees: str = "",
|
assignees: str = "",
|
||||||
due_date: int = -1,
|
due_date: int = -1,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
is_approved: int = -1,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update fields on an existing task. Only pass fields you want to change.
|
"""Update fields on an existing task. Only pass fields you want to change.
|
||||||
task_id: the task's UUID (required)
|
task_id: the task's UUID (required)
|
||||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||||
is_approved: -1 means unchanged; 0 or 1 sets the value
|
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
@@ -119,8 +141,6 @@ async def update_task(
|
|||||||
updates["dueDate"] = due_date or None
|
updates["dueDate"] = due_date or None
|
||||||
if project_id:
|
if project_id:
|
||||||
updates["projectId"] = project_id
|
updates["projectId"] = project_id
|
||||||
if is_approved != -1:
|
|
||||||
updates["isApproved"] = is_approved
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
@@ -188,8 +208,12 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|||||||
table="taskComments",
|
table="taskComments",
|
||||||
data={"taskId": task_id, "author": author, "content": content},
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result.get("row", {})
|
||||||
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
|
row_author = row.get("author", author)
|
||||||
|
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
|
||||||
|
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||||
|
row_comment_id = row.get("id", "unknown")
|
||||||
|
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -199,4 +223,16 @@ async def delete_task_comment(comment_id: str) -> str:
|
|||||||
return f"Comment {comment_id} deleted."
|
return f"Comment {comment_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
TASK_TOOLS: list[Any] = [
|
||||||
|
list_tasks,
|
||||||
|
create_task,
|
||||||
|
update_task,
|
||||||
|
delete_task,
|
||||||
|
list_tasks_due_today,
|
||||||
|
list_task_comments,
|
||||||
|
add_task_comment,
|
||||||
|
delete_task_comment,
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,21 +1,45 @@
|
|||||||
"""Timeline agent — tool definitions for project milestone CRUD."""
|
"""Timeline agent — project milestone management (list, create, update, delete)."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
TIMELINE_SYSTEM_PROMPT = (
|
||||||
|
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
||||||
|
"track progress on a project — they are not calendar events.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||||
|
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
||||||
|
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||||
|
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||||
|
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||||
|
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Listing without a project_id returns all timelines across projects\n"
|
||||||
|
" - Always echo the title and formatted date in your confirmation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
"""List timelines. Provide project_id to scope to a specific project."""
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
filters={"projectId": project_id or None},
|
filters={"projectId": normalized_project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -30,14 +54,12 @@ async def create_timeline(
|
|||||||
title: str,
|
title: str,
|
||||||
date: int,
|
date: int,
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
is_approved: int = 0,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a project timeline (milestone).
|
"""Create a project timeline (milestone).
|
||||||
project_id: REQUIRED UUID of the parent project
|
project_id: REQUIRED UUID of the parent project
|
||||||
title: descriptive name for the milestone
|
title: descriptive name for the milestone
|
||||||
date: Unix timestamp in milliseconds
|
date: Unix timestamp in milliseconds
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
is_approved: 0 until the user confirms
|
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -47,7 +69,6 @@ async def create_timeline(
|
|||||||
"title": title,
|
"title": title,
|
||||||
"date": date,
|
"date": date,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
"isApproved": is_approved,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
@@ -59,20 +80,16 @@ async def update_timeline(
|
|||||||
timeline_id: str,
|
timeline_id: str,
|
||||||
title: str = "",
|
title: str = "",
|
||||||
date: int = -1,
|
date: int = -1,
|
||||||
is_approved: int = -1,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update a timeline. Only pass fields that should change.
|
"""Update a timeline. Only pass fields that should change.
|
||||||
timeline_id: UUID of the timeline (required)
|
timeline_id: UUID of the timeline (required)
|
||||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
updates["title"] = title
|
updates["title"] = title
|
||||||
if date != -1:
|
if date != -1:
|
||||||
updates["date"] = date
|
updates["date"] = date
|
||||||
if is_approved != -1:
|
|
||||||
updates["isApproved"] = is_approved
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
@@ -89,4 +106,9 @@ async def delete_timeline(timeline_id: str) -> str:
|
|||||||
return f"Timeline {timeline_id} deleted."
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
TIMELINE_TOOLS: list[Any] = [
|
||||||
|
list_timelines,
|
||||||
|
create_timeline,
|
||||||
|
update_timeline,
|
||||||
|
delete_timeline,
|
||||||
|
]
|
||||||
|
|||||||
@@ -55,12 +55,15 @@ async def get_current_user(
|
|||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
|
|
||||||
# Live tier lookup — subscription row is the authoritative source.
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
|
# In dev, fall back to 'power' (unlimited) so quota limits don't
|
||||||
|
# block local development when no Stripe subscription exists.
|
||||||
from app.models import Subscription, User # noqa: PLC0415
|
from app.models import Subscription, User # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str = result.scalar_one_or_none() or "free"
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
# Fetch name/surname from user row.
|
# Fetch name/surname from user row.
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
|
|||||||
@@ -1,54 +1,40 @@
|
|||||||
"""Chatbot Journey endpoints — guided conversation to build an agent prompt_template.
|
"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
Endpoints:
|
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
||||||
POST /agents/journey/start — start a new journey session
|
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
||||||
POST /agents/journey/message — continue the conversation
|
frames to the functions exported here.
|
||||||
|
|
||||||
Sessions are stored in-memory with a 30-minute TTL. Stale entries are
|
|
||||||
cleaned up lazily on access. Upgrade to Redis for multi-instance deployments.
|
|
||||||
|
|
||||||
Journey flow:
|
Journey flow:
|
||||||
1. Client sends ``{ agent_type, agent_id? }`` to ``/start``.
|
1. FE sends ``journey_start`` frame with basic agent config (directory,
|
||||||
2. Server creates a session, calls the LLM with a contextual system prompt,
|
data_types, schedule).
|
||||||
and returns the first question.
|
2. Server creates an in-memory session, sets up a WS executor so the
|
||||||
3. Client sends follow-up messages to ``/message``.
|
setup LLM can use file-system tools, does a first directory scrape,
|
||||||
4. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` block
|
and sends back a ``journey_reply`` with the first question.
|
||||||
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
3. FE sends ``journey_message`` frames for each user reply.
|
||||||
5. Server parses the block, sets ``done=True``, and returns the template.
|
4. Server appends the user message, calls the LLM (which may read files
|
||||||
|
via tools), and sends back a ``journey_reply``.
|
||||||
The ``prompt_template`` from the final response is meant to be stored in
|
5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template``
|
||||||
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
|
block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
||||||
by the Electron client (via the agent CRUD endpoints).
|
6. Server parses the block, sends ``journey_reply`` with ``done=True``
|
||||||
|
and the template. FE stores it locally.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
from app.core.llm import get_llm
|
from app.core.llm import get_llm
|
||||||
from app.db import get_session
|
|
||||||
from app.models import CloudAgentConfig, LocalAgentConfig
|
|
||||||
from app.schemas import (
|
|
||||||
JourneyMessageRequest,
|
|
||||||
JourneyResponse,
|
|
||||||
JourneyStartRequest,
|
|
||||||
UserProfile,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/agents/journey", tags=["agents"])
|
|
||||||
|
|
||||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
@@ -57,18 +43,25 @@ _SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
|||||||
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||||
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||||
|
|
||||||
# Maximum number of conversation turns before the LLM is nudged to wrap up.
|
# Minimum turns before we consider nudging the LLM to wrap up.
|
||||||
_MAX_TURNS: int = 5
|
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||||
|
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
|
||||||
|
_MAX_TURNS: int = 15
|
||||||
|
# Max tool-calling steps per LLM invocation.
|
||||||
|
_MAX_TOOL_STEPS: int = 6
|
||||||
|
|
||||||
# ── In-memory session store ───────────────────────────────────────────────
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _JourneySession:
|
class JourneySession:
|
||||||
session_id: str
|
session_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
agent_type: str # "local" | "cloud"
|
agent_type: str # "local" | "cloud"
|
||||||
|
directory: str
|
||||||
|
data_types: list[str]
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
system_prompt: str = ""
|
||||||
created_at: float = field(default_factory=time.monotonic)
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
@@ -76,67 +69,84 @@ class _JourneySession:
|
|||||||
|
|
||||||
|
|
||||||
# session_id → session
|
# session_id → session
|
||||||
_sessions: dict[str, _JourneySession] = {}
|
_sessions: dict[str, JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
def _get_session(session_id: str, user_id: str) -> _JourneySession:
|
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||||
"""Retrieve session; raise 404 on missing, expired, or wrong owner."""
|
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||||
s = _sessions.get(session_id)
|
s = _sessions.get(session_id)
|
||||||
if s is None or s.is_expired():
|
if s is None or s.is_expired():
|
||||||
_sessions.pop(session_id, None)
|
_sessions.pop(session_id, None)
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
return None
|
||||||
if s.user_id != user_id:
|
if s.user_id != user_id:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
return None
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# ── System prompt builder ─────────────────────────────────────────────────
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
_LOCAL_PREAMBLE = """\
|
|
||||||
What kind of files are in the directories you want to monitor? \
|
|
||||||
(for example: emails saved as .eml, documents in .pdf or .txt, markdown notes, etc.)"""
|
|
||||||
|
|
||||||
_CLOUD_PREAMBLE = """\
|
|
||||||
What kind of emails or messages should I look for? \
|
|
||||||
(for example: client communications, invoices, meeting notes, project updates, etc.)"""
|
|
||||||
|
|
||||||
_SYSTEM_PROMPT_TEMPLATE = """\
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
Your job is to understand exactly what data the user wants to extract from their {source_description} \
|
Your job is to understand exactly what data the user wants to extract from their
|
||||||
and produce a detailed prompt_template that a separate AI will use as its instruction set.
|
local directory and produce a detailed prompt_template that a separate AI will use
|
||||||
|
as its instruction set.
|
||||||
|
|
||||||
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
|
The extraction agent already has this base behaviour built in:
|
||||||
1. The type and format of the source content.
|
- Reads each file using file-system tools.
|
||||||
2. Which data types to extract: tasks, notes, timelines, and/or projects.
|
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
|
||||||
3. How fields should be mapped (e.g. email subject → task title).
|
- Sets isAiSuggested=1 on every new record.
|
||||||
4. Priority or status rules (e.g. "urgent" keyword → high priority).
|
- Only extracts data explicitly present in the files — it never invents information.
|
||||||
5. Any special handling, date extraction, or exclusions.
|
The user's custom prompt is appended AFTER this base behaviour, so focus on
|
||||||
|
what to look for and how to map it — not on the general extraction mechanics.
|
||||||
|
|
||||||
After 3-5 questions (when you have enough information), output the final prompt_template between \
|
You have access to file-system tools to explore the user's directory:
|
||||||
these exact markers on their own lines:
|
- list_directory: to see folder structure
|
||||||
|
- read_file_content: to peek at file contents
|
||||||
|
- get_file_metadata: to check file info
|
||||||
|
|
||||||
|
The user's configured directory is: {directory}
|
||||||
|
Target data types: {data_types}
|
||||||
|
|
||||||
|
IMPORTANT — project assignment is handled automatically by the main agent runner
|
||||||
|
before the custom prompt is ever used. You MUST NOT ask the user about projects,
|
||||||
|
projectId, or how to link records to projects. Never include projectId logic or
|
||||||
|
project creation instructions in the generated prompt_template.
|
||||||
|
|
||||||
|
Start by exploring the directory to understand its structure. Then ask concise,
|
||||||
|
focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||||
|
1. The type and format of the source content (confirmed by your exploration).
|
||||||
|
2. How fields should be mapped (e.g. filename → task title).
|
||||||
|
3. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||||
|
4. Any special handling, date extraction, or exclusions.
|
||||||
|
|
||||||
|
Once you reach 90% confidence, output the final prompt_template between these exact
|
||||||
|
markers on their own lines:
|
||||||
|
|
||||||
{template_start}
|
{template_start}
|
||||||
<the complete extraction prompt here>
|
<the complete extraction prompt here>
|
||||||
{template_end}
|
{template_end}
|
||||||
|
|
||||||
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
|
The prompt_template must be a self-contained instruction for an AI that reads files
|
||||||
and must return a JSON array of records in this shape:
|
and must perform CRUD operations using tools to create records. It should specify:
|
||||||
[{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
|
- What entity types to create (tasks, notes, timelines) — never projects.
|
||||||
|
- How to map file content to record fields (camelCase: title, status, priority,
|
||||||
|
dueDate, content, etc.) — never include projectId.
|
||||||
|
- That isAiSuggested must be set to 1 on every new record.
|
||||||
|
- Concrete examples of mappings based on what you discovered in the directory.
|
||||||
|
|
||||||
Rules for the generated template:
|
|
||||||
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
|
|
||||||
- Include concrete examples of mappings.
|
|
||||||
- Mention that Electron adds id/createdAt/updatedAt automatically.
|
|
||||||
- Set isAiSuggested: true and isApproved: false on every record.
|
|
||||||
{existing_section}\
|
{existing_section}\
|
||||||
Do not ask more than {max_turns} questions total. Start with your first question now.\
|
Keep asking clarifying questions until you are at least 90% confident you have
|
||||||
|
enough information to generate an accurate prompt_template. Once you reach that
|
||||||
|
confidence level, stop asking and produce the final template immediately.
|
||||||
|
Begin by exploring the directory, then ask your first question.\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
def _build_system_prompt(
|
||||||
source_description = (
|
directory: str,
|
||||||
"files in local directories" if agent_type == "local" else "emails and messages from cloud providers"
|
data_types: list[str],
|
||||||
)
|
existing_template: str | None = None,
|
||||||
|
) -> str:
|
||||||
existing_section = (
|
existing_section = (
|
||||||
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
f"---\n{existing_template}\n---\n"
|
f"---\n{existing_template}\n---\n"
|
||||||
@@ -144,18 +154,14 @@ def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
|||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
return _SYSTEM_PROMPT_TEMPLATE.format(
|
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
source_description=source_description,
|
directory=directory,
|
||||||
|
data_types=", ".join(data_types),
|
||||||
template_start=_TEMPLATE_START,
|
template_start=_TEMPLATE_START,
|
||||||
template_end=_TEMPLATE_END,
|
template_end=_TEMPLATE_END,
|
||||||
existing_section=existing_section,
|
existing_section=existing_section,
|
||||||
max_turns=_MAX_TURNS,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _first_question(agent_type: str) -> str:
|
|
||||||
return _LOCAL_PREAMBLE if agent_type == "local" else _CLOUD_PREAMBLE
|
|
||||||
|
|
||||||
|
|
||||||
# ── Template extraction ───────────────────────────────────────────────────
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -168,11 +174,37 @@ def _extract_template(text: str) -> str | None:
|
|||||||
return text[start_idx:end_idx].strip() or None
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
# ── LLM call ─────────────────────────────────────────────────────────────
|
# ── LLM call with tool support ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
def _as_text(content: Any) -> str:
|
||||||
"""Build LangChain messages from history and invoke the LLM."""
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_llm_with_tools(
|
||||||
|
system_prompt: str,
|
||||||
|
history: list[dict[str, Any]],
|
||||||
|
tools: list[Any],
|
||||||
|
) -> str:
|
||||||
|
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||||
|
|
||||||
|
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||||
|
continue until a final text response is produced.
|
||||||
|
"""
|
||||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
for turn in history:
|
for turn in history:
|
||||||
if turn["role"] == "user":
|
if turn["role"] == "user":
|
||||||
@@ -181,137 +213,194 @@ async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
|||||||
messages.append(AIMessage(content=turn["content"]))
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
llm = get_llm(model=None, temperature=0.4)
|
llm = get_llm(model=None, temperature=0.4)
|
||||||
response = await llm.ainvoke(messages)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
return response.content # type: ignore[return-value]
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(_MAX_TOOL_STEPS):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"agent_setup: journey tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_setup: journey tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:800],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
# Fallback: exceeded max steps.
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
# ── Existing-config loader ────────────────────────────────────────────────
|
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _load_existing_template(
|
async def handle_journey_start(
|
||||||
agent_id: str,
|
|
||||||
user_id: str,
|
user_id: str,
|
||||||
db: AsyncSession,
|
frame: dict[str, Any],
|
||||||
) -> str | None:
|
) -> dict[str, Any]:
|
||||||
"""Return the prompt_template of an existing agent config, or None."""
|
"""Handle a ``journey_start`` WS frame.
|
||||||
# Try local first, then cloud.
|
|
||||||
local_result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.id == agent_id,
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
local = local_result.scalar_one_or_none()
|
|
||||||
if local is not None:
|
|
||||||
return local.prompt_template
|
|
||||||
|
|
||||||
cloud_result = await db.execute(
|
Creates a session, runs the setup LLM with directory exploration,
|
||||||
select(CloudAgentConfig).where(
|
and returns the ``journey_reply`` payload.
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cloud = cloud_result.scalar_one_or_none()
|
|
||||||
return cloud.prompt_template if cloud is not None else None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/start", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
|
||||||
async def start_journey(
|
|
||||||
body: JourneyStartRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> JourneyResponse:
|
|
||||||
"""Start a new Chatbot Journey session.
|
|
||||||
|
|
||||||
If ``agent_id`` is provided the session is pre-seeded with the existing
|
|
||||||
agent's ``prompt_template`` so the user can refine it.
|
|
||||||
"""
|
"""
|
||||||
# Load existing template (may be None).
|
agent_type = frame.get("agent_type", "local")
|
||||||
existing_template: str | None = None
|
directory = frame.get("directory", "")
|
||||||
if body.agent_id:
|
data_types = frame.get("data_types", [])
|
||||||
existing_template = await _load_existing_template(body.agent_id, current_user.id, db)
|
existing_template = frame.get("existing_template")
|
||||||
# If agent_id was given but not found, proceed without seeding (don't 404 —
|
|
||||||
# the user may be starting a fresh journey for a not-yet-persisted config).
|
|
||||||
|
|
||||||
system_prompt = _build_system_prompt(body.agent_type, existing_template)
|
# Use the session_id provided by the FE so the reply matches the
|
||||||
first_question = _first_question(body.agent_type)
|
# listener key; fall back to a generated one if absent.
|
||||||
|
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||||
|
system_prompt = _build_system_prompt(directory, data_types, existing_template)
|
||||||
|
|
||||||
session_id = str(uuid.uuid4())
|
session = JourneySession(
|
||||||
session = _JourneySession(
|
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=current_user.id,
|
user_id=user_id,
|
||||||
agent_type=body.agent_type,
|
agent_type=agent_type,
|
||||||
# Seed history with the AI's first question so it stays consistent.
|
directory=directory,
|
||||||
history=[{"role": "assistant", "content": first_question}],
|
data_types=data_types,
|
||||||
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
# Store the system prompt inside the session for reuse in /message.
|
|
||||||
session.__dict__["_system_prompt"] = system_prompt # type: ignore[index]
|
# The LLM will explore the directory using FILESYSTEM_TOOLS via the
|
||||||
|
# ws_context executor (already set by the WS handler before calling us).
|
||||||
|
# Seed with an initial user message — some providers (e.g. GitHub Copilot)
|
||||||
|
# require at least one user/input message to be present.
|
||||||
|
seed_history: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||||
|
]
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history=seed_history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.extend(seed_history)
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
_sessions[session_id] = session
|
_sessions[session_id] = session
|
||||||
|
|
||||||
logger.info("Journey session %s started for user %s (agent_type=%s)", session_id, current_user.id, body.agent_type)
|
logger.info(
|
||||||
return JourneyResponse(session_id=session_id, message=first_question, done=False)
|
"agent_setup: journey session %s started for user %s (directory=%s)",
|
||||||
|
session_id,
|
||||||
|
user_id,
|
||||||
|
directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the LLM produced the template on the first turn (unlikely but possible).
|
||||||
@router.post("/message", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
|
||||||
async def send_journey_message(
|
|
||||||
body: JourneyMessageRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> JourneyResponse:
|
|
||||||
"""Send a message in an existing Chatbot Journey session.
|
|
||||||
|
|
||||||
The server appends the user's message to the conversation history,
|
|
||||||
calls the LLM, and appends the AI reply. When the LLM wraps up with a
|
|
||||||
``prompt_template`` block the response includes ``done=True`` and the
|
|
||||||
extracted template.
|
|
||||||
"""
|
|
||||||
session = _get_session(body.session_id, current_user.id)
|
|
||||||
system_prompt: str = session.__dict__.get("_system_prompt", _build_system_prompt(session.agent_type, None)) # type: ignore[assignment]
|
|
||||||
|
|
||||||
# Append user turn to history.
|
|
||||||
session.history.append({"role": "user", "content": body.message})
|
|
||||||
|
|
||||||
# Call the LLM with the full conversation so far.
|
|
||||||
ai_reply = await _call_llm(system_prompt, session.history)
|
|
||||||
|
|
||||||
# Append AI turn.
|
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
|
||||||
|
|
||||||
# Check if the LLM produced the final template.
|
|
||||||
prompt_template = _extract_template(ai_reply)
|
prompt_template = _extract_template(ai_reply)
|
||||||
done = prompt_template is not None
|
done = prompt_template is not None
|
||||||
|
|
||||||
# Strip the sentinel markers from the message shown to the user.
|
|
||||||
display_message = ai_reply
|
display_message = ai_reply
|
||||||
if done:
|
if done:
|
||||||
display_message = (
|
display_message = (
|
||||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
or "Here is your agent configuration. You can save it or continue refining."
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
)
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
|
||||||
if done:
|
return {
|
||||||
logger.info("Journey session %s completed for user %s", body.session_id, current_user.id)
|
"type": "journey_reply",
|
||||||
# Clean up the session immediately on completion.
|
"session_id": session_id,
|
||||||
_sessions.pop(body.session_id, None)
|
"message": display_message,
|
||||||
else:
|
"done": done,
|
||||||
# Nudge the LLM to wrap up after max turns.
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_message(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_message`` WS frame.
|
||||||
|
|
||||||
|
Appends the user message, calls the LLM, and returns the
|
||||||
|
``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
message = frame.get("message", "")
|
||||||
|
|
||||||
|
session = get_journey_session(session_id, user_id)
|
||||||
|
if session is None:
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "Journey session not found or expired. Please start a new setup.",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Append user turn.
|
||||||
|
session.history.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
# Call the LLM with tools.
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
|
# Check if the LLM produced the final template.
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
# If the LLM didn't produce a template, nudge it once it has asked enough
|
||||||
|
# questions (>= _MIN_TURNS_BEFORE_NUDGE) or hits the hard safety cap.
|
||||||
|
if not done:
|
||||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
if turns >= _MAX_TURNS:
|
if turns >= _MAX_TURNS:
|
||||||
# Add a system-level nudge as a hidden user message.
|
nudge_content = (
|
||||||
session.history.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": (
|
|
||||||
"[System: You have enough information. Please generate the final "
|
"[System: You have enough information. Please generate the final "
|
||||||
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
),
|
|
||||||
})
|
|
||||||
|
|
||||||
return JourneyResponse(
|
|
||||||
session_id=body.session_id,
|
|
||||||
message=display_message,
|
|
||||||
done=done,
|
|
||||||
prompt_template=prompt_template,
|
|
||||||
)
|
)
|
||||||
|
session.history.append({"role": "user", "content": nudge_content})
|
||||||
|
|
||||||
|
nudge_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||||
|
|
||||||
|
prompt_template = _extract_template(nudge_reply)
|
||||||
|
if prompt_template is not None:
|
||||||
|
done = True
|
||||||
|
ai_reply = nudge_reply
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
if _TEMPLATE_START in ai_reply
|
||||||
|
else "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,45 +1,36 @@
|
|||||||
"""Agent CRUD routes: local directory agents and cloud connector agents.
|
"""Agent routes.
|
||||||
|
|
||||||
Endpoints:
|
Backend responsibilities are intentionally minimal:
|
||||||
GET /agents/catalog — hardcoded agent type catalog
|
GET /agents/catalog — static catalog for UI display
|
||||||
GET /agents/local — list user's local agent configs
|
POST /agents/can-create — billing eligibility check
|
||||||
POST /agents/local — create local agent (tier-gated)
|
POST /agents/trigger — trigger a local agent run
|
||||||
PUT /agents/local/{agent_id} — partial update (ownership check)
|
|
||||||
DELETE /agents/local/{agent_id} — delete + cascade run logs
|
Agent configuration is owned by the Electron app and is not persisted
|
||||||
GET /agents/cloud — list user's cloud agent configs
|
in backend agent-config tables.
|
||||||
POST /agents/cloud — create cloud agent (tier-gated)
|
|
||||||
PUT /agents/cloud/{agent_id} — partial update (ownership check)
|
|
||||||
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
|
|
||||||
GET /agents/runs — paginated run logs (agent_id, page, limit)
|
|
||||||
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
import uuid
|
||||||
from typing import Any
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from pydantic import BaseModel
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy import func, or_, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.billing.tier_manager import FEATURES
|
from app.billing.tier_manager import FEATURES
|
||||||
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
from app.core.agent_runner import is_agent_running, run_local_agent
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
from app.models import AgentRunLog, LocalAgentConfig
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
AgentCatalogItem,
|
AgentCatalogItem,
|
||||||
|
AgentCreationCheckRequest,
|
||||||
|
AgentCreationCheckResponse,
|
||||||
AgentRunLogResponse,
|
AgentRunLogResponse,
|
||||||
CloudAgentConfigCreate,
|
AgentTriggerRequest,
|
||||||
CloudAgentConfigResponse,
|
|
||||||
CloudAgentConfigUpdate,
|
|
||||||
LocalAgentConfigCreate,
|
|
||||||
LocalAgentConfigResponse,
|
|
||||||
LocalAgentConfigUpdate,
|
|
||||||
UserProfile,
|
UserProfile,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -56,39 +47,21 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
|
|||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
# ── Model → schema converters ─────────────────────────────────────────
|
def _to_data_types(values: list[str]) -> list[str]:
|
||||||
|
normalize = {
|
||||||
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
|
"task": "tasks", "tasks": "tasks",
|
||||||
return LocalAgentConfigResponse(
|
"note": "notes", "notes": "notes",
|
||||||
id=a.id,
|
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||||
name=a.name,
|
"project": "projects", "projects": "projects",
|
||||||
device_id=a.device_id,
|
}
|
||||||
directory_paths=a.directory_paths,
|
seen: set[str] = set()
|
||||||
data_types=a.data_types,
|
result: list[str] = []
|
||||||
prompt_template=a.prompt_template,
|
for v in values:
|
||||||
file_extensions=a.file_extensions,
|
mapped = normalize.get(v)
|
||||||
schedule_cron=a.schedule_cron,
|
if mapped and mapped not in seen:
|
||||||
enabled=a.enabled,
|
seen.add(mapped)
|
||||||
last_run_at=_dt_ms_opt(a.last_run_at),
|
result.append(mapped)
|
||||||
created_at=_dt_ms(a.created_at),
|
return result
|
||||||
updated_at=_dt_ms(a.updated_at),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
|
|
||||||
return CloudAgentConfigResponse(
|
|
||||||
id=a.id,
|
|
||||||
provider=a.provider, # type: ignore[arg-type]
|
|
||||||
name=a.name,
|
|
||||||
data_types=a.data_types,
|
|
||||||
prompt_template=a.prompt_template,
|
|
||||||
schedule_cron=a.schedule_cron,
|
|
||||||
filter_config=a.filter_config,
|
|
||||||
enabled=a.enabled,
|
|
||||||
last_run_at=_dt_ms_opt(a.last_run_at),
|
|
||||||
created_at=_dt_ms(a.created_at),
|
|
||||||
updated_at=_dt_ms(a.updated_at),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||||
@@ -105,77 +78,42 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Ownership-checked lookups ─────────────────────────────────────────
|
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||||
|
|
||||||
async def _get_local_agent_for_user(
|
|
||||||
agent_id: str, user_id: str, db: AsyncSession
|
|
||||||
) -> LocalAgentConfig:
|
|
||||||
result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.id == agent_id,
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_cloud_agent_for_user(
|
|
||||||
agent_id: str, user_id: str, db: AsyncSession
|
|
||||||
) -> CloudAgentConfig:
|
|
||||||
result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(
|
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tier limit helper ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
|
|
||||||
"""Return combined enabled local + cloud agent count for the user."""
|
|
||||||
local_count = (
|
|
||||||
await db.execute(
|
|
||||||
select(func.count(LocalAgentConfig.id)).where(
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
LocalAgentConfig.enabled == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one()
|
|
||||||
cloud_count = (
|
|
||||||
await db.execute(
|
|
||||||
select(func.count(CloudAgentConfig.id)).where(
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
CloudAgentConfig.enabled == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one()
|
|
||||||
return local_count + cloud_count
|
|
||||||
|
|
||||||
|
|
||||||
def _enforce_agent_limit(tier: str, current_count: int) -> None:
|
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
if limit != -1 and current_count >= limit:
|
if limit != -1 and current_count >= limit:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
)
|
)
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
# ── Local page schema (used by runs endpoint) ─────────────────────────
|
async def _enforce_run_frequency(
|
||||||
|
tier: str,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||||
|
if limit == -1:
|
||||||
|
return # unlimited
|
||||||
|
|
||||||
class _RunsPage(BaseModel):
|
today_start = datetime.now(timezone.utc).replace(
|
||||||
total: int
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
page: int
|
)
|
||||||
limit: int
|
result = await db.execute(
|
||||||
items: list[AgentRunLogResponse]
|
select(func.count(AgentRunLog.id)).where(
|
||||||
|
AgentRunLog.user_id == user_id,
|
||||||
|
AgentRunLog.started_at >= today_start,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
runs_today: int = result.scalar_one()
|
||||||
|
|
||||||
|
if runs_today >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Catalog ───────────────────────────────────────────────────────────
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
@@ -209,229 +147,61 @@ async def get_agent_catalog(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# ── Local agent CRUD ──────────────────────────────────────────────────
|
@router.post("/can-create", response_model=AgentCreationCheckResponse)
|
||||||
|
async def can_create_agent(
|
||||||
@router.get("/local", response_model=list[LocalAgentConfigResponse])
|
body: AgentCreationCheckRequest,
|
||||||
async def list_local_agents(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
) -> AgentCreationCheckResponse:
|
||||||
) -> list[LocalAgentConfigResponse]:
|
"""Check if the user can create one more agent based on billing tier.
|
||||||
"""List all local directory agent configs owned by the authenticated user."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
return [_to_local_response(a) for a in result.scalars().all()]
|
|
||||||
|
|
||||||
|
Since configuration is client-owned, the Electron app sends its current
|
||||||
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
active agent count and the backend applies tier limits.
|
||||||
async def create_local_agent(
|
|
||||||
body: LocalAgentConfigCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> LocalAgentConfigResponse:
|
|
||||||
"""Create a new local directory agent config.
|
|
||||||
|
|
||||||
The combined count of enabled local and cloud agents for the user is
|
|
||||||
checked against the ``batch_active`` limit for their billing tier.
|
|
||||||
"""
|
"""
|
||||||
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
||||||
agent = LocalAgentConfig(
|
allowed = limit == -1 or body.active_agents < limit
|
||||||
user_id=current_user.id,
|
return AgentCreationCheckResponse(
|
||||||
name=body.name,
|
allowed=allowed,
|
||||||
device_id=body.device_id,
|
tier=current_user.tier,
|
||||||
directory_paths=body.directory_paths,
|
active_agents=body.active_agents,
|
||||||
data_types=body.data_types,
|
limit=limit,
|
||||||
prompt_template=body.prompt_template,
|
|
||||||
file_extensions=body.file_extensions,
|
|
||||||
schedule_cron=body.schedule_cron,
|
|
||||||
)
|
)
|
||||||
db.add(agent)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_local_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
|
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||||
async def update_local_agent(
|
|
||||||
agent_id: str,
|
|
||||||
body: LocalAgentConfigUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> LocalAgentConfigResponse:
|
|
||||||
"""Partially update a local agent config. Only provided fields are changed."""
|
|
||||||
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(agent, field, value)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_local_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/local/{agent_id}", response_model=dict)
|
|
||||||
async def delete_local_agent(
|
|
||||||
agent_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a local agent config. Associated run logs are cascade-deleted."""
|
|
||||||
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
await db.delete(agent)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud agent CRUD ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
|
|
||||||
async def list_cloud_agents(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[CloudAgentConfigResponse]:
|
|
||||||
"""List all cloud connector agent configs owned by the authenticated user."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
return [_to_cloud_response(a) for a in result.scalars().all()]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def create_cloud_agent(
|
|
||||||
body: CloudAgentConfigCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> CloudAgentConfigResponse:
|
|
||||||
"""Create a new cloud connector agent config.
|
|
||||||
|
|
||||||
The combined count of enabled local and cloud agents for the user is
|
|
||||||
checked against the ``batch_active`` limit for their billing tier.
|
|
||||||
"""
|
|
||||||
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
|
||||||
agent = CloudAgentConfig(
|
|
||||||
user_id=current_user.id,
|
|
||||||
provider=body.provider,
|
|
||||||
name=body.name,
|
|
||||||
data_types=body.data_types,
|
|
||||||
prompt_template=body.prompt_template,
|
|
||||||
oauth_token_encrypted=body.oauth_token_encrypted,
|
|
||||||
schedule_cron=body.schedule_cron,
|
|
||||||
filter_config=body.filter_config,
|
|
||||||
)
|
|
||||||
db.add(agent)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_cloud_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
|
|
||||||
async def update_cloud_agent(
|
|
||||||
agent_id: str,
|
|
||||||
body: CloudAgentConfigUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> CloudAgentConfigResponse:
|
|
||||||
"""Partially update a cloud agent config. Only provided fields are changed."""
|
|
||||||
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(agent, field, value)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_cloud_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/cloud/{agent_id}", response_model=dict)
|
|
||||||
async def delete_cloud_agent(
|
|
||||||
agent_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
|
|
||||||
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
await db.delete(agent)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Run logs ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/runs", response_model=_RunsPage)
|
|
||||||
async def list_run_logs(
|
|
||||||
agent_id: str | None = Query(default=None),
|
|
||||||
page: int = Query(default=1, ge=1),
|
|
||||||
limit: int = Query(default=20, ge=1, le=100),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> _RunsPage:
|
|
||||||
"""Return paginated run logs for the authenticated user.
|
|
||||||
|
|
||||||
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
|
|
||||||
"""
|
|
||||||
base_filter = [AgentRunLog.user_id == current_user.id]
|
|
||||||
if agent_id:
|
|
||||||
base_filter.append(AgentRunLog.agent_id == agent_id)
|
|
||||||
|
|
||||||
total = (
|
|
||||||
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
|
|
||||||
).scalar_one()
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(AgentRunLog)
|
|
||||||
.where(*base_filter)
|
|
||||||
.order_by(AgentRunLog.started_at.desc())
|
|
||||||
.offset((page - 1) * limit)
|
|
||||||
.limit(limit)
|
|
||||||
)
|
|
||||||
items = [_to_run_log_response(log) for log in result.scalars().all()]
|
|
||||||
|
|
||||||
return _RunsPage(total=total, page=page, limit=limit, items=items)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Manual trigger stub ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
|
||||||
async def trigger_agent_run(
|
async def trigger_agent_run(
|
||||||
agent_id: str,
|
body: AgentTriggerRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> AgentRunLogResponse:
|
) -> AgentRunLogResponse:
|
||||||
"""Manually trigger an agent run.
|
"""Trigger a local agent run using client-provided configuration."""
|
||||||
|
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||||
|
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
||||||
|
|
||||||
Looks up the agent config (local or cloud) by ID with ownership check,
|
config = LocalAgentConfig(
|
||||||
creates a run log entry with ``status="running"``, and returns it.
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=current_user.id,
|
||||||
|
device_id=body.device_id,
|
||||||
|
name="Local Directory Monitor",
|
||||||
|
directory_paths=[body.directory],
|
||||||
|
data_types=_to_data_types(body.what_to_extract),
|
||||||
|
prompt_template=body.custom_agent_prompt,
|
||||||
|
file_extensions=[],
|
||||||
|
schedule_cron=body.batch_interval,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
Actual dispatch to the agent runner is wired in Step 3.4 once
|
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
||||||
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
stable_agent_id = body.agent_id or config.id
|
||||||
"""
|
|
||||||
# Determine agent type by trying local first, then cloud.
|
|
||||||
# Keep the full config object so we can pass it to the agent runner.
|
|
||||||
local_config: LocalAgentConfig | None = None
|
|
||||||
cloud_config: CloudAgentConfig | None = None
|
|
||||||
|
|
||||||
local_result = await db.execute(
|
if is_agent_running(stable_agent_id):
|
||||||
select(LocalAgentConfig).where(
|
raise HTTPException(
|
||||||
LocalAgentConfig.id == agent_id,
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
LocalAgentConfig.user_id == current_user.id,
|
detail="Agent is already running. Only one run per agent is allowed at a time.",
|
||||||
)
|
)
|
||||||
)
|
|
||||||
local_config = local_result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if local_config is not None:
|
|
||||||
agent_type = "local"
|
|
||||||
else:
|
|
||||||
cloud_result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(
|
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == current_user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cloud_config = cloud_result.scalar_one_or_none()
|
|
||||||
if cloud_config is not None:
|
|
||||||
agent_type = "cloud"
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
|
|
||||||
run_log = AgentRunLog(
|
run_log = AgentRunLog(
|
||||||
agent_id=agent_id,
|
agent_id=stable_agent_id,
|
||||||
agent_type=agent_type,
|
agent_type="local",
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
status="running",
|
status="running",
|
||||||
)
|
)
|
||||||
@@ -439,14 +209,14 @@ async def trigger_agent_run(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(run_log)
|
await db.refresh(run_log)
|
||||||
|
|
||||||
# Dispatch the run as a background task — returns 202 immediately.
|
run_context = {
|
||||||
if agent_type == "local" and local_config is not None:
|
"type": "agent_batch",
|
||||||
|
"run_id": run_log.id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
}
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
||||||
)
|
|
||||||
elif agent_type == "cloud" and cloud_config is not None:
|
|
||||||
asyncio.create_task(
|
|
||||||
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return _to_run_log_response(run_log)
|
return _to_run_log_response(run_log)
|
||||||
|
|||||||
@@ -10,9 +10,7 @@ from fastapi.responses import JSONResponse
|
|||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.core.deep_agent import run_home
|
from app.core.deep_agent import run_home
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.schemas import ChatRequest, UserProfile
|
||||||
from app.db import async_session
|
|
||||||
from app.schemas import ChatRequest, ChatResponse, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
@@ -22,21 +20,10 @@ async def chat(
|
|||||||
body: ChatRequest,
|
body: ChatRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""Route a chat message through the Home deep agent (non-streaming)."""
|
"""REST fallback for home chat when websocket streaming is unavailable."""
|
||||||
async with async_session() as db:
|
response = await run_home(
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
memory_context = await memory.enrich_context(current_user.id, body.message)
|
|
||||||
|
|
||||||
context = {
|
|
||||||
**body.context.model_dump(),
|
|
||||||
**memory_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
response_text = await run_home(
|
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
message=body.message,
|
message=body.message,
|
||||||
context=context,
|
context=body.context.model_dump(),
|
||||||
db_session_factory=async_session,
|
|
||||||
)
|
)
|
||||||
result = ChatResponse(response=response_text)
|
return JSONResponse(content={"response": response})
|
||||||
return JSONResponse(content=result.model_dump())
|
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ Protocol:
|
|||||||
|
|
||||||
Incoming frame dispatch:
|
Incoming frame dispatch:
|
||||||
- ``tool_result`` → resolves a pending tool-call Future.
|
- ``tool_result`` → resolves a pending tool-call Future.
|
||||||
- ``agent_data`` → enqueued in the per-run agent data queue.
|
- ``journey_start`` → starts a guided setup journey session.
|
||||||
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
- ``journey_message`` → continues a journey conversation.
|
||||||
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||||
- unknown types → logged, ignored.
|
- unknown types → logged, ignored.
|
||||||
|
|
||||||
@@ -39,12 +39,13 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
|
from app.core.deep_agent import run_floating_stream, run_home_stream
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.deep_agent import run_home_stream, run_floating_stream
|
from app.core.output_formatter import StreamFormatter
|
||||||
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
@@ -147,37 +148,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
"device_ws: tool_result missing id from user=%s", user_id
|
"device_ws: tool_result missing id from user=%s", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.agent_data:
|
|
||||||
run_id = frame.get("run_id")
|
|
||||||
if run_id:
|
|
||||||
try:
|
|
||||||
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
|
||||||
await queue.put(frame)
|
|
||||||
except RuntimeError:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_data for unknown run user=%s run=%s",
|
|
||||||
user_id,
|
|
||||||
run_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_data missing run_id from user=%s", user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.agent_complete:
|
|
||||||
run_id = frame.get("run_id")
|
|
||||||
if run_id:
|
|
||||||
try:
|
|
||||||
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
|
||||||
# Sentinel: signals the agent data stream is finished.
|
|
||||||
await queue.put(None)
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_complete missing run_id from user=%s", user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.home_request:
|
elif frame_type == WsFrameType.home_request:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_handle_home_request(websocket, user_id, frame)
|
_handle_home_request(websocket, user_id, frame)
|
||||||
@@ -188,6 +158,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_floating_request(websocket, user_id, frame)
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.journey_start:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_journey_start(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.journey_message:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_journey_message(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -200,35 +180,13 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
|
|
||||||
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
_WS_TOOL_CALL_TIMEOUT = 30 # seconds to wait for Electron tool_result
|
|
||||||
|
|
||||||
|
|
||||||
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||||
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||||
async def _executor(payload: dict) -> dict:
|
async def _executor(payload: dict) -> dict:
|
||||||
payload["type"] = WsFrameType.tool_call
|
payload["type"] = WsFrameType.tool_call
|
||||||
call_id = payload["id"]
|
|
||||||
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
|
|
||||||
await websocket.send_text(json.dumps(payload))
|
await websocket.send_text(json.dumps(payload))
|
||||||
future = device_manager.create_pending_call(user_id, call_id)
|
future = device_manager.create_pending_call(user_id, payload["id"])
|
||||||
try:
|
return await future
|
||||||
result = await asyncio.wait_for(future, timeout=_WS_TOOL_CALL_TIMEOUT)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(
|
|
||||||
"ws_executor: timeout waiting for tool_result id=%s action=%s user=%s",
|
|
||||||
call_id, payload.get("action"), user_id,
|
|
||||||
)
|
|
||||||
# Clean up the pending future so it doesn't leak
|
|
||||||
conn = device_manager._connections.get(user_id)
|
|
||||||
if conn:
|
|
||||||
conn.pending_calls.pop(call_id, None)
|
|
||||||
return {"error": f"Tool call timed out after {_WS_TOOL_CALL_TIMEOUT}s", "rows": []}
|
|
||||||
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
|
|
||||||
call_id, type(result).__name__,
|
|
||||||
list(result.keys()) if isinstance(result, dict) else "N/A")
|
|
||||||
if result is None:
|
|
||||||
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
|
|
||||||
return result
|
|
||||||
return _executor
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
@@ -241,14 +199,27 @@ async def _handle_home_request(
|
|||||||
request_id = frame.get("request_id") or str(uuid4())
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
logger.info(
|
||||||
|
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message)
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
message,
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,12 +227,11 @@ async def _handle_home_request(
|
|||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_home_stream(
|
event_stream = run_home_stream(user_id, message, context)
|
||||||
user_id, message, context, db_session_factory=async_session
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
)
|
|
||||||
formatter = HomeFormatter(request_id=request_id)
|
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
# Collect text chunks to build the full response for episode storage
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -276,7 +246,14 @@ async def _handle_home_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks)
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
len("".join(response_chunks)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -290,23 +267,37 @@ async def _handle_floating_request(
|
|||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
scope: dict = frame.get("scope", {})
|
scope: dict = frame.get("scope", {})
|
||||||
|
logger.info(
|
||||||
|
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
json.dumps(scope, ensure_ascii=True)[:200],
|
||||||
|
message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message)
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
message,
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
context: dict = {"scope": scope, **memory_context}
|
context: dict = {
|
||||||
|
"scope": scope,
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_floating_stream(
|
event_stream = run_floating_stream(user_id, message, context)
|
||||||
user_id, message, context, scope=scope,
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
db_session_factory=async_session,
|
|
||||||
)
|
|
||||||
formatter = FloatingFormatter(request_id=request_id)
|
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
@@ -323,8 +314,72 @@ async def _handle_floating_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks)
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
len("".join(response_chunks)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_start(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a journey_start frame — explores directory and sends first question."""
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_start(user_id, frame)
|
||||||
|
await websocket.send_text(json.dumps(reply))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": frame.get("session_id", ""),
|
||||||
|
"message": f"Failed to start journey: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}))
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_message(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a journey_message frame — continues the journey conversation."""
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_message(user_id, frame)
|
||||||
|
await websocket.send_text(json.dumps(reply))
|
||||||
|
except Exception as exc:
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
logger.error(
|
||||||
|
"device_ws: journey_message failed user=%s session=%s: %s",
|
||||||
|
user_id, session_id, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": f"Journey error: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}))
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
@@ -360,6 +415,3 @@ async def _mark_runs_disconnected(user_id: str) -> None:
|
|||||||
user_id,
|
user_id,
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"free": {
|
"free": {
|
||||||
"agents": 3,
|
"agents": 3,
|
||||||
"batch_active": 2,
|
"batch_active": 2,
|
||||||
|
"batch_runs_per_day": 5,
|
||||||
"cloud_storage_gb": 0,
|
"cloud_storage_gb": 0,
|
||||||
"backup_gb": 0,
|
"backup_gb": 0,
|
||||||
"providers": 1,
|
"providers": 1,
|
||||||
@@ -31,6 +32,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
"batch_active": 10,
|
"batch_active": 10,
|
||||||
|
"batch_runs_per_day": 50,
|
||||||
"cloud_storage_gb": 5,
|
"cloud_storage_gb": 5,
|
||||||
"backup_gb": 5,
|
"backup_gb": 5,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -41,6 +43,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1, # unlimited
|
"batch_active": -1, # unlimited
|
||||||
|
"batch_runs_per_day": -1, # unlimited
|
||||||
"cloud_storage_gb": 25,
|
"cloud_storage_gb": 25,
|
||||||
"backup_gb": 25,
|
"backup_gb": 25,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -51,6 +54,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1,
|
"batch_active": -1,
|
||||||
|
"batch_runs_per_day": -1, # unlimited
|
||||||
"cloud_storage_gb": -1, # unlimited
|
"cloud_storage_gb": -1, # unlimited
|
||||||
"backup_gb": -1, # unlimited
|
"backup_gb": -1, # unlimited
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -77,16 +81,18 @@ class TierManager:
|
|||||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
"""Return the current billing tier for ``user_id`` from the DB.
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
Falls back to ``'free'`` when no subscription row exists.
|
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
|
||||||
|
when no subscription row exists.
|
||||||
"""
|
"""
|
||||||
from app.models import Subscription # noqa: PLC0415
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str | None = result.scalar_one_or_none()
|
tier: str | None = result.scalar_one_or_none()
|
||||||
if tier is None or tier not in FEATURES:
|
if tier is None or tier not in FEATURES:
|
||||||
return "free"
|
return "power" if settings.ENV == "dev" else "free"
|
||||||
return tier # type: ignore[return-value]
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
# ── Feature access ───────────────────────────────────────────────────
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ class Settings(BaseSettings):
|
|||||||
CEREBRAS_API_KEY: str = ""
|
CEREBRAS_API_KEY: str = ""
|
||||||
|
|
||||||
LLM_MODEL: str = "gpt-4o"
|
LLM_MODEL: str = "gpt-4o"
|
||||||
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
|
||||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
# GitHub Copilot OAuth token storage directory.
|
# GitHub Copilot OAuth token storage directory.
|
||||||
@@ -54,7 +53,9 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".env", env_file_encoding="utf-8", extra="ignore"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
30
app/core/agent_registry.py
Normal file
30
app/core/agent_registry.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""Minimal agent base types retained for compatibility with batch runners."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgent(ABC):
|
||||||
|
"""Common base for non-chat agents still using the old base contract."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_id: str = "",
|
||||||
|
shared_memory: dict[str, Any] | None = None,
|
||||||
|
vector_store_context: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.user_id = user_id
|
||||||
|
self.shared_memory: dict[str, Any] = shared_memory or {}
|
||||||
|
self.vector_store_context: list[str] = vector_store_context or []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_name(self) -> str: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_description(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def skills(self) -> list[str]:
|
||||||
|
return []
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -3,20 +3,15 @@
|
|||||||
Maintains in-memory state for all active Electron → backend WebSocket
|
Maintains in-memory state for all active Electron → backend WebSocket
|
||||||
connections. One connection per user (latest replaces previous).
|
connections. One connection per user (latest replaces previous).
|
||||||
|
|
||||||
The manager participates in two interaction patterns:
|
The manager handles the **tool-call round-trip** pattern:
|
||||||
|
- Backend sends ``tool_call`` frame → Electron executes the action →
|
||||||
1. **Tool-call round-trip** (bidirectional CRUD):
|
returns ``tool_result`` frame.
|
||||||
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
|
||||||
``tool_result`` frame.
|
|
||||||
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||||
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||||
receive the result dict from Electron.
|
receive the result dict from Electron.
|
||||||
|
|
||||||
2. **Agent-data streaming** (local directory agent runs):
|
This pattern is used by all tools (CRUD, file-system, etc.) via
|
||||||
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
``execute_on_client()`` in ``ws_context.py``.
|
||||||
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
|
||||||
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
|
||||||
a specific ``run_id`` so the agent runner can iterate frames.
|
|
||||||
|
|
||||||
The ``device_manager`` module-level singleton is imported by both the
|
The ``device_manager`` module-level singleton is imported by both the
|
||||||
device WS route and the agent runner.
|
device WS route and the agent runner.
|
||||||
@@ -42,8 +37,6 @@ class DeviceConnection:
|
|||||||
device_id: str
|
device_id: str
|
||||||
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
# Per-run queues for agent_data / agent_complete frames.
|
|
||||||
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceConnectionManager:
|
class DeviceConnectionManager:
|
||||||
@@ -153,31 +146,6 @@ class DeviceConnectionManager:
|
|||||||
if fut is not None and not fut.done():
|
if fut is not None and not fut.done():
|
||||||
fut.set_result(result)
|
fut.set_result(result)
|
||||||
|
|
||||||
# ── Agent-data queue ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
def get_agent_data_queue(
|
|
||||||
self, user_id: str, run_id: str
|
|
||||||
) -> asyncio.Queue[dict | None]:
|
|
||||||
"""Return (creating if absent) the queue for *run_id* agent frames.
|
|
||||||
|
|
||||||
The agent runner reads from this queue. The device WS handler writes
|
|
||||||
to it. ``None`` is the sentinel that signals the stream is finished.
|
|
||||||
"""
|
|
||||||
conn = self._connections.get(user_id)
|
|
||||||
if conn is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"get_agent_data_queue: user {user_id!r} is not connected"
|
|
||||||
)
|
|
||||||
if run_id not in conn.agent_data_queues:
|
|
||||||
conn.agent_data_queues[run_id] = asyncio.Queue()
|
|
||||||
return conn.agent_data_queues[run_id]
|
|
||||||
|
|
||||||
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
|
||||||
"""Remove the queue for *run_id* once a run has completed."""
|
|
||||||
conn = self._connections.get(user_id)
|
|
||||||
if conn:
|
|
||||||
conn.agent_data_queues.pop(run_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — import this everywhere.
|
# Module-level singleton — import this everywhere.
|
||||||
device_manager = DeviceConnectionManager()
|
device_manager = DeviceConnectionManager()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
|
Every agent and the orchestrator call ``get_llm()``
|
||||||
instead of directly constructing a provider-specific class. The model string
|
instead of directly constructing a provider-specific class. The model string
|
||||||
follows the `LiteLLM model naming convention
|
follows the `LiteLLM model naming convention
|
||||||
<https://docs.litellm.ai/docs/providers>`_:
|
<https://docs.litellm.ai/docs/providers>`_:
|
||||||
@@ -11,13 +11,14 @@ follows the `LiteLLM model naming convention
|
|||||||
* Ollama: ``ollama/llama3``
|
* Ollama: ``ollama/llama3``
|
||||||
* Bedrock: ``bedrock/anthropic.claude-v2``
|
* Bedrock: ``bedrock/anthropic.claude-v2``
|
||||||
|
|
||||||
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
Switch providers by changing **LLM_MODEL** in ``.env``
|
||||||
— no code changes required.
|
— no code changes required.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import litellm
|
import litellm
|
||||||
@@ -32,6 +33,14 @@ from app.config.settings import settings
|
|||||||
# Drop them silently instead of raising UnsupportedParamsError.
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
# Some provider responses include a plain dict in the `usage` field where a
|
||||||
|
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
@@ -86,14 +95,6 @@ def get_llm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_router_llm(
|
|
||||||
*,
|
|
||||||
temperature: float = 0,
|
|
||||||
) -> ChatOpenAI | ChatLiteLLM:
|
|
||||||
"""Return the lighter model used for intent classification / routing."""
|
|
||||||
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
|
||||||
|
|
||||||
|
|
||||||
async def embed(text: str) -> list[float]:
|
async def embed(text: str) -> list[float]:
|
||||||
"""Return an embedding vector for *text*.
|
"""Return an embedding vector for *text*.
|
||||||
|
|
||||||
|
|||||||
@@ -43,15 +43,21 @@ _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
|||||||
|
|
||||||
|
|
||||||
class MemoryMiddleware:
|
class MemoryMiddleware:
|
||||||
"""Enrich agent context with memory and persist interactions after."""
|
"""Enrich orchestrator context with memory and persist interactions after."""
|
||||||
|
|
||||||
def __init__(self, db: AsyncSession) -> None:
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
self._db = db
|
self._db = db
|
||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────────────
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
async def enrich_context(
|
||||||
"""Build memory context dict to inject into the agent before LLM call.
|
self,
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build memory context dict to inject into the orchestrator before LLM call.
|
||||||
|
|
||||||
Returns a dict with keys:
|
Returns a dict with keys:
|
||||||
core_memory — {key: plaintext_value, ...}
|
core_memory — {key: plaintext_value, ...}
|
||||||
@@ -65,9 +71,21 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
core = await self._load_core(user_id, fernet)
|
||||||
associative = await self._load_associative(user_id, message, fernet)
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
episodic = await self._load_episodic(user_id, fernet)
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
len(core),
|
||||||
|
len(associative),
|
||||||
|
len(episodic),
|
||||||
|
len(proactive),
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"core_memory": core,
|
"core_memory": core,
|
||||||
"associative_memory": associative,
|
"associative_memory": associative,
|
||||||
@@ -81,6 +99,7 @@ class MemoryMiddleware:
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
response: str,
|
response: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Summarise and store a completed interaction in episodic memory.
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
@@ -103,11 +122,19 @@ class MemoryMiddleware:
|
|||||||
self._db.add(row)
|
self._db.add(row)
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||||
"""Upsert a core memory key/value for a user."""
|
"""Upsert a core memory key/value for a user."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -133,10 +160,176 @@ class MemoryMiddleware:
|
|||||||
))
|
))
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
key,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
||||||
|
"""Return core memory as editable blocks (label/value)."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore)
|
||||||
|
.where(MemoryCore.user_id == user_id)
|
||||||
|
.order_by(MemoryCore.key.asc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[dict[str, str]] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append({"label": row.key, "value": plaintext})
|
||||||
|
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
||||||
|
"""Return a single core memory block value by label."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
||||||
|
return None
|
||||||
|
value = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def delete_core(self, user_id: str, label: str) -> bool:
|
||||||
|
"""Delete a core memory block by label. Returns True if deleted."""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._db.delete(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
||||||
|
"""Append content to a core block, creating it if missing."""
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None:
|
||||||
|
await self.update_core(user_id, label, content)
|
||||||
|
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
||||||
|
return
|
||||||
|
await self.update_core(user_id, label, f"{current}\n{content}")
|
||||||
|
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
||||||
|
|
||||||
|
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
||||||
|
"""Replace one exact string inside a core block. Returns False if not found."""
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None or old not in current:
|
||||||
|
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
||||||
|
return False
|
||||||
|
await self.update_core(user_id, label, current.replace(old, new, 1))
|
||||||
|
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||||
|
"""Insert a long-term archival memory entry."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, content)
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
content_encrypted=encrypted,
|
||||||
|
embedding=None,
|
||||||
|
entity_type=source,
|
||||||
|
entity_id=None,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc())
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
"""Search recall memory (episodic summaries) by keyword."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
# ── Private helpers ───────────────────────────────────────────────────────
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
@@ -148,6 +341,16 @@ class MemoryMiddleware:
|
|||||||
return None
|
return None
|
||||||
return Fernet(user.encryption_key.encode())
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
|
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||||
|
"""Load lightweight user debug fields for trace logs."""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
return {"tier": None}
|
||||||
|
return {
|
||||||
|
"tier": user.tier,
|
||||||
|
}
|
||||||
|
|
||||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
@@ -183,10 +386,17 @@ class MemoryMiddleware:
|
|||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
async def _load_episodic(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
fernet: Fernet,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||||
|
if session_id:
|
||||||
|
query = query.where(MemoryEpisodic.session_id == session_id)
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryEpisodic)
|
query
|
||||||
.where(MemoryEpisodic.user_id == user_id)
|
|
||||||
.order_by(MemoryEpisodic.created_at.desc())
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
.limit(_EPISODIC_RECENT_N)
|
.limit(_EPISODIC_RECENT_N)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,141 +1,47 @@
|
|||||||
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
|
"""Output formatter for deep-agent stream events."""
|
||||||
|
|
||||||
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
|
||||||
* ``("token", str)`` — supervisor text token
|
|
||||||
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
|
|
||||||
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
|
||||||
|
|
||||||
HomeFormatter:
|
|
||||||
* Streams text tokens as-is → emits ``WsStreamText``
|
|
||||||
(text may contain inline ``<type>[id,...]</type>`` entity tags
|
|
||||||
for the frontend to parse and render as interactive components)
|
|
||||||
* Attaches mutations → injects into ``WsStreamEnd``
|
|
||||||
|
|
||||||
FloatingFormatter:
|
|
||||||
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
|
||||||
* Streams text tokens → emits ``WsStreamText``
|
|
||||||
* Attaches mutations → injects into ``WsStreamEnd``
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.schemas import (
|
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
WsFloatingDomain,
|
|
||||||
WsStreamEnd,
|
|
||||||
WsStreamStart,
|
|
||||||
WsStreamText,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Map sub-agent tool name → floating domain / entity type
|
|
||||||
_AGENT_DOMAIN: dict[str, str] = {
|
|
||||||
"task_agent": "tasks",
|
|
||||||
"timeline_agent": "timelines",
|
|
||||||
"note_agent": "notes",
|
|
||||||
"project_agent": "projects",
|
|
||||||
}
|
|
||||||
|
|
||||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
class HomeFormatter:
|
class StreamFormatter:
|
||||||
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
||||||
|
|
||||||
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
|
|
||||||
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
|
|
||||||
is responsible for parsing those and rendering interactive components.
|
|
||||||
Mutations are attached to ``WsStreamEnd``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: str) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self._mutations: list[dict] = []
|
|
||||||
|
|
||||||
async def format(
|
async def format(
|
||||||
self,
|
self,
|
||||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
started = False
|
||||||
|
|
||||||
async for event_type, data in event_stream:
|
async for event_type, data in event_stream:
|
||||||
if event_type == "token":
|
if event_type == "floating_domain":
|
||||||
if data:
|
if isinstance(data, dict):
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=data)
|
|
||||||
|
|
||||||
elif event_type == "mutations":
|
|
||||||
self._mutations = data or []
|
|
||||||
|
|
||||||
yield WsStreamEnd(
|
|
||||||
request_id=self.request_id,
|
|
||||||
mutations=[
|
|
||||||
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
|
||||||
for m in self._mutations
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FloatingFormatter:
|
|
||||||
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
|
|
||||||
|
|
||||||
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
|
|
||||||
``task_agent`` → ``"tasks"``), then streams text tokens as plain
|
|
||||||
``WsStreamText``. No block parsing for floating context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
|
||||||
self.request_id = request_id
|
|
||||||
self._mutations: list[dict] = []
|
|
||||||
|
|
||||||
async def format(
|
|
||||||
self,
|
|
||||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
|
||||||
domain_sent = False
|
|
||||||
|
|
||||||
async for event_type, data in event_stream:
|
|
||||||
if event_type == "tool_end" and not domain_sent:
|
|
||||||
# Sniff domain from the first sub-agent that completes
|
|
||||||
name = data.get("name", "")
|
|
||||||
domain = _AGENT_DOMAIN.get(name, "tasks")
|
|
||||||
yield WsFloatingDomain(
|
yield WsFloatingDomain(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
domain=domain, # type: ignore[arg-type]
|
domain=data,
|
||||||
)
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event_type != "token":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not started:
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
domain_sent = True
|
started = True
|
||||||
|
|
||||||
elif event_type == "token":
|
text = str(data or "")
|
||||||
if not domain_sent:
|
if text:
|
||||||
# First token arrived before any tool_end — default domain
|
yield WsStreamText(request_id=self.request_id, chunk=text)
|
||||||
yield WsFloatingDomain(
|
|
||||||
request_id=self.request_id,
|
if not started:
|
||||||
domain="tasks", # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
domain_sent = True
|
yield WsStreamEnd(request_id=self.request_id)
|
||||||
if data:
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=data)
|
|
||||||
|
|
||||||
elif event_type == "mutations":
|
|
||||||
self._mutations = data or []
|
|
||||||
|
|
||||||
# If no events triggered domain_sent (edge case), still emit structure
|
|
||||||
if not domain_sent:
|
|
||||||
yield WsFloatingDomain(
|
|
||||||
request_id=self.request_id,
|
|
||||||
domain="tasks", # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
|
||||||
|
|
||||||
yield WsStreamEnd(
|
|
||||||
request_id=self.request_id,
|
|
||||||
mutations=[
|
|
||||||
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
|
||||||
for m in self._mutations
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -7,21 +7,18 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any, Callable, Coroutine
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Holds the execute callback for the current WS session.
|
# Holds the execute callback for the current WS session.
|
||||||
# Set by the chat WS handler before the deep agent runs; cleared after.
|
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
||||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||||
"_client_executor"
|
"_client_executor"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional collector that captures raw execute_on_client results.
|
# Optional collector that captures raw execute_on_client results.
|
||||||
# Set by the deep agent tool loop to capture CRUD mutations.
|
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
||||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
"_tool_result_collector", default=None
|
"_tool_result_collector", default=None
|
||||||
)
|
)
|
||||||
@@ -84,17 +81,12 @@ async def execute_on_client(
|
|||||||
if limit is not None:
|
if limit is not None:
|
||||||
payload["limit"] = limit
|
payload["limit"] = limit
|
||||||
|
|
||||||
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
|
|
||||||
result = await callback(payload)
|
result = await callback(payload)
|
||||||
if result is None:
|
|
||||||
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
|
|
||||||
else:
|
|
||||||
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
|
|
||||||
collector = _tool_result_collector.get(None)
|
collector = _tool_result_collector.get(None)
|
||||||
if collector is not None and action in ("insert", "update", "delete"):
|
if collector is not None:
|
||||||
collector.append({
|
collector.append({
|
||||||
"action": action,
|
"action": action,
|
||||||
"table": table,
|
"table": table,
|
||||||
"data": data or {},
|
"data": result,
|
||||||
})
|
})
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -18,7 +18,9 @@ from app.config.settings import settings
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup: initialise DB connection pool
|
# Startup: ensure agent tool modules are loaded.
|
||||||
|
import app.agents # noqa: F401
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown: dispose SQLAlchemy connection pool
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
@@ -48,7 +50,7 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
@@ -58,7 +60,6 @@ def create_app() -> FastAPI:
|
|||||||
app.include_router(plugins.router, prefix="/api/v1")
|
app.include_router(plugins.router, prefix="/api/v1")
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
app.include_router(agent_setup.router, prefix="/api/v1")
|
|
||||||
app.include_router(device_ws.router, prefix="/api/v1")
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
|||||||
146
app/schemas.py
146
app/schemas.py
@@ -142,9 +142,6 @@ class WsFrameType(str, Enum):
|
|||||||
tool_result = "tool_result"
|
tool_result = "tool_result"
|
||||||
final = "final"
|
final = "final"
|
||||||
ping = "ping"
|
ping = "ping"
|
||||||
agent_run = "agent_run"
|
|
||||||
agent_data = "agent_data"
|
|
||||||
agent_complete = "agent_complete"
|
|
||||||
device_hello = "device_hello"
|
device_hello = "device_hello"
|
||||||
# ── v3 frame types ─────────────────────────────────────────────────
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
home_request = "home_request"
|
home_request = "home_request"
|
||||||
@@ -156,6 +153,10 @@ class WsFrameType(str, Enum):
|
|||||||
data_request = "data_request"
|
data_request = "data_request"
|
||||||
data_response = "data_response"
|
data_response = "data_response"
|
||||||
mutation = "mutation"
|
mutation = "mutation"
|
||||||
|
# ── v4 journey frame types ────────────────────────────────────────
|
||||||
|
journey_start = "journey_start"
|
||||||
|
journey_message = "journey_message"
|
||||||
|
journey_reply = "journey_reply"
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -208,31 +209,6 @@ class WsDeviceHello(BaseModel):
|
|||||||
agent_ids: list[str] = Field(default_factory=list)
|
agent_ids: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class WsAgentRun(BaseModel):
|
|
||||||
"""Server → Client: trigger an agent run on the connected device."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_run] = WsFrameType.agent_run
|
|
||||||
run_id: str
|
|
||||||
agent_id: str
|
|
||||||
config: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class WsAgentData(BaseModel):
|
|
||||||
"""Client → Server: files read by the local agent."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_data] = WsFrameType.agent_data
|
|
||||||
run_id: str
|
|
||||||
files: list[dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class WsAgentComplete(BaseModel):
|
|
||||||
"""Client → Server: Electron signals it has finished reading files."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_complete] = WsFrameType.agent_complete
|
|
||||||
run_id: str
|
|
||||||
files_read: int
|
|
||||||
errors: list[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
@@ -279,7 +255,14 @@ class WsStreamEnd(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
request_id: str
|
request_id: str
|
||||||
mutations: list[dict[str, Any]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
class WsDomain(BaseModel):
|
||||||
|
"""Structured floating domain payload for UI routing decisions."""
|
||||||
|
|
||||||
|
type: Literal["task", "timeline", "project", "node"]
|
||||||
|
id: str | None = None
|
||||||
|
section: Literal["task", "timeline", "note"] | None = None
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingDomain(BaseModel):
|
class WsFloatingDomain(BaseModel):
|
||||||
@@ -287,7 +270,7 @@ class WsFloatingDomain(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
request_id: str
|
request_id: str
|
||||||
domain: Literal["tasks", "timelines", "notes", "projects"]
|
domain: WsDomain
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
@@ -296,84 +279,28 @@ class AgentCatalogItem(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
config_schema: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local Agent Config ────────────────────────────────────────────────
|
class AgentCreationCheckRequest(BaseModel):
|
||||||
|
active_agents: int = Field(ge=0, default=0)
|
||||||
class LocalAgentConfigCreate(BaseModel):
|
|
||||||
name: str
|
|
||||||
device_id: str
|
|
||||||
directory_paths: list[str]
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
file_extensions: list[str]
|
|
||||||
schedule_cron: str
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfigUpdate(BaseModel):
|
class AgentCreationCheckResponse(BaseModel):
|
||||||
name: str | None = None
|
allowed: bool
|
||||||
device_id: str | None = None
|
tier: BillingTier
|
||||||
directory_paths: list[str] | None = None
|
active_agents: int
|
||||||
data_types: list[str] | None = None
|
limit: int
|
||||||
prompt_template: str | None = None
|
|
||||||
file_extensions: list[str] | None = None
|
|
||||||
schedule_cron: str | None = None
|
|
||||||
enabled: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfigResponse(BaseModel):
|
class AgentTriggerRequest(BaseModel):
|
||||||
id: str
|
directory: str = Field(min_length=1)
|
||||||
name: str
|
device_id: str = Field(default="")
|
||||||
device_id: str
|
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
||||||
directory_paths: list[str]
|
what_to_extract: list[str] = Field(min_length=1)
|
||||||
data_types: list[str]
|
actions_by_type: dict[str, list[str]] | None = None
|
||||||
prompt_template: str
|
batch_interval: str = Field(min_length=1)
|
||||||
file_extensions: list[str]
|
custom_agent_prompt: str = Field(min_length=1)
|
||||||
schedule_cron: str
|
active_agents: int = Field(ge=0, default=0)
|
||||||
enabled: bool
|
|
||||||
last_run_at: int | None
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud Agent Config ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class CloudAgentConfigCreate(BaseModel):
|
|
||||||
provider: Literal["gmail", "teams", "outlook"]
|
|
||||||
name: str
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
oauth_token_encrypted: str
|
|
||||||
schedule_cron: str
|
|
||||||
filter_config: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CloudAgentConfigUpdate(BaseModel):
|
|
||||||
provider: Literal["gmail", "teams", "outlook"] | None = None
|
|
||||||
name: str | None = None
|
|
||||||
data_types: list[str] | None = None
|
|
||||||
prompt_template: str | None = None
|
|
||||||
oauth_token_encrypted: str | None = None
|
|
||||||
schedule_cron: str | None = None
|
|
||||||
filter_config: dict[str, Any] | None = None
|
|
||||||
enabled: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CloudAgentConfigResponse(BaseModel):
|
|
||||||
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
provider: Literal["gmail", "teams", "outlook"]
|
|
||||||
name: str
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
schedule_cron: str
|
|
||||||
filter_config: dict[str, Any] | None
|
|
||||||
enabled: bool
|
|
||||||
last_run_at: int | None
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
@@ -392,18 +319,3 @@ class AgentRunLogResponse(BaseModel):
|
|||||||
|
|
||||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||||
|
|
||||||
class JourneyStartRequest(BaseModel):
|
|
||||||
agent_type: Literal["local", "cloud"]
|
|
||||||
agent_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class JourneyMessageRequest(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
class JourneyResponse(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
message: str
|
|
||||||
done: bool
|
|
||||||
prompt_template: str | None = None
|
|
||||||
|
|||||||
941
docs/MICROSERVICES_ARCHITECTURE.md
Normal file
941
docs/MICROSERVICES_ARCHITECTURE.md
Normal file
@@ -0,0 +1,941 @@
|
|||||||
|
# Adiuva — Architettura Microservizi (MVP)
|
||||||
|
|
||||||
|
## Panoramica
|
||||||
|
|
||||||
|
Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
|
||||||
|
|
||||||
|
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────┐
|
||||||
|
│ Cloudflare │
|
||||||
|
│ (DNS + CDN) │
|
||||||
|
└──────┬───────┘
|
||||||
|
│ HTTPS / WSS
|
||||||
|
┌──────▼───────┐
|
||||||
|
│ Traefik │
|
||||||
|
│ API Gateway │
|
||||||
|
│ (routing, │
|
||||||
|
│ TLS, rate │
|
||||||
|
│ limiting) │
|
||||||
|
└──────┬───────┘
|
||||||
|
│
|
||||||
|
┌──────────┬───────────┼───────────┐
|
||||||
|
│ │ │ │
|
||||||
|
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
|
||||||
|
│ Auth │ │ Chat │ │ Agent │ │Billing │
|
||||||
|
│ Service │ │Service│ │ Service │ │Service │
|
||||||
|
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
|
||||||
|
│ │ │ │
|
||||||
|
┌─────▼──────────▼──────────▼───────────▼────┐
|
||||||
|
│ Infrastruttura │
|
||||||
|
│ PostgreSQL │ Redis │ Qdrant │
|
||||||
|
└─────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Suddivisione dei Servizi
|
||||||
|
|
||||||
|
### 1.1 Auth Service (`auth-service`)
|
||||||
|
|
||||||
|
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
|
||||||
|
|
||||||
|
| Endpoint originale | Metodo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/auth/register` | POST |
|
||||||
|
| `/api/v1/auth/login` | POST |
|
||||||
|
| `/api/v1/auth/refresh` | POST |
|
||||||
|
| `/api/v1/auth/me` | GET / PUT |
|
||||||
|
|
||||||
|
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
|
||||||
|
|
||||||
|
**Modifica chiave — JWT con RS256**:
|
||||||
|
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
|
||||||
|
- L'Auth Service firma i JWT con la **chiave privata**.
|
||||||
|
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
|
||||||
|
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# auth-service/app/auth/jwt.py
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
PRIVATE_KEY = ... # Da env/secret
|
||||||
|
PUBLIC_KEY = ... # Derivata o da env
|
||||||
|
|
||||||
|
def create_access_token(user_id: str, tier: str) -> str:
|
||||||
|
return jwt.encode(
|
||||||
|
{"sub": user_id, "tier": tier, "exp": ...},
|
||||||
|
PRIVATE_KEY,
|
||||||
|
algorithm="RS256",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# shared/auth.py (usato da tutti gli altri servizi)
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
|
||||||
|
|
||||||
|
def verify_token(token: str) -> dict:
|
||||||
|
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
||||||
|
```
|
||||||
|
|
||||||
|
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.2 Chat Service (`chat-service`) ⭐ Real-time
|
||||||
|
|
||||||
|
**Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
|
||||||
|
|
||||||
|
Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
|
||||||
|
|
||||||
|
| Endpoint | Tipo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/ws/device` | WebSocket (connessione persistente) |
|
||||||
|
| `/api/v1/chat` | POST (REST fallback) |
|
||||||
|
|
||||||
|
**Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
|
||||||
|
|
||||||
|
**Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
|
||||||
|
|
||||||
|
**Scaling**: 2–N repliche. Sticky cookies per le WS + Redis per cross-instance.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.3 Agent Service (`agent-service`) ⭐ Batch
|
||||||
|
|
||||||
|
**Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
|
||||||
|
|
||||||
|
Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
|
||||||
|
|
||||||
|
| Endpoint | Tipo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/agents/catalog` | GET |
|
||||||
|
| `/api/v1/agents/can-create` | POST |
|
||||||
|
| `/api/v1/agents/trigger` | POST |
|
||||||
|
| `/api/v1/agents/journey/start` | POST (o WS relay) |
|
||||||
|
| `/api/v1/agents/journey/message` | POST (o WS relay) |
|
||||||
|
|
||||||
|
**Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
|
||||||
|
|
||||||
|
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────┐ ┌──────────────┐ ┌──────────┐
|
||||||
|
│ Agent Service│ │ Redis │ │ Chat │
|
||||||
|
│ (batch run) │ │ │ │ Service │
|
||||||
|
│ │ │ │ │ (ha WS) │
|
||||||
|
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
|
||||||
|
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
|
||||||
|
│ from │ │ │ │ al │
|
||||||
|
│ device │ │ │ │ device│
|
||||||
|
│ │ │ │ │ via WS│
|
||||||
|
│ │ SUBSCRIBE │ │ PUBLISH │ │
|
||||||
|
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
|
||||||
|
│ risultato │ │ │ │ reply │
|
||||||
|
└──────────────┘ └──────────────┘ └──────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**Scaling**: 1–N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
|
||||||
|
|
||||||
|
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.4 Billing Service (`billing-service`)
|
||||||
|
|
||||||
|
**Responsabilità**: Stripe checkout, webhook, subscription management.
|
||||||
|
|
||||||
|
| Endpoint originale | Metodo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/billing/checkout` | POST |
|
||||||
|
| `/api/v1/billing/webhook` | POST |
|
||||||
|
| `/api/v1/billing/subscription` | GET / DELETE |
|
||||||
|
|
||||||
|
**Database**: Tabelle `subscriptions` (schema `billing`).
|
||||||
|
|
||||||
|
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
|
||||||
|
|
||||||
|
**Scaling**: 1 replica sufficiente. Basso traffico.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.5 Servizi esclusi dall'MVP
|
||||||
|
|
||||||
|
I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
|
||||||
|
|
||||||
|
| Servizio | Responsabilità | Note |
|
||||||
|
|---|---|---|
|
||||||
|
| **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
|
||||||
|
| **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Tier Check — Dove e Come
|
||||||
|
|
||||||
|
Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
|
||||||
|
|
||||||
|
### Strategia: Tier nel JWT
|
||||||
|
|
||||||
|
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sub": "user_123",
|
||||||
|
"tier": "pro",
|
||||||
|
"exp": 1742515200,
|
||||||
|
"iat": 1742511600
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Ogni servizio:
|
||||||
|
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
|
||||||
|
2. Legge `payload["tier"]` — **zero chiamate extra**
|
||||||
|
3. Applica le sue regole di enforcement localmente
|
||||||
|
|
||||||
|
```python
|
||||||
|
# shared/auth.py — dependency FastAPI condivisa
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
PUBLIC_KEY = ...
|
||||||
|
|
||||||
|
class CurrentUser:
|
||||||
|
def __init__(self, user_id: str, tier: str):
|
||||||
|
self.user_id = user_id
|
||||||
|
self.tier = tier
|
||||||
|
|
||||||
|
async def get_current_user(request: Request) -> CurrentUser:
|
||||||
|
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
|
||||||
|
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
||||||
|
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
|
||||||
|
|
||||||
|
def require_tier(*allowed_tiers: str):
|
||||||
|
"""Dependency che blocca se il tier non è tra quelli ammessi."""
|
||||||
|
async def check(user: CurrentUser = Depends(get_current_user)):
|
||||||
|
if user.tier not in allowed_tiers:
|
||||||
|
raise HTTPException(403, "Tier insufficient")
|
||||||
|
return user
|
||||||
|
return check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cosa succede quando il tier cambia (upgrade/downgrade)?
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
|
||||||
|
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
|
||||||
|
│ │ │ Service │ (Redis pub/sub) │ Service │
|
||||||
|
└──────────┘ └──────────┘ └────┬─────┘
|
||||||
|
│
|
||||||
|
UPDATE users
|
||||||
|
SET tier = 'power'
|
||||||
|
│
|
||||||
|
Al prossimo /refresh
|
||||||
|
il JWT conterrà tier='power'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 15–30 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
|
||||||
|
|
||||||
|
### Dove si applica in ciascun servizio
|
||||||
|
|
||||||
|
| Servizio | Enforcement |
|
||||||
|
|---|---|
|
||||||
|
| **Auth Service** | Nessuno (è lui che scrive il tier) |
|
||||||
|
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
|
||||||
|
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
|
||||||
|
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
|
||||||
|
|
||||||
|
### Rate-limit distribuito via Redis
|
||||||
|
|
||||||
|
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# shared/middleware/rate_limit.py
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
class DistributedRateLimiter:
|
||||||
|
def __init__(self, redis: aioredis.Redis):
|
||||||
|
self._redis = redis
|
||||||
|
|
||||||
|
async def check(self, user_id: str, tier: str, service: str) -> bool:
|
||||||
|
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
|
||||||
|
max_req = limits.get(tier, 20)
|
||||||
|
key = f"rate:{service}:{user_id}"
|
||||||
|
|
||||||
|
pipe = self._redis.pipeline()
|
||||||
|
pipe.incr(key)
|
||||||
|
pipe.expire(key, 60)
|
||||||
|
count, _ = await pipe.execute()
|
||||||
|
|
||||||
|
return count <= max_req
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
|
||||||
|
|
||||||
|
`DeviceConnectionManager` è un **singleton in-memory**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class DeviceConnectionManager:
|
||||||
|
def __init__(self):
|
||||||
|
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
|
||||||
|
```
|
||||||
|
|
||||||
|
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
|
||||||
|
|
||||||
|
### La soluzione: Redis Pub/Sub + Registry
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ Redis │
|
||||||
|
│ │
|
||||||
|
│ Hash: ws:connections │
|
||||||
|
│ user_123 → instance_A │
|
||||||
|
│ user_456 → instance_B │
|
||||||
|
│ │
|
||||||
|
│ Pub/Sub channels: │
|
||||||
|
│ tool_call:{user_id} → tool call payloads │
|
||||||
|
│ tool_result:{call_id} → tool result payloads │
|
||||||
|
│ stream:{user_id} → text_chunk streaming │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
|
||||||
|
┌───────────────────────┐ ┌───────────────────────┐
|
||||||
|
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
|
||||||
|
│ tool_call:user_123│ │ → user_123 è su A │
|
||||||
|
│ │ │ │
|
||||||
|
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
|
||||||
|
│ da Redis channel │ │ tool_call:user_123 │
|
||||||
|
│ │ │ {id, action, ...} │
|
||||||
|
│ 3. Invia al device │ │ │
|
||||||
|
│ via WS │ │ 4. SUBSCRIBE │
|
||||||
|
│ │ │ tool_result:{id} │
|
||||||
|
│ 4. Device risponde │ │ │
|
||||||
|
│ tool_result │──────────│► 5. Riceve risultato │
|
||||||
|
│ │ │ │
|
||||||
|
│ 5. PUBLISH │ │ │
|
||||||
|
│ tool_result:{id} │ │ │
|
||||||
|
└───────────────────────┘ └───────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Implementazione: `RedisDeviceManager`
|
||||||
|
|
||||||
|
```python
|
||||||
|
# chat-service/app/core/device_manager.py
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LocalConnection:
|
||||||
|
ws: WebSocket
|
||||||
|
device_id: str
|
||||||
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisDeviceManager:
|
||||||
|
"""Device manager backed by Redis for cross-instance communication."""
|
||||||
|
|
||||||
|
def __init__(self, redis_url: str = "redis://redis:6379"):
|
||||||
|
self._redis = aioredis.from_url(redis_url)
|
||||||
|
self._pubsub = self._redis.pubsub()
|
||||||
|
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
|
||||||
|
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Avvia il listener Redis per tool_call in arrivo."""
|
||||||
|
asyncio.create_task(self._listen_tool_calls())
|
||||||
|
|
||||||
|
# ── Registrazione ──
|
||||||
|
|
||||||
|
async def register(self, user_id: str, device_id: str, ws: WebSocket):
|
||||||
|
# Registra localmente
|
||||||
|
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
|
||||||
|
# Registra in Redis quale istanza ha la connessione
|
||||||
|
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
|
||||||
|
# Sottoscrivi ai tool_call per questo utente
|
||||||
|
await self._pubsub.subscribe(f"tool_call:{user_id}")
|
||||||
|
|
||||||
|
async def unregister(self, user_id: str):
|
||||||
|
conn = self._local.pop(user_id, None)
|
||||||
|
if conn:
|
||||||
|
for fut in conn.pending_calls.values():
|
||||||
|
if not fut.done():
|
||||||
|
fut.cancel()
|
||||||
|
await self._redis.hdel("ws:connections", user_id)
|
||||||
|
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
|
||||||
|
|
||||||
|
# ── Presenza ──
|
||||||
|
|
||||||
|
async def is_online(self, user_id: str) -> bool:
|
||||||
|
return await self._redis.hexists("ws:connections", user_id)
|
||||||
|
|
||||||
|
# ── Tool-call round-trip (cross-instance) ──
|
||||||
|
|
||||||
|
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Invia un tool_call al device dell'utente.
|
||||||
|
Funziona sia che la WS sia locale che su un'altra istanza.
|
||||||
|
"""
|
||||||
|
call_id = payload["id"]
|
||||||
|
|
||||||
|
# Caso 1: connessione locale → invio diretto
|
||||||
|
if user_id in self._local:
|
||||||
|
conn = self._local[user_id]
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
fut: asyncio.Future[dict] = loop.create_future()
|
||||||
|
conn.pending_calls[call_id] = fut
|
||||||
|
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
|
||||||
|
return await asyncio.wait_for(fut, timeout=30.0)
|
||||||
|
|
||||||
|
# Caso 2: connessione remota → Redis pub/sub
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
fut = loop.create_future()
|
||||||
|
self._remote_futures[call_id] = fut
|
||||||
|
|
||||||
|
# Sottoscrivi al canale di risposta
|
||||||
|
result_channel = f"tool_result:{call_id}"
|
||||||
|
await self._pubsub.subscribe(result_channel)
|
||||||
|
|
||||||
|
# Pubblica il tool_call
|
||||||
|
await self._redis.publish(
|
||||||
|
f"tool_call:{user_id}",
|
||||||
|
json.dumps(payload),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(fut, timeout=30.0)
|
||||||
|
finally:
|
||||||
|
self._remote_futures.pop(call_id, None)
|
||||||
|
await self._pubsub.unsubscribe(result_channel)
|
||||||
|
|
||||||
|
# ── Risoluzione tool_result (da WS locale) ──
|
||||||
|
|
||||||
|
def resolve_local(self, user_id: str, call_id: str, result: dict):
|
||||||
|
conn = self._local.get(user_id)
|
||||||
|
if conn:
|
||||||
|
fut = conn.pending_calls.pop(call_id, None)
|
||||||
|
if fut and not fut.done():
|
||||||
|
fut.set_result(result)
|
||||||
|
|
||||||
|
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
|
||||||
|
"""Chiamato quando il device locale invia un tool_result."""
|
||||||
|
self.resolve_local(user_id, call_id, result)
|
||||||
|
# Pubblica anche su Redis per l'istanza remota che aspetta
|
||||||
|
await self._redis.publish(
|
||||||
|
f"tool_result:{call_id}",
|
||||||
|
json.dumps(result),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Listener Redis ──
|
||||||
|
|
||||||
|
async def _listen_tool_calls(self):
|
||||||
|
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
|
||||||
|
async for message in self._pubsub.listen():
|
||||||
|
if message["type"] != "message":
|
||||||
|
continue
|
||||||
|
channel = message["channel"]
|
||||||
|
if isinstance(channel, bytes):
|
||||||
|
channel = channel.decode()
|
||||||
|
|
||||||
|
data = json.loads(message["data"])
|
||||||
|
|
||||||
|
if channel.startswith("tool_call:"):
|
||||||
|
# Un'altra istanza vuole che inviamo un tool_call al nostro device
|
||||||
|
user_id = channel.split(":", 1)[1]
|
||||||
|
conn = self._local.get(user_id)
|
||||||
|
if conn:
|
||||||
|
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
|
||||||
|
|
||||||
|
elif channel.startswith("tool_result:"):
|
||||||
|
# Risposta a un tool_call che abbiamo inviato tramite Redis
|
||||||
|
call_id = channel.split(":", 1)[1]
|
||||||
|
fut = self._remote_futures.pop(call_id, None)
|
||||||
|
if fut and not fut.done():
|
||||||
|
fut.set_result(data)
|
||||||
|
|
||||||
|
# ── Stream cross-instance ──
|
||||||
|
|
||||||
|
async def publish_stream_chunk(self, user_id: str, chunk: dict):
|
||||||
|
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
|
||||||
|
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Struttura Directory Proposta (MVP)
|
||||||
|
|
||||||
|
```
|
||||||
|
adiuva-api/
|
||||||
|
├── docker-compose.yml # Orchestrazione completa
|
||||||
|
├── docker-compose.dev.yml # Override per sviluppo locale
|
||||||
|
├── shared/ # Codice condiviso (montato come volume)
|
||||||
|
│ ├── auth.py # JWT verification (chiave pubblica)
|
||||||
|
│ ├── schemas.py # Pydantic schemas condivisi
|
||||||
|
│ ├── middleware/
|
||||||
|
│ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
|
||||||
|
│ │ └── sanitizer.py
|
||||||
|
│ └── models/
|
||||||
|
│ └── base.py # SQLAlchemy base condivisa
|
||||||
|
│
|
||||||
|
├── auth-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # users, refresh_tokens
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ └── auth.py
|
||||||
|
│ └── services/
|
||||||
|
│ ├── jwt_service.py # RS256 signing
|
||||||
|
│ └── user_service.py
|
||||||
|
│
|
||||||
|
├── chat-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # memory_*
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ ├── device_ws.py # WS connection owner
|
||||||
|
│ │ └── chat.py # REST fallback
|
||||||
|
│ ├── core/
|
||||||
|
│ │ ├── device_manager.py # RedisDeviceManager
|
||||||
|
│ │ ├── deep_agent.py # Home + floating chat
|
||||||
|
│ │ ├── memory_middleware.py
|
||||||
|
│ │ ├── ws_context.py
|
||||||
|
│ │ ├── output_formatter.py
|
||||||
|
│ │ └── llm.py
|
||||||
|
│ └── agents/ # Tool definitions (used by deep_agent)
|
||||||
|
│ ├── task_agent.py
|
||||||
|
│ ├── project_agent.py
|
||||||
|
│ ├── note_agent.py
|
||||||
|
│ └── timeline_agent.py
|
||||||
|
│
|
||||||
|
├── agent-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # agent_run_logs, local/cloud_agent_configs
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ ├── agents.py # catalog, can-create, trigger
|
||||||
|
│ │ └── agent_setup.py # journey start/message
|
||||||
|
│ ├── core/
|
||||||
|
│ │ ├── agent_runner.py # Batch classify → process
|
||||||
|
│ │ ├── agent_registry.py
|
||||||
|
│ │ ├── redis_executor.py # execute_on_client via Redis pub/sub
|
||||||
|
│ │ └── llm.py
|
||||||
|
│ └── agents/
|
||||||
|
│ ├── task_agent.py # Tool definitions (batch context)
|
||||||
|
│ ├── project_agent.py
|
||||||
|
│ ├── note_agent.py
|
||||||
|
│ ├── timeline_agent.py
|
||||||
|
│ └── filesystem_agent.py
|
||||||
|
│
|
||||||
|
├── billing-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # subscriptions
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ └── billing.py
|
||||||
|
│ └── services/
|
||||||
|
│ ├── stripe_service.py
|
||||||
|
│ └── tier_manager.py
|
||||||
|
│
|
||||||
|
└── infra/
|
||||||
|
├── traefik/
|
||||||
|
│ └── traefik.yml
|
||||||
|
├── keys/
|
||||||
|
│ ├── jwt_private.pem # Solo auth-service
|
||||||
|
│ └── jwt_public.pem # Tutti i servizi
|
||||||
|
└── alembic/ # Migrazioni condivise o per-servizio
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Docker Compose — Configurazione MVP
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# docker-compose.yml
|
||||||
|
|
||||||
|
services:
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# API Gateway
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
traefik:
|
||||||
|
image: traefik:v3.2
|
||||||
|
command:
|
||||||
|
- "--api.insecure=true"
|
||||||
|
- "--providers.docker=true"
|
||||||
|
- "--providers.docker.exposedbydefault=false"
|
||||||
|
- "--entrypoints.web.address=:80"
|
||||||
|
- "--entrypoints.websecure.address=:443"
|
||||||
|
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
|
||||||
|
ports:
|
||||||
|
- "80:80"
|
||||||
|
- "443:443"
|
||||||
|
- "8080:8080" # Dashboard Traefik (disabilitare in prod)
|
||||||
|
volumes:
|
||||||
|
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||||
|
- ./infra/certs:/certs:ro
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Auth Service (2 repliche)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
auth-service:
|
||||||
|
build: ./auth-service
|
||||||
|
deploy:
|
||||||
|
replicas: 2
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
|
||||||
|
SERVICE_NAME: auth
|
||||||
|
secrets:
|
||||||
|
- jwt_private_key
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
|
||||||
|
- "traefik.http.services.auth.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Chat Service — Real-time WS + Chat (scalabile)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
chat-service:
|
||||||
|
build: ./chat-service
|
||||||
|
deploy:
|
||||||
|
replicas: 2
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: chat
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
# REST chat endpoint
|
||||||
|
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
|
||||||
|
- "traefik.http.services.chat.loadbalancer.server.port=8000"
|
||||||
|
# WebSocket route con sticky session
|
||||||
|
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
|
||||||
|
- "traefik.http.routers.ws.service=chat-ws"
|
||||||
|
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
|
||||||
|
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
|
||||||
|
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Agent Service — Batch processing (scalabile indipendentemente)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
agent-service:
|
||||||
|
build: ./agent-service
|
||||||
|
deploy:
|
||||||
|
replicas: 2
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: agent
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
|
||||||
|
- "traefik.http.services.agents.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Billing Service (1 replica)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
billing-service:
|
||||||
|
build: ./billing-service
|
||||||
|
deploy:
|
||||||
|
replicas: 1
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: billing
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
|
||||||
|
- "traefik.http.services.billing.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Infrastruttura
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
db:
|
||||||
|
image: pgvector/pgvector:pg16
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
POSTGRES_DB: adiuva
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:7-alpine
|
||||||
|
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
||||||
|
volumes:
|
||||||
|
- redis_data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
qdrant:
|
||||||
|
image: qdrant/qdrant:latest
|
||||||
|
volumes:
|
||||||
|
- qdrant_data:/qdrant/storage
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
secrets:
|
||||||
|
jwt_private_key:
|
||||||
|
file: ./infra/keys/jwt_private.pem
|
||||||
|
jwt_public_key:
|
||||||
|
file: ./infra/keys/jwt_public.pem
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
redis_data:
|
||||||
|
qdrant_data:
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Configurazione Cloudflare + VPS
|
||||||
|
|
||||||
|
### 6.1 DNS
|
||||||
|
|
||||||
|
```
|
||||||
|
api.tuodominio.com → A record → IP del VPS
|
||||||
|
→ Proxy: ON (orange cloud)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 Cloudflare Settings
|
||||||
|
|
||||||
|
| Setting | Valore | Motivo |
|
||||||
|
|---------|--------|--------|
|
||||||
|
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
|
||||||
|
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
|
||||||
|
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
|
||||||
|
| Under Attack Mode | Off (attivare se necessario) | |
|
||||||
|
|
||||||
|
### 6.3 TLS sul VPS
|
||||||
|
|
||||||
|
Due opzioni:
|
||||||
|
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
|
||||||
|
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# traefik.yml — con Cloudflare Origin Certificate
|
||||||
|
entryPoints:
|
||||||
|
websecure:
|
||||||
|
address: ":443"
|
||||||
|
|
||||||
|
tls:
|
||||||
|
certificates:
|
||||||
|
- certFile: /certs/origin.pem
|
||||||
|
keyFile: /certs/origin-key.pem
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.4 Rete VPS
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
|
||||||
|
# https://www.cloudflare.com/ips/
|
||||||
|
ufw default deny incoming
|
||||||
|
ufw allow from 173.245.48.0/20 to any port 443
|
||||||
|
ufw allow from 103.21.244.0/22 to any port 443
|
||||||
|
# ... (tutti gli IP range di Cloudflare)
|
||||||
|
ufw allow ssh
|
||||||
|
ufw enable
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Comunicazione Inter-Servizio
|
||||||
|
|
||||||
|
### 7.1 Redis Pub/Sub — Event Bus
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────┐ tier_changed:user_123 ┌──────────┐
|
||||||
|
│ Billing │ ────────────────────────► │ Auth │
|
||||||
|
│ Service │ │ Service │
|
||||||
|
└──────────┘ └──────────┘
|
||||||
|
|
||||||
|
┌──────────┐ tool_call:user_123 ┌──────────┐
|
||||||
|
│ Agent │ ────────────────────────► │ Chat │
|
||||||
|
│ Service │ │ Service │
|
||||||
|
│ (batch) │ ◄────────────────────────│ (ha WS) │
|
||||||
|
└──────────┘ tool_result:{call_id} └──────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.2 Health Checks e Service Discovery
|
||||||
|
|
||||||
|
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
|
||||||
|
- **Redis pub/sub** (tool-call cross-instance, tier events)
|
||||||
|
- **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
|
||||||
|
- **PostgreSQL** (dati persistenti condivisi)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Piano di Migrazione Incrementale (MVP)
|
||||||
|
|
||||||
|
### Fase 1 — Preparazione (nel monolite attuale)
|
||||||
|
1. Aggiungere Redis al `docker-compose.yml` attuale
|
||||||
|
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
|
||||||
|
3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
|
||||||
|
4. Estrarre `shared/` con auth verification, schemas, middleware
|
||||||
|
|
||||||
|
### Fase 2 — Auth Service (primo split)
|
||||||
|
1. Estrarre `auth.py` routes + models in `auth-service/`
|
||||||
|
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
|
||||||
|
3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
|
||||||
|
4. Il monolite continua a servire tutto il resto
|
||||||
|
|
||||||
|
### Fase 3 — Billing Service
|
||||||
|
1. Estrarre billing routes, Stripe service, tier manager
|
||||||
|
2. Configurare Redis pub/sub per `tier_changed` events
|
||||||
|
3. Routare via Traefik
|
||||||
|
|
||||||
|
### Fase 4 — Split Chat + Agent (il più delicato)
|
||||||
|
1. Il monolite residuo contiene WS + chat + agents
|
||||||
|
2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
|
||||||
|
3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
|
||||||
|
4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
|
||||||
|
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
|
||||||
|
|
||||||
|
### Fase 5 — Scaling test
|
||||||
|
1. Scalare Chat Service a 2 repliche, verificare sticky sessions
|
||||||
|
2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
|
||||||
|
3. Monitoring (Prometheus + Grafana) per ogni servizio
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Monitoraggio e Logging
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Aggiungere al docker-compose.yml
|
||||||
|
|
||||||
|
prometheus:
|
||||||
|
image: prom/prometheus:latest
|
||||||
|
volumes:
|
||||||
|
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
grafana:
|
||||||
|
image: grafana/grafana:latest
|
||||||
|
ports:
|
||||||
|
- "3000:3000"
|
||||||
|
volumes:
|
||||||
|
- grafana_data:/var/lib/grafana
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
loki:
|
||||||
|
image: grafana/loki:latest
|
||||||
|
restart: unless-stopped
|
||||||
|
```
|
||||||
|
|
||||||
|
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Sizing VPS Minimo Consigliato (MVP)
|
||||||
|
|
||||||
|
| Componente | CPU | RAM | Note |
|
||||||
|
|---|---|---|---|
|
||||||
|
| Traefik | 0.25 | 128MB | |
|
||||||
|
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
|
||||||
|
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
|
||||||
|
| Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
|
||||||
|
| Billing Service | 0.25 | 128MB | |
|
||||||
|
| PostgreSQL | 1.0 | 1GB | |
|
||||||
|
| Redis | 0.25 | 256MB | |
|
||||||
|
| Qdrant | 0.5 | 512MB | |
|
||||||
|
| **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
|
||||||
|
|
||||||
|
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Riepilogo Architettura MVP
|
||||||
|
|
||||||
|
| Servizio | Repliche | Proprietario di |
|
||||||
|
|---|---|---|
|
||||||
|
| **Traefik** | 1 | Routing, TLS, sticky sessions |
|
||||||
|
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
|
||||||
|
| **Chat Service** | 2–N | WebSocket, home/floating chat, streaming |
|
||||||
|
| **Agent Service** | 2–N | Batch processing, directory scan, agent setup |
|
||||||
|
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
|
||||||
|
|
||||||
|
| Decisione | Scelta | Motivazione |
|
||||||
|
|---|---|---|
|
||||||
|
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
|
||||||
|
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
|
||||||
|
| Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
|
||||||
|
| WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
|
||||||
|
| Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
|
||||||
|
| Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
|
||||||
|
| Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
|
||||||
|
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
|
||||||
|
| TLS | Cloudflare Origin Certificate | Zero maintenance |
|
||||||
|
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
|
||||||
|
| Storage / Plugin | Post-MVP | Non critici per il lancio |
|
||||||
@@ -4,8 +4,6 @@ gunicorn>=22.0.0
|
|||||||
langchain>=0.3.0
|
langchain>=0.3.0
|
||||||
langchain-openai>=0.3.0
|
langchain-openai>=0.3.0
|
||||||
langchain-litellm>=0.1.0
|
langchain-litellm>=0.1.0
|
||||||
langgraph>=0.3.0
|
|
||||||
deepagents>=0.4.10
|
|
||||||
litellm>=1.50.0
|
litellm>=1.50.0
|
||||||
pydantic>=2.10.0
|
pydantic>=2.10.0
|
||||||
pydantic-settings>=2.7.0
|
pydantic-settings>=2.7.0
|
||||||
@@ -34,4 +32,6 @@ google-auth-oauthlib>=1.2.0
|
|||||||
google-auth-httplib2>=0.2.0
|
google-auth-httplib2>=0.2.0
|
||||||
msal>=1.28.0
|
msal>=1.28.0
|
||||||
cryptography>=42.0.0
|
cryptography>=42.0.0
|
||||||
|
redis>=5.0.0
|
||||||
|
langfuse>=3.0.0
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
|||||||
19
services/auth/.env.example
Normal file
19
services/auth/.env.example
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# ── Auth Service ──────────────────────────────────────────────────────────────
|
||||||
|
# This file contains env vars specific to the Auth Service.
|
||||||
|
# Shared vars (DATABASE_URL, REDIS_URL, etc.) come from the root .env
|
||||||
|
# or from docker-compose environment.
|
||||||
|
|
||||||
|
# ── JWT RS256 Keys ────────────────────────────────────────────────────────────
|
||||||
|
# Generate keypair:
|
||||||
|
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||||
|
# openssl rsa -in private.pem -pubout -out public.pem
|
||||||
|
#
|
||||||
|
# Paste PEM content with literal \n for newlines:
|
||||||
|
# JWT_PRIVATE_KEY=-----BEGIN PRIVATE KEY-----\nMIIEvQ...
|
||||||
|
# JWT_PUBLIC_KEY=-----BEGIN PUBLIC KEY-----\nMIIBIj...
|
||||||
|
|
||||||
|
# PRIVATE KEY — used to SIGN JWTs. NEVER share outside this service.
|
||||||
|
JWT_PRIVATE_KEY=
|
||||||
|
|
||||||
|
# PUBLIC KEY — used to VERIFY JWTs.
|
||||||
|
JWT_PUBLIC_KEY=
|
||||||
36
services/auth/Dockerfile
Normal file
36
services/auth/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
# Install shared + service deps in one layer
|
||||||
|
COPY services/auth/requirements.txt ./requirements.txt
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Copy shared module (available to all services)
|
||||||
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Copy service source
|
||||||
|
COPY services/auth/app/ app/
|
||||||
|
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "2", \
|
||||||
|
"--timeout", "30"]
|
||||||
16
services/auth/README.md
Normal file
16
services/auth/README.md
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Auth Service
|
||||||
|
|
||||||
|
Owns: user registration, login, JWT RS256 issuance, token refresh, `/me` endpoint.
|
||||||
|
|
||||||
|
## Tables owned
|
||||||
|
- `users`
|
||||||
|
- `refresh_tokens`
|
||||||
|
- `subscriptions` (read; Billing Service writes)
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
- `POST /auth/register`
|
||||||
|
- `POST /auth/login`
|
||||||
|
- `POST /auth/refresh`
|
||||||
|
- `GET /auth/me`
|
||||||
|
- `PUT /auth/me`
|
||||||
|
- `GET /auth/verify` (ForwardAuth for Traefik)
|
||||||
0
services/auth/app/__init__.py
Normal file
0
services/auth/app/__init__.py
Normal file
34
services/auth/app/config.py
Normal file
34
services/auth/app/config.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""Auth Service — local configuration.
|
||||||
|
|
||||||
|
Contains secrets that ONLY the Auth Service needs (e.g., JWT private key).
|
||||||
|
These are NOT in shared/config.py to prevent other services from accessing them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import field_validator
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class AuthSettings(BaseSettings):
|
||||||
|
# RS256 private key (PEM format). Used to SIGN JWTs.
|
||||||
|
# Only the Auth Service has this. Generate with:
|
||||||
|
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||||
|
# Then set the env var (newlines as \n):
|
||||||
|
# JWT_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\nMIIEv..."
|
||||||
|
JWT_PRIVATE_KEY: str = ""
|
||||||
|
|
||||||
|
# RS256 public key (PEM format). Used to VERIFY JWTs.
|
||||||
|
# Derived from the private key:
|
||||||
|
# openssl rsa -in private.pem -pubout -out public.pem
|
||||||
|
JWT_PUBLIC_KEY: str = ""
|
||||||
|
|
||||||
|
@field_validator("JWT_PRIVATE_KEY", "JWT_PUBLIC_KEY", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _expand_pem_newlines(cls, v: str) -> str:
|
||||||
|
if isinstance(v, str) and r"\n" in v:
|
||||||
|
return v.replace(r"\n", "\n")
|
||||||
|
return v
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
auth_settings = AuthSettings()
|
||||||
69
services/auth/app/deps.py
Normal file
69
services/auth/app/deps.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""Auth dependencies — JWT validation for the Auth Service.
|
||||||
|
|
||||||
|
This is the canonical get_current_user used by protected endpoints
|
||||||
|
within the Auth Service itself (/me, /me PUT).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.db import get_session
|
||||||
|
from shared.models import Subscription, User
|
||||||
|
from shared.schemas import UserProfile
|
||||||
|
|
||||||
|
from app.config import auth_settings
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: str = Depends(oauth2_scheme),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Validate a Bearer JWT and return the authenticated user.
|
||||||
|
|
||||||
|
The JWT is used for identity and expiry. Tier is fetched live from the
|
||||||
|
subscriptions table so upgrades/downgrades take effect immediately.
|
||||||
|
"""
|
||||||
|
credentials_exc = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
email: str | None = payload.get("email")
|
||||||
|
if not user_id or not email:
|
||||||
|
raise credentials_exc
|
||||||
|
except JWTError:
|
||||||
|
raise credentials_exc
|
||||||
|
|
||||||
|
# Live tier lookup
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
|
# Fetch name/surname
|
||||||
|
user_result = await db.execute(
|
||||||
|
select(User.name, User.surname).where(User.id == user_id)
|
||||||
|
)
|
||||||
|
user_row = user_result.one_or_none()
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user_id,
|
||||||
|
email=email,
|
||||||
|
name=user_row.name if user_row else None,
|
||||||
|
surname=user_row.surname if user_row else None,
|
||||||
|
tier=tier,
|
||||||
|
) # type: ignore[arg-type]
|
||||||
62
services/auth/app/main.py
Normal file
62
services/auth/app/main.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""Auth Service — JWT issuance, user management, ForwardAuth verification.
|
||||||
|
|
||||||
|
Standalone FastAPI service extracted from the adiuva-api monolith.
|
||||||
|
Owns: users, refresh_tokens, subscriptions (read).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure the repo root is on sys.path so "shared" is importable.
|
||||||
|
# In Docker, COPY shared/ puts it at /app/shared/ (already importable).
|
||||||
|
# In local dev, we need to add the repo root (two levels up from this file).
|
||||||
|
_repo_root = str(Path(__file__).resolve().parents[3])
|
||||||
|
if _repo_root not in sys.path:
|
||||||
|
sys.path.insert(0, _repo_root)
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
yield
|
||||||
|
from shared.db import engine
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
app = FastAPI(
|
||||||
|
title="Adiuva Auth Service",
|
||||||
|
version="0.1.0",
|
||||||
|
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||||
|
redoc_url=None,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.CORS_ORIGINS,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.routes import router
|
||||||
|
from app.verify import router as verify_router
|
||||||
|
|
||||||
|
app.include_router(router, prefix="/api/v1")
|
||||||
|
app.include_router(verify_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
async def health() -> dict:
|
||||||
|
return {"status": "ok", "service": "auth", "version": app.version}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
249
services/auth/app/routes.py
Normal file
249
services/auth/app/routes.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
"""Auth routes: register, login, refresh, me.
|
||||||
|
|
||||||
|
Extracted from app/api/routes/auth.py — uses shared.* imports instead of app.*.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from jose import jwt
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.db import get_session
|
||||||
|
from shared.models import RefreshToken, Subscription, User
|
||||||
|
from shared.schemas import AuthTokens, UserProfile
|
||||||
|
|
||||||
|
from app.config import auth_settings
|
||||||
|
from app.deps import get_current_user
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_password(password: str) -> str:
|
||||||
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_password(password: str, hashed: str) -> bool:
|
||||||
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_token(plain_token: str) -> str:
|
||||||
|
"""SHA-256 of the plain refresh token string."""
|
||||||
|
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||||
|
"""Return (RS256-signed JWT, expires_at_ms)."""
|
||||||
|
now = int(time.time())
|
||||||
|
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"email": email,
|
||||||
|
"tier": tier,
|
||||||
|
"exp": exp,
|
||||||
|
"iat": now,
|
||||||
|
}
|
||||||
|
token = jwt.encode(payload, auth_settings.JWT_PRIVATE_KEY, algorithm="RS256")
|
||||||
|
return token, exp * 1000 # ms for client
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_live_tier(db: AsyncSession, user_id: str) -> str:
|
||||||
|
"""Fetch authoritative tier from subscriptions table."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
return result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request bodies ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _RegisterRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class _LoginRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class _RefreshRequest(BaseModel):
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateProfileRequest(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register(
|
||||||
|
body: _RegisterRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Create a new account and return JWT tokens."""
|
||||||
|
existing = await db.execute(select(User).where(User.email == body.email))
|
||||||
|
if existing.scalar_one_or_none() is not None:
|
||||||
|
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
email=body.email,
|
||||||
|
name=body.name,
|
||||||
|
surname=body.surname,
|
||||||
|
password_hash=_hash_password(body.password),
|
||||||
|
tier="free",
|
||||||
|
encryption_key=Fernet.generate_key().decode(),
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=AuthTokens)
|
||||||
|
async def login(
|
||||||
|
body: _LoginRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Validate credentials and return JWT tokens."""
|
||||||
|
result = await db.execute(select(User).where(User.email == body.email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not _verify_password(body.password, user.password_hash):
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||||
|
|
||||||
|
# Fetch live tier for the JWT claim
|
||||||
|
tier = await _get_live_tier(db, user.id)
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=AuthTokens)
|
||||||
|
async def refresh(
|
||||||
|
body: _RefreshRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Rotate a refresh token and return a new token pair."""
|
||||||
|
token_hash = _hash_token(body.refresh_token)
|
||||||
|
result = await db.execute(
|
||||||
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||||
|
)
|
||||||
|
rt = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||||
|
|
||||||
|
await db.delete(rt)
|
||||||
|
|
||||||
|
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||||
|
user = user_result.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||||
|
|
||||||
|
# Fetch live tier for the new JWT
|
||||||
|
tier = await _get_live_tier(db, user.id)
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
new_rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=new_expires,
|
||||||
|
)
|
||||||
|
db.add(new_rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserProfile)
|
||||||
|
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||||
|
"""Return the profile for the authenticated user."""
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me", response_model=UserProfile)
|
||||||
|
async def update_profile(
|
||||||
|
body: _UpdateProfileRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Update the authenticated user's name and surname."""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
if body.name is not None:
|
||||||
|
user.name = body.name
|
||||||
|
if body.surname is not None:
|
||||||
|
user.surname = body.surname
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user.id,
|
||||||
|
email=user.email,
|
||||||
|
name=user.name,
|
||||||
|
surname=user.surname,
|
||||||
|
tier=current_user.tier,
|
||||||
|
)
|
||||||
66
services/auth/app/verify.py
Normal file
66
services/auth/app/verify.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""ForwardAuth verification endpoint for Traefik.
|
||||||
|
|
||||||
|
Traefik calls GET /api/v1/auth/verify on every request to a protected
|
||||||
|
service. This endpoint validates the JWT from the Authorization header
|
||||||
|
and returns identity headers that Traefik injects into downstream requests.
|
||||||
|
|
||||||
|
Downstream services NEVER validate JWTs themselves — they trust the
|
||||||
|
X-User-Id, X-User-Email, X-User-Tier headers injected by Traefik.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request, Response
|
||||||
|
from fastapi import status as http_status
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.models import Subscription
|
||||||
|
|
||||||
|
from app.config import auth_settings
|
||||||
|
|
||||||
|
router = APIRouter(tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/auth/verify")
|
||||||
|
async def verify(request: Request) -> Response:
|
||||||
|
"""Validate JWT and return identity headers for Traefik ForwardAuth.
|
||||||
|
|
||||||
|
Returns 200 with X-User-* headers on success, 401 on failure.
|
||||||
|
Traefik copies response headers to the downstream request.
|
||||||
|
"""
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
|
token = auth_header[7:] # strip "Bearer "
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
email: str | None = payload.get("email")
|
||||||
|
if not user_id or not email:
|
||||||
|
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
||||||
|
except JWTError:
|
||||||
|
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
|
# Live tier lookup from subscriptions table
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
status_code=http_status.HTTP_200_OK,
|
||||||
|
headers={
|
||||||
|
"X-User-Id": user_id,
|
||||||
|
"X-User-Email": email,
|
||||||
|
"X-User-Tier": tier,
|
||||||
|
},
|
||||||
|
)
|
||||||
11
services/auth/requirements.txt
Normal file
11
services/auth/requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
python-jose[cryptography]>=3.3.0
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
asyncpg>=0.30.0
|
||||||
|
bcrypt>=4.2.0
|
||||||
|
cryptography>=42.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
36
services/batch-agent/Dockerfile
Normal file
36
services/batch-agent/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY services/batch-agent/requirements.txt ./requirements.txt
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Shared module
|
||||||
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Service source
|
||||||
|
COPY services/batch-agent/app/ app/
|
||||||
|
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Batch runs are long-lived — use a longer timeout than chat (300s vs 120s)
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "2", \
|
||||||
|
"--timeout", "300"]
|
||||||
23
services/batch-agent/README.md
Normal file
23
services/batch-agent/README.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Batch Agent Service
|
||||||
|
|
||||||
|
Owns: agent_runner, journey builder, filesystem_agent, integrations (Gmail, MS Graph).
|
||||||
|
|
||||||
|
## Tables owned
|
||||||
|
- `local_agent_configs`
|
||||||
|
- `cloud_agent_configs`
|
||||||
|
- `agent_run_logs`
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
- `GET /agents/catalog`
|
||||||
|
- `POST /agents/can-create`
|
||||||
|
- `POST /agents/trigger`
|
||||||
|
- `GET /agents/{id}/history`
|
||||||
|
|
||||||
|
## Redis channels
|
||||||
|
- Subscribe: `batch:request:{user_id}`
|
||||||
|
- Publish: `ws:out:{user_id}` (journey replies + tool calls)
|
||||||
|
- BRPOP: `tool:result:{call_id}` (30s timeout)
|
||||||
|
- SET+EX: `journey:{user_id}` (session state, TTL 1800s)
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
- [ ] Integrate Langfuse tracing (reuse `services/chat/app/tracing.py` pattern — `trace_span()`, `get_langfuse_callback()`, prompt management). Each batch agent run should create a trace with input/output, link prompts, and pass the LangChain `CallbackHandler` to LLM calls.
|
||||||
0
services/batch-agent/app/__init__.py
Normal file
0
services/batch-agent/app/__init__.py
Normal file
884
services/batch-agent/app/agent_runner.py
Normal file
884
services/batch-agent/app/agent_runner.py
Normal file
@@ -0,0 +1,884 @@
|
|||||||
|
"""Agent run orchestrator — adapted for Batch Agent Service.
|
||||||
|
|
||||||
|
Key changes from monolith app/core/agent_runner.py:
|
||||||
|
- No DeviceConnectionManager — tool calls go through Redis ws_context.
|
||||||
|
- set_current_user / clear_current_user replace set_client_executor.
|
||||||
|
- run_local_agent accepts a serialized dict (from Redis / REST) instead
|
||||||
|
of SQLAlchemy model objects.
|
||||||
|
- _finalize_run writes to PostgreSQL via shared.db.async_session.
|
||||||
|
- Cloud agent import path changed to app.integrations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
|
from app.agents.note_agent import NOTE_TOOLS
|
||||||
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
|
from app.llm import get_llm
|
||||||
|
from app.ws_context import execute_on_client, set_current_user, clear_current_user
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
from shared.redis import redis_client, ws_out_channel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Concurrency guard ─────────────────────────────────────────────────────
|
||||||
|
_running_agents: set[str] = set()
|
||||||
|
|
||||||
|
|
||||||
|
def is_agent_running(agent_id: str) -> bool:
|
||||||
|
return agent_id in _running_agents
|
||||||
|
|
||||||
|
|
||||||
|
# ── Timeouts ───────────────────────────────────────────────────────────────
|
||||||
|
_TOOL_CALL_TIMEOUT: int = 30
|
||||||
|
_MAX_PROCESSING_STEPS: int = 12
|
||||||
|
_MAX_SCAN_DEPTH: int = 5
|
||||||
|
|
||||||
|
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
||||||
|
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
||||||
|
"tasks": TASK_TOOLS,
|
||||||
|
"notes": NOTE_TOOLS,
|
||||||
|
"timelines": TIMELINE_TOOLS,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Step 1: Classification prompt ─────────────────────────────────────────
|
||||||
|
|
||||||
|
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
|
||||||
|
"tasks": (
|
||||||
|
"Action items, to-dos, deliverables — anything that describes work to be done, "
|
||||||
|
"assigned to someone, or tracked with a due date or status."
|
||||||
|
),
|
||||||
|
"notes": (
|
||||||
|
"Documentation, meeting notes, summaries, reference material — "
|
||||||
|
"written content meant to be read and referenced rather than acted on."
|
||||||
|
),
|
||||||
|
"timelines": (
|
||||||
|
"Project milestones, deadlines, scheduled events — "
|
||||||
|
"specific dates that mark a point in the progress of a project."
|
||||||
|
),
|
||||||
|
"projects": (
|
||||||
|
"High-level project entities — only relevant if the file clearly introduces "
|
||||||
|
"a new project or updates the scope of an existing one."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
_STEP1_SYSTEM_PROMPT = """\
|
||||||
|
You are a file classifier for a freelance project management tool.
|
||||||
|
|
||||||
|
Your job is to match a file to an existing project and identify which data domains to extract.
|
||||||
|
|
||||||
|
## Project matching rules (STRICT — follow in order)
|
||||||
|
|
||||||
|
1. Search the file content for any mention of a project name, client name, acronym, or topic
|
||||||
|
that overlaps with the existing projects listed below.
|
||||||
|
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
|
||||||
|
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
|
||||||
|
when the file has zero meaningful connection to any listed project.
|
||||||
|
4. When in doubt, pick the closest match from the list.
|
||||||
|
|
||||||
|
## Response format
|
||||||
|
|
||||||
|
Respond ONLY with a JSON object — no markdown, no explanation:
|
||||||
|
|
||||||
|
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
|
||||||
|
|
||||||
|
## Domain definitions (only consider domains in the allowed list)
|
||||||
|
|
||||||
|
{domain_definitions}
|
||||||
|
|
||||||
|
## Existing projects
|
||||||
|
|
||||||
|
{projects_list}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Step 2: Processing prompt ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
_PROCESSING_SYSTEM_PROMPT = """\
|
||||||
|
You are a data extraction assistant for a freelance project management tool.
|
||||||
|
|
||||||
|
Your task: extract structured data from the file content and persist it using the available tools.
|
||||||
|
|
||||||
|
## Mandatory process — follow this order for EVERY item you extract
|
||||||
|
|
||||||
|
1. READ the existing records listed below for the relevant domain.
|
||||||
|
2. SEARCH for a match by title, topic, or semantic similarity.
|
||||||
|
3. If a match exists → call the update_* tool with the existing record's id.
|
||||||
|
4. If no match exists → call the create_* tool and set isAiSuggested=1.
|
||||||
|
|
||||||
|
NEVER call create_* without first checking the existing records.
|
||||||
|
NEVER duplicate a record that already exists under a different wording.
|
||||||
|
|
||||||
|
## Existing records (source of truth)
|
||||||
|
|
||||||
|
{existing_context}
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
Project: {project_context}
|
||||||
|
Domains to extract: {data_types}
|
||||||
|
|
||||||
|
{custom_prompt_section}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Cloud processing prompt ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
_CLOUD_PROCESSING_PROMPT = """\
|
||||||
|
You are a data extraction and management assistant for a freelance project
|
||||||
|
management tool.
|
||||||
|
|
||||||
|
Available tools:
|
||||||
|
Filesystem : read_file_content, list_directory, get_file_metadata
|
||||||
|
Tasks : list_tasks, create_task, update_task, add_task_comment
|
||||||
|
Notes : list_notes, get_note, create_note, update_note
|
||||||
|
Timelines : list_timelines, create_timeline, update_timeline
|
||||||
|
Projects : list_all_projects, get_project, create_project, update_project
|
||||||
|
|
||||||
|
Your task:
|
||||||
|
1. Read the full content of each file below using read_file_content.
|
||||||
|
2. For each piece of information found, ALWAYS try to match and update an
|
||||||
|
existing record before creating a new one.
|
||||||
|
3. ONLY act on these entity types: {data_types}.
|
||||||
|
4. Do NOT invent data. Only extract what is clearly present in the files.
|
||||||
|
5. If a file contains no relevant data for the target entity types, skip it.
|
||||||
|
|
||||||
|
{project_context}
|
||||||
|
|
||||||
|
Files to process:
|
||||||
|
{file_list}
|
||||||
|
|
||||||
|
{custom_prompt_section}
|
||||||
|
|
||||||
|
After processing all files, respond with a brief summary of what you updated
|
||||||
|
and what you created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM tool-calling loop ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_agent_with_tools(
|
||||||
|
*,
|
||||||
|
system_prompt: str,
|
||||||
|
user_message: str,
|
||||||
|
tools: list[Any],
|
||||||
|
max_steps: int,
|
||||||
|
) -> str:
|
||||||
|
"""Run an LLM agent with tool-calling, returning the final text response."""
|
||||||
|
llm = get_llm()
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(content=user_message),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:200],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool list builder ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _build_processing_tools(data_types: list[str]) -> list[Any]:
|
||||||
|
tools: list[Any] = list(FILESYSTEM_TOOLS)
|
||||||
|
for dt in data_types:
|
||||||
|
dt_tools = _DATA_TYPE_TOOLS.get(dt)
|
||||||
|
if dt_tools:
|
||||||
|
tools.extend(dt_tools)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
# ── Code-based directory scanner ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _scan_directories(
|
||||||
|
paths: list[str],
|
||||||
|
extensions: list[str],
|
||||||
|
last_run_at: datetime | None,
|
||||||
|
) -> list[str]:
|
||||||
|
all_files: list[str] = []
|
||||||
|
ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set()
|
||||||
|
|
||||||
|
async def _walk(path: str, depth: int) -> None:
|
||||||
|
if depth > _MAX_SCAN_DEPTH:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="list_directory", data={"path": path})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
|
||||||
|
return
|
||||||
|
for entry in result.get("entries", []):
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
if not entry_path:
|
||||||
|
continue
|
||||||
|
if entry.get("type") == "directory":
|
||||||
|
await _walk(entry_path, depth + 1)
|
||||||
|
elif entry.get("type") == "file":
|
||||||
|
if ext_set:
|
||||||
|
dot_pos = entry_path.rfind(".")
|
||||||
|
file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else ""
|
||||||
|
if file_ext not in ext_set:
|
||||||
|
continue
|
||||||
|
all_files.append(entry_path)
|
||||||
|
|
||||||
|
for root in paths:
|
||||||
|
await _walk(root, depth=0)
|
||||||
|
|
||||||
|
if last_run_at is None:
|
||||||
|
return all_files
|
||||||
|
|
||||||
|
last_run_ms = int(last_run_at.timestamp() * 1000)
|
||||||
|
filtered: list[str] = []
|
||||||
|
for file_path in all_files:
|
||||||
|
try:
|
||||||
|
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
||||||
|
modified_at = meta.get("modifiedAt")
|
||||||
|
if modified_at is None:
|
||||||
|
filtered.append(file_path)
|
||||||
|
continue
|
||||||
|
if isinstance(modified_at, (int, float)):
|
||||||
|
mod_ms = int(modified_at)
|
||||||
|
else:
|
||||||
|
mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000)
|
||||||
|
if mod_ms > last_run_ms:
|
||||||
|
filtered.append(file_path)
|
||||||
|
except Exception:
|
||||||
|
filtered.append(file_path)
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
# ── Code-based entity fetchers ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_projects() -> list[dict]:
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="select", table="projects")
|
||||||
|
return result.get("rows", [])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to fetch projects: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
_DOMAIN_TABLE: dict[str, str] = {
|
||||||
|
"tasks": "tasks",
|
||||||
|
"notes": "notes",
|
||||||
|
"timelines": "timelines",
|
||||||
|
"projects": "projects",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
|
||||||
|
table = _DOMAIN_TABLE.get(domain)
|
||||||
|
if not table:
|
||||||
|
return []
|
||||||
|
filters: dict[str, Any] = {}
|
||||||
|
if project_id != "standalone" and domain != "projects":
|
||||||
|
filters["projectId"] = project_id
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table=table,
|
||||||
|
filters=filters if filters else None,
|
||||||
|
)
|
||||||
|
return result.get("rows", [])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
|
||||||
|
if not rows:
|
||||||
|
return f"No existing {domain}."
|
||||||
|
lines: list[str] = []
|
||||||
|
for r in rows:
|
||||||
|
if domain == "tasks":
|
||||||
|
desc = r.get("description") or ""
|
||||||
|
desc_part = f" — {desc[:120]}" if desc else ""
|
||||||
|
assignee = r.get("assignee") or r.get("assignees") or ""
|
||||||
|
due = r.get("dueDate") or r.get("due_date") or ""
|
||||||
|
meta = ", ".join(filter(None, [
|
||||||
|
f"priority: {r.get('priority', '')}" if r.get("priority") else "",
|
||||||
|
f"assignee: {assignee}" if assignee else "",
|
||||||
|
f"due: {due}" if due else "",
|
||||||
|
]))
|
||||||
|
lines.append(
|
||||||
|
f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}"
|
||||||
|
f" ({meta}, id: {r['id']})"
|
||||||
|
)
|
||||||
|
elif domain == "notes":
|
||||||
|
snippet = (r.get("content") or "")[:200].replace("\n", " ")
|
||||||
|
snippet_part = f"\n Preview: {snippet}" if snippet else ""
|
||||||
|
lines.append(
|
||||||
|
f" - {r.get('title', '')} (id: {r['id']}){snippet_part}"
|
||||||
|
)
|
||||||
|
elif domain == "timelines":
|
||||||
|
lines.append(
|
||||||
|
f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})"
|
||||||
|
)
|
||||||
|
elif domain == "projects":
|
||||||
|
summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120]
|
||||||
|
summary_part = f" — {summary}" if summary else ""
|
||||||
|
lines.append(
|
||||||
|
f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}"
|
||||||
|
f" (id: {r['id']})"
|
||||||
|
)
|
||||||
|
return f"Existing {domain}:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 1: LLM file classifier ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _classify_file(
|
||||||
|
file_path: str,
|
||||||
|
file_content: str,
|
||||||
|
projects: list[dict],
|
||||||
|
config_data_types: list[str],
|
||||||
|
) -> tuple[str, list[str], str | None]:
|
||||||
|
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
|
||||||
|
|
||||||
|
if not file_content.strip():
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
valid_project_ids = {p["id"] for p in projects}
|
||||||
|
|
||||||
|
def _fmt_project(p: dict) -> str:
|
||||||
|
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
|
||||||
|
summary_part = f" — {summary[:100]}" if summary else ""
|
||||||
|
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
|
||||||
|
|
||||||
|
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
|
||||||
|
|
||||||
|
domain_definitions = "\n".join(
|
||||||
|
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
|
||||||
|
for d in config_data_types
|
||||||
|
if d in _DOMAIN_DESCRIPTIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
system = _STEP1_SYSTEM_PROMPT.format(
|
||||||
|
domain_definitions=domain_definitions,
|
||||||
|
projects_list=projects_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = get_llm()
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=system),
|
||||||
|
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
|
||||||
|
])
|
||||||
|
raw = _as_text(response.content).strip()
|
||||||
|
if raw.startswith("```"):
|
||||||
|
raw = raw.split("```")[1]
|
||||||
|
if raw.startswith("json"):
|
||||||
|
raw = raw[4:]
|
||||||
|
parsed = json.loads(raw.strip())
|
||||||
|
raw_project_id: str = str(parsed.get("project_id") or "new")
|
||||||
|
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
|
||||||
|
new_project_name: str | None = (
|
||||||
|
str(parsed["new_project_name"]).strip() or None
|
||||||
|
if project_id == "new" and parsed.get("new_project_name")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
domains: list[str] = [
|
||||||
|
d for d in parsed.get("domains", [])
|
||||||
|
if d in config_data_types
|
||||||
|
]
|
||||||
|
if not domains:
|
||||||
|
domains = list(config_data_types)
|
||||||
|
return project_id, domains, new_project_name
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"agent_runner: step1 classification failed for %r: %s", file_path, exc
|
||||||
|
)
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local agent runner (two-step per file) ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_local_agent(user_id: str, trigger_data: dict[str, Any]) -> None:
|
||||||
|
"""Execute a local directory agent run.
|
||||||
|
|
||||||
|
In the microservice world, trigger_data is a serialized dict from
|
||||||
|
the REST route (forwarded via Redis), containing the agent config
|
||||||
|
fields and run_context.
|
||||||
|
|
||||||
|
set_current_user() must be called BEFORE this function.
|
||||||
|
"""
|
||||||
|
run_context: dict = trigger_data.get("run_context", {})
|
||||||
|
agent_id = run_context.get("agent_id", str(uuid.uuid4()))
|
||||||
|
run_id = run_context.get("run_id")
|
||||||
|
|
||||||
|
_running_agents.add(agent_id)
|
||||||
|
|
||||||
|
# Extract config from trigger payload
|
||||||
|
directory_paths: list[str] = trigger_data.get("directory_paths", [])
|
||||||
|
if not directory_paths:
|
||||||
|
directory = trigger_data.get("directory", "")
|
||||||
|
if directory:
|
||||||
|
directory_paths = [directory]
|
||||||
|
|
||||||
|
data_types: list[str] = trigger_data.get("data_types", [])
|
||||||
|
file_extensions: list[str] = trigger_data.get("file_extensions", [])
|
||||||
|
prompt_template: str = trigger_data.get("prompt_template", "")
|
||||||
|
last_run_at_raw = trigger_data.get("last_run_at")
|
||||||
|
last_run_at: datetime | None = None
|
||||||
|
if last_run_at_raw:
|
||||||
|
if isinstance(last_run_at_raw, str):
|
||||||
|
last_run_at = datetime.fromisoformat(last_run_at_raw)
|
||||||
|
elif isinstance(last_run_at_raw, (int, float)):
|
||||||
|
last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
|
||||||
|
custom_section = (
|
||||||
|
f"User instructions:\n{prompt_template}"
|
||||||
|
if prompt_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create or load run log
|
||||||
|
run_log_id = run_id
|
||||||
|
if not run_log_id:
|
||||||
|
async with async_session() as db:
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
run_log_id = run_log.id
|
||||||
|
|
||||||
|
try:
|
||||||
|
# ── Scan directories ─────────────────────────────────────────
|
||||||
|
logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id)
|
||||||
|
file_paths = await _scan_directories(
|
||||||
|
paths=directory_paths,
|
||||||
|
extensions=file_extensions,
|
||||||
|
last_run_at=last_run_at,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not file_paths:
|
||||||
|
await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Fetch all projects once ──────────────────────────────────
|
||||||
|
projects = await _fetch_projects()
|
||||||
|
|
||||||
|
for file_path in file_paths:
|
||||||
|
try:
|
||||||
|
file_result = await execute_on_client(
|
||||||
|
action="read_file_content", data={"path": file_path}
|
||||||
|
)
|
||||||
|
file_content: str = file_result.get("content", "")
|
||||||
|
if not file_content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
items_processed += 1
|
||||||
|
|
||||||
|
# Step 1 — classify file
|
||||||
|
project_id, domains, new_project_name = await _classify_file(
|
||||||
|
file_path=file_path,
|
||||||
|
file_content=file_content,
|
||||||
|
projects=projects,
|
||||||
|
config_data_types=data_types,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2 — resolve project_id, fetch entities, process
|
||||||
|
if project_id == "new":
|
||||||
|
proj_name = new_project_name or "Untitled Project"
|
||||||
|
try:
|
||||||
|
proj_result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="projects",
|
||||||
|
data={"name": proj_name, "clientId": None},
|
||||||
|
)
|
||||||
|
created = proj_result.get("row", {})
|
||||||
|
effective_project_id = created.get("id", "standalone")
|
||||||
|
if "id" in created:
|
||||||
|
projects.append(created)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc)
|
||||||
|
effective_project_id = "standalone"
|
||||||
|
proj_name = "unknown"
|
||||||
|
project_context = (
|
||||||
|
f"Project: {proj_name} (id: {effective_project_id}). "
|
||||||
|
"Always set projectId to this id on every record you create."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
effective_project_id = project_id
|
||||||
|
proj = next((p for p in projects if p["id"] == project_id), None)
|
||||||
|
proj_name = proj.get("name", project_id) if proj else project_id
|
||||||
|
project_context = (
|
||||||
|
f"Project: {proj_name} (id: {project_id}). "
|
||||||
|
"Always set projectId to this id on every record you create."
|
||||||
|
)
|
||||||
|
|
||||||
|
domains = [d for d in domains if d != "projects"]
|
||||||
|
|
||||||
|
existing_blocks: list[str] = []
|
||||||
|
for domain in domains:
|
||||||
|
rows = await _fetch_domain_entities(domain, effective_project_id)
|
||||||
|
existing_blocks.append(_format_entities_for_context(domain, rows))
|
||||||
|
|
||||||
|
existing_context = "\n\n".join(existing_blocks)
|
||||||
|
|
||||||
|
system_prompt = _PROCESSING_SYSTEM_PROMPT.format(
|
||||||
|
existing_context=existing_context,
|
||||||
|
project_context=project_context,
|
||||||
|
data_types=", ".join(domains),
|
||||||
|
custom_prompt_section=custom_section,
|
||||||
|
)
|
||||||
|
|
||||||
|
processing_tools = _build_processing_tools(domains)
|
||||||
|
|
||||||
|
result_text = await _run_agent_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_message=(
|
||||||
|
f"Process this file and extract relevant information.\n\n"
|
||||||
|
f"File: {file_path}\n\nContent:\n{file_content}"
|
||||||
|
),
|
||||||
|
tools=processing_tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s file=%r result=%s",
|
||||||
|
run_log_id, file_path, result_text[:200],
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Error processing '{file_path}': {exc}")
|
||||||
|
logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Agent run failed: {exc}")
|
||||||
|
logger.error("agent_runner: run=%s failed: %s", run_log_id, exc)
|
||||||
|
finally:
|
||||||
|
_running_agents.discard(agent_id)
|
||||||
|
|
||||||
|
# ── Finalise ────────────────────────────────────────────────────
|
||||||
|
if errors and items_processed == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=items_created,
|
||||||
|
errors=errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Notify Electron that the run is complete via Redis
|
||||||
|
if run_context:
|
||||||
|
try:
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps({
|
||||||
|
"type": "run_complete",
|
||||||
|
"run_context": run_context,
|
||||||
|
"status": final_status,
|
||||||
|
}))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
||||||
|
|
||||||
|
|
||||||
|
async def run_cloud_agent(user_id: str, config_id: str) -> None:
|
||||||
|
"""Execute a cloud connector agent run.
|
||||||
|
|
||||||
|
Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches
|
||||||
|
messages from the provider, and runs LLM extraction.
|
||||||
|
|
||||||
|
set_current_user() must be called BEFORE this function.
|
||||||
|
"""
|
||||||
|
from app.integrations import decrypt_token, encrypt_token, get_provider
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
config = result.scalar_one_or_none()
|
||||||
|
if config is None:
|
||||||
|
logger.error("agent_runner: cloud config %s not found", config_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create run log
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=config.id,
|
||||||
|
agent_type="cloud",
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
run_log_id = run_log.id
|
||||||
|
|
||||||
|
# ── Decrypt OAuth token ────────────────────────────────────────
|
||||||
|
if not config.oauth_token_encrypted:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status="error",
|
||||||
|
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
||||||
|
except ValueError as exc:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Failed to decrypt OAuth token: {exc}"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Instantiate provider ──────────────────────────────────────
|
||||||
|
try:
|
||||||
|
provider = get_provider(config.provider, credentials_info)
|
||||||
|
except ValueError as exc:
|
||||||
|
await _finalize_run(run_log_id, status="error", errors=[str(exc)])
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Fetch messages ────────────────────────────────────────────
|
||||||
|
since: datetime | None = config.last_run_at
|
||||||
|
if since is None:
|
||||||
|
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
|
||||||
|
if since.tzinfo is None:
|
||||||
|
since = since.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
items_processed = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if config.provider == "gmail":
|
||||||
|
raw_messages = await provider.fetch_messages(
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "outlook":
|
||||||
|
raw_messages = await provider.fetch_emails(
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "teams":
|
||||||
|
raw_messages = await provider.fetch_messages(
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_messages = []
|
||||||
|
except RuntimeError as exc:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Provider fetch failed: {exc}"],
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: cloud agent %s fetched %d item(s) from %s",
|
||||||
|
config.id, len(raw_messages), config.provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Extract + insert via LLM ─────────────────────────────────
|
||||||
|
try:
|
||||||
|
processing_tools = _build_processing_tools(config.data_types)
|
||||||
|
custom_section = (
|
||||||
|
f"User instructions:\n{config.prompt_template}"
|
||||||
|
if config.prompt_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
for msg in raw_messages:
|
||||||
|
content_text = msg.as_text
|
||||||
|
if not content_text:
|
||||||
|
continue
|
||||||
|
items_processed += 1
|
||||||
|
|
||||||
|
processing_prompt = _CLOUD_PROCESSING_PROMPT.format(
|
||||||
|
data_types=", ".join(config.data_types),
|
||||||
|
project_context="Determine the appropriate project from the message context.",
|
||||||
|
file_list=f"Message from {config.provider} (id: {msg.id})",
|
||||||
|
custom_prompt_section=custom_section,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _run_agent_with_tools(
|
||||||
|
system_prompt=processing_prompt,
|
||||||
|
user_message=f"Process this message content:\n\n{content_text[:8000]}",
|
||||||
|
tools=processing_tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Agent run failed: {exc}")
|
||||||
|
|
||||||
|
# ── Persist refreshed token ───────────────────────────────────
|
||||||
|
refreshed = getattr(provider, "refreshed_credentials", None)
|
||||||
|
if refreshed:
|
||||||
|
try:
|
||||||
|
new_encrypted = encrypt_token(refreshed)
|
||||||
|
async with async_session() as db:
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
||||||
|
)
|
||||||
|
cfg_row = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg_row:
|
||||||
|
cfg_row.oauth_token_encrypted = new_encrypted
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to persist refreshed token: %s", exc)
|
||||||
|
|
||||||
|
# ── Finalise ──────────────────────────────────────────────────
|
||||||
|
if errors and items_processed == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=0,
|
||||||
|
errors=errors,
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helper ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _finalize_run(
|
||||||
|
run_log_id: int | str,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
items_processed: int = 0,
|
||||||
|
items_created: int = 0,
|
||||||
|
errors: list[str] | None = None,
|
||||||
|
update_config_last_run: bool = False,
|
||||||
|
config_id: str | None = None,
|
||||||
|
config_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Persist the run outcome and optionally update last_run_at on the config."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentRunLog).where(AgentRunLog.id == run_log_id)
|
||||||
|
)
|
||||||
|
managed = result.scalar_one_or_none()
|
||||||
|
if managed is None:
|
||||||
|
logger.warning("agent_runner: run_log %s not found for finalization", run_log_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
managed.status = status
|
||||||
|
managed.items_processed = items_processed
|
||||||
|
managed.items_created = items_created
|
||||||
|
managed.errors = errors or []
|
||||||
|
managed.completed_at = now
|
||||||
|
|
||||||
|
if update_config_last_run and config_id:
|
||||||
|
if config_type == "local":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
elif config_type == "cloud":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc)
|
||||||
1
services/batch-agent/app/agents/__init__.py
Normal file
1
services/batch-agent/app/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Batch Agent Service domain agents and filesystem tools."""
|
||||||
83
services/batch-agent/app/agents/filesystem_agent.py
Normal file
83
services/batch-agent/app/agents/filesystem_agent.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_directory(path: str) -> str:
|
||||||
|
"""List files and folders in a local directory on the user's device.
|
||||||
|
|
||||||
|
Returns a formatted listing of entries with name, type (file/directory),
|
||||||
|
and full path.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
return f"Directory '{path}' is empty or does not exist."
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
entry_type = entry.get("type", "unknown")
|
||||||
|
entry_name = entry.get("name", "")
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||||
|
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_file_content(path: str) -> str:
|
||||||
|
"""Read the text content of a local file on the user's device.
|
||||||
|
|
||||||
|
Returns the file content as a string. Large files may be truncated
|
||||||
|
by the Electron client.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return f"File '{path}' is empty or could not be read."
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_file_metadata(path: str) -> str:
|
||||||
|
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||||
|
|
||||||
|
Returns a formatted summary of the file's metadata.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_file_metadata",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
size = result.get("size", "unknown")
|
||||||
|
created = result.get("createdAt", "unknown")
|
||||||
|
modified = result.get("modifiedAt", "unknown")
|
||||||
|
extension = result.get("extension", "unknown")
|
||||||
|
name = result.get("name", path)
|
||||||
|
return (
|
||||||
|
f"File: {name}\n"
|
||||||
|
f" Extension: {extension}\n"
|
||||||
|
f" Size: {size} bytes\n"
|
||||||
|
f" Created: {created}\n"
|
||||||
|
f" Modified: {modified}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FILESYSTEM_TOOLS: list[Any] = [
|
||||||
|
list_directory,
|
||||||
|
read_file_content,
|
||||||
|
get_file_metadata,
|
||||||
|
]
|
||||||
110
services/batch-agent/app/agents/note_agent.py
Normal file
110
services/batch-agent/app/agents/note_agent.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""Note agent — Markdown note management.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.ws_context and app.llm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.llm import embed
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_notes(project_id: str = "") -> str:
|
||||||
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="notes",
|
||||||
|
filters={"projectId": normalized_project_id or None},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No notes found."
|
||||||
|
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_note(note_id: str) -> str:
|
||||||
|
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||||
|
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
|
||||||
|
row = result.get("row")
|
||||||
|
if not row:
|
||||||
|
return f"Note {note_id} not found."
|
||||||
|
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_note(title: str, content: str, project_id: str = "") -> str:
|
||||||
|
"""Create a new note."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="notes",
|
||||||
|
data={
|
||||||
|
"title": title,
|
||||||
|
"content": content,
|
||||||
|
"projectId": project_id or None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
vector = await embed(content)
|
||||||
|
await execute_on_client(
|
||||||
|
action="vector_upsert",
|
||||||
|
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
||||||
|
vector=vector,
|
||||||
|
)
|
||||||
|
return f"Note created: '{row['title']}' (id: {row['id']})."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_note(note_id: str, title: str = "", content: str = "") -> str:
|
||||||
|
"""Update an existing note. Only pass fields that should change."""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if content:
|
||||||
|
updates["content"] = content
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="notes",
|
||||||
|
data={"id": note_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
if content:
|
||||||
|
vector = await embed(content)
|
||||||
|
await execute_on_client(
|
||||||
|
action="vector_upsert",
|
||||||
|
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
||||||
|
vector=vector,
|
||||||
|
)
|
||||||
|
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_note(note_id: str) -> str:
|
||||||
|
"""Delete a note permanently by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="notes", data={"id": note_id})
|
||||||
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
NOTE_TOOLS: list[Any] = [
|
||||||
|
list_notes,
|
||||||
|
get_note,
|
||||||
|
create_note,
|
||||||
|
update_note,
|
||||||
|
delete_note,
|
||||||
|
]
|
||||||
110
services/batch-agent/app/agents/project_agent.py
Normal file
110
services/batch-agent/app/agents/project_agent.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""Project agent — full lifecycle management.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_projects(client_id: str = "", include_archived: int = 0) -> str:
|
||||||
|
"""List projects, optionally filtered by client_id."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="projects",
|
||||||
|
filters={
|
||||||
|
"clientId": client_id or None,
|
||||||
|
"includeArchived": bool(include_archived),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No projects found."
|
||||||
|
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_all_projects() -> str:
|
||||||
|
"""List every project regardless of client or status."""
|
||||||
|
result = await execute_on_client(action="select", table="projects")
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No projects found."
|
||||||
|
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_project(project_id: str) -> str:
|
||||||
|
"""Fetch a single project by its UUID."""
|
||||||
|
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
|
||||||
|
row = result.get("row")
|
||||||
|
if not row:
|
||||||
|
return f"Project {project_id} not found."
|
||||||
|
return (
|
||||||
|
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
|
||||||
|
f"clientId: {row.get('clientId', 'none')})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_project(name: str, client_id: str = "") -> str:
|
||||||
|
"""Create a new project."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="projects",
|
||||||
|
data={"name": name, "clientId": client_id or None},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Project created: '{row['name']}' (id: {row['id']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_project(
|
||||||
|
project_id: str,
|
||||||
|
name: str = "",
|
||||||
|
client_id: str = "",
|
||||||
|
status: str = "",
|
||||||
|
ai_summary: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Update a project. Only pass fields that should change."""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if name:
|
||||||
|
updates["name"] = name
|
||||||
|
if client_id:
|
||||||
|
updates["clientId"] = client_id
|
||||||
|
if status:
|
||||||
|
updates["status"] = status
|
||||||
|
if ai_summary:
|
||||||
|
updates["aiSummary"] = ai_summary
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="projects",
|
||||||
|
data={"id": project_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_project(project_id: str) -> str:
|
||||||
|
"""Permanently delete a project."""
|
||||||
|
await execute_on_client(action="delete", table="projects", data={"id": project_id})
|
||||||
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_TOOLS: list[Any] = [
|
||||||
|
list_projects,
|
||||||
|
list_all_projects,
|
||||||
|
get_project,
|
||||||
|
create_project,
|
||||||
|
update_project,
|
||||||
|
delete_project,
|
||||||
|
]
|
||||||
197
services/batch-agent/app/agents/task_agent.py
Normal file
197
services/batch-agent/app/agents/task_agent.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""Task agent — full CRUD for tasks and task comments.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_tasks(
|
||||||
|
project_id: str = "",
|
||||||
|
status: str = "",
|
||||||
|
search: str = "",
|
||||||
|
order_by: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""List tasks, optionally filtered by project_id, status, search, or order_by."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="tasks",
|
||||||
|
filters={
|
||||||
|
"projectId": normalized_project_id or None,
|
||||||
|
"status": status or None,
|
||||||
|
"search": search or None,
|
||||||
|
"orderBy": order_by or None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No tasks found matching the given filters."
|
||||||
|
lines = [
|
||||||
|
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_task(
|
||||||
|
title: str,
|
||||||
|
description: str = "",
|
||||||
|
status: str = "todo",
|
||||||
|
priority: str = "medium",
|
||||||
|
assignees: str = "[]",
|
||||||
|
due_date: int = 0,
|
||||||
|
project_id: str = "",
|
||||||
|
is_ai_suggested: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Create a new task."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="tasks",
|
||||||
|
data={
|
||||||
|
"title": title,
|
||||||
|
"description": description or None,
|
||||||
|
"status": status,
|
||||||
|
"priority": priority,
|
||||||
|
"assignee": assignees,
|
||||||
|
"dueDate": due_date or None,
|
||||||
|
"projectId": project_id or None,
|
||||||
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return (
|
||||||
|
f"Task created: '{row['title']}' "
|
||||||
|
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_task(
|
||||||
|
task_id: str,
|
||||||
|
title: str = "",
|
||||||
|
description: str = "",
|
||||||
|
status: str = "",
|
||||||
|
priority: str = "",
|
||||||
|
assignees: str = "",
|
||||||
|
due_date: int = -1,
|
||||||
|
project_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Update fields on an existing task. Only pass fields you want to change."""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if description:
|
||||||
|
updates["description"] = description
|
||||||
|
if status:
|
||||||
|
updates["status"] = status
|
||||||
|
if priority:
|
||||||
|
updates["priority"] = priority
|
||||||
|
if assignees:
|
||||||
|
updates["assignee"] = assignees
|
||||||
|
if due_date != -1:
|
||||||
|
updates["dueDate"] = due_date or None
|
||||||
|
if project_id:
|
||||||
|
updates["projectId"] = project_id
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="tasks",
|
||||||
|
data={"id": task_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_task(task_id: str) -> str:
|
||||||
|
"""Delete a task permanently by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
|
||||||
|
return f"Task {task_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_tasks_due_today() -> str:
|
||||||
|
"""List all tasks whose due date falls on today's date."""
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
||||||
|
end_ms = start_ms + 86_400_000 - 1
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="tasks",
|
||||||
|
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No tasks are due today."
|
||||||
|
lines = [
|
||||||
|
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_task_comments(task_id: str) -> str:
|
||||||
|
"""List all comments on a task by its UUID."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="taskComments",
|
||||||
|
filters={"taskId": task_id},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return f"No comments found for task {task_id}."
|
||||||
|
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||||
|
"""Add a comment to a task."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="taskComments",
|
||||||
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
|
)
|
||||||
|
row = result.get("row", {})
|
||||||
|
row_author = row.get("author", author)
|
||||||
|
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||||
|
row_comment_id = row.get("id", "unknown")
|
||||||
|
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_task_comment(comment_id: str) -> str:
|
||||||
|
"""Delete a task comment by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
|
||||||
|
return f"Comment {comment_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
TASK_TOOLS: list[Any] = [
|
||||||
|
list_tasks,
|
||||||
|
create_task,
|
||||||
|
update_task,
|
||||||
|
delete_task,
|
||||||
|
list_tasks_due_today,
|
||||||
|
list_task_comments,
|
||||||
|
add_task_comment,
|
||||||
|
delete_task_comment,
|
||||||
|
]
|
||||||
88
services/batch-agent/app/agents/timeline_agent.py
Normal file
88
services/batch-agent/app/agents/timeline_agent.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""Timeline agent — project milestone management.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="timelines",
|
||||||
|
filters={"projectId": normalized_project_id or None},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No timelines found."
|
||||||
|
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_timeline(
|
||||||
|
project_id: str, title: str, date: int, is_ai_suggested: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Create a project timeline (milestone)."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="timelines",
|
||||||
|
data={
|
||||||
|
"projectId": project_id,
|
||||||
|
"title": title,
|
||||||
|
"date": date,
|
||||||
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_timeline(timeline_id: str, title: str = "", date: int = -1) -> str:
|
||||||
|
"""Update a timeline. Only pass fields that should change."""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if date != -1:
|
||||||
|
updates["date"] = date
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="timelines",
|
||||||
|
data={"id": timeline_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_timeline(timeline_id: str) -> str:
|
||||||
|
"""Delete a timeline permanently by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||||
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
TIMELINE_TOOLS: list[Any] = [
|
||||||
|
list_timelines,
|
||||||
|
create_timeline,
|
||||||
|
update_timeline,
|
||||||
|
delete_timeline,
|
||||||
|
]
|
||||||
108
services/batch-agent/app/integrations/__init__.py
Normal file
108
services/batch-agent/app/integrations/__init__.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Cloud provider integration utilities.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from shared.config instead of app.config.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
* Shared message dataclasses (EmailMessage, ChatMessage)
|
||||||
|
* get_provider() — factory for Gmail/MS Graph clients
|
||||||
|
* encrypt_token() / decrypt_token() — Fernet-based OAuth token encryption
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmailMessage:
|
||||||
|
id: str
|
||||||
|
subject: str
|
||||||
|
sender: str
|
||||||
|
body_text: str
|
||||||
|
date: datetime
|
||||||
|
labels: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{labels_str}\n"
|
||||||
|
f"Subject: {self.subject}\n\n"
|
||||||
|
f"{self.body_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
id: str
|
||||||
|
content: str
|
||||||
|
sender: str
|
||||||
|
channel: str | None
|
||||||
|
date: datetime
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
channel_str = f" [channel: {self.channel}]" if self.channel else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{channel_str}\n\n"
|
||||||
|
f"{self.content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fernet() -> Fernet:
|
||||||
|
key = settings.OAUTH_ENCRYPTION_KEY
|
||||||
|
if not key:
|
||||||
|
raise RuntimeError(
|
||||||
|
"OAUTH_ENCRYPTION_KEY is not set. "
|
||||||
|
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
||||||
|
)
|
||||||
|
return Fernet(key.encode() if isinstance(key, str) else key)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_token(token_info: dict) -> str:
|
||||||
|
if not isinstance(token_info, dict) or not token_info:
|
||||||
|
raise ValueError("token_info must be a non-empty dict")
|
||||||
|
plaintext = json.dumps(token_info).encode("utf-8")
|
||||||
|
return _get_fernet().encrypt(plaintext).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_token(encrypted: str) -> dict:
|
||||||
|
try:
|
||||||
|
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
|
||||||
|
return json.loads(plaintext)
|
||||||
|
except (InvalidToken, json.JSONDecodeError) as exc:
|
||||||
|
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(
|
||||||
|
provider: str,
|
||||||
|
credentials_info: dict,
|
||||||
|
) -> "GmailClient | MSGraphClient":
|
||||||
|
if provider == "gmail":
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
return GmailClient(credentials_info)
|
||||||
|
if provider in {"outlook", "teams"}:
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(credentials_info)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown cloud provider {provider!r}. "
|
||||||
|
"Supported: 'gmail', 'outlook', 'teams'."
|
||||||
|
)
|
||||||
252
services/batch-agent/app/integrations/gmail.py
Normal file
252
services/batch-agent/app/integrations/gmail.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""Gmail API client for cloud agent integration.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.integrations instead of
|
||||||
|
app.integrations (same relative path within the service).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import email
|
||||||
|
import html
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.integrations import EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gmail_query(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
labels: list[str] = cfg.get("labels", [])
|
||||||
|
if labels:
|
||||||
|
if len(labels) == 1:
|
||||||
|
parts.append(f"label:{labels[0]}")
|
||||||
|
else:
|
||||||
|
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
||||||
|
parts.append(f"({label_expr})")
|
||||||
|
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
for sender in senders:
|
||||||
|
parts.append(f"from:{sender}")
|
||||||
|
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw_html: str) -> str:
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
||||||
|
decoded = html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_body(payload: dict[str, Any]) -> str:
|
||||||
|
mime_type: str = payload.get("mimeType", "")
|
||||||
|
body: dict = payload.get("body", {})
|
||||||
|
parts: list[dict] = payload.get("parts", [])
|
||||||
|
|
||||||
|
if mime_type == "text/plain":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if mime_type == "text/html":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return _strip_html(raw)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
plain_fallback = ""
|
||||||
|
for part in parts:
|
||||||
|
part_mime = part.get("mimeType", "")
|
||||||
|
if part_mime == "text/plain":
|
||||||
|
return _parse_body(part)
|
||||||
|
if part_mime == "text/html" and not plain_fallback:
|
||||||
|
plain_fallback = _parse_body(part)
|
||||||
|
if part_mime.startswith("multipart/"):
|
||||||
|
nested = _parse_body(part)
|
||||||
|
if nested:
|
||||||
|
return nested
|
||||||
|
return plain_fallback
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_date(raw: str) -> datetime:
|
||||||
|
try:
|
||||||
|
parsed = email.utils.parsedate_to_datetime(raw)
|
||||||
|
if parsed.tzinfo is None:
|
||||||
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||||
|
return parsed.astimezone(timezone.utc)
|
||||||
|
except Exception:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
class GmailClient:
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
expiry_str: str | None = credentials_info.get("expiry")
|
||||||
|
expiry: datetime | None = None
|
||||||
|
if expiry_str:
|
||||||
|
try:
|
||||||
|
expiry = datetime.fromisoformat(
|
||||||
|
expiry_str.replace("Z", "+00:00")
|
||||||
|
).replace(tzinfo=timezone.utc)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._credentials = Credentials(
|
||||||
|
token=credentials_info.get("token"),
|
||||||
|
refresh_token=credentials_info.get("refresh_token"),
|
||||||
|
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||||
|
client_id=credentials_info.get("client_id"),
|
||||||
|
client_secret=credentials_info.get("client_secret"),
|
||||||
|
scopes=credentials_info.get("scopes"),
|
||||||
|
expiry=expiry,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
query = _build_gmail_query(filter_config, since)
|
||||||
|
logger.debug("gmail: executing search query %r", query)
|
||||||
|
return await asyncio.to_thread(self._fetch_sync, query)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
creds = self._credentials
|
||||||
|
if not creds.valid and creds.expired:
|
||||||
|
return None
|
||||||
|
if creds.token != self._credentials_info.get("token"):
|
||||||
|
result = {
|
||||||
|
"token": creds.token,
|
||||||
|
"refresh_token": creds.refresh_token,
|
||||||
|
"token_uri": creds.token_uri,
|
||||||
|
"client_id": creds.client_id,
|
||||||
|
"client_secret": creds.client_secret,
|
||||||
|
"scopes": list(creds.scopes or []),
|
||||||
|
}
|
||||||
|
if creds.expiry:
|
||||||
|
result["expiry"] = creds.expiry.isoformat()
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
||||||
|
import googleapiclient.discovery
|
||||||
|
import googleapiclient.errors
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
|
||||||
|
if self._credentials.expired and self._credentials.refresh_token:
|
||||||
|
try:
|
||||||
|
self._credentials.refresh(Request())
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
|
||||||
|
|
||||||
|
service = googleapiclient.discovery.build(
|
||||||
|
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
||||||
|
)
|
||||||
|
user_api = service.users()
|
||||||
|
|
||||||
|
ids: list[str] = []
|
||||||
|
page_token: str | None = None
|
||||||
|
while len(ids) < _MAX_MESSAGES:
|
||||||
|
batch_size = min(100, _MAX_MESSAGES - len(ids))
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"userId": "me",
|
||||||
|
"maxResults": batch_size,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
kwargs["q"] = query
|
||||||
|
if page_token:
|
||||||
|
kwargs["pageToken"] = page_token
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = user_api.messages().list(**kwargs).execute()
|
||||||
|
except googleapiclient.errors.HttpError as exc:
|
||||||
|
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
|
||||||
|
|
||||||
|
for msg in resp.get("messages", []):
|
||||||
|
ids.append(msg["id"])
|
||||||
|
|
||||||
|
page_token = resp.get("nextPageToken")
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info("gmail: fetching %d message(s)", len(ids))
|
||||||
|
|
||||||
|
messages: list[EmailMessage] = []
|
||||||
|
for msg_id in ids:
|
||||||
|
try:
|
||||||
|
msg = user_api.messages().get(
|
||||||
|
userId="me", id=msg_id, format="full"
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
h["name"].lower(): h["value"]
|
||||||
|
for h in msg.get("payload", {}).get("headers", [])
|
||||||
|
}
|
||||||
|
subject = headers.get("subject", "(no subject)")
|
||||||
|
sender = headers.get("from", "unknown")
|
||||||
|
date_raw = headers.get("date", "")
|
||||||
|
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
|
||||||
|
labels = msg.get("labelIds", [])
|
||||||
|
|
||||||
|
messages.append(EmailMessage(
|
||||||
|
id=msg_id,
|
||||||
|
subject=subject,
|
||||||
|
sender=sender,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
labels=labels,
|
||||||
|
))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("gmail: skipping message %s: %s", msg_id, exc)
|
||||||
|
|
||||||
|
logger.info("gmail: returned %d message(s)", len(messages))
|
||||||
|
return messages
|
||||||
266
services/batch-agent/app/integrations/ms_graph.py
Normal file
266
services/batch-agent/app/integrations/ms_graph.py
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
"""Microsoft Graph API client for Outlook and Teams.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import settings from shared.config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from app.integrations import ChatMessage, EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
_MAX_EMAILS = 200
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw: str) -> str:
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw)
|
||||||
|
import html as _html
|
||||||
|
decoded = _html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _odata_datetime(dt: datetime) -> str:
|
||||||
|
utc = dt.astimezone(timezone.utc)
|
||||||
|
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_email_filter(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
clauses: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
if senders:
|
||||||
|
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
|
||||||
|
clauses.append("(" + " or ".join(sender_clauses) + ")")
|
||||||
|
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
|
||||||
|
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
if to_dt.tzinfo is None:
|
||||||
|
to_dt = to_dt.replace(tzinfo=timezone.utc)
|
||||||
|
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " and ".join(clauses)
|
||||||
|
|
||||||
|
|
||||||
|
class MSGraphClient:
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
self._access_token: str = credentials_info.get("access_token", "")
|
||||||
|
self._original_access_token: str = self._access_token
|
||||||
|
self._refresh_token: str | None = credentials_info.get("refresh_token")
|
||||||
|
|
||||||
|
def _auth_headers(self) -> dict[str, str]:
|
||||||
|
return {"Authorization": f"Bearer {self._access_token}"}
|
||||||
|
|
||||||
|
async def _refresh_access_token(self) -> None:
|
||||||
|
import msal
|
||||||
|
|
||||||
|
app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=settings.MS_CLIENT_ID,
|
||||||
|
client_credential=settings.MS_CLIENT_SECRET,
|
||||||
|
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
|
||||||
|
)
|
||||||
|
scopes: list[str] = self._credentials_info.get("scope", "").split()
|
||||||
|
if not scopes:
|
||||||
|
scopes = ["https://graph.microsoft.com/.default"]
|
||||||
|
|
||||||
|
result = app.acquire_token_by_refresh_token(
|
||||||
|
self._refresh_token,
|
||||||
|
scopes=scopes,
|
||||||
|
)
|
||||||
|
if "access_token" not in result:
|
||||||
|
error = result.get("error_description", result.get("error", "unknown"))
|
||||||
|
raise RuntimeError(f"MS Graph token refresh failed: {error}")
|
||||||
|
|
||||||
|
self._access_token = result["access_token"]
|
||||||
|
if "refresh_token" in result:
|
||||||
|
self._refresh_token = result["refresh_token"]
|
||||||
|
self._credentials_info["refresh_token"] = result["refresh_token"]
|
||||||
|
self._credentials_info["access_token"] = self._access_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
if self._access_token != self._original_access_token:
|
||||||
|
return {**self._credentials_info, "access_token": self._access_token}
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get(
|
||||||
|
self,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
url: str,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
retry_on_401: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
|
||||||
|
await self._refresh_access_token()
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 429:
|
||||||
|
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
async def fetch_emails(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
odata_filter = _build_email_filter(filter_config, since)
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"$top": 50,
|
||||||
|
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
|
||||||
|
"$orderby": "receivedDateTime desc",
|
||||||
|
}
|
||||||
|
if odata_filter:
|
||||||
|
params["$filter"] = odata_filter
|
||||||
|
|
||||||
|
emails: list[EmailMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/messages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(emails) < _MAX_EMAILS:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
for item in data.get("value", []):
|
||||||
|
emails.append(self._parse_email(item))
|
||||||
|
if len(emails) >= _MAX_EMAILS:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
|
||||||
|
return emails
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[ChatMessage]:
|
||||||
|
cfg = filter_config or {}
|
||||||
|
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
|
||||||
|
params: dict[str, Any] = {"$top": 50}
|
||||||
|
if since:
|
||||||
|
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
|
||||||
|
|
||||||
|
messages: list[ChatMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(messages) < _MAX_MESSAGES:
|
||||||
|
try:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if exc.response.status_code in (403, 404):
|
||||||
|
logger.warning(
|
||||||
|
"ms_graph: /me/chats/getAllMessages not available (%d)",
|
||||||
|
exc.response.status_code,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
raise
|
||||||
|
|
||||||
|
for item in data.get("value", []):
|
||||||
|
msg = self._parse_teams_message(item)
|
||||||
|
if channel_filter and msg.channel:
|
||||||
|
if not any(c in msg.channel.lower() for c in channel_filter):
|
||||||
|
continue
|
||||||
|
messages.append(msg)
|
||||||
|
if len(messages) >= _MAX_MESSAGES:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_email(item: dict[str, Any]) -> EmailMessage:
|
||||||
|
subject: str = item.get("subject", "(no subject)") or "(no subject)"
|
||||||
|
sender_block = item.get("from", {}) or {}
|
||||||
|
sender_addr = (
|
||||||
|
(sender_block.get("emailAddress") or {}).get("address", "unknown")
|
||||||
|
)
|
||||||
|
date_str: str = item.get("receivedDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_body: str = body_block.get("content", "")
|
||||||
|
if content_type == "html":
|
||||||
|
body_text = _strip_html(raw_body)
|
||||||
|
else:
|
||||||
|
body_text = raw_body or item.get("bodyPreview", "")
|
||||||
|
body_text = body_text[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return EmailMessage(
|
||||||
|
id=item.get("id", ""),
|
||||||
|
subject=subject,
|
||||||
|
sender=sender_addr,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
|
||||||
|
msg_id: str = item.get("id", "")
|
||||||
|
sender_block = (item.get("from") or {}).get("user") or {}
|
||||||
|
sender: str = sender_block.get("displayName", "unknown")
|
||||||
|
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
|
||||||
|
|
||||||
|
date_str: str = item.get("createdDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_content: str = body_block.get("content", "")
|
||||||
|
content = _strip_html(raw_content) if content_type == "html" else raw_content
|
||||||
|
content = content[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
id=msg_id,
|
||||||
|
content=content,
|
||||||
|
sender=sender,
|
||||||
|
channel=channel,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
385
services/batch-agent/app/journey.py
Normal file
385
services/batch-agent/app/journey.py
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
"""Chatbot Journey — guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: imports from app.agents.filesystem_agent
|
||||||
|
and app.llm instead of monolith paths. Session state is in-memory (could
|
||||||
|
be moved to Redis for horizontal scaling in the future).
|
||||||
|
|
||||||
|
Journey flow:
|
||||||
|
1. Redis consumer dispatches ``journey_start`` with basic agent config.
|
||||||
|
2. Server creates an in-memory session, runs the setup LLM with
|
||||||
|
file-system tools to explore the directory, returns first question.
|
||||||
|
3. ``journey_message`` frames drive the conversation.
|
||||||
|
4. After 3-5 turns the LLM emits PROMPT_TEMPLATE_START / _END block.
|
||||||
|
5. Server parses the block and returns ``journey_reply`` with ``done=True``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
|
from app.llm import get_llm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
|
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
||||||
|
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||||
|
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||||
|
|
||||||
|
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||||
|
_MAX_TURNS: int = 15
|
||||||
|
_MAX_TOOL_STEPS: int = 6
|
||||||
|
|
||||||
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JourneySession:
|
||||||
|
session_id: str
|
||||||
|
user_id: str
|
||||||
|
agent_type: str # "local" | "cloud"
|
||||||
|
directory: str
|
||||||
|
data_types: list[str]
|
||||||
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
system_prompt: str = ""
|
||||||
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
|
||||||
|
|
||||||
|
|
||||||
|
# session_id → session
|
||||||
|
_sessions: dict[str, JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||||
|
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||||
|
s = _sessions.get(session_id)
|
||||||
|
if s is None or s.is_expired():
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
return None
|
||||||
|
if s.user_id != user_id:
|
||||||
|
return None
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
|
Your job is to understand exactly what data the user wants to extract from their
|
||||||
|
local directory and produce a detailed prompt_template that a separate AI will use
|
||||||
|
as its instruction set.
|
||||||
|
|
||||||
|
The extraction agent already has this base behaviour built in:
|
||||||
|
- Reads each file using file-system tools.
|
||||||
|
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
|
||||||
|
- Sets isAiSuggested=1 on every new record.
|
||||||
|
- Only extracts data explicitly present in the files — it never invents information.
|
||||||
|
The user's custom prompt is appended AFTER this base behaviour, so focus on
|
||||||
|
what to look for and how to map it — not on the general extraction mechanics.
|
||||||
|
|
||||||
|
You have access to file-system tools to explore the user's directory:
|
||||||
|
- list_directory: to see folder structure
|
||||||
|
- read_file_content: to peek at file contents
|
||||||
|
- get_file_metadata: to check file info
|
||||||
|
|
||||||
|
The user's configured directory is: {directory}
|
||||||
|
Target data types: {data_types}
|
||||||
|
|
||||||
|
IMPORTANT — project assignment is handled automatically by the main agent runner
|
||||||
|
before the custom prompt is ever used. You MUST NOT ask the user about projects,
|
||||||
|
projectId, or how to link records to projects. Never include projectId logic or
|
||||||
|
project creation instructions in the generated prompt_template.
|
||||||
|
|
||||||
|
Start by exploring the directory to understand its structure. Then ask concise,
|
||||||
|
focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||||
|
1. The type and format of the source content (confirmed by your exploration).
|
||||||
|
2. How fields should be mapped (e.g. filename → task title).
|
||||||
|
3. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||||
|
4. Any special handling, date extraction, or exclusions.
|
||||||
|
|
||||||
|
Once you reach 90% confidence, output the final prompt_template between these exact
|
||||||
|
markers on their own lines:
|
||||||
|
|
||||||
|
{template_start}
|
||||||
|
<the complete extraction prompt here>
|
||||||
|
{template_end}
|
||||||
|
|
||||||
|
The prompt_template must be a self-contained instruction for an AI that reads files
|
||||||
|
and must perform CRUD operations using tools to create records. It should specify:
|
||||||
|
- What entity types to create (tasks, notes, timelines) — never projects.
|
||||||
|
- How to map file content to record fields (camelCase: title, status, priority,
|
||||||
|
dueDate, content, etc.) — never include projectId.
|
||||||
|
- That isAiSuggested must be set to 1 on every new record.
|
||||||
|
- Concrete examples of mappings based on what you discovered in the directory.
|
||||||
|
|
||||||
|
{existing_section}\
|
||||||
|
Keep asking clarifying questions until you are at least 90% confident you have
|
||||||
|
enough information to generate an accurate prompt_template. Once you reach that
|
||||||
|
confidence level, stop asking and produce the final template immediately.
|
||||||
|
Begin by exploring the directory, then ask your first question.\
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_system_prompt(
|
||||||
|
directory: str,
|
||||||
|
data_types: list[str],
|
||||||
|
existing_template: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
existing_section = (
|
||||||
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
|
f"---\n{existing_template}\n---\n"
|
||||||
|
if existing_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
|
directory=directory,
|
||||||
|
data_types=", ".join(data_types),
|
||||||
|
template_start=_TEMPLATE_START,
|
||||||
|
template_end=_TEMPLATE_END,
|
||||||
|
existing_section=existing_section,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_template(text: str) -> str | None:
|
||||||
|
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
||||||
|
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
||||||
|
return None
|
||||||
|
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
||||||
|
end_idx = text.index(_TEMPLATE_END)
|
||||||
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM call with tool support ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_llm_with_tools(
|
||||||
|
system_prompt: str,
|
||||||
|
history: list[dict[str, Any]],
|
||||||
|
tools: list[Any],
|
||||||
|
) -> str:
|
||||||
|
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||||
|
|
||||||
|
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||||
|
continue until a final text response is produced.
|
||||||
|
"""
|
||||||
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
|
for turn in history:
|
||||||
|
if turn["role"] == "user":
|
||||||
|
messages.append(HumanMessage(content=turn["content"]))
|
||||||
|
else:
|
||||||
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
|
llm = get_llm(model=None, temperature=0.4)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(_MAX_TOOL_STEPS):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"journey: tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey: tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:800],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
# Fallback: exceeded max tool steps.
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Journey handlers (called from redis_consumer) ────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_start(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_start`` request.
|
||||||
|
|
||||||
|
Creates a session, runs the setup LLM with directory exploration,
|
||||||
|
and returns the ``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
agent_type = frame.get("agent_type", "local")
|
||||||
|
directory = frame.get("directory", "")
|
||||||
|
data_types = frame.get("data_types", [])
|
||||||
|
existing_template = frame.get("existing_template")
|
||||||
|
|
||||||
|
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||||
|
system_prompt = _build_system_prompt(directory, data_types, existing_template)
|
||||||
|
|
||||||
|
session = JourneySession(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
directory=directory,
|
||||||
|
data_types=data_types,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
seed_history: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||||
|
]
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history=seed_history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.extend(seed_history)
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
_sessions[session_id] = session
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey: session %s started for user %s (directory=%s)",
|
||||||
|
session_id,
|
||||||
|
user_id,
|
||||||
|
directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_message(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_message`` request.
|
||||||
|
|
||||||
|
Appends the user message, calls the LLM, and returns the
|
||||||
|
``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
message = frame.get("message", "")
|
||||||
|
|
||||||
|
session = get_journey_session(session_id, user_id)
|
||||||
|
if session is None:
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "Journey session not found or expired. Please start a new setup.",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
session.history.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
if not done:
|
||||||
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
|
if turns >= _MAX_TURNS:
|
||||||
|
nudge_content = (
|
||||||
|
"[System: You have enough information. Please generate the final "
|
||||||
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
|
)
|
||||||
|
session.history.append({"role": "user", "content": nudge_content})
|
||||||
|
|
||||||
|
nudge_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||||
|
|
||||||
|
prompt_template = _extract_template(nudge_reply)
|
||||||
|
if prompt_template is not None:
|
||||||
|
done = True
|
||||||
|
ai_reply = nudge_reply
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
if _TEMPLATE_START in ai_reply
|
||||||
|
else "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
logger.info("journey: session %s completed for user %s", session_id, user_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
76
services/batch-agent/app/llm.py
Normal file
76
services/batch-agent/app/llm.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
|
Identical to services/chat/app/llm.py. Uses shared.config.settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_litellm import ChatLiteLLM
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
|
if model.startswith("anthropic/"):
|
||||||
|
return settings.ANTHROPIC_API_KEY or None
|
||||||
|
if model.startswith("gemini/") or model.startswith("google/"):
|
||||||
|
return settings.GOOGLE_API_KEY or None
|
||||||
|
if model.startswith("cerebras/"):
|
||||||
|
return settings.CEREBRAS_API_KEY or None
|
||||||
|
if model.startswith("github_copilot/"):
|
||||||
|
return None
|
||||||
|
return settings.OPENAI_API_KEY or None
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm(
|
||||||
|
*,
|
||||||
|
model: str | None = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
model = model or settings.LLM_MODEL
|
||||||
|
|
||||||
|
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||||
|
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||||
|
|
||||||
|
if "/" in model:
|
||||||
|
return ChatLiteLLM(model=model, temperature=temperature)
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=_api_key_for_model(model),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_router_llm(
|
||||||
|
*,
|
||||||
|
temperature: float = 0,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
||||||
|
|
||||||
|
|
||||||
|
async def embed(text: str) -> list[float]:
|
||||||
|
model = settings.LLM_EMBED_MODEL
|
||||||
|
|
||||||
|
if model.startswith("github_copilot/") or "/" in model:
|
||||||
|
response = await litellm.aembedding(model=model, input=[text])
|
||||||
|
return response.data[0]["embedding"]
|
||||||
|
|
||||||
|
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
|
response = await client.embeddings.create(model=model, input=text)
|
||||||
|
return response.data[0].embedding
|
||||||
57
services/batch-agent/app/main.py
Normal file
57
services/batch-agent/app/main.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Batch Agent Service — FastAPI application.
|
||||||
|
|
||||||
|
Owns: agent_runner (local directory + cloud connectors), journey builder,
|
||||||
|
filesystem_agent, integrations (Gmail, MS Graph).
|
||||||
|
|
||||||
|
Communicates with WS Gateway via Redis:
|
||||||
|
- Subscribes to batch:request:{user_id} (journey_start, journey_message)
|
||||||
|
- Publishes to ws:out:{user_id} (journey replies + tool calls)
|
||||||
|
- BRPOP on tool:result:{call_id} (tool-call round-trip, 30s timeout)
|
||||||
|
- SET+EX on journey:{user_id} (journey session state, TTL 1800s)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.redis_consumer import start_consumer
|
||||||
|
from app.routes import router
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
logger.info("batch-agent: starting Redis consumer")
|
||||||
|
task = asyncio.create_task(start_consumer())
|
||||||
|
yield
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
logger.info("batch-agent: Redis consumer stopped")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Adiuva Batch Agent Service", lifespan=lifespan)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["GET", "POST"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health() -> dict[str, str]:
|
||||||
|
return {"status": "ok", "service": "batch-agent"}
|
||||||
141
services/batch-agent/app/redis_consumer.py
Normal file
141
services/batch-agent/app/redis_consumer.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Redis consumer for the Batch Agent Service.
|
||||||
|
|
||||||
|
Subscribes to batch:request:* (pattern) and dispatches:
|
||||||
|
- journey_start → handle_journey_start
|
||||||
|
- journey_message → handle_journey_message
|
||||||
|
- agent_trigger → run_local_agent / run_cloud_agent
|
||||||
|
|
||||||
|
Results are published back to ws:out:{user_id} via Redis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.redis import redis_client, batch_request_channel, ws_out_channel
|
||||||
|
|
||||||
|
from app.ws_context import set_current_user, clear_current_user
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None:
|
||||||
|
"""Publish a frame to the user's WS outbound channel."""
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(payload))
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None:
|
||||||
|
"""Handle a journey_start request from WS Gateway."""
|
||||||
|
from app.journey import handle_journey_start
|
||||||
|
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_start(user_id, data)
|
||||||
|
await _publish_to_user(user_id, reply)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc)
|
||||||
|
await _publish_to_user(user_id, {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": data.get("session_id", ""),
|
||||||
|
"message": f"Journey setup failed: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None:
|
||||||
|
"""Handle a journey_message from WS Gateway."""
|
||||||
|
from app.journey import handle_journey_message
|
||||||
|
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_message(user_id, data)
|
||||||
|
await _publish_to_user(user_id, reply)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc)
|
||||||
|
await _publish_to_user(user_id, {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": data.get("session_id", ""),
|
||||||
|
"message": f"Journey processing failed: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None:
|
||||||
|
"""Handle an agent_trigger request from the REST route (forwarded via Redis)."""
|
||||||
|
from app.agent_runner import run_local_agent
|
||||||
|
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
await run_local_agent(user_id, data)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc)
|
||||||
|
await _publish_to_user(user_id, {
|
||||||
|
"type": "run_complete",
|
||||||
|
"status": "error",
|
||||||
|
"run_context": data.get("run_context", {}),
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None:
|
||||||
|
"""Route a batch request to the correct handler."""
|
||||||
|
msg_type = message_data.get("type", "")
|
||||||
|
|
||||||
|
if msg_type == "journey_start":
|
||||||
|
await _handle_journey_start(user_id, message_data)
|
||||||
|
elif msg_type == "journey_message":
|
||||||
|
await _handle_journey_message(user_id, message_data)
|
||||||
|
elif msg_type == "agent_trigger":
|
||||||
|
await _handle_agent_trigger(user_id, message_data)
|
||||||
|
else:
|
||||||
|
logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def start_consumer() -> None:
|
||||||
|
"""Subscribe to batch:request:* and dispatch incoming frames."""
|
||||||
|
pubsub = redis_client.pubsub()
|
||||||
|
await pubsub.psubscribe("batch:request:*")
|
||||||
|
logger.info("batch-agent: subscribed to batch:request:*")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for message in pubsub.listen():
|
||||||
|
if message["type"] != "pmessage":
|
||||||
|
continue
|
||||||
|
|
||||||
|
channel: str = message["channel"]
|
||||||
|
if isinstance(channel, bytes):
|
||||||
|
channel = channel.decode()
|
||||||
|
|
||||||
|
# Extract user_id from channel: batch:request:{user_id}
|
||||||
|
parts = channel.split(":", 2)
|
||||||
|
if len(parts) < 3:
|
||||||
|
continue
|
||||||
|
user_id = parts[2]
|
||||||
|
|
||||||
|
raw = message["data"]
|
||||||
|
if isinstance(raw, bytes):
|
||||||
|
raw = raw.decode()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("batch-agent: invalid JSON on channel %s", channel)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Dispatch in a separate task to avoid blocking the consumer
|
||||||
|
asyncio.create_task(_dispatch(user_id, data))
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("batch-agent: consumer shutting down")
|
||||||
|
finally:
|
||||||
|
await pubsub.punsubscribe("batch:request:*")
|
||||||
208
services/batch-agent/app/routes.py
Normal file
208
services/batch-agent/app/routes.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""Agent REST routes — catalog, billing checks, trigger.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas.
|
||||||
|
Agent trigger dispatches via Redis to the consumer instead of spawning
|
||||||
|
an in-process background task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Header, HTTPException, status
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.models import AgentRunLog
|
||||||
|
from shared.redis import redis_client, batch_request_channel
|
||||||
|
|
||||||
|
from app.agent_runner import is_agent_running
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||||
|
|
||||||
|
# ── Tier feature limits ───────────────────────────────────────────────
|
||||||
|
# Mirrors app/billing/tier_manager.py FEATURES dict.
|
||||||
|
FEATURES: dict[str, dict] = {
|
||||||
|
"free": {"batch_active": 1, "batch_runs_per_day": 3},
|
||||||
|
"pro": {"batch_active": 5, "batch_runs_per_day": 20},
|
||||||
|
"power": {"batch_active": 20, "batch_runs_per_day": 100},
|
||||||
|
"team": {"batch_active": -1, "batch_runs_per_day": -1},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _dt_ms(dt: datetime) -> int:
|
||||||
|
return int(dt.timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
|
def _to_data_types(values: list[str]) -> list[str]:
|
||||||
|
normalize = {
|
||||||
|
"task": "tasks", "tasks": "tasks",
|
||||||
|
"note": "notes", "notes": "notes",
|
||||||
|
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||||
|
"project": "projects", "projects": "projects",
|
||||||
|
}
|
||||||
|
seen: set[str] = set()
|
||||||
|
result: list[str] = []
|
||||||
|
for v in values:
|
||||||
|
mapped = normalize.get(v)
|
||||||
|
if mapped and mapped not in seen:
|
||||||
|
seen.add(mapped)
|
||||||
|
result.append(mapped)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
|
if limit != -1 and current_count >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
|
)
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
|
async def _enforce_run_frequency(tier: str, user_id: str) -> None:
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||||
|
if limit == -1:
|
||||||
|
return
|
||||||
|
today_start = datetime.now(timezone.utc).replace(
|
||||||
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
|
)
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.count(AgentRunLog.id)).where(
|
||||||
|
AgentRunLog.user_id == user_id,
|
||||||
|
AgentRunLog.started_at >= today_start,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
runs_today: int = result.scalar_one()
|
||||||
|
|
||||||
|
if runs_today >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Daily batch run limit ({limit}) reached for your tier.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/catalog")
|
||||||
|
async def get_agent_catalog(
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
) -> list[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "local_directory",
|
||||||
|
"name": "Local Directory Monitor",
|
||||||
|
"description": "Watches local directories, extracts data from files using AI",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "gmail",
|
||||||
|
"name": "Gmail Connector",
|
||||||
|
"description": "Scans Gmail inbox, extracts tasks/notes from emails",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "teams",
|
||||||
|
"name": "Microsoft Teams Connector",
|
||||||
|
"description": "Monitors Teams messages, extracts action items",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "outlook",
|
||||||
|
"name": "Outlook Connector",
|
||||||
|
"description": "Scans Outlook inbox, extracts tasks/notes",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Can-create check ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/can-create")
|
||||||
|
async def can_create_agent(
|
||||||
|
body: dict,
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
||||||
|
) -> dict:
|
||||||
|
active_agents = body.get("active_agents", 0)
|
||||||
|
limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"]
|
||||||
|
allowed = limit == -1 or active_agents < limit
|
||||||
|
return {
|
||||||
|
"allowed": allowed,
|
||||||
|
"tier": x_user_tier,
|
||||||
|
"active_agents": active_agents,
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Trigger ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/trigger", status_code=status.HTTP_202_ACCEPTED)
|
||||||
|
async def trigger_agent_run(
|
||||||
|
body: dict,
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
||||||
|
) -> dict:
|
||||||
|
"""Trigger a local agent run — creates run log and dispatches via Redis."""
|
||||||
|
active_agents = body.get("active_agents", 0)
|
||||||
|
_enforce_agent_limit(x_user_tier, active_agents)
|
||||||
|
await _enforce_run_frequency(x_user_tier, x_user_id)
|
||||||
|
|
||||||
|
stable_agent_id = body.get("agent_id") or str(uuid.uuid4())
|
||||||
|
|
||||||
|
if is_agent_running(stable_agent_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Agent is already running.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create run log in DB
|
||||||
|
async with async_session() as db:
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=stable_agent_id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=x_user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
run_log_id = run_log.id
|
||||||
|
|
||||||
|
run_context = {
|
||||||
|
"type": "agent_batch",
|
||||||
|
"run_id": run_log_id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Dispatch to the Redis consumer for processing
|
||||||
|
trigger_data = {
|
||||||
|
"type": "agent_trigger",
|
||||||
|
"directory": body.get("directory", ""),
|
||||||
|
"directory_paths": [body.get("directory", "")] if body.get("directory") else [],
|
||||||
|
"data_types": _to_data_types(body.get("what_to_extract", [])),
|
||||||
|
"file_extensions": body.get("file_extensions", []),
|
||||||
|
"prompt_template": body.get("custom_agent_prompt", ""),
|
||||||
|
"device_id": body.get("device_id", ""),
|
||||||
|
"run_context": run_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = batch_request_channel(x_user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(trigger_data))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": run_log_id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
"agent_type": "local",
|
||||||
|
"status": "running",
|
||||||
|
"items_processed": 0,
|
||||||
|
"items_created": 0,
|
||||||
|
"errors": [],
|
||||||
|
"started_at": _dt_ms(run_log.started_at),
|
||||||
|
"completed_at": None,
|
||||||
|
}
|
||||||
135
services/batch-agent/app/ws_context.py
Normal file
135
services/batch-agent/app/ws_context.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""WebSocket context for Batch Agent Service — Redis-based tool call round-trip.
|
||||||
|
|
||||||
|
Same pattern as services/chat/app/ws_context.py: publishes tool_call frames
|
||||||
|
to Redis ws:out:{user_id} and awaits BRPOP on tool:result:{call_id}.
|
||||||
|
|
||||||
|
Additionally provides set_client_executor / clear_client_executor stubs
|
||||||
|
for backward compatibility with the agent_runner code (which originally
|
||||||
|
used a DeviceConnectionManager callback). In the microservice world these
|
||||||
|
are no-ops — execute_on_client() always uses the Redis path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Callable, Coroutine
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from shared.redis import redis_client, tool_result_key, ws_out_channel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TOOL_CALL_TIMEOUT = 30 # seconds — BRPOP timeout
|
||||||
|
|
||||||
|
# Per-request user_id context var (set before agent run)
|
||||||
|
_current_user_id: ContextVar[str | None] = ContextVar("_current_user_id", default=None)
|
||||||
|
|
||||||
|
# Optional collector for debug / logging
|
||||||
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
|
"_tool_result_collector", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_current_user(user_id: str) -> None:
|
||||||
|
_current_user_id.set(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_current_user() -> None:
|
||||||
|
_current_user_id.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||||
|
_tool_result_collector.set(lst)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_tool_result_collector() -> None:
|
||||||
|
_tool_result_collector.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Compatibility shims ──────────────────────────────────────────────────
|
||||||
|
# agent_runner.py originally called set_client_executor / clear_client_executor
|
||||||
|
# with a DeviceConnectionManager callback. In the microservice world the
|
||||||
|
# Redis-based execute_on_client replaces this, so these are no-ops that
|
||||||
|
# keep the agent_runner code unchanged.
|
||||||
|
|
||||||
|
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]] | None) -> None:
|
||||||
|
"""No-op — kept for agent_runner compatibility."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def clear_client_executor() -> None:
|
||||||
|
"""No-op — kept for agent_runner compatibility."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_on_client(
|
||||||
|
action: str,
|
||||||
|
table: str | None = None,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
vector: list[float] | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send a tool_call to Electron via Redis and await the result.
|
||||||
|
|
||||||
|
1. Build tool_call payload
|
||||||
|
2. Publish to ws:out:{user_id} (WS Gateway forwards to Electron)
|
||||||
|
3. BRPOP on tool:result:{call_id} (WS Gateway pushes when Electron replies)
|
||||||
|
4. Return result dict
|
||||||
|
|
||||||
|
Raises RuntimeError if no user_id is set or if the call times out.
|
||||||
|
"""
|
||||||
|
user_id = _current_user_id.get()
|
||||||
|
if not user_id:
|
||||||
|
raise RuntimeError(
|
||||||
|
"execute_on_client() called without a user_id — "
|
||||||
|
"set_current_user() must be called first."
|
||||||
|
)
|
||||||
|
|
||||||
|
call_id = str(uuid4())
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"type": "tool_call",
|
||||||
|
"id": call_id,
|
||||||
|
"action": action,
|
||||||
|
}
|
||||||
|
if table is not None:
|
||||||
|
payload["table"] = table
|
||||||
|
if data is not None:
|
||||||
|
payload["data"] = data
|
||||||
|
if filters is not None:
|
||||||
|
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
||||||
|
if vector is not None:
|
||||||
|
payload["vector"] = vector
|
||||||
|
if limit is not None:
|
||||||
|
payload["limit"] = limit
|
||||||
|
|
||||||
|
# Publish tool_call to WS Gateway → Electron
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(payload))
|
||||||
|
|
||||||
|
# Wait for Electron's tool_result
|
||||||
|
result_key = tool_result_key(call_id)
|
||||||
|
response = await redis_client.brpop(result_key, timeout=_TOOL_CALL_TIMEOUT)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Tool call {call_id} timed out after {_TOOL_CALL_TIMEOUT}s — "
|
||||||
|
f"device may be offline or unresponsive."
|
||||||
|
)
|
||||||
|
|
||||||
|
# response is (key, value) tuple
|
||||||
|
_, raw = response
|
||||||
|
result = json.loads(raw)
|
||||||
|
|
||||||
|
# Collect for debug if requested
|
||||||
|
collector = _tool_result_collector.get(None)
|
||||||
|
if collector is not None:
|
||||||
|
collector.append({
|
||||||
|
"action": action,
|
||||||
|
"table": table,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
20
services/batch-agent/requirements.txt
Normal file
20
services/batch-agent/requirements.txt
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
asyncpg>=0.30.0
|
||||||
|
redis>=5.0.0
|
||||||
|
cryptography>=42.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
langchain-core>=0.3.0
|
||||||
|
langchain-openai>=0.3.0
|
||||||
|
langchain-litellm>=0.3.0
|
||||||
|
litellm>=1.50.0
|
||||||
|
openai>=1.50.0
|
||||||
|
httpx>=0.27.0
|
||||||
|
croniter>=2.0.0
|
||||||
|
google-api-python-client>=2.130.0
|
||||||
|
google-auth>=2.30.0
|
||||||
|
msal>=1.28.0
|
||||||
15
services/billing/README.md
Normal file
15
services/billing/README.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Billing Service
|
||||||
|
|
||||||
|
Owns: Stripe integration, tier management, subscription CRUD.
|
||||||
|
|
||||||
|
## Tables owned (write)
|
||||||
|
- `subscriptions`
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
- `POST /billing/checkout`
|
||||||
|
- `POST /billing/webhook` (Stripe, no JWT auth)
|
||||||
|
- `GET /billing/subscription`
|
||||||
|
- `DELETE /billing/subscription`
|
||||||
|
|
||||||
|
## Redis channels
|
||||||
|
- Publish: `tier:changed:{user_id}` on tier change
|
||||||
0
services/billing/app/__init__.py
Normal file
0
services/billing/app/__init__.py
Normal file
36
services/chat/Dockerfile
Normal file
36
services/chat/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY services/chat/requirements.txt ./requirements.txt
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Shared module
|
||||||
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Service source
|
||||||
|
COPY services/chat/app/ app/
|
||||||
|
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Chat service is CPU-bound (LLM calls) — use multiple workers
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "2", \
|
||||||
|
"--timeout", "120"]
|
||||||
21
services/chat/README.md
Normal file
21
services/chat/README.md
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Chat Service
|
||||||
|
|
||||||
|
Owns: deep_agent (home + floating chat), memory middleware, domain agents
|
||||||
|
(task, note, project, timeline), LLM orchestration.
|
||||||
|
|
||||||
|
## Tables owned
|
||||||
|
- `memory_core`
|
||||||
|
- `memory_associative`
|
||||||
|
- `memory_episodic`
|
||||||
|
- `memory_proactive`
|
||||||
|
|
||||||
|
## Tables read (cross-service)
|
||||||
|
- `users` (for encryption_key — memory decryption)
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
- `POST /chat` (REST fallback)
|
||||||
|
|
||||||
|
## Redis channels
|
||||||
|
- Subscribe: `chat:request:{user_id}`
|
||||||
|
- Publish: `ws:out:{user_id}` (stream frames + tool calls)
|
||||||
|
- BRPOP: `tool:result:{call_id}` (30s timeout)
|
||||||
0
services/chat/app/__init__.py
Normal file
0
services/chat/app/__init__.py
Normal file
1
services/chat/app/agents/__init__.py
Normal file
1
services/chat/app/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Chat Service domain agents."""
|
||||||
142
services/chat/app/agents/note_agent.py
Normal file
142
services/chat/app/agents/note_agent.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""Note agent — Markdown note management (list, get, create, update, delete).
|
||||||
|
|
||||||
|
Adapted for Chat Service: import from app.ws_context and app.llm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.llm import embed
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
NOTE_SYSTEM_PROMPT = (
|
||||||
|
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||||
|
"and delete Markdown notes in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - content is always Markdown; preserve formatting when updating\n"
|
||||||
|
" - project_id is optional; link a note to a project when mentioned\n"
|
||||||
|
" - When updating, call get_note first if you need to read existing content\n"
|
||||||
|
" before appending or replacing sections\n"
|
||||||
|
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||||
|
" when the user is working within a specific project\n"
|
||||||
|
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
||||||
|
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||||
|
" is already in the note (retrieved via get_note)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_notes(project_id: str = "") -> str:
|
||||||
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="notes",
|
||||||
|
filters={"projectId": normalized_project_id or None},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No notes found."
|
||||||
|
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_note(note_id: str) -> str:
|
||||||
|
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||||
|
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
|
||||||
|
row = result.get("row")
|
||||||
|
if not row:
|
||||||
|
return f"Note {note_id} not found."
|
||||||
|
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_note(
|
||||||
|
title: str,
|
||||||
|
content: str,
|
||||||
|
project_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Create a new note.
|
||||||
|
title: note heading (required)
|
||||||
|
content: Markdown body text (required)
|
||||||
|
project_id: optional UUID linking this note to a project
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="notes",
|
||||||
|
data={
|
||||||
|
"title": title,
|
||||||
|
"content": content,
|
||||||
|
"projectId": project_id or None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
# Index the note content in the vector store.
|
||||||
|
vector = await embed(content)
|
||||||
|
await execute_on_client(
|
||||||
|
action="vector_upsert",
|
||||||
|
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
||||||
|
vector=vector,
|
||||||
|
)
|
||||||
|
return f"Note created: '{row['title']}' (id: {row['id']})."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_note(
|
||||||
|
note_id: str,
|
||||||
|
title: str = "",
|
||||||
|
content: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Update an existing note. Only pass fields that should change.
|
||||||
|
note_id: UUID of the note (required)
|
||||||
|
If you need to preserve existing content, call get_note first.
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if content:
|
||||||
|
updates["content"] = content
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="notes",
|
||||||
|
data={"id": note_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
# Re-index if content changed.
|
||||||
|
if content:
|
||||||
|
vector = await embed(content)
|
||||||
|
await execute_on_client(
|
||||||
|
action="vector_upsert",
|
||||||
|
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
||||||
|
vector=vector,
|
||||||
|
)
|
||||||
|
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_note(note_id: str) -> str:
|
||||||
|
"""Delete a note permanently by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="notes", data={"id": note_id})
|
||||||
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
NOTE_TOOLS: list[Any] = [
|
||||||
|
list_notes,
|
||||||
|
get_note,
|
||||||
|
create_note,
|
||||||
|
update_note,
|
||||||
|
delete_note,
|
||||||
|
]
|
||||||
146
services/chat/app/agents/project_agent.py
Normal file
146
services/chat/app/agents/project_agent.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""Project agent — full lifecycle management (list, get, create, update, archive, delete).
|
||||||
|
|
||||||
|
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
PROJECT_SYSTEM_PROMPT = (
|
||||||
|
"You are a project management assistant. You help users create, find,\n"
|
||||||
|
"update, and archive projects in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: active, archived\n"
|
||||||
|
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
||||||
|
" - ai_summary is populated only when the user asks for a project summary;\n"
|
||||||
|
" derive it from context data — do not fabricate content\n"
|
||||||
|
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
||||||
|
" user wants a complete cross-client view including archived projects\n"
|
||||||
|
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
||||||
|
" list_projects if you only have a project name\n"
|
||||||
|
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
||||||
|
" only call delete_project when the user explicitly confirms deletion."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_projects(
|
||||||
|
client_id: str = "",
|
||||||
|
include_archived: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""List projects, optionally filtered by client_id.
|
||||||
|
include_archived: 1 to include archived projects, 0 for active only (default).
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="projects",
|
||||||
|
filters={
|
||||||
|
"clientId": client_id or None,
|
||||||
|
"includeArchived": bool(include_archived),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No projects found."
|
||||||
|
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_all_projects() -> str:
|
||||||
|
"""List every project regardless of client or status.
|
||||||
|
Use only when the user wants a complete cross-client overview.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(action="select", table="projects")
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No projects found."
|
||||||
|
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_project(project_id: str) -> str:
|
||||||
|
"""Fetch a single project by its UUID."""
|
||||||
|
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
|
||||||
|
row = result.get("row")
|
||||||
|
if not row:
|
||||||
|
return f"Project {project_id} not found."
|
||||||
|
return (
|
||||||
|
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
|
||||||
|
f"clientId: {row.get('clientId', 'none')})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_project(
|
||||||
|
name: str,
|
||||||
|
client_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Create a new project.
|
||||||
|
name: human-readable project name (required)
|
||||||
|
client_id: optional UUID of the owning client
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="projects",
|
||||||
|
data={"name": name, "clientId": client_id or None},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Project created: '{row['name']}' (id: {row['id']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_project(
|
||||||
|
project_id: str,
|
||||||
|
name: str = "",
|
||||||
|
client_id: str = "",
|
||||||
|
status: str = "",
|
||||||
|
ai_summary: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Update a project. Only pass fields that should change.
|
||||||
|
project_id: UUID of the project (required)
|
||||||
|
status: active | archived
|
||||||
|
ai_summary: AI-generated summary text (populate only when explicitly requested)
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if name:
|
||||||
|
updates["name"] = name
|
||||||
|
if client_id:
|
||||||
|
updates["clientId"] = client_id
|
||||||
|
if status:
|
||||||
|
updates["status"] = status
|
||||||
|
if ai_summary:
|
||||||
|
updates["aiSummary"] = ai_summary
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="projects",
|
||||||
|
data={"id": project_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_project(project_id: str) -> str:
|
||||||
|
"""Permanently delete a project and orphan its tasks.
|
||||||
|
IMPORTANT: prefer update_project(status='archived') unless the user
|
||||||
|
has explicitly confirmed they want permanent deletion.
|
||||||
|
"""
|
||||||
|
await execute_on_client(action="delete", table="projects", data={"id": project_id})
|
||||||
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_TOOLS: list[Any] = [
|
||||||
|
list_projects,
|
||||||
|
list_all_projects,
|
||||||
|
get_project,
|
||||||
|
create_project,
|
||||||
|
update_project,
|
||||||
|
delete_project,
|
||||||
|
]
|
||||||
240
services/chat/app/agents/task_agent.py
Normal file
240
services/chat/app/agents/task_agent.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""Task agent — full CRUD for tasks and task comments.
|
||||||
|
|
||||||
|
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
TASK_SYSTEM_PROMPT = (
|
||||||
|
"You are a task management assistant for a project workspace.\n"
|
||||||
|
"You create, update, list, and track tasks and their comments.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: todo, in_progress, done\n"
|
||||||
|
" - priority must be one of: high, medium, low\n"
|
||||||
|
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
||||||
|
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
||||||
|
" - project_id is optional; link to a project when the user mentions one\n"
|
||||||
|
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
||||||
|
" did not explicitly request; 0 otherwise\n"
|
||||||
|
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
|
||||||
|
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
||||||
|
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Always confirm the action in plain, user-friendly language."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_tasks(
|
||||||
|
project_id: str = "",
|
||||||
|
status: str = "",
|
||||||
|
search: str = "",
|
||||||
|
order_by: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="tasks",
|
||||||
|
filters={
|
||||||
|
"projectId": normalized_project_id or None,
|
||||||
|
"status": status or None,
|
||||||
|
"search": search or None,
|
||||||
|
"orderBy": order_by or None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No tasks found matching the given filters."
|
||||||
|
lines = [
|
||||||
|
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_task(
|
||||||
|
title: str,
|
||||||
|
description: str = "",
|
||||||
|
status: str = "todo",
|
||||||
|
priority: str = "medium",
|
||||||
|
assignees: str = "[]",
|
||||||
|
due_date: int = 0,
|
||||||
|
project_id: str = "",
|
||||||
|
is_ai_suggested: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Create a new task.
|
||||||
|
title: task title (required)
|
||||||
|
description: optional details
|
||||||
|
status: todo | in_progress | done (default: todo)
|
||||||
|
priority: high | medium | low (default: medium)
|
||||||
|
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
|
||||||
|
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||||
|
project_id: optional UUID of the parent project
|
||||||
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="tasks",
|
||||||
|
data={
|
||||||
|
"title": title,
|
||||||
|
"description": description or None,
|
||||||
|
"status": status,
|
||||||
|
"priority": priority,
|
||||||
|
"assignee": assignees,
|
||||||
|
"dueDate": due_date or None,
|
||||||
|
"projectId": project_id or None,
|
||||||
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return (
|
||||||
|
f"Task created: '{row['title']}' "
|
||||||
|
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_task(
|
||||||
|
task_id: str,
|
||||||
|
title: str = "",
|
||||||
|
description: str = "",
|
||||||
|
status: str = "",
|
||||||
|
priority: str = "",
|
||||||
|
assignees: str = "",
|
||||||
|
due_date: int = -1,
|
||||||
|
project_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Update fields on an existing task. Only pass fields you want to change.
|
||||||
|
task_id: the task's UUID (required)
|
||||||
|
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if description:
|
||||||
|
updates["description"] = description
|
||||||
|
if status:
|
||||||
|
updates["status"] = status
|
||||||
|
if priority:
|
||||||
|
updates["priority"] = priority
|
||||||
|
if assignees:
|
||||||
|
updates["assignee"] = assignees
|
||||||
|
if due_date != -1:
|
||||||
|
updates["dueDate"] = due_date or None
|
||||||
|
if project_id:
|
||||||
|
updates["projectId"] = project_id
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="tasks",
|
||||||
|
data={"id": task_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_task(task_id: str) -> str:
|
||||||
|
"""Delete a task permanently by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
|
||||||
|
return f"Task {task_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_tasks_due_today() -> str:
|
||||||
|
"""List all tasks whose due date falls on today's date."""
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
||||||
|
end_ms = start_ms + 86_400_000 - 1 # last ms of today
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="tasks",
|
||||||
|
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No tasks are due today."
|
||||||
|
lines = [
|
||||||
|
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Task comment tools ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_task_comments(task_id: str) -> str:
|
||||||
|
"""List all comments on a task by its UUID."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="taskComments",
|
||||||
|
filters={"taskId": task_id},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return f"No comments found for task {task_id}."
|
||||||
|
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||||
|
"""Add a comment to a task.
|
||||||
|
task_id: UUID of the task to comment on
|
||||||
|
author: name or ID of the comment author
|
||||||
|
content: comment text
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="taskComments",
|
||||||
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
|
)
|
||||||
|
row = result.get("row", {})
|
||||||
|
row_author = row.get("author", author)
|
||||||
|
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||||
|
row_comment_id = row.get("id", "unknown")
|
||||||
|
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_task_comment(comment_id: str) -> str:
|
||||||
|
"""Delete a task comment by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
|
||||||
|
return f"Comment {comment_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
TASK_TOOLS: list[Any] = [
|
||||||
|
list_tasks,
|
||||||
|
create_task,
|
||||||
|
update_task,
|
||||||
|
delete_task,
|
||||||
|
list_tasks_due_today,
|
||||||
|
list_task_comments,
|
||||||
|
add_task_comment,
|
||||||
|
delete_task_comment,
|
||||||
|
]
|
||||||
117
services/chat/app/agents/timeline_agent.py
Normal file
117
services/chat/app/agents/timeline_agent.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
"""Timeline agent — project milestone management (list, create, update, delete).
|
||||||
|
|
||||||
|
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
TIMELINE_SYSTEM_PROMPT = (
|
||||||
|
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
||||||
|
"track progress on a project — they are not calendar events.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||||
|
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
||||||
|
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||||
|
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||||
|
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||||
|
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Listing without a project_id returns all timelines across projects\n"
|
||||||
|
" - Always echo the title and formatted date in your confirmation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="timelines",
|
||||||
|
filters={"projectId": normalized_project_id or None},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No timelines found."
|
||||||
|
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_timeline(
|
||||||
|
project_id: str,
|
||||||
|
title: str,
|
||||||
|
date: int,
|
||||||
|
is_ai_suggested: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Create a project timeline (milestone).
|
||||||
|
project_id: REQUIRED UUID of the parent project
|
||||||
|
title: descriptive name for the milestone
|
||||||
|
date: Unix timestamp in milliseconds
|
||||||
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="timelines",
|
||||||
|
data={
|
||||||
|
"projectId": project_id,
|
||||||
|
"title": title,
|
||||||
|
"date": date,
|
||||||
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_timeline(
|
||||||
|
timeline_id: str,
|
||||||
|
title: str = "",
|
||||||
|
date: int = -1,
|
||||||
|
) -> str:
|
||||||
|
"""Update a timeline. Only pass fields that should change.
|
||||||
|
timeline_id: UUID of the timeline (required)
|
||||||
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if date != -1:
|
||||||
|
updates["date"] = date
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="timelines",
|
||||||
|
data={"id": timeline_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_timeline(timeline_id: str) -> str:
|
||||||
|
"""Delete a timeline permanently by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||||
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
TIMELINE_TOOLS: list[Any] = [
|
||||||
|
list_timelines,
|
||||||
|
create_timeline,
|
||||||
|
update_timeline,
|
||||||
|
delete_timeline,
|
||||||
|
]
|
||||||
883
services/chat/app/deep_agent.py
Normal file
883
services/chat/app/deep_agent.py
Normal file
@@ -0,0 +1,883 @@
|
|||||||
|
"""Single-agent runners for home and floating chat contexts.
|
||||||
|
|
||||||
|
Adapted from app/core/deep_agent.py for the Chat Service.
|
||||||
|
Import paths changed to use local app modules and shared/.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import date
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.agents.note_agent import NOTE_TOOLS
|
||||||
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
|
from app.llm import get_llm
|
||||||
|
from app.memory_middleware import MemoryMiddleware
|
||||||
|
from app.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
||||||
|
from app import tracing
|
||||||
|
from shared.db import async_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
||||||
|
FloatingDomainSection = Literal["task", "timeline", "note"]
|
||||||
|
|
||||||
|
_HOME_SINGLE_AGENT_SYSTEM = (
|
||||||
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
|
"Always use tools for factual data retrieval before answering. "
|
||||||
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||||
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
|
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||||
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
|
||||||
|
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
|
||||||
|
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
|
||||||
|
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
|
||||||
|
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
|
||||||
|
)
|
||||||
|
|
||||||
|
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
||||||
|
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
|
"Stay focused on the floating scope in context.scope and answer concisely. "
|
||||||
|
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
|
||||||
|
"Always use tools for factual data retrieval before answering. "
|
||||||
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||||
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
|
)
|
||||||
|
|
||||||
|
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
|
||||||
|
"You are a strict domain classifier for websocket floating requests. "
|
||||||
|
"Return ONLY a JSON object with keys: type, id, section. "
|
||||||
|
"Allowed type values: task, timeline, project, node. "
|
||||||
|
"Allowed section values: task, timeline, note, or null. "
|
||||||
|
"Rules: infer from user message intent first; do not blindly trust scope.type. "
|
||||||
|
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
|
||||||
|
"If project id is unknown but context.resolved_project_id exists, use it as id. "
|
||||||
|
"If id is unknown, use null. "
|
||||||
|
"No markdown, no prose, JSON only."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _candidate_tokens(message: str) -> list[str]:
|
||||||
|
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
|
||||||
|
return [token for token in tokens if len(token) >= 3]
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_project_id_from_message(message: str) -> str | None:
|
||||||
|
"""Resolve likely project UUID from user message using client project list."""
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="select", table="projects")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("deep_agent: project resolve select failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not isinstance(rows, list) or not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tokens = _candidate_tokens(message)
|
||||||
|
scored: list[tuple[int, dict[str, Any]]] = []
|
||||||
|
for row in rows:
|
||||||
|
if not isinstance(row, dict):
|
||||||
|
continue
|
||||||
|
name = str(row.get("name", "")).lower()
|
||||||
|
score = sum(1 for token in tokens if token in name)
|
||||||
|
if score > 0:
|
||||||
|
scored.append((score, row))
|
||||||
|
|
||||||
|
if not scored:
|
||||||
|
return None
|
||||||
|
|
||||||
|
scored.sort(key=lambda item: item[0], reverse=True)
|
||||||
|
top_score = scored[0][0]
|
||||||
|
top_rows = [row for score, row in scored if score == top_score]
|
||||||
|
if len(top_rows) != 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
project_id = top_rows[0].get("id")
|
||||||
|
return project_id if isinstance(project_id, str) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _needs_project_resolution(message: str) -> bool:
|
||||||
|
lowered = message.lower()
|
||||||
|
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
prepared = dict(context)
|
||||||
|
if _needs_project_resolution(message):
|
||||||
|
resolved_project_id = await _resolve_project_id_from_message(message)
|
||||||
|
if resolved_project_id:
|
||||||
|
prepared["resolved_project_id"] = resolved_project_id
|
||||||
|
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
|
||||||
|
return prepared
|
||||||
|
|
||||||
|
|
||||||
|
def _all_tools() -> list[Any]:
|
||||||
|
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
||||||
|
|
||||||
|
|
||||||
|
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
||||||
|
debug = context.get("_debug")
|
||||||
|
if isinstance(debug, dict):
|
||||||
|
request_id = debug.get("request_id")
|
||||||
|
if isinstance(request_id, str) and request_id:
|
||||||
|
return request_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
sanitized = dict(context)
|
||||||
|
sanitized.pop("_debug", None)
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
|
||||||
|
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_upcoming_timeline_query(message: str) -> bool:
|
||||||
|
lowered = message.lower()
|
||||||
|
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
|
||||||
|
has_timeline_topic = any(
|
||||||
|
token in lowered
|
||||||
|
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
|
||||||
|
)
|
||||||
|
return has_upcoming and has_timeline_topic
|
||||||
|
|
||||||
|
|
||||||
|
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
|
||||||
|
match = _TIMELINE_DMY_RE.search(dmy)
|
||||||
|
if not match:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
parsed = date(
|
||||||
|
int(match.group("y")),
|
||||||
|
int(match.group("m")),
|
||||||
|
int(match.group("d")),
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
return True
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
return parsed >= today and parsed.year == today.year and parsed.month == today.month
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_tagged_list_lines(text: str, message: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
upcoming_timeline_only = _is_upcoming_timeline_query(message)
|
||||||
|
output_lines: list[str] = []
|
||||||
|
|
||||||
|
for line in text.splitlines():
|
||||||
|
matches = list(_TAG_LINE_RE.finditer(line))
|
||||||
|
if not matches:
|
||||||
|
output_lines.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
|
||||||
|
if not had_non_tag_text and len(matches) == 1:
|
||||||
|
tag_text = matches[0].group(0)
|
||||||
|
if (
|
||||||
|
upcoming_timeline_only
|
||||||
|
and "<timeline>" in tag_text
|
||||||
|
and not _timeline_date_in_current_month_or_future(line)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
output_lines.append(tag_text)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
tag_text = match.group(0)
|
||||||
|
if (
|
||||||
|
upcoming_timeline_only
|
||||||
|
and "<timeline>" in tag_text
|
||||||
|
and not _timeline_date_in_current_month_or_future(line)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
output_lines.append(tag_text)
|
||||||
|
|
||||||
|
return "\n".join(output_lines)
|
||||||
|
|
||||||
|
|
||||||
|
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
|
||||||
|
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
|
||||||
|
_FLOATING_EMPTY_FALLBACK = "No results found."
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_floating_markup_fragment(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
cleaned = _GENERIC_TAG_RE.sub("", text)
|
||||||
|
return _BRACKETED_ID_RE.sub("", cleaned)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_floating_markup(text: str) -> str:
|
||||||
|
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
cleaned = _strip_floating_markup_fragment(text)
|
||||||
|
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
|
||||||
|
return "\n".join(line for line in lines if line)
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_from_raw_floating_text(raw_text: str) -> str:
|
||||||
|
fallback = _strip_floating_markup_fragment(raw_text or "")
|
||||||
|
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
|
||||||
|
return fallback or _FLOATING_EMPTY_FALLBACK
|
||||||
|
|
||||||
|
|
||||||
|
class _FloatingStreamSanitizer:
|
||||||
|
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._pending = ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_safe_boundary(text: str) -> tuple[str, str]:
|
||||||
|
boundary = len(text)
|
||||||
|
|
||||||
|
last_lt = text.rfind("<")
|
||||||
|
if last_lt != -1 and ">" not in text[last_lt:]:
|
||||||
|
boundary = min(boundary, last_lt)
|
||||||
|
|
||||||
|
last_lb = text.rfind("[")
|
||||||
|
if last_lb != -1 and "]" not in text[last_lb:]:
|
||||||
|
boundary = min(boundary, last_lb)
|
||||||
|
|
||||||
|
if boundary == len(text):
|
||||||
|
return text, ""
|
||||||
|
return text[:boundary], text[boundary:]
|
||||||
|
|
||||||
|
def feed(self, chunk: str) -> str:
|
||||||
|
combined = f"{self._pending}{chunk}"
|
||||||
|
safe_text, self._pending = self._split_safe_boundary(combined)
|
||||||
|
return _strip_floating_markup_fragment(safe_text)
|
||||||
|
|
||||||
|
def finalize(self) -> str:
|
||||||
|
tail = re.sub(r"<[^>\n]*$", "", self._pending)
|
||||||
|
tail = re.sub(r"\[[^\]\n]*$", "", tail)
|
||||||
|
self._pending = ""
|
||||||
|
return _strip_floating_markup_fragment(tail)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_memory_label(path_or_label: str) -> str:
|
||||||
|
value = path_or_label.strip()
|
||||||
|
if value.startswith("/memories/"):
|
||||||
|
value = value[len("/memories/"):]
|
||||||
|
value = value.strip("/")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
@tool
|
||||||
|
async def memory_list_blocks() -> str:
|
||||||
|
"""List all core memory blocks currently stored for the user."""
|
||||||
|
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
blocks = await memory.list_core_blocks(user_id)
|
||||||
|
if not blocks:
|
||||||
|
return "No memory blocks found."
|
||||||
|
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
|
||||||
|
return "Memory blocks:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_get(path_or_label: str) -> str:
|
||||||
|
"""Get one memory block by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
value = await memory.get_core_block(user_id, label)
|
||||||
|
if value is None:
|
||||||
|
return f"Memory block '{label}' not found."
|
||||||
|
return f"Memory block '{label}':\n{value}"
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_create(path_or_label: str, value: str) -> str:
|
||||||
|
"""Create or overwrite a memory block value by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.update_core(user_id, label, value, trace_id=trace_id)
|
||||||
|
return f"Memory block '{label}' saved."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_append(path_or_label: str, content: str) -> str:
|
||||||
|
"""Append content to a memory block, creating it if missing."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.append_core(user_id, label, content)
|
||||||
|
return f"Memory block '{label}' appended."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
|
||||||
|
"""Replace one exact string in a memory block."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
changed = await memory.replace_core(user_id, label, old_string, new_string)
|
||||||
|
if not changed:
|
||||||
|
return f"No replacement made in '{label}' (old string not found)."
|
||||||
|
return f"Memory block '{label}' updated."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_delete(path_or_label: str) -> str:
|
||||||
|
"""Delete a memory block by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
deleted = await memory.delete_core(user_id, label)
|
||||||
|
if not deleted:
|
||||||
|
return f"Memory block '{label}' not found."
|
||||||
|
return f"Memory block '{label}' deleted."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def archival_memory_insert(content: str) -> str:
|
||||||
|
"""Insert a long-term archival memory entry."""
|
||||||
|
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.insert_archival(user_id, content, source="assistant")
|
||||||
|
return "Archival memory saved."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def archival_memory_search(query: str, top_k: int = 5) -> str:
|
||||||
|
"""Search long-term archival memory by semantic fallback (keyword currently)."""
|
||||||
|
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
results = await memory.search_archival(user_id, query, top_k=top_k)
|
||||||
|
if not results:
|
||||||
|
return "No archival memory results found."
|
||||||
|
lines = [f"- {item}" for item in results]
|
||||||
|
return "Archival memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def conversation_search(query: str, top_k: int = 5) -> str:
|
||||||
|
"""Search recall memory from prior episodic conversation summaries."""
|
||||||
|
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
results = await memory.search_recall(user_id, query, top_k=top_k)
|
||||||
|
if not results:
|
||||||
|
return "No recall memory results found."
|
||||||
|
lines = [f"- {item}" for item in results]
|
||||||
|
return "Recall memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
return [
|
||||||
|
memory_list_blocks,
|
||||||
|
memory_get,
|
||||||
|
memory_create,
|
||||||
|
memory_append,
|
||||||
|
memory_replace,
|
||||||
|
memory_delete,
|
||||||
|
archival_memory_insert,
|
||||||
|
archival_memory_search,
|
||||||
|
conversation_search,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
|
||||||
|
lowered = message.lower()
|
||||||
|
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
|
||||||
|
return "timeline"
|
||||||
|
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
|
||||||
|
return "task"
|
||||||
|
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
|
||||||
|
return "note"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
|
||||||
|
type_raw = str(payload.get("type") or "").strip().lower()
|
||||||
|
domain_type: FloatingDomainType = "task"
|
||||||
|
if type_raw in {"task", "timeline", "project", "node"}:
|
||||||
|
domain_type = type_raw
|
||||||
|
|
||||||
|
id_value = payload.get("id")
|
||||||
|
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
|
||||||
|
if domain_type == "project" and not domain_id:
|
||||||
|
domain_id = fallback_id
|
||||||
|
|
||||||
|
section_raw = payload.get("section")
|
||||||
|
section: FloatingDomainSection | None = None
|
||||||
|
if isinstance(section_raw, str):
|
||||||
|
section_candidate = section_raw.strip().lower()
|
||||||
|
if section_candidate in {"task", "timeline", "note"}:
|
||||||
|
section = section_candidate
|
||||||
|
|
||||||
|
if domain_type != "project":
|
||||||
|
section = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": domain_type,
|
||||||
|
"id": domain_id,
|
||||||
|
"section": section,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_object(text: str) -> dict[str, Any] | None:
|
||||||
|
raw = text.strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = json.loads(match.group(0))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
||||||
|
section = _detect_domain_section(message)
|
||||||
|
scope = context.get("scope") if isinstance(context, dict) else None
|
||||||
|
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
||||||
|
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
||||||
|
|
||||||
|
if isinstance(scope, dict):
|
||||||
|
scope_type = str(scope.get("type") or "").strip().lower()
|
||||||
|
scope_id = scope.get("id")
|
||||||
|
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
|
||||||
|
|
||||||
|
if scope_type in {"task", "tasks"}:
|
||||||
|
return {"type": "task", "id": scope_id_value, "section": None}
|
||||||
|
if scope_type in {"project", "projects"}:
|
||||||
|
project_scope_id = scope_id_value or project_id
|
||||||
|
return {
|
||||||
|
"type": "project",
|
||||||
|
"id": project_scope_id,
|
||||||
|
"section": section,
|
||||||
|
}
|
||||||
|
if scope_type in {"note", "notes"}:
|
||||||
|
return {
|
||||||
|
"type": "node",
|
||||||
|
"id": scope_id_value,
|
||||||
|
"section": None,
|
||||||
|
}
|
||||||
|
if scope_type in {"timeline", "timelines"}:
|
||||||
|
return {"type": "timeline", "id": scope_id_value, "section": None}
|
||||||
|
|
||||||
|
lowered = message.lower()
|
||||||
|
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
|
||||||
|
return {
|
||||||
|
"type": "project",
|
||||||
|
"id": project_id,
|
||||||
|
"section": section,
|
||||||
|
}
|
||||||
|
if section == "timeline":
|
||||||
|
return {"type": "timeline", "id": None, "section": None}
|
||||||
|
if section == "note":
|
||||||
|
return {"type": "node", "id": None, "section": None}
|
||||||
|
return {"type": "task", "id": None, "section": None}
|
||||||
|
|
||||||
|
|
||||||
|
async def _infer_floating_domain(
|
||||||
|
message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None,
|
||||||
|
) -> dict[str, str | None]:
|
||||||
|
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
||||||
|
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
||||||
|
|
||||||
|
classifier_context = {
|
||||||
|
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
|
||||||
|
"resolved_project_id": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
classifier_prompt = _get_system_prompt(
|
||||||
|
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_SYSTEM,
|
||||||
|
)
|
||||||
|
callbacks = _build_callbacks(langfuse_handler)
|
||||||
|
llm = get_llm(callbacks=callbacks)
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[
|
||||||
|
SystemMessage(content=classifier_prompt),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"Message:\n{message}\n\n"
|
||||||
|
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
parsed = _parse_json_object(_as_text(response.content))
|
||||||
|
if parsed is not None:
|
||||||
|
domain = _normalize_domain_payload(parsed, project_id)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
|
||||||
|
domain.get("type"),
|
||||||
|
domain.get("id"),
|
||||||
|
domain.get("section"),
|
||||||
|
)
|
||||||
|
return domain
|
||||||
|
logger.warning("deep_agent: floating_domain classifier returned non-json output")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
|
||||||
|
|
||||||
|
return _infer_floating_domain_rule_based(message, context)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_system_prompt(langfuse_name: str, fallback: str) -> str:
|
||||||
|
"""Fetch a managed prompt from Langfuse, falling back to the hardcoded string."""
|
||||||
|
managed = tracing.get_prompt(langfuse_name, fallback=None)
|
||||||
|
return managed if managed is not None else fallback
|
||||||
|
|
||||||
|
|
||||||
|
def _build_callbacks(langfuse_handler: Any | None) -> list[Any] | None:
|
||||||
|
"""Return a callbacks list if a Langfuse handler is available."""
|
||||||
|
if langfuse_handler is None:
|
||||||
|
return None
|
||||||
|
return [langfuse_handler]
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_single_agent(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
system_prompt: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
max_steps: int = 6,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> str:
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
|
callbacks = _build_callbacks(langfuse_handler)
|
||||||
|
llm = get_llm(callbacks=callbacks)
|
||||||
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
|
model_context = _context_for_model(context)
|
||||||
|
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"User message:\n{message}\n\n"
|
||||||
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls_count = 0
|
||||||
|
collected: list[dict[str, Any]] = []
|
||||||
|
set_tool_result_collector(collected)
|
||||||
|
try:
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
final_text = _as_text(response.content)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
len(final_text),
|
||||||
|
)
|
||||||
|
return final_text
|
||||||
|
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_calls_count += 1
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:1200],
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
final_text = _as_text(final.content)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
len(final_text),
|
||||||
|
)
|
||||||
|
return final_text
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_single_agent_stream(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
system_prompt: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
max_steps: int = 6,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
|
callbacks = _build_callbacks(langfuse_handler)
|
||||||
|
llm = get_llm(callbacks=callbacks)
|
||||||
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
|
model_context = _context_for_model(context)
|
||||||
|
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"User message:\n{message}\n\n"
|
||||||
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls_count = 0
|
||||||
|
streamed_chars = 0
|
||||||
|
collected: list[dict[str, Any]] = []
|
||||||
|
set_tool_result_collector(collected)
|
||||||
|
try:
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
emitted_any = False
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
|
if token:
|
||||||
|
streamed_chars += len(token)
|
||||||
|
emitted_any = True
|
||||||
|
yield "token", token
|
||||||
|
|
||||||
|
if not emitted_any:
|
||||||
|
fallback_text = _as_text(response.content)
|
||||||
|
if fallback_text:
|
||||||
|
streamed_chars += len(fallback_text)
|
||||||
|
yield "token", fallback_text
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
streamed_chars,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_calls_count += 1
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:1200],
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
|
if token:
|
||||||
|
streamed_chars += len(token)
|
||||||
|
yield "token", token
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
streamed_chars,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> str:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
|
||||||
|
response = await _run_single_agent(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
return _normalize_tagged_list_lines(response, message)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_floating(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> tuple[str, dict[str, str | None]]:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
|
||||||
|
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
|
||||||
|
response = await _run_single_agent(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
sanitized = _strip_floating_markup(response)
|
||||||
|
if not sanitized and response:
|
||||||
|
sanitized = _fallback_from_raw_floating_text(response)
|
||||||
|
return sanitized, domain
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
*,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
|
||||||
|
text_chunks: list[str] = []
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
):
|
||||||
|
event_type, data = event
|
||||||
|
if event_type != "token":
|
||||||
|
yield event
|
||||||
|
continue
|
||||||
|
text_chunks.append(str(data or ""))
|
||||||
|
|
||||||
|
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
|
||||||
|
if normalized:
|
||||||
|
yield "token", normalized
|
||||||
|
|
||||||
|
|
||||||
|
async def run_floating_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
*,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
|
||||||
|
yield "floating_domain", domain
|
||||||
|
|
||||||
|
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
|
||||||
|
sanitizer = _FloatingStreamSanitizer()
|
||||||
|
emitted_sanitized = False
|
||||||
|
raw_chunks: list[str] = []
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
):
|
||||||
|
event_type, data = event
|
||||||
|
if event_type != "token":
|
||||||
|
yield event
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_chunk = str(data or "")
|
||||||
|
raw_chunks.append(raw_chunk)
|
||||||
|
sanitized_chunk = sanitizer.feed(raw_chunk)
|
||||||
|
if sanitized_chunk:
|
||||||
|
emitted_sanitized = True
|
||||||
|
yield "token", sanitized_chunk
|
||||||
|
|
||||||
|
tail = sanitizer.finalize()
|
||||||
|
if tail:
|
||||||
|
emitted_sanitized = True
|
||||||
|
yield "token", tail
|
||||||
|
|
||||||
|
if not emitted_sanitized and raw_chunks:
|
||||||
|
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
||||||
|
|
||||||
|
|
||||||
|
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
||||||
|
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.update_core(user_id, key, value)
|
||||||
72
services/chat/app/llm.py
Normal file
72
services/chat/app/llm.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
|
Adapted from app/core/llm.py for the Chat Service.
|
||||||
|
Uses shared.config.settings instead of app.config.settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_litellm import ChatLiteLLM
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
|
if model.startswith("anthropic/"):
|
||||||
|
return settings.ANTHROPIC_API_KEY or None
|
||||||
|
if model.startswith("gemini/") or model.startswith("google/"):
|
||||||
|
return settings.GOOGLE_API_KEY or None
|
||||||
|
if model.startswith("cerebras/"):
|
||||||
|
return settings.CEREBRAS_API_KEY or None
|
||||||
|
if model.startswith("github_copilot/"):
|
||||||
|
return None
|
||||||
|
return settings.OPENAI_API_KEY or None
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm(
|
||||||
|
*,
|
||||||
|
model: str | None = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
callbacks: list | None = None,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
model = model or settings.LLM_MODEL
|
||||||
|
|
||||||
|
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||||
|
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||||
|
|
||||||
|
if "/" in model:
|
||||||
|
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=_api_key_for_model(model),
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embed(text: str) -> list[float]:
|
||||||
|
model = settings.LLM_EMBED_MODEL
|
||||||
|
|
||||||
|
if model.startswith("github_copilot/") or "/" in model:
|
||||||
|
response = await litellm.aembedding(model=model, input=[text])
|
||||||
|
return response.data[0]["embedding"]
|
||||||
|
|
||||||
|
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
|
response = await client.embeddings.create(model=model, input=text)
|
||||||
|
return response.data[0].embedding
|
||||||
87
services/chat/app/main.py
Normal file
87
services/chat/app/main.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""Chat Service — LLM orchestration, domain agents, memory.
|
||||||
|
|
||||||
|
Consumes chat requests from Redis, executes deep_agent (home/floating),
|
||||||
|
streams responses back via Redis pub/sub to WS Gateway.
|
||||||
|
|
||||||
|
Owns: memory_core, memory_associative, memory_episodic, memory_proactive tables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
|
||||||
|
_repo_root = str(Path(__file__).resolve().parents[3])
|
||||||
|
if _repo_root not in sys.path:
|
||||||
|
sys.path.insert(0, _repo_root)
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# Initialise Langfuse tracing (no-op if keys are missing)
|
||||||
|
from app.tracing import init_langfuse
|
||||||
|
|
||||||
|
init_langfuse()
|
||||||
|
|
||||||
|
# Start Redis consumer in background
|
||||||
|
from app.redis_consumer import start_consumer
|
||||||
|
|
||||||
|
consumer_task = start_consumer()
|
||||||
|
yield
|
||||||
|
consumer_task.cancel()
|
||||||
|
|
||||||
|
from app.tracing import shutdown as shutdown_langfuse
|
||||||
|
|
||||||
|
shutdown_langfuse()
|
||||||
|
|
||||||
|
from shared.db import engine
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
from shared.redis import redis_client
|
||||||
|
|
||||||
|
await redis_client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
app = FastAPI(
|
||||||
|
title="Adiuva Chat Service",
|
||||||
|
version="0.1.0",
|
||||||
|
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||||
|
redoc_url=None,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.CORS_ORIGINS,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.routes import router
|
||||||
|
|
||||||
|
app.include_router(router, prefix="/api/v1")
|
||||||
|
|
||||||
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
async def health() -> dict:
|
||||||
|
return {"status": "ok", "service": "chat", "version": app.version}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
295
services/chat/app/memory_middleware.py
Normal file
295
services/chat/app/memory_middleware.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
"""Memory Middleware — adapted for Chat Service.
|
||||||
|
|
||||||
|
Uses shared.models instead of app.models. Otherwise identical to the
|
||||||
|
monolith's app/core/memory_middleware.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.models import (
|
||||||
|
MemoryAssociative,
|
||||||
|
MemoryCore,
|
||||||
|
MemoryEpisodic,
|
||||||
|
MemoryProactive,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ASSOCIATIVE_TOP_K = 5
|
||||||
|
_EPISODIC_RECENT_N = 10
|
||||||
|
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryMiddleware:
|
||||||
|
|
||||||
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
|
self._db = db
|
||||||
|
|
||||||
|
async def enrich_context(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
core = await self._load_core(user_id, fernet)
|
||||||
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"memory: enrich_context trace=%s user=%s core=%d assoc=%d episodic=%d proactive=%d",
|
||||||
|
trace_id or "-", user_id, len(core), len(associative), len(episodic), len(proactive),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"core_memory": core,
|
||||||
|
"associative_memory": associative,
|
||||||
|
"episodic_memory": episodic,
|
||||||
|
"proactive_hints": proactive,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def store_episode(
|
||||||
|
self, user_id: str, session_id: str, message: str, response: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||||
|
encrypted = _encrypt(fernet, summary)
|
||||||
|
|
||||||
|
row = MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
summary_encrypted=encrypted,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, value)
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == key)
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
if existing is not None:
|
||||||
|
existing.value_encrypted = encrypted
|
||||||
|
else:
|
||||||
|
self._db.add(MemoryCore(
|
||||||
|
id=str(uuid.uuid4()), user_id=user_id, key=key, value_encrypted=encrypted,
|
||||||
|
))
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id).order_by(MemoryCore.key.asc())
|
||||||
|
)
|
||||||
|
out: list[dict[str, str]] = []
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append({"label": row.key, "value": plaintext})
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return None
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
|
||||||
|
async def delete_core(self, user_id: str, label: str) -> bool:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
return False
|
||||||
|
await self._db.delete(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None:
|
||||||
|
await self.update_core(user_id, label, content)
|
||||||
|
return
|
||||||
|
await self.update_core(user_id, label, f"{current}\n{content}")
|
||||||
|
|
||||||
|
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None or old not in current:
|
||||||
|
return False
|
||||||
|
await self.update_core(user_id, label, current.replace(old, new, 1))
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
encrypted = _encrypt(fernet, content)
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()), user_id=user_id,
|
||||||
|
content_encrypted=encrypted, embedding=None,
|
||||||
|
entity_type=source, entity_id=None,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc()).limit(100)
|
||||||
|
)
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||||
|
.order_by(MemoryEpisodic.created_at.desc()).limit(100)
|
||||||
|
)
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
# ── Private ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not user.encryption_key:
|
||||||
|
logger.warning("memory: no encryption_key for user=%s", user_id)
|
||||||
|
return None
|
||||||
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
|
)
|
||||||
|
out: dict[str, str] = {}
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out[row.key] = plaintext
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_associative(self, user_id: str, message: str, fernet: Fernet) -> list[str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc()).limit(_ASSOCIATIVE_TOP_K)
|
||||||
|
)
|
||||||
|
out: list[str] = []
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_episodic(self, user_id: str, fernet: Fernet, session_id: str | None = None) -> list[str]:
|
||||||
|
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||||
|
if session_id:
|
||||||
|
query = query.where(MemoryEpisodic.session_id == session_id)
|
||||||
|
result = await self._db.execute(
|
||||||
|
query.order_by(MemoryEpisodic.created_at.desc()).limit(_EPISODIC_RECENT_N)
|
||||||
|
)
|
||||||
|
out: list[str] = []
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryProactive).where(
|
||||||
|
MemoryProactive.user_id == user_id,
|
||||||
|
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
||||||
|
).order_by(MemoryProactive.confidence.desc())
|
||||||
|
)
|
||||||
|
out: list[str] = []
|
||||||
|
for row in result.scalars().all():
|
||||||
|
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
||||||
|
return fernet.encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
||||||
|
try:
|
||||||
|
return fernet.decrypt(ciphertext.encode()).decode()
|
||||||
|
except (InvalidToken, Exception) as exc:
|
||||||
|
logger.warning("memory: decrypt failed: %s", exc)
|
||||||
|
return None
|
||||||
50
services/chat/app/output_formatter.py
Normal file
50
services/chat/app/output_formatter.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Output formatter for deep-agent stream events — Chat Service copy.
|
||||||
|
|
||||||
|
Converts (event_type, data) tuples into WebSocket frame Pydantic models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
|
|
||||||
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
|
class StreamFormatter:
|
||||||
|
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
started = False
|
||||||
|
|
||||||
|
async for event_type, data in event_stream:
|
||||||
|
if event_type == "floating_domain":
|
||||||
|
if isinstance(data, dict):
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain=data,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event_type != "token":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not started:
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
started = True
|
||||||
|
|
||||||
|
text = str(data or "")
|
||||||
|
if text:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=text)
|
||||||
|
|
||||||
|
if not started:
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
yield WsStreamEnd(request_id=self.request_id)
|
||||||
209
services/chat/app/redis_consumer.py
Normal file
209
services/chat/app/redis_consumer.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
"""Redis consumer — listens for chat requests and dispatches to deep_agent.
|
||||||
|
|
||||||
|
Subscribes to a Redis pattern channel chat:request:* so it receives
|
||||||
|
requests for ALL users. Each request is processed in a separate asyncio task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.redis import redis_client, ws_out_channel
|
||||||
|
|
||||||
|
from app.deep_agent import run_floating_stream, run_home_stream
|
||||||
|
from app.memory_middleware import MemoryMiddleware
|
||||||
|
from app.output_formatter import StreamFormatter
|
||||||
|
from app.ws_context import clear_current_user, set_current_user
|
||||||
|
from app import tracing
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def start_consumer() -> asyncio.Task:
|
||||||
|
"""Start the Redis consumer as a background asyncio task."""
|
||||||
|
return asyncio.create_task(_consumer_loop())
|
||||||
|
|
||||||
|
|
||||||
|
async def _consumer_loop() -> None:
|
||||||
|
"""Subscribe to chat:request:* and dispatch incoming frames."""
|
||||||
|
pubsub = redis_client.pubsub()
|
||||||
|
await pubsub.psubscribe("chat:request:*")
|
||||||
|
logger.info("redis_consumer: subscribed to chat:request:*")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await pubsub.get_message(
|
||||||
|
ignore_subscribe_messages=True, timeout=1.0
|
||||||
|
)
|
||||||
|
if message is not None and message["type"] == "pmessage":
|
||||||
|
frame = json.loads(message["data"])
|
||||||
|
asyncio.create_task(_dispatch(frame))
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("redis_consumer: shutting down")
|
||||||
|
finally:
|
||||||
|
await pubsub.punsubscribe()
|
||||||
|
await pubsub.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch(frame: dict) -> None:
|
||||||
|
"""Route a chat request frame to the appropriate handler."""
|
||||||
|
frame_type = frame.get("type")
|
||||||
|
user_id = frame.get("user_id")
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("redis_consumer: frame missing user_id: %s", frame.get("type"))
|
||||||
|
return
|
||||||
|
|
||||||
|
if frame_type == "home_request":
|
||||||
|
await _handle_home_request(user_id, frame)
|
||||||
|
elif frame_type == "floating_request":
|
||||||
|
await _handle_floating_request(user_id, frame)
|
||||||
|
else:
|
||||||
|
logger.debug("redis_consumer: unknown frame type %r", frame_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def _publish_frame(user_id: str, frame_data: str) -> None:
|
||||||
|
"""Publish a frame to ws:out:{user_id} for the WS Gateway to forward."""
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, frame_data)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_home_request(user_id: str, frame: dict) -> None:
|
||||||
|
"""Process a home_request — enrich with memory, run deep_agent, stream results."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"redis_consumer: home_request user=%s req=%s msg=%s",
|
||||||
|
user_id, request_id, message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
|
||||||
|
with tracing.trace_span(
|
||||||
|
name="home_request",
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
trace_id=request_id,
|
||||||
|
input=message,
|
||||||
|
metadata={"message_preview": message[:200]},
|
||||||
|
tags=["home"],
|
||||||
|
) as span:
|
||||||
|
langfuse_handler = tracing.get_langfuse_callback()
|
||||||
|
|
||||||
|
# Enrich with memory context
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id, message,
|
||||||
|
trace_id=request_id, session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler)
|
||||||
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
await _publish_frame(user_id, ws_frame.model_dump_json())
|
||||||
|
if hasattr(ws_frame, "chunk"):
|
||||||
|
response_chunks.append(ws_frame.chunk)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc)
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
# Link prompt and attach output preview
|
||||||
|
tracing.link_prompt_to_trace(span, "home_system")
|
||||||
|
response_text = "".join(response_chunks)
|
||||||
|
span.update(output=response_text[:500] if response_text else None)
|
||||||
|
|
||||||
|
tracing.flush()
|
||||||
|
|
||||||
|
# Store episode
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.store_episode(
|
||||||
|
user_id, session_id, message, "".join(response_chunks),
|
||||||
|
trace_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_floating_request(user_id: str, frame: dict) -> None:
|
||||||
|
"""Process a floating_request — enrich with memory, run deep_agent, stream results."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
scope: dict = frame.get("scope", {})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"redis_consumer: floating_request user=%s req=%s scope=%s msg=%s",
|
||||||
|
user_id, request_id, json.dumps(scope)[:200], message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
|
||||||
|
with tracing.trace_span(
|
||||||
|
name="floating_request",
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
trace_id=request_id,
|
||||||
|
input=message,
|
||||||
|
metadata={"message_preview": message[:200], "scope": scope},
|
||||||
|
tags=["floating"],
|
||||||
|
) as span:
|
||||||
|
langfuse_handler = tracing.get_langfuse_callback()
|
||||||
|
|
||||||
|
# Enrich with memory context
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id, message,
|
||||||
|
trace_id=request_id, session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"scope": scope,
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler)
|
||||||
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
await _publish_frame(user_id, ws_frame.model_dump_json())
|
||||||
|
if hasattr(ws_frame, "chunk"):
|
||||||
|
response_chunks.append(ws_frame.chunk)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc)
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
# Link prompt and attach output preview
|
||||||
|
tracing.link_prompt_to_trace(span, "floating_system")
|
||||||
|
response_text = "".join(response_chunks)
|
||||||
|
span.update(output=response_text[:500] if response_text else None)
|
||||||
|
|
||||||
|
tracing.flush()
|
||||||
|
|
||||||
|
# Store episode
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.store_episode(
|
||||||
|
user_id, session_id, message, "".join(response_chunks),
|
||||||
|
trace_id=request_id,
|
||||||
|
)
|
||||||
37
services/chat/app/routes.py
Normal file
37
services/chat/app/routes.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Chat REST route — POST /chat fallback when WS is unavailable."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from shared.schemas import ChatRequest
|
||||||
|
|
||||||
|
from app.deep_agent import run_home
|
||||||
|
from app.ws_context import clear_current_user, set_current_user
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
async def chat(body: ChatRequest, request: Request) -> JSONResponse:
|
||||||
|
"""REST fallback for home chat.
|
||||||
|
|
||||||
|
In the microservices setup, Traefik ForwardAuth has already validated
|
||||||
|
the JWT and injected X-User-Id / X-User-Email / X-User-Tier headers.
|
||||||
|
"""
|
||||||
|
user_id = request.headers.get("X-User-Id", "")
|
||||||
|
if not user_id:
|
||||||
|
return JSONResponse(status_code=401, content={"detail": "Missing X-User-Id header"})
|
||||||
|
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
response = await run_home(
|
||||||
|
user_id=user_id,
|
||||||
|
message=body.message,
|
||||||
|
context=body.context.model_dump(),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
return JSONResponse(content={"response": response})
|
||||||
264
services/chat/app/tracing.py
Normal file
264
services/chat/app/tracing.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
"""Langfuse tracing & prompt management for the Chat Service (v4 SDK).
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- ``init_langfuse()`` — initialise the singleton client at startup
|
||||||
|
- ``trace_span()`` — context manager that creates a trace + span
|
||||||
|
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
|
||||||
|
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
|
||||||
|
- ``flush()`` / ``shutdown()`` — lifecycle management
|
||||||
|
|
||||||
|
All functions gracefully degrade to no-ops when Langfuse is not configured,
|
||||||
|
so the service works identically with or without observability keys.
|
||||||
|
|
||||||
|
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── State ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_initialised: bool = False
|
||||||
|
_disabled: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_configured() -> bool:
|
||||||
|
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
def init_langfuse() -> None:
|
||||||
|
"""Initialise the Langfuse singleton. Call once at startup."""
|
||||||
|
global _initialised, _disabled
|
||||||
|
|
||||||
|
if _initialised or _disabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not _is_configured():
|
||||||
|
_disabled = True
|
||||||
|
logger.info("tracing: Langfuse keys not set — tracing disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import Langfuse
|
||||||
|
|
||||||
|
Langfuse(
|
||||||
|
secret_key=settings.LANGFUSE_SECRET_KEY,
|
||||||
|
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
||||||
|
host=settings.LANGFUSE_HOST,
|
||||||
|
)
|
||||||
|
_initialised = True
|
||||||
|
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
|
||||||
|
except Exception as exc:
|
||||||
|
_disabled = True
|
||||||
|
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client() -> Any | None:
|
||||||
|
"""Return the singleton Langfuse client, or *None* if disabled."""
|
||||||
|
if _disabled:
|
||||||
|
return None
|
||||||
|
if not _initialised:
|
||||||
|
init_langfuse()
|
||||||
|
if _disabled:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
from langfuse import get_client
|
||||||
|
return get_client()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _NullSpan:
|
||||||
|
"""Drop-in replacement when Langfuse is disabled."""
|
||||||
|
|
||||||
|
def update(self, **_: Any) -> None: ...
|
||||||
|
def set_trace_io(self, **_: Any) -> None: ...
|
||||||
|
def score_trace(self, **_: Any) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
# ── Trace context manager ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def trace_span(
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str | None = None,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
input: Any = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
"""Context manager that creates a Langfuse trace/span.
|
||||||
|
|
||||||
|
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
|
||||||
|
A ``CallbackHandler`` created inside this block auto-inherits the trace
|
||||||
|
context, so there is no need to pass trace IDs manually.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
yield _NullSpan()
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import Langfuse, propagate_attributes
|
||||||
|
|
||||||
|
trace_ctx: dict[str, str] = {}
|
||||||
|
if trace_id is not None:
|
||||||
|
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
|
||||||
|
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="span",
|
||||||
|
name=name,
|
||||||
|
input=input,
|
||||||
|
metadata=metadata or {},
|
||||||
|
**({"trace_context": trace_ctx} if trace_ctx else {}),
|
||||||
|
) as span:
|
||||||
|
with propagate_attributes(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
tags=tags or [],
|
||||||
|
):
|
||||||
|
yield span
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
|
||||||
|
yield _NullSpan()
|
||||||
|
|
||||||
|
|
||||||
|
# ── LangChain callback handler ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_langfuse_callback() -> Any | None:
|
||||||
|
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
|
||||||
|
|
||||||
|
Must be called inside a ``trace_span()`` block for proper linking.
|
||||||
|
Returns *None* when Langfuse is disabled.
|
||||||
|
"""
|
||||||
|
if _disabled and not _initialised:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse.langchain import CallbackHandler
|
||||||
|
return CallbackHandler()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prompt management ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt(
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
fallback: str | None = None,
|
||||||
|
cache_ttl_seconds: int = 300,
|
||||||
|
) -> str | None:
|
||||||
|
"""Fetch a managed prompt from Langfuse by name.
|
||||||
|
|
||||||
|
Returns the compiled prompt string, or *fallback* if the prompt is not
|
||||||
|
found or Langfuse is disabled.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"name": name,
|
||||||
|
"cache_ttl_seconds": cache_ttl_seconds,
|
||||||
|
}
|
||||||
|
if version is not None:
|
||||||
|
kwargs["version"] = version
|
||||||
|
if label is not None:
|
||||||
|
kwargs["label"] = label
|
||||||
|
prompt = lf.get_prompt(**kwargs)
|
||||||
|
return prompt.prompt
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: get_prompt(%s) failed: %s", name, exc)
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
def link_prompt_to_trace(
|
||||||
|
span: Any,
|
||||||
|
prompt_name: str,
|
||||||
|
*,
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Attach prompt metadata to a span/trace."""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None or isinstance(span, _NullSpan):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs: dict[str, Any] = {"name": prompt_name}
|
||||||
|
if version is not None:
|
||||||
|
kwargs["version"] = version
|
||||||
|
if label is not None:
|
||||||
|
kwargs["label"] = label
|
||||||
|
prompt = lf.get_prompt(**kwargs)
|
||||||
|
span.update(metadata={"prompt": {"name": prompt_name, "version": prompt.version}})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Scoring helper ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def score_trace(
|
||||||
|
trace_id: str,
|
||||||
|
name: str,
|
||||||
|
value: float,
|
||||||
|
*,
|
||||||
|
comment: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: score_trace failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Shutdown ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def flush() -> None:
|
||||||
|
"""Flush pending Langfuse events."""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is not None:
|
||||||
|
try:
|
||||||
|
lf.flush()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: flush failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
def shutdown() -> None:
|
||||||
|
"""Flush and close the Langfuse client."""
|
||||||
|
global _initialised, _disabled
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is not None:
|
||||||
|
try:
|
||||||
|
lf.flush()
|
||||||
|
lf.shutdown()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: shutdown failed: %s", exc)
|
||||||
|
_initialised = False
|
||||||
|
_disabled = False
|
||||||
115
services/chat/app/ws_context.py
Normal file
115
services/chat/app/ws_context.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""WebSocket context for Chat Service — Redis-based tool call round-trip.
|
||||||
|
|
||||||
|
Replaces the monolith's ws_context.py. Instead of calling Electron directly
|
||||||
|
via WebSocket, this publishes tool_call frames to Redis (ws:out:{user_id})
|
||||||
|
and awaits the result via BRPOP on tool:result:{call_id}.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from shared.redis import redis_client, tool_result_key, ws_out_channel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TOOL_CALL_TIMEOUT = 30 # seconds — BRPOP timeout
|
||||||
|
|
||||||
|
# Per-request user_id context var (set before agent runs)
|
||||||
|
_current_user_id: ContextVar[str | None] = ContextVar("_current_user_id", default=None)
|
||||||
|
|
||||||
|
# Optional collector for debug
|
||||||
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
|
"_tool_result_collector", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_current_user(user_id: str) -> None:
|
||||||
|
_current_user_id.set(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_current_user() -> None:
|
||||||
|
_current_user_id.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||||
|
_tool_result_collector.set(lst)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_tool_result_collector() -> None:
|
||||||
|
_tool_result_collector.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_on_client(
|
||||||
|
action: str,
|
||||||
|
table: str | None = None,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
vector: list[float] | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send a tool_call to Electron via Redis and await the result.
|
||||||
|
|
||||||
|
1. Build tool_call payload
|
||||||
|
2. Publish to ws:out:{user_id} (WS Gateway forwards to Electron)
|
||||||
|
3. BRPOP on tool:result:{call_id} (WS Gateway pushes when Electron replies)
|
||||||
|
4. Return result dict
|
||||||
|
|
||||||
|
Raises RuntimeError if no user_id is set or if the call times out.
|
||||||
|
"""
|
||||||
|
user_id = _current_user_id.get()
|
||||||
|
if not user_id:
|
||||||
|
raise RuntimeError(
|
||||||
|
"execute_on_client() called without a user_id — "
|
||||||
|
"set_current_user() must be called first."
|
||||||
|
)
|
||||||
|
|
||||||
|
call_id = str(uuid4())
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"type": "tool_call",
|
||||||
|
"id": call_id,
|
||||||
|
"action": action,
|
||||||
|
}
|
||||||
|
if table is not None:
|
||||||
|
payload["table"] = table
|
||||||
|
if data is not None:
|
||||||
|
payload["data"] = data
|
||||||
|
if filters is not None:
|
||||||
|
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
||||||
|
if vector is not None:
|
||||||
|
payload["vector"] = vector
|
||||||
|
if limit is not None:
|
||||||
|
payload["limit"] = limit
|
||||||
|
|
||||||
|
# Publish tool_call to WS Gateway → Electron
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(payload))
|
||||||
|
|
||||||
|
# Wait for Electron's tool_result
|
||||||
|
result_key = tool_result_key(call_id)
|
||||||
|
response = await redis_client.brpop(result_key, timeout=_TOOL_CALL_TIMEOUT)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Tool call {call_id} timed out after {_TOOL_CALL_TIMEOUT}s — "
|
||||||
|
f"device may be offline or unresponsive."
|
||||||
|
)
|
||||||
|
|
||||||
|
# response is (key, value) tuple
|
||||||
|
_, raw = response
|
||||||
|
result = json.loads(raw)
|
||||||
|
|
||||||
|
# Collect for debug if requested
|
||||||
|
collector = _tool_result_collector.get(None)
|
||||||
|
if collector is not None:
|
||||||
|
collector.append({
|
||||||
|
"action": action,
|
||||||
|
"table": table,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
17
services/chat/requirements.txt
Normal file
17
services/chat/requirements.txt
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
asyncpg>=0.30.0
|
||||||
|
redis>=5.0.0
|
||||||
|
cryptography>=42.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
langchain-core>=0.3.0
|
||||||
|
langchain-openai>=0.3.0
|
||||||
|
langchain-litellm>=0.3.0
|
||||||
|
litellm>=1.50.0
|
||||||
|
openai>=1.50.0
|
||||||
|
httpx>=0.27.0
|
||||||
|
langfuse>=3.0.0
|
||||||
36
services/ws-gateway/Dockerfile
Normal file
36
services/ws-gateway/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY services/ws-gateway/requirements.txt ./requirements.txt
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Shared module
|
||||||
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Service source
|
||||||
|
COPY services/ws-gateway/app/ app/
|
||||||
|
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Single worker — each instance handles many WS connections via asyncio
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "1", \
|
||||||
|
"--timeout", "0"]
|
||||||
17
services/ws-gateway/README.md
Normal file
17
services/ws-gateway/README.md
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# WS Gateway
|
||||||
|
|
||||||
|
Stateless WebSocket proxy. Accepts Electron connections, authenticates JWT,
|
||||||
|
routes frames to Chat/Batch services via Redis pub/sub.
|
||||||
|
|
||||||
|
## No business logic
|
||||||
|
This service does NOT know what tasks, notes, or agents are.
|
||||||
|
It only routes JSON frames between Electron and downstream services.
|
||||||
|
|
||||||
|
## Scaling
|
||||||
|
Sticky sessions on `user_id` (Traefik consistent hashing).
|
||||||
|
|
||||||
|
## Redis channels used
|
||||||
|
- Subscribe: `ws:out:{user_id}` (frames to send to client)
|
||||||
|
- Publish: `chat:request:{user_id}`, `batch:request:{user_id}`
|
||||||
|
- LPUSH: `tool:result:{call_id}` (from client tool_result frames)
|
||||||
|
- HSET/HDEL: `ws:devices:{user_id}` (device registry)
|
||||||
0
services/ws-gateway/app/__init__.py
Normal file
0
services/ws-gateway/app/__init__.py
Normal file
173
services/ws-gateway/app/handler.py
Normal file
173
services/ws-gateway/app/handler.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""WebSocket handler — device connection lifecycle.
|
||||||
|
|
||||||
|
Accepts Electron WS connections, authenticates JWT, registers device in Redis,
|
||||||
|
and runs two concurrent loops:
|
||||||
|
1. Message loop: receive frames from Electron, route to Redis
|
||||||
|
2. Outbound loop: subscribe to Redis ws:out:{user_id}, forward to Electron
|
||||||
|
3. Heartbeat loop: ping every 30s
|
||||||
|
|
||||||
|
No business logic lives here — the handler is a JSON frame router.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.schemas import WsFrameType
|
||||||
|
|
||||||
|
from app.redis_bridge import (
|
||||||
|
publish_batch_request,
|
||||||
|
publish_chat_request,
|
||||||
|
push_tool_result,
|
||||||
|
register_device,
|
||||||
|
set_gateway_id,
|
||||||
|
subscribe_outbound,
|
||||||
|
unregister_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/ws", tags=["ws-gateway"])
|
||||||
|
|
||||||
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||||
|
|
||||||
|
# Set a unique gateway instance ID on module load
|
||||||
|
set_gateway_id(str(uuid4()))
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/device")
|
||||||
|
async def device_ws(websocket: WebSocket) -> None:
|
||||||
|
"""Persistent WebSocket endpoint for Electron device connections."""
|
||||||
|
|
||||||
|
# ── 1. Authenticate via ?token= query parameter ──────────────────
|
||||||
|
token = websocket.query_params.get("token", "")
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.JWT_PUBLIC_KEY,
|
||||||
|
algorithms=["RS256"],
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
email: str | None = payload.get("email")
|
||||||
|
if not user_id:
|
||||||
|
raise JWTError("missing sub")
|
||||||
|
except JWTError:
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
# ── 2. Await device_hello frame ──────────────────────────────────
|
||||||
|
try:
|
||||||
|
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
||||||
|
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
hello = json.loads(raw)
|
||||||
|
if hello.get("type") != WsFrameType.device_hello:
|
||||||
|
raise ValueError("expected device_hello as first frame")
|
||||||
|
device_id: str = hello["device_id"]
|
||||||
|
agent_ids: list[str] = hello.get("agent_ids", [])
|
||||||
|
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
||||||
|
logger.warning("handler: invalid device_hello user=%s: %s", user_id, exc)
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 3. Register device in Redis ──────────────────────────────────
|
||||||
|
await register_device(user_id, device_id)
|
||||||
|
logger.info("handler: connected user=%s device=%s agents=%s", user_id, device_id, agent_ids)
|
||||||
|
|
||||||
|
# Notify downstream services that device is online (for agent trigger)
|
||||||
|
await publish_batch_request(user_id, {
|
||||||
|
"type": "device_online",
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
"agent_ids": agent_ids,
|
||||||
|
})
|
||||||
|
|
||||||
|
# ── 4. Subscribe to outbound Redis channel ───────────────────────
|
||||||
|
pubsub = await subscribe_outbound(user_id)
|
||||||
|
|
||||||
|
# ── 5. Run concurrent loops ──────────────────────────────────────
|
||||||
|
try:
|
||||||
|
await asyncio.gather(
|
||||||
|
_inbound_loop(websocket, user_id),
|
||||||
|
_outbound_loop(websocket, pubsub),
|
||||||
|
_heartbeat_loop(websocket),
|
||||||
|
)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("handler: unhandled exception user=%s: %s", user_id, exc)
|
||||||
|
finally:
|
||||||
|
await pubsub.unsubscribe()
|
||||||
|
await pubsub.aclose()
|
||||||
|
await unregister_device(user_id)
|
||||||
|
logger.info("handler: disconnected user=%s device=%s", user_id, device_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Inbound: Electron → Redis ────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _inbound_loop(websocket: WebSocket, user_id: str) -> None:
|
||||||
|
"""Receive frames from Electron and route to the appropriate Redis channel."""
|
||||||
|
async for raw in websocket.iter_text():
|
||||||
|
try:
|
||||||
|
frame: dict = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("handler: invalid JSON from user=%s", user_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame_type = frame.get("type")
|
||||||
|
|
||||||
|
# Inject user_id so downstream services know who sent it
|
||||||
|
frame["user_id"] = user_id
|
||||||
|
|
||||||
|
if frame_type == WsFrameType.tool_result:
|
||||||
|
call_id = frame.get("id")
|
||||||
|
if call_id:
|
||||||
|
await push_tool_result(call_id, frame)
|
||||||
|
else:
|
||||||
|
logger.warning("handler: tool_result missing id user=%s", user_id)
|
||||||
|
|
||||||
|
elif frame_type in (WsFrameType.home_request, WsFrameType.floating_request):
|
||||||
|
await publish_chat_request(user_id, frame)
|
||||||
|
|
||||||
|
elif frame_type in (WsFrameType.journey_start, WsFrameType.journey_message):
|
||||||
|
await publish_batch_request(user_id, frame)
|
||||||
|
|
||||||
|
elif frame_type == "pong":
|
||||||
|
pass # heartbeat ack
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug("handler: unknown frame type %r user=%s", frame_type, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Outbound: Redis → Electron ───────────────────────────────────────
|
||||||
|
|
||||||
|
async def _outbound_loop(websocket: WebSocket, pubsub) -> None:
|
||||||
|
"""Subscribe to Redis ws:out:{user_id} and forward frames to Electron."""
|
||||||
|
while True:
|
||||||
|
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||||
|
if message is not None and message["type"] == "message":
|
||||||
|
await websocket.send_text(message["data"])
|
||||||
|
else:
|
||||||
|
# Brief sleep to avoid busy-wait when no messages
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Heartbeat ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||||
|
"""Send ping frames every 30s to keep the connection alive."""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||||
|
await websocket.send_text(json.dumps({"type": "ping"}))
|
||||||
56
services/ws-gateway/app/main.py
Normal file
56
services/ws-gateway/app/main.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""WS Gateway — stateless WebSocket proxy.
|
||||||
|
|
||||||
|
Accepts Electron device connections, authenticates JWT (RS256 public key),
|
||||||
|
and routes frames between Electron and downstream services via Redis pub/sub.
|
||||||
|
|
||||||
|
This service has NO business logic — it only routes JSON frames.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
|
||||||
|
_repo_root = str(Path(__file__).resolve().parents[3])
|
||||||
|
if _repo_root not in sys.path:
|
||||||
|
sys.path.insert(0, _repo_root)
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
yield
|
||||||
|
from shared.redis import redis_client
|
||||||
|
|
||||||
|
await redis_client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
app = FastAPI(
|
||||||
|
title="Adiuva WS Gateway",
|
||||||
|
version="0.1.0",
|
||||||
|
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||||
|
redoc_url=None,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.handler import router
|
||||||
|
|
||||||
|
app.include_router(router, prefix="/api/v1")
|
||||||
|
|
||||||
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
async def health() -> dict:
|
||||||
|
return {"status": "ok", "service": "ws-gateway", "version": app.version}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
104
services/ws-gateway/app/redis_bridge.py
Normal file
104
services/ws-gateway/app/redis_bridge.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Redis bridge — device registry + pub/sub routing.
|
||||||
|
|
||||||
|
All inter-service communication passes through Redis:
|
||||||
|
- Device registry: HSET/HDEL ws:devices:{user_id}
|
||||||
|
- Outbound frames: Subscribe ws:out:{user_id}
|
||||||
|
- Chat requests: Publish chat:request:{user_id}
|
||||||
|
- Batch requests: Publish batch:request:{user_id}
|
||||||
|
- Tool results: LPUSH tool:result:{call_id}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from shared.redis import (
|
||||||
|
batch_request_channel,
|
||||||
|
chat_request_channel,
|
||||||
|
device_key,
|
||||||
|
redis_client,
|
||||||
|
tool_result_key,
|
||||||
|
ws_out_channel,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Instance ID for this gateway replica (set on startup)
|
||||||
|
_GATEWAY_ID: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def set_gateway_id(gid: str) -> None:
|
||||||
|
global _GATEWAY_ID
|
||||||
|
_GATEWAY_ID = gid
|
||||||
|
|
||||||
|
|
||||||
|
def get_gateway_id() -> str:
|
||||||
|
return _GATEWAY_ID
|
||||||
|
|
||||||
|
|
||||||
|
# ── Device Registry ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def register_device(user_id: str, device_id: str) -> None:
|
||||||
|
"""Register a connected device in Redis."""
|
||||||
|
key = device_key(user_id)
|
||||||
|
await redis_client.hset(key, mapping={
|
||||||
|
"device_id": device_id,
|
||||||
|
"gateway_id": _GATEWAY_ID,
|
||||||
|
})
|
||||||
|
logger.info("redis_bridge: registered user=%s device=%s gateway=%s", user_id, device_id, _GATEWAY_ID)
|
||||||
|
|
||||||
|
|
||||||
|
async def unregister_device(user_id: str) -> None:
|
||||||
|
"""Remove device registration from Redis."""
|
||||||
|
key = device_key(user_id)
|
||||||
|
await redis_client.delete(key)
|
||||||
|
logger.info("redis_bridge: unregistered user=%s", user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def is_device_online(user_id: str) -> bool:
|
||||||
|
"""Check if a device is registered."""
|
||||||
|
key = device_key(user_id)
|
||||||
|
return await redis_client.exists(key) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Frame Routing ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_chat_request(user_id: str, frame: dict) -> None:
|
||||||
|
"""Forward a chat request frame to the Chat Service via Redis."""
|
||||||
|
channel = chat_request_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(frame))
|
||||||
|
logger.debug("redis_bridge: published chat_request user=%s", user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_batch_request(user_id: str, frame: dict) -> None:
|
||||||
|
"""Forward a batch request frame to the Batch Agent Service via Redis."""
|
||||||
|
channel = batch_request_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(frame))
|
||||||
|
logger.debug("redis_bridge: published batch_request user=%s", user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def push_tool_result(call_id: str, result: dict) -> None:
|
||||||
|
"""Push a tool_result to the Redis list for the waiting service.
|
||||||
|
|
||||||
|
Chat/Batch services do BRPOP on this key with a 30s timeout.
|
||||||
|
"""
|
||||||
|
key = tool_result_key(call_id)
|
||||||
|
await redis_client.lpush(key, json.dumps(result))
|
||||||
|
# Auto-expire after 60s to prevent stale keys
|
||||||
|
await redis_client.expire(key, 60)
|
||||||
|
logger.debug("redis_bridge: pushed tool_result call_id=%s", call_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def subscribe_outbound(user_id: str):
|
||||||
|
"""Return an async pubsub subscription for frames to send to Electron.
|
||||||
|
|
||||||
|
Chat/Batch services publish to ws:out:{user_id} and this gateway
|
||||||
|
forwards them to the connected WebSocket.
|
||||||
|
"""
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
pubsub = redis_client.pubsub()
|
||||||
|
await pubsub.subscribe(channel)
|
||||||
|
return pubsub
|
||||||
8
services/ws-gateway/requirements.txt
Normal file
8
services/ws-gateway/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
python-jose[cryptography]>=3.3.0
|
||||||
|
redis>=5.0.0
|
||||||
|
websockets>=14.0
|
||||||
5
shared/__init__.py
Normal file
5
shared/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Shared module — imported by all microservices.
|
||||||
|
|
||||||
|
Contains DB engine/session, ORM models, Pydantic schemas, config,
|
||||||
|
and Redis utilities. Changes here affect ALL services.
|
||||||
|
"""
|
||||||
98
shared/config.py
Normal file
98
shared/config.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Shared configuration — Pydantic Settings loaded from environment.
|
||||||
|
|
||||||
|
All services import ``settings`` from here. Each service only uses a subset
|
||||||
|
of the vars, but keeping one Settings class avoids fragmentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import field_validator
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
# Locate the repo root (adiuva-api/) so we can load its .env as a fallback.
|
||||||
|
# Works whether cwd is adiuva-api/ (monolith) or adiuva-api/services/xyz/ (microservice).
|
||||||
|
_this_dir = Path(__file__).resolve().parent # shared/
|
||||||
|
_repo_root = _this_dir.parent # adiuva-api/
|
||||||
|
_root_env = _repo_root / ".env"
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# ── Database ─────────────────────────────────────────────────────
|
||||||
|
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
|
||||||
|
|
||||||
|
# ── JWT ────────────────────────────────────────────────────────
|
||||||
|
# RS256 public key (PEM). Used by any service that needs to verify
|
||||||
|
# JWTs locally (optional — Traefik ForwardAuth handles this in prod).
|
||||||
|
# The private key lives ONLY in the Auth Service config.
|
||||||
|
JWT_PUBLIC_KEY: str = ""
|
||||||
|
|
||||||
|
@field_validator("JWT_PUBLIC_KEY", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _expand_pem_newlines(cls, v: str) -> str:
|
||||||
|
if isinstance(v, str) and r"\n" in v:
|
||||||
|
return v.replace(r"\n", "\n")
|
||||||
|
return v
|
||||||
|
|
||||||
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||||
|
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||||
|
|
||||||
|
# ── Redis ────────────────────────────────────────────────────────
|
||||||
|
REDIS_URL: str = "redis://localhost:6379/0"
|
||||||
|
|
||||||
|
# ── Stripe ───────────────────────────────────────────────────────
|
||||||
|
STRIPE_SECRET_KEY: str = ""
|
||||||
|
STRIPE_WEBHOOK_SECRET: str = ""
|
||||||
|
|
||||||
|
# ── S3 ───────────────────────────────────────────────────────────
|
||||||
|
S3_BUCKET: str = ""
|
||||||
|
S3_REGION: str = "us-east-1"
|
||||||
|
S3_ENDPOINT_URL: str = ""
|
||||||
|
AWS_ACCESS_KEY_ID: str = ""
|
||||||
|
AWS_SECRET_ACCESS_KEY: str = ""
|
||||||
|
|
||||||
|
# ── Vector stores ────────────────────────────────────────────────
|
||||||
|
PINECONE_API_KEY: str = ""
|
||||||
|
PINECONE_INDEX: str = "adiuva"
|
||||||
|
QDRANT_URL: str = ""
|
||||||
|
QDRANT_API_KEY: str = ""
|
||||||
|
|
||||||
|
# ── LLM providers ────────────────────────────────────────────────
|
||||||
|
OPENAI_API_KEY: str = ""
|
||||||
|
ANTHROPIC_API_KEY: str = ""
|
||||||
|
GOOGLE_API_KEY: str = ""
|
||||||
|
CEREBRAS_API_KEY: str = ""
|
||||||
|
|
||||||
|
LLM_MODEL: str = "gpt-4o"
|
||||||
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
|
GITHUB_COPILOT_TOKEN_DIR: str = ""
|
||||||
|
|
||||||
|
# ── OAuth (integrations) ─────────────────────────────────────────
|
||||||
|
GMAIL_CLIENT_ID: str = ""
|
||||||
|
GMAIL_CLIENT_SECRET: str = ""
|
||||||
|
MS_CLIENT_ID: str = ""
|
||||||
|
MS_CLIENT_SECRET: str = ""
|
||||||
|
MS_TENANT_ID: str = "common"
|
||||||
|
OAUTH_ENCRYPTION_KEY: str = ""
|
||||||
|
|
||||||
|
# ── Langfuse (observability) ─────────────────────────────────────
|
||||||
|
LANGFUSE_SECRET_KEY: str = ""
|
||||||
|
LANGFUSE_PUBLIC_KEY: str = ""
|
||||||
|
LANGFUSE_HOST: str = "https://cloud.langfuse.com"
|
||||||
|
|
||||||
|
# ── CORS ─────────────────────────────────────────────────────────
|
||||||
|
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||||
|
|
||||||
|
# ── Environment ──────────────────────────────────────────────────
|
||||||
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
# Local .env (cwd) takes priority; root .env is fallback.
|
||||||
|
env_file=(".env", str(_root_env)),
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
extra="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
32
shared/db.py
Normal file
32
shared/db.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Database engine, session factory, and declarative base.
|
||||||
|
|
||||||
|
All services use the async SQLAlchemy API via ``get_session()``.
|
||||||
|
Alembic migrations use the synchronous psycopg2 URL (see alembic/env.py).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.DATABASE_URL,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
echo=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""Shared declarative base for all ORM models."""
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""FastAPI dependency that yields an async DB session per request."""
|
||||||
|
async with async_session() as session:
|
||||||
|
yield session
|
||||||
455
shared/models.py
Normal file
455
shared/models.py
Normal file
@@ -0,0 +1,455 @@
|
|||||||
|
"""SQLAlchemy ORM models for all persistent tables.
|
||||||
|
|
||||||
|
Centralized here so that Alembic migrations and all services share
|
||||||
|
the same model definitions. Each service only queries the tables it owns.
|
||||||
|
|
||||||
|
Ownership:
|
||||||
|
Auth Service → users, refresh_tokens, subscriptions
|
||||||
|
Chat Service → memory_core, memory_associative, memory_episodic, memory_proactive
|
||||||
|
Batch Agent → local_agent_configs, cloud_agent_configs, agent_run_logs
|
||||||
|
Billing Service → subscriptions (shared write with Auth)
|
||||||
|
(excluded MVP) → storage_records, backup_metadata, plugins, plugin_*, revenue_events
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
JSON,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
Uuid,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from shared.db import Base
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _uuid() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Enum types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||||
|
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
||||||
|
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
||||||
|
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
||||||
|
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
||||||
|
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Auth models ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
|
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
|
||||||
|
back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
subscription: Mapped[Subscription | None] = relationship(
|
||||||
|
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(Base):
|
||||||
|
__tablename__ = "refresh_tokens"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||||
|
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||||
|
|
||||||
|
|
||||||
|
class Subscription(Base):
|
||||||
|
__tablename__ = "subscriptions"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, unique=True, index=True
|
||||||
|
)
|
||||||
|
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
|
||||||
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
|
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
|
||||||
|
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="subscription")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Storage models (excluded from MVP, kept for Alembic) ──────────────
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecord(Base):
|
||||||
|
__tablename__ = "storage_records"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BackupMetadata(Base):
|
||||||
|
__tablename__ = "backup_metadata"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plugin models (excluded from MVP, kept for Alembic) ───────────────
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(Base):
|
||||||
|
__tablename__ = "plugins"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
||||||
|
author_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||||
|
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
||||||
|
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
|
||||||
|
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
||||||
|
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
|
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
submitted_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
installations: Mapped[list[PluginInstallation]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
reviews: Mapped[list[PluginReview]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallation(Base):
|
||||||
|
__tablename__ = "plugin_installations"
|
||||||
|
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
installed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginReview(Base):
|
||||||
|
__tablename__ = "plugin_reviews"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
reviewer_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
||||||
|
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
reviewed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
||||||
|
|
||||||
|
|
||||||
|
class RevenueEvent(Base):
|
||||||
|
__tablename__ = "revenue_events"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent models ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAgentConfig(Base):
|
||||||
|
__tablename__ = "local_agent_configs"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
device_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
|
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
|
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
|
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||||
|
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
|
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
run_logs: Mapped[list[AgentRunLog]] = relationship(
|
||||||
|
back_populates="local_agent",
|
||||||
|
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
|
||||||
|
foreign_keys="AgentRunLog.agent_id",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
overlaps="run_logs,cloud_agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CloudAgentConfig(Base):
|
||||||
|
__tablename__ = "cloud_agent_configs"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
provider: Mapped[str] = mapped_column(CloudProviderEnum, nullable=False)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
|
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
oauth_token_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
filter_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||||
|
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||||
|
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
|
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
run_logs: Mapped[list[AgentRunLog]] = relationship(
|
||||||
|
back_populates="cloud_agent",
|
||||||
|
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
|
||||||
|
foreign_keys="AgentRunLog.agent_id",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
overlaps="run_logs,local_agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRunLog(Base):
|
||||||
|
__tablename__ = "agent_run_logs"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||||
|
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
||||||
|
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||||
|
started_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
local_agent: Mapped[LocalAgentConfig | None] = relationship(
|
||||||
|
back_populates="run_logs",
|
||||||
|
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
|
||||||
|
foreign_keys="AgentRunLog.agent_id",
|
||||||
|
overlaps="run_logs,cloud_agent",
|
||||||
|
)
|
||||||
|
cloud_agent: Mapped[CloudAgentConfig | None] = relationship(
|
||||||
|
back_populates="run_logs",
|
||||||
|
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
|
||||||
|
foreign_keys="AgentRunLog.agent_id",
|
||||||
|
overlaps="run_logs,local_agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Memory models ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryCore(Base):
|
||||||
|
"""Per-user persistent key/value preferences, encrypted at rest."""
|
||||||
|
|
||||||
|
__tablename__ = "memory_core"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryAssociative(Base):
|
||||||
|
"""Per-user semantic memory: encrypted content + pgvector embedding."""
|
||||||
|
|
||||||
|
__tablename__ = "memory_associative"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||||
|
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEpisodic(Base):
|
||||||
|
"""Per-user session summaries, encrypted at rest."""
|
||||||
|
|
||||||
|
__tablename__ = "memory_episodic"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryProactive(Base):
|
||||||
|
"""Per-user inferred behavioral patterns, encrypted at rest."""
|
||||||
|
|
||||||
|
__tablename__ = "memory_proactive"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
|
||||||
|
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
53
shared/redis.py
Normal file
53
shared/redis.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""Redis client and pub/sub utilities for inter-service communication.
|
||||||
|
|
||||||
|
All services that need Redis import from here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
redis_client: aioredis.Redis = aioredis.from_url(
|
||||||
|
settings.REDIS_URL,
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Channel naming conventions ────────────────────────────────────────
|
||||||
|
# See /memories/repo/microservices-architecture.md for full list.
|
||||||
|
|
||||||
|
def ws_out_channel(user_id: str) -> str:
|
||||||
|
"""Frames to forward to Electron via WS Gateway."""
|
||||||
|
return f"ws:out:{user_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def chat_request_channel(user_id: str) -> str:
|
||||||
|
"""Chat requests (home + floating) from WS Gateway → Chat Service."""
|
||||||
|
return f"chat:request:{user_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def batch_request_channel(user_id: str) -> str:
|
||||||
|
"""Batch requests (journey + triggers) from WS Gateway → Batch Agent."""
|
||||||
|
return f"batch:request:{user_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def tool_result_key(call_id: str) -> str:
|
||||||
|
"""Tool result list: LPUSH by WS Gateway, BRPOP by Chat/Batch."""
|
||||||
|
return f"tool:result:{call_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def device_key(user_id: str) -> str:
|
||||||
|
"""Device registry hash."""
|
||||||
|
return f"ws:devices:{user_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def tier_changed_channel(user_id: str) -> str:
|
||||||
|
"""Billing tier change notifications."""
|
||||||
|
return f"tier:changed:{user_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def journey_session_key(user_id: str) -> str:
|
||||||
|
"""Journey builder session (String + TTL 1800s)."""
|
||||||
|
return f"journey:{user_id}"
|
||||||
317
shared/schemas.py
Normal file
317
shared/schemas.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
"""Pydantic schemas — API request/response contracts.
|
||||||
|
|
||||||
|
Shared across all services. Mirrors the TypeScript types from
|
||||||
|
the Electron app (src/shared/api-types.ts).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ── Billing ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
BillingTier = Literal["free", "pro", "power", "team"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Auth ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class AuthTokens(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
expires_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class UserProfile(BaseModel):
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
|
tier: BillingTier
|
||||||
|
|
||||||
|
|
||||||
|
# ── Chat ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ChatContext(BaseModel):
|
||||||
|
user_profile: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
relevant_documents: list[str] = Field(default_factory=list)
|
||||||
|
recent_tasks: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
message: str
|
||||||
|
context: ChatContext = Field(default_factory=ChatContext)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
response: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── Backup ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class BackupMetadata(BaseModel):
|
||||||
|
version: int
|
||||||
|
timestamp: int
|
||||||
|
checksum: str
|
||||||
|
chunk_count: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
|
||||||
|
|
||||||
|
class StorageRecord(BaseModel):
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
table: str
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecordCreate(BaseModel):
|
||||||
|
table: str
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecordUpdate(BaseModel):
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
|
||||||
|
|
||||||
|
class VectorItem(BaseModel):
|
||||||
|
id: str
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorUpsertRequest(BaseModel):
|
||||||
|
vectors: list[VectorItem]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchRequest(BaseModel):
|
||||||
|
query_blob: bytes
|
||||||
|
top_k: int = 10
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchResult(BaseModel):
|
||||||
|
id: str
|
||||||
|
score: float
|
||||||
|
blob: bytes
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchResponse(BaseModel):
|
||||||
|
results: list[VectorSearchResult]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plugin Marketplace ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PluginManifest(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
version: str
|
||||||
|
author: str
|
||||||
|
permissions: list[str]
|
||||||
|
category: str
|
||||||
|
price_cents: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class PluginListResponse(BaseModel):
|
||||||
|
plugins: list[PluginManifest]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallRequest(BaseModel):
|
||||||
|
plugin_id: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
||||||
|
|
||||||
|
class WsFrameType(str, Enum):
|
||||||
|
# ── v2 frame types (kept for backward compat) ──────────────────────
|
||||||
|
chat_request = "chat_request"
|
||||||
|
text_chunk = "text_chunk"
|
||||||
|
tool_call = "tool_call"
|
||||||
|
tool_result = "tool_result"
|
||||||
|
final = "final"
|
||||||
|
ping = "ping"
|
||||||
|
device_hello = "device_hello"
|
||||||
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
|
home_request = "home_request"
|
||||||
|
floating_request = "floating_request"
|
||||||
|
stream_start = "stream_start"
|
||||||
|
stream_text = "stream_text"
|
||||||
|
stream_end = "stream_end"
|
||||||
|
floating_domain = "floating_domain"
|
||||||
|
data_request = "data_request"
|
||||||
|
data_response = "data_response"
|
||||||
|
mutation = "mutation"
|
||||||
|
# ── v4 journey frame types ────────────────────────────────────────
|
||||||
|
journey_start = "journey_start"
|
||||||
|
journey_message = "journey_message"
|
||||||
|
journey_reply = "journey_reply"
|
||||||
|
|
||||||
|
|
||||||
|
class WsToolCall(BaseModel):
|
||||||
|
"""Server → Client: requests a CRUD/vector operation on the local DB."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.tool_call] = WsFrameType.tool_call
|
||||||
|
id: str
|
||||||
|
action: str
|
||||||
|
table: str | None = None
|
||||||
|
data: dict[str, Any] | None = None
|
||||||
|
filters: dict[str, Any] | None = None
|
||||||
|
vector: list[float] | None = None
|
||||||
|
limit: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WsToolResult(BaseModel):
|
||||||
|
"""Client → Server: result of a CRUD/vector operation."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.tool_result] = WsFrameType.tool_result
|
||||||
|
id: str
|
||||||
|
row: dict[str, Any] | None = None
|
||||||
|
rows: list[dict[str, Any]] | None = None
|
||||||
|
results: list[dict[str, Any]] | None = None
|
||||||
|
deleted: bool | None = None
|
||||||
|
ok: bool | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WsTextChunk(BaseModel):
|
||||||
|
"""Server → Client: incremental LLM response text."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.text_chunk] = WsFrameType.text_chunk
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class WsFinal(BaseModel):
|
||||||
|
"""Server → Client: signals end of response with the complete text."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.final] = WsFrameType.final
|
||||||
|
response: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── WebSocket Agent Frame Protocol ────────────────────────────────────
|
||||||
|
|
||||||
|
class WsDeviceHello(BaseModel):
|
||||||
|
"""Client → Server: device identification on WS connect."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.device_hello] = WsFrameType.device_hello
|
||||||
|
device_id: str
|
||||||
|
agent_ids: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
|
class WsFloatingScope(BaseModel):
|
||||||
|
"""Scope for a floating request."""
|
||||||
|
|
||||||
|
type: Literal["task", "project", "note", "timeline"]
|
||||||
|
id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WsHomeRequest(BaseModel):
|
||||||
|
"""Client → Server: Home chat message."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
|
||||||
|
message: str
|
||||||
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WsFloatingRequest(BaseModel):
|
||||||
|
"""Client → Server: Floating chat message scoped to an entity."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
|
||||||
|
message: str
|
||||||
|
scope: WsFloatingScope
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamStart(BaseModel):
|
||||||
|
"""Server → Client: signals start of a streaming response."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_start] = WsFrameType.stream_start
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamText(BaseModel):
|
||||||
|
"""Server → Client: streamed text token."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_text] = WsFrameType.stream_text
|
||||||
|
request_id: str
|
||||||
|
chunk: str
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamEnd(BaseModel):
|
||||||
|
"""Server → Client: signals end of a streaming response."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class WsDomain(BaseModel):
|
||||||
|
"""Structured floating domain payload for UI routing decisions."""
|
||||||
|
|
||||||
|
type: Literal["task", "timeline", "project", "node"]
|
||||||
|
id: str | None = None
|
||||||
|
section: Literal["task", "timeline", "note"] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WsFloatingDomain(BaseModel):
|
||||||
|
"""Server → Client: domain determined for a floating request."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
|
request_id: str
|
||||||
|
domain: WsDomain
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class AgentCatalogItem(BaseModel):
|
||||||
|
type: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCreationCheckRequest(BaseModel):
|
||||||
|
active_agents: int = Field(ge=0, default=0)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCreationCheckResponse(BaseModel):
|
||||||
|
allowed: bool
|
||||||
|
tier: BillingTier
|
||||||
|
active_agents: int
|
||||||
|
limit: int
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTriggerRequest(BaseModel):
|
||||||
|
directory: str = Field(min_length=1)
|
||||||
|
device_id: str = Field(default="")
|
||||||
|
agent_id: str | None = None
|
||||||
|
what_to_extract: list[str] = Field(min_length=1)
|
||||||
|
actions_by_type: dict[str, list[str]] | None = None
|
||||||
|
batch_interval: str = Field(min_length=1)
|
||||||
|
custom_agent_prompt: str = Field(min_length=1)
|
||||||
|
active_agents: int = Field(ge=0, default=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class AgentRunLogResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
agent_id: str
|
||||||
|
agent_type: Literal["local", "cloud"]
|
||||||
|
status: Literal["running", "success", "error", "partial"]
|
||||||
|
items_processed: int
|
||||||
|
items_created: int
|
||||||
|
errors: list[str]
|
||||||
|
started_at: int
|
||||||
|
completed_at: int | None
|
||||||
@@ -10,13 +10,13 @@ Coverage:
|
|||||||
- run_local_agent — file-read timeout path
|
- run_local_agent — file-read timeout path
|
||||||
- run_local_agent — LLM extraction error path
|
- run_local_agent — LLM extraction error path
|
||||||
- run_cloud_agent — stub returns error immediately
|
- run_cloud_agent — stub returns error immediately
|
||||||
- trigger_pending_runs — overdue local + cloud dispatched
|
- trigger_pending_runs — skipped when config is client-owned
|
||||||
- trigger_pending_runs — non-overdue skipped
|
- trigger_pending_runs — non-overdue skipped
|
||||||
- trigger_pending_runs — device_id filter for local agents
|
- trigger_pending_runs — device_id filter for local agents
|
||||||
|
|
||||||
Integration:
|
Integration:
|
||||||
- POST /agents/{id}/run — 404 on unknown agent
|
- POST /agents/can-create — billing eligibility check
|
||||||
- POST /agents/{id}/run — creates run log + dispatches background task
|
- POST /agents/trigger — creates run log + dispatches background task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -373,7 +373,7 @@ async def test_run_local_agent_happy_path():
|
|||||||
assert kwargs["items_processed"] == 1
|
assert kwargs["items_processed"] == 1
|
||||||
assert kwargs["items_created"] == 1
|
assert kwargs["items_created"] == 1
|
||||||
assert kwargs["errors"] == []
|
assert kwargs["errors"] == []
|
||||||
assert kwargs["update_config_last_run"] is True
|
assert kwargs["update_config_last_run"] is False
|
||||||
|
|
||||||
# Verify agent_run frame was sent.
|
# Verify agent_run frame was sent.
|
||||||
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
||||||
@@ -690,31 +690,11 @@ async def test_finalize_run_updates_cloud_config_last_run_at():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_no_overdue():
|
async def test_trigger_pending_runs_no_overdue():
|
||||||
"""If no agents are overdue trigger_pending_runs does nothing."""
|
"""Pending-run scan is skipped because agent config is client-owned."""
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
config = _make_local_config()
|
|
||||||
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
|
|
||||||
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
|
|
||||||
|
|
||||||
mock_db_result_local = MagicMock()
|
|
||||||
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
|
||||||
|
|
||||||
mock_db_result_cloud = MagicMock()
|
|
||||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
|
||||||
|
|
||||||
mgr = _make_manager()
|
mgr = _make_manager()
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
||||||
mock_ctx = AsyncMock()
|
|
||||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
|
||||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_ctx.execute = AsyncMock(
|
|
||||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
|
||||||
)
|
|
||||||
mock_session_factory.return_value = mock_ctx
|
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
mock_run.assert_not_called()
|
||||||
@@ -722,31 +702,11 @@ async def test_trigger_pending_runs_no_overdue():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_device_id_filter():
|
async def test_trigger_pending_runs_device_id_filter():
|
||||||
"""Local agents are only triggered for the matching device_id."""
|
"""Device filtering is no longer backend-managed in pending runs."""
|
||||||
# The DB query already filters by device_id, so we verify the SELECT
|
|
||||||
# includes the device_id filter by checking that a config bound to a
|
|
||||||
# different device is never dispatched.
|
|
||||||
#
|
|
||||||
# Since trigger_pending_runs queries with device_id == "dev-001",
|
|
||||||
# simulate the DB returning an empty list (as it would for a mismatch).
|
|
||||||
mock_db_result_local = MagicMock()
|
|
||||||
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
|
|
||||||
|
|
||||||
mock_db_result_cloud = MagicMock()
|
|
||||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
|
||||||
|
|
||||||
mgr = _make_manager(device_id="dev-001")
|
mgr = _make_manager(device_id="dev-001")
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
||||||
mock_ctx = AsyncMock()
|
|
||||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
|
||||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_ctx.execute = AsyncMock(
|
|
||||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
|
||||||
)
|
|
||||||
mock_session_factory.return_value = mock_ctx
|
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
mock_run.assert_not_called()
|
||||||
@@ -754,56 +714,18 @@ async def test_trigger_pending_runs_device_id_filter():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_dispatches_overdue():
|
async def test_trigger_pending_runs_dispatches_overdue():
|
||||||
"""Overdue local agent triggers run_local_agent sequentially."""
|
"""No pending runs are dispatched by backend after config deprecation."""
|
||||||
config = _make_local_config() # last_run_at=None → always overdue
|
|
||||||
|
|
||||||
mock_db_result_local = MagicMock()
|
|
||||||
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
|
||||||
|
|
||||||
mock_db_result_cloud = MagicMock()
|
|
||||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
|
||||||
|
|
||||||
mgr = _make_manager()
|
mgr = _make_manager()
|
||||||
|
|
||||||
call_order: list[str] = []
|
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
|
||||||
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
|
|
||||||
call_order.append("run_local")
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
|
||||||
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
|
|
||||||
# First call: query configs. Subsequent calls: create run_log.
|
|
||||||
mock_query_ctx = AsyncMock()
|
|
||||||
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
|
|
||||||
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_query_ctx.execute = AsyncMock(
|
|
||||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
|
||||||
)
|
|
||||||
|
|
||||||
run_log_obj = AgentRunLog(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
agent_id=config.id,
|
|
||||||
agent_type="local",
|
|
||||||
user_id=_FREE_UID,
|
|
||||||
status="running",
|
|
||||||
started_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
mock_insert_ctx = AsyncMock()
|
|
||||||
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
|
|
||||||
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_insert_ctx.add = MagicMock()
|
|
||||||
mock_insert_ctx.commit = AsyncMock()
|
|
||||||
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
|
|
||||||
|
|
||||||
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
|
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
assert call_order == ["run_local"]
|
mock_run.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Integration: POST /agents/{id}/run
|
# Integration: POST /agents/can-create and /agents/trigger
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -820,50 +742,67 @@ def _override_db(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_run_unknown_agent(client):
|
async def test_can_create_agent_allows_when_under_limit(client):
|
||||||
"""POST /agents/{id}/run returns 404 for unknown agent id."""
|
"""POST /agents/can-create returns allowed=True when under tier limit."""
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
f"/api/v1/agents/{uuid.uuid4()}/run",
|
"/api/v1/agents/can-create",
|
||||||
headers=auth_header("power"),
|
json={"active_agents": 0},
|
||||||
|
headers=auth_header("free"),
|
||||||
)
|
)
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["allowed"] is True
|
||||||
|
assert body["tier"] == "free"
|
||||||
|
assert body["active_agents"] == 0
|
||||||
|
assert body["limit"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_can_create_agent_denies_when_at_limit(client):
|
||||||
|
"""POST /agents/can-create returns allowed=False at free-tier limit."""
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/agents/can-create",
|
||||||
|
json={"active_agents": 2},
|
||||||
|
headers=auth_header("free"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["allowed"] is False
|
||||||
|
assert body["limit"] == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||||
"""POST /agents/{id}/run creates a run log and dispatches a background task."""
|
"""POST /agents/trigger creates a local run log and dispatches background task."""
|
||||||
# Create the local agent config in the DB.
|
dispatched: list[tuple[str, str]] = []
|
||||||
config = LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=TEST_USER_IDS["power"],
|
|
||||||
device_id="dev-001",
|
|
||||||
name="My Agent",
|
|
||||||
directory_paths=["/home/user/docs"],
|
|
||||||
data_types=["tasks"],
|
|
||||||
prompt_template="Extract tasks.",
|
|
||||||
file_extensions=[".txt"],
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
)
|
|
||||||
db_session.add(config)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
dispatched: list = []
|
|
||||||
|
|
||||||
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
||||||
dispatched.append((user_id, cfg.id))
|
dispatched.append((user_id, cfg.id))
|
||||||
|
|
||||||
|
def _fake_create_task(coro):
|
||||||
|
coro.close()
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
||||||
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
|
|
||||||
patch("asyncio.create_task") as mock_create_task:
|
patch("asyncio.create_task") as mock_create_task:
|
||||||
|
mock_create_task.side_effect = _fake_create_task
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
f"/api/v1/agents/{config.id}/run",
|
"/api/v1/agents/trigger",
|
||||||
|
json={
|
||||||
|
"directory": "/home/user/docs",
|
||||||
|
"what_to_extract": ["task", "note"],
|
||||||
|
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
|
||||||
|
"batch_interval": "0 */6 * * *",
|
||||||
|
"custom_agent_prompt": "Extract tasks and notes.",
|
||||||
|
"active_agents": 0,
|
||||||
|
},
|
||||||
headers=auth_header("power"),
|
headers=auth_header("power"),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resp.status_code == 202
|
assert resp.status_code == 202
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert data["agent_id"] == config.id
|
assert isinstance(data["agent_id"], str)
|
||||||
|
assert data["agent_id"]
|
||||||
assert data["status"] == "running"
|
assert data["status"] == "running"
|
||||||
assert data["agent_type"] == "local"
|
assert data["agent_type"] == "local"
|
||||||
|
|
||||||
|
|||||||
184
tests/test_classify_file.py
Normal file
184
tests/test_classify_file.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""Unit tests for Step 1 file classification (_classify_file).
|
||||||
|
|
||||||
|
These tests call the real LLM so they require OPENAI_API_KEY / LLM env vars.
|
||||||
|
Run with: pytest tests/test_classify_file.py -v
|
||||||
|
|
||||||
|
To run a quick manual check against a real file without the full UI:
|
||||||
|
python -m tests.test_classify_file <path/to/file.txt> [project_name...]
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.agent_runner import _classify_file
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
PROJECTS_SAMPLE = [
|
||||||
|
{
|
||||||
|
"id": "aaaa-0001-0000-0000-000000000001",
|
||||||
|
"name": "ARPA Sicilia POC",
|
||||||
|
"status": "active",
|
||||||
|
"aiSummary": "Proof of concept for AI features targeting ARPA Sicilia agency.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "bbbb-0002-0000-0000-000000000002",
|
||||||
|
"name": "SNAM AI Meeting Prep",
|
||||||
|
"status": "active",
|
||||||
|
"aiSummary": "AI-assisted preparation of meeting materials for SNAM.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "cccc-0003-0000-0000-000000000003",
|
||||||
|
"name": "SFERA+ Wave 2",
|
||||||
|
"status": "active",
|
||||||
|
"aiSummary": "Second wave of the SFERA+ whitelist project.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
ARPA_EMAIL = """\
|
||||||
|
to: roberto.musso@hpe.com; luca.tondin@hpecds.com
|
||||||
|
isImportance: normal
|
||||||
|
hasAttachment: True
|
||||||
|
---
|
||||||
|
## Body
|
||||||
|
Buongiorno,
|
||||||
|
|
||||||
|
In riferimento alla riunione di ieri sul POC ARPA Sicilia, vi invio il riassunto
|
||||||
|
dei deliverable concordati:
|
||||||
|
- Preparare demo entro il 30 marzo
|
||||||
|
- Condividere documentazione tecnica con il team ARPA
|
||||||
|
- Fissare call di follow-up la prossima settimana
|
||||||
|
|
||||||
|
Cordiali saluti
|
||||||
|
Roberto Marchetti
|
||||||
|
"""
|
||||||
|
|
||||||
|
SNAM_EMAIL = """\
|
||||||
|
to: roberto.musso@hpe.com
|
||||||
|
isImportance: high
|
||||||
|
hasAttachment: False
|
||||||
|
---
|
||||||
|
## Body
|
||||||
|
Ciao,
|
||||||
|
ti invio l'agenda per la riunione SNAM di domani.
|
||||||
|
Per favore conferma la tua presenza.
|
||||||
|
"""
|
||||||
|
|
||||||
|
UNRELATED_EMAIL = """\
|
||||||
|
to: roberto.musso@hpe.com
|
||||||
|
isImportance: normal
|
||||||
|
---
|
||||||
|
## Body
|
||||||
|
Benvenuto nel programma HPE Employee Learning Series.
|
||||||
|
Completa la formazione richiesta entro la fine del trimestre.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_classify_arpa_matches_existing():
|
||||||
|
project_id, domains, new_name = await _classify_file(
|
||||||
|
file_path="arpa_email.txt",
|
||||||
|
file_content=ARPA_EMAIL,
|
||||||
|
projects=PROJECTS_SAMPLE,
|
||||||
|
config_data_types=["tasks", "notes", "timelines"],
|
||||||
|
)
|
||||||
|
assert project_id == "aaaa-0001-0000-0000-000000000001", (
|
||||||
|
f"Expected ARPA project, got project_id={project_id!r} new_name={new_name!r}"
|
||||||
|
)
|
||||||
|
assert new_name is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_classify_snam_matches_existing():
|
||||||
|
project_id, domains, new_name = await _classify_file(
|
||||||
|
file_path="snam_email.txt",
|
||||||
|
file_content=SNAM_EMAIL,
|
||||||
|
projects=PROJECTS_SAMPLE,
|
||||||
|
config_data_types=["tasks", "notes"],
|
||||||
|
)
|
||||||
|
assert project_id == "bbbb-0002-0000-0000-000000000002", (
|
||||||
|
f"Expected SNAM project, got project_id={project_id!r} new_name={new_name!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_classify_unrelated_returns_new():
|
||||||
|
project_id, domains, new_name = await _classify_file(
|
||||||
|
file_path="learning_email.txt",
|
||||||
|
file_content=UNRELATED_EMAIL,
|
||||||
|
projects=PROJECTS_SAMPLE,
|
||||||
|
config_data_types=["tasks", "notes"],
|
||||||
|
)
|
||||||
|
assert project_id == "new"
|
||||||
|
assert new_name is not None # LLM should suggest a name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_classify_empty_file_returns_new():
|
||||||
|
project_id, domains, new_name = await _classify_file(
|
||||||
|
file_path="empty.txt",
|
||||||
|
file_content=" ",
|
||||||
|
projects=PROJECTS_SAMPLE,
|
||||||
|
config_data_types=["tasks"],
|
||||||
|
)
|
||||||
|
assert project_id == "new"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_classify_no_projects_returns_new():
|
||||||
|
project_id, domains, new_name = await _classify_file(
|
||||||
|
file_path="arpa_email.txt",
|
||||||
|
file_content=ARPA_EMAIL,
|
||||||
|
projects=[],
|
||||||
|
config_data_types=["tasks", "notes"],
|
||||||
|
)
|
||||||
|
assert project_id == "new"
|
||||||
|
assert new_name is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ── CLI quick-test runner ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _cli_test(file_path: str, project_names: list[str]) -> None:
|
||||||
|
"""Run Step 1 classification against a real file from the CLI."""
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
content = Path(file_path).read_text(encoding="utf-8", errors="replace")
|
||||||
|
projects = [
|
||||||
|
{"id": f"test-id-{i:04d}", "name": name, "status": "active", "aiSummary": ""}
|
||||||
|
for i, name in enumerate(project_names)
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"\nClassifying: {file_path}")
|
||||||
|
print(f"Projects in context: {[p['name'] for p in projects]}\n")
|
||||||
|
|
||||||
|
project_id, domains, new_name = await _classify_file(
|
||||||
|
file_path=file_path,
|
||||||
|
file_content=content,
|
||||||
|
projects=projects,
|
||||||
|
config_data_types=["tasks", "notes", "timelines"],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"project_id": project_id,
|
||||||
|
"matched_name": next((p["name"] for p in projects if p["id"] == project_id), None),
|
||||||
|
"new_project_name": new_name,
|
||||||
|
"domains": domains,
|
||||||
|
}
|
||||||
|
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: python -m tests.test_classify_file <file_path> [project_name ...]")
|
||||||
|
sys.exit(1)
|
||||||
|
asyncio.run(_cli_test(sys.argv[1], sys.argv[2:]))
|
||||||
288
tests/test_deep_agent.py
Normal file
288
tests/test_deep_agent.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
"""Unit tests for single-agent deep_agent flows with mocked tool results."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.core.deep_agent import (
|
||||||
|
_infer_floating_domain,
|
||||||
|
_normalize_tagged_list_lines,
|
||||||
|
run_floating,
|
||||||
|
run_floating_stream,
|
||||||
|
run_home,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTool:
|
||||||
|
name = "list_tasks"
|
||||||
|
|
||||||
|
async def ainvoke(self, args):
|
||||||
|
return {"rows": [{"id": "task-1", "title": "Mock Task"}], "echo": args}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeLLM:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.agent_calls = 0
|
||||||
|
|
||||||
|
def bind_tools(self, _tools):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def ainvoke(self, messages):
|
||||||
|
system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
|
||||||
|
if "strict domain classifier" in system_prompt:
|
||||||
|
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
|
||||||
|
|
||||||
|
self.agent_calls += 1
|
||||||
|
if self.agent_calls == 1:
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "call-1",
|
||||||
|
"name": "list_tasks",
|
||||||
|
"args": {"project_id": "proj-1"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
||||||
|
assert tool_messages, "Expected at least one tool message"
|
||||||
|
return AIMessage(content=f"Final answer from mocked tool: {tool_messages[-1].content}")
|
||||||
|
|
||||||
|
async def astream(self, _messages):
|
||||||
|
yield SimpleNamespace(content="stream-")
|
||||||
|
yield SimpleNamespace(content="ok")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_home_uses_mocked_tool_result():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
|
):
|
||||||
|
out = await run_home("user-1", "list my tasks", {})
|
||||||
|
|
||||||
|
assert "Final answer from mocked tool" in out
|
||||||
|
assert "Mock Task" in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"show me timeline updates",
|
||||||
|
{"scope": {"type": "timeline", "id": "tl-1"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert events[0] == (
|
||||||
|
"floating_domain",
|
||||||
|
{"type": "timeline", "id": "tl-1", "section": None},
|
||||||
|
)
|
||||||
|
assert ("token", "stream-") in events
|
||||||
|
assert ("token", "ok") in events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
|
||||||
|
class _ClassifierOnlyLLM:
|
||||||
|
async def ainvoke(self, _messages):
|
||||||
|
return AIMessage(
|
||||||
|
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
|
||||||
|
domain = await _infer_floating_domain(
|
||||||
|
"Quali sono i miei task per il progetto X",
|
||||||
|
{
|
||||||
|
"scope": {"type": "timeline"},
|
||||||
|
"resolved_project_id": "213213-312321-312312-421321",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert domain == {
|
||||||
|
"type": "project",
|
||||||
|
"id": "213213-312321-312312-421321",
|
||||||
|
"section": "task",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
|
||||||
|
raw = (
|
||||||
|
"Certo!\n\n"
|
||||||
|
"1. **Task A** — priorita high <task>[task-1]</task>\n"
|
||||||
|
"2. **Task B** — priorita medium <task>[task-2]</task>\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
out = _normalize_tagged_list_lines(raw, "quali sono le prossime attivita?")
|
||||||
|
|
||||||
|
assert "<task>[task-1]</task>" in out
|
||||||
|
assert "<task>[task-2]</task>" in out
|
||||||
|
assert "Task A" not in out
|
||||||
|
assert "Task B" not in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_month_future_only():
|
||||||
|
today = date.today()
|
||||||
|
tomorrow = today + timedelta(days=1)
|
||||||
|
yesterday = today - timedelta(days=1)
|
||||||
|
next_month = (today.replace(day=28) + timedelta(days=5)).replace(day=1)
|
||||||
|
|
||||||
|
raw = "\n".join(
|
||||||
|
[
|
||||||
|
f"- Milestone old — {yesterday.strftime('%d/%m/%Y')} <timeline>[tl-old]</timeline>",
|
||||||
|
f"- Milestone next — {tomorrow.strftime('%d/%m/%Y')} <timeline>[tl-next]</timeline>",
|
||||||
|
f"- Milestone future — {next_month.strftime('%d/%m/%Y')} <timeline>[tl-future]</timeline>",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
out = _normalize_tagged_list_lines(raw, "invece i miei eventi prossimi?")
|
||||||
|
|
||||||
|
assert "<timeline>[tl-next]</timeline>" in out
|
||||||
|
assert "<timeline>[tl-old]</timeline>" not in out
|
||||||
|
assert "<timeline>[tl-future]</timeline>" not in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_strips_xml_like_tags_from_final_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_run_single_agent(**_kwargs):
|
||||||
|
return (
|
||||||
|
"Hai 1 task:\\n"
|
||||||
|
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||||
|
):
|
||||||
|
text, _domain = await run_floating(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "<task>" not in text
|
||||||
|
assert "</task>" not in text
|
||||||
|
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_stream(**_kwargs):
|
||||||
|
yield "token", "Hai 1 task:\\n"
|
||||||
|
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
token_events = [str(data) for event_type, data in events if event_type == "token"]
|
||||||
|
combined = "".join(token_events)
|
||||||
|
assert "<task>" not in combined
|
||||||
|
assert "</task>" not in combined
|
||||||
|
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
|
||||||
|
class _NoChunkLLM:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
def bind_tools(self, _tools):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def ainvoke(self, _messages):
|
||||||
|
self.calls += 1
|
||||||
|
if self.calls == 1:
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "call-1",
|
||||||
|
"name": "list_tasks",
|
||||||
|
"args": {},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return AIMessage(content="No notes found.")
|
||||||
|
|
||||||
|
async def astream(self, _messages):
|
||||||
|
if False:
|
||||||
|
yield None
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
|
||||||
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"quali sono le note?",
|
||||||
|
{"scope": {"type": "note"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert events[0][0] == "floating_domain"
|
||||||
|
assert ("token", "No notes found.") in events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_run_single_agent(**_kwargs):
|
||||||
|
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||||
|
):
|
||||||
|
text, _domain = await run_floating(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert text == "No results found."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_stream(**_kwargs):
|
||||||
|
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert ("token", "No results found.") in events
|
||||||
124
tests/test_e2e_flow.py
Normal file
124
tests/test_e2e_flow.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""End-to-end test: Auth → WS Gateway → Chat Service round-trip.
|
||||||
|
|
||||||
|
Usage (from repo root, with venv activated):
|
||||||
|
python test_e2e_flow.py
|
||||||
|
|
||||||
|
Requires: Auth (8001), WS Gateway (8002), Chat (8003) all running.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
AUTH_URL = "http://127.0.0.1:8001/api/v1/auth"
|
||||||
|
WS_URL = "ws://127.0.0.1:8002/api/v1/ws/device"
|
||||||
|
|
||||||
|
# ── 1. Authenticate ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def get_token() -> str:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# Try login first, register if user doesn't exist
|
||||||
|
resp = await client.post(
|
||||||
|
f"{AUTH_URL}/login",
|
||||||
|
json={"email": "e2e@test.com", "password": "Test1234!"},
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
print("[1/4] Logged in as e2e@test.com")
|
||||||
|
return resp.json()["access_token"]
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
f"{AUTH_URL}/register",
|
||||||
|
json={
|
||||||
|
"email": "e2e@test.com",
|
||||||
|
"password": "Test1234!",
|
||||||
|
"name": "E2E",
|
||||||
|
"surname": "Test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
print("[1/4] Registered + logged in as e2e@test.com")
|
||||||
|
return resp.json()["access_token"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── 2. WebSocket flow ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_e2e():
|
||||||
|
token = await get_token()
|
||||||
|
|
||||||
|
uri = f"{WS_URL}?token={token}"
|
||||||
|
async with websockets.connect(uri) as ws:
|
||||||
|
# Send device_hello
|
||||||
|
await ws.send(json.dumps({
|
||||||
|
"type": "device_hello",
|
||||||
|
"device_id": str(uuid.uuid4()),
|
||||||
|
"agent_ids": ["task", "note", "project", "timeline"],
|
||||||
|
}))
|
||||||
|
print("[2/4] Device registered with WS Gateway")
|
||||||
|
|
||||||
|
# Send a home_request (simple greeting — unlikely to need tools)
|
||||||
|
await ws.send(json.dumps({
|
||||||
|
"type": "home_request",
|
||||||
|
"message": "Hello! How are you doing today?",
|
||||||
|
"context": {},
|
||||||
|
}))
|
||||||
|
print("[3/4] Sent home_request → waiting for Chat Service response...")
|
||||||
|
|
||||||
|
# Listen for response frames (text_chunk, tool_call, final)
|
||||||
|
full_response = []
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
raw = await asyncio.wait_for(ws.recv(), timeout=60)
|
||||||
|
frame = json.loads(raw)
|
||||||
|
ftype = frame.get("type")
|
||||||
|
|
||||||
|
if ftype == "text_chunk":
|
||||||
|
chunk = frame.get("chunk", frame.get("text", ""))
|
||||||
|
full_response.append(chunk)
|
||||||
|
print(f" ← text_chunk: {chunk[:80]}")
|
||||||
|
|
||||||
|
elif ftype == "tool_call":
|
||||||
|
# Respond with a mock tool_result so the agent doesn't hang
|
||||||
|
call_id = frame.get("id")
|
||||||
|
action = frame.get("action")
|
||||||
|
table = frame.get("table", "")
|
||||||
|
print(f" ← tool_call: {action} {table} (id={call_id})")
|
||||||
|
|
||||||
|
mock_result = {"rows": [], "row": None}
|
||||||
|
await ws.send(json.dumps({
|
||||||
|
"type": "tool_result",
|
||||||
|
"id": call_id,
|
||||||
|
**mock_result,
|
||||||
|
}))
|
||||||
|
print(f" → tool_result (mock) for {call_id}")
|
||||||
|
|
||||||
|
elif ftype == "final":
|
||||||
|
text = frame.get("text", "")
|
||||||
|
if text:
|
||||||
|
full_response.append(text)
|
||||||
|
print(f" ← final")
|
||||||
|
break
|
||||||
|
|
||||||
|
elif ftype == "ping":
|
||||||
|
# Ignore heartbeats
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f" ← {ftype}: {json.dumps(frame)[:120]}")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print(" ⚠ Timed out waiting for response (60s)")
|
||||||
|
|
||||||
|
print()
|
||||||
|
if full_response:
|
||||||
|
print(f"[4/4] Full response: {''.join(full_response)}")
|
||||||
|
else:
|
||||||
|
print("[4/4] No text response received (check Chat Service logs)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(run_e2e())
|
||||||
@@ -110,6 +110,32 @@ async def test_enrich_context_returns_episodic_memory(db_session, user_with_key)
|
|||||||
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_filters_episodic_by_session_id(db_session, user_with_key):
|
||||||
|
target_session = str(uuid.uuid4())
|
||||||
|
other_session = str(uuid.uuid4())
|
||||||
|
db_session.add(MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=_enc("Target session memory"),
|
||||||
|
session_id=target_session,
|
||||||
|
))
|
||||||
|
db_session.add(MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=_enc("Other session memory"),
|
||||||
|
session_id=other_session,
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "any message", session_id=target_session)
|
||||||
|
|
||||||
|
episodic = ctx.get("episodic_memory", [])
|
||||||
|
assert any("Target session" in s for s in episodic)
|
||||||
|
assert not any("Other session" in s for s in episodic)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
||||||
# Add one pattern above threshold and one below
|
# Add one pattern above threshold and one below
|
||||||
@@ -229,6 +255,40 @@ async def test_update_core_upsert(db_session, user_with_key):
|
|||||||
assert _dec(rows[0].value_encrypted) == "fr"
|
assert _dec(rows[0].value_encrypted) == "fr"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_core_block_edit_ops(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
|
||||||
|
await middleware.update_core(USER_ID, "human", "Name: Roberto")
|
||||||
|
await middleware.append_core(USER_ID, "human", "Timezone: Europe/Rome")
|
||||||
|
replaced = await middleware.replace_core(USER_ID, "human", "Roberto", "Robert")
|
||||||
|
|
||||||
|
blocks = await middleware.list_core_blocks(USER_ID)
|
||||||
|
human = next(b for b in blocks if b["label"] == "human")
|
||||||
|
|
||||||
|
assert replaced is True
|
||||||
|
assert "Name: Robert" in human["value"]
|
||||||
|
assert "Timezone: Europe/Rome" in human["value"]
|
||||||
|
|
||||||
|
deleted = await middleware.delete_core(USER_ID, "human")
|
||||||
|
assert deleted is True
|
||||||
|
assert await middleware.get_core_block(USER_ID, "human") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archival_and_recall_search_helpers(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
|
||||||
|
await middleware.insert_archival(USER_ID, "Project whitelist has release risk", source="assistant")
|
||||||
|
await middleware.store_episode(USER_ID, str(uuid.uuid4()), "How is whitelist?", "Whitelist is delayed")
|
||||||
|
|
||||||
|
arch = await middleware.search_archival(USER_ID, "whitelist", top_k=3)
|
||||||
|
rec = await middleware.search_recall(USER_ID, "delayed", top_k=3)
|
||||||
|
|
||||||
|
assert any("whitelist" in item.lower() for item in arch)
|
||||||
|
assert any("delayed" in item.lower() for item in rec)
|
||||||
|
|
||||||
|
|
||||||
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
||||||
|
|
||||||
def test_home_request_calls_memory_middleware(client):
|
def test_home_request_calls_memory_middleware(client):
|
||||||
@@ -240,21 +300,20 @@ def test_home_request_calls_memory_middleware(client):
|
|||||||
def __init__(self, db):
|
def __init__(self, db):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def enrich_context(self, user_id, message):
|
async def enrich_context(self, user_id, message, **kwargs):
|
||||||
enrich_calls.append((user_id, message))
|
enrich_calls.append((user_id, message))
|
||||||
return {"core_memory": {"tz": "UTC"}}
|
return {"core_memory": {"tz": "UTC"}}
|
||||||
|
|
||||||
async def store_episode(self, user_id, session_id, message, response):
|
async def store_episode(self, user_id, session_id, message, response, **kwargs):
|
||||||
store_calls.append((user_id, session_id, message, response))
|
store_calls.append((user_id, session_id, message, response))
|
||||||
|
|
||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
async def _mock_stream(user_id, message, context, db_session_factory=None):
|
async def _mock_stream(user_id, message, context):
|
||||||
# Verify memory context was injected
|
# Verify memory context was injected
|
||||||
assert context.get("core_memory") == {"tz": "UTC"}
|
assert context.get("core_memory") == {"tz": "UTC"}
|
||||||
yield ("token", "Done")
|
yield "token", "Done"
|
||||||
yield ("mutations", [])
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
||||||
|
|||||||
@@ -1,214 +1,82 @@
|
|||||||
"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
|
"""Tests for app.core.output_formatter.StreamFormatter."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
from app.core.output_formatter import StreamFormatter
|
||||||
from app.schemas import (
|
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
WsFloatingDomain,
|
|
||||||
WsStreamEnd,
|
|
||||||
WsStreamStart,
|
|
||||||
WsStreamText,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _stream(*events: tuple[str, object]):
|
async def _stream(*events: tuple[str, object]):
|
||||||
"""Async generator that yields (event_type, data) tuples."""
|
|
||||||
for event in events:
|
for event in events:
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
|
||||||
async def collect(formatter, event_stream):
|
async def _collect(formatter: StreamFormatter, event_stream):
|
||||||
frames = []
|
frames = []
|
||||||
async for frame in formatter.format(event_stream):
|
async for frame in formatter.format(event_stream):
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_home_formatter_plain_text():
|
async def test_stream_formatter_text_stream() -> None:
|
||||||
req_id = "req-1"
|
formatter = StreamFormatter(request_id="req-1")
|
||||||
events = [
|
frames = await _collect(
|
||||||
("token", "Hello world"),
|
formatter,
|
||||||
("mutations", []),
|
_stream(("token", "Hello"), ("token", " world")),
|
||||||
]
|
)
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
assert isinstance(frames[0], WsStreamStart)
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
assert frames[0].request_id == req_id
|
assert isinstance(frames[1], WsStreamText)
|
||||||
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
assert frames[1].chunk == "Hello"
|
||||||
assert any("Hello world" in f.chunk for f in text_frames)
|
assert isinstance(frames[2], WsStreamText)
|
||||||
|
assert frames[2].chunk == " world"
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_home_formatter_entity_tags_passed_through():
|
async def test_stream_formatter_floating_domain_first() -> None:
|
||||||
"""Entity tags are streamed as-is — the frontend parses them."""
|
formatter = StreamFormatter(request_id="req-2")
|
||||||
req_id = "req-2"
|
frames = await _collect(
|
||||||
events = [
|
formatter,
|
||||||
("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
|
_stream(
|
||||||
("mutations", []),
|
(
|
||||||
]
|
"floating_domain",
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
{"type": "node", "id": "n-1", "section": None},
|
||||||
frames = await collect(formatter, _stream(*events))
|
),
|
||||||
|
|
||||||
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
|
||||||
assert "<project>[abc-123]</project>" in text
|
|
||||||
assert "Here is your project:" in text
|
|
||||||
assert "All good." in text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_multiple_tags_passed_through():
|
|
||||||
req_id = "req-3"
|
|
||||||
events = [
|
|
||||||
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
|
|
||||||
("mutations", []),
|
|
||||||
]
|
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
|
||||||
assert "<project>[p1]</project>" in text
|
|
||||||
assert "<task>[t1,t2]</task>" in text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_tool_end_ignored():
|
|
||||||
"""tool_end events are silently ignored by HomeFormatter."""
|
|
||||||
req_id = "req-4"
|
|
||||||
events = [
|
|
||||||
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
|
||||||
("token", "No tags here."),
|
|
||||||
("mutations", []),
|
|
||||||
]
|
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
|
||||||
assert text == "No tags here."
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_mutations_in_stream_end():
|
|
||||||
req_id = "req-5"
|
|
||||||
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
|
|
||||||
events = [
|
|
||||||
("token", "Done"),
|
|
||||||
("mutations", muts),
|
|
||||||
]
|
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
end_frame = frames[-1]
|
|
||||||
assert isinstance(end_frame, WsStreamEnd)
|
|
||||||
assert len(end_frame.mutations) == 1
|
|
||||||
assert end_frame.mutations[0]["action"] == "insert"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_frame_order():
|
|
||||||
"""stream_start is first, stream_end is last."""
|
|
||||||
req_id = "req-6"
|
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
|
|
||||||
assert isinstance(frames[0], WsStreamStart)
|
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
|
||||||
|
|
||||||
|
|
||||||
# ── FloatingFormatter ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_floating_formatter_domain_from_tool_end():
|
|
||||||
req_id = "pop-1"
|
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
events = [
|
|
||||||
("tool_end", {"name": "task_agent", "result": "ok"}),
|
|
||||||
("token", "Hello"),
|
|
||||||
("mutations", []),
|
|
||||||
]
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
assert isinstance(frames[0], WsFloatingDomain)
|
|
||||||
assert frames[0].domain == "tasks"
|
|
||||||
assert frames[0].request_id == req_id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_floating_formatter_text_only():
|
|
||||||
req_id = "pop-2"
|
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
events = [
|
|
||||||
("tool_end", {"name": "timeline_agent", "result": "done"}),
|
|
||||||
("token", "Summary"),
|
("token", "Summary"),
|
||||||
("mutations", []),
|
),
|
||||||
]
|
)
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
assert isinstance(frames[0], WsFloatingDomain)
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
assert frames[0].domain == "timelines"
|
assert frames[0].domain.type == "node"
|
||||||
|
assert frames[0].domain.id == "n-1"
|
||||||
|
assert isinstance(frames[1], WsStreamStart)
|
||||||
|
assert isinstance(frames[2], WsStreamText)
|
||||||
|
assert frames[2].chunk == "Summary"
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_formatter_ignores_unknown_events() -> None:
|
||||||
|
formatter = StreamFormatter(request_id="req-3")
|
||||||
|
frames = await _collect(
|
||||||
|
formatter,
|
||||||
|
_stream(("tool_end", {"name": "x"}), ("token", "ok")),
|
||||||
|
)
|
||||||
|
|
||||||
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
assert len(text_frames) == 1
|
assert len(text_frames) == 1
|
||||||
assert text_frames[0].chunk == "Summary"
|
assert text_frames[0].chunk == "ok"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_floating_formatter_no_entity_tags():
|
async def test_stream_formatter_empty_stream_still_brackets() -> None:
|
||||||
"""FloatingFormatter never emits entity tag blocks."""
|
formatter = StreamFormatter(request_id="req-4")
|
||||||
req_id = "pop-3"
|
frames = await _collect(formatter, _stream())
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
events = [
|
|
||||||
("tool_end", {"name": "note_agent", "result": "data"}),
|
|
||||||
("token", "some text"),
|
|
||||||
("mutations", []),
|
|
||||||
]
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
# Only expected frame types
|
|
||||||
for f in frames:
|
|
||||||
assert isinstance(f, (WsFloatingDomain, WsStreamStart, WsStreamText, WsStreamEnd))
|
|
||||||
|
|
||||||
|
assert len(frames) == 2
|
||||||
@pytest.mark.asyncio
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
async def test_floating_formatter_end_frame():
|
assert isinstance(frames[1], WsStreamEnd)
|
||||||
req_id = "pop-4"
|
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
events = [
|
|
||||||
("tool_end", {"name": "project_agent", "result": "ok"}),
|
|
||||||
("token", "Done"),
|
|
||||||
("mutations", []),
|
|
||||||
]
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_floating_formatter_default_domain_on_early_token():
|
|
||||||
"""When the first event is a token (no tool_end yet), default to 'tasks'."""
|
|
||||||
req_id = "pop-5"
|
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
events = [("token", "hi"), ("mutations", [])]
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
assert isinstance(frames[0], WsFloatingDomain)
|
|
||||||
assert frames[0].domain == "tasks"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_floating_formatter_mutations_in_stream_end():
|
|
||||||
req_id = "pop-6"
|
|
||||||
muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}]
|
|
||||||
events = [
|
|
||||||
("token", "Updated"),
|
|
||||||
("mutations", muts),
|
|
||||||
]
|
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
end_frame = frames[-1]
|
|
||||||
assert isinstance(end_frame, WsStreamEnd)
|
|
||||||
assert len(end_frame.mutations) == 1
|
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class TestPluginRegistry:
|
|||||||
async def test_list_filter_by_query(
|
async def test_list_filter_by_query(
|
||||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
) -> None:
|
) -> None:
|
||||||
result = await reg.list_plugins(db_session, query="time tracker")
|
result = await reg.list_plugins(db_session, query="time")
|
||||||
assert result.total == 1
|
assert result.total == 1
|
||||||
assert result.plugins[0].id == "plugin-time-tracker"
|
assert result.plugins[0].id == "plugin-time-tracker"
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import pytest
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
|
WsDomain,
|
||||||
WsFrameType,
|
WsFrameType,
|
||||||
WsHomeRequest,
|
WsHomeRequest,
|
||||||
WsFloatingDomain,
|
WsFloatingDomain,
|
||||||
@@ -178,23 +179,15 @@ def test_stream_text_deserializes():
|
|||||||
def test_stream_end_defaults():
|
def test_stream_end_defaults():
|
||||||
frame = WsStreamEnd(request_id="r1")
|
frame = WsStreamEnd(request_id="r1")
|
||||||
assert frame.type == WsFrameType.stream_end
|
assert frame.type == WsFrameType.stream_end
|
||||||
assert frame.mutations == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_with_mutations():
|
|
||||||
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
|
|
||||||
frame = WsStreamEnd(request_id="r1", mutations=mutations)
|
|
||||||
assert len(frame.mutations) == 1
|
|
||||||
assert frame.mutations[0]["action"] == "create"
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_serializes():
|
def test_stream_end_serializes():
|
||||||
data = WsStreamEnd(request_id="r2").model_dump()
|
data = WsStreamEnd(request_id="r2").model_dump()
|
||||||
assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
|
assert data == {"type": "stream_end", "request_id": "r2"}
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_deserializes():
|
def test_stream_end_deserializes():
|
||||||
raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
|
raw = {"type": "stream_end", "request_id": "r3"}
|
||||||
frame = WsStreamEnd.model_validate(raw)
|
frame = WsStreamEnd.model_validate(raw)
|
||||||
assert frame.request_id == "r3"
|
assert frame.request_id == "r3"
|
||||||
|
|
||||||
@@ -203,28 +196,47 @@ def test_stream_end_deserializes():
|
|||||||
|
|
||||||
|
|
||||||
def test_floating_domain_tasks():
|
def test_floating_domain_tasks():
|
||||||
frame = WsFloatingDomain(request_id="r1", domain="tasks")
|
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
|
||||||
assert frame.type == WsFrameType.floating_domain
|
assert frame.type == WsFrameType.floating_domain
|
||||||
assert frame.domain == "tasks"
|
assert frame.domain.type == "task"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"])
|
def test_floating_domain_valid_domains():
|
||||||
def test_floating_domain_valid_domains(domain: str):
|
frame = WsFloatingDomain(
|
||||||
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
|
request_id="r1",
|
||||||
assert frame.domain == domain
|
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
|
||||||
|
)
|
||||||
|
assert frame.domain.type == "project"
|
||||||
|
assert frame.domain.id == "213213-312321-312312-421321"
|
||||||
|
assert frame.domain.section == "task"
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_invalid():
|
def test_floating_domain_object_valid():
|
||||||
with pytest.raises(ValidationError):
|
frame = WsFloatingDomain(
|
||||||
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
|
request_id="r1",
|
||||||
|
domain=WsDomain(type="project", id="p1", section="task"),
|
||||||
|
)
|
||||||
|
assert frame.domain.type == "project"
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_serializes():
|
def test_floating_domain_serializes():
|
||||||
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
|
d = WsFloatingDomain(
|
||||||
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
|
request_id="r1",
|
||||||
|
domain=WsDomain(type="timeline"),
|
||||||
|
).model_dump()
|
||||||
|
assert d == {
|
||||||
|
"type": "floating_domain",
|
||||||
|
"request_id": "r1",
|
||||||
|
"domain": {"type": "timeline", "id": None, "section": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_deserializes():
|
def test_floating_domain_deserializes():
|
||||||
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
|
raw = {
|
||||||
|
"type": "floating_domain",
|
||||||
|
"request_id": "r1",
|
||||||
|
"domain": {"type": "node", "id": "n-1", "section": None},
|
||||||
|
}
|
||||||
frame = WsFloatingDomain.model_validate(raw)
|
frame = WsFloatingDomain.model_validate(raw)
|
||||||
assert frame.domain == "projects"
|
assert frame.domain.type == "node"
|
||||||
|
assert frame.domain.id == "n-1"
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user