Compare commits
7 Commits
7ccdad431f
...
0b5ef48463
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b5ef48463 | ||
|
|
ca8721e1ac | ||
|
|
f658e5e6a3 | ||
|
|
341ee140e5 | ||
|
|
741b9b87fb | ||
|
|
2d8abb6311 | ||
|
|
e668e3fd20 |
15
.env.example
15
.env.example
@@ -53,6 +53,21 @@ LLM_MODEL_CLOUD_PROCESSOR=
|
|||||||
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
||||||
LLM_MODEL_SETUP_AGENT=
|
LLM_MODEL_SETUP_AGENT=
|
||||||
|
|
||||||
|
# Memory-extractor — Mem0-style extract/decide pipeline (Phase 2).
|
||||||
|
# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0).
|
||||||
|
LLM_MODEL_MEMORY_EXTRACTOR=
|
||||||
|
|
||||||
|
# Memory-miner — proactive pattern mining from episodic history (Phase 5, Power+ only).
|
||||||
|
# Defaults to gpt-4o-mini when empty.
|
||||||
|
LLM_MODEL_MEMORY_MINER=
|
||||||
|
|
||||||
|
# Memory-auditor — weekly contradiction scan + relation label canonicalization (Phase 7).
|
||||||
|
# Defaults to LLM_MODEL when empty (a reasoning-capable model is recommended).
|
||||||
|
LLM_MODEL_MEMORY_AUDITOR=
|
||||||
|
|
||||||
|
# Scheduler — set to false to disable memory cron jobs (automatically false in tests).
|
||||||
|
SCHEDULER_ENABLED=true
|
||||||
|
|
||||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||||
STRIPE_SECRET_KEY=
|
STRIPE_SECRET_KEY=
|
||||||
STRIPE_WEBHOOK_SECRET=
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|||||||
54
alembic/versions/005_associative_pgvector.py
Normal file
54
alembic/versions/005_associative_pgvector.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Phase 1 — confirm pgvector activation on memory_associative.
|
||||||
|
|
||||||
|
Migration 004 created the embedding column as vector(1536) and added the
|
||||||
|
IVFFlat index. This migration is the Phase-1 checkpoint:
|
||||||
|
1. Ensures the pgvector extension is enabled (idempotent).
|
||||||
|
2. Ensures the canonical Phase-1 IVFFlat index exists under the name
|
||||||
|
memory_associative_embedding_idx (creates it only if absent).
|
||||||
|
|
||||||
|
Revision ID: 005
|
||||||
|
Revises: 9a1f2d0b6c7e
|
||||||
|
Create Date: 2026-04-15
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "005"
|
||||||
|
down_revision: Union[str, None] = "e04100e88ace"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Ensure pgvector extension is enabled (also done in 004, idempotent).
|
||||||
|
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||||
|
|
||||||
|
# Ensure the canonical Phase-1 IVFFlat index exists.
|
||||||
|
# 004 may have created ix_memory_associative_embedding; this adds the
|
||||||
|
# Phase-1 name memory_associative_embedding_idx if it is missing.
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_indexes
|
||||||
|
WHERE tablename = 'memory_associative'
|
||||||
|
AND indexname = 'memory_associative_embedding_idx'
|
||||||
|
) THEN
|
||||||
|
CREATE INDEX memory_associative_embedding_idx
|
||||||
|
ON memory_associative
|
||||||
|
USING ivfflat (embedding vector_cosine_ops)
|
||||||
|
WITH (lists = 100);
|
||||||
|
END IF;
|
||||||
|
END $$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DROP INDEX IF EXISTS memory_associative_embedding_idx;")
|
||||||
74
alembic/versions/006_memory_relations.py
Normal file
74
alembic/versions/006_memory_relations.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Add memory_relations table (Phase 3 — relational tier).
|
||||||
|
|
||||||
|
Revision ID: 006
|
||||||
|
Revises: 1f5975a4f3f4
|
||||||
|
Create Date: 2026-04-16
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "006"
|
||||||
|
down_revision: Union[str, None] = "1f5975a4f3f4"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"memory_relations",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("subject_label", sa.String(128), nullable=False),
|
||||||
|
sa.Column("subject_type", sa.String(32), nullable=False),
|
||||||
|
sa.Column("predicate", sa.String(64), nullable=False),
|
||||||
|
sa.Column("object_label", sa.String(128), nullable=False),
|
||||||
|
sa.Column("object_type", sa.String(32), nullable=False),
|
||||||
|
sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"),
|
||||||
|
sa.Column(
|
||||||
|
"source_episode_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("notes_encrypted", sa.LargeBinary, nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"memory_relations_user_subject_idx",
|
||||||
|
"memory_relations",
|
||||||
|
["user_id", "subject_label"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"memory_relations_user_predicate_idx",
|
||||||
|
"memory_relations",
|
||||||
|
["user_id", "predicate"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("memory_relations_user_predicate_idx", "memory_relations")
|
||||||
|
op.drop_index("memory_relations_user_subject_idx", "memory_relations")
|
||||||
|
op.drop_table("memory_relations")
|
||||||
38
alembic/versions/1f5975a4f3f4_add_extraction_queue.py
Normal file
38
alembic/versions/1f5975a4f3f4_add_extraction_queue.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""add extraction_queue
|
||||||
|
|
||||||
|
Revision ID: 1f5975a4f3f4
|
||||||
|
Revises: 005
|
||||||
|
Create Date: 2026-04-16 17:26:25.790870
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '1f5975a4f3f4'
|
||||||
|
down_revision: Union[str, None] = '005'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'extraction_queue',
|
||||||
|
sa.Column('id', sa.Uuid(as_uuid=False), nullable=False),
|
||||||
|
sa.Column('user_id', sa.Uuid(as_uuid=False), nullable=False),
|
||||||
|
sa.Column('episode_id', sa.Uuid(as_uuid=False), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_extraction_queue_user_id'), 'extraction_queue', ['user_id'], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(op.f('ix_extraction_queue_user_id'), table_name='extraction_queue')
|
||||||
|
op.drop_table('extraction_queue')
|
||||||
34
alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py
Normal file
34
alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""avatar_url_varchar_to_text
|
||||||
|
|
||||||
|
Revision ID: e04100e88ace
|
||||||
|
Revises: c5d1e2f3a4b5
|
||||||
|
Create Date: 2026-04-13 09:13:06.733674
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'e04100e88ace'
|
||||||
|
down_revision: Union[str, None] = 'c5d1e2f3a4b5'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.alter_column('users', 'avatar_url',
|
||||||
|
existing_type=sa.VARCHAR(length=2048),
|
||||||
|
type_=sa.Text(),
|
||||||
|
existing_nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column('users', 'avatar_url',
|
||||||
|
existing_type=sa.Text(),
|
||||||
|
type_=sa.VARCHAR(length=2048),
|
||||||
|
existing_nullable=True)
|
||||||
@@ -65,10 +65,11 @@ async def get_current_user(
|
|||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
tier: str = result.scalar_one_or_none() or default_tier
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
# Fetch name/surname/avatar_url/onboarding_completed_at from user row.
|
# Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
select(
|
select(
|
||||||
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||||
|
User.password_hash,
|
||||||
).where(User.id == user_id)
|
).where(User.id == user_id)
|
||||||
)
|
)
|
||||||
user_row = user_result.one_or_none()
|
user_row = user_result.one_or_none()
|
||||||
@@ -95,6 +96,7 @@ async def get_current_user(
|
|||||||
name=user_row.name if user_row else None,
|
name=user_row.name if user_row else None,
|
||||||
surname=user_row.surname if user_row else None,
|
surname=user_row.surname if user_row else None,
|
||||||
avatar_url=user_row.avatar_url if user_row else None,
|
avatar_url=user_row.avatar_url if user_row else None,
|
||||||
|
has_password=bool(user_row.password_hash) if user_row else False,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
onboarding_completed_at=onboarding_ms,
|
onboarding_completed_at=onboarding_ms,
|
||||||
memory=memory_dict,
|
memory=memory_dict,
|
||||||
|
|||||||
@@ -519,6 +519,7 @@ async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProf
|
|||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
select(
|
select(
|
||||||
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||||
|
User.password_hash,
|
||||||
).where(User.id == user_id)
|
).where(User.id == user_id)
|
||||||
)
|
)
|
||||||
user_row = user_result.one_or_none()
|
user_row = user_result.one_or_none()
|
||||||
@@ -541,6 +542,7 @@ async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProf
|
|||||||
name=user_row.name if user_row else None,
|
name=user_row.name if user_row else None,
|
||||||
surname=user_row.surname if user_row else None,
|
surname=user_row.surname if user_row else None,
|
||||||
avatar_url=user_row.avatar_url if user_row else None,
|
avatar_url=user_row.avatar_url if user_row else None,
|
||||||
|
has_password=bool(user_row.password_hash) if user_row else False,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
onboarding_completed_at=onboarding_ms,
|
onboarding_completed_at=onboarding_ms,
|
||||||
memory=memory_dict,
|
memory=memory_dict,
|
||||||
@@ -621,3 +623,173 @@ async def normalize_onboarding(
|
|||||||
except Exception:
|
except Exception:
|
||||||
# LLM failure must never block onboarding — return inputs unchanged
|
# LLM failure must never block onboarding — return inputs unchanged
|
||||||
return _NormalizeResponse(normalized=body.inputs)
|
return _NormalizeResponse(normalized=body.inputs)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Password management ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _ChangePasswordRequest(BaseModel):
|
||||||
|
current_password: str = Field(min_length=1)
|
||||||
|
new_password: str = Field(min_length=8)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/password", status_code=status.HTTP_200_OK)
|
||||||
|
async def change_password(
|
||||||
|
body: _ChangePasswordRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Change the authenticated user's password.
|
||||||
|
|
||||||
|
Requires the current password for verification.
|
||||||
|
Returns 400 for social-only users (no password set).
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
if user.password_hash is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
"This account uses social login and has no password to change",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _verify_password(body.current_password, user.password_hash):
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
|
||||||
|
|
||||||
|
user.password_hash = _hash_password(body.new_password)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth account management ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me/oauth-accounts", response_model=list[dict])
|
||||||
|
async def list_oauth_accounts(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[dict]:
|
||||||
|
"""List all OAuth providers linked to the authenticated user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
accounts = result.scalars().all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"provider": a.provider,
|
||||||
|
"provider_email": a.provider_email,
|
||||||
|
"created_at": int(a.created_at.timestamp() * 1000),
|
||||||
|
}
|
||||||
|
for a in accounts
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
|
||||||
|
async def unlink_oauth_account(
|
||||||
|
provider: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Unlink an OAuth provider from the authenticated user.
|
||||||
|
|
||||||
|
Refuses if the user has no password and this is their only login method.
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
oauth_result = await db.execute(
|
||||||
|
select(OAuthAccount).where(
|
||||||
|
OAuthAccount.user_id == current_user.id,
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
account = oauth_result.scalar_one_or_none()
|
||||||
|
if account is None:
|
||||||
|
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
|
||||||
|
|
||||||
|
# Safety: don't let users lock themselves out.
|
||||||
|
all_oauth = await db.execute(
|
||||||
|
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
oauth_count = len(all_oauth.scalars().all())
|
||||||
|
|
||||||
|
if user.password_hash is None and oauth_count <= 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
"Cannot unlink the only login method. Set a password first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.delete(account)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Avatar update ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateAvatarRequest(BaseModel):
|
||||||
|
avatar_url: str = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/avatar", response_model=UserProfile)
|
||||||
|
async def update_avatar(
|
||||||
|
body: _UpdateAvatarRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Update the authenticated user's avatar URL.
|
||||||
|
|
||||||
|
Accepts {"avatar_url": "https://..."} — the client uploads the image
|
||||||
|
to its own storage and passes the resulting URL here.
|
||||||
|
"""
|
||||||
|
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
|
||||||
|
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.avatar_url = body.avatar_url
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return await _build_profile(current_user.id, current_user.email, db)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Account deletion ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/me", status_code=status.HTTP_200_OK)
|
||||||
|
async def delete_account(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Permanently delete the authenticated user's account.
|
||||||
|
|
||||||
|
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
|
||||||
|
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
|
||||||
|
is cancelled if active.
|
||||||
|
"""
|
||||||
|
# Cancel Stripe subscription if present.
|
||||||
|
try:
|
||||||
|
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
||||||
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
|
except HTTPException:
|
||||||
|
pass # No subscription — that's fine
|
||||||
|
|
||||||
|
# Delete all memory rows (core, associative, episodic, proactive).
|
||||||
|
try:
|
||||||
|
from app.models import ( # noqa: PLC0415
|
||||||
|
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
|
||||||
|
)
|
||||||
|
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
|
||||||
|
await db.execute(
|
||||||
|
model.__table__.delete().where(model.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Non-critical — cascade on User will handle most
|
||||||
|
|
||||||
|
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
await db.delete(user)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|||||||
@@ -83,3 +83,16 @@ async def cancel_subscription(
|
|||||||
"""Cancel the active subscription."""
|
"""Cancel the active subscription."""
|
||||||
await stripe_service.cancel_subscription(current_user.id, db)
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/invoices", response_model=list[dict])
|
||||||
|
async def list_invoices(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return billing history (invoices) from Stripe.
|
||||||
|
|
||||||
|
Returns an empty list when Stripe is not configured.
|
||||||
|
"""
|
||||||
|
invoices = await stripe_service.list_invoices(current_user.id, db)
|
||||||
|
return invoices
|
||||||
|
|||||||
225
app/api/routes/memory.py
Normal file
225
app/api/routes/memory.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
"""Memory management routes — view/edit/delete user memory tiers.
|
||||||
|
|
||||||
|
All routes require authentication. Data is always user-scoped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import (
|
||||||
|
ExtractionQueue,
|
||||||
|
MemoryAssociative,
|
||||||
|
MemoryCore,
|
||||||
|
MemoryEpisodic,
|
||||||
|
MemoryProactive,
|
||||||
|
MemoryRelation,
|
||||||
|
)
|
||||||
|
from app.schemas import UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/memory", tags=["memory"])
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ALLOWED_PREDICATES = {
|
||||||
|
"works_at",
|
||||||
|
"reports_to",
|
||||||
|
"stakeholder_of",
|
||||||
|
"last_contacted_on",
|
||||||
|
"owes_followup",
|
||||||
|
"manages",
|
||||||
|
"collaborates_with",
|
||||||
|
"owns",
|
||||||
|
"member_of",
|
||||||
|
"custom",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Response schemas ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class RelationOut(BaseModel):
|
||||||
|
id: str
|
||||||
|
subject_label: str
|
||||||
|
subject_type: str
|
||||||
|
predicate: str
|
||||||
|
object_label: str
|
||||||
|
object_type: str
|
||||||
|
confidence: float
|
||||||
|
last_confirmed_at: int | None = None # epoch ms
|
||||||
|
|
||||||
|
|
||||||
|
class RelationPatch(BaseModel):
|
||||||
|
subject_label: str | None = None
|
||||||
|
object_label: str | None = None
|
||||||
|
predicate: str | None = None
|
||||||
|
confidence: float | None = Field(None, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class CoreAddBody(BaseModel):
|
||||||
|
key: str = Field(..., min_length=1, max_length=255)
|
||||||
|
value: str = Field(..., min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _relation_to_out(row: MemoryRelation) -> RelationOut:
|
||||||
|
last_ms: int | None = None
|
||||||
|
if row.last_confirmed_at is not None:
|
||||||
|
last_ms = int(row.last_confirmed_at.timestamp() * 1000)
|
||||||
|
return RelationOut(
|
||||||
|
id=row.id,
|
||||||
|
subject_label=row.subject_label,
|
||||||
|
subject_type=row.subject_type,
|
||||||
|
predicate=row.predicate,
|
||||||
|
object_label=row.object_label,
|
||||||
|
object_type=row.object_type,
|
||||||
|
confidence=row.confidence,
|
||||||
|
last_confirmed_at=last_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/core", response_model=dict[str, str])
|
||||||
|
async def get_core_memory(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Return all core memory k/v pairs (plaintext) for the current user."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
blocks = await mw.list_core_blocks(current_user.id)
|
||||||
|
return {b["label"]: b["value"] for b in blocks}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_core_key(
|
||||||
|
key: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
"""Delete a single core memory key (GDPR Art. 17)."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
deleted = await mw.delete_core(current_user.id, key)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str])
|
||||||
|
async def add_core_key(
|
||||||
|
body: CoreAddBody,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Add or overwrite a core memory key/value pair."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
await mw.update_core(current_user.id, body.key, body.value)
|
||||||
|
return {body.key: body.value}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/relational", response_model=list[RelationOut])
|
||||||
|
async def get_relational_memory(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[RelationOut]:
|
||||||
|
"""Return all relational memory rows for the current user."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
rows = await mw.query_relations(current_user.id, limit=200)
|
||||||
|
return [_relation_to_out(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/relational/{relation_id}", response_model=RelationOut)
|
||||||
|
async def patch_relation(
|
||||||
|
relation_id: str,
|
||||||
|
body: RelationPatch,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> RelationOut:
|
||||||
|
"""Edit a relation row's labels, predicate, or confidence."""
|
||||||
|
if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryRelation).where(
|
||||||
|
MemoryRelation.id == relation_id,
|
||||||
|
MemoryRelation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||||
|
|
||||||
|
if body.subject_label is not None:
|
||||||
|
row.subject_label = body.subject_label
|
||||||
|
if body.object_label is not None:
|
||||||
|
row.object_label = body.object_label
|
||||||
|
if body.predicate is not None:
|
||||||
|
row.predicate = body.predicate
|
||||||
|
if body.confidence is not None:
|
||||||
|
row.confidence = body.confidence
|
||||||
|
row.last_confirmed_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(row)
|
||||||
|
logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id)
|
||||||
|
return _relation_to_out(row)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_relation(
|
||||||
|
relation_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
"""Hard-delete a relation row (GDPR Art. 17)."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryRelation).where(
|
||||||
|
MemoryRelation.id == relation_id,
|
||||||
|
MemoryRelation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||||
|
await db.delete(row)
|
||||||
|
await db.commit()
|
||||||
|
logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def forget_all(
|
||||||
|
x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
"""Wipe all memory tiers for the current user (GDPR Art. 17).
|
||||||
|
|
||||||
|
Requires ``X-Confirm: true`` header. Does NOT delete the user account.
|
||||||
|
"""
|
||||||
|
if x_confirm != "true":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.",
|
||||||
|
)
|
||||||
|
|
||||||
|
uid = current_user.id
|
||||||
|
await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid))
|
||||||
|
await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid))
|
||||||
|
await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid))
|
||||||
|
await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid))
|
||||||
|
await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid))
|
||||||
|
await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid))
|
||||||
|
await db.commit()
|
||||||
|
logger.warning("memory: forget_all GDPR wipe user=%s", uid)
|
||||||
@@ -200,6 +200,45 @@ class StripeService:
|
|||||||
sub.status = "canceled"
|
sub.status = "canceled"
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
async def list_invoices(
|
||||||
|
self, user_id: str, db: AsyncSession, limit: int = 24
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return recent invoices for the user from Stripe.
|
||||||
|
|
||||||
|
Returns an empty list when Stripe is not configured or the user has
|
||||||
|
no ``stripe_customer_id``.
|
||||||
|
"""
|
||||||
|
if not self._configured():
|
||||||
|
return []
|
||||||
|
|
||||||
|
from app.models import User # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(User.stripe_customer_id).where(User.id == user_id)
|
||||||
|
)
|
||||||
|
customer_id = result.scalar_one_or_none()
|
||||||
|
if not customer_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = self._client()
|
||||||
|
invoices = s.Invoice.list(customer=customer_id, limit=limit)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": inv.id,
|
||||||
|
"amount_due": inv.amount_due,
|
||||||
|
"amount_paid": inv.amount_paid,
|
||||||
|
"currency": inv.currency,
|
||||||
|
"status": inv.status,
|
||||||
|
"created": inv.created * 1000, # epoch ms
|
||||||
|
"invoice_url": inv.hosted_invoice_url,
|
||||||
|
"invoice_pdf": inv.invoice_pdf,
|
||||||
|
}
|
||||||
|
for inv in invoices.auto_paging_iter()
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
# ── Private DB helpers ───────────────────────────────────────────────
|
# ── Private DB helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
async def _upsert_subscription(
|
async def _upsert_subscription(
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": 1,
|
"providers": 1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
|
"real_embeddings": False, # keyword fallback only
|
||||||
|
"realtime_extraction": False, # batch queue (Phase 2)
|
||||||
|
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
||||||
|
"proactive_mining": False, # Power+ only (Phase 5)
|
||||||
},
|
},
|
||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
@@ -33,6 +37,10 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
|
"real_embeddings": True, # pgvector cosine search
|
||||||
|
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||||
|
"relational_memory": True, # person/project predicates
|
||||||
|
"proactive_mining": False, # Power+ only (Phase 5)
|
||||||
},
|
},
|
||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
@@ -41,6 +49,10 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
|
"real_embeddings": True,
|
||||||
|
"realtime_extraction": True,
|
||||||
|
"relational_memory": True, # all predicates incl. custom
|
||||||
|
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||||
},
|
},
|
||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
@@ -49,6 +61,10 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
"sso": True,
|
"sso": True,
|
||||||
|
"real_embeddings": True,
|
||||||
|
"realtime_extraction": True,
|
||||||
|
"relational_memory": True, # all predicates incl. custom
|
||||||
|
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,9 @@ class Settings(BaseSettings):
|
|||||||
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||||
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||||
|
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
||||||
|
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
||||||
|
LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit)
|
||||||
|
|
||||||
# GitHub Copilot OAuth token storage directory.
|
# GitHub Copilot OAuth token storage directory.
|
||||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||||
@@ -57,15 +60,23 @@ class Settings(BaseSettings):
|
|||||||
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||||
OAUTH_ENCRYPTION_KEY: str = ""
|
OAUTH_ENCRYPTION_KEY: str = ""
|
||||||
|
|
||||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
CORS_ORIGINS: list[str] = [
|
||||||
|
"app://.",
|
||||||
|
"http://localhost:3000",
|
||||||
|
"http://localhost:5173",
|
||||||
|
"http://localhost:4173", # Vite preview (web SPA)
|
||||||
|
"https://app.adiuvai.com", # Production web portal
|
||||||
|
]
|
||||||
|
|
||||||
LANGFUSE_SECRET_KEY: str = ""
|
LANGFUSE_SECRET_KEY: str = ""
|
||||||
LANGFUSE_PUBLIC_KEY: str = ""
|
LANGFUSE_PUBLIC_KEY: str = ""
|
||||||
LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
|
LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
|
||||||
|
|
||||||
|
SCHEDULER_ENABLED: bool = True
|
||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -55,6 +55,38 @@ def _language_instruction(context: dict[str, Any]) -> str:
|
|||||||
f"All your output text must be written in {lang}."
|
f"All your output text must be written in {lang}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _proactive_hints_injection(context: dict[str, Any]) -> str:
|
||||||
|
"""Return a system-prompt paragraph listing proactive behavioral hints.
|
||||||
|
|
||||||
|
Returns empty string when no hints or confidence below threshold.
|
||||||
|
Capped at 600 chars.
|
||||||
|
"""
|
||||||
|
hints: list[str] = context.get("proactive_hints") or []
|
||||||
|
if not hints:
|
||||||
|
return ""
|
||||||
|
body = "\n".join(f"- {h}" for h in hints)
|
||||||
|
section = f"\n\nI noticed (behavioral patterns):\n{body}"
|
||||||
|
if len(section) > 600:
|
||||||
|
section = section[:597] + "..."
|
||||||
|
return section
|
||||||
|
|
||||||
|
|
||||||
|
def _relational_memory_injection(context: dict[str, Any]) -> str:
|
||||||
|
"""Return a system-prompt paragraph listing known people/projects from relational memory.
|
||||||
|
|
||||||
|
Returns empty string when no relational rows or tier is Free.
|
||||||
|
Capped at 800 chars to control token spend.
|
||||||
|
"""
|
||||||
|
relations: list[str] = context.get("relational_memory") or []
|
||||||
|
if not relations:
|
||||||
|
return ""
|
||||||
|
body = "\n".join(f"- {r}" for r in relations)
|
||||||
|
section = f"\n\nKnown people & projects:\n{body}"
|
||||||
|
if len(section) > 800:
|
||||||
|
section = section[:797] + "..."
|
||||||
|
return section
|
||||||
|
|
||||||
|
|
||||||
_HOME_SYSTEM_PROMPT = (
|
_HOME_SYSTEM_PROMPT = (
|
||||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
"Always use tools for factual data retrieval before answering. "
|
"Always use tools for factual data retrieval before answering. "
|
||||||
@@ -904,6 +936,8 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
"home_system", _HOME_SYSTEM_PROMPT
|
"home_system", _HOME_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
|
system_prompt += _relational_memory_injection(context)
|
||||||
|
system_prompt += _proactive_hints_injection(context)
|
||||||
system_prompt += _language_instruction(context)
|
system_prompt += _language_instruction(context)
|
||||||
response = await _run_single_agent(
|
response = await _run_single_agent(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -922,6 +956,8 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
|
|||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
|
system_prompt += _relational_memory_injection(context)
|
||||||
|
system_prompt += _proactive_hints_injection(context)
|
||||||
system_prompt += _language_instruction(context)
|
system_prompt += _language_instruction(context)
|
||||||
response = await _run_single_agent(
|
response = await _run_single_agent(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -946,6 +982,8 @@ async def run_home_stream(
|
|||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
"home_system", _HOME_SYSTEM_PROMPT
|
"home_system", _HOME_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
|
system_prompt += _relational_memory_injection(context)
|
||||||
|
system_prompt += _proactive_hints_injection(context)
|
||||||
system_prompt += _language_instruction(context)
|
system_prompt += _language_instruction(context)
|
||||||
text_chunks: list[str] = []
|
text_chunks: list[str] = []
|
||||||
async for event in _run_single_agent_stream(
|
async for event in _run_single_agent_stream(
|
||||||
@@ -979,6 +1017,8 @@ async def run_floating_stream(
|
|||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
|
system_prompt += _relational_memory_injection(context)
|
||||||
|
system_prompt += _proactive_hints_injection(context)
|
||||||
system_prompt += _language_instruction(context)
|
system_prompt += _language_instruction(context)
|
||||||
sanitizer = _FloatingStreamSanitizer()
|
sanitizer = _FloatingStreamSanitizer()
|
||||||
emitted_sanitized = False
|
emitted_sanitized = False
|
||||||
|
|||||||
34
app/core/embeddings.py
Normal file
34
app/core/embeddings.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""OpenAI embedding helper for associative memory tier.
|
||||||
|
|
||||||
|
Single public function: ``embed_text(text) -> list[float] | None``.
|
||||||
|
Returns None on any failure — callers must implement a keyword fallback.
|
||||||
|
Never raises; all exceptions are logged as warnings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MAX_INPUT_CHARS = 8000
|
||||||
|
_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||||
|
|
||||||
|
|
||||||
|
async def embed_text(text: str) -> list[float] | None:
|
||||||
|
"""Call OpenAI text-embedding-3-small. Return None on failure (caller falls back to keyword)."""
|
||||||
|
try:
|
||||||
|
client = AsyncOpenAI()
|
||||||
|
truncated = text[:_MAX_INPUT_CHARS]
|
||||||
|
response = await client.embeddings.create(
|
||||||
|
input=truncated,
|
||||||
|
model=_EMBEDDING_MODEL,
|
||||||
|
)
|
||||||
|
result: list[float] = response.data[0].embedding
|
||||||
|
logger.debug("embeddings: embed_text dims=%d", len(result))
|
||||||
|
return result
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("embeddings: embed_text failed: %s", exc)
|
||||||
|
return None
|
||||||
@@ -103,6 +103,9 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
|||||||
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||||
|
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||||
|
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
||||||
|
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
450
app/core/memory_extraction.py
Normal file
450
app/core/memory_extraction.py
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
"""Mem0-style Extract/Update pipeline — Phase 2.
|
||||||
|
|
||||||
|
Runs after every ``store_episode`` call to distil durable facts, preferences,
|
||||||
|
routines, and relations from the latest conversation turn.
|
||||||
|
|
||||||
|
Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)``
|
||||||
|
|
||||||
|
Design notes
|
||||||
|
------------
|
||||||
|
- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate.
|
||||||
|
- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving).
|
||||||
|
- Zero-trust: never logs decrypted user content; relation subject/object labels are
|
||||||
|
treated as identifiers (safe to log per spec).
|
||||||
|
- Must not raise into the request path — caller wraps in asyncio.create_task().
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context
|
||||||
|
from app.core.llm import get_agent_llm, model_for_agent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Fallback prompts (used when Langfuse unavailable) ─────────────────────────
|
||||||
|
|
||||||
|
_EXTRACTION_FALLBACK = (
|
||||||
|
"You are a memory extractor for a personal AI secretary. Given the last conversation "
|
||||||
|
"turn, the user's core memory, and recent episode summaries, identify durable facts, "
|
||||||
|
"preferences, routines, and person/project relations worth remembering.\n\n"
|
||||||
|
"Output JSON matching this schema exactly:\n"
|
||||||
|
'{{"candidates": [{{"type": "<fact|preference|relation|routine>", '
|
||||||
|
'"content": "<short canonical statement>", '
|
||||||
|
'"target_tier": "<core|associative|relational|proactive>", '
|
||||||
|
'"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n'
|
||||||
|
"Rules:\n"
|
||||||
|
"- Skip small talk, greetings, one-off questions.\n"
|
||||||
|
"- Max 5 candidates per call.\n"
|
||||||
|
"- Only extract durable information (still true next week).\n"
|
||||||
|
"- For type=relation: subject/predicate/object required.\n"
|
||||||
|
"- Default confidence=0.7.\n\n"
|
||||||
|
"## Last turn\n{last_turn}\n\n"
|
||||||
|
"## Core memory (current)\n{core_memory}\n\n"
|
||||||
|
"## Recent episodes\n{recent_episodes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_DECIDE_FALLBACK = (
|
||||||
|
"You are a memory update decision engine. Given a new memory candidate and a list of "
|
||||||
|
"existing memories from the same tier, decide what action to take.\n\n"
|
||||||
|
"Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n"
|
||||||
|
"- ADD: new information not in existing memories.\n"
|
||||||
|
"- UPDATE: contradicts or supersedes an existing memory.\n"
|
||||||
|
"- DELETE: states something is no longer true.\n"
|
||||||
|
"- NOOP: already captured accurately.\n\n"
|
||||||
|
"## New candidate\n{candidate}\n\n"
|
||||||
|
"## Existing memories (same tier, top neighbours)\n{existing_memories}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pydantic schemas ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class MemoryCandidate(BaseModel):
|
||||||
|
type: Literal["fact", "preference", "relation", "routine"]
|
||||||
|
content: str
|
||||||
|
target_tier: Literal["core", "associative", "relational", "proactive"]
|
||||||
|
subject: str | None = None
|
||||||
|
predicate: str | None = None
|
||||||
|
object: str | None = None
|
||||||
|
confidence: float = Field(default=0.7, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionResult(BaseModel):
|
||||||
|
candidates: list[MemoryCandidate] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Task 2.1 — Extract candidates ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def extract_candidates(
|
||||||
|
last_turn: str,
|
||||||
|
core_memory: dict[str, str],
|
||||||
|
recent_episodes: list[str],
|
||||||
|
) -> ExtractionResult:
|
||||||
|
"""Call gpt-4o-mini to extract memory candidates from the latest turn.
|
||||||
|
|
||||||
|
Returns an ExtractionResult (may be empty on failure — never raises).
|
||||||
|
"""
|
||||||
|
core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)"
|
||||||
|
episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)"
|
||||||
|
|
||||||
|
template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK)
|
||||||
|
|
||||||
|
# Compile with Langfuse variable syntax ({{var}}) or fallback {var}
|
||||||
|
if prompt_obj is not None:
|
||||||
|
try:
|
||||||
|
system_text = prompt_obj.compile(
|
||||||
|
last_turn=last_turn,
|
||||||
|
core_memory=core_str,
|
||||||
|
recent_episodes=episodes_str,
|
||||||
|
)
|
||||||
|
if isinstance(system_text, list):
|
||||||
|
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_extraction: compile failed: %s", exc)
|
||||||
|
system_text = template.format(
|
||||||
|
last_turn=last_turn,
|
||||||
|
core_memory=core_str,
|
||||||
|
recent_episodes=episodes_str,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
system_text = template.format(
|
||||||
|
last_turn=last_turn,
|
||||||
|
core_memory=core_str,
|
||||||
|
recent_episodes=episodes_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||||
|
# Bind JSON mode so the model always returns parseable output.
|
||||||
|
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
lf = get_langfuse()
|
||||||
|
try:
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=system_text),
|
||||||
|
HumanMessage(content="Extract memory candidates as JSON."),
|
||||||
|
]
|
||||||
|
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="memory-extraction",
|
||||||
|
model=model_for_agent("memory-extractor"),
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input=messages,
|
||||||
|
) as gen:
|
||||||
|
response = await llm_json.ainvoke(messages)
|
||||||
|
gen.update(output=response.content, usage=extract_usage(response))
|
||||||
|
else:
|
||||||
|
response = await llm_json.ainvoke(messages)
|
||||||
|
|
||||||
|
raw = json.loads(response.content)
|
||||||
|
result = ExtractionResult.model_validate(raw)
|
||||||
|
logger.info("memory_extraction: extracted %d candidates", len(result.candidates))
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_extraction: extract_candidates failed: %s", exc)
|
||||||
|
return ExtractionResult(candidates=[])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Task 2.2 — Decide action ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def decide_action(
|
||||||
|
candidate: MemoryCandidate,
|
||||||
|
existing: list[str],
|
||||||
|
) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]:
|
||||||
|
"""Decide what to do with a candidate given existing memories in the same tier.
|
||||||
|
|
||||||
|
Short-circuits to ADD without an LLM call when existing is empty (cost saving).
|
||||||
|
Never raises.
|
||||||
|
"""
|
||||||
|
if not existing:
|
||||||
|
return "ADD"
|
||||||
|
|
||||||
|
candidate_str = f"[{candidate.type}] {candidate.content}"
|
||||||
|
existing_str = "\n".join(f"- {m}" for m in existing)
|
||||||
|
|
||||||
|
template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK)
|
||||||
|
|
||||||
|
if prompt_obj is not None:
|
||||||
|
try:
|
||||||
|
system_text = prompt_obj.compile(
|
||||||
|
candidate=candidate_str,
|
||||||
|
existing_memories=existing_str,
|
||||||
|
)
|
||||||
|
if isinstance(system_text, list):
|
||||||
|
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_extraction: decide compile failed: %s", exc)
|
||||||
|
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||||
|
else:
|
||||||
|
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||||
|
|
||||||
|
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||||
|
lf = get_langfuse()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=system_text),
|
||||||
|
HumanMessage(content="Decide action."),
|
||||||
|
]
|
||||||
|
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="memory-decide-action",
|
||||||
|
model=model_for_agent("memory-extractor"),
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input=messages,
|
||||||
|
) as gen:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
gen.update(output=response.content, usage=extract_usage(response))
|
||||||
|
else:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
|
||||||
|
verb = response.content.strip().upper()
|
||||||
|
if verb in ("ADD", "UPDATE", "DELETE", "NOOP"):
|
||||||
|
return verb # type: ignore[return-value]
|
||||||
|
logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb)
|
||||||
|
return "ADD"
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_extraction: decide_action failed: %s", exc)
|
||||||
|
return "ADD"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Task 2.3 — Pipeline orchestrator ──────────────────────────────────────────
|
||||||
|
|
||||||
|
async def run_extraction(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
last_user_msg: str,
|
||||||
|
last_assistant_msg: str,
|
||||||
|
session_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Full Mem0-style extract/update pipeline for one conversation turn.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Load core memory + last 5 episodes.
|
||||||
|
2. extract_candidates() → up to 5 MemoryCandidate objects.
|
||||||
|
3. For each candidate: find top-3 neighbours → decide_action() → apply.
|
||||||
|
4. Trace via Langfuse.
|
||||||
|
|
||||||
|
Never raises — wraps everything in try/except.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_extraction_inner(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
last_user_msg: str,
|
||||||
|
last_assistant_msg: str,
|
||||||
|
session_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db)
|
||||||
|
fernet = await middleware._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1. Load context
|
||||||
|
core: dict[str, str] = await middleware._load_core(user_id, fernet)
|
||||||
|
episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
|
|
||||||
|
last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}"
|
||||||
|
|
||||||
|
lf = get_langfuse()
|
||||||
|
|
||||||
|
async def _run(trace_id: str | None) -> dict[str, Any]:
|
||||||
|
# 2. Extract candidates
|
||||||
|
result = await extract_candidates(last_turn, core, episodes)
|
||||||
|
if not result.candidates:
|
||||||
|
logger.info("memory_extraction: no candidates user=%s", user_id)
|
||||||
|
return {"candidates": 0, "applied": 0}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"memory_extraction: processing %d candidates user=%s trace=%s",
|
||||||
|
len(result.candidates),
|
||||||
|
user_id,
|
||||||
|
trace_id or "-",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Apply each candidate
|
||||||
|
applied = 0
|
||||||
|
actions: list[str] = []
|
||||||
|
for candidate in result.candidates:
|
||||||
|
try:
|
||||||
|
await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id)
|
||||||
|
applied += 1
|
||||||
|
actions.append(f"{candidate.type}:{candidate.target_tier}")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory_extraction: apply failed candidate=%r user=%s: %s",
|
||||||
|
candidate.content[:80],
|
||||||
|
user_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"memory_extraction: applied %d/%d candidates user=%s",
|
||||||
|
applied,
|
||||||
|
len(result.candidates),
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
return {"candidates": len(result.candidates), "applied": applied, "actions": actions}
|
||||||
|
|
||||||
|
with langfuse_context(user_id=user_id, session_id=session_id):
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="span",
|
||||||
|
name="memory-extraction-pipeline",
|
||||||
|
input={"last_turn_preview": last_turn[:200]},
|
||||||
|
) as span:
|
||||||
|
summary = await _run(trace_id=span.id)
|
||||||
|
span.update(output=summary)
|
||||||
|
try:
|
||||||
|
lf.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
await _run(trace_id=None)
|
||||||
|
|
||||||
|
|
||||||
|
async def _apply_candidate(
|
||||||
|
middleware: Any,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
fernet: Any,
|
||||||
|
candidate: MemoryCandidate,
|
||||||
|
trace_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Fetch neighbours, decide action, apply to the appropriate tier."""
|
||||||
|
|
||||||
|
neighbours: list[str] = []
|
||||||
|
|
||||||
|
if candidate.target_tier == "core":
|
||||||
|
# For core tier: neighbours are existing core block values for similar keys.
|
||||||
|
blocks = await middleware.list_core_blocks(user_id)
|
||||||
|
neighbours = [b["value"] for b in blocks[:3]]
|
||||||
|
|
||||||
|
elif candidate.target_tier == "associative":
|
||||||
|
neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3)
|
||||||
|
|
||||||
|
elif candidate.target_tier == "relational":
|
||||||
|
# Relation candidates handled specially — passed to upsert_relation directly.
|
||||||
|
# Neighbours: search by subject label if available.
|
||||||
|
neighbours = []
|
||||||
|
|
||||||
|
elif candidate.target_tier == "proactive":
|
||||||
|
neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3)
|
||||||
|
|
||||||
|
action = await decide_action(candidate, neighbours)
|
||||||
|
logger.info(
|
||||||
|
"memory_extraction: candidate type=%s tier=%s action=%s",
|
||||||
|
candidate.type,
|
||||||
|
candidate.target_tier,
|
||||||
|
action,
|
||||||
|
)
|
||||||
|
|
||||||
|
if action == "NOOP":
|
||||||
|
return
|
||||||
|
|
||||||
|
if candidate.target_tier == "relational":
|
||||||
|
# Always upsert relations — decide_action skipped (no neighbour search).
|
||||||
|
if candidate.subject and candidate.predicate and candidate.object:
|
||||||
|
await _upsert_relation(
|
||||||
|
middleware, db, user_id, candidate, trace_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if action in ("ADD", "UPDATE"):
|
||||||
|
if candidate.target_tier == "core":
|
||||||
|
# Derive a short key from the content (first 40 chars, snake_cased).
|
||||||
|
key = _content_to_key(candidate.content)
|
||||||
|
await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id)
|
||||||
|
|
||||||
|
elif candidate.target_tier == "associative":
|
||||||
|
await middleware.store_associative(user_id, candidate.content)
|
||||||
|
|
||||||
|
elif candidate.target_tier == "proactive":
|
||||||
|
await _store_proactive_stub(middleware, db, user_id, candidate, fernet)
|
||||||
|
|
||||||
|
elif action == "DELETE":
|
||||||
|
if candidate.target_tier == "core":
|
||||||
|
key = _content_to_key(candidate.content)
|
||||||
|
await middleware.delete_core(user_id, key)
|
||||||
|
|
||||||
|
|
||||||
|
def _content_to_key(content: str) -> str:
|
||||||
|
"""Derive a short snake_case key from a content string (first 40 chars)."""
|
||||||
|
import re # noqa: PLC0415
|
||||||
|
slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_")
|
||||||
|
return slug or "memory"
|
||||||
|
|
||||||
|
|
||||||
|
async def _upsert_relation(
|
||||||
|
middleware: Any,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
candidate: MemoryCandidate,
|
||||||
|
trace_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
|
||||||
|
await middleware.upsert_relation(
|
||||||
|
user_id=user_id,
|
||||||
|
subject=candidate.subject or "unknown",
|
||||||
|
subject_type="unknown",
|
||||||
|
predicate=candidate.predicate or "related_to",
|
||||||
|
object_=candidate.object or "unknown",
|
||||||
|
object_type="unknown",
|
||||||
|
confidence=candidate.confidence,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
|
||||||
|
candidate.subject,
|
||||||
|
candidate.predicate,
|
||||||
|
candidate.object,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _store_proactive_stub(
|
||||||
|
middleware: Any,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
candidate: MemoryCandidate,
|
||||||
|
fernet: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Store a proactive pattern row directly (MemoryProactive model)."""
|
||||||
|
import uuid # noqa: PLC0415
|
||||||
|
from app.models import MemoryProactive # noqa: PLC0415
|
||||||
|
from app.core.memory_middleware import _encrypt # noqa: PLC0415
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, candidate.content)
|
||||||
|
row = MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
pattern_encrypted=encrypted,
|
||||||
|
confidence=candidate.confidence,
|
||||||
|
source="inferred",
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
logger.info("memory_extraction: stored proactive pattern user=%s", user_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_extraction: store proactive failed: %s", exc)
|
||||||
|
await db.rollback()
|
||||||
581
app/core/memory_maintenance.py
Normal file
581
app/core/memory_maintenance.py
Normal file
@@ -0,0 +1,581 @@
|
|||||||
|
"""Memory maintenance jobs — Phase 3/5.
|
||||||
|
|
||||||
|
Three entrypoints called by the scheduler (APScheduler) registered in app/main.py:
|
||||||
|
|
||||||
|
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
|
||||||
|
mine_proactive_patterns(db, user_id) — Power+ pattern mining (Phase 5).
|
||||||
|
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
|
||||||
|
|
||||||
|
All are safe to call manually or from tests; they never raise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
||||||
|
from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Decay parameters for relations
|
||||||
|
_DECAY_FACTOR = 0.95
|
||||||
|
_DECAY_PERIOD_DAYS = 30
|
||||||
|
_PRUNE_THRESHOLD = 0.2
|
||||||
|
|
||||||
|
# Proactive pattern decay: 10 % per 7 days since last sighting
|
||||||
|
_PROACTIVE_DECAY_FACTOR = 0.9
|
||||||
|
_PROACTIVE_DECAY_PERIOD_DAYS = 7
|
||||||
|
_PROACTIVE_PRUNE_THRESHOLD = 0.2
|
||||||
|
|
||||||
|
# Mining: require at least this many episodes to attempt pattern extraction
|
||||||
|
_MIN_EPISODES_FOR_MINING = 3
|
||||||
|
_MINING_LOOKBACK_DAYS = 30
|
||||||
|
|
||||||
|
# Audit: caps to control token cost
|
||||||
|
_AUDIT_MAX_FACTS = 50
|
||||||
|
_AUDIT_MAX_LABELS = 100
|
||||||
|
|
||||||
|
|
||||||
|
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
||||||
|
"""Apply confidence decay to all relation rows for a user.
|
||||||
|
|
||||||
|
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
|
||||||
|
Rows whose confidence falls below 0.2 are deleted.
|
||||||
|
|
||||||
|
Never raises — wraps in try/except.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await _decay_relations_inner(db, user_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
deleted = 0
|
||||||
|
decayed = 0
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
reference = row.last_confirmed_at or row.created_at
|
||||||
|
if reference is None:
|
||||||
|
continue
|
||||||
|
if reference.tzinfo is None:
|
||||||
|
reference = reference.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
days_elapsed = (now - reference).days
|
||||||
|
if days_elapsed < _DECAY_PERIOD_DAYS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
periods = days_elapsed // _DECAY_PERIOD_DAYS
|
||||||
|
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
|
||||||
|
|
||||||
|
if new_confidence < _PRUNE_THRESHOLD:
|
||||||
|
await db.delete(row)
|
||||||
|
deleted += 1
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
|
||||||
|
"confidence=%.3f (below threshold)",
|
||||||
|
row.id, user_id, row.subject_label, row.predicate, new_confidence,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
row.confidence = new_confidence
|
||||||
|
decayed += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
|
||||||
|
user_id, decayed, deleted,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
|
||||||
|
await db.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
async def drain_extraction_queue(db: AsyncSession) -> None:
|
||||||
|
"""Process pending ExtractionQueue rows for Free-tier users.
|
||||||
|
|
||||||
|
Each row corresponds to a stored episode that should be fed through the
|
||||||
|
Mem0-style extraction pipeline. Rows are deleted after successful processing.
|
||||||
|
Never raises — wraps in try/except.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await _drain_extraction_queue_inner(db)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _drain_extraction_queue_inner(db: AsyncSession) -> None:
|
||||||
|
from app.models import ExtractionQueue # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(select(ExtractionQueue))
|
||||||
|
rows = result.scalars().all()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
logger.debug("memory_maintenance: drain_extraction_queue nothing to drain")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("memory_maintenance: drain_extraction_queue pending=%d", len(rows))
|
||||||
|
|
||||||
|
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||||
|
|
||||||
|
processed = 0
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
await run_extraction(
|
||||||
|
db=db,
|
||||||
|
user_id=row.user_id,
|
||||||
|
last_user_msg="",
|
||||||
|
last_assistant_msg="",
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
await db.delete(row)
|
||||||
|
await db.commit()
|
||||||
|
processed += 1
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory_maintenance: drain failed row=%s user=%s: %s",
|
||||||
|
row.id, row.user_id, exc,
|
||||||
|
)
|
||||||
|
await db.rollback()
|
||||||
|
|
||||||
|
logger.info("memory_maintenance: drain_extraction_queue processed=%d/%d", processed, len(rows))
|
||||||
|
|
||||||
|
|
||||||
|
async def mine_proactive_patterns(db: AsyncSession, user_id: str) -> None:
|
||||||
|
"""Mine recurring behavioral patterns from last 30 days of episodes (Power+ only).
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Gate on proactive_mining tier feature.
|
||||||
|
2. Load + decrypt last 30 days of episodic summaries.
|
||||||
|
3. Call gpt-4o-mini to identify recurring patterns.
|
||||||
|
4. Encrypt and store each pattern in memory_proactive.
|
||||||
|
5. Apply decay to existing proactive rows.
|
||||||
|
|
||||||
|
Never raises — wraps in try/except.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await _mine_proactive_patterns_inner(db, user_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: mine_proactive_patterns failed user=%s: %s", user_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _mine_proactive_patterns_inner(db: AsyncSession, user_id: str) -> None:
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
|
||||||
|
tier = await tier_manager.get_tier(user_id, db)
|
||||||
|
if not tier_manager.check_feature(tier, "proactive_mining"):
|
||||||
|
logger.debug("memory_maintenance: mine_proactive_patterns skipped (tier=%s)", tier)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load user Fernet key
|
||||||
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not user.encryption_key:
|
||||||
|
logger.warning("memory_maintenance: mine_proactive_patterns no encryption_key user=%s", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
fernet = Fernet(user.encryption_key.encode())
|
||||||
|
cutoff = datetime.now(timezone.utc) - timedelta(days=_MINING_LOOKBACK_DAYS)
|
||||||
|
|
||||||
|
episodes_result = await db.execute(
|
||||||
|
select(MemoryEpisodic)
|
||||||
|
.where(
|
||||||
|
MemoryEpisodic.user_id == user_id,
|
||||||
|
MemoryEpisodic.created_at >= cutoff,
|
||||||
|
)
|
||||||
|
.order_by(MemoryEpisodic.created_at.asc())
|
||||||
|
)
|
||||||
|
episode_rows = episodes_result.scalars().all()
|
||||||
|
|
||||||
|
if len(episode_rows) < _MIN_EPISODES_FOR_MINING:
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: mine_proactive_patterns skipped user=%s episodes=%d (< %d)",
|
||||||
|
user_id, len(episode_rows), _MIN_EPISODES_FOR_MINING,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
summaries: list[str] = []
|
||||||
|
for ep in episode_rows:
|
||||||
|
try:
|
||||||
|
plaintext = fernet.decrypt(ep.summary_encrypted.encode()).decode()
|
||||||
|
summaries.append(plaintext)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not summaries:
|
||||||
|
return
|
||||||
|
|
||||||
|
patterns = await _extract_proactive_patterns(summaries)
|
||||||
|
if not patterns:
|
||||||
|
logger.info("memory_maintenance: mine_proactive_patterns user=%s no patterns extracted", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
stored = 0
|
||||||
|
for pattern_text in patterns:
|
||||||
|
try:
|
||||||
|
encrypted = fernet.encrypt(pattern_text.encode()).decode()
|
||||||
|
row = MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
pattern_encrypted=encrypted,
|
||||||
|
confidence=0.7,
|
||||||
|
source="inferred",
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
stored += 1
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: failed to store pattern user=%s: %s", user_id, exc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: mine_proactive_patterns user=%s stored=%d",
|
||||||
|
user_id, stored,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: mine_proactive_patterns commit failed user=%s: %s", user_id, exc)
|
||||||
|
await db.rollback()
|
||||||
|
return
|
||||||
|
|
||||||
|
await _decay_proactive_patterns(db, user_id, fernet)
|
||||||
|
|
||||||
|
|
||||||
|
async def _extract_proactive_patterns(summaries: list[str]) -> list[str]:
|
||||||
|
"""Call memory-miner LLM to identify recurring behavioral/temporal patterns."""
|
||||||
|
from app.core.llm import get_agent_llm # noqa: PLC0415
|
||||||
|
|
||||||
|
llm = get_agent_llm("memory-miner", temperature=0)
|
||||||
|
combined = "\n---\n".join(summaries[-20:]) # cap at last 20 to control token usage
|
||||||
|
prompt = (
|
||||||
|
"You are analyzing conversation history for a personal AI secretary. "
|
||||||
|
"Identify 3-5 recurring temporal or behavioral patterns (e.g. 'always works late on Thursdays', "
|
||||||
|
"'prefers bullet-point summaries', 'frequently asks about Project Acme status'). "
|
||||||
|
"Return each pattern as a plain, short English sentence on its own line. "
|
||||||
|
"No numbering, no bullet points, no extra text.\n\n"
|
||||||
|
f"Conversation history:\n{combined}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke(prompt)
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
lines = [line.strip() for line in str(text).splitlines() if line.strip()]
|
||||||
|
return lines[:5]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: _extract_proactive_patterns LLM failed: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fernet) -> None:
|
||||||
|
"""Decay confidence of existing proactive patterns; prune below threshold."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryProactive).where(MemoryProactive.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
deleted = 0
|
||||||
|
decayed = 0
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
reference = row.created_at
|
||||||
|
if reference is None:
|
||||||
|
continue
|
||||||
|
if reference.tzinfo is None:
|
||||||
|
reference = reference.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
days_elapsed = (now - reference).days
|
||||||
|
if days_elapsed < _PROACTIVE_DECAY_PERIOD_DAYS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
periods = days_elapsed // _PROACTIVE_DECAY_PERIOD_DAYS
|
||||||
|
new_confidence = row.confidence * (_PROACTIVE_DECAY_FACTOR ** periods)
|
||||||
|
|
||||||
|
if new_confidence < _PROACTIVE_PRUNE_THRESHOLD:
|
||||||
|
await db.delete(row)
|
||||||
|
deleted += 1
|
||||||
|
else:
|
||||||
|
row.confidence = new_confidence
|
||||||
|
decayed += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: decay_proactive user=%s decayed=%d deleted=%d",
|
||||||
|
user_id, decayed, deleted,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
|
||||||
|
await db.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Phase 7: weekly memory audit ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
_AUDIT_CONTRADICTIONS_FALLBACK = (
|
||||||
|
"You are auditing a personal AI assistant's memory bank. "
|
||||||
|
"Each fact has an ID in brackets. "
|
||||||
|
"Find pairs that directly contradict each other "
|
||||||
|
"(e.g. 'prefers morning meetings' vs 'never schedules before noon'). "
|
||||||
|
"For each contradiction, pick the ID to DELETE (the older or less specific one). "
|
||||||
|
'Return ONLY a valid JSON array, no markdown fences: '
|
||||||
|
'[{{"delete": "<id>", "reason": "<one line>"}}]. '
|
||||||
|
"If no contradictions, return [].\n\n"
|
||||||
|
"Facts:\n{facts}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_AUDIT_CANONICALIZE_FALLBACK = (
|
||||||
|
"You are auditing entity labels in a personal AI assistant's relational memory. "
|
||||||
|
"These are names of people, companies, projects, or topics. "
|
||||||
|
"Group labels that clearly refer to the same real-world entity "
|
||||||
|
"(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). "
|
||||||
|
"Return ONLY a valid JSON array, no markdown fences: "
|
||||||
|
'[{{"canonical": "<best label>", "variants": ["<v1>", "<v2>"]}}]. '
|
||||||
|
"Only include groups with at least one variant. Singletons: omit.\n\n"
|
||||||
|
"Labels:\n{labels}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def audit_memory(db: AsyncSession, user_id: str) -> None:
|
||||||
|
"""Weekly audit: contradiction scan on associative facts + label canonicalization on relations.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM.
|
||||||
|
2. LLM flags rows to delete (direct contradictions); hard-delete them.
|
||||||
|
3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates.
|
||||||
|
4. Rewrite variant labels to their canonical form in-place.
|
||||||
|
|
||||||
|
Never raises — wraps in try/except.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await _audit_memory_inner(db, user_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None:
|
||||||
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not user.encryption_key:
|
||||||
|
logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
fernet = Fernet(user.encryption_key.encode())
|
||||||
|
await _scan_associative_contradictions(db, user_id, fernet)
|
||||||
|
await _canonicalize_relation_labels(db, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _scan_associative_contradictions(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
fernet: Fernet,
|
||||||
|
) -> None:
|
||||||
|
"""Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc())
|
||||||
|
.limit(_AUDIT_MAX_FACTS)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
if len(rows) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
id_to_text: dict[str, str] = {}
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
plaintext = fernet.decrypt(row.content_encrypted.encode()).decode()
|
||||||
|
id_to_text[row.id] = plaintext
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if len(id_to_text) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
id_list = list(id_to_text.keys())
|
||||||
|
numbered = "\n".join(
|
||||||
|
f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list)
|
||||||
|
)
|
||||||
|
|
||||||
|
template, prompt_obj = get_prompt_or_fallback(
|
||||||
|
"memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK
|
||||||
|
)
|
||||||
|
system_text = compile_prompt(template, prompt_obj, facts=numbered)
|
||||||
|
|
||||||
|
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||||
|
|
||||||
|
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||||
|
lf = get_langfuse()
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=system_text),
|
||||||
|
HumanMessage(content="Audit facts for contradictions."),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="memory-audit-contradictions",
|
||||||
|
model=model_for_agent("memory-auditor"),
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input=messages,
|
||||||
|
) as gen:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
gen.update(output=response.content, usage=extract_usage(response))
|
||||||
|
else:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
deletions = json.loads(text.strip())
|
||||||
|
if not isinstance(deletions, list):
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s",
|
||||||
|
user_id, exc,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
deleted = 0
|
||||||
|
for item in deletions:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
rid = item.get("delete")
|
||||||
|
if not rid or rid not in id_to_text:
|
||||||
|
continue
|
||||||
|
result2 = await db.execute(
|
||||||
|
select(MemoryAssociative).where(
|
||||||
|
MemoryAssociative.id == rid,
|
||||||
|
MemoryAssociative.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
target = result2.scalar_one_or_none()
|
||||||
|
if target:
|
||||||
|
await db.delete(target)
|
||||||
|
deleted += 1
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s",
|
||||||
|
rid, user_id, item.get("reason", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
if deleted:
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
await db.rollback()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None:
|
||||||
|
"""Group near-duplicate entity labels in memory_relations and unify to canonical form."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
all_labels: set[str] = set()
|
||||||
|
for row in rows:
|
||||||
|
all_labels.add(row.subject_label)
|
||||||
|
all_labels.add(row.object_label)
|
||||||
|
|
||||||
|
labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS]
|
||||||
|
if len(labels_list) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
labels_block = "\n".join(f"- {lbl}" for lbl in labels_list)
|
||||||
|
template, prompt_obj = get_prompt_or_fallback(
|
||||||
|
"memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK
|
||||||
|
)
|
||||||
|
system_text = compile_prompt(template, prompt_obj, labels=labels_block)
|
||||||
|
|
||||||
|
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||||
|
|
||||||
|
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||||
|
lf = get_langfuse()
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=system_text),
|
||||||
|
HumanMessage(content="Canonicalize entity labels."),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="memory-audit-canonicalize",
|
||||||
|
model=model_for_agent("memory-auditor"),
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input=messages,
|
||||||
|
) as gen:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
gen.update(output=response.content, usage=extract_usage(response))
|
||||||
|
else:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
groups = json.loads(text.strip())
|
||||||
|
if not isinstance(groups, list):
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s",
|
||||||
|
user_id, exc,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build variant → canonical map
|
||||||
|
remap: dict[str, str] = {}
|
||||||
|
for group in groups:
|
||||||
|
if not isinstance(group, dict):
|
||||||
|
continue
|
||||||
|
canonical = group.get("canonical", "")
|
||||||
|
variants = group.get("variants") or []
|
||||||
|
if not canonical:
|
||||||
|
continue
|
||||||
|
for v in variants:
|
||||||
|
if isinstance(v, str) and v != canonical:
|
||||||
|
remap[v] = canonical
|
||||||
|
|
||||||
|
if not remap:
|
||||||
|
return
|
||||||
|
|
||||||
|
updated = 0
|
||||||
|
for row in rows:
|
||||||
|
changed = False
|
||||||
|
if row.subject_label in remap:
|
||||||
|
row.subject_label = remap[row.subject_label]
|
||||||
|
changed = True
|
||||||
|
if row.object_label in remap:
|
||||||
|
row.object_label = remap[row.object_label]
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
|
updated += 1
|
||||||
|
|
||||||
|
if updated:
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: _canonicalize_relation_labels user=%s updated=%d",
|
||||||
|
user_id, updated,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
await db.rollback()
|
||||||
@@ -18,8 +18,10 @@ Usage:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from cryptography.fernet import Fernet, InvalidToken
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
@@ -27,15 +29,22 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.models import (
|
from app.models import (
|
||||||
|
ExtractionQueue,
|
||||||
MemoryAssociative,
|
MemoryAssociative,
|
||||||
MemoryCore,
|
MemoryCore,
|
||||||
MemoryEpisodic,
|
MemoryEpisodic,
|
||||||
MemoryProactive,
|
MemoryProactive,
|
||||||
|
MemoryRelation,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
# Tuning constants
|
# Tuning constants
|
||||||
_ASSOCIATIVE_TOP_K = 5
|
_ASSOCIATIVE_TOP_K = 5
|
||||||
_EPISODIC_RECENT_N = 10
|
_EPISODIC_RECENT_N = 10
|
||||||
@@ -64,26 +73,31 @@ class MemoryMiddleware:
|
|||||||
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||||
episodic_memory — [plaintext_summary, ...] (most recent N)
|
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||||
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||||
|
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
|
||||||
"""
|
"""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
user_tier: str = user_dbg.get("tier") or "free"
|
||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
core = await self._load_core(user_id, fernet)
|
||||||
associative = await self._load_associative(user_id, message, fernet)
|
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
||||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
relational = await self._load_relational(user_id, user_tier=user_tier)
|
||||||
|
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
|
||||||
trace_id or "-",
|
trace_id or "-",
|
||||||
user_id,
|
user_id,
|
||||||
user_dbg.get("tier") or "-",
|
user_tier,
|
||||||
len(core),
|
len(core),
|
||||||
len(associative),
|
len(associative),
|
||||||
len(episodic),
|
len(episodic),
|
||||||
len(proactive),
|
len(proactive),
|
||||||
|
len(relational),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -91,6 +105,7 @@ class MemoryMiddleware:
|
|||||||
"associative_memory": associative,
|
"associative_memory": associative,
|
||||||
"episodic_memory": episodic,
|
"episodic_memory": episodic,
|
||||||
"proactive_hints": proactive,
|
"proactive_hints": proactive,
|
||||||
|
"relational_memory": relational,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def store_episode(
|
async def store_episode(
|
||||||
@@ -104,7 +119,10 @@ class MemoryMiddleware:
|
|||||||
"""Summarise and store a completed interaction in episodic memory.
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
The summary is a simple heuristic concatenation (no LLM call) to keep
|
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||||
latency low. Full LLM summarisation can be added in a later step.
|
latency low. After committing the episode row, dispatches the Mem0-style
|
||||||
|
extraction pipeline:
|
||||||
|
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
|
||||||
|
- Free → enqueue an ExtractionQueue row for the daily cron.
|
||||||
"""
|
"""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -113,26 +131,95 @@ class MemoryMiddleware:
|
|||||||
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||||
encrypted = _encrypt(fernet, summary)
|
encrypted = _encrypt(fernet, summary)
|
||||||
|
|
||||||
row = MemoryEpisodic(
|
episode = MemoryEpisodic(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
summary_encrypted=encrypted,
|
summary_encrypted=encrypted,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
self._db.add(row)
|
self._db.add(episode)
|
||||||
|
episode_id: str = episode.id
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
tier = user_dbg.get("tier") or "free"
|
||||||
logger.info(
|
logger.info(
|
||||||
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||||
trace_id or "-",
|
trace_id or "-",
|
||||||
user_id,
|
user_id,
|
||||||
user_dbg.get("tier") or "-",
|
tier,
|
||||||
session_id,
|
session_id,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
|
||||||
|
await self._dispatch_extraction(
|
||||||
|
user_id=user_id,
|
||||||
|
episode_id=episode_id,
|
||||||
|
last_user_msg=message,
|
||||||
|
last_assistant_msg=response,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _dispatch_extraction(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
episode_id: str,
|
||||||
|
last_user_msg: str,
|
||||||
|
last_assistant_msg: str,
|
||||||
|
session_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Route extraction to realtime task or batch queue based on user tier."""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
|
||||||
|
tier = await tier_manager.get_tier(user_id, self._db)
|
||||||
|
|
||||||
|
if tier_manager.check_feature(tier, "realtime_extraction"):
|
||||||
|
# Pro/Power/Team: fire-and-forget in the background.
|
||||||
|
# Must open a fresh session — request session closes after handler returns.
|
||||||
|
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||||
|
from app.db import async_session # noqa: PLC0415
|
||||||
|
|
||||||
|
async def _task() -> None:
|
||||||
|
try:
|
||||||
|
async with async_session() as fresh_db:
|
||||||
|
await run_extraction(
|
||||||
|
db=fresh_db,
|
||||||
|
user_id=user_id,
|
||||||
|
last_user_msg=last_user_msg,
|
||||||
|
last_assistant_msg=last_assistant_msg,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory: extraction task failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.create_task(_task())
|
||||||
|
logger.info("memory: realtime extraction dispatched user=%s", user_id)
|
||||||
|
else:
|
||||||
|
# Free tier: enqueue for daily batch cron.
|
||||||
|
queue_row = ExtractionQueue(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
episode_id=episode_id,
|
||||||
|
)
|
||||||
|
self._db.add(queue_row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory: extraction enqueued (batch) user=%s episode=%s",
|
||||||
|
user_id,
|
||||||
|
episode_id,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory: extraction queue insert failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||||
"""Upsert a core memory key/value for a user."""
|
"""Upsert a core memory key/value for a user."""
|
||||||
@@ -255,6 +342,143 @@ class MemoryMiddleware:
|
|||||||
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def store_associative(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
content: str,
|
||||||
|
entity_type: str | None = None,
|
||||||
|
entity_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store associative memory; embed if user tier has real_embeddings."""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||||
|
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, content)
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
user_tier = user_dbg.get("tier") or "free"
|
||||||
|
|
||||||
|
embedding: list[float] | None = None
|
||||||
|
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||||
|
embedding = await embed_text(content)
|
||||||
|
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
content_encrypted=encrypted,
|
||||||
|
embedding=embedding,
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_id=entity_id,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory: store_associative user=%s embedded=%s",
|
||||||
|
user_id,
|
||||||
|
embedding is not None,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def upsert_relation(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
subject: str,
|
||||||
|
subject_type: str,
|
||||||
|
predicate: str,
|
||||||
|
object_: str,
|
||||||
|
object_type: str,
|
||||||
|
*,
|
||||||
|
confidence: float = 0.7,
|
||||||
|
source_episode_id: str | None = None,
|
||||||
|
notes: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
|
||||||
|
|
||||||
|
subject_label / object_label are plaintext entity identifiers — not encrypted.
|
||||||
|
notes is optional; encrypted with user Fernet if provided.
|
||||||
|
"""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
user_tier = user_dbg.get("tier") or "free"
|
||||||
|
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||||
|
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
|
||||||
|
return
|
||||||
|
|
||||||
|
notes_encrypted: bytes | None = None
|
||||||
|
if notes:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet:
|
||||||
|
notes_encrypted = fernet.encrypt(notes.encode())
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryRelation).where(
|
||||||
|
MemoryRelation.user_id == user_id,
|
||||||
|
MemoryRelation.subject_label == subject,
|
||||||
|
MemoryRelation.predicate == predicate,
|
||||||
|
MemoryRelation.object_label == object_,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing is not None:
|
||||||
|
existing.subject_type = subject_type
|
||||||
|
existing.object_type = object_type
|
||||||
|
existing.confidence = confidence
|
||||||
|
existing.last_confirmed_at = _now()
|
||||||
|
if notes_encrypted is not None:
|
||||||
|
existing.notes_encrypted = notes_encrypted
|
||||||
|
else:
|
||||||
|
self._db.add(MemoryRelation(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
subject_label=subject,
|
||||||
|
subject_type=subject_type,
|
||||||
|
predicate=predicate,
|
||||||
|
object_label=object_,
|
||||||
|
object_type=object_type,
|
||||||
|
confidence=confidence,
|
||||||
|
source_episode_id=source_episode_id,
|
||||||
|
notes_encrypted=notes_encrypted,
|
||||||
|
))
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
|
||||||
|
user_id, subject, predicate, object_,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def query_relations(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
subject: str | None = None,
|
||||||
|
predicate: str | None = None,
|
||||||
|
object_: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> list[MemoryRelation]:
|
||||||
|
"""Query relation rows for a user with optional filters."""
|
||||||
|
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||||
|
if subject is not None:
|
||||||
|
q = q.where(MemoryRelation.subject_label == subject)
|
||||||
|
if predicate is not None:
|
||||||
|
q = q.where(MemoryRelation.predicate == predicate)
|
||||||
|
if object_ is not None:
|
||||||
|
q = q.where(MemoryRelation.object_label == object_)
|
||||||
|
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
|
||||||
|
result = await self._db.execute(q)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||||
"""Insert a long-term archival memory entry."""
|
"""Insert a long-term archival memory entry."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
@@ -343,13 +567,26 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||||
"""Load lightweight user debug fields for trace logs."""
|
"""Load lightweight user debug fields for trace logs."""
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
if user is None:
|
if user is None:
|
||||||
return {"tier": None}
|
return {"tier": None}
|
||||||
return {
|
|
||||||
"tier": user.tier,
|
sub_result = await self._db.execute(
|
||||||
}
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub_tier: str | None = sub_result.scalar_one_or_none()
|
||||||
|
if sub_tier:
|
||||||
|
tier = sub_tier
|
||||||
|
elif settings.ENV == "dev":
|
||||||
|
tier = "power"
|
||||||
|
else:
|
||||||
|
tier = user.tier or "free"
|
||||||
|
|
||||||
|
return {"tier": tier}
|
||||||
|
|
||||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
@@ -364,14 +601,49 @@ class MemoryMiddleware:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_associative(
|
async def _load_associative(
|
||||||
self, user_id: str, message: str, fernet: Fernet
|
self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free"
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Load top-k associative memories.
|
"""Load top-k associative memories.
|
||||||
|
|
||||||
Production: uses pgvector cosine similarity on the message embedding.
|
Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature).
|
||||||
Current implementation: keyword-based fallback (no external embedding call)
|
Free / embedding failure: keyword-ordered fallback (most recent rows).
|
||||||
so tests pass without a live OpenAI key.
|
|
||||||
"""
|
"""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||||
|
|
||||||
|
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||||
|
vec = await embed_text(message)
|
||||||
|
if vec is not None:
|
||||||
|
try:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(
|
||||||
|
MemoryAssociative.user_id == user_id,
|
||||||
|
MemoryAssociative.embedding.isnot(None),
|
||||||
|
)
|
||||||
|
.order_by(MemoryAssociative.embedding.cosine_distance(vec))
|
||||||
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
logger.info(
|
||||||
|
"memory: _load_associative user=%s mode=vector hits=%d",
|
||||||
|
user_id,
|
||||||
|
len(out),
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory: vector search failed user=%s, falling back to keyword: %s",
|
||||||
|
user_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keyword fallback: most recent rows
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryAssociative)
|
select(MemoryAssociative)
|
||||||
.where(MemoryAssociative.user_id == user_id)
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
@@ -379,7 +651,7 @@ class MemoryMiddleware:
|
|||||||
.limit(_ASSOCIATIVE_TOP_K)
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
)
|
)
|
||||||
rows = result.scalars().all()
|
rows = result.scalars().all()
|
||||||
out: list[str] = []
|
out = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
if plaintext is not None:
|
if plaintext is not None:
|
||||||
@@ -408,6 +680,26 @@ class MemoryMiddleware:
|
|||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
|
||||||
|
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
|
||||||
|
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryRelation)
|
||||||
|
.where(MemoryRelation.user_id == user_id)
|
||||||
|
.order_by(MemoryRelation.confidence.desc())
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out = [
|
||||||
|
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return out
|
||||||
|
|
||||||
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryProactive)
|
select(MemoryProactive)
|
||||||
|
|||||||
77
app/main.py
77
app/main.py
@@ -16,13 +16,87 @@ from app.api.middleware.sanitizer import SanitizerMiddleware
|
|||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
async def _memory_audit_cron_tick() -> None:
|
||||||
|
"""Weekly cron: contradiction scan + label canonicalization for all users (Phase 7)."""
|
||||||
|
import logging # noqa: PLC0415
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
_log.info("memory audit cron tick: starting")
|
||||||
|
try:
|
||||||
|
from app.db import async_session # noqa: PLC0415
|
||||||
|
from app.core.memory_maintenance import audit_memory # noqa: PLC0415
|
||||||
|
from app.models import User # noqa: PLC0415
|
||||||
|
from sqlalchemy import select # noqa: PLC0415
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(select(User.id))
|
||||||
|
user_ids: list[str] = list(result.scalars().all())
|
||||||
|
|
||||||
|
for uid in user_ids:
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
await audit_memory(db, uid)
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("memory audit cron tick: audit_memory failed user=%s: %s", uid, exc)
|
||||||
|
|
||||||
|
_log.info("memory audit cron tick: done users=%d", len(user_ids))
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("memory audit cron tick: failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _memory_cron_tick() -> None:
|
||||||
|
"""Hourly cron: drain Free-tier extraction queue + mine proactive patterns for Power+ users."""
|
||||||
|
import logging # noqa: PLC0415
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
_log.info("memory cron tick: starting")
|
||||||
|
try:
|
||||||
|
from app.db import async_session # noqa: PLC0415
|
||||||
|
from app.core.memory_maintenance import drain_extraction_queue, mine_proactive_patterns # noqa: PLC0415
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.models import User # noqa: PLC0415
|
||||||
|
from sqlalchemy import select # noqa: PLC0415
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
await drain_extraction_queue(db)
|
||||||
|
|
||||||
|
# mine proactive patterns for every Power+ user
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(select(User.id))
|
||||||
|
user_ids: list[str] = list(result.scalars().all())
|
||||||
|
|
||||||
|
for uid in user_ids:
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
tier = await tier_manager.get_tier(uid, db)
|
||||||
|
if tier_manager.check_feature(tier, "proactive_mining"):
|
||||||
|
await mine_proactive_patterns(db, uid)
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("memory cron tick: mine_proactive_patterns failed user=%s: %s", uid, exc)
|
||||||
|
|
||||||
|
_log.info("memory cron tick: done users=%d", len(user_ids))
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("memory cron tick: failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup: ensure agent tool modules are loaded.
|
# Startup: ensure agent tool modules are loaded.
|
||||||
import app.agents # noqa: F401
|
import app.agents # noqa: F401
|
||||||
|
|
||||||
|
scheduler = None
|
||||||
|
if settings.SCHEDULER_ENABLED:
|
||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler # noqa: PLC0415
|
||||||
|
|
||||||
|
scheduler = AsyncIOScheduler()
|
||||||
|
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
|
||||||
|
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
|
||||||
|
scheduler.start()
|
||||||
|
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
if scheduler is not None:
|
||||||
|
scheduler.shutdown(wait=False)
|
||||||
|
|
||||||
# Shutdown: dispose SQLAlchemy connection pool
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
from app.db import engine
|
from app.db import engine
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
@@ -50,13 +124,14 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agents, auth, billing, chat, device_ws
|
from app.api.routes import agents, auth, billing, chat, device_ws, memory
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
app.include_router(device_ws.router, prefix="/api/v1")
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
app.include_router(memory.router, prefix="/api/v1")
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
async def health() -> dict:
|
async def health() -> dict:
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ Table inventory:
|
|||||||
memory_associative — per-user semantic memory with embeddings (encrypted)
|
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||||
memory_episodic — per-user session summaries (encrypted)
|
memory_episodic — per-user session summaries (encrypted)
|
||||||
memory_proactive — per-user behavioral patterns (encrypted)
|
memory_proactive — per-user behavioral patterns (encrypted)
|
||||||
|
memory_relations — per-user entity/relation graph (Mem0g-light, Phase 3)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -21,6 +22,7 @@ from __future__ import annotations
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from pgvector.sqlalchemy import Vector
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Boolean,
|
Boolean,
|
||||||
DateTime,
|
DateTime,
|
||||||
@@ -29,6 +31,7 @@ from sqlalchemy import (
|
|||||||
ForeignKey,
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
JSON,
|
JSON,
|
||||||
|
LargeBinary,
|
||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
Uuid,
|
Uuid,
|
||||||
@@ -70,7 +73,7 @@ class User(Base):
|
|||||||
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
password_hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
password_hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
avatar_url: Mapped[str | None] = mapped_column(String(2048), nullable=True)
|
avatar_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||||
@@ -299,8 +302,8 @@ class MemoryAssociative(Base):
|
|||||||
nullable=False, index=True,
|
nullable=False, index=True,
|
||||||
)
|
)
|
||||||
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
# vector(1536) via pgvector; SQLite tests use NULL embeddings so no dialect issue.
|
||||||
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
embedding: Mapped[list | None] = mapped_column(Vector(1536), nullable=True)
|
||||||
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
@@ -348,3 +351,85 @@ class MemoryProactive(Base):
|
|||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionQueue(Base):
|
||||||
|
"""Batch extraction queue for Free-tier users (Phase 2).
|
||||||
|
|
||||||
|
Pro/Power/Team users get realtime asyncio.create_task() extraction.
|
||||||
|
Free users get a queue row here; a daily cron (Phase 5) drains it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "extraction_queue"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
episode_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), nullable=True,
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRelation(Base):
|
||||||
|
"""Per-user entity/relation graph row (Mem0g-light, Phase 3).
|
||||||
|
|
||||||
|
subject_label/object_label are plaintext entity identifiers (not user content).
|
||||||
|
notes_encrypted is optional Fernet-encrypted per-user commentary.
|
||||||
|
confidence in [0.0, 1.0] — decays 5 % per 30 days since last_confirmed_at.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_relations"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
subject_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||||
|
subject_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
|
predicate: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
object_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||||
|
object_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
|
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.7)
|
||||||
|
source_episode_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False),
|
||||||
|
ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
notes_encrypted: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
last_confirmed_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(Base):
|
||||||
|
"""Plugin marketplace catalog entry."""
|
||||||
|
|
||||||
|
__tablename__ = "plugins"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
version: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
author_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
category: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
|
||||||
|
status: Mapped[str] = mapped_column(String(50), nullable=False, default="pending")
|
||||||
|
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|||||||
@@ -31,10 +31,17 @@ class UserProfile(BaseModel):
|
|||||||
surname: str | None = None
|
surname: str | None = None
|
||||||
tier: BillingTier
|
tier: BillingTier
|
||||||
avatar_url: str | None = None
|
avatar_url: str | None = None
|
||||||
|
has_password: bool = True
|
||||||
onboarding_completed_at: int | None = None # epoch ms, null = not onboarded
|
onboarding_completed_at: int | None = None # epoch ms, null = not onboarded
|
||||||
memory: dict[str, str] = Field(default_factory=dict) # decrypted core memory k/v
|
memory: dict[str, str] = Field(default_factory=dict) # decrypted core memory k/v
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccountInfo(BaseModel):
|
||||||
|
provider: str
|
||||||
|
provider_email: str | None = None
|
||||||
|
created_at: int # epoch ms
|
||||||
|
|
||||||
|
|
||||||
# ── Chat ─────────────────────────────────────────────────────────────
|
# ── Chat ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
class ChatContext(BaseModel):
|
class ChatContext(BaseModel):
|
||||||
|
|||||||
@@ -32,8 +32,10 @@ google-auth-oauthlib>=1.2.0
|
|||||||
google-auth-httplib2>=0.2.0
|
google-auth-httplib2>=0.2.0
|
||||||
msal>=1.28.0
|
msal>=1.28.0
|
||||||
cryptography>=42.0.0
|
cryptography>=42.0.0
|
||||||
|
pgvector>=0.2.5
|
||||||
langfuse>=2.0.0
|
langfuse>=2.0.0
|
||||||
beautifulsoup4>=4.12.0
|
beautifulsoup4>=4.12.0
|
||||||
lxml>=5.0.0
|
lxml>=5.0.0
|
||||||
PyYAML>=6.0.0
|
PyYAML>=6.0.0
|
||||||
|
apscheduler>=3.10.0
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
|||||||
1
results.xml
Normal file
1
results.xml
Normal file
File diff suppressed because one or more lines are too long
@@ -1,808 +0,0 @@
|
|||||||
"""Tests for Step 3.4: agent_runner module.
|
|
||||||
|
|
||||||
Coverage:
|
|
||||||
Unit:
|
|
||||||
- _is_overdue — cron schedule overdue detection
|
|
||||||
- _extract_items_from_content — LLM extraction + JSON parsing + validation
|
|
||||||
- _send_insert_to_client — tool_call frame construction + timeout
|
|
||||||
- run_local_agent — end-to-end local agent happy path
|
|
||||||
- run_local_agent — device offline path
|
|
||||||
- run_local_agent — file-read timeout path
|
|
||||||
- run_local_agent — LLM extraction error path
|
|
||||||
- run_cloud_agent — stub returns error immediately
|
|
||||||
- trigger_pending_runs — skipped when config is client-owned
|
|
||||||
- trigger_pending_runs — non-overdue skipped
|
|
||||||
- trigger_pending_runs — device_id filter for local agents
|
|
||||||
|
|
||||||
Integration:
|
|
||||||
- POST /agents/can-create — billing eligibility check
|
|
||||||
- POST /agents/trigger — creates run log + dispatches background task
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.agent_runner import (
|
|
||||||
_extract_items_from_content,
|
|
||||||
_is_overdue,
|
|
||||||
_send_insert_to_client,
|
|
||||||
run_cloud_agent,
|
|
||||||
run_local_agent,
|
|
||||||
trigger_pending_runs,
|
|
||||||
)
|
|
||||||
from app.core.device_manager import DeviceConnectionManager
|
|
||||||
from app.db import get_session
|
|
||||||
from app.main import app
|
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
|
||||||
from tests.conftest import TEST_USER_IDS, auth_header
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_FREE_UID = TEST_USER_IDS["free"]
|
|
||||||
_PRO_UID = TEST_USER_IDS["pro"]
|
|
||||||
|
|
||||||
|
|
||||||
def _make_local_config(user_id: str = _FREE_UID, device_id: str = "dev-001") -> LocalAgentConfig:
|
|
||||||
return LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
device_id=device_id,
|
|
||||||
name="Test Local Agent",
|
|
||||||
directory_paths=["/home/user/emails"],
|
|
||||||
data_types=["tasks", "notes"],
|
|
||||||
prompt_template="Extract tasks and notes from this document.",
|
|
||||||
file_extensions=[".txt", ".eml"],
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
last_run_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_cloud_config(user_id: str = _FREE_UID) -> CloudAgentConfig:
|
|
||||||
return CloudAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
provider="gmail",
|
|
||||||
name="Test Gmail Agent",
|
|
||||||
data_types=["tasks"],
|
|
||||||
prompt_template="Extract tasks from email.",
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
last_run_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_run_log(agent_id: str, agent_type: str = "local", user_id: str = _FREE_UID) -> AgentRunLog:
|
|
||||||
return AgentRunLog(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_type=agent_type,
|
|
||||||
user_id=user_id,
|
|
||||||
status="running",
|
|
||||||
started_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_manager(user_id: str = _FREE_UID, device_id: str = "dev-001") -> DeviceConnectionManager:
|
|
||||||
mgr = DeviceConnectionManager()
|
|
||||||
ws = MagicMock()
|
|
||||||
ws.send_text = AsyncMock()
|
|
||||||
mgr.register(user_id, device_id, ws)
|
|
||||||
return mgr
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _is_overdue
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def test_is_overdue_never_run():
|
|
||||||
"""An agent that has never run is always overdue."""
|
|
||||||
assert _is_overdue("0 */6 * * *", None) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_overdue_very_recently_run():
|
|
||||||
"""An agent that just ran is not overdue."""
|
|
||||||
last = datetime.now(timezone.utc)
|
|
||||||
assert _is_overdue("0 */6 * * *", last) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_overdue_long_ago():
|
|
||||||
"""An agent last run 2 days ago with a 6-hour schedule is overdue."""
|
|
||||||
from datetime import timedelta
|
|
||||||
last = datetime.now(timezone.utc) - timedelta(days=2)
|
|
||||||
assert _is_overdue("0 */6 * * *", last) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_overdue_invalid_cron_returns_false():
|
|
||||||
"""Unparseable cron must not raise and should return False (fail-safe)."""
|
|
||||||
assert _is_overdue("not a cron", None) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_overdue_naive_datetime():
|
|
||||||
"""Naive datetime objects are handled without raising."""
|
|
||||||
from datetime import timedelta
|
|
||||||
last = datetime.utcnow() - timedelta(days=1) # naive
|
|
||||||
# Should not raise.
|
|
||||||
result = _is_overdue("0 */6 * * *", last)
|
|
||||||
assert isinstance(result, bool)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _extract_items_from_content
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_items_happy_path():
|
|
||||||
"""LLM returns valid JSON array; items with allowed tables are returned."""
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.content = json.dumps([
|
|
||||||
{"table": "tasks", "data": {"title": "Buy milk", "priority": "high"}},
|
|
||||||
{"table": "notes", "data": {"title": "Meeting recap", "content": "Discussed roadmap"}},
|
|
||||||
])
|
|
||||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
||||||
items = await _extract_items_from_content(
|
|
||||||
"Extract tasks and notes.",
|
|
||||||
"Email body: Buy milk urgently. Notes from meeting: discussed roadmap.",
|
|
||||||
["tasks", "notes"],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(items) == 2
|
|
||||||
assert items[0]["table"] == "tasks"
|
|
||||||
assert items[0]["data"]["title"] == "Buy milk"
|
|
||||||
assert items[1]["table"] == "notes"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_items_strips_forbidden_fields():
|
|
||||||
"""Fields like id, createdAt, isAiSuggested must be stripped from extracted data."""
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.content = json.dumps([
|
|
||||||
{
|
|
||||||
"table": "tasks",
|
|
||||||
"data": {
|
|
||||||
"title": "Review PR",
|
|
||||||
"id": "should-be-removed",
|
|
||||||
"createdAt": 99999,
|
|
||||||
"isAiSuggested": 0,
|
|
||||||
"isApproved": 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
])
|
|
||||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
||||||
items = await _extract_items_from_content("Extract tasks.", "Review the PR.", ["tasks"])
|
|
||||||
|
|
||||||
assert len(items) == 1
|
|
||||||
data = items[0]["data"]
|
|
||||||
assert "id" not in data
|
|
||||||
assert "createdAt" not in data
|
|
||||||
assert "isAiSuggested" not in data
|
|
||||||
assert "isApproved" not in data
|
|
||||||
assert data["title"] == "Review PR"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_items_invalid_json_returns_empty():
|
|
||||||
"""LLM returning invalid JSON must return empty list without raising."""
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.content = "Sorry, I cannot extract anything."
|
|
||||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
||||||
items = await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
|
||||||
|
|
||||||
assert items == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_items_disallowed_table_filtered():
|
|
||||||
"""Items whose table is not in data_types are discarded."""
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.content = json.dumps([
|
|
||||||
{"table": "tasks", "data": {"title": "Valid task"}},
|
|
||||||
{"table": "projects", "data": {"name": "Should be filtered"}},
|
|
||||||
])
|
|
||||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
||||||
# Only "tasks" is in data_types — "projects" should be filtered.
|
|
||||||
items = await _extract_items_from_content("Extract.", "content", ["tasks"])
|
|
||||||
|
|
||||||
assert len(items) == 1
|
|
||||||
assert items[0]["table"] == "tasks"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_items_empty_data_types_returns_empty():
|
|
||||||
"""If no allowed data_types match, skip LLM call and return immediately."""
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_llm.ainvoke = AsyncMock()
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
||||||
items = await _extract_items_from_content("Extract.", "content", [])
|
|
||||||
|
|
||||||
mock_llm.ainvoke.assert_not_called()
|
|
||||||
assert items == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_items_llm_error_propagates():
|
|
||||||
"""LLM API errors propagate so the caller (run_local_agent) can record them."""
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("API unavailable"))
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
||||||
with pytest.raises(RuntimeError, match="API unavailable"):
|
|
||||||
await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _send_insert_to_client
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_send_insert_to_client_happy_path():
|
|
||||||
"""Frame is sent with isAiSuggested/isApproved added; result is returned."""
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
sent_payloads: list[dict] = []
|
|
||||||
original_send = mgr.send_frame
|
|
||||||
|
|
||||||
async def _capture_send(uid: str, frame: dict) -> None:
|
|
||||||
sent_payloads.append(frame)
|
|
||||||
# Immediately resolve the pending call with a success result.
|
|
||||||
call_id = frame["id"]
|
|
||||||
mgr.resolve_pending_call(uid, call_id, {"row": {"id": "new-id", "title": "Buy milk"}})
|
|
||||||
|
|
||||||
mgr.send_frame = _capture_send # type: ignore[method-assign]
|
|
||||||
|
|
||||||
result = await _send_insert_to_client(
|
|
||||||
_FREE_UID, "tasks", {"title": "Buy milk", "priority": "high"}, mgr
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(sent_payloads) == 1
|
|
||||||
payload = sent_payloads[0]
|
|
||||||
assert payload["action"] == "insert"
|
|
||||||
assert payload["table"] == "tasks"
|
|
||||||
assert payload["data"]["title"] == "Buy milk"
|
|
||||||
assert payload["data"]["isAiSuggested"] == 1
|
|
||||||
assert payload["data"]["isApproved"] == 0
|
|
||||||
assert result["row"]["title"] == "Buy milk"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_send_insert_to_client_timeout():
|
|
||||||
"""asyncio.TimeoutError is raised when Electron does not respond."""
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
async def _slow_send(uid: str, frame: dict) -> None:
|
|
||||||
# Never resolve the pending call.
|
|
||||||
pass
|
|
||||||
|
|
||||||
mgr.send_frame = _slow_send # type: ignore[method-assign]
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._INSERT_TIMEOUT", 0.05):
|
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
await _send_insert_to_client(_FREE_UID, "tasks", {"title": "X"}, mgr)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# run_local_agent
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_local_agent_device_offline():
|
|
||||||
"""run_local_agent marks run as error when device is offline."""
|
|
||||||
config = _make_local_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = DeviceConnectionManager() # Empty — no device registered.
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
||||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
mock_finalize.assert_called_once()
|
|
||||||
_args, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("not connected" in e for e in kwargs["errors"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_local_agent_happy_path():
|
|
||||||
"""End-to-end: files received, LLM extracts one task, insert sent + ack'd."""
|
|
||||||
config = _make_local_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
# Build a fake agent_data frame (will be queued after send).
|
|
||||||
file_frame = {
|
|
||||||
"type": "agent_data",
|
|
||||||
"run_id": run_log.id,
|
|
||||||
"files": [{"path": "/email.eml", "content": "Urgent: fix the bug by Friday."}],
|
|
||||||
}
|
|
||||||
agent_complete_frame = None # sentinel
|
|
||||||
|
|
||||||
sent_frames: list[dict] = []
|
|
||||||
|
|
||||||
async def _mock_send(uid: str, frame: dict) -> None:
|
|
||||||
sent_frames.append(frame)
|
|
||||||
if frame.get("type") == "agent_run":
|
|
||||||
# Simulate Electron responding with file data then agent_complete.
|
|
||||||
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
|
||||||
await q.put(file_frame)
|
|
||||||
await q.put(agent_complete_frame)
|
|
||||||
elif frame.get("type") == "tool_call":
|
|
||||||
# Resolve the pending insert immediately.
|
|
||||||
mgr.resolve_pending_call(uid, frame["id"], {"row": {"id": "new-task", "title": "Fix the bug"}})
|
|
||||||
|
|
||||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
|
||||||
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.content = json.dumps([
|
|
||||||
{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}
|
|
||||||
])
|
|
||||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
||||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
mock_finalize.assert_called_once()
|
|
||||||
_args, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "success"
|
|
||||||
assert kwargs["items_processed"] == 1
|
|
||||||
assert kwargs["items_created"] == 1
|
|
||||||
assert kwargs["errors"] == []
|
|
||||||
assert kwargs["update_config_last_run"] is False
|
|
||||||
|
|
||||||
# Verify agent_run frame was sent.
|
|
||||||
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
|
||||||
assert len(agent_run_frames) == 1
|
|
||||||
assert agent_run_frames[0]["agent_id"] == config.id
|
|
||||||
assert "paths" in agent_run_frames[0]["config"]
|
|
||||||
|
|
||||||
# Verify insert frame was sent with AI flags.
|
|
||||||
insert_frames = [f for f in sent_frames if f.get("type") == "tool_call"]
|
|
||||||
assert len(insert_frames) == 1
|
|
||||||
assert insert_frames[0]["data"]["isAiSuggested"] == 1
|
|
||||||
assert insert_frames[0]["data"]["isApproved"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_local_agent_file_read_timeout():
|
|
||||||
"""run_local_agent marks run as partial/error when device stops sending files."""
|
|
||||||
config = _make_local_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
async def _mock_send(uid: str, frame: dict) -> None:
|
|
||||||
# Don't put anything in the queue — simulate stalled device.
|
|
||||||
pass
|
|
||||||
|
|
||||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._FILE_READ_TIMEOUT", 0.1), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
||||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
mock_finalize.assert_called_once()
|
|
||||||
_args, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error" # No items created, so error (not partial).
|
|
||||||
assert any("timed out" in e.lower() for e in kwargs["errors"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_local_agent_llm_extraction_error():
|
|
||||||
"""LLM errors per-file are recorded; run continues for remaining files."""
|
|
||||||
config = _make_local_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
file_frame = {
|
|
||||||
"type": "agent_data",
|
|
||||||
"run_id": run_log.id,
|
|
||||||
"files": [
|
|
||||||
{"path": "/file1.eml", "content": "Email one."},
|
|
||||||
{"path": "/file2.eml", "content": "Email two."},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _mock_send(uid: str, frame: dict) -> None:
|
|
||||||
if frame.get("type") == "agent_run":
|
|
||||||
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
|
||||||
await q.put(file_frame)
|
|
||||||
await q.put(None) # agent_complete sentinel
|
|
||||||
|
|
||||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
|
||||||
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM boom"))
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
||||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_args, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert kwargs["items_processed"] == 2 # Both files attempted.
|
|
||||||
assert kwargs["items_created"] == 0
|
|
||||||
assert len(kwargs["errors"]) == 2 # One error per file.
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# run_cloud_agent (stub)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_cloud_agent_device_offline():
|
|
||||||
"""Cloud agent aborts immediately when no device is connected."""
|
|
||||||
config = _make_cloud_config()
|
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
||||||
mgr = DeviceConnectionManager() # empty — no devices registered
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
mock_finalize.assert_called_once()
|
|
||||||
_, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_cloud_agent_no_oauth_token():
|
|
||||||
"""Cloud agent errors when no OAuth token is stored."""
|
|
||||||
config = _make_cloud_config()
|
|
||||||
config.oauth_token_encrypted = None
|
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_cloud_agent_token_decrypt_failure():
|
|
||||||
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
|
|
||||||
config = _make_cloud_config()
|
|
||||||
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
|
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
from cryptography.fernet import Fernet as _Fernet
|
|
||||||
valid_key = _Fernet.generate_key().decode()
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
|
||||||
patch("app.integrations.settings") as mock_settings:
|
|
||||||
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
|
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("decrypt" in e.lower() for e in kwargs["errors"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_cloud_agent_happy_path_gmail():
|
|
||||||
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
|
|
||||||
from app.integrations import EmailMessage, encrypt_token
|
|
||||||
from cryptography.fernet import Fernet as _Fernet
|
|
||||||
|
|
||||||
fernet_key = _Fernet.generate_key().decode()
|
|
||||||
credentials = {
|
|
||||||
"token": "access_abc",
|
|
||||||
"refresh_token": "refresh_xyz",
|
|
||||||
"token_uri": "https://oauth2.googleapis.com/token",
|
|
||||||
"client_id": "cid",
|
|
||||||
"client_secret": "csec",
|
|
||||||
}
|
|
||||||
|
|
||||||
config = _make_cloud_config()
|
|
||||||
config.provider = "gmail"
|
|
||||||
config.prompt_template = "Extract tasks from this email."
|
|
||||||
config.data_types = ["tasks"]
|
|
||||||
|
|
||||||
with patch("app.integrations.settings") as ms:
|
|
||||||
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
|
||||||
config.oauth_token_encrypted = encrypt_token(credentials)
|
|
||||||
|
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
sample_email = EmailMessage(
|
|
||||||
id="msg001",
|
|
||||||
subject="Action required",
|
|
||||||
sender="boss@company.com",
|
|
||||||
body_text="Please fix the bug by Friday.",
|
|
||||||
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
|
|
||||||
|
|
||||||
with patch("app.integrations.settings") as mock_int_settings, \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
|
||||||
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
|
|
||||||
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
|
|
||||||
patch("app.core.agent_runner.async_session"):
|
|
||||||
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
|
||||||
|
|
||||||
mock_gmail = AsyncMock()
|
|
||||||
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
|
|
||||||
mock_gmail.refreshed_credentials = None
|
|
||||||
|
|
||||||
with patch("app.integrations.decrypt_token", return_value=credentials), \
|
|
||||||
patch("app.integrations.get_provider", return_value=mock_gmail):
|
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
mock_extract.assert_called_once()
|
|
||||||
mock_insert.assert_called_once()
|
|
||||||
_, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "success"
|
|
||||||
assert kwargs["items_processed"] == 1
|
|
||||||
assert kwargs["items_created"] == 1
|
|
||||||
assert kwargs["config_type"] == "cloud"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_cloud_agent_provider_fetch_error():
|
|
||||||
"""Cloud agent records error status when provider fetch raises RuntimeError."""
|
|
||||||
credentials = {"token": "abc"}
|
|
||||||
config = _make_cloud_config()
|
|
||||||
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
|
|
||||||
config.prompt_template = "Extract tasks."
|
|
||||||
config.data_types = ["tasks"]
|
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
mock_provider = AsyncMock()
|
|
||||||
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
|
|
||||||
mock_provider.refreshed_credentials = None
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
|
||||||
patch("app.integrations.decrypt_token", return_value=credentials), \
|
|
||||||
patch("app.integrations.get_provider", return_value=mock_provider), \
|
|
||||||
patch("app.core.agent_runner.async_session"):
|
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_cloud_agent_refreshed_token_persisted():
|
|
||||||
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
|
||||||
from app.integrations import encrypt_token
|
|
||||||
from cryptography.fernet import Fernet as _Fernet
|
|
||||||
|
|
||||||
fernet_key = _Fernet.generate_key().decode()
|
|
||||||
credentials = {"token": "old_token", "refresh_token": "rt_old"}
|
|
||||||
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
|
|
||||||
|
|
||||||
config = _make_cloud_config()
|
|
||||||
config.prompt_template = "Extract tasks."
|
|
||||||
config.data_types = ["tasks"]
|
|
||||||
|
|
||||||
with patch("app.integrations.settings") as ms:
|
|
||||||
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
|
||||||
config.oauth_token_encrypted = encrypt_token(credentials)
|
|
||||||
|
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
mock_provider = AsyncMock()
|
|
||||||
mock_provider.fetch_messages = AsyncMock(return_value=[])
|
|
||||||
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
|
|
||||||
|
|
||||||
# Track DB writes via mock async_session.
|
|
||||||
mock_cfg_row = MagicMock()
|
|
||||||
mock_cfg_row.oauth_token_encrypted = None
|
|
||||||
|
|
||||||
mock_db = AsyncMock()
|
|
||||||
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
|
||||||
mock_db.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
|
|
||||||
cfg_result = MagicMock()
|
|
||||||
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
|
|
||||||
mock_db.execute = AsyncMock(return_value=cfg_result)
|
|
||||||
mock_db.commit = AsyncMock()
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
|
|
||||||
patch("app.integrations.decrypt_token", return_value=credentials), \
|
|
||||||
patch("app.integrations.get_provider", return_value=mock_provider), \
|
|
||||||
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
|
|
||||||
patch("app.core.agent_runner.async_session", return_value=mock_db), \
|
|
||||||
patch("app.integrations.settings") as mock_int_settings:
|
|
||||||
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
# The new encrypted token should have been written to the config row.
|
|
||||||
mock_encrypt.assert_called_once_with(fresh_credentials)
|
|
||||||
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_finalize_run_updates_cloud_config_last_run_at():
|
|
||||||
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
|
|
||||||
from app.core.agent_runner import _finalize_run
|
|
||||||
|
|
||||||
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
|
|
||||||
run_log.id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
mock_cfg = MagicMock()
|
|
||||||
mock_cfg.last_run_at = None
|
|
||||||
|
|
||||||
cfg_result = MagicMock()
|
|
||||||
cfg_result.scalar_one_or_none.return_value = mock_cfg
|
|
||||||
|
|
||||||
mock_db = AsyncMock()
|
|
||||||
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
|
||||||
mock_db.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_db.merge = AsyncMock(return_value=run_log)
|
|
||||||
mock_db.execute = AsyncMock(return_value=cfg_result)
|
|
||||||
mock_db.commit = AsyncMock()
|
|
||||||
|
|
||||||
config_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session", return_value=mock_db):
|
|
||||||
await _finalize_run(
|
|
||||||
run_log,
|
|
||||||
status="success",
|
|
||||||
update_config_last_run=True,
|
|
||||||
config_id=config_id,
|
|
||||||
config_type="cloud",
|
|
||||||
)
|
|
||||||
|
|
||||||
# CloudAgentConfig.last_run_at should have been set.
|
|
||||||
assert mock_cfg.last_run_at is not None
|
|
||||||
mock_db.commit.assert_called()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# trigger_pending_runs
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trigger_pending_runs_no_overdue():
|
|
||||||
"""Pending-run scan is skipped because agent config is client-owned."""
|
|
||||||
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trigger_pending_runs_device_id_filter():
|
|
||||||
"""Device filtering is no longer backend-managed in pending runs."""
|
|
||||||
|
|
||||||
mgr = _make_manager(device_id="dev-001")
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trigger_pending_runs_dispatches_overdue():
|
|
||||||
"""No pending runs are dispatched by backend after config deprecation."""
|
|
||||||
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Integration: POST /agents/can-create and /agents/trigger
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _override_db(db_session):
|
|
||||||
"""Route all get_session calls to the test SQLite session."""
|
|
||||||
|
|
||||||
async def _gen():
|
|
||||||
yield db_session
|
|
||||||
|
|
||||||
app.dependency_overrides[get_session] = _gen
|
|
||||||
yield
|
|
||||||
app.dependency_overrides.pop(get_session, None)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_can_create_agent_allows_when_under_limit(client):
|
|
||||||
"""POST /agents/can-create returns allowed=True when under tier limit."""
|
|
||||||
resp = client.post(
|
|
||||||
"/api/v1/agents/can-create",
|
|
||||||
json={"active_agents": 0},
|
|
||||||
headers=auth_header("free"),
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert body["allowed"] is True
|
|
||||||
assert body["tier"] == "free"
|
|
||||||
assert body["active_agents"] == 0
|
|
||||||
assert body["limit"] == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_can_create_agent_denies_when_at_limit(client):
|
|
||||||
"""POST /agents/can-create returns allowed=False at free-tier limit."""
|
|
||||||
resp = client.post(
|
|
||||||
"/api/v1/agents/can-create",
|
|
||||||
json={"active_agents": 2},
|
|
||||||
headers=auth_header("free"),
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert body["allowed"] is False
|
|
||||||
assert body["limit"] == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
|
||||||
"""POST /agents/trigger creates a local run log and dispatches background task."""
|
|
||||||
dispatched: list[tuple[str, str]] = []
|
|
||||||
|
|
||||||
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
|
||||||
dispatched.append((user_id, cfg.id))
|
|
||||||
|
|
||||||
def _fake_create_task(coro):
|
|
||||||
coro.close()
|
|
||||||
return MagicMock()
|
|
||||||
|
|
||||||
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
|
||||||
patch("asyncio.create_task") as mock_create_task:
|
|
||||||
mock_create_task.side_effect = _fake_create_task
|
|
||||||
resp = client.post(
|
|
||||||
"/api/v1/agents/trigger",
|
|
||||||
json={
|
|
||||||
"directory": "/home/user/docs",
|
|
||||||
"what_to_extract": ["task", "note"],
|
|
||||||
"batch_interval": "0 */6 * * *",
|
|
||||||
"custom_agent_prompt": "Extract tasks and notes.",
|
|
||||||
"active_agents": 0,
|
|
||||||
},
|
|
||||||
headers=auth_header("power"),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert resp.status_code == 202
|
|
||||||
data = resp.json()
|
|
||||||
assert isinstance(data["agent_id"], str)
|
|
||||||
assert data["agent_id"]
|
|
||||||
assert data["status"] == "running"
|
|
||||||
assert data["agent_type"] == "local"
|
|
||||||
|
|
||||||
# Verify create_task was called (dispatching background run).
|
|
||||||
mock_create_task.assert_called_once()
|
|
||||||
@@ -1,242 +0,0 @@
|
|||||||
"""Tests for the Chatbot Journey endpoints.
|
|
||||||
|
|
||||||
Covers:
|
|
||||||
1. Start journey for local agent → session_id + first question, done=False
|
|
||||||
2. Start journey for cloud agent → contextual email-focused question
|
|
||||||
3. Start journey with existing agent_id → session seeded, first question returned
|
|
||||||
4. Start journey with non-existent agent_id → still succeeds (graceful fallback)
|
|
||||||
5. Message: continue conversation → done=False, follow-up question returned
|
|
||||||
6. Message: LLM wraps up → done=True + prompt_template extracted correctly
|
|
||||||
7. Message with max-turns nudge → no crash, returns response
|
|
||||||
8. Invalid session_id → 404
|
|
||||||
9. Expired session → 404
|
|
||||||
10. Session ownership: user B cannot access user A's session
|
|
||||||
11. No JWT on /start → 401
|
|
||||||
12. No JWT on /message → 401
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.routes.agent_setup import (
|
|
||||||
_SESSION_TTL_SECONDS,
|
|
||||||
_TEMPLATE_END,
|
|
||||||
_TEMPLATE_START,
|
|
||||||
_extract_template,
|
|
||||||
_sessions,
|
|
||||||
)
|
|
||||||
from app.models import LocalAgentConfig
|
|
||||||
from tests.conftest import TEST_USER_IDS, auth_header
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _start(client: TestClient, agent_type: str = "local", agent_id: str | None = None, tier: str = "power") -> dict:
|
|
||||||
body: dict = {"agent_type": agent_type}
|
|
||||||
if agent_id:
|
|
||||||
body["agent_id"] = agent_id
|
|
||||||
resp = client.post("/api/v1/agents/journey/start", json=body, headers=auth_header(tier))
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
def _message(client: TestClient, session_id: str, message: str, tier: str = "power") -> dict:
|
|
||||||
return client.post(
|
|
||||||
"/api/v1/agents/journey/message",
|
|
||||||
json={"session_id": session_id, "message": message},
|
|
||||||
headers=auth_header(tier),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: _extract_template ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_template_present():
|
|
||||||
text = f"Some preamble.\n{_TEMPLATE_START}\nExtract tasks from emails.\n{_TEMPLATE_END}\nTrailing text."
|
|
||||||
result = _extract_template(text)
|
|
||||||
assert result == "Extract tasks from emails."
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_template_absent():
|
|
||||||
assert _extract_template("No markers here.") is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_template_empty_content():
|
|
||||||
text = f"{_TEMPLATE_START}\n{_TEMPLATE_END}"
|
|
||||||
assert _extract_template(text) is None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Start journey ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_start_journey_local(client: TestClient):
|
|
||||||
resp = _start(client, agent_type="local")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert "session_id" in body
|
|
||||||
assert body["done"] is False
|
|
||||||
assert body["prompt_template"] is None
|
|
||||||
assert len(body["message"]) > 0
|
|
||||||
# Local question should be about files/directories
|
|
||||||
assert any(w in body["message"].lower() for w in ("file", "director", "document", "monitor"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_start_journey_cloud(client: TestClient):
|
|
||||||
resp = _start(client, agent_type="cloud")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert body["done"] is False
|
|
||||||
# Cloud question should mention emails or messages
|
|
||||||
assert any(w in body["message"].lower() for w in ("email", "message", "communication"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_start_journey_with_agent_id(client: TestClient, db_session: AsyncSession):
|
|
||||||
"""When agent_id is provided, session should be created even if agent doesn't exist."""
|
|
||||||
fake_agent_id = str(uuid.uuid4())
|
|
||||||
resp = _start(client, agent_type="local", agent_id=fake_agent_id)
|
|
||||||
# Should succeed gracefully even if the agent_id doesn't exist
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert body["done"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_start_journey_with_existing_agent(client: TestClient, db_session: AsyncSession):
|
|
||||||
"""When a real local agent is provided, session is seeded with its prompt_template."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
user_id = TEST_USER_IDS["power"]
|
|
||||||
agent = LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
name="Test Agent",
|
|
||||||
device_id="device-1",
|
|
||||||
directory_paths=["/home/user/emails"],
|
|
||||||
data_types=["tasks"],
|
|
||||||
prompt_template="Extract tasks from .eml files.",
|
|
||||||
file_extensions=[".eml"],
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _seed():
|
|
||||||
db_session.add(agent)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(_seed())
|
|
||||||
|
|
||||||
resp = _start(client, agent_type="local", agent_id=agent.id)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert body["done"] is False
|
|
||||||
# The session should be stored
|
|
||||||
assert body["session_id"] in _sessions
|
|
||||||
|
|
||||||
|
|
||||||
def test_start_journey_requires_auth(client: TestClient):
|
|
||||||
resp = client.post("/api/v1/agents/journey/start", json={"agent_type": "local"})
|
|
||||||
assert resp.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
# ── Message ───────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_continues_conversation(client: TestClient):
|
|
||||||
"""A mid-journey reply (no template markers) returns done=False."""
|
|
||||||
follow_up = "That looks good. Can you tell me more about priority rules?"
|
|
||||||
|
|
||||||
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
|
|
||||||
start_resp = _start(client, agent_type="local")
|
|
||||||
assert start_resp.status_code == 200
|
|
||||||
session_id = start_resp.json()["session_id"]
|
|
||||||
|
|
||||||
msg_resp = _message(client, session_id, "I have .eml and .txt files")
|
|
||||||
assert msg_resp.status_code == 200
|
|
||||||
body = msg_resp.json()
|
|
||||||
assert body["done"] is False
|
|
||||||
assert body["prompt_template"] is None
|
|
||||||
assert body["message"] == follow_up
|
|
||||||
assert body["session_id"] == session_id
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_produces_template(client: TestClient):
|
|
||||||
"""When the LLM includes PROMPT_TEMPLATE markers, done=True and prompt_template is set."""
|
|
||||||
final_template = "Extract tasks from email. Subject → title. 'urgent' → high priority."
|
|
||||||
llm_response = (
|
|
||||||
"Great, I have all the information I need.\n"
|
|
||||||
f"{_TEMPLATE_START}\n{final_template}\n{_TEMPLATE_END}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=llm_response)):
|
|
||||||
start_resp = _start(client, agent_type="cloud")
|
|
||||||
assert start_resp.status_code == 200
|
|
||||||
session_id = start_resp.json()["session_id"]
|
|
||||||
|
|
||||||
msg_resp = _message(client, session_id, "Only invoices from clients")
|
|
||||||
assert msg_resp.status_code == 200
|
|
||||||
body = msg_resp.json()
|
|
||||||
assert body["done"] is True
|
|
||||||
assert body["prompt_template"] == final_template
|
|
||||||
# Session should be cleaned up
|
|
||||||
assert session_id not in _sessions
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_invalid_session(client: TestClient):
|
|
||||||
resp = _message(client, "nonexistent-session-id", "hello")
|
|
||||||
assert resp.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_wrong_owner(client: TestClient):
|
|
||||||
"""User B cannot access user A's session."""
|
|
||||||
start_resp = _start(client, agent_type="local", tier="power")
|
|
||||||
session_id = start_resp.json()["session_id"]
|
|
||||||
|
|
||||||
# user with "pro" tier (different user_id) tries to send a message
|
|
||||||
resp = client.post(
|
|
||||||
"/api/v1/agents/journey/message",
|
|
||||||
json={"session_id": session_id, "message": "hello"},
|
|
||||||
headers=auth_header("pro"), # different user
|
|
||||||
)
|
|
||||||
assert resp.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_expired_session(client: TestClient):
|
|
||||||
"""Expired sessions return 404."""
|
|
||||||
start_resp = _start(client, agent_type="local")
|
|
||||||
session_id = start_resp.json()["session_id"]
|
|
||||||
|
|
||||||
# Manually expire the session
|
|
||||||
_sessions[session_id].created_at = time.monotonic() - _SESSION_TTL_SECONDS - 1
|
|
||||||
|
|
||||||
resp = _message(client, session_id, "hello")
|
|
||||||
assert resp.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_requires_auth(client: TestClient):
|
|
||||||
resp = client.post(
|
|
||||||
"/api/v1/agents/journey/message",
|
|
||||||
json={"session_id": "any", "message": "hello"},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_max_turns_nudge(client: TestClient):
|
|
||||||
"""After _MAX_TURNS user messages, a system nudge is appended but no crash occurs."""
|
|
||||||
from app.api.routes.agent_setup import _MAX_TURNS
|
|
||||||
|
|
||||||
follow_up = "Tell me more about priority rules."
|
|
||||||
|
|
||||||
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
|
|
||||||
start_resp = _start(client, agent_type="local")
|
|
||||||
session_id = start_resp.json()["session_id"]
|
|
||||||
|
|
||||||
for i in range(_MAX_TURNS):
|
|
||||||
resp = _message(client, session_id, f"Answer {i + 1}")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
# While no template produced, session must still exist
|
|
||||||
if resp.json()["done"]:
|
|
||||||
break # LLM decided to wrap up early — also fine
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
"""Unit tests for Step 1 file classification (_classify_file).
|
|
||||||
|
|
||||||
These tests call the real LLM so they require OPENAI_API_KEY / LLM env vars.
|
|
||||||
Run with: pytest tests/test_classify_file.py -v
|
|
||||||
|
|
||||||
To run a quick manual check against a real file without the full UI:
|
|
||||||
python -m tests.test_classify_file <path/to/file.txt> [project_name...]
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.agent_runner import _classify_file
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
PROJECTS_SAMPLE = [
|
|
||||||
{
|
|
||||||
"id": "aaaa-0001-0000-0000-000000000001",
|
|
||||||
"name": "ARPA Sicilia POC",
|
|
||||||
"status": "active",
|
|
||||||
"aiSummary": "Proof of concept for AI features targeting ARPA Sicilia agency.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "bbbb-0002-0000-0000-000000000002",
|
|
||||||
"name": "SNAM AI Meeting Prep",
|
|
||||||
"status": "active",
|
|
||||||
"aiSummary": "AI-assisted preparation of meeting materials for SNAM.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "cccc-0003-0000-0000-000000000003",
|
|
||||||
"name": "SFERA+ Wave 2",
|
|
||||||
"status": "active",
|
|
||||||
"aiSummary": "Second wave of the SFERA+ whitelist project.",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
ARPA_EMAIL = """\
|
|
||||||
to: roberto.musso@hpe.com; luca.tondin@hpecds.com
|
|
||||||
isImportance: normal
|
|
||||||
hasAttachment: True
|
|
||||||
---
|
|
||||||
## Body
|
|
||||||
Buongiorno,
|
|
||||||
|
|
||||||
In riferimento alla riunione di ieri sul POC ARPA Sicilia, vi invio il riassunto
|
|
||||||
dei deliverable concordati:
|
|
||||||
- Preparare demo entro il 30 marzo
|
|
||||||
- Condividere documentazione tecnica con il team ARPA
|
|
||||||
- Fissare call di follow-up la prossima settimana
|
|
||||||
|
|
||||||
Cordiali saluti
|
|
||||||
Roberto Marchetti
|
|
||||||
"""
|
|
||||||
|
|
||||||
SNAM_EMAIL = """\
|
|
||||||
to: roberto.musso@hpe.com
|
|
||||||
isImportance: high
|
|
||||||
hasAttachment: False
|
|
||||||
---
|
|
||||||
## Body
|
|
||||||
Ciao,
|
|
||||||
ti invio l'agenda per la riunione SNAM di domani.
|
|
||||||
Per favore conferma la tua presenza.
|
|
||||||
"""
|
|
||||||
|
|
||||||
UNRELATED_EMAIL = """\
|
|
||||||
to: roberto.musso@hpe.com
|
|
||||||
isImportance: normal
|
|
||||||
---
|
|
||||||
## Body
|
|
||||||
Benvenuto nel programma HPE Employee Learning Series.
|
|
||||||
Completa la formazione richiesta entro la fine del trimestre.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tests ─────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_arpa_matches_existing():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="arpa_email.txt",
|
|
||||||
file_content=ARPA_EMAIL,
|
|
||||||
projects=PROJECTS_SAMPLE,
|
|
||||||
config_data_types=["tasks", "notes", "timelines"],
|
|
||||||
)
|
|
||||||
assert project_id == "aaaa-0001-0000-0000-000000000001", (
|
|
||||||
f"Expected ARPA project, got project_id={project_id!r} new_name={new_name!r}"
|
|
||||||
)
|
|
||||||
assert new_name is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_snam_matches_existing():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="snam_email.txt",
|
|
||||||
file_content=SNAM_EMAIL,
|
|
||||||
projects=PROJECTS_SAMPLE,
|
|
||||||
config_data_types=["tasks", "notes"],
|
|
||||||
)
|
|
||||||
assert project_id == "bbbb-0002-0000-0000-000000000002", (
|
|
||||||
f"Expected SNAM project, got project_id={project_id!r} new_name={new_name!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_unrelated_returns_new():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="learning_email.txt",
|
|
||||||
file_content=UNRELATED_EMAIL,
|
|
||||||
projects=PROJECTS_SAMPLE,
|
|
||||||
config_data_types=["tasks", "notes"],
|
|
||||||
)
|
|
||||||
assert project_id == "new"
|
|
||||||
assert new_name is not None # LLM should suggest a name
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_empty_file_returns_new():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="empty.txt",
|
|
||||||
file_content=" ",
|
|
||||||
projects=PROJECTS_SAMPLE,
|
|
||||||
config_data_types=["tasks"],
|
|
||||||
)
|
|
||||||
assert project_id == "new"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_no_projects_returns_new():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="arpa_email.txt",
|
|
||||||
file_content=ARPA_EMAIL,
|
|
||||||
projects=[],
|
|
||||||
config_data_types=["tasks", "notes"],
|
|
||||||
)
|
|
||||||
assert project_id == "new"
|
|
||||||
assert new_name is not None
|
|
||||||
|
|
||||||
|
|
||||||
# ── CLI quick-test runner ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _cli_test(file_path: str, project_names: list[str]) -> None:
|
|
||||||
"""Run Step 1 classification against a real file from the CLI."""
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
content = Path(file_path).read_text(encoding="utf-8", errors="replace")
|
|
||||||
projects = [
|
|
||||||
{"id": f"test-id-{i:04d}", "name": name, "status": "active", "aiSummary": ""}
|
|
||||||
for i, name in enumerate(project_names)
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"\nClassifying: {file_path}")
|
|
||||||
print(f"Projects in context: {[p['name'] for p in projects]}\n")
|
|
||||||
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path=file_path,
|
|
||||||
file_content=content,
|
|
||||||
projects=projects,
|
|
||||||
config_data_types=["tasks", "notes", "timelines"],
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"project_id": project_id,
|
|
||||||
"matched_name": next((p["name"] for p in projects if p["id"] == project_id), None),
|
|
||||||
"new_project_name": new_name,
|
|
||||||
"domains": domains,
|
|
||||||
}
|
|
||||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if len(sys.argv) < 2:
|
|
||||||
print("Usage: python -m tests.test_classify_file <file_path> [project_name ...]")
|
|
||||||
sys.exit(1)
|
|
||||||
asyncio.run(_cli_test(sys.argv[1], sys.argv[2:]))
|
|
||||||
@@ -63,7 +63,7 @@ class _FakeLLM:
|
|||||||
async def test_run_home_uses_mocked_tool_result():
|
async def test_run_home_uses_mocked_tool_result():
|
||||||
fake_llm = _FakeLLM()
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
):
|
):
|
||||||
out = await run_home("user-1", "list my tasks", {})
|
out = await run_home("user-1", "list my tasks", {})
|
||||||
@@ -76,7 +76,7 @@ async def test_run_home_uses_mocked_tool_result():
|
|||||||
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
|
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
|
||||||
fake_llm = _FakeLLM()
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
):
|
):
|
||||||
events = []
|
events = []
|
||||||
@@ -103,7 +103,7 @@ async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
|
|||||||
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
|
with patch("app.core.deep_agent.get_agent_llm", return_value=_ClassifierOnlyLLM()):
|
||||||
domain = await _infer_floating_domain(
|
domain = await _infer_floating_domain(
|
||||||
"Quali sono i miei task per il progetto X",
|
"Quali sono i miei task per il progetto X",
|
||||||
{
|
{
|
||||||
@@ -165,7 +165,7 @@ async def test_run_floating_strips_xml_like_tags_from_final_text():
|
|||||||
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||||
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||||
):
|
):
|
||||||
text, _domain = await run_floating(
|
text, _domain = await run_floating(
|
||||||
@@ -187,7 +187,7 @@ async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
|
|||||||
yield "token", "Hai 1 task:\\n"
|
yield "token", "Hai 1 task:\\n"
|
||||||
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||||
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||||
):
|
):
|
||||||
events = []
|
events = []
|
||||||
@@ -233,7 +233,7 @@ async def test_run_floating_stream_falls_back_to_final_response_content_when_ast
|
|||||||
if False:
|
if False:
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
|
with patch("app.core.deep_agent.get_agent_llm", return_value=_NoChunkLLM()), patch(
|
||||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
):
|
):
|
||||||
events = []
|
events = []
|
||||||
@@ -255,7 +255,7 @@ async def test_run_floating_returns_fallback_when_sanitization_would_empty_text(
|
|||||||
async def _fake_run_single_agent(**_kwargs):
|
async def _fake_run_single_agent(**_kwargs):
|
||||||
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||||
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||||
):
|
):
|
||||||
text, _domain = await run_floating(
|
text, _domain = await run_floating(
|
||||||
@@ -274,7 +274,7 @@ async def test_run_floating_stream_returns_fallback_when_sanitization_would_empt
|
|||||||
async def _fake_stream(**_kwargs):
|
async def _fake_stream(**_kwargs):
|
||||||
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||||
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||||
):
|
):
|
||||||
events = []
|
events = []
|
||||||
|
|||||||
@@ -156,40 +156,6 @@ async def test_manager_unregister_cancels_pending_calls(manager, mock_ws):
|
|||||||
assert fut.cancelled()
|
assert fut.cancelled()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_manager_agent_data_queue(manager, mock_ws):
|
|
||||||
manager.register("user1", "dev-A", mock_ws)
|
|
||||||
q = manager.get_agent_data_queue("user1", "run-xyz")
|
|
||||||
# Put a frame and get it back.
|
|
||||||
frame = {"type": "agent_data", "run_id": "run-xyz", "files": []}
|
|
||||||
await q.put(frame)
|
|
||||||
assert await q.get() == frame
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_manager_agent_data_queue_creates_once(manager, mock_ws):
|
|
||||||
manager.register("user1", "dev-A", mock_ws)
|
|
||||||
q1 = manager.get_agent_data_queue("user1", "run-1")
|
|
||||||
q2 = manager.get_agent_data_queue("user1", "run-1")
|
|
||||||
assert q1 is q2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_manager_agent_data_queue_raises_when_offline(manager):
|
|
||||||
with pytest.raises(RuntimeError, match="not connected"):
|
|
||||||
manager.get_agent_data_queue("ghost", "run-1")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_manager_cleanup_agent_data_queue(manager, mock_ws):
|
|
||||||
manager.register("user1", "dev-A", mock_ws)
|
|
||||||
manager.get_agent_data_queue("user1", "run-1")
|
|
||||||
manager.cleanup_agent_data_queue("user1", "run-1")
|
|
||||||
# After cleanup a new queue is created (not the same object).
|
|
||||||
q_new = manager.get_agent_data_queue("user1", "run-1")
|
|
||||||
assert q_new is not None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Integration tests — /api/v1/ws/device endpoint
|
# Integration tests — /api/v1/ws/device endpoint
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -266,43 +232,6 @@ def test_ws_device_tool_result_dispatched(client):
|
|||||||
assert any(c["call_id"] == "call-123" for c in captured)
|
assert any(c["call_id"] == "call-123" for c in captured)
|
||||||
|
|
||||||
|
|
||||||
def test_ws_device_agent_data_enqueued(client):
|
|
||||||
"""agent_data frame is placed in the per-run queue by the message loop."""
|
|
||||||
from app.core.device_manager import device_manager as dm
|
|
||||||
|
|
||||||
token = make_jwt(tier="free")
|
|
||||||
user_id = TEST_USER_IDS["free"]
|
|
||||||
|
|
||||||
# Capture the queue object the message loop accesses.
|
|
||||||
captured_queue: list[asyncio.Queue] = []
|
|
||||||
original_get_queue = dm.get_agent_data_queue
|
|
||||||
|
|
||||||
def _spy_get_queue(uid, run_id):
|
|
||||||
q = original_get_queue(uid, run_id)
|
|
||||||
if not captured_queue:
|
|
||||||
captured_queue.append(q)
|
|
||||||
return q
|
|
||||||
|
|
||||||
with patch.object(dm, "get_agent_data_queue", side_effect=_spy_get_queue):
|
|
||||||
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
|
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
|
||||||
ws.send_text(_device_hello("dev-001"))
|
|
||||||
ws.send_text(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"type": "agent_data",
|
|
||||||
"run_id": "run-XYZ",
|
|
||||||
"files": [{"path": "/tmp/file.txt", "content": "hello"}],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
# The queue should have received exactly one frame.
|
|
||||||
assert captured_queue, "queue was never accessed"
|
|
||||||
assert not captured_queue[0].empty()
|
|
||||||
|
|
||||||
|
|
||||||
def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session):
|
def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session):
|
||||||
"""On disconnect, _mark_runs_disconnected is called with the correct user_id."""
|
"""On disconnect, _mark_runs_disconnected is called with the correct user_id."""
|
||||||
from app.api.routes import device_ws as _dws
|
from app.api.routes import device_ws as _dws
|
||||||
|
|||||||
405
tests/test_memory_audit.py
Normal file
405
tests/test_memory_audit.py
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
"""Tests for Phase 7 — weekly audit_memory job.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
1. audit_memory never raises even if inner work fails.
|
||||||
|
2. _scan_associative_contradictions skips when < 2 decryptable facts.
|
||||||
|
3. _scan_associative_contradictions calls LLM and deletes flagged rows.
|
||||||
|
4. _scan_associative_contradictions is a no-op when LLM fails.
|
||||||
|
5. _scan_associative_contradictions is a no-op when LLM returns non-list.
|
||||||
|
6. _canonicalize_relation_labels skips when no relation rows.
|
||||||
|
7. _canonicalize_relation_labels rewrites variant labels to canonical form.
|
||||||
|
8. _canonicalize_relation_labels is a no-op when LLM fails.
|
||||||
|
9. _canonicalize_relation_labels is a no-op when remap is empty.
|
||||||
|
10. Both helpers work correctly when Langfuse is unavailable (lf=None).
|
||||||
|
11. get_prompt_or_fallback called with correct Langfuse prompt names.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from contextlib import contextmanager, ExitStack
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.memory_maintenance import (
|
||||||
|
_canonicalize_relation_labels,
|
||||||
|
_scan_associative_contradictions,
|
||||||
|
audit_memory,
|
||||||
|
)
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import MemoryAssociative, MemoryRelation, User
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||||
|
_FERNET_KEY = Fernet.generate_key().decode()
|
||||||
|
_FERNET = Fernet(_FERNET_KEY.encode())
|
||||||
|
|
||||||
|
|
||||||
|
# ── DB override ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def pro_user(db_session):
|
||||||
|
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = _FERNET_KEY
|
||||||
|
await db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def _enc(text: str) -> str:
|
||||||
|
return _FERNET.encrypt(text.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _assoc_row(user_id: str, text: str) -> MemoryAssociative:
|
||||||
|
return MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
content_encrypted=_enc(text),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _relation_row(user_id: str, subject: str, predicate: str, obj: str) -> MemoryRelation:
|
||||||
|
return MemoryRelation(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
subject_label=subject,
|
||||||
|
subject_type="person",
|
||||||
|
predicate=predicate,
|
||||||
|
object_label=obj,
|
||||||
|
object_type="company",
|
||||||
|
confidence=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _llm_response(content: str) -> MagicMock:
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.content = content
|
||||||
|
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_llm(content: str) -> MagicMock:
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=_llm_response(content))
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _patch_audit(llm_mock, lf=None, prompt_text: str = "fallback {facts}"):
|
||||||
|
"""Context manager that patches all external deps for audit helpers."""
|
||||||
|
with ExitStack() as stack:
|
||||||
|
stack.enter_context(
|
||||||
|
patch("app.core.llm.get_agent_llm", return_value=llm_mock)
|
||||||
|
)
|
||||||
|
stack.enter_context(
|
||||||
|
patch("app.core.llm.model_for_agent", return_value="memory-auditor")
|
||||||
|
)
|
||||||
|
stack.enter_context(
|
||||||
|
patch("app.core.memory_maintenance.get_langfuse", return_value=lf)
|
||||||
|
)
|
||||||
|
stack.enter_context(
|
||||||
|
patch(
|
||||||
|
"app.core.memory_maintenance.get_prompt_or_fallback",
|
||||||
|
return_value=(prompt_text, None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
stack.enter_context(
|
||||||
|
patch(
|
||||||
|
"app.core.memory_maintenance.compile_prompt",
|
||||||
|
side_effect=lambda tmpl, obj, **kw: tmpl.format(**kw) if "{" in tmpl else tmpl,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test 1: audit_memory never raises ────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audit_memory_never_raises_on_missing_user(db_session):
|
||||||
|
"""audit_memory with a non-existent user_id must not raise."""
|
||||||
|
await audit_memory(db_session, str(uuid.uuid4()))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audit_memory_never_raises_on_llm_failure(db_session, pro_user):
|
||||||
|
"""audit_memory must swallow inner exceptions."""
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.core.llm.get_agent_llm", return_value=llm),
|
||||||
|
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
||||||
|
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
||||||
|
patch(
|
||||||
|
"app.core.memory_maintenance.get_prompt_or_fallback",
|
||||||
|
return_value=("p {facts}", None),
|
||||||
|
),
|
||||||
|
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
||||||
|
):
|
||||||
|
await audit_memory(db_session, PRO_USER_ID)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test 2: _scan skips when < 2 facts ───────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_contradictions_skips_with_one_fact(db_session, pro_user):
|
||||||
|
row = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
|
||||||
|
|
||||||
|
with _patch_audit(llm):
|
||||||
|
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||||
|
|
||||||
|
llm.ainvoke.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test 3: _scan deletes flagged contradiction ───────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_contradictions_deletes_flagged_row(db_session, pro_user):
|
||||||
|
keep = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
|
||||||
|
drop = _assoc_row(PRO_USER_ID, "Never schedules before noon")
|
||||||
|
db_session.add(keep)
|
||||||
|
db_session.add(drop)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts morning pref"}])
|
||||||
|
llm = _mock_llm(deletion_payload)
|
||||||
|
|
||||||
|
with _patch_audit(llm, prompt_text="p {facts}"):
|
||||||
|
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||||
|
)
|
||||||
|
remaining = result.scalars().all()
|
||||||
|
remaining_ids = {r.id for r in remaining}
|
||||||
|
assert keep.id in remaining_ids
|
||||||
|
assert drop.id not in remaining_ids
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test 4: _scan is no-op on LLM failure ────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_contradictions_noop_on_llm_failure(db_session, pro_user):
|
||||||
|
for text in ("Fact A", "Fact B"):
|
||||||
|
db_session.add(_assoc_row(PRO_USER_ID, text))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||||
|
|
||||||
|
with _patch_audit(llm, prompt_text="p {facts}"):
|
||||||
|
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||||
|
)
|
||||||
|
assert len(result.scalars().all()) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test 5: _scan is no-op when LLM returns non-list ─────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_contradictions_noop_on_non_list_response(db_session, pro_user):
|
||||||
|
for text in ("Fact A", "Fact B"):
|
||||||
|
db_session.add(_assoc_row(PRO_USER_ID, text))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
llm = _mock_llm('"unexpected string"')
|
||||||
|
|
||||||
|
with _patch_audit(llm, prompt_text="p {facts}"):
|
||||||
|
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||||
|
)
|
||||||
|
assert len(result.scalars().all()) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test 6: _canonicalize skips when no relations ────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_canonicalize_skips_when_no_relations(db_session, pro_user):
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
|
||||||
|
|
||||||
|
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||||
|
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||||
|
|
||||||
|
llm.ainvoke.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test 7: _canonicalize rewrites variant labels ────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_canonicalize_rewrites_variant_labels(db_session, pro_user):
|
||||||
|
row_a = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
||||||
|
row_b = _relation_row(PRO_USER_ID, "Giulia R.", "reports_to", "Marco")
|
||||||
|
row_c = _relation_row(PRO_USER_ID, "Marco", "manages", "Giulia")
|
||||||
|
db_session.add(row_a)
|
||||||
|
db_session.add(row_b)
|
||||||
|
db_session.add(row_c)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
groups = json.dumps([
|
||||||
|
{"canonical": "Giulia", "variants": ["giulia", "Giulia R."]}
|
||||||
|
])
|
||||||
|
llm = _mock_llm(groups)
|
||||||
|
|
||||||
|
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||||
|
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||||
|
|
||||||
|
await db_session.refresh(row_a)
|
||||||
|
await db_session.refresh(row_b)
|
||||||
|
await db_session.refresh(row_c)
|
||||||
|
|
||||||
|
assert row_a.subject_label == "Giulia"
|
||||||
|
assert row_b.subject_label == "Giulia"
|
||||||
|
assert row_c.object_label == "Giulia"
|
||||||
|
assert row_c.subject_label == "Marco"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test 8: _canonicalize is no-op on LLM failure ────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_canonicalize_noop_on_llm_failure(db_session, pro_user):
|
||||||
|
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||||
|
|
||||||
|
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||||
|
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||||
|
|
||||||
|
await db_session.refresh(row)
|
||||||
|
assert row.subject_label == "giulia"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test 9: _canonicalize is no-op when remap is empty ───────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_canonicalize_noop_when_remap_empty(db_session, pro_user):
|
||||||
|
row = _relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme")
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
llm = _mock_llm("[]")
|
||||||
|
|
||||||
|
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||||
|
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||||
|
|
||||||
|
await db_session.refresh(row)
|
||||||
|
assert row.subject_label == "Giulia"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test 10: both helpers work without Langfuse ───────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_works_without_langfuse(db_session, pro_user):
|
||||||
|
keep = _assoc_row(PRO_USER_ID, "Prefers dark mode")
|
||||||
|
drop = _assoc_row(PRO_USER_ID, "Prefers light mode")
|
||||||
|
db_session.add(keep)
|
||||||
|
db_session.add(drop)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts dark mode"}])
|
||||||
|
llm = _mock_llm(deletion_payload)
|
||||||
|
|
||||||
|
with _patch_audit(llm, lf=None, prompt_text="p {facts}"):
|
||||||
|
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||||
|
)
|
||||||
|
remaining_ids = {r.id for r in result.scalars().all()}
|
||||||
|
assert keep.id in remaining_ids
|
||||||
|
assert drop.id not in remaining_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_canonicalize_works_without_langfuse(db_session, pro_user):
|
||||||
|
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
||||||
|
db_session.add(row)
|
||||||
|
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Giulia"))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
groups = json.dumps([{"canonical": "Giulia", "variants": ["giulia"]}])
|
||||||
|
llm = _mock_llm(groups)
|
||||||
|
|
||||||
|
with _patch_audit(llm, lf=None, prompt_text="p {labels}"):
|
||||||
|
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||||
|
|
||||||
|
await db_session.refresh(row)
|
||||||
|
assert row.subject_label == "Giulia"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test 11: correct Langfuse prompt names used ───────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_uses_correct_langfuse_prompt_name(db_session, pro_user):
|
||||||
|
for text in ("Fact A", "Fact B"):
|
||||||
|
db_session.add(_assoc_row(PRO_USER_ID, text))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
llm = _mock_llm("[]")
|
||||||
|
mock_get_prompt = MagicMock(return_value=("p {facts}", None))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.core.llm.get_agent_llm", return_value=llm),
|
||||||
|
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
||||||
|
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
||||||
|
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
|
||||||
|
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
||||||
|
):
|
||||||
|
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||||
|
|
||||||
|
mock_get_prompt.assert_called_once()
|
||||||
|
assert mock_get_prompt.call_args[0][0] == "memory_audit_contradictions"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_canonicalize_uses_correct_langfuse_prompt_name(db_session, pro_user):
|
||||||
|
db_session.add(_relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme"))
|
||||||
|
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Acme"))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
llm = _mock_llm("[]")
|
||||||
|
mock_get_prompt = MagicMock(return_value=("p {labels}", None))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.core.llm.get_agent_llm", return_value=llm),
|
||||||
|
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
||||||
|
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
||||||
|
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
|
||||||
|
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
||||||
|
):
|
||||||
|
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||||
|
|
||||||
|
mock_get_prompt.assert_called_once()
|
||||||
|
assert mock_get_prompt.call_args[0][0] == "memory_audit_canonicalize"
|
||||||
345
tests/test_memory_extraction.py
Normal file
345
tests/test_memory_extraction.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
"""Tests for Phase 2 — Mem0-style Extract/Update pipeline.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
2.1 extract_candidates returns valid ExtractionResult with mocked LLM.
|
||||||
|
2.2 decide_action — all 4 branches (ADD/UPDATE/DELETE/NOOP + empty existing).
|
||||||
|
2.3 run_extraction end-to-end with mocked LLM writes expected rows.
|
||||||
|
2.4 _dispatch_extraction — Pro user triggers realtime task; Free enqueues row.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.memory_extraction import (
|
||||||
|
ExtractionResult,
|
||||||
|
MemoryCandidate,
|
||||||
|
decide_action,
|
||||||
|
extract_candidates,
|
||||||
|
run_extraction,
|
||||||
|
)
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import ExtractionQueue, MemoryCore, User
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
|
||||||
|
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||||
|
FREE_USER_ID = TEST_USER_IDS["free"]
|
||||||
|
_FERNET_KEY = Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── DB override ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def pro_user(db_session):
|
||||||
|
"""Update the seeded pro user to have an encryption_key."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = _FERNET_KEY
|
||||||
|
await db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def free_user(db_session):
|
||||||
|
"""Update the seeded free user to have an encryption_key."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = _FERNET_KEY
|
||||||
|
await db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def _make_llm_response(content: str) -> MagicMock:
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.content = content
|
||||||
|
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
# ── TASK 2.1 — extract_candidates ────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_candidates_returns_valid_result():
|
||||||
|
payload = {
|
||||||
|
"candidates": [
|
||||||
|
{
|
||||||
|
"type": "fact",
|
||||||
|
"content": "User's CFO is Giulia",
|
||||||
|
"target_tier": "core",
|
||||||
|
"subject": None,
|
||||||
|
"predicate": None,
|
||||||
|
"object": None,
|
||||||
|
"confidence": 0.85,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_response = _make_llm_response(json.dumps(payload))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||||
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||||
|
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||||
|
):
|
||||||
|
mock_prompt.return_value = (
|
||||||
|
"system prompt {last_turn} {core_memory} {recent_episodes}",
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
llm_instance = MagicMock()
|
||||||
|
llm_instance.bind.return_value = llm_instance
|
||||||
|
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
mock_get_llm.return_value = llm_instance
|
||||||
|
|
||||||
|
result = await extract_candidates(
|
||||||
|
last_turn="User: My CFO is Giulia\nAssistant: Noted.",
|
||||||
|
core_memory={},
|
||||||
|
recent_episodes=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, ExtractionResult)
|
||||||
|
assert len(result.candidates) == 1
|
||||||
|
assert result.candidates[0].type == "fact"
|
||||||
|
assert "Giulia" in result.candidates[0].content
|
||||||
|
assert result.candidates[0].confidence == 0.85
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_candidates_returns_empty_on_llm_failure():
|
||||||
|
with (
|
||||||
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||||
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||||
|
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||||
|
):
|
||||||
|
mock_prompt.return_value = ("prompt {last_turn} {core_memory} {recent_episodes}", None)
|
||||||
|
llm_instance = MagicMock()
|
||||||
|
llm_instance.bind.return_value = llm_instance
|
||||||
|
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||||
|
mock_get_llm.return_value = llm_instance
|
||||||
|
|
||||||
|
result = await extract_candidates("turn", {}, [])
|
||||||
|
|
||||||
|
assert isinstance(result, ExtractionResult)
|
||||||
|
assert result.candidates == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── TASK 2.2 — decide_action ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decide_action_add_when_no_existing():
|
||||||
|
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
|
||||||
|
action = await decide_action(candidate, existing=[])
|
||||||
|
assert action == "ADD"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decide_action_noop():
|
||||||
|
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
|
||||||
|
mock_response = _make_llm_response("NOOP")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||||
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||||
|
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||||
|
):
|
||||||
|
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||||
|
llm_instance = MagicMock()
|
||||||
|
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
mock_get_llm.return_value = llm_instance
|
||||||
|
|
||||||
|
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
||||||
|
|
||||||
|
assert action == "NOOP"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decide_action_update():
|
||||||
|
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
|
||||||
|
mock_response = _make_llm_response("UPDATE")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||||
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||||
|
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||||
|
):
|
||||||
|
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||||
|
llm_instance = MagicMock()
|
||||||
|
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
mock_get_llm.return_value = llm_instance
|
||||||
|
|
||||||
|
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
||||||
|
|
||||||
|
assert action == "UPDATE"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decide_action_delete():
|
||||||
|
candidate = MemoryCandidate(type="fact", content="No longer have a CFO", target_tier="core")
|
||||||
|
mock_response = _make_llm_response("DELETE")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||||
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||||
|
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||||
|
):
|
||||||
|
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||||
|
llm_instance = MagicMock()
|
||||||
|
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
mock_get_llm.return_value = llm_instance
|
||||||
|
|
||||||
|
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
||||||
|
|
||||||
|
assert action == "DELETE"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decide_action_defaults_add_on_llm_failure():
|
||||||
|
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||||
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||||
|
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||||
|
):
|
||||||
|
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||||
|
llm_instance = MagicMock()
|
||||||
|
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||||
|
mock_get_llm.return_value = llm_instance
|
||||||
|
|
||||||
|
action = await decide_action(candidate, existing=["old memory"])
|
||||||
|
|
||||||
|
assert action == "ADD"
|
||||||
|
|
||||||
|
|
||||||
|
# ── TASK 2.3 — run_extraction end-to-end ─────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_extraction_writes_core_candidate(db_session, pro_user):
|
||||||
|
"""'My CFO is Giulia' → fact candidate → core row written."""
|
||||||
|
fact_payload = {
|
||||||
|
"candidates": [
|
||||||
|
{
|
||||||
|
"type": "fact",
|
||||||
|
"content": "User prefers morning meetings",
|
||||||
|
"target_tier": "core",
|
||||||
|
"confidence": 0.8,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _mock_llm_response(content: str):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.content = content
|
||||||
|
msg.usage_metadata = {}
|
||||||
|
return msg
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def _ainvoke_side_effect(messages):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
# extract_candidates call
|
||||||
|
return _mock_llm_response(json.dumps(fact_payload))
|
||||||
|
# decide_action — no existing → short-circuits to ADD without LLM
|
||||||
|
return _mock_llm_response("ADD")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||||
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||||
|
patch(
|
||||||
|
"app.core.memory_extraction.get_prompt_or_fallback",
|
||||||
|
side_effect=lambda name, fb: (
|
||||||
|
("p {last_turn} {core_memory} {recent_episodes}", None)
|
||||||
|
if name == "memory_extraction"
|
||||||
|
else ("p {candidate} {existing_memories}", None)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
llm_instance = MagicMock()
|
||||||
|
llm_instance.bind.return_value = llm_instance
|
||||||
|
llm_instance.ainvoke = AsyncMock(side_effect=_ainvoke_side_effect)
|
||||||
|
mock_get_llm.return_value = llm_instance
|
||||||
|
|
||||||
|
await run_extraction(
|
||||||
|
db=db_session,
|
||||||
|
user_id=PRO_USER_ID,
|
||||||
|
last_user_msg="My CFO is Giulia",
|
||||||
|
last_assistant_msg="Noted, I will remember that.",
|
||||||
|
session_id="test-session",
|
||||||
|
)
|
||||||
|
|
||||||
|
# core row should exist
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == PRO_USER_ID)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
assert len(rows) >= 1
|
||||||
|
fernet = Fernet(_FERNET_KEY.encode())
|
||||||
|
values = [fernet.decrypt(r.value_encrypted.encode()).decode() for r in rows]
|
||||||
|
assert any("morning meetings" in v for v in values)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TASK 2.4 — dispatch ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_realtime_for_pro(db_session, pro_user):
|
||||||
|
"""Pro user: asyncio.create_task called (not queue row)."""
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.core.memory_middleware.asyncio.create_task") as mock_task,
|
||||||
|
patch("app.billing.tier_manager.tier_manager.check_feature", return_value=True),
|
||||||
|
):
|
||||||
|
await middleware._dispatch_extraction(
|
||||||
|
user_id=PRO_USER_ID,
|
||||||
|
episode_id=str(uuid.uuid4()),
|
||||||
|
last_user_msg="hello",
|
||||||
|
last_assistant_msg="hi",
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_task.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_queue_for_free(db_session, free_user):
|
||||||
|
"""Free user: ExtractionQueue row inserted."""
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ep_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
with patch("app.billing.tier_manager.tier_manager.check_feature", return_value=False):
|
||||||
|
await middleware._dispatch_extraction(
|
||||||
|
user_id=FREE_USER_ID,
|
||||||
|
episode_id=ep_id,
|
||||||
|
last_user_msg="hello",
|
||||||
|
last_assistant_msg="hi",
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(ExtractionQueue).where(ExtractionQueue.user_id == FREE_USER_ID)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0].episode_id == ep_id
|
||||||
@@ -12,13 +12,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.embeddings import embed_text
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
@@ -341,3 +342,33 @@ def test_home_request_calls_memory_middleware(client):
|
|||||||
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
||||||
assert stored_session_id == session_id
|
assert stored_session_id == session_id
|
||||||
assert stored_message == "Show tasks"
|
assert stored_message == "Show tasks"
|
||||||
|
|
||||||
|
|
||||||
|
# ── embed_text ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embed_text_returns_1536_floats():
|
||||||
|
"""embed_text returns a 1536-dim float list when OpenAI responds successfully."""
|
||||||
|
fake_embedding = [0.1] * 1536
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data = [MagicMock(embedding=fake_embedding)]
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.embeddings.AsyncOpenAI", return_value=mock_client):
|
||||||
|
result = await embed_text("test text")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) == 1536
|
||||||
|
assert all(isinstance(x, float) for x in result)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embed_text_returns_none_on_failure():
|
||||||
|
"""embed_text returns None when OpenAI raises; must not propagate the exception."""
|
||||||
|
with patch("app.core.embeddings.AsyncOpenAI", side_effect=Exception("no key")):
|
||||||
|
result = await embed_text("test text")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|||||||
153
tests/test_memory_proactive.py
Normal file
153
tests/test_memory_proactive.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""Tests for Phase 5 — proactive hints surfacing.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
1. _proactive_hints_injection returns correct section for seeded hints
|
||||||
|
2. _proactive_hints_injection returns empty string when no hints
|
||||||
|
3. enrich_context includes proactive_hints key from MemoryProactive row
|
||||||
|
4. System prompt includes proactive line when row exists + confidence >= threshold
|
||||||
|
5. TierManager.check_feature returns True for power/team, False for free/pro
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
from app.core.deep_agent import _proactive_hints_injection
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import MemoryProactive, User
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["power"]
|
||||||
|
_FERNET_KEY = Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── DB override ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def user_with_key(db_session):
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = _FERNET_KEY
|
||||||
|
await db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def _enc(plaintext: str) -> str:
|
||||||
|
return Fernet(_FERNET_KEY.encode()).encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── _proactive_hints_injection unit tests ─────────────────────────────────────
|
||||||
|
|
||||||
|
def test_proactive_hints_injection_with_hints():
|
||||||
|
context = {"proactive_hints": ["Works late on Thursdays", "Prefers bullet points"]}
|
||||||
|
result = _proactive_hints_injection(context)
|
||||||
|
assert "I noticed" in result
|
||||||
|
assert "Works late on Thursdays" in result
|
||||||
|
assert "Prefers bullet points" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_proactive_hints_injection_empty():
|
||||||
|
assert _proactive_hints_injection({}) == ""
|
||||||
|
assert _proactive_hints_injection({"proactive_hints": []}) == ""
|
||||||
|
assert _proactive_hints_injection({"proactive_hints": None}) == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_proactive_hints_injection_truncates_long_hints():
|
||||||
|
hints = ["x" * 200] * 10
|
||||||
|
result = _proactive_hints_injection({"proactive_hints": hints})
|
||||||
|
assert len(result) <= 600
|
||||||
|
assert result.endswith("...")
|
||||||
|
|
||||||
|
|
||||||
|
# ── enrich_context includes proactive hints ───────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
||||||
|
pattern = "Always checks tasks before meetings"
|
||||||
|
db_session.add(MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
pattern_encrypted=_enc(pattern),
|
||||||
|
confidence=0.8,
|
||||||
|
source="inferred",
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "test message")
|
||||||
|
|
||||||
|
assert "proactive_hints" in ctx
|
||||||
|
assert pattern in ctx["proactive_hints"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_excludes_low_confidence_proactive(db_session, user_with_key):
|
||||||
|
pattern = "Low confidence pattern"
|
||||||
|
db_session.add(MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
pattern_encrypted=_enc(pattern),
|
||||||
|
confidence=0.1,
|
||||||
|
source="inferred",
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "test message")
|
||||||
|
|
||||||
|
hints = ctx.get("proactive_hints", [])
|
||||||
|
assert pattern not in hints
|
||||||
|
|
||||||
|
|
||||||
|
# ── proactive hints appear in system prompt string ───────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_proactive_hints_in_system_prompt_string(db_session, user_with_key):
|
||||||
|
pattern = "Frequently requests end-of-day summaries"
|
||||||
|
db_session.add(MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
pattern_encrypted=_enc(pattern),
|
||||||
|
confidence=0.75,
|
||||||
|
source="inferred",
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "summarize my day")
|
||||||
|
|
||||||
|
system_prompt_suffix = _proactive_hints_injection(ctx)
|
||||||
|
assert pattern in system_prompt_suffix
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier gate ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tier,expected", [
|
||||||
|
("free", False),
|
||||||
|
("pro", False),
|
||||||
|
("power", True),
|
||||||
|
("team", True),
|
||||||
|
])
|
||||||
|
def test_proactive_mining_tier_gate(tier, expected):
|
||||||
|
assert tier_manager.check_feature(tier, "proactive_mining") == expected
|
||||||
220
tests/test_memory_relations.py
Normal file
220
tests/test_memory_relations.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""Tests for Phase 3 — relational tier (Mem0g-light).
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
1. upsert_relation inserts a row and query_relations returns it
|
||||||
|
2. upsert_relation updates existing row on duplicate (subject/predicate/object)
|
||||||
|
3. tier gating: Free user gets empty list from query_relations + enrich_context
|
||||||
|
4. enrich_context includes relational_memory key for Pro user
|
||||||
|
5. decay_relations decays confidence and prunes rows below threshold
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.memory_maintenance import decay_relations
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import MemoryRelation, User
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||||
|
FREE_USER_ID = TEST_USER_IDS["free"]
|
||||||
|
_FERNET_KEY = Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── DB override ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def pro_user_with_key(db_session):
|
||||||
|
"""Set encryption_key on the pro test user so Fernet works."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = _FERNET_KEY
|
||||||
|
await db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def free_user_with_key(db_session):
|
||||||
|
"""Set encryption_key on the free test user."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = _FERNET_KEY
|
||||||
|
await db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_relation_inserts_and_queries(db_session, pro_user_with_key):
|
||||||
|
"""upsert_relation inserts a row; query_relations returns it."""
|
||||||
|
mm = MemoryMiddleware(db_session)
|
||||||
|
await mm.upsert_relation(
|
||||||
|
PRO_USER_ID,
|
||||||
|
subject="Giulia",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="works_at",
|
||||||
|
object_="Acme Corp",
|
||||||
|
object_type="company",
|
||||||
|
confidence=0.9,
|
||||||
|
)
|
||||||
|
rows = await mm.query_relations(PRO_USER_ID, subject="Giulia")
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0].subject_label == "Giulia"
|
||||||
|
assert rows[0].predicate == "works_at"
|
||||||
|
assert rows[0].object_label == "Acme Corp"
|
||||||
|
assert abs(rows[0].confidence - 0.9) < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_relation_updates_on_duplicate(db_session, pro_user_with_key):
|
||||||
|
"""Second upsert on same triple updates confidence and last_confirmed_at."""
|
||||||
|
mm = MemoryMiddleware(db_session)
|
||||||
|
await mm.upsert_relation(
|
||||||
|
PRO_USER_ID,
|
||||||
|
subject="Marco",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="stakeholder_of",
|
||||||
|
object_="Project Nexus",
|
||||||
|
object_type="project",
|
||||||
|
confidence=0.7,
|
||||||
|
)
|
||||||
|
await mm.upsert_relation(
|
||||||
|
PRO_USER_ID,
|
||||||
|
subject="Marco",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="stakeholder_of",
|
||||||
|
object_="Project Nexus",
|
||||||
|
object_type="project",
|
||||||
|
confidence=0.95,
|
||||||
|
)
|
||||||
|
rows = await mm.query_relations(PRO_USER_ID, subject="Marco")
|
||||||
|
# Only one row despite two upserts
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert abs(rows[0].confidence - 0.95) < 0.001
|
||||||
|
assert rows[0].last_confirmed_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_free_tier_relation_skipped(db_session, free_user_with_key):
|
||||||
|
"""Free user: upsert_relation is silently skipped (no row created)."""
|
||||||
|
mm = MemoryMiddleware(db_session)
|
||||||
|
await mm.upsert_relation(
|
||||||
|
FREE_USER_ID,
|
||||||
|
subject="Alice",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="reports_to",
|
||||||
|
object_="Bob",
|
||||||
|
object_type="person",
|
||||||
|
confidence=0.8,
|
||||||
|
)
|
||||||
|
rows = await mm.query_relations(FREE_USER_ID, subject="Alice")
|
||||||
|
assert rows == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_includes_relational_memory(db_session, pro_user_with_key):
|
||||||
|
"""enrich_context includes relational_memory key for Pro user."""
|
||||||
|
mm = MemoryMiddleware(db_session)
|
||||||
|
await mm.upsert_relation(
|
||||||
|
PRO_USER_ID,
|
||||||
|
subject="Elena",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="cfo_of",
|
||||||
|
object_="StartupXYZ",
|
||||||
|
object_type="company",
|
||||||
|
confidence=0.85,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
|
||||||
|
ctx = await mm.enrich_context(PRO_USER_ID, "who is Elena?")
|
||||||
|
|
||||||
|
assert "relational_memory" in ctx
|
||||||
|
assert any("Elena" in r for r in ctx["relational_memory"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_relational_empty_for_free(db_session, free_user_with_key):
|
||||||
|
"""Free user: relational_memory is empty list in enrich_context."""
|
||||||
|
mm = MemoryMiddleware(db_session)
|
||||||
|
|
||||||
|
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
|
||||||
|
ctx = await mm.enrich_context(FREE_USER_ID, "test message")
|
||||||
|
|
||||||
|
assert ctx.get("relational_memory") == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decay_relations_reduces_confidence(db_session, pro_user_with_key):
|
||||||
|
"""decay_relations reduces confidence on stale rows."""
|
||||||
|
old_date = datetime.now(timezone.utc) - timedelta(days=35)
|
||||||
|
row = MemoryRelation(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=PRO_USER_ID,
|
||||||
|
subject_label="OldContact",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="knows",
|
||||||
|
object_label="SomeProject",
|
||||||
|
object_type="project",
|
||||||
|
confidence=0.8,
|
||||||
|
last_confirmed_at=old_date,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
await decay_relations(db_session, PRO_USER_ID)
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryRelation).where(MemoryRelation.subject_label == "OldContact")
|
||||||
|
)
|
||||||
|
updated = result.scalar_one_or_none()
|
||||||
|
assert updated is not None
|
||||||
|
assert updated.confidence < 0.8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decay_relations_prunes_low_confidence(db_session, pro_user_with_key):
|
||||||
|
"""decay_relations deletes rows whose confidence drops below 0.2 threshold."""
|
||||||
|
# Start at 0.21 with 60-day-old last_confirmed_at → two decay periods → 0.21 * 0.95^2 ≈ 0.19 → pruned
|
||||||
|
old_date = datetime.now(timezone.utc) - timedelta(days=65)
|
||||||
|
row = MemoryRelation(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=PRO_USER_ID,
|
||||||
|
subject_label="ExpiredContact",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="used_to_work_with",
|
||||||
|
object_label="OldCorp",
|
||||||
|
object_type="company",
|
||||||
|
confidence=0.21,
|
||||||
|
last_confirmed_at=old_date,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
await decay_relations(db_session, PRO_USER_ID)
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryRelation).where(MemoryRelation.subject_label == "ExpiredContact")
|
||||||
|
)
|
||||||
|
pruned = result.scalar_one_or_none()
|
||||||
|
assert pruned is None
|
||||||
@@ -45,9 +45,6 @@ def test_v2_frame_types_still_exist():
|
|||||||
"tool_result",
|
"tool_result",
|
||||||
"final",
|
"final",
|
||||||
"ping",
|
"ping",
|
||||||
"agent_run",
|
|
||||||
"agent_data",
|
|
||||||
"agent_complete",
|
|
||||||
"device_hello",
|
"device_hello",
|
||||||
]
|
]
|
||||||
for name in v2_types:
|
for name in v2_types:
|
||||||
|
|||||||
Reference in New Issue
Block a user