Files
api/tests/test_memory_proactive.py
2026-04-17 17:58:30 +02:00

154 lines
5.3 KiB
Python

"""Tests for Phase 5 — proactive hints surfacing.
Coverage:
1. _proactive_hints_injection returns correct section for seeded hints
2. _proactive_hints_injection returns empty string when no hints
3. enrich_context includes proactive_hints key from MemoryProactive row
4. System prompt includes proactive line when row exists + confidence >= threshold
5. TierManager.check_feature returns True for power/team, False for free/pro
"""
from __future__ import annotations
import uuid
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.billing.tier_manager import tier_manager
from app.core.deep_agent import _proactive_hints_injection
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.main import app
from app.models import MemoryProactive, User
from tests.conftest import TEST_USER_IDS
USER_ID = TEST_USER_IDS["power"]
_FERNET_KEY = Fernet.generate_key().decode()
# ── DB override ───────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest_asyncio.fixture
async def user_with_key(db_session):
result = await db_session.execute(select(User).where(User.id == USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
def _enc(plaintext: str) -> str:
return Fernet(_FERNET_KEY.encode()).encrypt(plaintext.encode()).decode()
# ── _proactive_hints_injection unit tests ─────────────────────────────────────
def test_proactive_hints_injection_with_hints():
context = {"proactive_hints": ["Works late on Thursdays", "Prefers bullet points"]}
result = _proactive_hints_injection(context)
assert "I noticed" in result
assert "Works late on Thursdays" in result
assert "Prefers bullet points" in result
def test_proactive_hints_injection_empty():
assert _proactive_hints_injection({}) == ""
assert _proactive_hints_injection({"proactive_hints": []}) == ""
assert _proactive_hints_injection({"proactive_hints": None}) == ""
def test_proactive_hints_injection_truncates_long_hints():
hints = ["x" * 200] * 10
result = _proactive_hints_injection({"proactive_hints": hints})
assert len(result) <= 600
assert result.endswith("...")
# ── enrich_context includes proactive hints ───────────────────────────────────
@pytest.mark.asyncio
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
pattern = "Always checks tasks before meetings"
db_session.add(MemoryProactive(
id=str(uuid.uuid4()),
user_id=USER_ID,
pattern_encrypted=_enc(pattern),
confidence=0.8,
source="inferred",
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "test message")
assert "proactive_hints" in ctx
assert pattern in ctx["proactive_hints"]
@pytest.mark.asyncio
async def test_enrich_context_excludes_low_confidence_proactive(db_session, user_with_key):
pattern = "Low confidence pattern"
db_session.add(MemoryProactive(
id=str(uuid.uuid4()),
user_id=USER_ID,
pattern_encrypted=_enc(pattern),
confidence=0.1,
source="inferred",
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "test message")
hints = ctx.get("proactive_hints", [])
assert pattern not in hints
# ── proactive hints appear in system prompt string ───────────────────────────
@pytest.mark.asyncio
async def test_proactive_hints_in_system_prompt_string(db_session, user_with_key):
pattern = "Frequently requests end-of-day summaries"
db_session.add(MemoryProactive(
id=str(uuid.uuid4()),
user_id=USER_ID,
pattern_encrypted=_enc(pattern),
confidence=0.75,
source="inferred",
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "summarize my day")
system_prompt_suffix = _proactive_hints_injection(ctx)
assert pattern in system_prompt_suffix
# ── Tier gate ─────────────────────────────────────────────────────────────────
@pytest.mark.parametrize("tier,expected", [
("free", False),
("pro", False),
("power", True),
("team", True),
])
def test_proactive_mining_tier_gate(tier, expected):
assert tier_manager.check_feature(tier, "proactive_mining") == expected