154 lines
5.3 KiB
Python
154 lines
5.3 KiB
Python
"""Tests for Phase 5 — proactive hints surfacing.
|
|
|
|
Coverage:
|
|
1. _proactive_hints_injection returns correct section for seeded hints
|
|
2. _proactive_hints_injection returns empty string when no hints
|
|
3. enrich_context includes proactive_hints key from MemoryProactive row
|
|
4. System prompt includes proactive line when row exists + confidence >= threshold
|
|
5. TierManager.check_feature returns True for power/team, False for free/pro
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from cryptography.fernet import Fernet
|
|
from sqlalchemy import select
|
|
|
|
from app.billing.tier_manager import tier_manager
|
|
from app.core.deep_agent import _proactive_hints_injection
|
|
from app.core.memory_middleware import MemoryMiddleware
|
|
from app.db import get_session
|
|
from app.main import app
|
|
from app.models import MemoryProactive, User
|
|
from tests.conftest import TEST_USER_IDS
|
|
|
|
|
|
USER_ID = TEST_USER_IDS["power"]
|
|
_FERNET_KEY = Fernet.generate_key().decode()
|
|
|
|
|
|
# ── DB override ───────────────────────────────────────────────────────────────
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _override_db(db_session):
|
|
async def _gen():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_session] = _gen
|
|
yield
|
|
app.dependency_overrides.pop(get_session, None)
|
|
|
|
|
|
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
|
|
|
@pytest_asyncio.fixture
|
|
async def user_with_key(db_session):
|
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
|
user = result.scalar_one()
|
|
user.encryption_key = _FERNET_KEY
|
|
await db_session.commit()
|
|
return user
|
|
|
|
|
|
def _enc(plaintext: str) -> str:
|
|
return Fernet(_FERNET_KEY.encode()).encrypt(plaintext.encode()).decode()
|
|
|
|
|
|
# ── _proactive_hints_injection unit tests ─────────────────────────────────────
|
|
|
|
def test_proactive_hints_injection_with_hints():
|
|
context = {"proactive_hints": ["Works late on Thursdays", "Prefers bullet points"]}
|
|
result = _proactive_hints_injection(context)
|
|
assert "I noticed" in result
|
|
assert "Works late on Thursdays" in result
|
|
assert "Prefers bullet points" in result
|
|
|
|
|
|
def test_proactive_hints_injection_empty():
|
|
assert _proactive_hints_injection({}) == ""
|
|
assert _proactive_hints_injection({"proactive_hints": []}) == ""
|
|
assert _proactive_hints_injection({"proactive_hints": None}) == ""
|
|
|
|
|
|
def test_proactive_hints_injection_truncates_long_hints():
|
|
hints = ["x" * 200] * 10
|
|
result = _proactive_hints_injection({"proactive_hints": hints})
|
|
assert len(result) <= 600
|
|
assert result.endswith("...")
|
|
|
|
|
|
# ── enrich_context includes proactive hints ───────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
|
pattern = "Always checks tasks before meetings"
|
|
db_session.add(MemoryProactive(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
pattern_encrypted=_enc(pattern),
|
|
confidence=0.8,
|
|
source="inferred",
|
|
))
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "test message")
|
|
|
|
assert "proactive_hints" in ctx
|
|
assert pattern in ctx["proactive_hints"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_excludes_low_confidence_proactive(db_session, user_with_key):
|
|
pattern = "Low confidence pattern"
|
|
db_session.add(MemoryProactive(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
pattern_encrypted=_enc(pattern),
|
|
confidence=0.1,
|
|
source="inferred",
|
|
))
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "test message")
|
|
|
|
hints = ctx.get("proactive_hints", [])
|
|
assert pattern not in hints
|
|
|
|
|
|
# ── proactive hints appear in system prompt string ───────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proactive_hints_in_system_prompt_string(db_session, user_with_key):
|
|
pattern = "Frequently requests end-of-day summaries"
|
|
db_session.add(MemoryProactive(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
pattern_encrypted=_enc(pattern),
|
|
confidence=0.75,
|
|
source="inferred",
|
|
))
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "summarize my day")
|
|
|
|
system_prompt_suffix = _proactive_hints_injection(ctx)
|
|
assert pattern in system_prompt_suffix
|
|
|
|
|
|
# ── Tier gate ─────────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.parametrize("tier,expected", [
|
|
("free", False),
|
|
("pro", False),
|
|
("power", True),
|
|
("team", True),
|
|
])
|
|
def test_proactive_mining_tier_gate(tier, expected):
|
|
assert tier_manager.check_feature(tier, "proactive_mining") == expected
|