refactor: deduplicate shared code into shared/ module
Move duplicated files from chat + batch-agent into shared/: - shared/ws_context.py — Redis-based tool call round-trip - shared/llm.py — LiteLLM factory (get_llm, embed) - shared/agents/ — 4 domain agents (task, note, project, timeline) Update all service imports to use shared.* instead of app.*. Delete 12 duplicated files across both services.
This commit is contained in:
1
shared/agents/__init__.py
Normal file
1
shared/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Shared domain agents — tool definitions used by both Chat and Batch Agent services."""
|
||||
142
shared/agents/note_agent.py
Normal file
142
shared/agents/note_agent.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Note agent — Markdown note management (list, get, create, update, delete).
|
||||
|
||||
Shared tool definitions used by both Chat and Batch Agent services.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from shared.llm import embed
|
||||
from shared.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
NOTE_SYSTEM_PROMPT = (
|
||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||
"and delete Markdown notes in their workspace.\n\n"
|
||||
"Rules:\n"
|
||||
" - content is always Markdown; preserve formatting when updating\n"
|
||||
" - project_id is optional; link a note to a project when mentioned\n"
|
||||
" - When updating, call get_note first if you need to read existing content\n"
|
||||
" before appending or replacing sections\n"
|
||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||
" when the user is working within a specific project\n"
|
||||
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||
" is already in the note (retrieved via get_note)."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_notes(project_id: str = "") -> str:
|
||||
"""List notes, optionally scoped to a project by project_id."""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="notes",
|
||||
filters={"projectId": normalized_project_id or None},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No notes found."
|
||||
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_note(note_id: str) -> str:
|
||||
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
|
||||
row = result.get("row")
|
||||
if not row:
|
||||
return f"Note {note_id} not found."
|
||||
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
|
||||
|
||||
|
||||
@tool
|
||||
async def create_note(
|
||||
title: str,
|
||||
content: str,
|
||||
project_id: str = "",
|
||||
) -> str:
|
||||
"""Create a new note.
|
||||
title: note heading (required)
|
||||
content: Markdown body text (required)
|
||||
project_id: optional UUID linking this note to a project
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="notes",
|
||||
data={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"projectId": project_id or None,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
# Index the note content in the vector store.
|
||||
vector = await embed(content)
|
||||
await execute_on_client(
|
||||
action="vector_upsert",
|
||||
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
||||
vector=vector,
|
||||
)
|
||||
return f"Note created: '{row['title']}' (id: {row['id']})."
|
||||
|
||||
|
||||
@tool
|
||||
async def update_note(
|
||||
note_id: str,
|
||||
title: str = "",
|
||||
content: str = "",
|
||||
) -> str:
|
||||
"""Update an existing note. Only pass fields that should change.
|
||||
note_id: UUID of the note (required)
|
||||
If you need to preserve existing content, call get_note first.
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if content:
|
||||
updates["content"] = content
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="notes",
|
||||
data={"id": note_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
# Re-index if content changed.
|
||||
if content:
|
||||
vector = await embed(content)
|
||||
await execute_on_client(
|
||||
action="vector_upsert",
|
||||
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
||||
vector=vector,
|
||||
)
|
||||
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_note(note_id: str) -> str:
|
||||
"""Delete a note permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="notes", data={"id": note_id})
|
||||
return f"Note {note_id} deleted."
|
||||
|
||||
|
||||
NOTE_TOOLS: list[Any] = [
|
||||
list_notes,
|
||||
get_note,
|
||||
create_note,
|
||||
update_note,
|
||||
delete_note,
|
||||
]
|
||||
146
shared/agents/project_agent.py
Normal file
146
shared/agents/project_agent.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete).
|
||||
|
||||
Shared tool definitions used by both Chat and Batch Agent services.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from shared.ws_context import execute_on_client
|
||||
|
||||
PROJECT_SYSTEM_PROMPT = (
|
||||
"You are a project management assistant. You help users create, find,\n"
|
||||
"update, and archive projects in their workspace.\n\n"
|
||||
"Rules:\n"
|
||||
" - status must be one of: active, archived\n"
|
||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
||||
" derive it from context data — do not fabricate content\n"
|
||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
||||
" user wants a complete cross-client view including archived projects\n"
|
||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
||||
" list_projects if you only have a project name\n"
|
||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
||||
" only call delete_project when the user explicitly confirms deletion."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_projects(
|
||||
client_id: str = "",
|
||||
include_archived: int = 0,
|
||||
) -> str:
|
||||
"""List projects, optionally filtered by client_id.
|
||||
include_archived: 1 to include archived projects, 0 for active only (default).
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="projects",
|
||||
filters={
|
||||
"clientId": client_id or None,
|
||||
"includeArchived": bool(include_archived),
|
||||
},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No projects found."
|
||||
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_all_projects() -> str:
|
||||
"""List every project regardless of client or status.
|
||||
Use only when the user wants a complete cross-client overview.
|
||||
"""
|
||||
result = await execute_on_client(action="select", table="projects")
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No projects found."
|
||||
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_project(project_id: str) -> str:
|
||||
"""Fetch a single project by its UUID."""
|
||||
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
|
||||
row = result.get("row")
|
||||
if not row:
|
||||
return f"Project {project_id} not found."
|
||||
return (
|
||||
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
|
||||
f"clientId: {row.get('clientId', 'none')})"
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def create_project(
|
||||
name: str,
|
||||
client_id: str = "",
|
||||
) -> str:
|
||||
"""Create a new project.
|
||||
name: human-readable project name (required)
|
||||
client_id: optional UUID of the owning client
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="projects",
|
||||
data={"name": name, "clientId": client_id or None},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Project created: '{row['name']}' (id: {row['id']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def update_project(
|
||||
project_id: str,
|
||||
name: str = "",
|
||||
client_id: str = "",
|
||||
status: str = "",
|
||||
ai_summary: str = "",
|
||||
) -> str:
|
||||
"""Update a project. Only pass fields that should change.
|
||||
project_id: UUID of the project (required)
|
||||
status: active | archived
|
||||
ai_summary: AI-generated summary text (populate only when explicitly requested)
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if name:
|
||||
updates["name"] = name
|
||||
if client_id:
|
||||
updates["clientId"] = client_id
|
||||
if status:
|
||||
updates["status"] = status
|
||||
if ai_summary:
|
||||
updates["aiSummary"] = ai_summary
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="projects",
|
||||
data={"id": project_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_project(project_id: str) -> str:
|
||||
"""Permanently delete a project and orphan its tasks.
|
||||
IMPORTANT: prefer update_project(status='archived') unless the user
|
||||
has explicitly confirmed they want permanent deletion.
|
||||
"""
|
||||
await execute_on_client(action="delete", table="projects", data={"id": project_id})
|
||||
return f"Project {project_id} permanently deleted."
|
||||
|
||||
|
||||
PROJECT_TOOLS: list[Any] = [
|
||||
list_projects,
|
||||
list_all_projects,
|
||||
get_project,
|
||||
create_project,
|
||||
update_project,
|
||||
delete_project,
|
||||
]
|
||||
239
shared/agents/task_agent.py
Normal file
239
shared/agents/task_agent.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Task agent — full CRUD for tasks and task comments.
|
||||
|
||||
Shared tool definitions used by both Chat and Batch Agent services.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from shared.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
TASK_SYSTEM_PROMPT = (
|
||||
"You are a task management assistant for a project workspace.\n"
|
||||
"You create, update, list, and track tasks and their comments.\n\n"
|
||||
"Rules:\n"
|
||||
" - status must be one of: todo, in_progress, done\n"
|
||||
" - priority must be one of: high, medium, low\n"
|
||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
||||
" - project_id is optional; link to a project when the user mentions one\n"
|
||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
||||
" did not explicitly request; 0 otherwise\n"
|
||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||
" - Always confirm the action in plain, user-friendly language."
|
||||
)
|
||||
|
||||
|
||||
# ── Task tools ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks(
|
||||
project_id: str = "",
|
||||
status: str = "",
|
||||
search: str = "",
|
||||
order_by: str = "",
|
||||
) -> str:
|
||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="tasks",
|
||||
filters={
|
||||
"projectId": normalized_project_id or None,
|
||||
"status": status or None,
|
||||
"search": search or None,
|
||||
"orderBy": order_by or None,
|
||||
},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks found matching the given filters."
|
||||
lines = [
|
||||
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def create_task(
|
||||
title: str,
|
||||
description: str = "",
|
||||
status: str = "todo",
|
||||
priority: str = "medium",
|
||||
assignees: str = "[]",
|
||||
due_date: int = 0,
|
||||
project_id: str = "",
|
||||
is_ai_suggested: int = 0,
|
||||
) -> str:
|
||||
"""Create a new task.
|
||||
title: task title (required)
|
||||
description: optional details
|
||||
status: todo | in_progress | done (default: todo)
|
||||
priority: high | medium | low (default: medium)
|
||||
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
|
||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||
project_id: optional UUID of the parent project
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="tasks",
|
||||
data={
|
||||
"title": title,
|
||||
"description": description or None,
|
||||
"status": status,
|
||||
"priority": priority,
|
||||
"assignee": assignees,
|
||||
"dueDate": due_date or None,
|
||||
"projectId": project_id or None,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
return (
|
||||
f"Task created: '{row['title']}' "
|
||||
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def update_task(
|
||||
task_id: str,
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
status: str = "",
|
||||
priority: str = "",
|
||||
assignees: str = "",
|
||||
due_date: int = -1,
|
||||
project_id: str = "",
|
||||
) -> str:
|
||||
"""Update fields on an existing task. Only pass fields you want to change.
|
||||
task_id: the task's UUID (required)
|
||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if description:
|
||||
updates["description"] = description
|
||||
if status:
|
||||
updates["status"] = status
|
||||
if priority:
|
||||
updates["priority"] = priority
|
||||
if assignees:
|
||||
updates["assignee"] = assignees
|
||||
if due_date != -1:
|
||||
updates["dueDate"] = due_date or None
|
||||
if project_id:
|
||||
updates["projectId"] = project_id
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="tasks",
|
||||
data={"id": task_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task(task_id: str) -> str:
|
||||
"""Delete a task permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
|
||||
return f"Task {task_id} deleted."
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks_due_today() -> str:
|
||||
"""List all tasks whose due date falls on today's date."""
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
||||
end_ms = start_ms + 86_400_000 - 1 # last ms of today
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="tasks",
|
||||
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks are due today."
|
||||
lines = [
|
||||
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
# ── Task comment tools ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@tool
|
||||
async def list_task_comments(task_id: str) -> str:
|
||||
"""List all comments on a task by its UUID."""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="taskComments",
|
||||
filters={"taskId": task_id},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return f"No comments found for task {task_id}."
|
||||
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||
"""Add a comment to a task.
|
||||
task_id: UUID of the task to comment on
|
||||
author: name or ID of the comment author
|
||||
content: comment text
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="taskComments",
|
||||
data={"taskId": task_id, "author": author, "content": content},
|
||||
)
|
||||
row = result.get("row", {})
|
||||
row_author = row.get("author", author)
|
||||
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||
row_comment_id = row.get("id", "unknown")
|
||||
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task_comment(comment_id: str) -> str:
|
||||
"""Delete a task comment by its UUID."""
|
||||
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
|
||||
return f"Comment {comment_id} deleted."
|
||||
|
||||
|
||||
# ── Exports ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
TASK_TOOLS: list[Any] = [
|
||||
list_tasks,
|
||||
create_task,
|
||||
update_task,
|
||||
delete_task,
|
||||
list_tasks_due_today,
|
||||
list_task_comments,
|
||||
add_task_comment,
|
||||
delete_task_comment,
|
||||
]
|
||||
116
shared/agents/timeline_agent.py
Normal file
116
shared/agents/timeline_agent.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Timeline agent — project milestone management (list, create, update, delete).
|
||||
|
||||
Shared tool definitions used by both Chat and Batch Agent services.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from shared.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
TIMELINE_SYSTEM_PROMPT = (
|
||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
||||
"track progress on a project — they are not calendar events.\n\n"
|
||||
"Rules:\n"
|
||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
||||
" - Listing without a project_id returns all timelines across projects\n"
|
||||
" - Always echo the title and formatted date in your confirmation."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_timelines(project_id: str = "") -> str:
|
||||
"""List timelines. Provide project_id to scope to a specific project."""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="timelines",
|
||||
filters={"projectId": normalized_project_id or None},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No timelines found."
|
||||
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def create_timeline(
|
||||
project_id: str,
|
||||
title: str,
|
||||
date: int,
|
||||
is_ai_suggested: int = 0,
|
||||
) -> str:
|
||||
"""Create a project timeline (milestone).
|
||||
project_id: REQUIRED UUID of the parent project
|
||||
title: descriptive name for the milestone
|
||||
date: Unix timestamp in milliseconds
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="timelines",
|
||||
data={
|
||||
"projectId": project_id,
|
||||
"title": title,
|
||||
"date": date,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def update_timeline(
|
||||
timeline_id: str,
|
||||
title: str = "",
|
||||
date: int = -1,
|
||||
) -> str:
|
||||
"""Update a timeline. Only pass fields that should change.
|
||||
timeline_id: UUID of the timeline (required)
|
||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if date != -1:
|
||||
updates["date"] = date
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="timelines",
|
||||
data={"id": timeline_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_timeline(timeline_id: str) -> str:
|
||||
"""Delete a timeline permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||
return f"Timeline {timeline_id} deleted."
|
||||
|
||||
|
||||
TIMELINE_TOOLS: list[Any] = [
|
||||
list_timelines,
|
||||
create_timeline,
|
||||
update_timeline,
|
||||
delete_timeline,
|
||||
]
|
||||
72
shared/llm.py
Normal file
72
shared/llm.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||
|
||||
Shared by Chat and Batch Agent services.
|
||||
Uses shared.config.settings for all configuration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
import litellm
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
|
||||
from shared.config import settings
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||
category=UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def _api_key_for_model(model: str) -> str | None:
|
||||
if model.startswith("anthropic/"):
|
||||
return settings.ANTHROPIC_API_KEY or None
|
||||
if model.startswith("gemini/") or model.startswith("google/"):
|
||||
return settings.GOOGLE_API_KEY or None
|
||||
if model.startswith("cerebras/"):
|
||||
return settings.CEREBRAS_API_KEY or None
|
||||
if model.startswith("github_copilot/"):
|
||||
return None
|
||||
return settings.OPENAI_API_KEY or None
|
||||
|
||||
|
||||
def get_llm(
|
||||
*,
|
||||
model: str | None = None,
|
||||
temperature: float = 0,
|
||||
callbacks: list | None = None,
|
||||
) -> ChatOpenAI | ChatLiteLLM:
|
||||
model = model or settings.LLM_MODEL
|
||||
|
||||
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||
|
||||
if "/" in model:
|
||||
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
|
||||
|
||||
return ChatOpenAI(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=_api_key_for_model(model),
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
|
||||
async def embed(text: str) -> list[float]:
|
||||
model = settings.LLM_EMBED_MODEL
|
||||
|
||||
if model.startswith("github_copilot/") or "/" in model:
|
||||
response = await litellm.aembedding(model=model, input=[text])
|
||||
return response.data[0]["embedding"]
|
||||
|
||||
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
response = await client.embeddings.create(model=model, input=text)
|
||||
return response.data[0].embedding
|
||||
132
shared/ws_context.py
Normal file
132
shared/ws_context.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""WebSocket context — Redis-based tool call round-trip.
|
||||
|
||||
Shared by Chat and Batch Agent services. Publishes tool_call frames to
|
||||
Redis ``ws:out:{user_id}`` and awaits the result via BRPOP on
|
||||
``tool:result:{call_id}``.
|
||||
|
||||
Also provides ``set_client_executor`` / ``clear_client_executor`` no-op
|
||||
shims for backward compatibility with agent_runner code that originally
|
||||
used a DeviceConnectionManager callback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Coroutine
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.redis import redis_client, tool_result_key, ws_out_channel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TOOL_CALL_TIMEOUT = 30 # seconds — BRPOP timeout
|
||||
|
||||
# Per-request user_id context var (set before agent runs)
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("_current_user_id", default=None)
|
||||
|
||||
# Optional collector for debug
|
||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||
"_tool_result_collector", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_current_user(user_id: str) -> None:
|
||||
_current_user_id.set(user_id)
|
||||
|
||||
|
||||
def clear_current_user() -> None:
|
||||
_current_user_id.set(None)
|
||||
|
||||
|
||||
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||
_tool_result_collector.set(lst)
|
||||
|
||||
|
||||
def clear_tool_result_collector() -> None:
|
||||
_tool_result_collector.set(None)
|
||||
|
||||
|
||||
# ── Compatibility shims ──────────────────────────────────────────────────
|
||||
# agent_runner.py originally called set_client_executor / clear_client_executor
|
||||
# with a DeviceConnectionManager callback. In the microservice world the
|
||||
# Redis-based execute_on_client replaces this, so these are no-ops.
|
||||
|
||||
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]] | None) -> None:
|
||||
"""No-op — kept for agent_runner compatibility."""
|
||||
|
||||
|
||||
def clear_client_executor() -> None:
|
||||
"""No-op — kept for agent_runner compatibility."""
|
||||
|
||||
|
||||
async def execute_on_client(
|
||||
action: str,
|
||||
table: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
filters: dict[str, Any] | None = None,
|
||||
vector: list[float] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a tool_call to Electron via Redis and await the result.
|
||||
|
||||
1. Build tool_call payload
|
||||
2. Publish to ws:out:{user_id} (WS Gateway forwards to Electron)
|
||||
3. BRPOP on tool:result:{call_id} (WS Gateway pushes when Electron replies)
|
||||
4. Return result dict
|
||||
|
||||
Raises RuntimeError if no user_id is set or if the call times out.
|
||||
"""
|
||||
user_id = _current_user_id.get()
|
||||
if not user_id:
|
||||
raise RuntimeError(
|
||||
"execute_on_client() called without a user_id — "
|
||||
"set_current_user() must be called first."
|
||||
)
|
||||
|
||||
call_id = str(uuid4())
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool_call",
|
||||
"id": call_id,
|
||||
"action": action,
|
||||
}
|
||||
if table is not None:
|
||||
payload["table"] = table
|
||||
if data is not None:
|
||||
payload["data"] = data
|
||||
if filters is not None:
|
||||
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
||||
if vector is not None:
|
||||
payload["vector"] = vector
|
||||
if limit is not None:
|
||||
payload["limit"] = limit
|
||||
|
||||
# Publish tool_call to WS Gateway → Electron
|
||||
channel = ws_out_channel(user_id)
|
||||
await redis_client.publish(channel, json.dumps(payload))
|
||||
|
||||
# Wait for Electron's tool_result
|
||||
result_key = tool_result_key(call_id)
|
||||
response = await redis_client.brpop(result_key, timeout=_TOOL_CALL_TIMEOUT)
|
||||
|
||||
if response is None:
|
||||
raise RuntimeError(
|
||||
f"Tool call {call_id} timed out after {_TOOL_CALL_TIMEOUT}s — "
|
||||
f"device may be offline or unresponsive."
|
||||
)
|
||||
|
||||
# response is (key, value) tuple
|
||||
_, raw = response
|
||||
result = json.loads(raw)
|
||||
|
||||
# Collect for debug if requested
|
||||
collector = _tool_result_collector.get(None)
|
||||
if collector is not None:
|
||||
collector.append({
|
||||
"action": action,
|
||||
"table": table,
|
||||
"data": result,
|
||||
})
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user