140 lines
4.8 KiB
Python
140 lines
4.8 KiB
Python
"""Response sanitizer middleware.
|
|
|
|
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
|
|
that could reveal server-side prompt IP:
|
|
- System prompt openers ("You are a/an/the …")
|
|
- Agent routing metadata ("Available agents:", "intent classifier", …)
|
|
- LangChain tool schema fragments (``"type": "function"``)
|
|
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
|
- Exact-match known prompt fingerprints
|
|
|
|
Binary responses (storage blobs, backup data) are never touched — the
|
|
middleware only activates for paths under /api/v1/chat.
|
|
|
|
Any sanitisation event is logged as a WARNING with the request path and the
|
|
names of the fields that were modified.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.types import ASGIApp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Detection patterns — order matters: fingerprints checked first (exact),
|
|
# then compiled regexes.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_FINGERPRINTS: tuple[str, ...] = (
|
|
"You are an intent classifier",
|
|
"Respond with just the agent name",
|
|
"Summarize these agent results",
|
|
"Available agents:",
|
|
"route to:",
|
|
)
|
|
|
|
_PATTERNS: tuple[re.Pattern[str], ...] = (
|
|
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
|
|
re.compile(r"Available agents\s*:", re.IGNORECASE),
|
|
re.compile(r"\bintent classifier\b", re.IGNORECASE),
|
|
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
|
|
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
|
|
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
|
|
re.compile(r"route\s+to\s*:", re.IGNORECASE),
|
|
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
|
|
)
|
|
|
|
|
|
def _sanitize_text(text: str) -> tuple[str, bool]:
|
|
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
|
|
|
|
Returns ``(cleaned_text, was_changed)``.
|
|
"""
|
|
# Fingerprint check — if any exact phrase is present, redact the whole string.
|
|
for fp in _FINGERPRINTS:
|
|
if fp in text:
|
|
return "[REDACTED]", True
|
|
|
|
changed = False
|
|
for pattern in _PATTERNS:
|
|
new_text, n = pattern.subn("[REDACTED]", text)
|
|
if n:
|
|
text = new_text
|
|
changed = True
|
|
|
|
return text, changed
|
|
|
|
|
|
class SanitizerMiddleware(BaseHTTPMiddleware):
|
|
"""Strip prompt IP from /api/v1/chat JSON responses."""
|
|
|
|
def __init__(self, app: ASGIApp) -> None:
|
|
super().__init__(app)
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
|
response: Response = await call_next(request)
|
|
|
|
# Only process chat endpoint responses.
|
|
if not request.url.path.startswith("/api/v1/chat"):
|
|
return response
|
|
|
|
# Read body — collect streaming chunks.
|
|
body_bytes = b""
|
|
async for chunk in response.body_iterator:
|
|
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
|
|
|
|
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
|
|
try:
|
|
body = json.loads(body_bytes.decode("utf-8"))
|
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
return Response(
|
|
content=body_bytes,
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers),
|
|
media_type=response.media_type,
|
|
)
|
|
|
|
if not isinstance(body, dict):
|
|
return Response(
|
|
content=body_bytes,
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers),
|
|
media_type=response.media_type,
|
|
)
|
|
|
|
# Walk top-level string fields and sanitise.
|
|
sanitised_fields: list[str] = []
|
|
for key, value in body.items():
|
|
if isinstance(value, str):
|
|
cleaned, changed = _sanitize_text(value)
|
|
if changed:
|
|
body[key] = cleaned
|
|
sanitised_fields.append(key)
|
|
|
|
if sanitised_fields:
|
|
logger.warning(
|
|
"Sanitizer redacted prompt fragments",
|
|
extra={
|
|
"path": request.url.path,
|
|
"fields": sanitised_fields,
|
|
},
|
|
)
|
|
|
|
new_body = json.dumps(body).encode("utf-8")
|
|
headers = dict(response.headers)
|
|
headers["content-length"] = str(len(new_body))
|
|
|
|
return Response(
|
|
content=new_body,
|
|
status_code=response.status_code,
|
|
headers=headers,
|
|
media_type="application/json",
|
|
)
|