PHASE 5 — Proactive mining (Power tier only)
This commit is contained in:
@@ -57,6 +57,13 @@ LLM_MODEL_SETUP_AGENT=
|
||||
# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0).
|
||||
LLM_MODEL_MEMORY_EXTRACTOR=
|
||||
|
||||
# Memory-miner — proactive pattern mining from episodic history (Phase 5, Power+ only).
|
||||
# Defaults to gpt-4o-mini when empty.
|
||||
LLM_MODEL_MEMORY_MINER=
|
||||
|
||||
# Scheduler — set to false to disable memory cron jobs (automatically false in tests).
|
||||
SCHEDULER_ENABLED=true
|
||||
|
||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||
STRIPE_SECRET_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
@@ -28,6 +28,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"real_embeddings": False, # keyword fallback only
|
||||
"realtime_extraction": False, # batch queue (Phase 2)
|
||||
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
||||
"proactive_mining": False, # Power+ only (Phase 5)
|
||||
},
|
||||
"pro": {
|
||||
"agents": -1, # unlimited
|
||||
@@ -39,6 +40,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"real_embeddings": True, # pgvector cosine search
|
||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||
"relational_memory": True, # person/project predicates
|
||||
"proactive_mining": False, # Power+ only (Phase 5)
|
||||
},
|
||||
"power": {
|
||||
"agents": -1,
|
||||
@@ -50,6 +52,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
"relational_memory": True, # all predicates incl. custom
|
||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||
},
|
||||
"team": {
|
||||
"agents": -1,
|
||||
@@ -61,6 +64,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
"relational_memory": True, # all predicates incl. custom
|
||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ class Settings(BaseSettings):
|
||||
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
||||
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
||||
|
||||
# GitHub Copilot OAuth token storage directory.
|
||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||
@@ -70,6 +71,8 @@ class Settings(BaseSettings):
|
||||
LANGFUSE_PUBLIC_KEY: str = ""
|
||||
LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
|
||||
|
||||
SCHEDULER_ENABLED: bool = True
|
||||
|
||||
ENV: Literal["dev", "prod"] = "dev"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
@@ -55,6 +55,22 @@ def _language_instruction(context: dict[str, Any]) -> str:
|
||||
f"All your output text must be written in {lang}."
|
||||
)
|
||||
|
||||
def _proactive_hints_injection(context: dict[str, Any]) -> str:
|
||||
"""Return a system-prompt paragraph listing proactive behavioral hints.
|
||||
|
||||
Returns empty string when no hints or confidence below threshold.
|
||||
Capped at 600 chars.
|
||||
"""
|
||||
hints: list[str] = context.get("proactive_hints") or []
|
||||
if not hints:
|
||||
return ""
|
||||
body = "\n".join(f"- {h}" for h in hints)
|
||||
section = f"\n\nI noticed (behavioral patterns):\n{body}"
|
||||
if len(section) > 600:
|
||||
section = section[:597] + "..."
|
||||
return section
|
||||
|
||||
|
||||
def _relational_memory_injection(context: dict[str, Any]) -> str:
|
||||
"""Return a system-prompt paragraph listing known people/projects from relational memory.
|
||||
|
||||
@@ -921,6 +937,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||
"home_system", _HOME_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _proactive_hints_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
@@ -940,6 +957,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
|
||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _proactive_hints_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
@@ -965,6 +983,7 @@ async def run_home_stream(
|
||||
"home_system", _HOME_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _proactive_hints_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
text_chunks: list[str] = []
|
||||
async for event in _run_single_agent_stream(
|
||||
@@ -999,6 +1018,7 @@ async def run_floating_stream(
|
||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _proactive_hints_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
sanitizer = _FloatingStreamSanitizer()
|
||||
emitted_sanitized = False
|
||||
|
||||
@@ -104,6 +104,7 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,29 +1,41 @@
|
||||
"""Memory maintenance jobs — Phase 3/5.
|
||||
|
||||
Two entrypoints called by the scheduler (APScheduler) registered in app/main.py:
|
||||
Three entrypoints called by the scheduler (APScheduler) registered in app/main.py:
|
||||
|
||||
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
|
||||
mine_proactive_patterns(db, user_id) — Power+ pattern mining (Phase 5).
|
||||
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
|
||||
|
||||
Both are safe to call manually or from tests; they never raise.
|
||||
All are safe to call manually or from tests; they never raise.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import MemoryRelation
|
||||
from app.models import MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Decay parameters
|
||||
_DECAY_FACTOR = 0.95 # multiply confidence by this every _DECAY_PERIOD days
|
||||
_DECAY_PERIOD_DAYS = 30 # period for one decay step
|
||||
_PRUNE_THRESHOLD = 0.2 # rows below this confidence are deleted
|
||||
# Decay parameters for relations
|
||||
_DECAY_FACTOR = 0.95
|
||||
_DECAY_PERIOD_DAYS = 30
|
||||
_PRUNE_THRESHOLD = 0.2
|
||||
|
||||
# Proactive pattern decay: 10 % per 7 days since last sighting
|
||||
_PROACTIVE_DECAY_FACTOR = 0.9
|
||||
_PROACTIVE_DECAY_PERIOD_DAYS = 7
|
||||
_PROACTIVE_PRUNE_THRESHOLD = 0.2
|
||||
|
||||
# Mining: require at least this many episodes to attempt pattern extraction
|
||||
_MIN_EPISODES_FOR_MINING = 3
|
||||
_MINING_LOOKBACK_DAYS = 30
|
||||
|
||||
|
||||
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
||||
@@ -53,7 +65,6 @@ async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
|
||||
reference = row.last_confirmed_at or row.created_at
|
||||
if reference is None:
|
||||
continue
|
||||
# Ensure timezone-aware comparison
|
||||
if reference.tzinfo is None:
|
||||
reference = reference.replace(tzinfo=timezone.utc)
|
||||
|
||||
@@ -88,15 +99,215 @@ async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
|
||||
|
||||
|
||||
async def drain_extraction_queue(db: AsyncSession) -> None:
|
||||
"""Process pending ExtractionQueue rows for Free-tier users (Phase 5 stub).
|
||||
"""Process pending ExtractionQueue rows for Free-tier users.
|
||||
|
||||
Full implementation wired in Phase 5 when APScheduler is registered.
|
||||
Currently logs count and returns.
|
||||
Each row corresponds to a stored episode that should be fed through the
|
||||
Mem0-style extraction pipeline. Rows are deleted after successful processing.
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
from app.models import ExtractionQueue # noqa: PLC0415
|
||||
result = await db.execute(select(ExtractionQueue))
|
||||
rows = result.scalars().all()
|
||||
logger.info("memory_maintenance: drain_extraction_queue pending=%d (Phase 5 cron)", len(rows))
|
||||
await _drain_extraction_queue_inner(db)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
|
||||
|
||||
|
||||
async def _drain_extraction_queue_inner(db: AsyncSession) -> None:
|
||||
from app.models import ExtractionQueue # noqa: PLC0415
|
||||
|
||||
result = await db.execute(select(ExtractionQueue))
|
||||
rows = result.scalars().all()
|
||||
|
||||
if not rows:
|
||||
logger.debug("memory_maintenance: drain_extraction_queue nothing to drain")
|
||||
return
|
||||
|
||||
logger.info("memory_maintenance: drain_extraction_queue pending=%d", len(rows))
|
||||
|
||||
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||
|
||||
processed = 0
|
||||
for row in rows:
|
||||
try:
|
||||
await run_extraction(
|
||||
db=db,
|
||||
user_id=row.user_id,
|
||||
last_user_msg="",
|
||||
last_assistant_msg="",
|
||||
session_id=None,
|
||||
)
|
||||
await db.delete(row)
|
||||
await db.commit()
|
||||
processed += 1
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: drain failed row=%s user=%s: %s",
|
||||
row.id, row.user_id, exc,
|
||||
)
|
||||
await db.rollback()
|
||||
|
||||
logger.info("memory_maintenance: drain_extraction_queue processed=%d/%d", processed, len(rows))
|
||||
|
||||
|
||||
async def mine_proactive_patterns(db: AsyncSession, user_id: str) -> None:
|
||||
"""Mine recurring behavioral patterns from last 30 days of episodes (Power+ only).
|
||||
|
||||
Steps:
|
||||
1. Gate on proactive_mining tier feature.
|
||||
2. Load + decrypt last 30 days of episodic summaries.
|
||||
3. Call gpt-4o-mini to identify recurring patterns.
|
||||
4. Encrypt and store each pattern in memory_proactive.
|
||||
5. Apply decay to existing proactive rows.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _mine_proactive_patterns_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: mine_proactive_patterns failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _mine_proactive_patterns_inner(db: AsyncSession, user_id: str) -> None:
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
tier = await tier_manager.get_tier(user_id, db)
|
||||
if not tier_manager.check_feature(tier, "proactive_mining"):
|
||||
logger.debug("memory_maintenance: mine_proactive_patterns skipped (tier=%s)", tier)
|
||||
return
|
||||
|
||||
# Load user Fernet key
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory_maintenance: mine_proactive_patterns no encryption_key user=%s", user_id)
|
||||
return
|
||||
|
||||
fernet = Fernet(user.encryption_key.encode())
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=_MINING_LOOKBACK_DAYS)
|
||||
|
||||
episodes_result = await db.execute(
|
||||
select(MemoryEpisodic)
|
||||
.where(
|
||||
MemoryEpisodic.user_id == user_id,
|
||||
MemoryEpisodic.created_at >= cutoff,
|
||||
)
|
||||
.order_by(MemoryEpisodic.created_at.asc())
|
||||
)
|
||||
episode_rows = episodes_result.scalars().all()
|
||||
|
||||
if len(episode_rows) < _MIN_EPISODES_FOR_MINING:
|
||||
logger.info(
|
||||
"memory_maintenance: mine_proactive_patterns skipped user=%s episodes=%d (< %d)",
|
||||
user_id, len(episode_rows), _MIN_EPISODES_FOR_MINING,
|
||||
)
|
||||
return
|
||||
|
||||
summaries: list[str] = []
|
||||
for ep in episode_rows:
|
||||
try:
|
||||
plaintext = fernet.decrypt(ep.summary_encrypted.encode()).decode()
|
||||
summaries.append(plaintext)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
|
||||
patterns = await _extract_proactive_patterns(summaries)
|
||||
if not patterns:
|
||||
logger.info("memory_maintenance: mine_proactive_patterns user=%s no patterns extracted", user_id)
|
||||
return
|
||||
|
||||
stored = 0
|
||||
for pattern_text in patterns:
|
||||
try:
|
||||
encrypted = fernet.encrypt(pattern_text.encode()).decode()
|
||||
row = MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
pattern_encrypted=encrypted,
|
||||
confidence=0.7,
|
||||
source="inferred",
|
||||
)
|
||||
db.add(row)
|
||||
stored += 1
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: failed to store pattern user=%s: %s", user_id, exc)
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: mine_proactive_patterns user=%s stored=%d",
|
||||
user_id, stored,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: mine_proactive_patterns commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
return
|
||||
|
||||
await _decay_proactive_patterns(db, user_id, fernet)
|
||||
|
||||
|
||||
async def _extract_proactive_patterns(summaries: list[str]) -> list[str]:
|
||||
"""Call memory-miner LLM to identify recurring behavioral/temporal patterns."""
|
||||
from app.core.llm import get_agent_llm # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-miner", temperature=0)
|
||||
combined = "\n---\n".join(summaries[-20:]) # cap at last 20 to control token usage
|
||||
prompt = (
|
||||
"You are analyzing conversation history for a personal AI secretary. "
|
||||
"Identify 3-5 recurring temporal or behavioral patterns (e.g. 'always works late on Thursdays', "
|
||||
"'prefers bullet-point summaries', 'frequently asks about Project Acme status'). "
|
||||
"Return each pattern as a plain, short English sentence on its own line. "
|
||||
"No numbering, no bullet points, no extra text.\n\n"
|
||||
f"Conversation history:\n{combined}"
|
||||
)
|
||||
try:
|
||||
response = await llm.ainvoke(prompt)
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
lines = [line.strip() for line in str(text).splitlines() if line.strip()]
|
||||
return lines[:5]
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: _extract_proactive_patterns LLM failed: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fernet) -> None:
|
||||
"""Decay confidence of existing proactive patterns; prune below threshold."""
|
||||
result = await db.execute(
|
||||
select(MemoryProactive).where(MemoryProactive.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
now = datetime.now(timezone.utc)
|
||||
deleted = 0
|
||||
decayed = 0
|
||||
|
||||
for row in rows:
|
||||
reference = row.created_at
|
||||
if reference is None:
|
||||
continue
|
||||
if reference.tzinfo is None:
|
||||
reference = reference.replace(tzinfo=timezone.utc)
|
||||
|
||||
days_elapsed = (now - reference).days
|
||||
if days_elapsed < _PROACTIVE_DECAY_PERIOD_DAYS:
|
||||
continue
|
||||
|
||||
periods = days_elapsed // _PROACTIVE_DECAY_PERIOD_DAYS
|
||||
new_confidence = row.confidence * (_PROACTIVE_DECAY_FACTOR ** periods)
|
||||
|
||||
if new_confidence < _PROACTIVE_PRUNE_THRESHOLD:
|
||||
await db.delete(row)
|
||||
deleted += 1
|
||||
else:
|
||||
row.confidence = new_confidence
|
||||
decayed += 1
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: decay_proactive user=%s decayed=%d deleted=%d",
|
||||
user_id, decayed, deleted,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
|
||||
46
app/main.py
46
app/main.py
@@ -16,13 +16,59 @@ from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
async def _memory_cron_tick() -> None:
|
||||
"""Hourly cron: drain Free-tier extraction queue + mine proactive patterns for Power+ users."""
|
||||
import logging # noqa: PLC0415
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.info("memory cron tick: starting")
|
||||
try:
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
from app.core.memory_maintenance import drain_extraction_queue, mine_proactive_patterns # noqa: PLC0415
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.models import User # noqa: PLC0415
|
||||
from sqlalchemy import select # noqa: PLC0415
|
||||
|
||||
async with async_session() as db:
|
||||
await drain_extraction_queue(db)
|
||||
|
||||
# mine proactive patterns for every Power+ user
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(User.id))
|
||||
user_ids: list[str] = list(result.scalars().all())
|
||||
|
||||
for uid in user_ids:
|
||||
try:
|
||||
async with async_session() as db:
|
||||
tier = await tier_manager.get_tier(uid, db)
|
||||
if tier_manager.check_feature(tier, "proactive_mining"):
|
||||
await mine_proactive_patterns(db, uid)
|
||||
except Exception as exc:
|
||||
_log.warning("memory cron tick: mine_proactive_patterns failed user=%s: %s", uid, exc)
|
||||
|
||||
_log.info("memory cron tick: done users=%d", len(user_ids))
|
||||
except Exception as exc:
|
||||
_log.warning("memory cron tick: failed: %s", exc)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup: ensure agent tool modules are loaded.
|
||||
import app.agents # noqa: F401
|
||||
|
||||
scheduler = None
|
||||
if settings.SCHEDULER_ENABLED:
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # noqa: PLC0415
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
|
||||
scheduler.start()
|
||||
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
|
||||
|
||||
yield
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.shutdown(wait=False)
|
||||
|
||||
# Shutdown: dispose SQLAlchemy connection pool
|
||||
from app.db import engine
|
||||
await engine.dispose()
|
||||
|
||||
@@ -37,4 +37,5 @@ langfuse>=2.0.0
|
||||
beautifulsoup4>=4.12.0
|
||||
lxml>=5.0.0
|
||||
PyYAML>=6.0.0
|
||||
apscheduler>=3.10.0
|
||||
ruff>=0.8.0
|
||||
|
||||
1
results.xml
Normal file
1
results.xml
Normal file
File diff suppressed because one or more lines are too long
153
tests/test_memory_proactive.py
Normal file
153
tests/test_memory_proactive.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Tests for Phase 5 — proactive hints surfacing.
|
||||
|
||||
Coverage:
|
||||
1. _proactive_hints_injection returns correct section for seeded hints
|
||||
2. _proactive_hints_injection returns empty string when no hints
|
||||
3. enrich_context includes proactive_hints key from MemoryProactive row
|
||||
4. System prompt includes proactive line when row exists + confidence >= threshold
|
||||
5. TierManager.check_feature returns True for power/team, False for free/pro
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.billing.tier_manager import tier_manager
|
||||
from app.core.deep_agent import _proactive_hints_injection
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import MemoryProactive, User
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
|
||||
USER_ID = TEST_USER_IDS["power"]
|
||||
_FERNET_KEY = Fernet.generate_key().decode()
|
||||
|
||||
|
||||
# ── DB override ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def user_with_key(db_session):
|
||||
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def _enc(plaintext: str) -> str:
|
||||
return Fernet(_FERNET_KEY.encode()).encrypt(plaintext.encode()).decode()
|
||||
|
||||
|
||||
# ── _proactive_hints_injection unit tests ─────────────────────────────────────
|
||||
|
||||
def test_proactive_hints_injection_with_hints():
|
||||
context = {"proactive_hints": ["Works late on Thursdays", "Prefers bullet points"]}
|
||||
result = _proactive_hints_injection(context)
|
||||
assert "I noticed" in result
|
||||
assert "Works late on Thursdays" in result
|
||||
assert "Prefers bullet points" in result
|
||||
|
||||
|
||||
def test_proactive_hints_injection_empty():
|
||||
assert _proactive_hints_injection({}) == ""
|
||||
assert _proactive_hints_injection({"proactive_hints": []}) == ""
|
||||
assert _proactive_hints_injection({"proactive_hints": None}) == ""
|
||||
|
||||
|
||||
def test_proactive_hints_injection_truncates_long_hints():
|
||||
hints = ["x" * 200] * 10
|
||||
result = _proactive_hints_injection({"proactive_hints": hints})
|
||||
assert len(result) <= 600
|
||||
assert result.endswith("...")
|
||||
|
||||
|
||||
# ── enrich_context includes proactive hints ───────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
||||
pattern = "Always checks tasks before meetings"
|
||||
db_session.add(MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=USER_ID,
|
||||
pattern_encrypted=_enc(pattern),
|
||||
confidence=0.8,
|
||||
source="inferred",
|
||||
))
|
||||
await db_session.commit()
|
||||
|
||||
middleware = MemoryMiddleware(db_session)
|
||||
ctx = await middleware.enrich_context(USER_ID, "test message")
|
||||
|
||||
assert "proactive_hints" in ctx
|
||||
assert pattern in ctx["proactive_hints"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_context_excludes_low_confidence_proactive(db_session, user_with_key):
|
||||
pattern = "Low confidence pattern"
|
||||
db_session.add(MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=USER_ID,
|
||||
pattern_encrypted=_enc(pattern),
|
||||
confidence=0.1,
|
||||
source="inferred",
|
||||
))
|
||||
await db_session.commit()
|
||||
|
||||
middleware = MemoryMiddleware(db_session)
|
||||
ctx = await middleware.enrich_context(USER_ID, "test message")
|
||||
|
||||
hints = ctx.get("proactive_hints", [])
|
||||
assert pattern not in hints
|
||||
|
||||
|
||||
# ── proactive hints appear in system prompt string ───────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proactive_hints_in_system_prompt_string(db_session, user_with_key):
|
||||
pattern = "Frequently requests end-of-day summaries"
|
||||
db_session.add(MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=USER_ID,
|
||||
pattern_encrypted=_enc(pattern),
|
||||
confidence=0.75,
|
||||
source="inferred",
|
||||
))
|
||||
await db_session.commit()
|
||||
|
||||
middleware = MemoryMiddleware(db_session)
|
||||
ctx = await middleware.enrich_context(USER_ID, "summarize my day")
|
||||
|
||||
system_prompt_suffix = _proactive_hints_injection(ctx)
|
||||
assert pattern in system_prompt_suffix
|
||||
|
||||
|
||||
# ── Tier gate ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.parametrize("tier,expected", [
|
||||
("free", False),
|
||||
("pro", False),
|
||||
("power", True),
|
||||
("team", True),
|
||||
])
|
||||
def test_proactive_mining_tier_gate(tier, expected):
|
||||
assert tier_manager.check_feature(tier, "proactive_mining") == expected
|
||||
Reference in New Issue
Block a user