"""Response sanitizer middleware. Scans JSON responses from the /api/v1/chat endpoint and strips any fragments that could reveal server-side prompt IP: - System prompt openers ("You are a/an/the …") - Agent routing metadata ("Available agents:", "intent classifier", …) - LangChain tool schema fragments (``"type": "function"``) - Internal reasoning markers (, , [INST], …) - Exact-match known prompt fingerprints Binary responses (storage blobs, backup data) are never touched — the middleware only activates for paths under /api/v1/chat. Any sanitisation event is logged as a WARNING with the request path and the names of the fields that were modified. """ from __future__ import annotations import json import logging import re from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Detection patterns — order matters: fingerprints checked first (exact), # then compiled regexes. # --------------------------------------------------------------------------- _FINGERPRINTS: tuple[str, ...] = ( "You are an intent classifier", "Respond with just the agent name", "Summarize these agent results", "Available agents:", "route to:", ) _PATTERNS: tuple[re.Pattern[str], ...] = ( re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL), re.compile(r"Available agents\s*:", re.IGNORECASE), re.compile(r"\bintent classifier\b", re.IGNORECASE), re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE), re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers re.compile(r"route\s+to\s*:", re.IGNORECASE), re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE), ) def _sanitize_text(text: str) -> tuple[str, bool]: """Scan *text* for prompt fragments and replace matches with ``[REDACTED]``. Returns ``(cleaned_text, was_changed)``. """ # Fingerprint check — if any exact phrase is present, redact the whole string. for fp in _FINGERPRINTS: if fp in text: return "[REDACTED]", True changed = False for pattern in _PATTERNS: new_text, n = pattern.subn("[REDACTED]", text) if n: text = new_text changed = True return text, changed class SanitizerMiddleware(BaseHTTPMiddleware): """Strip prompt IP from /api/v1/chat JSON responses.""" def __init__(self, app: ASGIApp) -> None: super().__init__(app) async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override] response: Response = await call_next(request) # Only process chat endpoint responses. if not request.url.path.startswith("/api/v1/chat"): return response # Read body — collect streaming chunks. body_bytes = b"" async for chunk in response.body_iterator: body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode() # Skip non-JSON bodies (shouldn't happen on /chat, but be safe). try: body = json.loads(body_bytes.decode("utf-8")) except (json.JSONDecodeError, UnicodeDecodeError): return Response( content=body_bytes, status_code=response.status_code, headers=dict(response.headers), media_type=response.media_type, ) if not isinstance(body, dict): return Response( content=body_bytes, status_code=response.status_code, headers=dict(response.headers), media_type=response.media_type, ) # Walk top-level string fields and sanitise. sanitised_fields: list[str] = [] for key, value in body.items(): if isinstance(value, str): cleaned, changed = _sanitize_text(value) if changed: body[key] = cleaned sanitised_fields.append(key) if sanitised_fields: logger.warning( "Sanitizer redacted prompt fragments", extra={ "path": request.url.path, "fields": sanitised_fields, }, ) new_body = json.dumps(body).encode("utf-8") headers = dict(response.headers) headers["content-length"] = str(len(new_body)) return Response( content=new_body, status_code=response.status_code, headers=headers, media_type="application/json", )