feat: add WS Gateway and Chat Service (Step 2)
WS Gateway:
- WebSocket lifecycle handler with RS256 JWT auth
- Redis bridge: device registry, frame publishing, tool_result routing
- Inbound routing: tool_result→LPUSH, home/floating→chat pub/sub
- Outbound: subscribes to ws:out:{user_id}, forwards to Electron
- Single-worker Dockerfile (long-lived WS connections)
Chat Service:
- Redis consumer: subscribes to chat:request:* pattern
- Redis-based ws_context: tool_call→publish, BRPOP tool_result (30s timeout)
- deep_agent: single-agent runner with home/floating/stream variants
- memory_middleware: core/associative/episodic/proactive memory with Fernet
- Domain agents: task (8 tools), note (5), project (6), timeline (4)
- LLM factory via LiteLLM (100+ providers)
- Output formatter (StreamFormatter)
- POST /chat REST fallback with Traefik header auth
- Multi-worker Dockerfile with 120s timeout for LLM calls
This commit is contained in:
36
services/chat/Dockerfile
Normal file
36
services/chat/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
||||
# ── builder ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY services/chat/requirements.txt ./requirements.txt
|
||||
RUN pip install --upgrade pip && \
|
||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||
|
||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS runtime
|
||||
|
||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /install /usr/local
|
||||
|
||||
# Shared module
|
||||
COPY shared/ shared/
|
||||
|
||||
# Service source
|
||||
COPY services/chat/app/ app/
|
||||
|
||||
RUN chown -R appuser:appgroup /app
|
||||
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
# Chat service is CPU-bound (LLM calls) — use multiple workers
|
||||
CMD ["gunicorn", "app.main:app", \
|
||||
"-k", "uvicorn.workers.UvicornWorker", \
|
||||
"--bind", "0.0.0.0:8000", \
|
||||
"--workers", "2", \
|
||||
"--timeout", "120"]
|
||||
1
services/chat/app/agents/__init__.py
Normal file
1
services/chat/app/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Chat Service domain agents."""
|
||||
142
services/chat/app/agents/note_agent.py
Normal file
142
services/chat/app/agents/note_agent.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Note agent — Markdown note management (list, get, create, update, delete).
|
||||
|
||||
Adapted for Chat Service: import from app.ws_context and app.llm.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.llm import embed
|
||||
from app.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
NOTE_SYSTEM_PROMPT = (
|
||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||
"and delete Markdown notes in their workspace.\n\n"
|
||||
"Rules:\n"
|
||||
" - content is always Markdown; preserve formatting when updating\n"
|
||||
" - project_id is optional; link a note to a project when mentioned\n"
|
||||
" - When updating, call get_note first if you need to read existing content\n"
|
||||
" before appending or replacing sections\n"
|
||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||
" when the user is working within a specific project\n"
|
||||
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||
" is already in the note (retrieved via get_note)."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_notes(project_id: str = "") -> str:
|
||||
"""List notes, optionally scoped to a project by project_id."""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="notes",
|
||||
filters={"projectId": normalized_project_id or None},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No notes found."
|
||||
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_note(note_id: str) -> str:
|
||||
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
|
||||
row = result.get("row")
|
||||
if not row:
|
||||
return f"Note {note_id} not found."
|
||||
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
|
||||
|
||||
|
||||
@tool
|
||||
async def create_note(
|
||||
title: str,
|
||||
content: str,
|
||||
project_id: str = "",
|
||||
) -> str:
|
||||
"""Create a new note.
|
||||
title: note heading (required)
|
||||
content: Markdown body text (required)
|
||||
project_id: optional UUID linking this note to a project
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="notes",
|
||||
data={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"projectId": project_id or None,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
# Index the note content in the vector store.
|
||||
vector = await embed(content)
|
||||
await execute_on_client(
|
||||
action="vector_upsert",
|
||||
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
||||
vector=vector,
|
||||
)
|
||||
return f"Note created: '{row['title']}' (id: {row['id']})."
|
||||
|
||||
|
||||
@tool
|
||||
async def update_note(
|
||||
note_id: str,
|
||||
title: str = "",
|
||||
content: str = "",
|
||||
) -> str:
|
||||
"""Update an existing note. Only pass fields that should change.
|
||||
note_id: UUID of the note (required)
|
||||
If you need to preserve existing content, call get_note first.
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if content:
|
||||
updates["content"] = content
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="notes",
|
||||
data={"id": note_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
# Re-index if content changed.
|
||||
if content:
|
||||
vector = await embed(content)
|
||||
await execute_on_client(
|
||||
action="vector_upsert",
|
||||
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
||||
vector=vector,
|
||||
)
|
||||
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_note(note_id: str) -> str:
|
||||
"""Delete a note permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="notes", data={"id": note_id})
|
||||
return f"Note {note_id} deleted."
|
||||
|
||||
|
||||
NOTE_TOOLS: list[Any] = [
|
||||
list_notes,
|
||||
get_note,
|
||||
create_note,
|
||||
update_note,
|
||||
delete_note,
|
||||
]
|
||||
146
services/chat/app/agents/project_agent.py
Normal file
146
services/chat/app/agents/project_agent.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete).
|
||||
|
||||
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.ws_context import execute_on_client
|
||||
|
||||
PROJECT_SYSTEM_PROMPT = (
|
||||
"You are a project management assistant. You help users create, find,\n"
|
||||
"update, and archive projects in their workspace.\n\n"
|
||||
"Rules:\n"
|
||||
" - status must be one of: active, archived\n"
|
||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
||||
" derive it from context data — do not fabricate content\n"
|
||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
||||
" user wants a complete cross-client view including archived projects\n"
|
||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
||||
" list_projects if you only have a project name\n"
|
||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
||||
" only call delete_project when the user explicitly confirms deletion."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_projects(
|
||||
client_id: str = "",
|
||||
include_archived: int = 0,
|
||||
) -> str:
|
||||
"""List projects, optionally filtered by client_id.
|
||||
include_archived: 1 to include archived projects, 0 for active only (default).
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="projects",
|
||||
filters={
|
||||
"clientId": client_id or None,
|
||||
"includeArchived": bool(include_archived),
|
||||
},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No projects found."
|
||||
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_all_projects() -> str:
|
||||
"""List every project regardless of client or status.
|
||||
Use only when the user wants a complete cross-client overview.
|
||||
"""
|
||||
result = await execute_on_client(action="select", table="projects")
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No projects found."
|
||||
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_project(project_id: str) -> str:
|
||||
"""Fetch a single project by its UUID."""
|
||||
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
|
||||
row = result.get("row")
|
||||
if not row:
|
||||
return f"Project {project_id} not found."
|
||||
return (
|
||||
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
|
||||
f"clientId: {row.get('clientId', 'none')})"
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def create_project(
|
||||
name: str,
|
||||
client_id: str = "",
|
||||
) -> str:
|
||||
"""Create a new project.
|
||||
name: human-readable project name (required)
|
||||
client_id: optional UUID of the owning client
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="projects",
|
||||
data={"name": name, "clientId": client_id or None},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Project created: '{row['name']}' (id: {row['id']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def update_project(
|
||||
project_id: str,
|
||||
name: str = "",
|
||||
client_id: str = "",
|
||||
status: str = "",
|
||||
ai_summary: str = "",
|
||||
) -> str:
|
||||
"""Update a project. Only pass fields that should change.
|
||||
project_id: UUID of the project (required)
|
||||
status: active | archived
|
||||
ai_summary: AI-generated summary text (populate only when explicitly requested)
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if name:
|
||||
updates["name"] = name
|
||||
if client_id:
|
||||
updates["clientId"] = client_id
|
||||
if status:
|
||||
updates["status"] = status
|
||||
if ai_summary:
|
||||
updates["aiSummary"] = ai_summary
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="projects",
|
||||
data={"id": project_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_project(project_id: str) -> str:
|
||||
"""Permanently delete a project and orphan its tasks.
|
||||
IMPORTANT: prefer update_project(status='archived') unless the user
|
||||
has explicitly confirmed they want permanent deletion.
|
||||
"""
|
||||
await execute_on_client(action="delete", table="projects", data={"id": project_id})
|
||||
return f"Project {project_id} permanently deleted."
|
||||
|
||||
|
||||
PROJECT_TOOLS: list[Any] = [
|
||||
list_projects,
|
||||
list_all_projects,
|
||||
get_project,
|
||||
create_project,
|
||||
update_project,
|
||||
delete_project,
|
||||
]
|
||||
240
services/chat/app/agents/task_agent.py
Normal file
240
services/chat/app/agents/task_agent.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Task agent — full CRUD for tasks and task comments.
|
||||
|
||||
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
TASK_SYSTEM_PROMPT = (
|
||||
"You are a task management assistant for a project workspace.\n"
|
||||
"You create, update, list, and track tasks and their comments.\n\n"
|
||||
"Rules:\n"
|
||||
" - status must be one of: todo, in_progress, done\n"
|
||||
" - priority must be one of: high, medium, low\n"
|
||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
||||
" - project_id is optional; link to a project when the user mentions one\n"
|
||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
||||
" did not explicitly request; 0 otherwise\n"
|
||||
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
|
||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||
" - Always confirm the action in plain, user-friendly language."
|
||||
)
|
||||
|
||||
|
||||
# ── Task tools ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks(
|
||||
project_id: str = "",
|
||||
status: str = "",
|
||||
search: str = "",
|
||||
order_by: str = "",
|
||||
) -> str:
|
||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="tasks",
|
||||
filters={
|
||||
"projectId": normalized_project_id or None,
|
||||
"status": status or None,
|
||||
"search": search or None,
|
||||
"orderBy": order_by or None,
|
||||
},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks found matching the given filters."
|
||||
lines = [
|
||||
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def create_task(
|
||||
title: str,
|
||||
description: str = "",
|
||||
status: str = "todo",
|
||||
priority: str = "medium",
|
||||
assignees: str = "[]",
|
||||
due_date: int = 0,
|
||||
project_id: str = "",
|
||||
is_ai_suggested: int = 0,
|
||||
) -> str:
|
||||
"""Create a new task.
|
||||
title: task title (required)
|
||||
description: optional details
|
||||
status: todo | in_progress | done (default: todo)
|
||||
priority: high | medium | low (default: medium)
|
||||
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
|
||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||
project_id: optional UUID of the parent project
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="tasks",
|
||||
data={
|
||||
"title": title,
|
||||
"description": description or None,
|
||||
"status": status,
|
||||
"priority": priority,
|
||||
"assignee": assignees,
|
||||
"dueDate": due_date or None,
|
||||
"projectId": project_id or None,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
return (
|
||||
f"Task created: '{row['title']}' "
|
||||
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def update_task(
|
||||
task_id: str,
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
status: str = "",
|
||||
priority: str = "",
|
||||
assignees: str = "",
|
||||
due_date: int = -1,
|
||||
project_id: str = "",
|
||||
) -> str:
|
||||
"""Update fields on an existing task. Only pass fields you want to change.
|
||||
task_id: the task's UUID (required)
|
||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if description:
|
||||
updates["description"] = description
|
||||
if status:
|
||||
updates["status"] = status
|
||||
if priority:
|
||||
updates["priority"] = priority
|
||||
if assignees:
|
||||
updates["assignee"] = assignees
|
||||
if due_date != -1:
|
||||
updates["dueDate"] = due_date or None
|
||||
if project_id:
|
||||
updates["projectId"] = project_id
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="tasks",
|
||||
data={"id": task_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task(task_id: str) -> str:
|
||||
"""Delete a task permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
|
||||
return f"Task {task_id} deleted."
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks_due_today() -> str:
|
||||
"""List all tasks whose due date falls on today's date."""
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
||||
end_ms = start_ms + 86_400_000 - 1 # last ms of today
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="tasks",
|
||||
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks are due today."
|
||||
lines = [
|
||||
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
# ── Task comment tools ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@tool
|
||||
async def list_task_comments(task_id: str) -> str:
|
||||
"""List all comments on a task by its UUID."""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="taskComments",
|
||||
filters={"taskId": task_id},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return f"No comments found for task {task_id}."
|
||||
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||
"""Add a comment to a task.
|
||||
task_id: UUID of the task to comment on
|
||||
author: name or ID of the comment author
|
||||
content: comment text
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="taskComments",
|
||||
data={"taskId": task_id, "author": author, "content": content},
|
||||
)
|
||||
row = result.get("row", {})
|
||||
row_author = row.get("author", author)
|
||||
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||
row_comment_id = row.get("id", "unknown")
|
||||
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task_comment(comment_id: str) -> str:
|
||||
"""Delete a task comment by its UUID."""
|
||||
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
|
||||
return f"Comment {comment_id} deleted."
|
||||
|
||||
|
||||
# ── Agent ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
TASK_TOOLS: list[Any] = [
|
||||
list_tasks,
|
||||
create_task,
|
||||
update_task,
|
||||
delete_task,
|
||||
list_tasks_due_today,
|
||||
list_task_comments,
|
||||
add_task_comment,
|
||||
delete_task_comment,
|
||||
]
|
||||
117
services/chat/app/agents/timeline_agent.py
Normal file
117
services/chat/app/agents/timeline_agent.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Timeline agent — project milestone management (list, create, update, delete).
|
||||
|
||||
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
TIMELINE_SYSTEM_PROMPT = (
|
||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
||||
"track progress on a project — they are not calendar events.\n\n"
|
||||
"Rules:\n"
|
||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
||||
" - Listing without a project_id returns all timelines across projects\n"
|
||||
" - Always echo the title and formatted date in your confirmation."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_timelines(project_id: str = "") -> str:
|
||||
"""List timelines. Provide project_id to scope to a specific project."""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="timelines",
|
||||
filters={"projectId": normalized_project_id or None},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No timelines found."
|
||||
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def create_timeline(
|
||||
project_id: str,
|
||||
title: str,
|
||||
date: int,
|
||||
is_ai_suggested: int = 0,
|
||||
) -> str:
|
||||
"""Create a project timeline (milestone).
|
||||
project_id: REQUIRED UUID of the parent project
|
||||
title: descriptive name for the milestone
|
||||
date: Unix timestamp in milliseconds
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="timelines",
|
||||
data={
|
||||
"projectId": project_id,
|
||||
"title": title,
|
||||
"date": date,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def update_timeline(
|
||||
timeline_id: str,
|
||||
title: str = "",
|
||||
date: int = -1,
|
||||
) -> str:
|
||||
"""Update a timeline. Only pass fields that should change.
|
||||
timeline_id: UUID of the timeline (required)
|
||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if date != -1:
|
||||
updates["date"] = date
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="timelines",
|
||||
data={"id": timeline_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_timeline(timeline_id: str) -> str:
|
||||
"""Delete a timeline permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||
return f"Timeline {timeline_id} deleted."
|
||||
|
||||
|
||||
TIMELINE_TOOLS: list[Any] = [
|
||||
list_timelines,
|
||||
create_timeline,
|
||||
update_timeline,
|
||||
delete_timeline,
|
||||
]
|
||||
847
services/chat/app/deep_agent.py
Normal file
847
services/chat/app/deep_agent.py
Normal file
@@ -0,0 +1,847 @@
|
||||
"""Single-agent runners for home and floating chat contexts.
|
||||
|
||||
Adapted from app/core/deep_agent.py for the Chat Service.
|
||||
Import paths changed to use local app modules and shared/.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import date
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.agents.note_agent import NOTE_TOOLS
|
||||
from app.agents.project_agent import PROJECT_TOOLS
|
||||
from app.agents.task_agent import TASK_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||
from app.llm import get_llm
|
||||
from app.memory_middleware import MemoryMiddleware
|
||||
from app.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
||||
from shared.db import async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
||||
FloatingDomainSection = Literal["task", "timeline", "note"]
|
||||
|
||||
_HOME_SINGLE_AGENT_SYSTEM = (
|
||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||
"Always use tools for factual data retrieval before answering. "
|
||||
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
|
||||
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
|
||||
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
|
||||
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
|
||||
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
|
||||
)
|
||||
|
||||
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
||||
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||
"Stay focused on the floating scope in context.scope and answer concisely. "
|
||||
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
|
||||
"Always use tools for factual data retrieval before answering. "
|
||||
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||
)
|
||||
|
||||
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
|
||||
"You are a strict domain classifier for websocket floating requests. "
|
||||
"Return ONLY a JSON object with keys: type, id, section. "
|
||||
"Allowed type values: task, timeline, project, node. "
|
||||
"Allowed section values: task, timeline, note, or null. "
|
||||
"Rules: infer from user message intent first; do not blindly trust scope.type. "
|
||||
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
|
||||
"If project id is unknown but context.resolved_project_id exists, use it as id. "
|
||||
"If id is unknown, use null. "
|
||||
"No markdown, no prose, JSON only."
|
||||
)
|
||||
|
||||
|
||||
def _as_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _candidate_tokens(message: str) -> list[str]:
|
||||
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
|
||||
return [token for token in tokens if len(token) >= 3]
|
||||
|
||||
|
||||
async def _resolve_project_id_from_message(message: str) -> str | None:
|
||||
"""Resolve likely project UUID from user message using client project list."""
|
||||
try:
|
||||
result = await execute_on_client(action="select", table="projects")
|
||||
except Exception as exc:
|
||||
logger.warning("deep_agent: project resolve select failed: %s", exc)
|
||||
return None
|
||||
|
||||
rows = result.get("rows", [])
|
||||
if not isinstance(rows, list) or not rows:
|
||||
return None
|
||||
|
||||
tokens = _candidate_tokens(message)
|
||||
scored: list[tuple[int, dict[str, Any]]] = []
|
||||
for row in rows:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
name = str(row.get("name", "")).lower()
|
||||
score = sum(1 for token in tokens if token in name)
|
||||
if score > 0:
|
||||
scored.append((score, row))
|
||||
|
||||
if not scored:
|
||||
return None
|
||||
|
||||
scored.sort(key=lambda item: item[0], reverse=True)
|
||||
top_score = scored[0][0]
|
||||
top_rows = [row for score, row in scored if score == top_score]
|
||||
if len(top_rows) != 1:
|
||||
return None
|
||||
|
||||
project_id = top_rows[0].get("id")
|
||||
return project_id if isinstance(project_id, str) else None
|
||||
|
||||
|
||||
def _needs_project_resolution(message: str) -> bool:
|
||||
lowered = message.lower()
|
||||
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
|
||||
|
||||
|
||||
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
|
||||
prepared = dict(context)
|
||||
if _needs_project_resolution(message):
|
||||
resolved_project_id = await _resolve_project_id_from_message(message)
|
||||
if resolved_project_id:
|
||||
prepared["resolved_project_id"] = resolved_project_id
|
||||
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
|
||||
return prepared
|
||||
|
||||
|
||||
def _all_tools() -> list[Any]:
|
||||
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
||||
|
||||
|
||||
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
||||
debug = context.get("_debug")
|
||||
if isinstance(debug, dict):
|
||||
request_id = debug.get("request_id")
|
||||
if isinstance(request_id, str) and request_id:
|
||||
return request_id
|
||||
return None
|
||||
|
||||
|
||||
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
||||
sanitized = dict(context)
|
||||
sanitized.pop("_debug", None)
|
||||
return sanitized
|
||||
|
||||
|
||||
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
|
||||
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
|
||||
|
||||
|
||||
def _is_upcoming_timeline_query(message: str) -> bool:
|
||||
lowered = message.lower()
|
||||
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
|
||||
has_timeline_topic = any(
|
||||
token in lowered
|
||||
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
|
||||
)
|
||||
return has_upcoming and has_timeline_topic
|
||||
|
||||
|
||||
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
|
||||
match = _TIMELINE_DMY_RE.search(dmy)
|
||||
if not match:
|
||||
return True
|
||||
try:
|
||||
parsed = date(
|
||||
int(match.group("y")),
|
||||
int(match.group("m")),
|
||||
int(match.group("d")),
|
||||
)
|
||||
except ValueError:
|
||||
return True
|
||||
|
||||
today = date.today()
|
||||
return parsed >= today and parsed.year == today.year and parsed.month == today.month
|
||||
|
||||
|
||||
def _normalize_tagged_list_lines(text: str, message: str) -> str:
|
||||
if not text:
|
||||
return text
|
||||
|
||||
upcoming_timeline_only = _is_upcoming_timeline_query(message)
|
||||
output_lines: list[str] = []
|
||||
|
||||
for line in text.splitlines():
|
||||
matches = list(_TAG_LINE_RE.finditer(line))
|
||||
if not matches:
|
||||
output_lines.append(line)
|
||||
continue
|
||||
|
||||
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
|
||||
if not had_non_tag_text and len(matches) == 1:
|
||||
tag_text = matches[0].group(0)
|
||||
if (
|
||||
upcoming_timeline_only
|
||||
and "<timeline>" in tag_text
|
||||
and not _timeline_date_in_current_month_or_future(line)
|
||||
):
|
||||
continue
|
||||
output_lines.append(tag_text)
|
||||
continue
|
||||
|
||||
for match in matches:
|
||||
tag_text = match.group(0)
|
||||
if (
|
||||
upcoming_timeline_only
|
||||
and "<timeline>" in tag_text
|
||||
and not _timeline_date_in_current_month_or_future(line)
|
||||
):
|
||||
continue
|
||||
output_lines.append(tag_text)
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
||||
|
||||
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
|
||||
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
|
||||
_FLOATING_EMPTY_FALLBACK = "No results found."
|
||||
|
||||
|
||||
def _strip_floating_markup_fragment(text: str) -> str:
|
||||
if not text:
|
||||
return text
|
||||
cleaned = _GENERIC_TAG_RE.sub("", text)
|
||||
return _BRACKETED_ID_RE.sub("", cleaned)
|
||||
|
||||
|
||||
def _strip_floating_markup(text: str) -> str:
|
||||
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
cleaned = _strip_floating_markup_fragment(text)
|
||||
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
|
||||
return "\n".join(line for line in lines if line)
|
||||
|
||||
|
||||
def _fallback_from_raw_floating_text(raw_text: str) -> str:
|
||||
fallback = _strip_floating_markup_fragment(raw_text or "")
|
||||
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
|
||||
return fallback or _FLOATING_EMPTY_FALLBACK
|
||||
|
||||
|
||||
class _FloatingStreamSanitizer:
|
||||
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pending = ""
|
||||
|
||||
@staticmethod
|
||||
def _split_safe_boundary(text: str) -> tuple[str, str]:
|
||||
boundary = len(text)
|
||||
|
||||
last_lt = text.rfind("<")
|
||||
if last_lt != -1 and ">" not in text[last_lt:]:
|
||||
boundary = min(boundary, last_lt)
|
||||
|
||||
last_lb = text.rfind("[")
|
||||
if last_lb != -1 and "]" not in text[last_lb:]:
|
||||
boundary = min(boundary, last_lb)
|
||||
|
||||
if boundary == len(text):
|
||||
return text, ""
|
||||
return text[:boundary], text[boundary:]
|
||||
|
||||
def feed(self, chunk: str) -> str:
|
||||
combined = f"{self._pending}{chunk}"
|
||||
safe_text, self._pending = self._split_safe_boundary(combined)
|
||||
return _strip_floating_markup_fragment(safe_text)
|
||||
|
||||
def finalize(self) -> str:
|
||||
tail = re.sub(r"<[^>\n]*$", "", self._pending)
|
||||
tail = re.sub(r"\[[^\]\n]*$", "", tail)
|
||||
self._pending = ""
|
||||
return _strip_floating_markup_fragment(tail)
|
||||
|
||||
|
||||
def _normalize_memory_label(path_or_label: str) -> str:
|
||||
value = path_or_label.strip()
|
||||
if value.startswith("/memories/"):
|
||||
value = value[len("/memories/"):]
|
||||
value = value.strip("/")
|
||||
return value
|
||||
|
||||
|
||||
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
@tool
|
||||
async def memory_list_blocks() -> str:
|
||||
"""List all core memory blocks currently stored for the user."""
|
||||
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
blocks = await memory.list_core_blocks(user_id)
|
||||
if not blocks:
|
||||
return "No memory blocks found."
|
||||
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
|
||||
return "Memory blocks:\n" + "\n".join(lines)
|
||||
|
||||
@tool
|
||||
async def memory_get(path_or_label: str) -> str:
|
||||
"""Get one memory block by label or /memories/<label> path."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
value = await memory.get_core_block(user_id, label)
|
||||
if value is None:
|
||||
return f"Memory block '{label}' not found."
|
||||
return f"Memory block '{label}':\n{value}"
|
||||
|
||||
@tool
|
||||
async def memory_create(path_or_label: str, value: str) -> str:
|
||||
"""Create or overwrite a memory block value by label or /memories/<label> path."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.update_core(user_id, label, value, trace_id=trace_id)
|
||||
return f"Memory block '{label}' saved."
|
||||
|
||||
@tool
|
||||
async def memory_append(path_or_label: str, content: str) -> str:
|
||||
"""Append content to a memory block, creating it if missing."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.append_core(user_id, label, content)
|
||||
return f"Memory block '{label}' appended."
|
||||
|
||||
@tool
|
||||
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
|
||||
"""Replace one exact string in a memory block."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
changed = await memory.replace_core(user_id, label, old_string, new_string)
|
||||
if not changed:
|
||||
return f"No replacement made in '{label}' (old string not found)."
|
||||
return f"Memory block '{label}' updated."
|
||||
|
||||
@tool
|
||||
async def memory_delete(path_or_label: str) -> str:
|
||||
"""Delete a memory block by label or /memories/<label> path."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
deleted = await memory.delete_core(user_id, label)
|
||||
if not deleted:
|
||||
return f"Memory block '{label}' not found."
|
||||
return f"Memory block '{label}' deleted."
|
||||
|
||||
@tool
|
||||
async def archival_memory_insert(content: str) -> str:
|
||||
"""Insert a long-term archival memory entry."""
|
||||
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.insert_archival(user_id, content, source="assistant")
|
||||
return "Archival memory saved."
|
||||
|
||||
@tool
|
||||
async def archival_memory_search(query: str, top_k: int = 5) -> str:
|
||||
"""Search long-term archival memory by semantic fallback (keyword currently)."""
|
||||
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
results = await memory.search_archival(user_id, query, top_k=top_k)
|
||||
if not results:
|
||||
return "No archival memory results found."
|
||||
lines = [f"- {item}" for item in results]
|
||||
return "Archival memory results:\n" + "\n".join(lines)
|
||||
|
||||
@tool
|
||||
async def conversation_search(query: str, top_k: int = 5) -> str:
|
||||
"""Search recall memory from prior episodic conversation summaries."""
|
||||
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
results = await memory.search_recall(user_id, query, top_k=top_k)
|
||||
if not results:
|
||||
return "No recall memory results found."
|
||||
lines = [f"- {item}" for item in results]
|
||||
return "Recall memory results:\n" + "\n".join(lines)
|
||||
|
||||
return [
|
||||
memory_list_blocks,
|
||||
memory_get,
|
||||
memory_create,
|
||||
memory_append,
|
||||
memory_replace,
|
||||
memory_delete,
|
||||
archival_memory_insert,
|
||||
archival_memory_search,
|
||||
conversation_search,
|
||||
]
|
||||
|
||||
|
||||
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
||||
|
||||
|
||||
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
|
||||
lowered = message.lower()
|
||||
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
|
||||
return "timeline"
|
||||
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
|
||||
return "task"
|
||||
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
|
||||
return "note"
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
|
||||
type_raw = str(payload.get("type") or "").strip().lower()
|
||||
domain_type: FloatingDomainType = "task"
|
||||
if type_raw in {"task", "timeline", "project", "node"}:
|
||||
domain_type = type_raw
|
||||
|
||||
id_value = payload.get("id")
|
||||
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
|
||||
if domain_type == "project" and not domain_id:
|
||||
domain_id = fallback_id
|
||||
|
||||
section_raw = payload.get("section")
|
||||
section: FloatingDomainSection | None = None
|
||||
if isinstance(section_raw, str):
|
||||
section_candidate = section_raw.strip().lower()
|
||||
if section_candidate in {"task", "timeline", "note"}:
|
||||
section = section_candidate
|
||||
|
||||
if domain_type != "project":
|
||||
section = None
|
||||
|
||||
return {
|
||||
"type": domain_type,
|
||||
"id": domain_id,
|
||||
"section": section,
|
||||
}
|
||||
|
||||
|
||||
def _parse_json_object(text: str) -> dict[str, Any] | None:
|
||||
raw = text.strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
try:
|
||||
parsed = json.loads(match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
|
||||
|
||||
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
||||
section = _detect_domain_section(message)
|
||||
scope = context.get("scope") if isinstance(context, dict) else None
|
||||
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
||||
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
||||
|
||||
if isinstance(scope, dict):
|
||||
scope_type = str(scope.get("type") or "").strip().lower()
|
||||
scope_id = scope.get("id")
|
||||
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
|
||||
|
||||
if scope_type in {"task", "tasks"}:
|
||||
return {"type": "task", "id": scope_id_value, "section": None}
|
||||
if scope_type in {"project", "projects"}:
|
||||
project_scope_id = scope_id_value or project_id
|
||||
return {
|
||||
"type": "project",
|
||||
"id": project_scope_id,
|
||||
"section": section,
|
||||
}
|
||||
if scope_type in {"note", "notes"}:
|
||||
return {
|
||||
"type": "node",
|
||||
"id": scope_id_value,
|
||||
"section": None,
|
||||
}
|
||||
if scope_type in {"timeline", "timelines"}:
|
||||
return {"type": "timeline", "id": scope_id_value, "section": None}
|
||||
|
||||
lowered = message.lower()
|
||||
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
|
||||
return {
|
||||
"type": "project",
|
||||
"id": project_id,
|
||||
"section": section,
|
||||
}
|
||||
if section == "timeline":
|
||||
return {"type": "timeline", "id": None, "section": None}
|
||||
if section == "note":
|
||||
return {"type": "node", "id": None, "section": None}
|
||||
return {"type": "task", "id": None, "section": None}
|
||||
|
||||
|
||||
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
||||
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
||||
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
||||
|
||||
classifier_context = {
|
||||
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
|
||||
"resolved_project_id": project_id,
|
||||
}
|
||||
|
||||
try:
|
||||
llm = get_llm()
|
||||
response = await llm.ainvoke(
|
||||
[
|
||||
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_SYSTEM),
|
||||
HumanMessage(
|
||||
content=(
|
||||
f"Message:\n{message}\n\n"
|
||||
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
parsed = _parse_json_object(_as_text(response.content))
|
||||
if parsed is not None:
|
||||
domain = _normalize_domain_payload(parsed, project_id)
|
||||
logger.info(
|
||||
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
|
||||
domain.get("type"),
|
||||
domain.get("id"),
|
||||
domain.get("section"),
|
||||
)
|
||||
return domain
|
||||
logger.warning("deep_agent: floating_domain classifier returned non-json output")
|
||||
except Exception as exc:
|
||||
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
|
||||
|
||||
return _infer_floating_domain_rule_based(message, context)
|
||||
|
||||
|
||||
async def _run_single_agent(
|
||||
*,
|
||||
user_id: str,
|
||||
system_prompt: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
max_steps: int = 6,
|
||||
) -> str:
|
||||
trace_id = _trace_id_from_context(context)
|
||||
llm = get_llm()
|
||||
tools = _all_tools_for_user(user_id, trace_id)
|
||||
model_context = _context_for_model(context)
|
||||
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
messages: list[Any] = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(
|
||||
content=(
|
||||
f"User message:\n{message}\n\n"
|
||||
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
tool_calls_count = 0
|
||||
collected: list[dict[str, Any]] = []
|
||||
set_tool_result_collector(collected)
|
||||
try:
|
||||
for _ in range(max_steps):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
final_text = _as_text(response.content)
|
||||
logger.info(
|
||||
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
tool_calls_count,
|
||||
len(final_text),
|
||||
)
|
||||
return final_text
|
||||
|
||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_calls_count += 1
|
||||
call_id = str(call.get("id", ""))
|
||||
call_name = str(call.get("name", ""))
|
||||
call_args = call.get("args", {})
|
||||
logger.info(
|
||||
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||
call_id,
|
||||
call_name,
|
||||
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||
)
|
||||
|
||||
tool_fn = tool_map.get(call_name)
|
||||
if tool_fn is None:
|
||||
tool_output = f"Unknown tool: {call_name}"
|
||||
else:
|
||||
tool_output = await tool_fn.ainvoke(call_args)
|
||||
|
||||
logger.info(
|
||||
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||
call_id,
|
||||
call_name,
|
||||
str(tool_output)[:1200],
|
||||
)
|
||||
|
||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||
|
||||
final = await llm.ainvoke(messages)
|
||||
final_text = _as_text(final.content)
|
||||
logger.info(
|
||||
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
tool_calls_count,
|
||||
len(final_text),
|
||||
)
|
||||
return final_text
|
||||
finally:
|
||||
clear_tool_result_collector()
|
||||
|
||||
|
||||
async def _run_single_agent_stream(
|
||||
*,
|
||||
user_id: str,
|
||||
system_prompt: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
max_steps: int = 6,
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
trace_id = _trace_id_from_context(context)
|
||||
llm = get_llm()
|
||||
tools = _all_tools_for_user(user_id, trace_id)
|
||||
model_context = _context_for_model(context)
|
||||
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
messages: list[Any] = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(
|
||||
content=(
|
||||
f"User message:\n{message}\n\n"
|
||||
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
tool_calls_count = 0
|
||||
streamed_chars = 0
|
||||
collected: list[dict[str, Any]] = []
|
||||
set_tool_result_collector(collected)
|
||||
try:
|
||||
for _ in range(max_steps):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
emitted_any = False
|
||||
async for chunk in llm.astream(messages):
|
||||
token = _as_text(getattr(chunk, "content", ""))
|
||||
if token:
|
||||
streamed_chars += len(token)
|
||||
emitted_any = True
|
||||
yield "token", token
|
||||
|
||||
if not emitted_any:
|
||||
fallback_text = _as_text(response.content)
|
||||
if fallback_text:
|
||||
streamed_chars += len(fallback_text)
|
||||
yield "token", fallback_text
|
||||
logger.info(
|
||||
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
tool_calls_count,
|
||||
streamed_chars,
|
||||
)
|
||||
return
|
||||
|
||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_calls_count += 1
|
||||
call_id = str(call.get("id", ""))
|
||||
call_name = str(call.get("name", ""))
|
||||
call_args = call.get("args", {})
|
||||
logger.info(
|
||||
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||
call_id,
|
||||
call_name,
|
||||
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||
)
|
||||
|
||||
tool_fn = tool_map.get(call_name)
|
||||
if tool_fn is None:
|
||||
tool_output = f"Unknown tool: {call_name}"
|
||||
else:
|
||||
tool_output = await tool_fn.ainvoke(call_args)
|
||||
|
||||
logger.info(
|
||||
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||
call_id,
|
||||
call_name,
|
||||
str(tool_output)[:1200],
|
||||
)
|
||||
|
||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||
|
||||
async for chunk in llm.astream(messages):
|
||||
token = _as_text(getattr(chunk, "content", ""))
|
||||
if token:
|
||||
streamed_chars += len(token)
|
||||
yield "token", token
|
||||
logger.info(
|
||||
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
tool_calls_count,
|
||||
streamed_chars,
|
||||
)
|
||||
finally:
|
||||
clear_tool_result_collector()
|
||||
|
||||
|
||||
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
||||
message=message,
|
||||
context=prepared_context,
|
||||
)
|
||||
return _normalize_tagged_list_lines(response, message)
|
||||
|
||||
|
||||
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
domain = await _infer_floating_domain(message, prepared_context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||
message=message,
|
||||
context=prepared_context,
|
||||
)
|
||||
sanitized = _strip_floating_markup(response)
|
||||
if not sanitized and response:
|
||||
sanitized = _fallback_from_raw_floating_text(response)
|
||||
return sanitized, domain
|
||||
|
||||
|
||||
async def run_home_stream(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
text_chunks: list[str] = []
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
||||
message=message,
|
||||
context=prepared_context,
|
||||
):
|
||||
event_type, data = event
|
||||
if event_type != "token":
|
||||
yield event
|
||||
continue
|
||||
text_chunks.append(str(data or ""))
|
||||
|
||||
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
|
||||
if normalized:
|
||||
yield "token", normalized
|
||||
|
||||
|
||||
async def run_floating_stream(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
domain = await _infer_floating_domain(message, prepared_context)
|
||||
yield "floating_domain", domain
|
||||
|
||||
sanitizer = _FloatingStreamSanitizer()
|
||||
emitted_sanitized = False
|
||||
raw_chunks: list[str] = []
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||
message=message,
|
||||
context=prepared_context,
|
||||
):
|
||||
event_type, data = event
|
||||
if event_type != "token":
|
||||
yield event
|
||||
continue
|
||||
|
||||
raw_chunk = str(data or "")
|
||||
raw_chunks.append(raw_chunk)
|
||||
sanitized_chunk = sanitizer.feed(raw_chunk)
|
||||
if sanitized_chunk:
|
||||
emitted_sanitized = True
|
||||
yield "token", sanitized_chunk
|
||||
|
||||
tail = sanitizer.finalize()
|
||||
if tail:
|
||||
emitted_sanitized = True
|
||||
yield "token", tail
|
||||
|
||||
if not emitted_sanitized and raw_chunks:
|
||||
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
||||
|
||||
|
||||
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
||||
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.update_core(user_id, key, value)
|
||||
77
services/chat/app/llm.py
Normal file
77
services/chat/app/llm.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||
|
||||
Adapted from app/core/llm.py for the Chat Service.
|
||||
Uses shared.config.settings instead of app.config.settings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
import litellm
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
|
||||
from shared.config import settings
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||
category=UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def _api_key_for_model(model: str) -> str | None:
|
||||
if model.startswith("anthropic/"):
|
||||
return settings.ANTHROPIC_API_KEY or None
|
||||
if model.startswith("gemini/") or model.startswith("google/"):
|
||||
return settings.GOOGLE_API_KEY or None
|
||||
if model.startswith("cerebras/"):
|
||||
return settings.CEREBRAS_API_KEY or None
|
||||
if model.startswith("github_copilot/"):
|
||||
return None
|
||||
return settings.OPENAI_API_KEY or None
|
||||
|
||||
|
||||
def get_llm(
|
||||
*,
|
||||
model: str | None = None,
|
||||
temperature: float = 0,
|
||||
) -> ChatOpenAI | ChatLiteLLM:
|
||||
model = model or settings.LLM_MODEL
|
||||
|
||||
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||
|
||||
if "/" in model:
|
||||
return ChatLiteLLM(model=model, temperature=temperature)
|
||||
|
||||
return ChatOpenAI(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=_api_key_for_model(model),
|
||||
)
|
||||
|
||||
|
||||
def get_router_llm(
|
||||
*,
|
||||
temperature: float = 0,
|
||||
) -> ChatOpenAI | ChatLiteLLM:
|
||||
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
||||
|
||||
|
||||
async def embed(text: str) -> list[float]:
|
||||
model = settings.LLM_EMBED_MODEL
|
||||
|
||||
if model.startswith("github_copilot/") or "/" in model:
|
||||
response = await litellm.aembedding(model=model, input=[text])
|
||||
return response.data[0]["embedding"]
|
||||
|
||||
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
response = await client.embeddings.create(model=model, input=text)
|
||||
return response.data[0].embedding
|
||||
71
services/chat/app/main.py
Normal file
71
services/chat/app/main.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Chat Service — LLM orchestration, domain agents, memory.
|
||||
|
||||
Consumes chat requests from Redis, executes deep_agent (home/floating),
|
||||
streams responses back via Redis pub/sub to WS Gateway.
|
||||
|
||||
Owns: memory_core, memory_associative, memory_episodic, memory_proactive tables.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from shared.config import settings
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Start Redis consumer in background
|
||||
from app.redis_consumer import start_consumer
|
||||
|
||||
consumer_task = start_consumer()
|
||||
yield
|
||||
consumer_task.cancel()
|
||||
|
||||
from shared.db import engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
from shared.redis import redis_client
|
||||
|
||||
await redis_client.aclose()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="Adiuva Chat Service",
|
||||
version="0.1.0",
|
||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
from app.routes import router
|
||||
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
@app.get("/api/v1/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
return {"status": "ok", "service": "chat", "version": app.version}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
295
services/chat/app/memory_middleware.py
Normal file
295
services/chat/app/memory_middleware.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""Memory Middleware — adapted for Chat Service.
|
||||
|
||||
Uses shared.models instead of app.models. Otherwise identical to the
|
||||
monolith's app/core/memory_middleware.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from shared.models import (
|
||||
MemoryAssociative,
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
MemoryProactive,
|
||||
User,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ASSOCIATIVE_TOP_K = 5
|
||||
_EPISODIC_RECENT_N = 10
|
||||
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
||||
|
||||
|
||||
class MemoryMiddleware:
|
||||
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self._db = db
|
||||
|
||||
async def enrich_context(
|
||||
self,
|
||||
user_id: str,
|
||||
message: str,
|
||||
trace_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return {}
|
||||
|
||||
core = await self._load_core(user_id, fernet)
|
||||
associative = await self._load_associative(user_id, message, fernet)
|
||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||
proactive = await self._load_proactive(user_id, fernet)
|
||||
|
||||
logger.info(
|
||||
"memory: enrich_context trace=%s user=%s core=%d assoc=%d episodic=%d proactive=%d",
|
||||
trace_id or "-", user_id, len(core), len(associative), len(episodic), len(proactive),
|
||||
)
|
||||
|
||||
return {
|
||||
"core_memory": core,
|
||||
"associative_memory": associative,
|
||||
"episodic_memory": episodic,
|
||||
"proactive_hints": proactive,
|
||||
}
|
||||
|
||||
async def store_episode(
|
||||
self, user_id: str, session_id: str, message: str, response: str,
|
||||
trace_id: str | None = None,
|
||||
) -> None:
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||
encrypted = _encrypt(fernet, summary)
|
||||
|
||||
row = MemoryEpisodic(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
summary_encrypted=encrypted,
|
||||
session_id=session_id,
|
||||
)
|
||||
self._db.add(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
except Exception as exc:
|
||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
encrypted = _encrypt(fernet, value)
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == key)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing is not None:
|
||||
existing.value_encrypted = encrypted
|
||||
else:
|
||||
self._db.add(MemoryCore(
|
||||
id=str(uuid.uuid4()), user_id=user_id, key=key, value_encrypted=encrypted,
|
||||
))
|
||||
try:
|
||||
await self._db.commit()
|
||||
except Exception as exc:
|
||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return []
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(MemoryCore.user_id == user_id).order_by(MemoryCore.key.asc())
|
||||
)
|
||||
out: list[dict[str, str]] = []
|
||||
for row in result.scalars().all():
|
||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append({"label": row.key, "value": plaintext})
|
||||
return out
|
||||
|
||||
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return None
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
return None
|
||||
return _safe_decrypt(fernet, row.value_encrypted)
|
||||
|
||||
async def delete_core(self, user_id: str, label: str) -> bool:
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
return False
|
||||
await self._db.delete(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
||||
await self._db.rollback()
|
||||
return False
|
||||
|
||||
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
||||
current = await self.get_core_block(user_id, label)
|
||||
if current is None:
|
||||
await self.update_core(user_id, label, content)
|
||||
return
|
||||
await self.update_core(user_id, label, f"{current}\n{content}")
|
||||
|
||||
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
||||
current = await self.get_core_block(user_id, label)
|
||||
if current is None or old not in current:
|
||||
return False
|
||||
await self.update_core(user_id, label, current.replace(old, new, 1))
|
||||
return True
|
||||
|
||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
encrypted = _encrypt(fernet, content)
|
||||
row = MemoryAssociative(
|
||||
id=str(uuid.uuid4()), user_id=user_id,
|
||||
content_encrypted=encrypted, embedding=None,
|
||||
entity_type=source, entity_id=None,
|
||||
)
|
||||
self._db.add(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
except Exception as exc:
|
||||
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return []
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc()).limit(100)
|
||||
)
|
||||
needle = query.strip().lower()
|
||||
out: list[str] = []
|
||||
for row in result.scalars().all():
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is None:
|
||||
continue
|
||||
if not needle or needle in plaintext.lower():
|
||||
out.append(plaintext)
|
||||
if len(out) >= max(top_k, 1):
|
||||
break
|
||||
return out
|
||||
|
||||
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return []
|
||||
result = await self._db.execute(
|
||||
select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||
.order_by(MemoryEpisodic.created_at.desc()).limit(100)
|
||||
)
|
||||
needle = query.strip().lower()
|
||||
out: list[str] = []
|
||||
for row in result.scalars().all():
|
||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||
if plaintext is None:
|
||||
continue
|
||||
if not needle or needle in plaintext.lower():
|
||||
out.append(plaintext)
|
||||
if len(out) >= max(top_k, 1):
|
||||
break
|
||||
return out
|
||||
|
||||
# ── Private ───────────────────────────────────────────────────────
|
||||
|
||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory: no encryption_key for user=%s", user_id)
|
||||
return None
|
||||
return Fernet(user.encryption_key.encode())
|
||||
|
||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||
)
|
||||
out: dict[str, str] = {}
|
||||
for row in result.scalars().all():
|
||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||
if plaintext is not None:
|
||||
out[row.key] = plaintext
|
||||
return out
|
||||
|
||||
async def _load_associative(self, user_id: str, message: str, fernet: Fernet) -> list[str]:
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc()).limit(_ASSOCIATIVE_TOP_K)
|
||||
)
|
||||
out: list[str] = []
|
||||
for row in result.scalars().all():
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
async def _load_episodic(self, user_id: str, fernet: Fernet, session_id: str | None = None) -> list[str]:
|
||||
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||
if session_id:
|
||||
query = query.where(MemoryEpisodic.session_id == session_id)
|
||||
result = await self._db.execute(
|
||||
query.order_by(MemoryEpisodic.created_at.desc()).limit(_EPISODIC_RECENT_N)
|
||||
)
|
||||
out: list[str] = []
|
||||
for row in result.scalars().all():
|
||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||
result = await self._db.execute(
|
||||
select(MemoryProactive).where(
|
||||
MemoryProactive.user_id == user_id,
|
||||
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
||||
).order_by(MemoryProactive.confidence.desc())
|
||||
)
|
||||
out: list[str] = []
|
||||
for row in result.scalars().all():
|
||||
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
|
||||
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
||||
return fernet.encrypt(plaintext.encode()).decode()
|
||||
|
||||
|
||||
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
||||
try:
|
||||
return fernet.decrypt(ciphertext.encode()).decode()
|
||||
except (InvalidToken, Exception) as exc:
|
||||
logger.warning("memory: decrypt failed: %s", exc)
|
||||
return None
|
||||
50
services/chat/app/output_formatter.py
Normal file
50
services/chat/app/output_formatter.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Output formatter for deep-agent stream events — Chat Service copy.
|
||||
|
||||
Converts (event_type, data) tuples into WebSocket frame Pydantic models.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from shared.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||
|
||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||
|
||||
|
||||
class StreamFormatter:
|
||||
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
|
||||
async def format(
|
||||
self,
|
||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
started = False
|
||||
|
||||
async for event_type, data in event_stream:
|
||||
if event_type == "floating_domain":
|
||||
if isinstance(data, dict):
|
||||
yield WsFloatingDomain(
|
||||
request_id=self.request_id,
|
||||
domain=data,
|
||||
)
|
||||
continue
|
||||
|
||||
if event_type != "token":
|
||||
continue
|
||||
|
||||
if not started:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
started = True
|
||||
|
||||
text = str(data or "")
|
||||
if text:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=text)
|
||||
|
||||
if not started:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
yield WsStreamEnd(request_id=self.request_id)
|
||||
170
services/chat/app/redis_consumer.py
Normal file
170
services/chat/app/redis_consumer.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Redis consumer — listens for chat requests and dispatches to deep_agent.
|
||||
|
||||
Subscribes to a Redis pattern channel chat:request:* so it receives
|
||||
requests for ALL users. Each request is processed in a separate asyncio task.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.db import async_session
|
||||
from shared.redis import redis_client, ws_out_channel
|
||||
|
||||
from app.deep_agent import run_floating_stream, run_home_stream
|
||||
from app.memory_middleware import MemoryMiddleware
|
||||
from app.output_formatter import StreamFormatter
|
||||
from app.ws_context import clear_current_user, set_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def start_consumer() -> asyncio.Task:
|
||||
"""Start the Redis consumer as a background asyncio task."""
|
||||
return asyncio.create_task(_consumer_loop())
|
||||
|
||||
|
||||
async def _consumer_loop() -> None:
|
||||
"""Subscribe to chat:request:* and dispatch incoming frames."""
|
||||
pubsub = redis_client.pubsub()
|
||||
await pubsub.psubscribe("chat:request:*")
|
||||
logger.info("redis_consumer: subscribed to chat:request:*")
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if message is not None and message["type"] == "pmessage":
|
||||
frame = json.loads(message["data"])
|
||||
asyncio.create_task(_dispatch(frame))
|
||||
else:
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("redis_consumer: shutting down")
|
||||
finally:
|
||||
await pubsub.punsubscribe()
|
||||
await pubsub.aclose()
|
||||
|
||||
|
||||
async def _dispatch(frame: dict) -> None:
|
||||
"""Route a chat request frame to the appropriate handler."""
|
||||
frame_type = frame.get("type")
|
||||
user_id = frame.get("user_id")
|
||||
|
||||
if not user_id:
|
||||
logger.warning("redis_consumer: frame missing user_id: %s", frame.get("type"))
|
||||
return
|
||||
|
||||
if frame_type == "home_request":
|
||||
await _handle_home_request(user_id, frame)
|
||||
elif frame_type == "floating_request":
|
||||
await _handle_floating_request(user_id, frame)
|
||||
else:
|
||||
logger.debug("redis_consumer: unknown frame type %r", frame_type)
|
||||
|
||||
|
||||
async def _publish_frame(user_id: str, frame_data: str) -> None:
|
||||
"""Publish a frame to ws:out:{user_id} for the WS Gateway to forward."""
|
||||
channel = ws_out_channel(user_id)
|
||||
await redis_client.publish(channel, frame_data)
|
||||
|
||||
|
||||
async def _handle_home_request(user_id: str, frame: dict) -> None:
|
||||
"""Process a home_request — enrich with memory, run deep_agent, stream results."""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
message: str = frame.get("message", "")
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
|
||||
logger.info(
|
||||
"redis_consumer: home_request user=%s req=%s msg=%s",
|
||||
user_id, request_id, message[:200],
|
||||
)
|
||||
|
||||
# Enrich with memory context
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id, message,
|
||||
trace_id=request_id, session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"conversation_history": frame.get("conversation_history", []),
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
set_current_user(user_id)
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
event_stream = run_home_stream(user_id, message, context)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await _publish_frame(user_id, ws_frame.model_dump_json())
|
||||
if hasattr(ws_frame, "chunk"):
|
||||
response_chunks.append(ws_frame.chunk)
|
||||
except Exception as exc:
|
||||
logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc)
|
||||
finally:
|
||||
clear_current_user()
|
||||
|
||||
# Store episode
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.store_episode(
|
||||
user_id, session_id, message, "".join(response_chunks),
|
||||
trace_id=request_id,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_floating_request(user_id: str, frame: dict) -> None:
|
||||
"""Process a floating_request — enrich with memory, run deep_agent, stream results."""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
message: str = frame.get("message", "")
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
scope: dict = frame.get("scope", {})
|
||||
|
||||
logger.info(
|
||||
"redis_consumer: floating_request user=%s req=%s scope=%s msg=%s",
|
||||
user_id, request_id, json.dumps(scope)[:200], message[:200],
|
||||
)
|
||||
|
||||
# Enrich with memory context
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id, message,
|
||||
trace_id=request_id, session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"scope": scope,
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
set_current_user(user_id)
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
event_stream = run_floating_stream(user_id, message, context)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await _publish_frame(user_id, ws_frame.model_dump_json())
|
||||
if hasattr(ws_frame, "chunk"):
|
||||
response_chunks.append(ws_frame.chunk)
|
||||
except Exception as exc:
|
||||
logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc)
|
||||
finally:
|
||||
clear_current_user()
|
||||
|
||||
# Store episode
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.store_episode(
|
||||
user_id, session_id, message, "".join(response_chunks),
|
||||
trace_id=request_id,
|
||||
)
|
||||
37
services/chat/app/routes.py
Normal file
37
services/chat/app/routes.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Chat REST route — POST /chat fallback when WS is unavailable."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from shared.schemas import ChatRequest
|
||||
|
||||
from app.deep_agent import run_home
|
||||
from app.ws_context import clear_current_user, set_current_user
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def chat(body: ChatRequest, request: Request) -> JSONResponse:
|
||||
"""REST fallback for home chat.
|
||||
|
||||
In the microservices setup, Traefik ForwardAuth has already validated
|
||||
the JWT and injected X-User-Id / X-User-Email / X-User-Tier headers.
|
||||
"""
|
||||
user_id = request.headers.get("X-User-Id", "")
|
||||
if not user_id:
|
||||
return JSONResponse(status_code=401, content={"detail": "Missing X-User-Id header"})
|
||||
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
response = await run_home(
|
||||
user_id=user_id,
|
||||
message=body.message,
|
||||
context=body.context.model_dump(),
|
||||
)
|
||||
finally:
|
||||
clear_current_user()
|
||||
|
||||
return JSONResponse(content={"response": response})
|
||||
115
services/chat/app/ws_context.py
Normal file
115
services/chat/app/ws_context.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""WebSocket context for Chat Service — Redis-based tool call round-trip.
|
||||
|
||||
Replaces the monolith's ws_context.py. Instead of calling Electron directly
|
||||
via WebSocket, this publishes tool_call frames to Redis (ws:out:{user_id})
|
||||
and awaits the result via BRPOP on tool:result:{call_id}.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.redis import redis_client, tool_result_key, ws_out_channel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TOOL_CALL_TIMEOUT = 30 # seconds — BRPOP timeout
|
||||
|
||||
# Per-request user_id context var (set before agent runs)
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("_current_user_id", default=None)
|
||||
|
||||
# Optional collector for debug
|
||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||
"_tool_result_collector", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_current_user(user_id: str) -> None:
|
||||
_current_user_id.set(user_id)
|
||||
|
||||
|
||||
def clear_current_user() -> None:
|
||||
_current_user_id.set(None)
|
||||
|
||||
|
||||
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||
_tool_result_collector.set(lst)
|
||||
|
||||
|
||||
def clear_tool_result_collector() -> None:
|
||||
_tool_result_collector.set(None)
|
||||
|
||||
|
||||
async def execute_on_client(
|
||||
action: str,
|
||||
table: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
filters: dict[str, Any] | None = None,
|
||||
vector: list[float] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a tool_call to Electron via Redis and await the result.
|
||||
|
||||
1. Build tool_call payload
|
||||
2. Publish to ws:out:{user_id} (WS Gateway forwards to Electron)
|
||||
3. BRPOP on tool:result:{call_id} (WS Gateway pushes when Electron replies)
|
||||
4. Return result dict
|
||||
|
||||
Raises RuntimeError if no user_id is set or if the call times out.
|
||||
"""
|
||||
user_id = _current_user_id.get()
|
||||
if not user_id:
|
||||
raise RuntimeError(
|
||||
"execute_on_client() called without a user_id — "
|
||||
"set_current_user() must be called first."
|
||||
)
|
||||
|
||||
call_id = str(uuid4())
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool_call",
|
||||
"id": call_id,
|
||||
"action": action,
|
||||
}
|
||||
if table is not None:
|
||||
payload["table"] = table
|
||||
if data is not None:
|
||||
payload["data"] = data
|
||||
if filters is not None:
|
||||
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
||||
if vector is not None:
|
||||
payload["vector"] = vector
|
||||
if limit is not None:
|
||||
payload["limit"] = limit
|
||||
|
||||
# Publish tool_call to WS Gateway → Electron
|
||||
channel = ws_out_channel(user_id)
|
||||
await redis_client.publish(channel, json.dumps(payload))
|
||||
|
||||
# Wait for Electron's tool_result
|
||||
result_key = tool_result_key(call_id)
|
||||
response = await redis_client.brpop(result_key, timeout=_TOOL_CALL_TIMEOUT)
|
||||
|
||||
if response is None:
|
||||
raise RuntimeError(
|
||||
f"Tool call {call_id} timed out after {_TOOL_CALL_TIMEOUT}s — "
|
||||
f"device may be offline or unresponsive."
|
||||
)
|
||||
|
||||
# response is (key, value) tuple
|
||||
_, raw = response
|
||||
result = json.loads(raw)
|
||||
|
||||
# Collect for debug if requested
|
||||
collector = _tool_result_collector.get(None)
|
||||
if collector is not None:
|
||||
collector.append({
|
||||
"action": action,
|
||||
"table": table,
|
||||
"data": result,
|
||||
})
|
||||
|
||||
return result
|
||||
16
services/chat/requirements.txt
Normal file
16
services/chat/requirements.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.34.0
|
||||
gunicorn>=22.0.0
|
||||
pydantic>=2.10.0
|
||||
pydantic-settings>=2.7.0
|
||||
sqlalchemy>=2.0.0
|
||||
asyncpg>=0.30.0
|
||||
redis>=5.0.0
|
||||
cryptography>=42.0.0
|
||||
python-dotenv>=1.0.0
|
||||
langchain-core>=0.3.0
|
||||
langchain-openai>=0.3.0
|
||||
langchain-litellm>=0.3.0
|
||||
litellm>=1.50.0
|
||||
openai>=1.50.0
|
||||
httpx>=0.27.0
|
||||
36
services/ws-gateway/Dockerfile
Normal file
36
services/ws-gateway/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
||||
# ── builder ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY services/ws-gateway/requirements.txt ./requirements.txt
|
||||
RUN pip install --upgrade pip && \
|
||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||
|
||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS runtime
|
||||
|
||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /install /usr/local
|
||||
|
||||
# Shared module
|
||||
COPY shared/ shared/
|
||||
|
||||
# Service source
|
||||
COPY services/ws-gateway/app/ app/
|
||||
|
||||
RUN chown -R appuser:appgroup /app
|
||||
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
# Single worker — each instance handles many WS connections via asyncio
|
||||
CMD ["gunicorn", "app.main:app", \
|
||||
"-k", "uvicorn.workers.UvicornWorker", \
|
||||
"--bind", "0.0.0.0:8000", \
|
||||
"--workers", "1", \
|
||||
"--timeout", "0"]
|
||||
173
services/ws-gateway/app/handler.py
Normal file
173
services/ws-gateway/app/handler.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""WebSocket handler — device connection lifecycle.
|
||||
|
||||
Accepts Electron WS connections, authenticates JWT, registers device in Redis,
|
||||
and runs two concurrent loops:
|
||||
1. Message loop: receive frames from Electron, route to Redis
|
||||
2. Outbound loop: subscribe to Redis ws:out:{user_id}, forward to Electron
|
||||
3. Heartbeat loop: ping every 30s
|
||||
|
||||
No business logic lives here — the handler is a JSON frame router.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from shared.config import settings
|
||||
from shared.schemas import WsFrameType
|
||||
|
||||
from app.redis_bridge import (
|
||||
publish_batch_request,
|
||||
publish_chat_request,
|
||||
push_tool_result,
|
||||
register_device,
|
||||
set_gateway_id,
|
||||
subscribe_outbound,
|
||||
unregister_device,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["ws-gateway"])
|
||||
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
|
||||
# Set a unique gateway instance ID on module load
|
||||
set_gateway_id(str(uuid4()))
|
||||
|
||||
|
||||
@router.websocket("/device")
|
||||
async def device_ws(websocket: WebSocket) -> None:
|
||||
"""Persistent WebSocket endpoint for Electron device connections."""
|
||||
|
||||
# ── 1. Authenticate via ?token= query parameter ──────────────────
|
||||
token = websocket.query_params.get("token", "")
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_PUBLIC_KEY,
|
||||
algorithms=["RS256"],
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
if not user_id:
|
||||
raise JWTError("missing sub")
|
||||
except JWTError:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# ── 2. Await device_hello frame ──────────────────────────────────
|
||||
try:
|
||||
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
||||
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
try:
|
||||
hello = json.loads(raw)
|
||||
if hello.get("type") != WsFrameType.device_hello:
|
||||
raise ValueError("expected device_hello as first frame")
|
||||
device_id: str = hello["device_id"]
|
||||
agent_ids: list[str] = hello.get("agent_ids", [])
|
||||
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
||||
logger.warning("handler: invalid device_hello user=%s: %s", user_id, exc)
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# ── 3. Register device in Redis ──────────────────────────────────
|
||||
await register_device(user_id, device_id)
|
||||
logger.info("handler: connected user=%s device=%s agents=%s", user_id, device_id, agent_ids)
|
||||
|
||||
# Notify downstream services that device is online (for agent trigger)
|
||||
await publish_batch_request(user_id, {
|
||||
"type": "device_online",
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"agent_ids": agent_ids,
|
||||
})
|
||||
|
||||
# ── 4. Subscribe to outbound Redis channel ───────────────────────
|
||||
pubsub = await subscribe_outbound(user_id)
|
||||
|
||||
# ── 5. Run concurrent loops ──────────────────────────────────────
|
||||
try:
|
||||
await asyncio.gather(
|
||||
_inbound_loop(websocket, user_id),
|
||||
_outbound_loop(websocket, pubsub),
|
||||
_heartbeat_loop(websocket),
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("handler: unhandled exception user=%s: %s", user_id, exc)
|
||||
finally:
|
||||
await pubsub.unsubscribe()
|
||||
await pubsub.aclose()
|
||||
await unregister_device(user_id)
|
||||
logger.info("handler: disconnected user=%s device=%s", user_id, device_id)
|
||||
|
||||
|
||||
# ── Inbound: Electron → Redis ────────────────────────────────────────
|
||||
|
||||
async def _inbound_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
"""Receive frames from Electron and route to the appropriate Redis channel."""
|
||||
async for raw in websocket.iter_text():
|
||||
try:
|
||||
frame: dict = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("handler: invalid JSON from user=%s", user_id)
|
||||
continue
|
||||
|
||||
frame_type = frame.get("type")
|
||||
|
||||
# Inject user_id so downstream services know who sent it
|
||||
frame["user_id"] = user_id
|
||||
|
||||
if frame_type == WsFrameType.tool_result:
|
||||
call_id = frame.get("id")
|
||||
if call_id:
|
||||
await push_tool_result(call_id, frame)
|
||||
else:
|
||||
logger.warning("handler: tool_result missing id user=%s", user_id)
|
||||
|
||||
elif frame_type in (WsFrameType.home_request, WsFrameType.floating_request):
|
||||
await publish_chat_request(user_id, frame)
|
||||
|
||||
elif frame_type in (WsFrameType.journey_start, WsFrameType.journey_message):
|
||||
await publish_batch_request(user_id, frame)
|
||||
|
||||
elif frame_type == "pong":
|
||||
pass # heartbeat ack
|
||||
|
||||
else:
|
||||
logger.debug("handler: unknown frame type %r user=%s", frame_type, user_id)
|
||||
|
||||
|
||||
# ── Outbound: Redis → Electron ───────────────────────────────────────
|
||||
|
||||
async def _outbound_loop(websocket: WebSocket, pubsub) -> None:
|
||||
"""Subscribe to Redis ws:out:{user_id} and forward frames to Electron."""
|
||||
while True:
|
||||
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||
if message is not None and message["type"] == "message":
|
||||
await websocket.send_text(message["data"])
|
||||
else:
|
||||
# Brief sleep to avoid busy-wait when no messages
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
# ── Heartbeat ────────────────────────────────────────────────────────
|
||||
|
||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||
"""Send ping frames every 30s to keep the connection alive."""
|
||||
while True:
|
||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||
await websocket.send_text(json.dumps({"type": "ping"}))
|
||||
49
services/ws-gateway/app/main.py
Normal file
49
services/ws-gateway/app/main.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""WS Gateway — stateless WebSocket proxy.
|
||||
|
||||
Accepts Electron device connections, authenticates JWT (RS256 public key),
|
||||
and routes frames between Electron and downstream services via Redis pub/sub.
|
||||
|
||||
This service has NO business logic — it only routes JSON frames.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
from shared.config import settings
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
yield
|
||||
from shared.redis import redis_client
|
||||
|
||||
await redis_client.aclose()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="Adiuva WS Gateway",
|
||||
version="0.1.0",
|
||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
from app.handler import router
|
||||
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
@app.get("/api/v1/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
return {"status": "ok", "service": "ws-gateway", "version": app.version}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
104
services/ws-gateway/app/redis_bridge.py
Normal file
104
services/ws-gateway/app/redis_bridge.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Redis bridge — device registry + pub/sub routing.
|
||||
|
||||
All inter-service communication passes through Redis:
|
||||
- Device registry: HSET/HDEL ws:devices:{user_id}
|
||||
- Outbound frames: Subscribe ws:out:{user_id}
|
||||
- Chat requests: Publish chat:request:{user_id}
|
||||
- Batch requests: Publish batch:request:{user_id}
|
||||
- Tool results: LPUSH tool:result:{call_id}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from shared.redis import (
|
||||
batch_request_channel,
|
||||
chat_request_channel,
|
||||
device_key,
|
||||
redis_client,
|
||||
tool_result_key,
|
||||
ws_out_channel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Instance ID for this gateway replica (set on startup)
|
||||
_GATEWAY_ID: str = ""
|
||||
|
||||
|
||||
def set_gateway_id(gid: str) -> None:
|
||||
global _GATEWAY_ID
|
||||
_GATEWAY_ID = gid
|
||||
|
||||
|
||||
def get_gateway_id() -> str:
|
||||
return _GATEWAY_ID
|
||||
|
||||
|
||||
# ── Device Registry ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def register_device(user_id: str, device_id: str) -> None:
|
||||
"""Register a connected device in Redis."""
|
||||
key = device_key(user_id)
|
||||
await redis_client.hset(key, mapping={
|
||||
"device_id": device_id,
|
||||
"gateway_id": _GATEWAY_ID,
|
||||
})
|
||||
logger.info("redis_bridge: registered user=%s device=%s gateway=%s", user_id, device_id, _GATEWAY_ID)
|
||||
|
||||
|
||||
async def unregister_device(user_id: str) -> None:
|
||||
"""Remove device registration from Redis."""
|
||||
key = device_key(user_id)
|
||||
await redis_client.delete(key)
|
||||
logger.info("redis_bridge: unregistered user=%s", user_id)
|
||||
|
||||
|
||||
async def is_device_online(user_id: str) -> bool:
|
||||
"""Check if a device is registered."""
|
||||
key = device_key(user_id)
|
||||
return await redis_client.exists(key) > 0
|
||||
|
||||
|
||||
# ── Frame Routing ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def publish_chat_request(user_id: str, frame: dict) -> None:
|
||||
"""Forward a chat request frame to the Chat Service via Redis."""
|
||||
channel = chat_request_channel(user_id)
|
||||
await redis_client.publish(channel, json.dumps(frame))
|
||||
logger.debug("redis_bridge: published chat_request user=%s", user_id)
|
||||
|
||||
|
||||
async def publish_batch_request(user_id: str, frame: dict) -> None:
|
||||
"""Forward a batch request frame to the Batch Agent Service via Redis."""
|
||||
channel = batch_request_channel(user_id)
|
||||
await redis_client.publish(channel, json.dumps(frame))
|
||||
logger.debug("redis_bridge: published batch_request user=%s", user_id)
|
||||
|
||||
|
||||
async def push_tool_result(call_id: str, result: dict) -> None:
|
||||
"""Push a tool_result to the Redis list for the waiting service.
|
||||
|
||||
Chat/Batch services do BRPOP on this key with a 30s timeout.
|
||||
"""
|
||||
key = tool_result_key(call_id)
|
||||
await redis_client.lpush(key, json.dumps(result))
|
||||
# Auto-expire after 60s to prevent stale keys
|
||||
await redis_client.expire(key, 60)
|
||||
logger.debug("redis_bridge: pushed tool_result call_id=%s", call_id)
|
||||
|
||||
|
||||
async def subscribe_outbound(user_id: str):
|
||||
"""Return an async pubsub subscription for frames to send to Electron.
|
||||
|
||||
Chat/Batch services publish to ws:out:{user_id} and this gateway
|
||||
forwards them to the connected WebSocket.
|
||||
"""
|
||||
channel = ws_out_channel(user_id)
|
||||
pubsub = redis_client.pubsub()
|
||||
await pubsub.subscribe(channel)
|
||||
return pubsub
|
||||
8
services/ws-gateway/requirements.txt
Normal file
8
services/ws-gateway/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.34.0
|
||||
gunicorn>=22.0.0
|
||||
pydantic>=2.10.0
|
||||
pydantic-settings>=2.7.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
redis>=5.0.0
|
||||
websockets>=14.0
|
||||
@@ -33,7 +33,7 @@ certificatesResolvers:
|
||||
storage: /etc/traefik/acme/acme.json
|
||||
dnsChallenge:
|
||||
provider: cloudflare
|
||||
delayBeforeCheck: 10
|
||||
delayBeforeCheck: "10"
|
||||
resolvers:
|
||||
- "1.1.1.1:53"
|
||||
- "8.8.8.8:53"
|
||||
|
||||
Reference in New Issue
Block a user