11 Commits

Author SHA1 Message Date
Roberto Musso
333bba6fdd feat(batch-agent): extract Batch Agent Service (Step 3)
- agent_runner: local directory + cloud agent orchestration via Redis
- 5 domain agents: filesystem, task, note, project, timeline
- integrations: Gmail, MS Graph (Outlook + Teams)
- journey: guided chatbot conversation to build prompt_template
- routes: REST endpoints (catalog, can-create, trigger)
- redis_consumer: subscribes to batch:request:* pattern
- ws_context: Redis-based execute_on_client for tool round-trip
- Dockerfile with 300s timeout for long-running batch jobs
2026-03-23 07:19:02 +01:00
Roberto Musso
229e20d073 docs: add Langfuse integration TODO for batch-agent service 2026-03-23 00:25:42 +01:00
Roberto Musso
0b491b3643 fix: langfuse v4 SDK compatibility and pass user message as trace input 2026-03-23 00:23:59 +01:00
Roberto Musso
0d5fa3e569 feat(chat): integrate Langfuse tracing, prompt management & generation tracking
- shared/config.py: add LANGFUSE_SECRET_KEY, LANGFUSE_PUBLIC_KEY, LANGFUSE_HOST
- services/chat/app/tracing.py: new module — Langfuse client singleton,
  create_trace(), get_langfuse_callback(), get_prompt(), link_prompt_to_trace(),
  score_trace(), flush/shutdown helpers. Gracefully no-ops when keys are missing.
- services/chat/app/llm.py: add callbacks param to get_llm() for LangChain
  callback handler injection
- services/chat/app/deep_agent.py: accept langfuse_handler in all run_* and
  _run_single_agent* functions, pipe callbacks to LLM calls, fetch managed
  prompts from Langfuse with fallback to hardcoded system prompts
- services/chat/app/redis_consumer.py: create Langfuse trace per request
  (home_request/floating_request), pass callback handler to deep_agent,
  link prompt name to trace, attach output preview, flush after each request
- services/chat/app/main.py: shutdown Langfuse client in lifespan teardown
- services/chat/requirements.txt: add langfuse>=2.0.0

Langfuse prompt names: 'home_system', 'floating_system' — create these in
the Langfuse dashboard to manage prompts. Without them, hardcoded defaults
are used transparently.
2026-03-22 23:15:04 +01:00
Roberto Musso
aff68a9051 fix: shared config loads root .env as fallback for microservices 2026-03-22 22:42:54 +01:00
Roberto Musso
5e9ef2809e fix: add extra=ignore to monolith Settings for strangler fig compat 2026-03-22 22:28:50 +01:00
Roberto Musso
90018af311 feat: add WS Gateway and Chat Service (Step 2)
WS Gateway:
- WebSocket lifecycle handler with RS256 JWT auth
- Redis bridge: device registry, frame publishing, tool_result routing
- Inbound routing: tool_result→LPUSH, home/floating→chat pub/sub
- Outbound: subscribes to ws:out:{user_id}, forwards to Electron
- Single-worker Dockerfile (long-lived WS connections)

Chat Service:
- Redis consumer: subscribes to chat:request:* pattern
- Redis-based ws_context: tool_call→publish, BRPOP tool_result (30s timeout)
- deep_agent: single-agent runner with home/floating/stream variants
- memory_middleware: core/associative/episodic/proactive memory with Fernet
- Domain agents: task (8 tools), note (5), project (6), timeline (4)
- LLM factory via LiteLLM (100+ providers)
- Output formatter (StreamFormatter)
- POST /chat REST fallback with Traefik header auth
- Multi-worker Dockerfile with 120s timeout for LLM calls
2026-03-22 01:20:11 +01:00
Roberto Musso
1e2e395676 fix: PEM newline parsing + shared config extra=ignore
- Add field_validator to expand literal \n in PEM keys (auth config + shared config)
- Set extra='ignore' on shared Settings so service-specific .env vars don't cause ValidationError
- Add *.pem to .gitignore
2026-03-22 01:03:28 +01:00
Roberto Musso
59d3a53980 chore: update .env.example files for RS256 + Redis
- Root .env.example: replace JWT_SECRET/JWT_ALGORITHM with JWT_PUBLIC_KEY, add REDIS_URL
- Auth Service .env.example: JWT_PRIVATE_KEY + JWT_PUBLIC_KEY with generation instructions
2026-03-22 00:51:54 +01:00
Roberto Musso
9feeaa79c8 feat(auth): migrate JWT from HS256 to RS256
- Add services/auth/app/config.py with JWT_PRIVATE_KEY and JWT_PUBLIC_KEY
  (Auth Service local config - private key never leaves this service)
- Update routes.py: sign tokens with RS256 private key
- Update deps.py + verify.py: verify tokens with RS256 public key
- Update shared/config.py: replace JWT_SECRET/JWT_ALGORITHM with
  JWT_PUBLIC_KEY (for optional local verification by other services)
- Add sys.path fix in main.py for local dev without PYTHONPATH
2026-03-22 00:50:36 +01:00
Roberto Musso
aa219a4d08 feat: microservices scaffold + Auth Service (Step 1)
- Add shared/ module: config, db, models, schemas, redis utilities
- Add Auth Service (services/auth/): register, login, refresh, me,
  ForwardAuth /verify endpoint for Traefik
- Add Traefik config: ACME/Cloudflare DNS-01, dynamic routing,
  ForwardAuth middleware, sticky sessions for WS Gateway
- Add service scaffolds: ws-gateway, chat, batch-agent, billing (READMEs)
- Add redis>=5.0.0 to requirements.txt
- Monolith app/ is untouched — strangler fig migration
2026-03-22 00:29:51 +01:00
72 changed files with 8176 additions and 17 deletions

View File

@@ -4,9 +4,17 @@ ENV=dev
# ── Database ──────────────────────────────────────────────────────────────────
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
# ── Auth ──────────────────────────────────────────────────────────────────────
JWT_SECRET=replace-with-a-long-random-secret
JWT_ALGORITHM=HS256
# ── Redis ─────────────────────────────────────────────────────────────────────
REDIS_URL=redis://localhost:6379/0
# ── Auth (JWT RS256) ──────────────────────────────────────────────────────────
# Public key for optional local JWT verification (Traefik ForwardAuth handles
# this in production — services trust X-User-* headers from Traefik).
# Generate keypair:
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
# openssl rsa -in private.pem -pubout -out public.pem
# Paste PEM content with literal \n for newlines.
JWT_PUBLIC_KEY=
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
@@ -17,7 +25,6 @@ OPENAI_API_KEY=
ANTHROPIC_API_KEY=
GOOGLE_API_KEY=
LLM_MODEL=gpt-4o
LLM_ROUTER_MODEL=gpt-4o-mini
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
STRIPE_SECRET_KEY=
@@ -42,3 +49,8 @@ QDRANT_API_KEY=
# ── CORS ──────────────────────────────────────────────────────────────────────
# Comma-separated list parsed by Settings (override default if needed)
# CORS_ORIGINS=["app://.","http://localhost:3000"]
# ── Langfuse (observability) ─────────────────────────────────────────────────
LANGFUSE_SECRET_KEY=sk-lf-...
LANGFUSE_PUBLIC_KEY=pk-lf-...
LANGFUSE_HOST=https://cloud.langfuse.com # or self-hosted URL

3
.gitignore vendored
View File

@@ -13,6 +13,9 @@ env/
# Environment variables
.env
# Cryptographic keys
*.pem
# IDE
.vscode/
.idea/

View File

@@ -739,7 +739,7 @@ adiuva-api/
│ │
│ ├── core/ # Orchestration engine
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
│ │ ├── llm.py # LiteLLM factory (get_llm)
│ │ ├── orchestrator.py # Intent classification & routing
│ │ └── execution_plan.py # Plan builder, templates, cache
│ │

View File

@@ -29,7 +29,6 @@ class Settings(BaseSettings):
CEREBRAS_API_KEY: str = ""
LLM_MODEL: str = "gpt-4o"
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
LLM_EMBED_MODEL: str = "text-embedding-3-small"
# GitHub Copilot OAuth token storage directory.
@@ -54,7 +53,9 @@ class Settings(BaseSettings):
ENV: Literal["dev", "prod"] = "dev"
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
settings = Settings()

View File

