226 lines
7.9 KiB
Python
226 lines
7.9 KiB
Python
"""Memory management routes — view/edit/delete user memory tiers.
|
|
|
|
All routes require authentication. Data is always user-scoped.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy import delete, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.core.memory_middleware import MemoryMiddleware
|
|
from app.db import get_session
|
|
from app.models import (
|
|
ExtractionQueue,
|
|
MemoryAssociative,
|
|
MemoryCore,
|
|
MemoryEpisodic,
|
|
MemoryProactive,
|
|
MemoryRelation,
|
|
)
|
|
from app.schemas import UserProfile
|
|
|
|
router = APIRouter(prefix="/memory", tags=["memory"])
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ALLOWED_PREDICATES = {
|
|
"works_at",
|
|
"reports_to",
|
|
"stakeholder_of",
|
|
"last_contacted_on",
|
|
"owes_followup",
|
|
"manages",
|
|
"collaborates_with",
|
|
"owns",
|
|
"member_of",
|
|
"custom",
|
|
}
|
|
|
|
|
|
# ── Response schemas ─────────────────────────────────────────────────────────
|
|
|
|
class RelationOut(BaseModel):
|
|
id: str
|
|
subject_label: str
|
|
subject_type: str
|
|
predicate: str
|
|
object_label: str
|
|
object_type: str
|
|
confidence: float
|
|
last_confirmed_at: int | None = None # epoch ms
|
|
|
|
|
|
class RelationPatch(BaseModel):
|
|
subject_label: str | None = None
|
|
object_label: str | None = None
|
|
predicate: str | None = None
|
|
confidence: float | None = Field(None, ge=0.0, le=1.0)
|
|
|
|
|
|
class CoreAddBody(BaseModel):
|
|
key: str = Field(..., min_length=1, max_length=255)
|
|
value: str = Field(..., min_length=1)
|
|
|
|
|
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
def _relation_to_out(row: MemoryRelation) -> RelationOut:
|
|
last_ms: int | None = None
|
|
if row.last_confirmed_at is not None:
|
|
last_ms = int(row.last_confirmed_at.timestamp() * 1000)
|
|
return RelationOut(
|
|
id=row.id,
|
|
subject_label=row.subject_label,
|
|
subject_type=row.subject_type,
|
|
predicate=row.predicate,
|
|
object_label=row.object_label,
|
|
object_type=row.object_type,
|
|
confidence=row.confidence,
|
|
last_confirmed_at=last_ms,
|
|
)
|
|
|
|
|
|
# ── Routes ───────────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/core", response_model=dict[str, str])
|
|
async def get_core_memory(
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> dict[str, str]:
|
|
"""Return all core memory k/v pairs (plaintext) for the current user."""
|
|
mw = MemoryMiddleware(db)
|
|
blocks = await mw.list_core_blocks(current_user.id)
|
|
return {b["label"]: b["value"] for b in blocks}
|
|
|
|
|
|
@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_core_key(
|
|
key: str,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> None:
|
|
"""Delete a single core memory key (GDPR Art. 17)."""
|
|
mw = MemoryMiddleware(db)
|
|
deleted = await mw.delete_core(current_user.id, key)
|
|
if not deleted:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found")
|
|
|
|
|
|
@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str])
|
|
async def add_core_key(
|
|
body: CoreAddBody,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> dict[str, str]:
|
|
"""Add or overwrite a core memory key/value pair."""
|
|
mw = MemoryMiddleware(db)
|
|
await mw.update_core(current_user.id, body.key, body.value)
|
|
return {body.key: body.value}
|
|
|
|
|
|
@router.get("/relational", response_model=list[RelationOut])
|
|
async def get_relational_memory(
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> list[RelationOut]:
|
|
"""Return all relational memory rows for the current user."""
|
|
mw = MemoryMiddleware(db)
|
|
rows = await mw.query_relations(current_user.id, limit=200)
|
|
return [_relation_to_out(r) for r in rows]
|
|
|
|
|
|
@router.patch("/relational/{relation_id}", response_model=RelationOut)
|
|
async def patch_relation(
|
|
relation_id: str,
|
|
body: RelationPatch,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> RelationOut:
|
|
"""Edit a relation row's labels, predicate, or confidence."""
|
|
if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}",
|
|
)
|
|
|
|
result = await db.execute(
|
|
select(MemoryRelation).where(
|
|
MemoryRelation.id == relation_id,
|
|
MemoryRelation.user_id == current_user.id,
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if row is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
|
|
|
if body.subject_label is not None:
|
|
row.subject_label = body.subject_label
|
|
if body.object_label is not None:
|
|
row.object_label = body.object_label
|
|
if body.predicate is not None:
|
|
row.predicate = body.predicate
|
|
if body.confidence is not None:
|
|
row.confidence = body.confidence
|
|
row.last_confirmed_at = datetime.now(timezone.utc)
|
|
|
|
await db.commit()
|
|
await db.refresh(row)
|
|
logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id)
|
|
return _relation_to_out(row)
|
|
|
|
|
|
@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_relation(
|
|
relation_id: str,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> None:
|
|
"""Hard-delete a relation row (GDPR Art. 17)."""
|
|
result = await db.execute(
|
|
select(MemoryRelation).where(
|
|
MemoryRelation.id == relation_id,
|
|
MemoryRelation.user_id == current_user.id,
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if row is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
|
await db.delete(row)
|
|
await db.commit()
|
|
logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id)
|
|
|
|
|
|
@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def forget_all(
|
|
x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> None:
|
|
"""Wipe all memory tiers for the current user (GDPR Art. 17).
|
|
|
|
Requires ``X-Confirm: true`` header. Does NOT delete the user account.
|
|
"""
|
|
if x_confirm != "true":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.",
|
|
)
|
|
|
|
uid = current_user.id
|
|
await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid))
|
|
await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid))
|
|
await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid))
|
|
await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid))
|
|
await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid))
|
|
await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid))
|
|
await db.commit()
|
|
logger.warning("memory: forget_all GDPR wipe user=%s", uid)
|