@@ -1,6 +1,6 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
Every agent and the orchestrator call ``get_llm()``
instead of directly constructing a provider-specific class. The model string
follows the `LiteLLM model naming convention
<https://docs.litellm.ai/docs/providers>`_:
@@ -11,7 +11,7 @@ follows the `LiteLLM model naming convention
* Ollama: ``ollama/llama3``
* Bedrock: ``bedrock/anthropic.claude-v2``
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
Switch providers by changing **LLM_MODEL** in ``.env``
— no code changes required.
"""
@@ -95,14 +95,6 @@ def get_llm(
)
def get_router_llm(
*,
temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM:
"""Return the lighter model used for intent classification / routing."""
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
async def embed(text: str) -> list[float]:
"""Return an embedding vector for *text*.

View File

@@ -32,4 +32,6 @@ google-auth-oauthlib>=1.2.0
google-auth-httplib2>=0.2.0
msal>=1.28.0
cryptography>=42.0.0
redis>=5.0.0
langfuse>=3.0.0
ruff>=0.8.0

View File

@@ -0,0 +1,19 @@
# ── Auth Service ──────────────────────────────────────────────────────────────
# This file contains env vars specific to the Auth Service.
# Shared vars (DATABASE_URL, REDIS_URL, etc.) come from the root .env
# or from docker-compose environment.
# ── JWT RS256 Keys ────────────────────────────────────────────────────────────
# Generate keypair:
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
# openssl rsa -in private.pem -pubout -out public.pem
#
# Paste PEM content with literal \n for newlines:
# JWT_PRIVATE_KEY=-----BEGIN PRIVATE KEY-----\nMIIEvQ...
# JWT_PUBLIC_KEY=-----BEGIN PUBLIC KEY-----\nMIIBIj...
# PRIVATE KEY — used to SIGN JWTs. NEVER share outside this service.
JWT_PRIVATE_KEY=
# PUBLIC KEY — used to VERIFY JWTs.
JWT_PUBLIC_KEY=

36
services/auth/Dockerfile Normal file
View File

@@ -0,0 +1,36 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
# Install shared + service deps in one layer
COPY services/auth/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Copy shared module (available to all services)
COPY shared/ shared/
# Copy service source
COPY services/auth/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "2", \
"--timeout", "30"]

16
services/auth/README.md Normal file
View File

@@ -0,0 +1,16 @@
# Auth Service
Owns: user registration, login, JWT RS256 issuance, token refresh, `/me` endpoint.
## Tables owned
- `users`
- `refresh_tokens`
- `subscriptions` (read; Billing Service writes)
## Endpoints
- `POST /auth/register`
- `POST /auth/login`
- `POST /auth/refresh`
- `GET /auth/me`
- `PUT /auth/me`
- `GET /auth/verify` (ForwardAuth for Traefik)

View File

View File

@@ -0,0 +1,34 @@
"""Auth Service — local configuration.
Contains secrets that ONLY the Auth Service needs (e.g., JWT private key).
These are NOT in shared/config.py to prevent other services from accessing them.
"""
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class AuthSettings(BaseSettings):
# RS256 private key (PEM format). Used to SIGN JWTs.
# Only the Auth Service has this. Generate with:
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
# Then set the env var (newlines as \n):
# JWT_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\nMIIEv..."
JWT_PRIVATE_KEY: str = ""
# RS256 public key (PEM format). Used to VERIFY JWTs.
# Derived from the private key:
# openssl rsa -in private.pem -pubout -out public.pem
JWT_PUBLIC_KEY: str = ""
@field_validator("JWT_PRIVATE_KEY", "JWT_PUBLIC_KEY", mode="before")
@classmethod
def _expand_pem_newlines(cls, v: str) -> str:
if isinstance(v, str) and r"\n" in v:
return v.replace(r"\n", "\n")
return v
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
auth_settings = AuthSettings()

69
services/auth/app/deps.py Normal file
View File

@@ -0,0 +1,69 @@
"""Auth dependencies — JWT validation for the Auth Service.
This is the canonical get_current_user used by protected endpoints
within the Auth Service itself (/me, /me PUT).
"""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from shared.config import settings
from shared.db import get_session
from shared.models import Subscription, User
from shared.schemas import UserProfile
from app.config import auth_settings
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Validate a Bearer JWT and return the authenticated user.
The JWT is used for identity and expiry. Tier is fetched live from the
subscriptions table so upgrades/downgrades take effect immediately.
"""
credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id or not email:
raise credentials_exc
except JWTError:
raise credentials_exc
# Live tier lookup
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
# Fetch name/surname
user_result = await db.execute(
select(User.name, User.surname).where(User.id == user_id)
)
user_row = user_result.one_or_none()
return UserProfile(
id=user_id,
email=email,
name=user_row.name if user_row else None,
surname=user_row.surname if user_row else None,
tier=tier,
) # type: ignore[arg-type]

62
services/auth/app/main.py Normal file
View File

@@ -0,0 +1,62 @@
"""Auth Service — JWT issuance, user management, ForwardAuth verification.
Standalone FastAPI service extracted from the adiuva-api monolith.
Owns: users, refresh_tokens, subscriptions (read).
"""
import sys
from contextlib import asynccontextmanager
from pathlib import Path
# Ensure the repo root is on sys.path so "shared" is importable.
# In Docker, COPY shared/ puts it at /app/shared/ (already importable).
# In local dev, we need to add the repo root (two levels up from this file).
_repo_root = str(Path(__file__).resolve().parents[3])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from shared.config import settings
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
from shared.db import engine
await engine.dispose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva Auth Service",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from app.routes import router
from app.verify import router as verify_router
app.include_router(router, prefix="/api/v1")
app.include_router(verify_router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "service": "auth", "version": app.version}
return app
app = create_app()

249
services/auth/app/routes.py Normal file
View File

@@ -0,0 +1,249 @@
"""Auth routes: register, login, refresh, me.
Extracted from app/api/routes/auth.py — uses shared.* imports instead of app.*.
"""
from __future__ import annotations
import hashlib
import time
import uuid
from datetime import datetime, timedelta, timezone
import bcrypt
from cryptography.fernet import Fernet
from fastapi import APIRouter, Depends, HTTPException, status
from jose import jwt
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from shared.config import settings
from shared.db import get_session
from shared.models import RefreshToken, Subscription, User
from shared.schemas import AuthTokens, UserProfile
from app.config import auth_settings
from app.deps import get_current_user
router = APIRouter(prefix="/auth", tags=["auth"])
# ── Internal helpers ─────────────────────────────────────────────────
def _hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def _verify_password(password: str, hashed: str) -> bool:
return bcrypt.checkpw(password.encode(), hashed.encode())
def _hash_token(plain_token: str) -> str:
"""SHA-256 of the plain refresh token string."""
return hashlib.sha256(plain_token.encode()).hexdigest()
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
"""Return (RS256-signed JWT, expires_at_ms)."""
now = int(time.time())
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
payload = {
"sub": user_id,
"email": email,
"tier": tier,
"exp": exp,
"iat": now,
}
token = jwt.encode(payload, auth_settings.JWT_PRIVATE_KEY, algorithm="RS256")
return token, exp * 1000 # ms for client
async def _get_live_tier(db: AsyncSession, user_id: str) -> str:
"""Fetch authoritative tier from subscriptions table."""
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
return result.scalar_one_or_none() or default_tier
# ── Request bodies ────────────────────────────────────────────────────
class _RegisterRequest(BaseModel):
email: str
password: str
name: str | None = None
surname: str | None = None
class _LoginRequest(BaseModel):
email: str
password: str
class _RefreshRequest(BaseModel):
refresh_token: str
class _UpdateProfileRequest(BaseModel):
name: str | None = None
surname: str | None = None
# ── Routes ────────────────────────────────────────────────────────────
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
async def register(
body: _RegisterRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Create a new account and return JWT tokens."""
existing = await db.execute(select(User).where(User.email == body.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
user = User(
id=str(uuid.uuid4()),
email=body.email,
name=body.name,
surname=body.surname,
password_hash=_hash_password(body.password),
tier="free",
encryption_key=Fernet.generate_key().decode(),
)
db.add(user)
await db.flush()
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/login", response_model=AuthTokens)
async def login(
body: _LoginRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Validate credentials and return JWT tokens."""
result = await db.execute(select(User).where(User.email == body.email))
user = result.scalar_one_or_none()
if user is None or not _verify_password(body.password, user.password_hash):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
# Fetch live tier for the JWT claim
tier = await _get_live_tier(db, user.id)
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/refresh", response_model=AuthTokens)
async def refresh(
body: _RefreshRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Rotate a refresh token and return a new token pair."""
token_hash = _hash_token(body.refresh_token)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
rt = result.scalar_one_or_none()
now = datetime.now(timezone.utc)
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
await db.delete(rt)
user_result = await db.execute(select(User).where(User.id == rt.user_id))
user = user_result.scalar_one_or_none()
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
# Fetch live tier for the new JWT
tier = await _get_live_tier(db, user.id)
plain_token = str(uuid.uuid4())
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
new_rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=new_expires,
)
db.add(new_rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.get("/me", response_model=UserProfile)
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
"""Return the profile for the authenticated user."""
return current_user
@router.put("/me", response_model=UserProfile)
async def update_profile(
body: _UpdateProfileRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Update the authenticated user's name and surname."""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
if body.name is not None:
user.name = body.name
if body.surname is not None:
user.surname = body.surname
await db.commit()
await db.refresh(user)
return UserProfile(
id=user.id,
email=user.email,
name=user.name,
surname=user.surname,
tier=current_user.tier,
)

View File

@@ -0,0 +1,66 @@
"""ForwardAuth verification endpoint for Traefik.
Traefik calls GET /api/v1/auth/verify on every request to a protected
service. This endpoint validates the JWT from the Authorization header
and returns identity headers that Traefik injects into downstream requests.
Downstream services NEVER validate JWTs themselves — they trust the
X-User-Id, X-User-Email, X-User-Tier headers injected by Traefik.
"""
from __future__ import annotations
from fastapi import APIRouter, Request, Response
from fastapi import status as http_status
from jose import JWTError, jwt
from sqlalchemy import select
from shared.config import settings
from shared.db import async_session
from shared.models import Subscription
from app.config import auth_settings
router = APIRouter(tags=["auth"])
@router.get("/auth/verify")
async def verify(request: Request) -> Response:
"""Validate JWT and return identity headers for Traefik ForwardAuth.
Returns 200 with X-User-* headers on success, 401 on failure.
Traefik copies response headers to the downstream request.
"""
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
token = auth_header[7:] # strip "Bearer "
try:
payload = jwt.decode(
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id or not email:
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
except JWTError:
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
# Live tier lookup from subscriptions table
async with async_session() as db:
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
return Response(
status_code=http_status.HTTP_200_OK,
headers={
"X-User-Id": user_id,
"X-User-Email": email,
"X-User-Tier": tier,
},
)

View File

@@ -0,0 +1,11 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-jose[cryptography]>=3.3.0
sqlalchemy>=2.0.0
asyncpg>=0.30.0
bcrypt>=4.2.0
cryptography>=42.0.0
python-dotenv>=1.0.0

View File

@@ -0,0 +1,36 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
COPY services/batch-agent/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Shared module
COPY shared/ shared/
# Service source
COPY services/batch-agent/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
# Batch runs are long-lived — use a longer timeout than chat (300s vs 120s)
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "2", \
"--timeout", "300"]

View File

@@ -0,0 +1,23 @@
# Batch Agent Service
Owns: agent_runner, journey builder, filesystem_agent, integrations (Gmail, MS Graph).
## Tables owned
- `local_agent_configs`
- `cloud_agent_configs`
- `agent_run_logs`
## Endpoints
- `GET /agents/catalog`
- `POST /agents/can-create`
- `POST /agents/trigger`
- `GET /agents/{id}/history`
## Redis channels
- Subscribe: `batch:request:{user_id}`
- Publish: `ws:out:{user_id}` (journey replies + tool calls)
- BRPOP: `tool:result:{call_id}` (30s timeout)
- SET+EX: `journey:{user_id}` (session state, TTL 1800s)
## TODO
- [ ] Integrate Langfuse tracing (reuse `services/chat/app/tracing.py` pattern — `trace_span()`, `get_langfuse_callback()`, prompt management). Each batch agent run should create a trace with input/output, link prompts, and pass the LangChain `CallbackHandler` to LLM calls.

View File

View File

@@ -0,0 +1,884 @@
"""Agent run orchestrator — adapted for Batch Agent Service.
Key changes from monolith app/core/agent_runner.py:
- No DeviceConnectionManager — tool calls go through Redis ws_context.
- set_current_user / clear_current_user replace set_client_executor.
- run_local_agent accepts a serialized dict (from Redis / REST) instead
of SQLAlchemy model objects.
- _finalize_run writes to PostgreSQL via shared.db.async_session.
- Cloud agent import path changed to app.integrations.
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from sqlalchemy import select
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.llm import get_llm
from app.ws_context import execute_on_client, set_current_user, clear_current_user
from shared.db import async_session
from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
from shared.redis import redis_client, ws_out_channel
logger = logging.getLogger(__name__)
# ── Concurrency guard ─────────────────────────────────────────────────────
_running_agents: set[str] = set()
def is_agent_running(agent_id: str) -> bool:
return agent_id in _running_agents
# ── Timeouts ───────────────────────────────────────────────────────────────
_TOOL_CALL_TIMEOUT: int = 30
_MAX_PROCESSING_STEPS: int = 12
_MAX_SCAN_DEPTH: int = 5
# ── Data-type to tool mapping ─────────────────────────────────────────────
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
"tasks": TASK_TOOLS,
"notes": NOTE_TOOLS,
"timelines": TIMELINE_TOOLS,
}
# ── Step 1: Classification prompt ─────────────────────────────────────────
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
"tasks": (
"Action items, to-dos, deliverables — anything that describes work to be done, "
"assigned to someone, or tracked with a due date or status."
),
"notes": (
"Documentation, meeting notes, summaries, reference material — "
"written content meant to be read and referenced rather than acted on."
),
"timelines": (
"Project milestones, deadlines, scheduled events — "
"specific dates that mark a point in the progress of a project."
),
"projects": (
"High-level project entities — only relevant if the file clearly introduces "
"a new project or updates the scope of an existing one."
),
}
_STEP1_SYSTEM_PROMPT = """\
You are a file classifier for a freelance project management tool.
Your job is to match a file to an existing project and identify which data domains to extract.
## Project matching rules (STRICT — follow in order)
1. Search the file content for any mention of a project name, client name, acronym, or topic
that overlaps with the existing projects listed below.
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
when the file has zero meaningful connection to any listed project.
4. When in doubt, pick the closest match from the list.
## Response format
Respond ONLY with a JSON object — no markdown, no explanation:
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
## Domain definitions (only consider domains in the allowed list)
{domain_definitions}
## Existing projects
{projects_list}
"""
# ── Step 2: Processing prompt ─────────────────────────────────────────────
_PROCESSING_SYSTEM_PROMPT = """\
You are a data extraction assistant for a freelance project management tool.
Your task: extract structured data from the file content and persist it using the available tools.
## Mandatory process — follow this order for EVERY item you extract
1. READ the existing records listed below for the relevant domain.
2. SEARCH for a match by title, topic, or semantic similarity.
3. If a match exists → call the update_* tool with the existing record's id.
4. If no match exists → call the create_* tool and set isAiSuggested=1.
NEVER call create_* without first checking the existing records.
NEVER duplicate a record that already exists under a different wording.
## Existing records (source of truth)
{existing_context}
## Context
Project: {project_context}
Domains to extract: {data_types}
{custom_prompt_section}
"""
# ── Cloud processing prompt ───────────────────────────────────────────────
_CLOUD_PROCESSING_PROMPT = """\
You are a data extraction and management assistant for a freelance project
management tool.
Available tools:
Filesystem : read_file_content, list_directory, get_file_metadata
Tasks : list_tasks, create_task, update_task, add_task_comment
Notes : list_notes, get_note, create_note, update_note
Timelines : list_timelines, create_timeline, update_timeline
Projects : list_all_projects, get_project, create_project, update_project
Your task:
1. Read the full content of each file below using read_file_content.
2. For each piece of information found, ALWAYS try to match and update an
existing record before creating a new one.
3. ONLY act on these entity types: {data_types}.
4. Do NOT invent data. Only extract what is clearly present in the files.
5. If a file contains no relevant data for the target entity types, skip it.
{project_context}
Files to process:
{file_list}
{custom_prompt_section}
After processing all files, respond with a brief summary of what you updated
and what you created.
"""
# ── LLM tool-calling loop ─────────────────────────────────────────────────
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _run_agent_with_tools(
*,
system_prompt: str,
user_message: str,
tools: list[Any],
max_steps: int,
) -> str:
"""Run an LLM agent with tool-calling, returning the final text response."""
llm = get_llm()
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_message),
]
tool_map = {tool_def.name: tool_def for tool_def in tools}
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
for call in response.tool_calls:
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"agent_runner: tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"agent_runner: tool_result name=%s output=%s",
call_name,
str(tool_output)[:200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
return _as_text(final.content)
# ── Tool list builder ─────────────────────────────────────────────────────
def _build_processing_tools(data_types: list[str]) -> list[Any]:
tools: list[Any] = list(FILESYSTEM_TOOLS)
for dt in data_types:
dt_tools = _DATA_TYPE_TOOLS.get(dt)
if dt_tools:
tools.extend(dt_tools)
return tools
# ── Code-based directory scanner ─────────────────────────────────────────
async def _scan_directories(
paths: list[str],
extensions: list[str],
last_run_at: datetime | None,
) -> list[str]:
all_files: list[str] = []
ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set()
async def _walk(path: str, depth: int) -> None:
if depth > _MAX_SCAN_DEPTH:
return
try:
result = await execute_on_client(action="list_directory", data={"path": path})
except Exception as exc:
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
return
for entry in result.get("entries", []):
entry_path = entry.get("path", "")
if not entry_path:
continue
if entry.get("type") == "directory":
await _walk(entry_path, depth + 1)
elif entry.get("type") == "file":
if ext_set:
dot_pos = entry_path.rfind(".")
file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else ""
if file_ext not in ext_set:
continue
all_files.append(entry_path)
for root in paths:
await _walk(root, depth=0)
if last_run_at is None:
return all_files
last_run_ms = int(last_run_at.timestamp() * 1000)
filtered: list[str] = []
for file_path in all_files:
try:
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
modified_at = meta.get("modifiedAt")
if modified_at is None:
filtered.append(file_path)
continue
if isinstance(modified_at, (int, float)):
mod_ms = int(modified_at)
else:
mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000)
if mod_ms > last_run_ms:
filtered.append(file_path)
except Exception:
filtered.append(file_path)
return filtered
# ── Code-based entity fetchers ────────────────────────────────────────────
async def _fetch_projects() -> list[dict]:
try:
result = await execute_on_client(action="select", table="projects")
return result.get("rows", [])
except Exception as exc:
logger.warning("agent_runner: failed to fetch projects: %s", exc)
return []
_DOMAIN_TABLE: dict[str, str] = {
"tasks": "tasks",
"notes": "notes",
"timelines": "timelines",
"projects": "projects",
}
async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
table = _DOMAIN_TABLE.get(domain)
if not table:
return []
filters: dict[str, Any] = {}
if project_id != "standalone" and domain != "projects":
filters["projectId"] = project_id
try:
result = await execute_on_client(
action="select",
table=table,
filters=filters if filters else None,
)
return result.get("rows", [])
except Exception as exc:
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
return []
def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
if not rows:
return f"No existing {domain}."
lines: list[str] = []
for r in rows:
if domain == "tasks":
desc = r.get("description") or ""
desc_part = f"{desc[:120]}" if desc else ""
assignee = r.get("assignee") or r.get("assignees") or ""
due = r.get("dueDate") or r.get("due_date") or ""
meta = ", ".join(filter(None, [
f"priority: {r.get('priority', '')}" if r.get("priority") else "",
f"assignee: {assignee}" if assignee else "",
f"due: {due}" if due else "",
]))
lines.append(
f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}"
f" ({meta}, id: {r['id']})"
)
elif domain == "notes":
snippet = (r.get("content") or "")[:200].replace("\n", " ")
snippet_part = f"\n Preview: {snippet}" if snippet else ""
lines.append(
f" - {r.get('title', '')} (id: {r['id']}){snippet_part}"
)
elif domain == "timelines":
lines.append(
f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})"
)
elif domain == "projects":
summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120]
summary_part = f"{summary}" if summary else ""
lines.append(
f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}"
f" (id: {r['id']})"
)
return f"Existing {domain}:\n" + "\n".join(lines)
# ── Step 1: LLM file classifier ───────────────────────────────────────────
async def _classify_file(
file_path: str,
file_content: str,
projects: list[dict],
config_data_types: list[str],
) -> tuple[str, list[str], str | None]:
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
if not file_content.strip():
return fallback
valid_project_ids = {p["id"] for p in projects}
def _fmt_project(p: dict) -> str:
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
summary_part = f"{summary[:100]}" if summary else ""
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
domain_definitions = "\n".join(
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
for d in config_data_types
if d in _DOMAIN_DESCRIPTIONS
)
system = _STEP1_SYSTEM_PROMPT.format(
domain_definitions=domain_definitions,
projects_list=projects_list,
)
llm = get_llm()
try:
response = await llm.ainvoke([
SystemMessage(content=system),
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
])
raw = _as_text(response.content).strip()
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
parsed = json.loads(raw.strip())
raw_project_id: str = str(parsed.get("project_id") or "new")
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
new_project_name: str | None = (
str(parsed["new_project_name"]).strip() or None
if project_id == "new" and parsed.get("new_project_name")
else None
)
domains: list[str] = [
d for d in parsed.get("domains", [])
if d in config_data_types
]
if not domains:
domains = list(config_data_types)
return project_id, domains, new_project_name
except Exception as exc:
logger.warning(
"agent_runner: step1 classification failed for %r: %s", file_path, exc
)
return fallback
# ── Local agent runner (two-step per file) ────────────────────────────────
async def run_local_agent(user_id: str, trigger_data: dict[str, Any]) -> None:
"""Execute a local directory agent run.
In the microservice world, trigger_data is a serialized dict from
the REST route (forwarded via Redis), containing the agent config
fields and run_context.
set_current_user() must be called BEFORE this function.
"""
run_context: dict = trigger_data.get("run_context", {})
agent_id = run_context.get("agent_id", str(uuid.uuid4()))
run_id = run_context.get("run_id")
_running_agents.add(agent_id)
# Extract config from trigger payload
directory_paths: list[str] = trigger_data.get("directory_paths", [])
if not directory_paths:
directory = trigger_data.get("directory", "")
if directory:
directory_paths = [directory]
data_types: list[str] = trigger_data.get("data_types", [])
file_extensions: list[str] = trigger_data.get("file_extensions", [])
prompt_template: str = trigger_data.get("prompt_template", "")
last_run_at_raw = trigger_data.get("last_run_at")
last_run_at: datetime | None = None
if last_run_at_raw:
if isinstance(last_run_at_raw, str):
last_run_at = datetime.fromisoformat(last_run_at_raw)
elif isinstance(last_run_at_raw, (int, float)):
last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc)
errors: list[str] = []
items_processed = 0
items_created = 0
custom_section = (
f"User instructions:\n{prompt_template}"
if prompt_template
else ""
)
# Create or load run log
run_log_id = run_id
if not run_log_id:
async with async_session() as db:
run_log = AgentRunLog(
agent_id=agent_id,
agent_type="local",
user_id=user_id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_log_id = run_log.id
try:
# ── Scan directories ─────────────────────────────────────────
logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id)
file_paths = await _scan_directories(
paths=directory_paths,
extensions=file_extensions,
last_run_at=last_run_at,
)
logger.info(
"agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths)
)
if not file_paths:
await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0)
return
# ── Fetch all projects once ──────────────────────────────────
projects = await _fetch_projects()
for file_path in file_paths:
try:
file_result = await execute_on_client(
action="read_file_content", data={"path": file_path}
)
file_content: str = file_result.get("content", "")
if not file_content:
continue
items_processed += 1
# Step 1 — classify file
project_id, domains, new_project_name = await _classify_file(
file_path=file_path,
file_content=file_content,
projects=projects,
config_data_types=data_types,
)
# Step 2 — resolve project_id, fetch entities, process
if project_id == "new":
proj_name = new_project_name or "Untitled Project"
try:
proj_result = await execute_on_client(
action="insert",
table="projects",
data={"name": proj_name, "clientId": None},
)
created = proj_result.get("row", {})
effective_project_id = created.get("id", "standalone")
if "id" in created:
projects.append(created)
except Exception as exc:
logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc)
effective_project_id = "standalone"
proj_name = "unknown"
project_context = (
f"Project: {proj_name} (id: {effective_project_id}). "
"Always set projectId to this id on every record you create."
)
else:
effective_project_id = project_id
proj = next((p for p in projects if p["id"] == project_id), None)
proj_name = proj.get("name", project_id) if proj else project_id
project_context = (
f"Project: {proj_name} (id: {project_id}). "
"Always set projectId to this id on every record you create."
)
domains = [d for d in domains if d != "projects"]
existing_blocks: list[str] = []
for domain in domains:
rows = await _fetch_domain_entities(domain, effective_project_id)
existing_blocks.append(_format_entities_for_context(domain, rows))
existing_context = "\n\n".join(existing_blocks)
system_prompt = _PROCESSING_SYSTEM_PROMPT.format(
existing_context=existing_context,
project_context=project_context,
data_types=", ".join(domains),
custom_prompt_section=custom_section,
)
processing_tools = _build_processing_tools(domains)
result_text = await _run_agent_with_tools(
system_prompt=system_prompt,
user_message=(
f"Process this file and extract relevant information.\n\n"
f"File: {file_path}\n\nContent:\n{file_content}"
),
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
)
logger.info(
"agent_runner: run=%s file=%r result=%s",
run_log_id, file_path, result_text[:200],
)
except Exception as exc:
errors.append(f"Error processing '{file_path}': {exc}")
logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc)
except Exception as exc:
errors.append(f"Agent run failed: {exc}")
logger.error("agent_runner: run=%s failed: %s", run_log_id, exc)
finally:
_running_agents.discard(agent_id)
# ── Finalise ────────────────────────────────────────────────────
if errors and items_processed == 0:
final_status = "error"
elif errors:
final_status = "partial"
else:
final_status = "success"
await _finalize_run(
run_log_id,
status=final_status,
items_processed=items_processed,
items_created=items_created,
errors=errors,
)
# Notify Electron that the run is complete via Redis
if run_context:
try:
channel = ws_out_channel(user_id)
await redis_client.publish(channel, json.dumps({
"type": "run_complete",
"run_context": run_context,
"status": final_status,
}))
except Exception as exc:
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc)
# ── Cloud agent runner ─────────────────────────────────────────────────────
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
async def run_cloud_agent(user_id: str, config_id: str) -> None:
"""Execute a cloud connector agent run.
Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches
messages from the provider, and runs LLM extraction.
set_current_user() must be called BEFORE this function.
"""
from app.integrations import decrypt_token, encrypt_token, get_provider
async with async_session() as db:
result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
)
config = result.scalar_one_or_none()
if config is None:
logger.error("agent_runner: cloud config %s not found", config_id)
return
# Create run log
run_log = AgentRunLog(
agent_id=config.id,
agent_type="cloud",
user_id=user_id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_log_id = run_log.id
# ── Decrypt OAuth token ────────────────────────────────────────
if not config.oauth_token_encrypted:
await _finalize_run(
run_log_id,
status="error",
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
)
return
try:
credentials_info = decrypt_token(config.oauth_token_encrypted)
except ValueError as exc:
await _finalize_run(
run_log_id,
status="error",
errors=[f"Failed to decrypt OAuth token: {exc}"],
)
return
# ── Instantiate provider ──────────────────────────────────────
try:
provider = get_provider(config.provider, credentials_info)
except ValueError as exc:
await _finalize_run(run_log_id, status="error", errors=[str(exc)])
return
# ── Fetch messages ────────────────────────────────────────────
since: datetime | None = config.last_run_at
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
if since.tzinfo is None:
since = since.replace(tzinfo=timezone.utc)
errors: list[str] = []
items_processed = 0
try:
if config.provider == "gmail":
raw_messages = await provider.fetch_messages(
filter_config=config.filter_config,
since=since,
)
elif config.provider == "outlook":
raw_messages = await provider.fetch_emails(
filter_config=config.filter_config,
since=since,
)
elif config.provider == "teams":
raw_messages = await provider.fetch_messages(
filter_config=config.filter_config,
since=since,
)
else:
raw_messages = []
except RuntimeError as exc:
await _finalize_run(
run_log_id,
status="error",
errors=[f"Provider fetch failed: {exc}"],
update_config_last_run=True,
config_id=config.id,
config_type="cloud",
)
return
logger.info(
"agent_runner: cloud agent %s fetched %d item(s) from %s",
config.id, len(raw_messages), config.provider,
)
# ── Extract + insert via LLM ─────────────────────────────────
try:
processing_tools = _build_processing_tools(config.data_types)
custom_section = (
f"User instructions:\n{config.prompt_template}"
if config.prompt_template
else ""
)
for msg in raw_messages:
content_text = msg.as_text
if not content_text:
continue
items_processed += 1
processing_prompt = _CLOUD_PROCESSING_PROMPT.format(
data_types=", ".join(config.data_types),
project_context="Determine the appropriate project from the message context.",
file_list=f"Message from {config.provider} (id: {msg.id})",
custom_prompt_section=custom_section,
)
try:
await _run_agent_with_tools(
system_prompt=processing_prompt,
user_message=f"Process this message content:\n\n{content_text[:8000]}",
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
)
except Exception as exc:
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
except Exception as exc:
errors.append(f"Agent run failed: {exc}")
# ── Persist refreshed token ───────────────────────────────────
refreshed = getattr(provider, "refreshed_credentials", None)
if refreshed:
try:
new_encrypted = encrypt_token(refreshed)
async with async_session() as db:
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
)
cfg_row = cfg_result.scalar_one_or_none()
if cfg_row:
cfg_row.oauth_token_encrypted = new_encrypted
await db.commit()
except Exception as exc:
logger.warning("agent_runner: failed to persist refreshed token: %s", exc)
# ── Finalise ──────────────────────────────────────────────────
if errors and items_processed == 0:
final_status = "error"
elif errors:
final_status = "partial"
else:
final_status = "success"
await _finalize_run(
run_log_id,
status=final_status,
items_processed=items_processed,
items_created=0,
errors=errors,
update_config_last_run=True,
config_id=config.id,
config_type="cloud",
)
# ── Internal helper ─────────────────────────────────────────────────────────
async def _finalize_run(
run_log_id: int | str,
*,
status: str,
items_processed: int = 0,
items_created: int = 0,
errors: list[str] | None = None,
update_config_last_run: bool = False,
config_id: str | None = None,
config_type: str | None = None,
) -> None:
"""Persist the run outcome and optionally update last_run_at on the config."""
now = datetime.now(timezone.utc)
try:
async with async_session() as db:
result = await db.execute(
select(AgentRunLog).where(AgentRunLog.id == run_log_id)
)
managed = result.scalar_one_or_none()
if managed is None:
logger.warning("agent_runner: run_log %s not found for finalization", run_log_id)
return
managed.status = status
managed.items_processed = items_processed
managed.items_created = items_created
managed.errors = errors or []
managed.completed_at = now
if update_config_last_run and config_id:
if config_type == "local":
cfg_result = await db.execute(
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
elif config_type == "cloud":
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
await db.commit()
except Exception as exc:
logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc)

View File

@@ -0,0 +1 @@
"""Batch Agent Service domain agents and filesystem tools."""

View File

@@ -0,0 +1,83 @@
"""Filesystem agent — tools for reading local directories and files on Electron.
Adapted for Batch Agent Service: import from app.ws_context.
"""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
@tool
async def list_directory(path: str) -> str:
"""List files and folders in a local directory on the user's device.
Returns a formatted listing of entries with name, type (file/directory),
and full path.
"""
result = await execute_on_client(
action="list_directory",
data={"path": path},
)
entries: list[dict[str, Any]] = result.get("entries", [])
if not entries:
return f"Directory '{path}' is empty or does not exist."
lines: list[str] = []
for entry in entries:
entry_type = entry.get("type", "unknown")
entry_name = entry.get("name", "")
entry_path = entry.get("path", "")
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
@tool
async def read_file_content(path: str) -> str:
"""Read the text content of a local file on the user's device.
Returns the file content as a string. Large files may be truncated
by the Electron client.
"""
result = await execute_on_client(
action="read_file_content",
data={"path": path},
)
content: str = result.get("content", "")
if not content:
return f"File '{path}' is empty or could not be read."
return content
@tool
async def get_file_metadata(path: str) -> str:
"""Get metadata for a local file: size, creation date, modification date, extension.
Returns a formatted summary of the file's metadata.
"""
result = await execute_on_client(
action="get_file_metadata",
data={"path": path},
)
size = result.get("size", "unknown")
created = result.get("createdAt", "unknown")
modified = result.get("modifiedAt", "unknown")
extension = result.get("extension", "unknown")
name = result.get("name", path)
return (
f"File: {name}\n"
f" Extension: {extension}\n"
f" Size: {size} bytes\n"
f" Created: {created}\n"
f" Modified: {modified}"
)
FILESYSTEM_TOOLS: list[Any] = [
list_directory,
read_file_content,
get_file_metadata,
]

View File

@@ -0,0 +1,110 @@
"""Note agent — Markdown note management.
Adapted for Batch Agent Service: import from app.ws_context and app.llm.
"""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.llm import embed
from app.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
@tool
async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="notes",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No notes found."
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
@tool
async def get_note(note_id: str) -> str:
"""Fetch a single note by its UUID to read its full Markdown content."""
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
row = result.get("row")
if not row:
return f"Note {note_id} not found."
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
@tool
async def create_note(title: str, content: str, project_id: str = "") -> str:
"""Create a new note."""
result = await execute_on_client(
action="insert",
table="notes",
data={
"title": title,
"content": content,
"projectId": project_id or None,
},
)
row = result["row"]
vector = await embed(content)
await execute_on_client(
action="vector_upsert",
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note created: '{row['title']}' (id: {row['id']})."
@tool
async def update_note(note_id: str, title: str = "", content: str = "") -> str:
"""Update an existing note. Only pass fields that should change."""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if content:
updates["content"] = content
result = await execute_on_client(
action="update",
table="notes",
data={"id": note_id, "updates": updates},
)
row = result["row"]
if content:
vector = await embed(content)
await execute_on_client(
action="vector_upsert",
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note updated: '{row['title']}' (id: {row['id']})."
@tool
async def delete_note(note_id: str) -> str:
"""Delete a note permanently by its UUID."""
await execute_on_client(action="delete", table="notes", data={"id": note_id})
return f"Note {note_id} deleted."
NOTE_TOOLS: list[Any] = [
list_notes,
get_note,
create_note,
update_note,
delete_note,
]

View File

@@ -0,0 +1,110 @@
"""Project agent — full lifecycle management.
Adapted for Batch Agent Service: import from app.ws_context.
"""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
@tool
async def list_projects(client_id: str = "", include_archived: int = 0) -> str:
"""List projects, optionally filtered by client_id."""
result = await execute_on_client(
action="select",
table="projects",
filters={
"clientId": client_id or None,
"includeArchived": bool(include_archived),
},
)
rows = result.get("rows", [])
if not rows:
return "No projects found."
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
@tool
async def list_all_projects() -> str:
"""List every project regardless of client or status."""
result = await execute_on_client(action="select", table="projects")
rows = result.get("rows", [])
if not rows:
return "No projects found."
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
@tool
async def get_project(project_id: str) -> str:
"""Fetch a single project by its UUID."""
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
row = result.get("row")
if not row:
return f"Project {project_id} not found."
return (
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
f"clientId: {row.get('clientId', 'none')})"
)
@tool
async def create_project(name: str, client_id: str = "") -> str:
"""Create a new project."""
result = await execute_on_client(
action="insert",
table="projects",
data={"name": name, "clientId": client_id or None},
)
row = result["row"]
return f"Project created: '{row['name']}' (id: {row['id']})"
@tool
async def update_project(
project_id: str,
name: str = "",
client_id: str = "",
status: str = "",
ai_summary: str = "",
) -> str:
"""Update a project. Only pass fields that should change."""
updates: dict[str, Any] = {}
if name:
updates["name"] = name
if client_id:
updates["clientId"] = client_id
if status:
updates["status"] = status
if ai_summary:
updates["aiSummary"] = ai_summary
result = await execute_on_client(
action="update",
table="projects",
data={"id": project_id, "updates": updates},
)
row = result["row"]
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
@tool
async def delete_project(project_id: str) -> str:
"""Permanently delete a project."""
await execute_on_client(action="delete", table="projects", data={"id": project_id})
return f"Project {project_id} permanently deleted."
PROJECT_TOOLS: list[Any] = [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]

View File

@@ -0,0 +1,197 @@
"""Task agent — full CRUD for tasks and task comments.
Adapted for Batch Agent Service: import from app.ws_context.
"""
from __future__ import annotations
from datetime import datetime, timezone
import re
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
@tool
async def list_tasks(
project_id: str = "",
status: str = "",
search: str = "",
order_by: str = "",
) -> str:
"""List tasks, optionally filtered by project_id, status, search, or order_by."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="tasks",
filters={
"projectId": normalized_project_id or None,
"status": status or None,
"search": search or None,
"orderBy": order_by or None,
},
)
rows = result.get("rows", [])
if not rows:
return "No tasks found matching the given filters."
lines = [
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
for r in rows
]
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
@tool
async def create_task(
title: str,
description: str = "",
status: str = "todo",
priority: str = "medium",
assignees: str = "[]",
due_date: int = 0,
project_id: str = "",
is_ai_suggested: int = 0,
) -> str:
"""Create a new task."""
result = await execute_on_client(
action="insert",
table="tasks",
data={
"title": title,
"description": description or None,
"status": status,
"priority": priority,
"assignee": assignees,
"dueDate": due_date or None,
"projectId": project_id or None,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return (
f"Task created: '{row['title']}' "
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
)
@tool
async def update_task(
task_id: str,
title: str = "",
description: str = "",
status: str = "",
priority: str = "",
assignees: str = "",
due_date: int = -1,
project_id: str = "",
) -> str:
"""Update fields on an existing task. Only pass fields you want to change."""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if description:
updates["description"] = description
if status:
updates["status"] = status
if priority:
updates["priority"] = priority
if assignees:
updates["assignee"] = assignees
if due_date != -1:
updates["dueDate"] = due_date or None
if project_id:
updates["projectId"] = project_id
result = await execute_on_client(
action="update",
table="tasks",
data={"id": task_id, "updates": updates},
)
row = result["row"]
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
@tool
async def delete_task(task_id: str) -> str:
"""Delete a task permanently by its UUID."""
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
return f"Task {task_id} deleted."
@tool
async def list_tasks_due_today() -> str:
"""List all tasks whose due date falls on today's date."""
now = datetime.now(tz=timezone.utc)
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
end_ms = start_ms + 86_400_000 - 1
result = await execute_on_client(
action="select",
table="tasks",
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
)
rows = result.get("rows", [])
if not rows:
return "No tasks are due today."
lines = [
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
for r in rows
]
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
@tool
async def list_task_comments(task_id: str) -> str:
"""List all comments on a task by its UUID."""
result = await execute_on_client(
action="select",
table="taskComments",
filters={"taskId": task_id},
)
rows = result.get("rows", [])
if not rows:
return f"No comments found for task {task_id}."
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
@tool
async def add_task_comment(task_id: str, author: str, content: str) -> str:
"""Add a comment to a task."""
result = await execute_on_client(
action="insert",
table="taskComments",
data={"taskId": task_id, "author": author, "content": content},
)
row = result.get("row", {})
row_author = row.get("author", author)
row_task_id = row.get("taskId") or row.get("task_id") or task_id
row_comment_id = row.get("id", "unknown")
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
@tool
async def delete_task_comment(comment_id: str) -> str:
"""Delete a task comment by its UUID."""
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
return f"Comment {comment_id} deleted."
TASK_TOOLS: list[Any] = [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]

View File

@@ -0,0 +1,88 @@
"""Timeline agent — project milestone management.
Adapted for Batch Agent Service: import from app.ws_context.
"""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
@tool
async def list_timelines(project_id: str = "") -> str:
"""List timelines. Provide project_id to scope to a specific project."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="timelines",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No timelines found."
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
@tool
async def create_timeline(
project_id: str, title: str, date: int, is_ai_suggested: int = 0,
) -> str:
"""Create a project timeline (milestone)."""
result = await execute_on_client(
action="insert",
table="timelines",
data={
"projectId": project_id,
"title": title,
"date": date,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
@tool
async def update_timeline(timeline_id: str, title: str = "", date: int = -1) -> str:
"""Update a timeline. Only pass fields that should change."""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if date != -1:
updates["date"] = date
result = await execute_on_client(
action="update",
table="timelines",
data={"id": timeline_id, "updates": updates},
)
row = result["row"]
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
@tool
async def delete_timeline(timeline_id: str) -> str:
"""Delete a timeline permanently by its UUID."""
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
return f"Timeline {timeline_id} deleted."
TIMELINE_TOOLS: list[Any] = [
list_timelines,
create_timeline,
update_timeline,
delete_timeline,
]

View File

@@ -0,0 +1,108 @@
"""Cloud provider integration utilities.
Adapted for Batch Agent Service: import from shared.config instead of app.config.
Provides:
* Shared message dataclasses (EmailMessage, ChatMessage)
* get_provider() — factory for Gmail/MS Graph clients
* encrypt_token() / decrypt_token() — Fernet-based OAuth token encryption
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from shared.config import settings
if TYPE_CHECKING:
from app.integrations.gmail import GmailClient
from app.integrations.ms_graph import MSGraphClient
logger = logging.getLogger(__name__)
@dataclass
class EmailMessage:
id: str
subject: str
sender: str
body_text: str
date: datetime
labels: list[str] = field(default_factory=list)
@property
def as_text(self) -> str:
date_str = self.date.strftime("%Y-%m-%d %H:%M")
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
return (
f"From: {self.sender}\n"
f"Date: {date_str}{labels_str}\n"
f"Subject: {self.subject}\n\n"
f"{self.body_text}"
)
@dataclass
class ChatMessage:
id: str
content: str
sender: str
channel: str | None
date: datetime
@property
def as_text(self) -> str:
date_str = self.date.strftime("%Y-%m-%d %H:%M")
channel_str = f" [channel: {self.channel}]" if self.channel else ""
return (
f"From: {self.sender}\n"
f"Date: {date_str}{channel_str}\n\n"
f"{self.content}"
)
def _get_fernet() -> Fernet:
key = settings.OAUTH_ENCRYPTION_KEY
if not key:
raise RuntimeError(
"OAUTH_ENCRYPTION_KEY is not set. "
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
)
return Fernet(key.encode() if isinstance(key, str) else key)
def encrypt_token(token_info: dict) -> str:
if not isinstance(token_info, dict) or not token_info:
raise ValueError("token_info must be a non-empty dict")
plaintext = json.dumps(token_info).encode("utf-8")
return _get_fernet().encrypt(plaintext).decode("utf-8")
def decrypt_token(encrypted: str) -> dict:
try:
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
return json.loads(plaintext)
except (InvalidToken, json.JSONDecodeError) as exc:
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
def get_provider(
provider: str,
credentials_info: dict,
) -> "GmailClient | MSGraphClient":
if provider == "gmail":
from app.integrations.gmail import GmailClient
return GmailClient(credentials_info)
if provider in {"outlook", "teams"}:
from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(credentials_info)
raise ValueError(
f"Unknown cloud provider {provider!r}. "
"Supported: 'gmail', 'outlook', 'teams'."
)

View File

@@ -0,0 +1,252 @@
"""Gmail API client for cloud agent integration.
Adapted for Batch Agent Service: import from app.integrations instead of
app.integrations (same relative path within the service).
"""
from __future__ import annotations
import asyncio
import base64
import email
import html
import logging
import re
from datetime import datetime, timezone
from typing import Any
from app.integrations import EmailMessage
logger = logging.getLogger(__name__)
_GMAIL_DATE_FMT = "%Y/%m/%d"
_BODY_TRUNCATE = 8_000
_MAX_MESSAGES = 200
def _build_gmail_query(
filter_config: dict[str, Any] | None,
since: datetime | None,
) -> str:
parts: list[str] = []
cfg = filter_config or {}
labels: list[str] = cfg.get("labels", [])
if labels:
if len(labels) == 1:
parts.append(f"label:{labels[0]}")
else:
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
parts.append(f"({label_expr})")
senders: list[str] = cfg.get("senders", [])
for sender in senders:
parts.append(f"from:{sender}")
date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from")
to_str: str | None = date_range.get("to")
effective_since: datetime | None = since
if from_str:
try:
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
if cfg_since.tzinfo is None:
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
if effective_since is None or cfg_since > effective_since:
effective_since = cfg_since
except ValueError:
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
if effective_since:
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
if to_str:
try:
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
except ValueError:
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
return " ".join(parts)
def _strip_html(raw_html: str) -> str:
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
decoded = html.unescape(no_tags)
return re.sub(r"\s+", " ", decoded).strip()
def _parse_body(payload: dict[str, Any]) -> str:
mime_type: str = payload.get("mimeType", "")
body: dict = payload.get("body", {})
parts: list[dict] = payload.get("parts", [])
if mime_type == "text/plain":
data = body.get("data", "")
if data:
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return ""
if mime_type == "text/html":
data = body.get("data", "")
if data:
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return _strip_html(raw)
return ""
plain_fallback = ""
for part in parts:
part_mime = part.get("mimeType", "")
if part_mime == "text/plain":
return _parse_body(part)
if part_mime == "text/html" and not plain_fallback:
plain_fallback = _parse_body(part)
if part_mime.startswith("multipart/"):
nested = _parse_body(part)
if nested:
return nested
return plain_fallback
def _parse_date(raw: str) -> datetime:
try:
parsed = email.utils.parsedate_to_datetime(raw)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed.astimezone(timezone.utc)
except Exception:
return datetime.now(timezone.utc)
class GmailClient:
def __init__(self, credentials_info: dict[str, Any]) -> None:
from google.oauth2.credentials import Credentials
self._credentials_info = credentials_info
expiry_str: str | None = credentials_info.get("expiry")
expiry: datetime | None = None
if expiry_str:
try:
expiry = datetime.fromisoformat(
expiry_str.replace("Z", "+00:00")
).replace(tzinfo=timezone.utc)
except ValueError:
pass
self._credentials = Credentials(
token=credentials_info.get("token"),
refresh_token=credentials_info.get("refresh_token"),
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
client_id=credentials_info.get("client_id"),
client_secret=credentials_info.get("client_secret"),
scopes=credentials_info.get("scopes"),
expiry=expiry,
)
async def fetch_messages(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[EmailMessage]:
query = _build_gmail_query(filter_config, since)
logger.debug("gmail: executing search query %r", query)
return await asyncio.to_thread(self._fetch_sync, query)
@property
def refreshed_credentials(self) -> dict[str, Any] | None:
creds = self._credentials
if not creds.valid and creds.expired:
return None
if creds.token != self._credentials_info.get("token"):
result = {
"token": creds.token,
"refresh_token": creds.refresh_token,
"token_uri": creds.token_uri,
"client_id": creds.client_id,
"client_secret": creds.client_secret,
"scopes": list(creds.scopes or []),
}
if creds.expiry:
result["expiry"] = creds.expiry.isoformat()
return result
return None
def _fetch_sync(self, query: str) -> list[EmailMessage]:
import googleapiclient.discovery
import googleapiclient.errors
from google.auth.transport.requests import Request
if self._credentials.expired and self._credentials.refresh_token:
try:
self._credentials.refresh(Request())
except Exception as exc:
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
service = googleapiclient.discovery.build(
"gmail", "v1", credentials=self._credentials, cache_discovery=False
)
user_api = service.users()
ids: list[str] = []
page_token: str | None = None
while len(ids) < _MAX_MESSAGES:
batch_size = min(100, _MAX_MESSAGES - len(ids))
kwargs: dict[str, Any] = {
"userId": "me",
"maxResults": batch_size,
}
if query:
kwargs["q"] = query
if page_token:
kwargs["pageToken"] = page_token
try:
resp = user_api.messages().list(**kwargs).execute()
except googleapiclient.errors.HttpError as exc:
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
for msg in resp.get("messages", []):
ids.append(msg["id"])
page_token = resp.get("nextPageToken")
if not page_token:
break
if not ids:
return []
logger.info("gmail: fetching %d message(s)", len(ids))
messages: list[EmailMessage] = []
for msg_id in ids:
try:
msg = user_api.messages().get(
userId="me", id=msg_id, format="full"
).execute()
headers: dict[str, str] = {
h["name"].lower(): h["value"]
for h in msg.get("payload", {}).get("headers", [])
}
subject = headers.get("subject", "(no subject)")
sender = headers.get("from", "unknown")
date_raw = headers.get("date", "")
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
labels = msg.get("labelIds", [])
messages.append(EmailMessage(
id=msg_id,
subject=subject,
sender=sender,
body_text=body_text,
date=date,
labels=labels,
))
except Exception as exc:
logger.warning("gmail: skipping message %s: %s", msg_id, exc)
logger.info("gmail: returned %d message(s)", len(messages))
return messages

View File

@@ -0,0 +1,266 @@
"""Microsoft Graph API client for Outlook and Teams.
Adapted for Batch Agent Service: import settings from shared.config.
"""
from __future__ import annotations
import logging
import re
from datetime import datetime, timezone
from typing import Any
import httpx
from shared.config import settings
from app.integrations import ChatMessage, EmailMessage
logger = logging.getLogger(__name__)
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
_MAX_EMAILS = 200
_MAX_MESSAGES = 200
_BODY_TRUNCATE = 8_000
def _strip_html(raw: str) -> str:
no_tags = re.sub(r"<[^>]+>", " ", raw)
import html as _html
decoded = _html.unescape(no_tags)
return re.sub(r"\s+", " ", decoded).strip()
def _odata_datetime(dt: datetime) -> str:
utc = dt.astimezone(timezone.utc)
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
def _build_email_filter(
filter_config: dict[str, Any] | None,
since: datetime | None,
) -> str:
clauses: list[str] = []
cfg = filter_config or {}
senders: list[str] = cfg.get("senders", [])
if senders:
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
clauses.append("(" + " or ".join(sender_clauses) + ")")
date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from")
effective_since: datetime | None = since
if from_str:
try:
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
if cfg_since.tzinfo is None:
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
if effective_since is None or cfg_since > effective_since:
effective_since = cfg_since
except ValueError:
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
if effective_since:
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
to_str: str | None = date_range.get("to")
if to_str:
try:
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
if to_dt.tzinfo is None:
to_dt = to_dt.replace(tzinfo=timezone.utc)
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
except ValueError:
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
return " and ".join(clauses)
class MSGraphClient:
def __init__(self, credentials_info: dict[str, Any]) -> None:
self._credentials_info = credentials_info
self._access_token: str = credentials_info.get("access_token", "")
self._original_access_token: str = self._access_token
self._refresh_token: str | None = credentials_info.get("refresh_token")
def _auth_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._access_token}"}
async def _refresh_access_token(self) -> None:
import msal
app = msal.ConfidentialClientApplication(
client_id=settings.MS_CLIENT_ID,
client_credential=settings.MS_CLIENT_SECRET,
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
)
scopes: list[str] = self._credentials_info.get("scope", "").split()
if not scopes:
scopes = ["https://graph.microsoft.com/.default"]
result = app.acquire_token_by_refresh_token(
self._refresh_token,
scopes=scopes,
)
if "access_token" not in result:
error = result.get("error_description", result.get("error", "unknown"))
raise RuntimeError(f"MS Graph token refresh failed: {error}")
self._access_token = result["access_token"]
if "refresh_token" in result:
self._refresh_token = result["refresh_token"]
self._credentials_info["refresh_token"] = result["refresh_token"]
self._credentials_info["access_token"] = self._access_token
@property
def refreshed_credentials(self) -> dict[str, Any] | None:
if self._access_token != self._original_access_token:
return {**self._credentials_info, "access_token": self._access_token}
return None
async def _get(
self,
client: httpx.AsyncClient,
url: str,
params: dict[str, Any] | None = None,
*,
retry_on_401: bool = True,
) -> dict[str, Any]:
resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
await self._refresh_access_token()
resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 429:
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
resp.raise_for_status()
return resp.json()
async def fetch_emails(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[EmailMessage]:
odata_filter = _build_email_filter(filter_config, since)
params: dict[str, Any] = {
"$top": 50,
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
"$orderby": "receivedDateTime desc",
}
if odata_filter:
params["$filter"] = odata_filter
emails: list[EmailMessage] = []
url = f"{_GRAPH_BASE}/me/messages"
async with httpx.AsyncClient(timeout=30.0) as client:
while url and len(emails) < _MAX_EMAILS:
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
for item in data.get("value", []):
emails.append(self._parse_email(item))
if len(emails) >= _MAX_EMAILS:
break
url = data.get("@odata.nextLink", "")
params = {}
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
return emails
async def fetch_messages(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[ChatMessage]:
cfg = filter_config or {}
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
params: dict[str, Any] = {"$top": 50}
if since:
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
messages: list[ChatMessage] = []
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
async with httpx.AsyncClient(timeout=30.0) as client:
while url and len(messages) < _MAX_MESSAGES:
try:
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
except httpx.HTTPStatusError as exc:
if exc.response.status_code in (403, 404):
logger.warning(
"ms_graph: /me/chats/getAllMessages not available (%d)",
exc.response.status_code,
)
break
raise
for item in data.get("value", []):
msg = self._parse_teams_message(item)
if channel_filter and msg.channel:
if not any(c in msg.channel.lower() for c in channel_filter):
continue
messages.append(msg)
if len(messages) >= _MAX_MESSAGES:
break
url = data.get("@odata.nextLink", "")
params = {}
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
return messages
@staticmethod
def _parse_email(item: dict[str, Any]) -> EmailMessage:
subject: str = item.get("subject", "(no subject)") or "(no subject)"
sender_block = item.get("from", {}) or {}
sender_addr = (
(sender_block.get("emailAddress") or {}).get("address", "unknown")
)
date_str: str = item.get("receivedDateTime", "")
try:
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except Exception:
date = datetime.now(timezone.utc)
body_block = item.get("body", {}) or {}
content_type: str = body_block.get("contentType", "text")
raw_body: str = body_block.get("content", "")
if content_type == "html":
body_text = _strip_html(raw_body)
else:
body_text = raw_body or item.get("bodyPreview", "")
body_text = body_text[:_BODY_TRUNCATE]
return EmailMessage(
id=item.get("id", ""),
subject=subject,
sender=sender_addr,
body_text=body_text,
date=date,
)
@staticmethod
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
msg_id: str = item.get("id", "")
sender_block = (item.get("from") or {}).get("user") or {}
sender: str = sender_block.get("displayName", "unknown")
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
date_str: str = item.get("createdDateTime", "")
try:
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except Exception:
date = datetime.now(timezone.utc)
body_block = item.get("body", {}) or {}
content_type: str = body_block.get("contentType", "text")
raw_content: str = body_block.get("content", "")
content = _strip_html(raw_content) if content_type == "html" else raw_content
content = content[:_BODY_TRUNCATE]
return ChatMessage(
id=msg_id,
content=content,
sender=sender,
channel=channel,
date=date,
)

View File

@@ -0,0 +1,385 @@
"""Chatbot Journey — guided conversation to build an agent prompt_template.
Adapted for Batch Agent Service: imports from app.agents.filesystem_agent
and app.llm instead of monolith paths. Session state is in-memory (could
be moved to Redis for horizontal scaling in the future).
Journey flow:
1. Redis consumer dispatches ``journey_start`` with basic agent config.
2. Server creates an in-memory session, runs the setup LLM with
file-system tools to explore the directory, returns first question.
3. ``journey_message`` frames drive the conversation.
4. After 3-5 turns the LLM emits PROMPT_TEMPLATE_START / _END block.
5. Server parses the block and returns ``journey_reply`` with ``done=True``.
"""
from __future__ import annotations
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from app.llm import get_llm
logger = logging.getLogger(__name__)
# ── Session TTL ───────────────────────────────────────────────────────────
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
# Sentinel strings used to delimit the LLM-produced prompt_template.
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
_MIN_TURNS_BEFORE_NUDGE: int = 3
_MAX_TURNS: int = 15
_MAX_TOOL_STEPS: int = 6
# ── In-memory session store ───────────────────────────────────────────────
@dataclass
class JourneySession:
session_id: str
user_id: str
agent_type: str # "local" | "cloud"
directory: str
data_types: list[str]
history: list[dict[str, Any]] = field(default_factory=list)
system_prompt: str = ""
created_at: float = field(default_factory=time.monotonic)
def is_expired(self) -> bool:
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
# session_id → session
_sessions: dict[str, JourneySession] = {}
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
"""Retrieve session; return None on missing, expired, or wrong owner."""
s = _sessions.get(session_id)
if s is None or s.is_expired():
_sessions.pop(session_id, None)
return None
if s.user_id != user_id:
return None
return s
# ── System prompt builder ─────────────────────────────────────────────────
_SYSTEM_PROMPT_TEMPLATE = """\
You are a friendly assistant helping a freelancer configure a data-extraction agent.
Your job is to understand exactly what data the user wants to extract from their
local directory and produce a detailed prompt_template that a separate AI will use
as its instruction set.
The extraction agent already has this base behaviour built in:
- Reads each file using file-system tools.
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
- Sets isAiSuggested=1 on every new record.
- Only extracts data explicitly present in the files — it never invents information.
The user's custom prompt is appended AFTER this base behaviour, so focus on
what to look for and how to map it — not on the general extraction mechanics.
You have access to file-system tools to explore the user's directory:
- list_directory: to see folder structure
- read_file_content: to peek at file contents
- get_file_metadata: to check file info
The user's configured directory is: {directory}
Target data types: {data_types}
IMPORTANT — project assignment is handled automatically by the main agent runner
before the custom prompt is ever used. You MUST NOT ask the user about projects,
projectId, or how to link records to projects. Never include projectId logic or
project creation instructions in the generated prompt_template.
Start by exploring the directory to understand its structure. Then ask concise,
focused questions one at a time. Cover these topics (not necessarily in this order):
1. The type and format of the source content (confirmed by your exploration).
2. How fields should be mapped (e.g. filename → task title).
3. Priority or status rules (e.g. "urgent" keyword → high priority).
4. Any special handling, date extraction, or exclusions.
Once you reach 90% confidence, output the final prompt_template between these exact
markers on their own lines:
{template_start}
<the complete extraction prompt here>
{template_end}
The prompt_template must be a self-contained instruction for an AI that reads files
and must perform CRUD operations using tools to create records. It should specify:
- What entity types to create (tasks, notes, timelines) — never projects.
- How to map file content to record fields (camelCase: title, status, priority,
dueDate, content, etc.) — never include projectId.
- That isAiSuggested must be set to 1 on every new record.
- Concrete examples of mappings based on what you discovered in the directory.
{existing_section}\
Keep asking clarifying questions until you are at least 90% confident you have
enough information to generate an accurate prompt_template. Once you reach that
confidence level, stop asking and produce the final template immediately.
Begin by exploring the directory, then ask your first question.\
"""
def _build_system_prompt(
directory: str,
data_types: list[str],
existing_template: str | None = None,
) -> str:
existing_section = (
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
f"---\n{existing_template}\n---\n"
if existing_template
else ""
)
return _SYSTEM_PROMPT_TEMPLATE.format(
directory=directory,
data_types=", ".join(data_types),
template_start=_TEMPLATE_START,
template_end=_TEMPLATE_END,
existing_section=existing_section,
)
# ── Template extraction ───────────────────────────────────────────────────
def _extract_template(text: str) -> str | None:
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
return None
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
end_idx = text.index(_TEMPLATE_END)
return text[start_idx:end_idx].strip() or None
# ── LLM call with tool support ───────────────────────────────────────────
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _call_llm_with_tools(
system_prompt: str,
history: list[dict[str, Any]],
tools: list[Any],
) -> str:
"""Build LangChain messages from history and invoke the LLM with tools.
Handles tool-calling loops: if the LLM calls tools, execute them and
continue until a final text response is produced.
"""
messages: list[Any] = [SystemMessage(content=system_prompt)]
for turn in history:
if turn["role"] == "user":
messages.append(HumanMessage(content=turn["content"]))
else:
messages.append(AIMessage(content=turn["content"]))
llm = get_llm(model=None, temperature=0.4)
llm_with_tools = llm.bind_tools(tools)
tool_map = {tool_def.name: tool_def for tool_def in tools}
for _ in range(_MAX_TOOL_STEPS):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
for call in response.tool_calls:
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"journey: tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:500],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"journey: tool_result name=%s output=%s",
call_name,
str(tool_output)[:800],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
# Fallback: exceeded max tool steps.
final = await llm.ainvoke(messages)
return _as_text(final.content)
# ── Journey handlers (called from redis_consumer) ────────────────────────
async def handle_journey_start(
user_id: str,
frame: dict[str, Any],
) -> dict[str, Any]:
"""Handle a ``journey_start`` request.
Creates a session, runs the setup LLM with directory exploration,
and returns the ``journey_reply`` payload.
"""
agent_type = frame.get("agent_type", "local")
directory = frame.get("directory", "")
data_types = frame.get("data_types", [])
existing_template = frame.get("existing_template")
session_id = frame.get("session_id") or str(uuid.uuid4())
system_prompt = _build_system_prompt(directory, data_types, existing_template)
session = JourneySession(
session_id=session_id,
user_id=user_id,
agent_type=agent_type,
directory=directory,
data_types=data_types,
system_prompt=system_prompt,
)
seed_history: list[dict[str, Any]] = [
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
]
ai_reply = await _call_llm_with_tools(
system_prompt=system_prompt,
history=seed_history,
tools=list(FILESYSTEM_TOOLS),
)
session.history.extend(seed_history)
session.history.append({"role": "assistant", "content": ai_reply})
_sessions[session_id] = session
logger.info(
"journey: session %s started for user %s (directory=%s)",
session_id,
user_id,
directory,
)
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
or "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
}
async def handle_journey_message(
user_id: str,
frame: dict[str, Any],
) -> dict[str, Any]:
"""Handle a ``journey_message`` request.
Appends the user message, calls the LLM, and returns the
``journey_reply`` payload.
"""
session_id = frame.get("session_id", "")
message = frame.get("message", "")
session = get_journey_session(session_id, user_id)
if session is None:
return {
"type": "journey_reply",
"session_id": session_id,
"message": "Journey session not found or expired. Please start a new setup.",
"done": True,
"prompt_template": None,
}
session.history.append({"role": "user", "content": message})
ai_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
)
session.history.append({"role": "assistant", "content": ai_reply})
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
if not done:
turns = sum(1 for t in session.history if t["role"] == "user")
if turns >= _MAX_TURNS:
nudge_content = (
"[System: You have enough information. Please generate the final "
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
)
session.history.append({"role": "user", "content": nudge_content})
nudge_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
)
session.history.append({"role": "assistant", "content": nudge_reply})
prompt_template = _extract_template(nudge_reply)
if prompt_template is not None:
done = True
ai_reply = nudge_reply
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
if _TEMPLATE_START in ai_reply
else "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
logger.info("journey: session %s completed for user %s", session_id, user_id)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
}

View File

@@ -0,0 +1,76 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Identical to services/chat/app/llm.py. Uses shared.config.settings.
"""
from __future__ import annotations
import os
import warnings
from openai import AsyncOpenAI
import litellm
from langchain_openai import ChatOpenAI
from langchain_litellm import ChatLiteLLM
from shared.config import settings
litellm.drop_params = True
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
if model.startswith("anthropic/"):
return settings.ANTHROPIC_API_KEY or None
if model.startswith("gemini/") or model.startswith("google/"):
return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None
if model.startswith("github_copilot/"):
return None
return settings.OPENAI_API_KEY or None
def get_llm(
*,
model: str | None = None,
temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM:
model = model or settings.LLM_MODEL
if settings.GITHUB_COPILOT_TOKEN_DIR:
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
if "/" in model:
return ChatLiteLLM(model=model, temperature=temperature)
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=_api_key_for_model(model),
)
def get_router_llm(
*,
temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM:
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
async def embed(text: str) -> list[float]:
model = settings.LLM_EMBED_MODEL
if model.startswith("github_copilot/") or "/" in model:
response = await litellm.aembedding(model=model, input=[text])
return response.data[0]["embedding"]
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
response = await client.embeddings.create(model=model, input=text)
return response.data[0].embedding

View File

@@ -0,0 +1,57 @@
"""Batch Agent Service — FastAPI application.
Owns: agent_runner (local directory + cloud connectors), journey builder,
filesystem_agent, integrations (Gmail, MS Graph).
Communicates with WS Gateway via Redis:
- Subscribes to batch:request:{user_id} (journey_start, journey_message)
- Publishes to ws:out:{user_id} (journey replies + tool calls)
- BRPOP on tool:result:{call_id} (tool-call round-trip, 30s timeout)
- SET+EX on journey:{user_id} (journey session state, TTL 1800s)
"""
from __future__ import annotations
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.redis_consumer import start_consumer
from app.routes import router
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
logger.info("batch-agent: starting Redis consumer")
task = asyncio.create_task(start_consumer())
yield
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
logger.info("batch-agent: Redis consumer stopped")
app = FastAPI(title="Adiuva Batch Agent Service", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
app.include_router(router)
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "ok", "service": "batch-agent"}

View File

@@ -0,0 +1,141 @@
"""Redis consumer for the Batch Agent Service.
Subscribes to batch:request:* (pattern) and dispatches:
- journey_start → handle_journey_start
- journey_message → handle_journey_message
- agent_trigger → run_local_agent / run_cloud_agent
Results are published back to ws:out:{user_id} via Redis.
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any
from shared.redis import redis_client, batch_request_channel, ws_out_channel
from app.ws_context import set_current_user, clear_current_user
logger = logging.getLogger(__name__)
async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None:
"""Publish a frame to the user's WS outbound channel."""
channel = ws_out_channel(user_id)
await redis_client.publish(channel, json.dumps(payload))
async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None:
"""Handle a journey_start request from WS Gateway."""
from app.journey import handle_journey_start
set_current_user(user_id)
try:
reply = await handle_journey_start(user_id, data)
await _publish_to_user(user_id, reply)
except Exception as exc:
logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc)
await _publish_to_user(user_id, {
"type": "journey_reply",
"session_id": data.get("session_id", ""),
"message": f"Journey setup failed: {exc}",
"done": True,
"prompt_template": None,
})
finally:
clear_current_user()
async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None:
"""Handle a journey_message from WS Gateway."""
from app.journey import handle_journey_message
set_current_user(user_id)
try:
reply = await handle_journey_message(user_id, data)
await _publish_to_user(user_id, reply)
except Exception as exc:
logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc)
await _publish_to_user(user_id, {
"type": "journey_reply",
"session_id": data.get("session_id", ""),
"message": f"Journey processing failed: {exc}",
"done": True,
"prompt_template": None,
})
finally:
clear_current_user()
async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None:
"""Handle an agent_trigger request from the REST route (forwarded via Redis)."""
from app.agent_runner import run_local_agent
set_current_user(user_id)
try:
await run_local_agent(user_id, data)
except Exception as exc:
logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc)
await _publish_to_user(user_id, {
"type": "run_complete",
"status": "error",
"run_context": data.get("run_context", {}),
})
finally:
clear_current_user()
async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None:
"""Route a batch request to the correct handler."""
msg_type = message_data.get("type", "")
if msg_type == "journey_start":
await _handle_journey_start(user_id, message_data)
elif msg_type == "journey_message":
await _handle_journey_message(user_id, message_data)
elif msg_type == "agent_trigger":
await _handle_agent_trigger(user_id, message_data)
else:
logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id)
async def start_consumer() -> None:
"""Subscribe to batch:request:* and dispatch incoming frames."""
pubsub = redis_client.pubsub()
await pubsub.psubscribe("batch:request:*")
logger.info("batch-agent: subscribed to batch:request:*")
try:
async for message in pubsub.listen():
if message["type"] != "pmessage":
continue
channel: str = message["channel"]
if isinstance(channel, bytes):
channel = channel.decode()
# Extract user_id from channel: batch:request:{user_id}
parts = channel.split(":", 2)
if len(parts) < 3:
continue
user_id = parts[2]
raw = message["data"]
if isinstance(raw, bytes):
raw = raw.decode()
try:
data = json.loads(raw)
except json.JSONDecodeError:
logger.warning("batch-agent: invalid JSON on channel %s", channel)
continue
# Dispatch in a separate task to avoid blocking the consumer
asyncio.create_task(_dispatch(user_id, data))
except asyncio.CancelledError:
logger.info("batch-agent: consumer shutting down")
finally:
await pubsub.punsubscribe("batch:request:*")

View File

@@ -0,0 +1,208 @@
"""Agent REST routes — catalog, billing checks, trigger.
Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas.
Agent trigger dispatches via Redis to the consumer instead of spawning
an in-process background task.
"""
from __future__ import annotations
import json
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Header, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from shared.db import async_session
from shared.models import AgentRunLog
from shared.redis import redis_client, batch_request_channel
from app.agent_runner import is_agent_running
router = APIRouter(prefix="/agents", tags=["agents"])
# ── Tier feature limits ───────────────────────────────────────────────
# Mirrors app/billing/tier_manager.py FEATURES dict.
FEATURES: dict[str, dict] = {
"free": {"batch_active": 1, "batch_runs_per_day": 3},
"pro": {"batch_active": 5, "batch_runs_per_day": 20},
"power": {"batch_active": 20, "batch_runs_per_day": 100},
"team": {"batch_active": -1, "batch_runs_per_day": -1},
}
def _dt_ms(dt: datetime) -> int:
return int(dt.timestamp() * 1000)
def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
async def _enforce_run_frequency(tier: str, user_id: str) -> None:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
async with async_session() as db:
result = await db.execute(
select(func.count(AgentRunLog.id)).where(
AgentRunLog.user_id == user_id,
AgentRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@router.get("/catalog")
async def get_agent_catalog(
x_user_id: str = Header(..., alias="X-User-Id"),
) -> list[dict]:
return [
{
"type": "local_directory",
"name": "Local Directory Monitor",
"description": "Watches local directories, extracts data from files using AI",
},
{
"type": "gmail",
"name": "Gmail Connector",
"description": "Scans Gmail inbox, extracts tasks/notes from emails",
},
{
"type": "teams",
"name": "Microsoft Teams Connector",
"description": "Monitors Teams messages, extracts action items",
},
{
"type": "outlook",
"name": "Outlook Connector",
"description": "Scans Outlook inbox, extracts tasks/notes",
},
]
# ── Can-create check ─────────────────────────────────────────────────
@router.post("/can-create")
async def can_create_agent(
body: dict,
x_user_id: str = Header(..., alias="X-User-Id"),
x_user_tier: str = Header("free", alias="X-User-Tier"),
) -> dict:
active_agents = body.get("active_agents", 0)
limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or active_agents < limit
return {
"allowed": allowed,
"tier": x_user_tier,
"active_agents": active_agents,
"limit": limit,
}
# ── Trigger ──────────────────────────────────────────────────────────
@router.post("/trigger", status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
body: dict,
x_user_id: str = Header(..., alias="X-User-Id"),
x_user_tier: str = Header("free", alias="X-User-Tier"),
) -> dict:
"""Trigger a local agent run — creates run log and dispatches via Redis."""
active_agents = body.get("active_agents", 0)
_enforce_agent_limit(x_user_tier, active_agents)
await _enforce_run_frequency(x_user_tier, x_user_id)
stable_agent_id = body.get("agent_id") or str(uuid.uuid4())
if is_agent_running(stable_agent_id):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Agent is already running.",
)
# Create run log in DB
async with async_session() as db:
run_log = AgentRunLog(
agent_id=stable_agent_id,
agent_type="local",
user_id=x_user_id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_log_id = run_log.id
run_context = {
"type": "agent_batch",
"run_id": run_log_id,
"agent_id": stable_agent_id,
}
# Dispatch to the Redis consumer for processing
trigger_data = {
"type": "agent_trigger",
"directory": body.get("directory", ""),
"directory_paths": [body.get("directory", "")] if body.get("directory") else [],
"data_types": _to_data_types(body.get("what_to_extract", [])),
"file_extensions": body.get("file_extensions", []),
"prompt_template": body.get("custom_agent_prompt", ""),
"device_id": body.get("device_id", ""),
"run_context": run_context,
}
channel = batch_request_channel(x_user_id)
await redis_client.publish(channel, json.dumps(trigger_data))
return {
"id": run_log_id,
"agent_id": stable_agent_id,
"agent_type": "local",
"status": "running",
"items_processed": 0,
"items_created": 0,
"errors": [],
"started_at": _dt_ms(run_log.started_at),
"completed_at": None,
}

View File

@@ -0,0 +1,135 @@
"""WebSocket context for Batch Agent Service — Redis-based tool call round-trip.
Same pattern as services/chat/app/ws_context.py: publishes tool_call frames
to Redis ws:out:{user_id} and awaits BRPOP on tool:result:{call_id}.
Additionally provides set_client_executor / clear_client_executor stubs
for backward compatibility with the agent_runner code (which originally
used a DeviceConnectionManager callback). In the microservice world these
are no-ops — execute_on_client() always uses the Redis path.
"""
from __future__ import annotations
import json
import logging
from contextvars import ContextVar
from typing import Any, Callable, Coroutine
from uuid import uuid4
from shared.redis import redis_client, tool_result_key, ws_out_channel
logger = logging.getLogger(__name__)
_TOOL_CALL_TIMEOUT = 30 # seconds — BRPOP timeout
# Per-request user_id context var (set before agent run)
_current_user_id: ContextVar[str | None] = ContextVar("_current_user_id", default=None)
# Optional collector for debug / logging
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None
)
def set_current_user(user_id: str) -> None:
_current_user_id.set(user_id)
def clear_current_user() -> None:
_current_user_id.set(None)
def set_tool_result_collector(lst: list[dict]) -> None:
_tool_result_collector.set(lst)
def clear_tool_result_collector() -> None:
_tool_result_collector.set(None)
# ── Compatibility shims ──────────────────────────────────────────────────
# agent_runner.py originally called set_client_executor / clear_client_executor
# with a DeviceConnectionManager callback. In the microservice world the
# Redis-based execute_on_client replaces this, so these are no-ops that
# keep the agent_runner code unchanged.
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]] | None) -> None:
"""No-op — kept for agent_runner compatibility."""
pass
def clear_client_executor() -> None:
"""No-op — kept for agent_runner compatibility."""
pass
async def execute_on_client(
action: str,
table: str | None = None,
data: dict[str, Any] | None = None,
filters: dict[str, Any] | None = None,
vector: list[float] | None = None,
limit: int | None = None,
) -> dict[str, Any]:
"""Send a tool_call to Electron via Redis and await the result.
1. Build tool_call payload
2. Publish to ws:out:{user_id} (WS Gateway forwards to Electron)
3. BRPOP on tool:result:{call_id} (WS Gateway pushes when Electron replies)
4. Return result dict
Raises RuntimeError if no user_id is set or if the call times out.
"""
user_id = _current_user_id.get()
if not user_id:
raise RuntimeError(
"execute_on_client() called without a user_id — "
"set_current_user() must be called first."
)
call_id = str(uuid4())
payload: dict[str, Any] = {
"type": "tool_call",
"id": call_id,
"action": action,
}
if table is not None:
payload["table"] = table
if data is not None:
payload["data"] = data
if filters is not None:
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
if vector is not None:
payload["vector"] = vector
if limit is not None:
payload["limit"] = limit
# Publish tool_call to WS Gateway → Electron
channel = ws_out_channel(user_id)
await redis_client.publish(channel, json.dumps(payload))
# Wait for Electron's tool_result
result_key = tool_result_key(call_id)
response = await redis_client.brpop(result_key, timeout=_TOOL_CALL_TIMEOUT)
if response is None:
raise RuntimeError(
f"Tool call {call_id} timed out after {_TOOL_CALL_TIMEOUT}s — "
f"device may be offline or unresponsive."
)
# response is (key, value) tuple
_, raw = response
result = json.loads(raw)
# Collect for debug if requested
collector = _tool_result_collector.get(None)
if collector is not None:
collector.append({
"action": action,
"table": table,
"data": result,
})
return result

View File

@@ -0,0 +1,20 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
sqlalchemy>=2.0.0
asyncpg>=0.30.0
redis>=5.0.0
cryptography>=42.0.0
python-dotenv>=1.0.0
langchain-core>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.3.0
litellm>=1.50.0
openai>=1.50.0
httpx>=0.27.0
croniter>=2.0.0
google-api-python-client>=2.130.0
google-auth>=2.30.0
msal>=1.28.0

View File

@@ -0,0 +1,15 @@
# Billing Service
Owns: Stripe integration, tier management, subscription CRUD.
## Tables owned (write)
- `subscriptions`
## Endpoints
- `POST /billing/checkout`
- `POST /billing/webhook` (Stripe, no JWT auth)
- `GET /billing/subscription`
- `DELETE /billing/subscription`
## Redis channels
- Publish: `tier:changed:{user_id}` on tier change

View File

36
services/chat/Dockerfile Normal file
View File

@@ -0,0 +1,36 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
COPY services/chat/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Shared module
COPY shared/ shared/
# Service source
COPY services/chat/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
# Chat service is CPU-bound (LLM calls) — use multiple workers
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "2", \
"--timeout", "120"]

21
services/chat/README.md Normal file
View File

@@ -0,0 +1,21 @@
# Chat Service
Owns: deep_agent (home + floating chat), memory middleware, domain agents
(task, note, project, timeline), LLM orchestration.
## Tables owned
- `memory_core`
- `memory_associative`
- `memory_episodic`
- `memory_proactive`
## Tables read (cross-service)
- `users` (for encryption_key — memory decryption)
## Endpoints
- `POST /chat` (REST fallback)
## Redis channels
- Subscribe: `chat:request:{user_id}`
- Publish: `ws:out:{user_id}` (stream frames + tool calls)
- BRPOP: `tool:result:{call_id}` (30s timeout)

View File

View File

@@ -0,0 +1 @@
"""Chat Service domain agents."""

View File

@@ -0,0 +1,142 @@
"""Note agent — Markdown note management (list, get, create, update, delete).
Adapted for Chat Service: import from app.ws_context and app.llm.
"""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.llm import embed
from app.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
NOTE_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
" - content is always Markdown; preserve formatting when updating\n"
" - project_id is optional; link a note to a project when mentioned\n"
" - When updating, call get_note first if you need to read existing content\n"
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@tool
async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="notes",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No notes found."
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
@tool
async def get_note(note_id: str) -> str:
"""Fetch a single note by its UUID to read its full Markdown content."""
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
row = result.get("row")
if not row:
return f"Note {note_id} not found."
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
@tool
async def create_note(
title: str,
content: str,
project_id: str = "",
) -> str:
"""Create a new note.
title: note heading (required)
content: Markdown body text (required)
project_id: optional UUID linking this note to a project
"""
result = await execute_on_client(
action="insert",
table="notes",
data={
"title": title,
"content": content,
"projectId": project_id or None,
},
)
row = result["row"]
# Index the note content in the vector store.
vector = await embed(content)
await execute_on_client(
action="vector_upsert",
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note created: '{row['title']}' (id: {row['id']})."
@tool
async def update_note(
note_id: str,
title: str = "",
content: str = "",
) -> str:
"""Update an existing note. Only pass fields that should change.
note_id: UUID of the note (required)
If you need to preserve existing content, call get_note first.
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if content:
updates["content"] = content
result = await execute_on_client(
action="update",
table="notes",
data={"id": note_id, "updates": updates},
)
row = result["row"]
# Re-index if content changed.
if content:
vector = await embed(content)
await execute_on_client(
action="vector_upsert",
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note updated: '{row['title']}' (id: {row['id']})."
@tool
async def delete_note(note_id: str) -> str:
"""Delete a note permanently by its UUID."""
await execute_on_client(action="delete", table="notes", data={"id": note_id})
return f"Note {note_id} deleted."
NOTE_TOOLS: list[Any] = [
list_notes,
get_note,
create_note,
update_note,
delete_note,
]

View File

@@ -0,0 +1,146 @@
"""Project agent — full lifecycle management (list, get, create, update, archive, delete).
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
"""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
PROJECT_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - client_id is optional; link to a client only when explicitly mentioned\n"
" - ai_summary is populated only when the user asks for a project summary;\n"
" derive it from context data — do not fabricate content\n"
" - Use list_projects for scoped queries; list_all_projects only when the\n"
" user wants a complete cross-client view including archived projects\n"
" - get_project requires a project UUID; resolve the ID first by calling\n"
" list_projects if you only have a project name\n"
" - Prefer archiving (update_project status=archived) over deletion;\n"
" only call delete_project when the user explicitly confirms deletion."
)
@tool
async def list_projects(
client_id: str = "",
include_archived: int = 0,
) -> str:
"""List projects, optionally filtered by client_id.
include_archived: 1 to include archived projects, 0 for active only (default).
"""
result = await execute_on_client(
action="select",
table="projects",
filters={
"clientId": client_id or None,
"includeArchived": bool(include_archived),
},
)
rows = result.get("rows", [])
if not rows:
return "No projects found."
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
@tool
async def list_all_projects() -> str:
"""List every project regardless of client or status.
Use only when the user wants a complete cross-client overview.
"""
result = await execute_on_client(action="select", table="projects")
rows = result.get("rows", [])
if not rows:
return "No projects found."
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
@tool
async def get_project(project_id: str) -> str:
"""Fetch a single project by its UUID."""
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
row = result.get("row")
if not row:
return f"Project {project_id} not found."
return (
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
f"clientId: {row.get('clientId', 'none')})"
)
@tool
async def create_project(
name: str,
client_id: str = "",
) -> str:
"""Create a new project.
name: human-readable project name (required)
client_id: optional UUID of the owning client
"""
result = await execute_on_client(
action="insert",
table="projects",
data={"name": name, "clientId": client_id or None},
)
row = result["row"]
return f"Project created: '{row['name']}' (id: {row['id']})"
@tool
async def update_project(
project_id: str,
name: str = "",
client_id: str = "",
status: str = "",
ai_summary: str = "",
) -> str:
"""Update a project. Only pass fields that should change.
project_id: UUID of the project (required)
status: active | archived
ai_summary: AI-generated summary text (populate only when explicitly requested)
"""
updates: dict[str, Any] = {}
if name:
updates["name"] = name
if client_id:
updates["clientId"] = client_id
if status:
updates["status"] = status
if ai_summary:
updates["aiSummary"] = ai_summary
result = await execute_on_client(
action="update",
table="projects",
data={"id": project_id, "updates": updates},
)
row = result["row"]
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
@tool
async def delete_project(project_id: str) -> str:
"""Permanently delete a project and orphan its tasks.
IMPORTANT: prefer update_project(status='archived') unless the user
has explicitly confirmed they want permanent deletion.
"""
await execute_on_client(action="delete", table="projects", data={"id": project_id})
return f"Project {project_id} permanently deleted."
PROJECT_TOOLS: list[Any] = [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]

View File

@@ -0,0 +1,240 @@
"""Task agent — full CRUD for tasks and task comments.
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
"""
from __future__ import annotations
from datetime import datetime, timezone
import re
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TASK_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n"
" - status must be one of: todo, in_progress, done\n"
" - priority must be one of: high, medium, low\n"
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
" - project_id is optional; link to a project when the user mentions one\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
" did not explicitly request; 0 otherwise\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
" - Use list_tasks_due_today for 'what's due today' queries\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
)
# ── Task tools ────────────────────────────────────────────────────────
@tool
async def list_tasks(
project_id: str = "",
status: str = "",
search: str = "",
order_by: str = "",
) -> str:
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
a search string, or an order_by field name (dueDate|priority|createdAt)."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="tasks",
filters={
"projectId": normalized_project_id or None,
"status": status or None,
"search": search or None,
"orderBy": order_by or None,
},
)
rows = result.get("rows", [])
if not rows:
return "No tasks found matching the given filters."
lines = [
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
for r in rows
]
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
@tool
async def create_task(
title: str,
description: str = "",
status: str = "todo",
priority: str = "medium",
assignees: str = "[]",
due_date: int = 0,
project_id: str = "",
is_ai_suggested: int = 0,
) -> str:
"""Create a new task.
title: task title (required)
description: optional details
status: todo | in_progress | done (default: todo)
priority: high | medium | low (default: medium)
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
due_date: Unix timestamp in milliseconds; 0 means no due date
project_id: optional UUID of the parent project
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
"""
result = await execute_on_client(
action="insert",
table="tasks",
data={
"title": title,
"description": description or None,
"status": status,
"priority": priority,
"assignee": assignees,
"dueDate": due_date or None,
"projectId": project_id or None,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return (
f"Task created: '{row['title']}' "
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
)
@tool
async def update_task(
task_id: str,
title: str = "",
description: str = "",
status: str = "",
priority: str = "",
assignees: str = "",
due_date: int = -1,
project_id: str = "",
) -> str:
"""Update fields on an existing task. Only pass fields you want to change.
task_id: the task's UUID (required)
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if description:
updates["description"] = description
if status:
updates["status"] = status
if priority:
updates["priority"] = priority
if assignees:
updates["assignee"] = assignees
if due_date != -1:
updates["dueDate"] = due_date or None
if project_id:
updates["projectId"] = project_id
result = await execute_on_client(
action="update",
table="tasks",
data={"id": task_id, "updates": updates},
)
row = result["row"]
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
@tool
async def delete_task(task_id: str) -> str:
"""Delete a task permanently by its UUID."""
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
return f"Task {task_id} deleted."
@tool
async def list_tasks_due_today() -> str:
"""List all tasks whose due date falls on today's date."""
now = datetime.now(tz=timezone.utc)
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
end_ms = start_ms + 86_400_000 - 1 # last ms of today
result = await execute_on_client(
action="select",
table="tasks",
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
)
rows = result.get("rows", [])
if not rows:
return "No tasks are due today."
lines = [
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
for r in rows
]
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
# ── Task comment tools ────────────────────────────────────────────────
@tool
async def list_task_comments(task_id: str) -> str:
"""List all comments on a task by its UUID."""
result = await execute_on_client(
action="select",
table="taskComments",
filters={"taskId": task_id},
)
rows = result.get("rows", [])
if not rows:
return f"No comments found for task {task_id}."
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
@tool
async def add_task_comment(task_id: str, author: str, content: str) -> str:
"""Add a comment to a task.
task_id: UUID of the task to comment on
author: name or ID of the comment author
content: comment text
"""
result = await execute_on_client(
action="insert",
table="taskComments",
data={"taskId": task_id, "author": author, "content": content},
)
row = result.get("row", {})
row_author = row.get("author", author)
row_task_id = row.get("taskId") or row.get("task_id") or task_id
row_comment_id = row.get("id", "unknown")
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
@tool
async def delete_task_comment(comment_id: str) -> str:
"""Delete a task comment by its UUID."""
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
return f"Comment {comment_id} deleted."
# ── Agent ─────────────────────────────────────────────────────────────
TASK_TOOLS: list[Any] = [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]

View File

@@ -0,0 +1,117 @@
"""Timeline agent — project milestone management (list, create, update, delete).
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
"""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TIMELINE_SYSTEM_PROMPT = (
"You are a project timeline assistant. Timelines are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - For update_timeline, use -1 for integer fields you do not want to change\n"
" - Listing without a project_id returns all timelines across projects\n"
" - Always echo the title and formatted date in your confirmation."
)
@tool
async def list_timelines(project_id: str = "") -> str:
"""List timelines. Provide project_id to scope to a specific project."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="timelines",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No timelines found."
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
@tool
async def create_timeline(
project_id: str,
title: str,
date: int,
is_ai_suggested: int = 0,
) -> str:
"""Create a project timeline (milestone).
project_id: REQUIRED UUID of the parent project
title: descriptive name for the milestone
date: Unix timestamp in milliseconds
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
"""
result = await execute_on_client(
action="insert",
table="timelines",
data={
"projectId": project_id,
"title": title,
"date": date,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
@tool
async def update_timeline(
timeline_id: str,
title: str = "",
date: int = -1,
) -> str:
"""Update a timeline. Only pass fields that should change.
timeline_id: UUID of the timeline (required)
date: -1 means unchanged; any other value sets the new date (ms timestamp)
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if date != -1:
updates["date"] = date
result = await execute_on_client(
action="update",
table="timelines",
data={"id": timeline_id, "updates": updates},
)
row = result["row"]
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
@tool
async def delete_timeline(timeline_id: str) -> str:
"""Delete a timeline permanently by its UUID."""
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
return f"Timeline {timeline_id} deleted."
TIMELINE_TOOLS: list[Any] = [
list_timelines,
create_timeline,
update_timeline,
delete_timeline,
]

View File

@@ -0,0 +1,883 @@
"""Single-agent runners for home and floating chat contexts.
Adapted from app/core/deep_agent.py for the Chat Service.
Import paths changed to use local app modules and shared/.
"""
from __future__ import annotations
import json
import logging
import re
from datetime import date
from collections.abc import AsyncGenerator
from typing import Any, Literal
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.llm import get_llm
from app.memory_middleware import MemoryMiddleware
from app.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
from app import tracing
from shared.db import async_session
logger = logging.getLogger(__name__)
FloatingDomainType = Literal["task", "timeline", "project", "node"]
FloatingDomainSection = Literal["task", "timeline", "note"]
_HOME_SINGLE_AGENT_SYSTEM = (
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Always use tools for factual data retrieval before answering. "
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
)
_FLOATING_SINGLE_AGENT_SYSTEM = (
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Stay focused on the floating scope in context.scope and answer concisely. "
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
"Always use tools for factual data retrieval before answering. "
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
)
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
"You are a strict domain classifier for websocket floating requests. "
"Return ONLY a JSON object with keys: type, id, section. "
"Allowed type values: task, timeline, project, node. "
"Allowed section values: task, timeline, note, or null. "
"Rules: infer from user message intent first; do not blindly trust scope.type. "
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
"If project id is unknown but context.resolved_project_id exists, use it as id. "
"If id is unknown, use null. "
"No markdown, no prose, JSON only."
)
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
def _candidate_tokens(message: str) -> list[str]:
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
return [token for token in tokens if len(token) >= 3]
async def _resolve_project_id_from_message(message: str) -> str | None:
"""Resolve likely project UUID from user message using client project list."""
try:
result = await execute_on_client(action="select", table="projects")
except Exception as exc:
logger.warning("deep_agent: project resolve select failed: %s", exc)
return None
rows = result.get("rows", [])
if not isinstance(rows, list) or not rows:
return None
tokens = _candidate_tokens(message)
scored: list[tuple[int, dict[str, Any]]] = []
for row in rows:
if not isinstance(row, dict):
continue
name = str(row.get("name", "")).lower()
score = sum(1 for token in tokens if token in name)
if score > 0:
scored.append((score, row))
if not scored:
return None
scored.sort(key=lambda item: item[0], reverse=True)
top_score = scored[0][0]
top_rows = [row for score, row in scored if score == top_score]
if len(top_rows) != 1:
return None
project_id = top_rows[0].get("id")
return project_id if isinstance(project_id, str) else None
def _needs_project_resolution(message: str) -> bool:
lowered = message.lower()
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
prepared = dict(context)
if _needs_project_resolution(message):
resolved_project_id = await _resolve_project_id_from_message(message)
if resolved_project_id:
prepared["resolved_project_id"] = resolved_project_id
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
return prepared
def _all_tools() -> list[Any]:
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
debug = context.get("_debug")
if isinstance(debug, dict):
request_id = debug.get("request_id")
if isinstance(request_id, str) and request_id:
return request_id
return None
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
sanitized = dict(context)
sanitized.pop("_debug", None)
return sanitized
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
def _is_upcoming_timeline_query(message: str) -> bool:
lowered = message.lower()
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
has_timeline_topic = any(
token in lowered
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
)
return has_upcoming and has_timeline_topic
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
match = _TIMELINE_DMY_RE.search(dmy)
if not match:
return True
try:
parsed = date(
int(match.group("y")),
int(match.group("m")),
int(match.group("d")),
)
except ValueError:
return True
today = date.today()
return parsed >= today and parsed.year == today.year and parsed.month == today.month
def _normalize_tagged_list_lines(text: str, message: str) -> str:
if not text:
return text
upcoming_timeline_only = _is_upcoming_timeline_query(message)
output_lines: list[str] = []
for line in text.splitlines():
matches = list(_TAG_LINE_RE.finditer(line))
if not matches:
output_lines.append(line)
continue
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
if not had_non_tag_text and len(matches) == 1:
tag_text = matches[0].group(0)
if (
upcoming_timeline_only
and "<timeline>" in tag_text
and not _timeline_date_in_current_month_or_future(line)
):
continue
output_lines.append(tag_text)
continue
for match in matches:
tag_text = match.group(0)
if (
upcoming_timeline_only
and "<timeline>" in tag_text
and not _timeline_date_in_current_month_or_future(line)
):
continue
output_lines.append(tag_text)
return "\n".join(output_lines)
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
_FLOATING_EMPTY_FALLBACK = "No results found."
def _strip_floating_markup_fragment(text: str) -> str:
if not text:
return text
cleaned = _GENERIC_TAG_RE.sub("", text)
return _BRACKETED_ID_RE.sub("", cleaned)
def _strip_floating_markup(text: str) -> str:
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
if not text:
return text
cleaned = _strip_floating_markup_fragment(text)
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
return "\n".join(line for line in lines if line)
def _fallback_from_raw_floating_text(raw_text: str) -> str:
fallback = _strip_floating_markup_fragment(raw_text or "")
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
return fallback or _FLOATING_EMPTY_FALLBACK
class _FloatingStreamSanitizer:
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
def __init__(self) -> None:
self._pending = ""
@staticmethod
def _split_safe_boundary(text: str) -> tuple[str, str]:
boundary = len(text)
last_lt = text.rfind("<")
if last_lt != -1 and ">" not in text[last_lt:]:
boundary = min(boundary, last_lt)
last_lb = text.rfind("[")
if last_lb != -1 and "]" not in text[last_lb:]:
boundary = min(boundary, last_lb)
if boundary == len(text):
return text, ""
return text[:boundary], text[boundary:]
def feed(self, chunk: str) -> str:
combined = f"{self._pending}{chunk}"
safe_text, self._pending = self._split_safe_boundary(combined)
return _strip_floating_markup_fragment(safe_text)
def finalize(self) -> str:
tail = re.sub(r"<[^>\n]*$", "", self._pending)
tail = re.sub(r"\[[^\]\n]*$", "", tail)
self._pending = ""
return _strip_floating_markup_fragment(tail)
def _normalize_memory_label(path_or_label: str) -> str:
value = path_or_label.strip()
if value.startswith("/memories/"):
value = value[len("/memories/"):]
value = value.strip("/")
return value
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
@tool
async def memory_list_blocks() -> str:
"""List all core memory blocks currently stored for the user."""
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
async with async_session() as db:
memory = MemoryMiddleware(db)
blocks = await memory.list_core_blocks(user_id)
if not blocks:
return "No memory blocks found."
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
return "Memory blocks:\n" + "\n".join(lines)
@tool
async def memory_get(path_or_label: str) -> str:
"""Get one memory block by label or /memories/<label> path."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
value = await memory.get_core_block(user_id, label)
if value is None:
return f"Memory block '{label}' not found."
return f"Memory block '{label}':\n{value}"
@tool
async def memory_create(path_or_label: str, value: str) -> str:
"""Create or overwrite a memory block value by label or /memories/<label> path."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, label, value, trace_id=trace_id)
return f"Memory block '{label}' saved."
@tool
async def memory_append(path_or_label: str, content: str) -> str:
"""Append content to a memory block, creating it if missing."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.append_core(user_id, label, content)
return f"Memory block '{label}' appended."
@tool
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
"""Replace one exact string in a memory block."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
changed = await memory.replace_core(user_id, label, old_string, new_string)
if not changed:
return f"No replacement made in '{label}' (old string not found)."
return f"Memory block '{label}' updated."
@tool
async def memory_delete(path_or_label: str) -> str:
"""Delete a memory block by label or /memories/<label> path."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
deleted = await memory.delete_core(user_id, label)
if not deleted:
return f"Memory block '{label}' not found."
return f"Memory block '{label}' deleted."
@tool
async def archival_memory_insert(content: str) -> str:
"""Insert a long-term archival memory entry."""
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.insert_archival(user_id, content, source="assistant")
return "Archival memory saved."
@tool
async def archival_memory_search(query: str, top_k: int = 5) -> str:
"""Search long-term archival memory by semantic fallback (keyword currently)."""
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
async with async_session() as db:
memory = MemoryMiddleware(db)
results = await memory.search_archival(user_id, query, top_k=top_k)
if not results:
return "No archival memory results found."
lines = [f"- {item}" for item in results]
return "Archival memory results:\n" + "\n".join(lines)
@tool
async def conversation_search(query: str, top_k: int = 5) -> str:
"""Search recall memory from prior episodic conversation summaries."""
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
async with async_session() as db:
memory = MemoryMiddleware(db)
results = await memory.search_recall(user_id, query, top_k=top_k)
if not results:
return "No recall memory results found."
lines = [f"- {item}" for item in results]
return "Recall memory results:\n" + "\n".join(lines)
return [
memory_list_blocks,
memory_get,
memory_create,
memory_append,
memory_replace,
memory_delete,
archival_memory_insert,
archival_memory_search,
conversation_search,
]
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
lowered = message.lower()
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
return "timeline"
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
return "task"
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
return "note"
return None
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
type_raw = str(payload.get("type") or "").strip().lower()
domain_type: FloatingDomainType = "task"
if type_raw in {"task", "timeline", "project", "node"}:
domain_type = type_raw
id_value = payload.get("id")
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
if domain_type == "project" and not domain_id:
domain_id = fallback_id
section_raw = payload.get("section")
section: FloatingDomainSection | None = None
if isinstance(section_raw, str):
section_candidate = section_raw.strip().lower()
if section_candidate in {"task", "timeline", "note"}:
section = section_candidate
if domain_type != "project":
section = None
return {
"type": domain_type,
"id": domain_id,
"section": section,
}
def _parse_json_object(text: str) -> dict[str, Any] | None:
raw = text.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", raw, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
section = _detect_domain_section(message)
scope = context.get("scope") if isinstance(context, dict) else None
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
if isinstance(scope, dict):
scope_type = str(scope.get("type") or "").strip().lower()
scope_id = scope.get("id")
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
if scope_type in {"task", "tasks"}:
return {"type": "task", "id": scope_id_value, "section": None}
if scope_type in {"project", "projects"}:
project_scope_id = scope_id_value or project_id
return {
"type": "project",
"id": project_scope_id,
"section": section,
}
if scope_type in {"note", "notes"}:
return {
"type": "node",
"id": scope_id_value,
"section": None,
}
if scope_type in {"timeline", "timelines"}:
return {"type": "timeline", "id": scope_id_value, "section": None}
lowered = message.lower()
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
return {
"type": "project",
"id": project_id,
"section": section,
}
if section == "timeline":
return {"type": "timeline", "id": None, "section": None}
if section == "note":
return {"type": "node", "id": None, "section": None}
return {"type": "task", "id": None, "section": None}
async def _infer_floating_domain(
message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None,
) -> dict[str, str | None]:
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
classifier_context = {
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
"resolved_project_id": project_id,
}
try:
classifier_prompt = _get_system_prompt(
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_SYSTEM,
)
callbacks = _build_callbacks(langfuse_handler)
llm = get_llm(callbacks=callbacks)
response = await llm.ainvoke(
[
SystemMessage(content=classifier_prompt),
HumanMessage(
content=(
f"Message:\n{message}\n\n"
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
)
),
]
)
parsed = _parse_json_object(_as_text(response.content))
if parsed is not None:
domain = _normalize_domain_payload(parsed, project_id)
logger.info(
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
domain.get("type"),
domain.get("id"),
domain.get("section"),
)
return domain
logger.warning("deep_agent: floating_domain classifier returned non-json output")
except Exception as exc:
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
return _infer_floating_domain_rule_based(message, context)
def _get_system_prompt(langfuse_name: str, fallback: str) -> str:
"""Fetch a managed prompt from Langfuse, falling back to the hardcoded string."""
managed = tracing.get_prompt(langfuse_name, fallback=None)
return managed if managed is not None else fallback
def _build_callbacks(langfuse_handler: Any | None) -> list[Any] | None:
"""Return a callbacks list if a Langfuse handler is available."""
if langfuse_handler is None:
return None
return [langfuse_handler]
async def _run_single_agent(
*,
user_id: str,
system_prompt: str,
message: str,
context: dict[str, Any],
max_steps: int = 6,
langfuse_handler: Any | None = None,
) -> str:
trace_id = _trace_id_from_context(context)
callbacks = _build_callbacks(langfuse_handler)
llm = get_llm(callbacks=callbacks)
tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(
content=(
f"User message:\n{message}\n\n"
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
)
),
]
tool_calls_count = 0
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
try:
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
final_text = _as_text(response.content)
logger.info(
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
trace_id or "-",
user_id,
tool_calls_count,
len(final_text),
)
return final_text
tool_map = {tool_def.name: tool_def for tool_def in tools}
for call in response.tool_calls:
tool_calls_count += 1
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
call_id,
call_name,
str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
final_text = _as_text(final.content)
logger.info(
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
trace_id or "-",
user_id,
tool_calls_count,
len(final_text),
)
return final_text
finally:
clear_tool_result_collector()
async def _run_single_agent_stream(
*,
user_id: str,
system_prompt: str,
message: str,
context: dict[str, Any],
max_steps: int = 6,
langfuse_handler: Any | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
trace_id = _trace_id_from_context(context)
callbacks = _build_callbacks(langfuse_handler)
llm = get_llm(callbacks=callbacks)
tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(
content=(
f"User message:\n{message}\n\n"
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
)
),
]
tool_calls_count = 0
streamed_chars = 0
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
try:
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
emitted_any = False
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
streamed_chars += len(token)
emitted_any = True
yield "token", token
if not emitted_any:
fallback_text = _as_text(response.content)
if fallback_text:
streamed_chars += len(fallback_text)
yield "token", fallback_text
logger.info(
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
trace_id or "-",
user_id,
tool_calls_count,
streamed_chars,
)
return
tool_map = {tool_def.name: tool_def for tool_def in tools}
for call in response.tool_calls:
tool_calls_count += 1
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
call_id,
call_name,
str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
streamed_chars += len(token)
yield "token", token
logger.info(
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
trace_id or "-",
user_id,
tool_calls_count,
streamed_chars,
)
finally:
clear_tool_result_collector()
async def run_home(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> str:
prepared_context = await _prepare_context(message, context)
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
response = await _run_single_agent(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_handler=langfuse_handler,
)
return _normalize_tagged_list_lines(response, message)
async def run_floating(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> tuple[str, dict[str, str | None]]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
response = await _run_single_agent(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_handler=langfuse_handler,
)
sanitized = _strip_floating_markup(response)
if not sanitized and response:
sanitized = _fallback_from_raw_floating_text(response)
return sanitized, domain
async def run_home_stream(
user_id: str,
message: str,
context: dict[str, Any],
*,
langfuse_handler: Any | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context)
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
text_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_handler=langfuse_handler,
):
event_type, data = event
if event_type != "token":
yield event
continue
text_chunks.append(str(data or ""))
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
if normalized:
yield "token", normalized
async def run_floating_stream(
user_id: str,
message: str,
context: dict[str, Any],
*,
langfuse_handler: Any | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
yield "floating_domain", domain
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False
raw_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_handler=langfuse_handler,
):
event_type, data = event
if event_type != "token":
yield event
continue
raw_chunk = str(data or "")
raw_chunks.append(raw_chunk)
sanitized_chunk = sanitizer.feed(raw_chunk)
if sanitized_chunk:
emitted_sanitized = True
yield "token", sanitized_chunk
tail = sanitizer.finalize()
if tail:
emitted_sanitized = True
yield "token", tail
if not emitted_sanitized and raw_chunks:
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
async def update_core_memory(user_id: str, key: str, value: str) -> None:
"""Compatibility helper kept for callers that expect explicit memory update API."""
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, key, value)

72
services/chat/app/llm.py Normal file
View File

@@ -0,0 +1,72 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Adapted from app/core/llm.py for the Chat Service.
Uses shared.config.settings instead of app.config.settings.
"""
from __future__ import annotations
import os
import warnings
from openai import AsyncOpenAI
import litellm
from langchain_openai import ChatOpenAI
from langchain_litellm import ChatLiteLLM
from shared.config import settings
litellm.drop_params = True
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
if model.startswith("anthropic/"):
return settings.ANTHROPIC_API_KEY or None
if model.startswith("gemini/") or model.startswith("google/"):
return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None
if model.startswith("github_copilot/"):
return None
return settings.OPENAI_API_KEY or None
def get_llm(
*,
model: str | None = None,
temperature: float = 0,
callbacks: list | None = None,
) -> ChatOpenAI | ChatLiteLLM:
model = model or settings.LLM_MODEL
if settings.GITHUB_COPILOT_TOKEN_DIR:
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
if "/" in model:
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=_api_key_for_model(model),
callbacks=callbacks,
)
async def embed(text: str) -> list[float]:
model = settings.LLM_EMBED_MODEL
if model.startswith("github_copilot/") or "/" in model:
response = await litellm.aembedding(model=model, input=[text])
return response.data[0]["embedding"]
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
response = await client.embeddings.create(model=model, input=text)
return response.data[0].embedding

87
services/chat/app/main.py Normal file
View File

@@ -0,0 +1,87 @@
"""Chat Service — LLM orchestration, domain agents, memory.
Consumes chat requests from Redis, executes deep_agent (home/floating),
streams responses back via Redis pub/sub to WS Gateway.
Owns: memory_core, memory_associative, memory_episodic, memory_proactive tables.
"""
import sys
from contextlib import asynccontextmanager
import logging
from pathlib import Path
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
_repo_root = str(Path(__file__).resolve().parents[3])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from shared.config import settings
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialise Langfuse tracing (no-op if keys are missing)
from app.tracing import init_langfuse
init_langfuse()
# Start Redis consumer in background
from app.redis_consumer import start_consumer
consumer_task = start_consumer()
yield
consumer_task.cancel()
from app.tracing import shutdown as shutdown_langfuse
shutdown_langfuse()
from shared.db import engine
await engine.dispose()
from shared.redis import redis_client
await redis_client.aclose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva Chat Service",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from app.routes import router
app.include_router(router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "service": "chat", "version": app.version}
return app
app = create_app()

View File

@@ -0,0 +1,295 @@
"""Memory Middleware — adapted for Chat Service.
Uses shared.models instead of app.models. Otherwise identical to the
monolith's app/core/memory_middleware.py.
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from shared.models import (
MemoryAssociative,
MemoryCore,
MemoryEpisodic,
MemoryProactive,
User,
)
logger = logging.getLogger(__name__)
_ASSOCIATIVE_TOP_K = 5
_EPISODIC_RECENT_N = 10
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
class MemoryMiddleware:
def __init__(self, db: AsyncSession) -> None:
self._db = db
async def enrich_context(
self,
user_id: str,
message: str,
trace_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
fernet = await self._get_fernet(user_id)
if fernet is None:
return {}
core = await self._load_core(user_id, fernet)
associative = await self._load_associative(user_id, message, fernet)
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
proactive = await self._load_proactive(user_id, fernet)
logger.info(
"memory: enrich_context trace=%s user=%s core=%d assoc=%d episodic=%d proactive=%d",
trace_id or "-", user_id, len(core), len(associative), len(episodic), len(proactive),
)
return {
"core_memory": core,
"associative_memory": associative,
"episodic_memory": episodic,
"proactive_hints": proactive,
}
async def store_episode(
self, user_id: str, session_id: str, message: str, response: str,
trace_id: str | None = None,
) -> None:
fernet = await self._get_fernet(user_id)
if fernet is None:
return
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
encrypted = _encrypt(fernet, summary)
row = MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=user_id,
summary_encrypted=encrypted,
session_id=session_id,
)
self._db.add(row)
try:
await self._db.commit()
except Exception as exc:
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, value)
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == key)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.value_encrypted = encrypted
else:
self._db.add(MemoryCore(
id=str(uuid.uuid4()), user_id=user_id, key=key, value_encrypted=encrypted,
))
try:
await self._db.commit()
except Exception as exc:
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
await self._db.rollback()
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id).order_by(MemoryCore.key.asc())
)
out: list[dict[str, str]] = []
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out.append({"label": row.key, "value": plaintext})
return out
async def get_core_block(self, user_id: str, label: str) -> str | None:
fernet = await self._get_fernet(user_id)
if fernet is None:
return None
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
)
row = result.scalar_one_or_none()
if row is None:
return None
return _safe_decrypt(fernet, row.value_encrypted)
async def delete_core(self, user_id: str, label: str) -> bool:
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
)
row = result.scalar_one_or_none()
if row is None:
return False
await self._db.delete(row)
try:
await self._db.commit()
return True
except Exception as exc:
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
await self._db.rollback()
return False
async def append_core(self, user_id: str, label: str, content: str) -> None:
current = await self.get_core_block(user_id, label)
if current is None:
await self.update_core(user_id, label, content)
return
await self.update_core(user_id, label, f"{current}\n{content}")
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
current = await self.get_core_block(user_id, label)
if current is None or old not in current:
return False
await self.update_core(user_id, label, current.replace(old, new, 1))
return True
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, content)
row = MemoryAssociative(
id=str(uuid.uuid4()), user_id=user_id,
content_encrypted=encrypted, embedding=None,
entity_type=source, entity_id=None,
)
self._db.add(row)
try:
await self._db.commit()
except Exception as exc:
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc()).limit(100)
)
needle = query.strip().lower()
out: list[str] = []
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
return out
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
.order_by(MemoryEpisodic.created_at.desc()).limit(100)
)
needle = query.strip().lower()
out: list[str] = []
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
return out
# ── Private ───────────────────────────────────────────────────────
async def _get_fernet(self, user_id: str) -> Fernet | None:
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.encryption_key:
logger.warning("memory: no encryption_key for user=%s", user_id)
return None
return Fernet(user.encryption_key.encode())
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id)
)
out: dict[str, str] = {}
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out[row.key] = plaintext
return out
async def _load_associative(self, user_id: str, message: str, fernet: Fernet) -> list[str]:
result = await self._db.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc()).limit(_ASSOCIATIVE_TOP_K)
)
out: list[str] = []
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
async def _load_episodic(self, user_id: str, fernet: Fernet, session_id: str | None = None) -> list[str]:
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
if session_id:
query = query.where(MemoryEpisodic.session_id == session_id)
result = await self._db.execute(
query.order_by(MemoryEpisodic.created_at.desc()).limit(_EPISODIC_RECENT_N)
)
out: list[str] = []
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
result = await self._db.execute(
select(MemoryProactive).where(
MemoryProactive.user_id == user_id,
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
).order_by(MemoryProactive.confidence.desc())
)
out: list[str] = []
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
def _encrypt(fernet: Fernet, plaintext: str) -> str:
return fernet.encrypt(plaintext.encode()).decode()
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
try:
return fernet.decrypt(ciphertext.encode()).decode()
except (InvalidToken, Exception) as exc:
logger.warning("memory: decrypt failed: %s", exc)
return None

View File

@@ -0,0 +1,50 @@
"""Output formatter for deep-agent stream events — Chat Service copy.
Converts (event_type, data) tuples into WebSocket frame Pydantic models.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from typing import Any
from shared.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
class StreamFormatter:
"""Convert `(event_type, data)` stream events into websocket frame models."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
async def format(
self,
event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]:
started = False
async for event_type, data in event_stream:
if event_type == "floating_domain":
if isinstance(data, dict):
yield WsFloatingDomain(
request_id=self.request_id,
domain=data,
)
continue
if event_type != "token":
continue
if not started:
yield WsStreamStart(request_id=self.request_id)
started = True
text = str(data or "")
if text:
yield WsStreamText(request_id=self.request_id, chunk=text)
if not started:
yield WsStreamStart(request_id=self.request_id)
yield WsStreamEnd(request_id=self.request_id)

View File

@@ -0,0 +1,209 @@
"""Redis consumer — listens for chat requests and dispatches to deep_agent.
Subscribes to a Redis pattern channel chat:request:* so it receives
requests for ALL users. Each request is processed in a separate asyncio task.
"""
from __future__ import annotations
import asyncio
import json
import logging
from uuid import uuid4
from shared.db import async_session
from shared.redis import redis_client, ws_out_channel
from app.deep_agent import run_floating_stream, run_home_stream
from app.memory_middleware import MemoryMiddleware
from app.output_formatter import StreamFormatter
from app.ws_context import clear_current_user, set_current_user
from app import tracing
logger = logging.getLogger(__name__)
def start_consumer() -> asyncio.Task:
"""Start the Redis consumer as a background asyncio task."""
return asyncio.create_task(_consumer_loop())
async def _consumer_loop() -> None:
"""Subscribe to chat:request:* and dispatch incoming frames."""
pubsub = redis_client.pubsub()
await pubsub.psubscribe("chat:request:*")
logger.info("redis_consumer: subscribed to chat:request:*")
try:
while True:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if message is not None and message["type"] == "pmessage":
frame = json.loads(message["data"])
asyncio.create_task(_dispatch(frame))
else:
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("redis_consumer: shutting down")
finally:
await pubsub.punsubscribe()
await pubsub.aclose()
async def _dispatch(frame: dict) -> None:
"""Route a chat request frame to the appropriate handler."""
frame_type = frame.get("type")
user_id = frame.get("user_id")
if not user_id:
logger.warning("redis_consumer: frame missing user_id: %s", frame.get("type"))
return
if frame_type == "home_request":
await _handle_home_request(user_id, frame)
elif frame_type == "floating_request":
await _handle_floating_request(user_id, frame)
else:
logger.debug("redis_consumer: unknown frame type %r", frame_type)
async def _publish_frame(user_id: str, frame_data: str) -> None:
"""Publish a frame to ws:out:{user_id} for the WS Gateway to forward."""
channel = ws_out_channel(user_id)
await redis_client.publish(channel, frame_data)
async def _handle_home_request(user_id: str, frame: dict) -> None:
"""Process a home_request — enrich with memory, run deep_agent, stream results."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
logger.info(
"redis_consumer: home_request user=%s req=%s msg=%s",
user_id, request_id, message[:200],
)
response_chunks: list[str] = []
with tracing.trace_span(
name="home_request",
user_id=user_id,
session_id=session_id,
trace_id=request_id,
input=message,
metadata={"message_preview": message[:200]},
tags=["home"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
# Enrich with memory context
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id, message,
trace_id=request_id, session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
set_current_user(user_id)
try:
event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await _publish_frame(user_id, ws_frame.model_dump_json())
if hasattr(ws_frame, "chunk"):
response_chunks.append(ws_frame.chunk)
except Exception as exc:
logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc)
finally:
clear_current_user()
# Link prompt and attach output preview
tracing.link_prompt_to_trace(span, "home_system")
response_text = "".join(response_chunks)
span.update(output=response_text[:500] if response_text else None)
tracing.flush()
# Store episode
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks),
trace_id=request_id,
)
async def _handle_floating_request(user_id: str, frame: dict) -> None:
"""Process a floating_request — enrich with memory, run deep_agent, stream results."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {})
logger.info(
"redis_consumer: floating_request user=%s req=%s scope=%s msg=%s",
user_id, request_id, json.dumps(scope)[:200], message[:200],
)
response_chunks: list[str] = []
with tracing.trace_span(
name="floating_request",
user_id=user_id,
session_id=session_id,
trace_id=request_id,
input=message,
metadata={"message_preview": message[:200], "scope": scope},
tags=["floating"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
# Enrich with memory context
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id, message,
trace_id=request_id, session_id=session_id,
)
context: dict = {
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
set_current_user(user_id)
try:
event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await _publish_frame(user_id, ws_frame.model_dump_json())
if hasattr(ws_frame, "chunk"):
response_chunks.append(ws_frame.chunk)
except Exception as exc:
logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc)
finally:
clear_current_user()
# Link prompt and attach output preview
tracing.link_prompt_to_trace(span, "floating_system")
response_text = "".join(response_chunks)
span.update(output=response_text[:500] if response_text else None)
tracing.flush()
# Store episode
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks),
trace_id=request_id,
)

View File

@@ -0,0 +1,37 @@
"""Chat REST route — POST /chat fallback when WS is unavailable."""
from __future__ import annotations
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from shared.schemas import ChatRequest
from app.deep_agent import run_home
from app.ws_context import clear_current_user, set_current_user
router = APIRouter(prefix="/chat", tags=["chat"])
@router.post("")
async def chat(body: ChatRequest, request: Request) -> JSONResponse:
"""REST fallback for home chat.
In the microservices setup, Traefik ForwardAuth has already validated
the JWT and injected X-User-Id / X-User-Email / X-User-Tier headers.
"""
user_id = request.headers.get("X-User-Id", "")
if not user_id:
return JSONResponse(status_code=401, content={"detail": "Missing X-User-Id header"})
set_current_user(user_id)
try:
response = await run_home(
user_id=user_id,
message=body.message,
context=body.context.model_dump(),
)
finally:
clear_current_user()
return JSONResponse(content={"response": response})

View File

@@ -0,0 +1,264 @@
"""Langfuse tracing & prompt management for the Chat Service (v4 SDK).
Provides:
- ``init_langfuse()`` — initialise the singleton client at startup
- ``trace_span()`` — context manager that creates a trace + span
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
- ``flush()`` / ``shutdown()`` — lifecycle management
All functions gracefully degrade to no-ops when Langfuse is not configured,
so the service works identically with or without observability keys.
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
"""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import Any
from shared.config import settings
logger = logging.getLogger(__name__)
# ── State ────────────────────────────────────────────────────────────────
_initialised: bool = False
_disabled: bool = False
def _is_configured() -> bool:
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
def init_langfuse() -> None:
"""Initialise the Langfuse singleton. Call once at startup."""
global _initialised, _disabled
if _initialised or _disabled:
return
if not _is_configured():
_disabled = True
logger.info("tracing: Langfuse keys not set — tracing disabled")
return
try:
from langfuse import Langfuse
Langfuse(
secret_key=settings.LANGFUSE_SECRET_KEY,
public_key=settings.LANGFUSE_PUBLIC_KEY,
host=settings.LANGFUSE_HOST,
)
_initialised = True
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
except Exception as exc:
_disabled = True
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
def _get_client() -> Any | None:
"""Return the singleton Langfuse client, or *None* if disabled."""
if _disabled:
return None
if not _initialised:
init_langfuse()
if _disabled:
return None
try:
from langfuse import get_client
return get_client()
except Exception:
return None
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
class _NullSpan:
"""Drop-in replacement when Langfuse is disabled."""
def update(self, **_: Any) -> None: ...
def set_trace_io(self, **_: Any) -> None: ...
def score_trace(self, **_: Any) -> None: ...
# ── Trace context manager ───────────────────────────────────────────────
@contextmanager
def trace_span(
*,
name: str,
user_id: str,
session_id: str | None = None,
trace_id: str | None = None,
input: Any = None,
metadata: dict[str, Any] | None = None,
tags: list[str] | None = None,
):
"""Context manager that creates a Langfuse trace/span.
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
A ``CallbackHandler`` created inside this block auto-inherits the trace
context, so there is no need to pass trace IDs manually.
"""
lf = _get_client()
if lf is None:
yield _NullSpan()
return
try:
from langfuse import Langfuse, propagate_attributes
trace_ctx: dict[str, str] = {}
if trace_id is not None:
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
with lf.start_as_current_observation(
as_type="span",
name=name,
input=input,
metadata=metadata or {},
**({"trace_context": trace_ctx} if trace_ctx else {}),
) as span:
with propagate_attributes(
user_id=user_id,
session_id=session_id,
tags=tags or [],
):
yield span
except Exception as exc:
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
yield _NullSpan()
# ── LangChain callback handler ──────────────────────────────────────────
def get_langfuse_callback() -> Any | None:
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
Must be called inside a ``trace_span()`` block for proper linking.
Returns *None* when Langfuse is disabled.
"""
if _disabled and not _initialised:
return None
try:
from langfuse.langchain import CallbackHandler
return CallbackHandler()
except Exception as exc:
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
return None
# ── Prompt management ────────────────────────────────────────────────────
def get_prompt(
name: str,
*,
version: int | None = None,
label: str | None = None,
fallback: str | None = None,
cache_ttl_seconds: int = 300,
) -> str | None:
"""Fetch a managed prompt from Langfuse by name.
Returns the compiled prompt string, or *fallback* if the prompt is not
found or Langfuse is disabled.
"""
lf = _get_client()
if lf is None:
return fallback
try:
kwargs: dict[str, Any] = {
"name": name,
"cache_ttl_seconds": cache_ttl_seconds,
}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
prompt = lf.get_prompt(**kwargs)
return prompt.prompt
except Exception as exc:
logger.warning("tracing: get_prompt(%s) failed: %s", name, exc)
return fallback
def link_prompt_to_trace(
span: Any,
prompt_name: str,
*,
version: int | None = None,
label: str | None = None,
) -> None:
"""Attach prompt metadata to a span/trace."""
lf = _get_client()
if lf is None or isinstance(span, _NullSpan):
return
try:
kwargs: dict[str, Any] = {"name": prompt_name}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
prompt = lf.get_prompt(**kwargs)
span.update(metadata={"prompt": {"name": prompt_name, "version": prompt.version}})
except Exception as exc:
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
# ── Scoring helper ───────────────────────────────────────────────────────
def score_trace(
trace_id: str,
name: str,
value: float,
*,
comment: str | None = None,
) -> None:
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
lf = _get_client()
if lf is None:
return
try:
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
except Exception as exc:
logger.warning("tracing: score_trace failed: %s", exc)
# ── Shutdown ─────────────────────────────────────────────────────────────
def flush() -> None:
"""Flush pending Langfuse events."""
lf = _get_client()
if lf is not None:
try:
lf.flush()
except Exception as exc:
logger.warning("tracing: flush failed: %s", exc)
def shutdown() -> None:
"""Flush and close the Langfuse client."""
global _initialised, _disabled
lf = _get_client()
if lf is not None:
try:
lf.flush()
lf.shutdown()
except Exception as exc:
logger.warning("tracing: shutdown failed: %s", exc)
_initialised = False
_disabled = False

View File

@@ -0,0 +1,115 @@
"""WebSocket context for Chat Service — Redis-based tool call round-trip.
Replaces the monolith's ws_context.py. Instead of calling Electron directly
via WebSocket, this publishes tool_call frames to Redis (ws:out:{user_id})
and awaits the result via BRPOP on tool:result:{call_id}.
"""
from __future__ import annotations
import json
import logging
from contextvars import ContextVar
from typing import Any
from uuid import uuid4
from shared.redis import redis_client, tool_result_key, ws_out_channel
logger = logging.getLogger(__name__)
_TOOL_CALL_TIMEOUT = 30 # seconds — BRPOP timeout
# Per-request user_id context var (set before agent runs)
_current_user_id: ContextVar[str | None] = ContextVar("_current_user_id", default=None)
# Optional collector for debug
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None
)
def set_current_user(user_id: str) -> None:
_current_user_id.set(user_id)
def clear_current_user() -> None:
_current_user_id.set(None)
def set_tool_result_collector(lst: list[dict]) -> None:
_tool_result_collector.set(lst)
def clear_tool_result_collector() -> None:
_tool_result_collector.set(None)
async def execute_on_client(
action: str,
table: str | None = None,
data: dict[str, Any] | None = None,
filters: dict[str, Any] | None = None,
vector: list[float] | None = None,
limit: int | None = None,
) -> dict[str, Any]:
"""Send a tool_call to Electron via Redis and await the result.
1. Build tool_call payload
2. Publish to ws:out:{user_id} (WS Gateway forwards to Electron)
3. BRPOP on tool:result:{call_id} (WS Gateway pushes when Electron replies)
4. Return result dict
Raises RuntimeError if no user_id is set or if the call times out.
"""
user_id = _current_user_id.get()
if not user_id:
raise RuntimeError(
"execute_on_client() called without a user_id — "
"set_current_user() must be called first."
)
call_id = str(uuid4())
payload: dict[str, Any] = {
"type": "tool_call",
"id": call_id,
"action": action,
}
if table is not None:
payload["table"] = table
if data is not None:
payload["data"] = data
if filters is not None:
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
if vector is not None:
payload["vector"] = vector
if limit is not None:
payload["limit"] = limit
# Publish tool_call to WS Gateway → Electron
channel = ws_out_channel(user_id)
await redis_client.publish(channel, json.dumps(payload))
# Wait for Electron's tool_result
result_key = tool_result_key(call_id)
response = await redis_client.brpop(result_key, timeout=_TOOL_CALL_TIMEOUT)
if response is None:
raise RuntimeError(
f"Tool call {call_id} timed out after {_TOOL_CALL_TIMEOUT}s — "
f"device may be offline or unresponsive."
)
# response is (key, value) tuple
_, raw = response
result = json.loads(raw)
# Collect for debug if requested
collector = _tool_result_collector.get(None)
if collector is not None:
collector.append({
"action": action,
"table": table,
"data": result,
})
return result

View File

@@ -0,0 +1,17 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
sqlalchemy>=2.0.0
asyncpg>=0.30.0
redis>=5.0.0
cryptography>=42.0.0
python-dotenv>=1.0.0
langchain-core>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.3.0
litellm>=1.50.0
openai>=1.50.0
httpx>=0.27.0
langfuse>=3.0.0

View File

@@ -0,0 +1,36 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
COPY services/ws-gateway/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Shared module
COPY shared/ shared/
# Service source
COPY services/ws-gateway/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
# Single worker — each instance handles many WS connections via asyncio
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "1", \
"--timeout", "0"]

View File

@@ -0,0 +1,17 @@
# WS Gateway
Stateless WebSocket proxy. Accepts Electron connections, authenticates JWT,
routes frames to Chat/Batch services via Redis pub/sub.
## No business logic
This service does NOT know what tasks, notes, or agents are.
It only routes JSON frames between Electron and downstream services.
## Scaling
Sticky sessions on `user_id` (Traefik consistent hashing).
## Redis channels used
- Subscribe: `ws:out:{user_id}` (frames to send to client)
- Publish: `chat:request:{user_id}`, `batch:request:{user_id}`
- LPUSH: `tool:result:{call_id}` (from client tool_result frames)
- HSET/HDEL: `ws:devices:{user_id}` (device registry)

View File

View File

@@ -0,0 +1,173 @@
"""WebSocket handler — device connection lifecycle.
Accepts Electron WS connections, authenticates JWT, registers device in Redis,
and runs two concurrent loops:
1. Message loop: receive frames from Electron, route to Redis
2. Outbound loop: subscribe to Redis ws:out:{user_id}, forward to Electron
3. Heartbeat loop: ping every 30s
No business logic lives here — the handler is a JSON frame router.
"""
from __future__ import annotations
import asyncio
import json
import logging
from uuid import uuid4
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt
from shared.config import settings
from shared.schemas import WsFrameType
from app.redis_bridge import (
publish_batch_request,
publish_chat_request,
push_tool_result,
register_device,
set_gateway_id,
subscribe_outbound,
unregister_device,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["ws-gateway"])
_HEARTBEAT_INTERVAL = 30 # seconds
# Set a unique gateway instance ID on module load
set_gateway_id(str(uuid4()))
@router.websocket("/device")
async def device_ws(websocket: WebSocket) -> None:
"""Persistent WebSocket endpoint for Electron device connections."""
# ── 1. Authenticate via ?token= query parameter ──────────────────
token = websocket.query_params.get("token", "")
try:
payload = jwt.decode(
token,
settings.JWT_PUBLIC_KEY,
algorithms=["RS256"],
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id:
raise JWTError("missing sub")
except JWTError:
await websocket.close(code=1008)
return
await websocket.accept()
# ── 2. Await device_hello frame ──────────────────────────────────
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
except (asyncio.TimeoutError, WebSocketDisconnect):
await websocket.close(code=1008)
return
try:
hello = json.loads(raw)
if hello.get("type") != WsFrameType.device_hello:
raise ValueError("expected device_hello as first frame")
device_id: str = hello["device_id"]
agent_ids: list[str] = hello.get("agent_ids", [])
except (KeyError, ValueError, json.JSONDecodeError) as exc:
logger.warning("handler: invalid device_hello user=%s: %s", user_id, exc)
await websocket.close(code=1008)
return
# ── 3. Register device in Redis ──────────────────────────────────
await register_device(user_id, device_id)
logger.info("handler: connected user=%s device=%s agents=%s", user_id, device_id, agent_ids)
# Notify downstream services that device is online (for agent trigger)
await publish_batch_request(user_id, {
"type": "device_online",
"user_id": user_id,
"device_id": device_id,
"agent_ids": agent_ids,
})
# ── 4. Subscribe to outbound Redis channel ───────────────────────
pubsub = await subscribe_outbound(user_id)
# ── 5. Run concurrent loops ──────────────────────────────────────
try:
await asyncio.gather(
_inbound_loop(websocket, user_id),
_outbound_loop(websocket, pubsub),
_heartbeat_loop(websocket),
)
except WebSocketDisconnect:
pass
except Exception as exc:
logger.warning("handler: unhandled exception user=%s: %s", user_id, exc)
finally:
await pubsub.unsubscribe()
await pubsub.aclose()
await unregister_device(user_id)
logger.info("handler: disconnected user=%s device=%s", user_id, device_id)
# ── Inbound: Electron → Redis ────────────────────────────────────────
async def _inbound_loop(websocket: WebSocket, user_id: str) -> None:
"""Receive frames from Electron and route to the appropriate Redis channel."""
async for raw in websocket.iter_text():
try:
frame: dict = json.loads(raw)
except json.JSONDecodeError:
logger.warning("handler: invalid JSON from user=%s", user_id)
continue
frame_type = frame.get("type")
# Inject user_id so downstream services know who sent it
frame["user_id"] = user_id
if frame_type == WsFrameType.tool_result:
call_id = frame.get("id")
if call_id:
await push_tool_result(call_id, frame)
else:
logger.warning("handler: tool_result missing id user=%s", user_id)
elif frame_type in (WsFrameType.home_request, WsFrameType.floating_request):
await publish_chat_request(user_id, frame)
elif frame_type in (WsFrameType.journey_start, WsFrameType.journey_message):
await publish_batch_request(user_id, frame)
elif frame_type == "pong":
pass # heartbeat ack
else:
logger.debug("handler: unknown frame type %r user=%s", frame_type, user_id)
# ── Outbound: Redis → Electron ───────────────────────────────────────
async def _outbound_loop(websocket: WebSocket, pubsub) -> None:
"""Subscribe to Redis ws:out:{user_id} and forward frames to Electron."""
while True:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if message is not None and message["type"] == "message":
await websocket.send_text(message["data"])
else:
# Brief sleep to avoid busy-wait when no messages
await asyncio.sleep(0.01)
# ── Heartbeat ────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None:
"""Send ping frames every 30s to keep the connection alive."""
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
await websocket.send_text(json.dumps({"type": "ping"}))

View File

@@ -0,0 +1,56 @@
"""WS Gateway — stateless WebSocket proxy.
Accepts Electron device connections, authenticates JWT (RS256 public key),
and routes frames between Electron and downstream services via Redis pub/sub.
This service has NO business logic — it only routes JSON frames.
"""
import sys
from contextlib import asynccontextmanager
import logging
from pathlib import Path
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
_repo_root = str(Path(__file__).resolve().parents[3])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from fastapi import FastAPI
from shared.config import settings
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
from shared.redis import redis_client
await redis_client.aclose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva WS Gateway",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
from app.handler import router
app.include_router(router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "service": "ws-gateway", "version": app.version}
return app
app = create_app()

View File

@@ -0,0 +1,104 @@
"""Redis bridge — device registry + pub/sub routing.
All inter-service communication passes through Redis:
- Device registry: HSET/HDEL ws:devices:{user_id}
- Outbound frames: Subscribe ws:out:{user_id}
- Chat requests: Publish chat:request:{user_id}
- Batch requests: Publish batch:request:{user_id}
- Tool results: LPUSH tool:result:{call_id}
"""
from __future__ import annotations
import json
import logging
from shared.redis import (
batch_request_channel,
chat_request_channel,
device_key,
redis_client,
tool_result_key,
ws_out_channel,
)
logger = logging.getLogger(__name__)
# Instance ID for this gateway replica (set on startup)
_GATEWAY_ID: str = ""
def set_gateway_id(gid: str) -> None:
global _GATEWAY_ID
_GATEWAY_ID = gid
def get_gateway_id() -> str:
return _GATEWAY_ID
# ── Device Registry ──────────────────────────────────────────────────
async def register_device(user_id: str, device_id: str) -> None:
"""Register a connected device in Redis."""
key = device_key(user_id)
await redis_client.hset(key, mapping={
"device_id": device_id,
"gateway_id": _GATEWAY_ID,
})
logger.info("redis_bridge: registered user=%s device=%s gateway=%s", user_id, device_id, _GATEWAY_ID)
async def unregister_device(user_id: str) -> None:
"""Remove device registration from Redis."""
key = device_key(user_id)
await redis_client.delete(key)
logger.info("redis_bridge: unregistered user=%s", user_id)
async def is_device_online(user_id: str) -> bool:
"""Check if a device is registered."""
key = device_key(user_id)
return await redis_client.exists(key) > 0
# ── Frame Routing ────────────────────────────────────────────────────
async def publish_chat_request(user_id: str, frame: dict) -> None:
"""Forward a chat request frame to the Chat Service via Redis."""
channel = chat_request_channel(user_id)
await redis_client.publish(channel, json.dumps(frame))
logger.debug("redis_bridge: published chat_request user=%s", user_id)
async def publish_batch_request(user_id: str, frame: dict) -> None:
"""Forward a batch request frame to the Batch Agent Service via Redis."""
channel = batch_request_channel(user_id)
await redis_client.publish(channel, json.dumps(frame))
logger.debug("redis_bridge: published batch_request user=%s", user_id)
async def push_tool_result(call_id: str, result: dict) -> None:
"""Push a tool_result to the Redis list for the waiting service.
Chat/Batch services do BRPOP on this key with a 30s timeout.
"""
key = tool_result_key(call_id)
await redis_client.lpush(key, json.dumps(result))
# Auto-expire after 60s to prevent stale keys
await redis_client.expire(key, 60)
logger.debug("redis_bridge: pushed tool_result call_id=%s", call_id)
async def subscribe_outbound(user_id: str):
"""Return an async pubsub subscription for frames to send to Electron.
Chat/Batch services publish to ws:out:{user_id} and this gateway
forwards them to the connected WebSocket.
"""
channel = ws_out_channel(user_id)
pubsub = redis_client.pubsub()
await pubsub.subscribe(channel)
return pubsub

View File

@@ -0,0 +1,8 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-jose[cryptography]>=3.3.0
redis>=5.0.0
websockets>=14.0

5
shared/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""Shared module — imported by all microservices.
Contains DB engine/session, ORM models, Pydantic schemas, config,
and Redis utilities. Changes here affect ALL services.
"""

98
shared/config.py Normal file
View File

@@ -0,0 +1,98 @@
"""Shared configuration — Pydantic Settings loaded from environment.
All services import ``settings`` from here. Each service only uses a subset
of the vars, but keeping one Settings class avoids fragmentation.
"""
from pathlib import Path
from typing import Literal
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
# Locate the repo root (adiuva-api/) so we can load its .env as a fallback.
# Works whether cwd is adiuva-api/ (monolith) or adiuva-api/services/xyz/ (microservice).
_this_dir = Path(__file__).resolve().parent # shared/
_repo_root = _this_dir.parent # adiuva-api/
_root_env = _repo_root / ".env"
class Settings(BaseSettings):
# ── Database ─────────────────────────────────────────────────────
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
# ── JWT ────────────────────────────────────────────────────────
# RS256 public key (PEM). Used by any service that needs to verify
# JWTs locally (optional — Traefik ForwardAuth handles this in prod).
# The private key lives ONLY in the Auth Service config.
JWT_PUBLIC_KEY: str = ""
@field_validator("JWT_PUBLIC_KEY", mode="before")
@classmethod
def _expand_pem_newlines(cls, v: str) -> str:
if isinstance(v, str) and r"\n" in v:
return v.replace(r"\n", "\n")
return v
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
# ── Redis ────────────────────────────────────────────────────────
REDIS_URL: str = "redis://localhost:6379/0"
# ── Stripe ───────────────────────────────────────────────────────
STRIPE_SECRET_KEY: str = ""
STRIPE_WEBHOOK_SECRET: str = ""
# ── S3 ───────────────────────────────────────────────────────────
S3_BUCKET: str = ""
S3_REGION: str = "us-east-1"
S3_ENDPOINT_URL: str = ""
AWS_ACCESS_KEY_ID: str = ""
AWS_SECRET_ACCESS_KEY: str = ""
# ── Vector stores ────────────────────────────────────────────────
PINECONE_API_KEY: str = ""
PINECONE_INDEX: str = "adiuva"
QDRANT_URL: str = ""
QDRANT_API_KEY: str = ""
# ── LLM providers ────────────────────────────────────────────────
OPENAI_API_KEY: str = ""
ANTHROPIC_API_KEY: str = ""
GOOGLE_API_KEY: str = ""
CEREBRAS_API_KEY: str = ""
LLM_MODEL: str = "gpt-4o"
LLM_EMBED_MODEL: str = "text-embedding-3-small"
GITHUB_COPILOT_TOKEN_DIR: str = ""
# ── OAuth (integrations) ─────────────────────────────────────────
GMAIL_CLIENT_ID: str = ""
GMAIL_CLIENT_SECRET: str = ""
MS_CLIENT_ID: str = ""
MS_CLIENT_SECRET: str = ""
MS_TENANT_ID: str = "common"
OAUTH_ENCRYPTION_KEY: str = ""
# ── Langfuse (observability) ─────────────────────────────────────
LANGFUSE_SECRET_KEY: str = ""
LANGFUSE_PUBLIC_KEY: str = ""
LANGFUSE_HOST: str = "https://cloud.langfuse.com"
# ── CORS ─────────────────────────────────────────────────────────
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
# ── Environment ──────────────────────────────────────────────────
ENV: Literal["dev", "prod"] = "dev"
model_config = SettingsConfigDict(
# Local .env (cwd) takes priority; root .env is fallback.
env_file=(".env", str(_root_env)),
env_file_encoding="utf-8",
extra="ignore",
)
settings = Settings()

32
shared/db.py Normal file
View File

@@ -0,0 +1,32 @@
"""Database engine, session factory, and declarative base.
All services use the async SQLAlchemy API via ``get_session()``.
Alembic migrations use the synchronous psycopg2 URL (see alembic/env.py).
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from shared.config import settings
engine = create_async_engine(
settings.DATABASE_URL,
pool_pre_ping=True,
echo=False,
)
async_session = async_sessionmaker(engine, expire_on_commit=False)
class Base(DeclarativeBase):
"""Shared declarative base for all ORM models."""
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI dependency that yields an async DB session per request."""
async with async_session() as session:
yield session

455
shared/models.py Normal file
View File

@@ -0,0 +1,455 @@
"""SQLAlchemy ORM models for all persistent tables.
Centralized here so that Alembic migrations and all services share
the same model definitions. Each service only queries the tables it owns.
Ownership:
Auth Service → users, refresh_tokens, subscriptions
Chat Service → memory_core, memory_associative, memory_episodic, memory_proactive
Batch Agent → local_agent_configs, cloud_agent_configs, agent_run_logs
Billing Service → subscriptions (shared write with Auth)
(excluded MVP) → storage_records, backup_metadata, plugins, plugin_*, revenue_events
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from sqlalchemy import (
BigInteger,
Boolean,
DateTime,
Enum,
Float,
ForeignKey,
Integer,
JSON,
String,
Text,
UniqueConstraint,
Uuid,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from shared.db import Base
# ── Helpers ──────────────────────────────────────────────────────────────
def _uuid() -> str:
return str(uuid.uuid4())
def _now() -> datetime:
return datetime.now(timezone.utc)
# ── Enum types ────────────────────────────────────────────────────────────
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
# ── Auth models ───────────────────────────────────────────────────────────
class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
subscription: Mapped[Subscription | None] = relationship(
back_populates="user", uselist=False, cascade="all, delete-orphan"
)
class RefreshToken(Base):
__tablename__ = "refresh_tokens"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
user: Mapped[User] = relationship(back_populates="refresh_tokens")
class Subscription(Base):
__tablename__ = "subscriptions"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, unique=True, index=True
)
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
user: Mapped[User] = relationship(back_populates="subscription")
# ── Storage models (excluded from MVP, kept for Alembic) ──────────────
class StorageRecord(Base):
__tablename__ = "storage_records"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class BackupMetadata(Base):
__tablename__ = "backup_metadata"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
# ── Plugin models (excluded from MVP, kept for Alembic) ───────────────
class Plugin(Base):
__tablename__ = "plugins"
id: Mapped[str] = mapped_column(String(255), primary_key=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
author_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
submitted_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
installations: Mapped[list[PluginInstallation]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
reviews: Mapped[list[PluginReview]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
revenue_events: Mapped[list[RevenueEvent]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
class PluginInstallation(Base):
__tablename__ = "plugin_installations"
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
installed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="installations")
class PluginReview(Base):
__tablename__ = "plugin_reviews"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
reviewer_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
reviewed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
class RevenueEvent(Base):
__tablename__ = "revenue_events"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
# ── Agent models ──────────────────────────────────────────────────────────
class LocalAgentConfig(Base):
__tablename__ = "local_agent_configs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
device_id: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
run_logs: Mapped[list[AgentRunLog]] = relationship(
back_populates="local_agent",
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
foreign_keys="AgentRunLog.agent_id",
cascade="all, delete-orphan",
overlaps="run_logs,cloud_agent",
)
class CloudAgentConfig(Base):
__tablename__ = "cloud_agent_configs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
provider: Mapped[str] = mapped_column(CloudProviderEnum, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
oauth_token_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
filter_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
run_logs: Mapped[list[AgentRunLog]] = relationship(
back_populates="cloud_agent",
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
foreign_keys="AgentRunLog.agent_id",
cascade="all, delete-orphan",
overlaps="run_logs,local_agent",
)
class AgentRunLog(Base):
__tablename__ = "agent_run_logs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
started_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
local_agent: Mapped[LocalAgentConfig | None] = relationship(
back_populates="run_logs",
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
foreign_keys="AgentRunLog.agent_id",
overlaps="run_logs,cloud_agent",
)
cloud_agent: Mapped[CloudAgentConfig | None] = relationship(
back_populates="run_logs",
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
foreign_keys="AgentRunLog.agent_id",
overlaps="run_logs,local_agent",
)
# ── Memory models ─────────────────────────────────────────────────────────
class MemoryCore(Base):
"""Per-user persistent key/value preferences, encrypted at rest."""
__tablename__ = "memory_core"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
key: Mapped[str] = mapped_column(String(255), nullable=False)
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class MemoryAssociative(Base):
"""Per-user semantic memory: encrypted content + pgvector embedding."""
__tablename__ = "memory_associative"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class MemoryEpisodic(Base):
"""Per-user session summaries, encrypted at rest."""
__tablename__ = "memory_episodic"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class MemoryProactive(Base):
"""Per-user inferred behavioral patterns, encrypted at rest."""
__tablename__ = "memory_proactive"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)

53
shared/redis.py Normal file
View File

@@ -0,0 +1,53 @@
"""Redis client and pub/sub utilities for inter-service communication.
All services that need Redis import from here.
"""
from __future__ import annotations
import redis.asyncio as aioredis
from shared.config import settings
redis_client: aioredis.Redis = aioredis.from_url(
settings.REDIS_URL,
decode_responses=True,
)
# ── Channel naming conventions ────────────────────────────────────────
# See /memories/repo/microservices-architecture.md for full list.
def ws_out_channel(user_id: str) -> str:
"""Frames to forward to Electron via WS Gateway."""
return f"ws:out:{user_id}"
def chat_request_channel(user_id: str) -> str:
"""Chat requests (home + floating) from WS Gateway → Chat Service."""
return f"chat:request:{user_id}"
def batch_request_channel(user_id: str) -> str:
"""Batch requests (journey + triggers) from WS Gateway → Batch Agent."""
return f"batch:request:{user_id}"
def tool_result_key(call_id: str) -> str:
"""Tool result list: LPUSH by WS Gateway, BRPOP by Chat/Batch."""
return f"tool:result:{call_id}"
def device_key(user_id: str) -> str:
"""Device registry hash."""
return f"ws:devices:{user_id}"
def tier_changed_channel(user_id: str) -> str:
"""Billing tier change notifications."""
return f"tier:changed:{user_id}"
def journey_session_key(user_id: str) -> str:
"""Journey builder session (String + TTL 1800s)."""
return f"journey:{user_id}"

317
shared/schemas.py Normal file
View File

@@ -0,0 +1,317 @@
"""Pydantic schemas — API request/response contracts.
Shared across all services. Mirrors the TypeScript types from
the Electron app (src/shared/api-types.ts).
"""
from __future__ import annotations
from enum import Enum
from typing import Any, Literal
from pydantic import BaseModel, Field
# ── Billing ──────────────────────────────────────────────────────────
BillingTier = Literal["free", "pro", "power", "team"]
# ── Auth ─────────────────────────────────────────────────────────────
class AuthTokens(BaseModel):
access_token: str
refresh_token: str
expires_at: int
class UserProfile(BaseModel):
id: str
email: str
name: str | None = None
surname: str | None = None
tier: BillingTier
# ── Chat ─────────────────────────────────────────────────────────────
class ChatContext(BaseModel):
user_profile: dict[str, Any] = Field(default_factory=dict)
relevant_documents: list[str] = Field(default_factory=list)
recent_tasks: list[dict[str, Any]] = Field(default_factory=list)
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
class ChatRequest(BaseModel):
message: str
context: ChatContext = Field(default_factory=ChatContext)
class ChatResponse(BaseModel):
response: str
# ── Backup ───────────────────────────────────────────────────────────
class BackupMetadata(BaseModel):
version: int
timestamp: int
checksum: str
chunk_count: int
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
class StorageRecord(BaseModel):
id: str
user_id: str
table: str
blob: bytes
checksum: str
created_at: int
updated_at: int
class StorageRecordCreate(BaseModel):
table: str
blob: bytes
checksum: str
class StorageRecordUpdate(BaseModel):
blob: bytes
checksum: str
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
class VectorItem(BaseModel):
id: str
blob: bytes
checksum: str
class VectorUpsertRequest(BaseModel):
vectors: list[VectorItem]
class VectorSearchRequest(BaseModel):
query_blob: bytes
top_k: int = 10
class VectorSearchResult(BaseModel):
id: str
score: float
blob: bytes
class VectorSearchResponse(BaseModel):
results: list[VectorSearchResult]
# ── Plugin Marketplace ────────────────────────────────────────────────
class PluginManifest(BaseModel):
id: str
name: str
description: str
version: str
author: str
permissions: list[str]
category: str
price_cents: int = 0
class PluginListResponse(BaseModel):
plugins: list[PluginManifest]
total: int
page: int
class PluginInstallRequest(BaseModel):
plugin_id: str
# ── WebSocket Frame Protocol ──────────────────────────────────────────
class WsFrameType(str, Enum):
# ── v2 frame types (kept for backward compat) ──────────────────────
chat_request = "chat_request"
text_chunk = "text_chunk"
tool_call = "tool_call"
tool_result = "tool_result"
final = "final"
ping = "ping"
device_hello = "device_hello"
# ── v3 frame types ─────────────────────────────────────────────────
home_request = "home_request"
floating_request = "floating_request"
stream_start = "stream_start"
stream_text = "stream_text"
stream_end = "stream_end"
floating_domain = "floating_domain"
data_request = "data_request"
data_response = "data_response"
mutation = "mutation"
# ── v4 journey frame types ────────────────────────────────────────
journey_start = "journey_start"
journey_message = "journey_message"
journey_reply = "journey_reply"
class WsToolCall(BaseModel):
"""Server → Client: requests a CRUD/vector operation on the local DB."""
type: Literal[WsFrameType.tool_call] = WsFrameType.tool_call
id: str
action: str
table: str | None = None
data: dict[str, Any] | None = None
filters: dict[str, Any] | None = None
vector: list[float] | None = None
limit: int | None = None
class WsToolResult(BaseModel):
"""Client → Server: result of a CRUD/vector operation."""
type: Literal[WsFrameType.tool_result] = WsFrameType.tool_result
id: str
row: dict[str, Any] | None = None
rows: list[dict[str, Any]] | None = None
results: list[dict[str, Any]] | None = None
deleted: bool | None = None
ok: bool | None = None
error: str | None = None
class WsTextChunk(BaseModel):
"""Server → Client: incremental LLM response text."""
type: Literal[WsFrameType.text_chunk] = WsFrameType.text_chunk
text: str
class WsFinal(BaseModel):
"""Server → Client: signals end of response with the complete text."""
type: Literal[WsFrameType.final] = WsFrameType.final
response: str
# ── WebSocket Agent Frame Protocol ────────────────────────────────────
class WsDeviceHello(BaseModel):
"""Client → Server: device identification on WS connect."""
type: Literal[WsFrameType.device_hello] = WsFrameType.device_hello
device_id: str
agent_ids: list[str] = Field(default_factory=list)
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
class WsFloatingScope(BaseModel):
"""Scope for a floating request."""
type: Literal["task", "project", "note", "timeline"]
id: str | None = None
class WsHomeRequest(BaseModel):
"""Client → Server: Home chat message."""
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
message: str
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
class WsFloatingRequest(BaseModel):
"""Client → Server: Floating chat message scoped to an entity."""
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
message: str
scope: WsFloatingScope
class WsStreamStart(BaseModel):
"""Server → Client: signals start of a streaming response."""
type: Literal[WsFrameType.stream_start] = WsFrameType.stream_start
request_id: str
class WsStreamText(BaseModel):
"""Server → Client: streamed text token."""
type: Literal[WsFrameType.stream_text] = WsFrameType.stream_text
request_id: str
chunk: str
class WsStreamEnd(BaseModel):
"""Server → Client: signals end of a streaming response."""
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
request_id: str
class WsDomain(BaseModel):
"""Structured floating domain payload for UI routing decisions."""
type: Literal["task", "timeline", "project", "node"]
id: str | None = None
section: Literal["task", "timeline", "note"] | None = None
class WsFloatingDomain(BaseModel):
"""Server → Client: domain determined for a floating request."""
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
request_id: str
domain: WsDomain
# ── Agent Catalog ─────────────────────────────────────────────────────
class AgentCatalogItem(BaseModel):
type: str
name: str
description: str
class AgentCreationCheckRequest(BaseModel):
active_agents: int = Field(ge=0, default=0)
class AgentCreationCheckResponse(BaseModel):
allowed: bool
tier: BillingTier
active_agents: int
limit: int
class AgentTriggerRequest(BaseModel):
directory: str = Field(min_length=1)
device_id: str = Field(default="")
agent_id: str | None = None
what_to_extract: list[str] = Field(min_length=1)
actions_by_type: dict[str, list[str]] | None = None
batch_interval: str = Field(min_length=1)
custom_agent_prompt: str = Field(min_length=1)
active_agents: int = Field(ge=0, default=0)
# ── Agent Run Log ─────────────────────────────────────────────────────
class AgentRunLogResponse(BaseModel):
id: str
agent_id: str
agent_type: Literal["local", "cloud"]
status: Literal["running", "success", "error", "partial"]
items_processed: int
items_created: int
errors: list[str]
started_at: int
completed_at: int | None

124
tests/test_e2e_flow.py Normal file
View File

@@ -0,0 +1,124 @@
"""End-to-end test: Auth → WS Gateway → Chat Service round-trip.
Usage (from repo root, with venv activated):
python test_e2e_flow.py
Requires: Auth (8001), WS Gateway (8002), Chat (8003) all running.
"""
import asyncio
import json
import uuid
import httpx
import websockets
AUTH_URL = "http://127.0.0.1:8001/api/v1/auth"
WS_URL = "ws://127.0.0.1:8002/api/v1/ws/device"
# ── 1. Authenticate ─────────────────────────────────────────────────
async def get_token() -> str:
async with httpx.AsyncClient() as client:
# Try login first, register if user doesn't exist
resp = await client.post(
f"{AUTH_URL}/login",
json={"email": "e2e@test.com", "password": "Test1234!"},
)
if resp.status_code == 200:
print("[1/4] Logged in as e2e@test.com")
return resp.json()["access_token"]
resp = await client.post(
f"{AUTH_URL}/register",
json={
"email": "e2e@test.com",
"password": "Test1234!",
"name": "E2E",
"surname": "Test",
},
)
resp.raise_for_status()
print("[1/4] Registered + logged in as e2e@test.com")
return resp.json()["access_token"]
# ── 2. WebSocket flow ───────────────────────────────────────────────
async def run_e2e():
token = await get_token()
uri = f"{WS_URL}?token={token}"
async with websockets.connect(uri) as ws:
# Send device_hello
await ws.send(json.dumps({
"type": "device_hello",
"device_id": str(uuid.uuid4()),
"agent_ids": ["task", "note", "project", "timeline"],
}))
print("[2/4] Device registered with WS Gateway")
# Send a home_request (simple greeting — unlikely to need tools)
await ws.send(json.dumps({
"type": "home_request",
"message": "Hello! How are you doing today?",
"context": {},
}))
print("[3/4] Sent home_request → waiting for Chat Service response...")
# Listen for response frames (text_chunk, tool_call, final)
full_response = []
try:
while True:
raw = await asyncio.wait_for(ws.recv(), timeout=60)
frame = json.loads(raw)
ftype = frame.get("type")
if ftype == "text_chunk":
chunk = frame.get("chunk", frame.get("text", ""))
full_response.append(chunk)
print(f" ← text_chunk: {chunk[:80]}")
elif ftype == "tool_call":
# Respond with a mock tool_result so the agent doesn't hang
call_id = frame.get("id")
action = frame.get("action")
table = frame.get("table", "")
print(f" ← tool_call: {action} {table} (id={call_id})")
mock_result = {"rows": [], "row": None}
await ws.send(json.dumps({
"type": "tool_result",
"id": call_id,
**mock_result,
}))
print(f" → tool_result (mock) for {call_id}")
elif ftype == "final":
text = frame.get("text", "")
if text:
full_response.append(text)
print(f" ← final")
break
elif ftype == "ping":
# Ignore heartbeats
continue
else:
print(f"{ftype}: {json.dumps(frame)[:120]}")
except asyncio.TimeoutError:
print(" ⚠ Timed out waiting for response (60s)")
print()
if full_response:
print(f"[4/4] Full response: {''.join(full_response)}")
else:
print("[4/4] No text response received (check Chat Service logs)")
if __name__ == "__main__":
asyncio.run(run_e2e())

143
traefik/dynamic/routers.yml Normal file
View File

@@ -0,0 +1,143 @@
# Dynamic routing configuration
http:
middlewares:
# ForwardAuth: validates JWT via Auth Service, injects identity headers
auth-forward:
forwardAuth:
address: "http://auth:8000/api/v1/auth/verify"
trustForwardHeader: true
authResponseHeaders:
- "X-User-Id"
- "X-User-Email"
- "X-User-Tier"
# Rate limiting (basic — per-client IP; upgrade to per-tier later)
rate-limit:
rateLimit:
average: 60
burst: 20
period: "1m"
# Strip /api/v1 prefix before forwarding to services
strip-api-prefix:
stripPrefix:
prefixes:
- "/api/v1"
routers:
# ── Auth (no ForwardAuth on public endpoints) ──────────────
auth-public:
rule: "PathPrefix(`/api/v1/auth/register`) || PathPrefix(`/api/v1/auth/login`) || PathPrefix(`/api/v1/auth/refresh`)"
entryPoints:
- websecure
middlewares:
- rate-limit
- strip-api-prefix
service: auth-svc
tls: {}
auth-protected:
rule: "PathPrefix(`/api/v1/auth`)"
entryPoints:
- websecure
middlewares:
- auth-forward
- rate-limit
- strip-api-prefix
service: auth-svc
tls: {}
# ── WebSocket Gateway (sticky sessions) ────────────────────
ws-gateway:
rule: "PathPrefix(`/api/v1/ws`)"
entryPoints:
- websecure
middlewares:
- rate-limit
service: ws-gateway-svc
tls: {}
# ── Chat Service ───────────────────────────────────────────
chat:
rule: "PathPrefix(`/api/v1/chat`)"
entryPoints:
- websecure
middlewares:
- auth-forward
- rate-limit
- strip-api-prefix
service: chat-svc
tls: {}
# ── Batch Agent Service ────────────────────────────────────
batch-agent:
rule: "PathPrefix(`/api/v1/agents`)"
entryPoints:
- websecure
middlewares:
- auth-forward
- rate-limit
- strip-api-prefix
service: batch-agent-svc
tls: {}
# ── Billing Service ────────────────────────────────────────
billing-webhook:
rule: "PathPrefix(`/api/v1/billing/webhook`)"
entryPoints:
- websecure
middlewares:
- rate-limit
- strip-api-prefix
service: billing-svc
tls: {}
priority: 10
billing:
rule: "PathPrefix(`/api/v1/billing`)"
entryPoints:
- websecure
middlewares:
- auth-forward
- rate-limit
- strip-api-prefix
service: billing-svc
tls: {}
# ── Health (no auth) ───────────────────────────────────────
health:
rule: "Path(`/api/v1/health`)"
entryPoints:
- websecure
service: auth-svc
tls: {}
services:
auth-svc:
loadBalancer:
servers:
- url: "http://auth:8000"
ws-gateway-svc:
loadBalancer:
sticky:
cookie:
name: "ws_affinity"
servers:
- url: "http://ws-gateway:8000"
chat-svc:
loadBalancer:
servers:
- url: "http://chat:8000"
batch-agent-svc:
loadBalancer:
servers:
- url: "http://batch-agent:8000"
billing-svc:
loadBalancer:
servers:
- url: "http://billing:8000"

39
traefik/traefik.yml Normal file
View File

@@ -0,0 +1,39 @@
# Traefik static configuration for microservices gateway
api:
dashboard: true
insecure: true # Dashboard on :8080 (internal only in prod)
entryPoints:
web:
address: ":80"
http:
redirections:
entryPoint:
to: websecure
scheme: https
websecure:
address: ":443"
http:
tls:
certResolver: cloudflare
providers:
docker:
exposedByDefault: false
file:
directory: /etc/traefik/dynamic
watch: true
# Automatic TLS via Let's Encrypt + Cloudflare DNS-01 challenge
certificatesResolvers:
cloudflare:
acme:
email: "${ACME_EMAIL}"
storage: /etc/traefik/acme/acme.json
dnsChallenge:
provider: cloudflare
delayBeforeCheck: "10"
resolvers:
- "1.1.1.1:53"
- "8.8.8.8:53"