207 lines
6.9 KiB
Python
207 lines
6.9 KiB
Python
"""
|
|
Middleware for AI Service.
|
|
[AC-AISVC-10, AC-AISVC-12, AC-AISVC-50] X-Tenant-Id header validation, tenant context injection, and API Key authentication.
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
from typing import Callable
|
|
|
|
from fastapi import Request, Response, status
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from app.core.exceptions import ErrorCode, ErrorResponse, MissingTenantIdException
|
|
from app.core.tenant import clear_tenant_context, set_tenant_context
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TENANT_ID_HEADER = "X-Tenant-Id"
|
|
API_KEY_HEADER = "X-API-Key"
|
|
ACCEPT_HEADER = "Accept"
|
|
SSE_CONTENT_TYPE = "text/event-stream"
|
|
|
|
TENANT_ID_PATTERN = re.compile(r'^[^@]+@ash@\d{4}$')
|
|
|
|
PATHS_SKIP_API_KEY = {
|
|
"/health",
|
|
"/ai/health",
|
|
"/docs",
|
|
"/redoc",
|
|
"/openapi.json",
|
|
}
|
|
|
|
|
|
def validate_tenant_id_format(tenant_id: str) -> bool:
|
|
"""
|
|
[AC-AISVC-10] Validate tenant ID format: name@ash@year
|
|
Examples: szmp@ash@2026, abc123@ash@2025
|
|
"""
|
|
return bool(TENANT_ID_PATTERN.match(tenant_id))
|
|
|
|
|
|
def parse_tenant_id(tenant_id: str) -> tuple[str, str]:
|
|
"""
|
|
[AC-AISVC-10] Parse tenant ID into name and year.
|
|
Returns: (name, year)
|
|
"""
|
|
parts = tenant_id.split('@')
|
|
return parts[0], parts[2]
|
|
|
|
|
|
class ApiKeyMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
[AC-AISVC-50] Middleware to validate API Key for all requests.
|
|
|
|
Features:
|
|
- Validates X-API-Key header against in-memory cache
|
|
- Skips validation for health/docs endpoints
|
|
- Returns 401 for missing or invalid API key
|
|
"""
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
if self._should_skip_api_key(request.url.path):
|
|
return await call_next(request)
|
|
|
|
api_key = request.headers.get(API_KEY_HEADER)
|
|
|
|
if not api_key or not api_key.strip():
|
|
logger.warning(f"[AC-AISVC-50] Missing X-API-Key header for {request.url.path}")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content=ErrorResponse(
|
|
code=ErrorCode.UNAUTHORIZED.value,
|
|
message="Missing required header: X-API-Key",
|
|
).model_dump(exclude_none=True),
|
|
)
|
|
|
|
api_key = api_key.strip()
|
|
|
|
from app.services.api_key import get_api_key_service
|
|
service = get_api_key_service()
|
|
|
|
if not service.validate_key(api_key):
|
|
logger.warning(f"[AC-AISVC-50] Invalid API key for {request.url.path}")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content=ErrorResponse(
|
|
code=ErrorCode.UNAUTHORIZED.value,
|
|
message="Invalid API key",
|
|
).model_dump(exclude_none=True),
|
|
)
|
|
|
|
return await call_next(request)
|
|
|
|
def _should_skip_api_key(self, path: str) -> bool:
|
|
"""Check if the path should skip API key validation."""
|
|
if path in PATHS_SKIP_API_KEY:
|
|
return True
|
|
for skip_path in PATHS_SKIP_API_KEY:
|
|
if path.startswith(skip_path):
|
|
return True
|
|
return False
|
|
|
|
|
|
class TenantContextMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
[AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header.
|
|
Injects tenant context into request state for downstream processing.
|
|
Validates tenant ID format and auto-creates tenant if not exists.
|
|
"""
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
clear_tenant_context()
|
|
|
|
if request.url.path in ("/health", "/ai/health"):
|
|
return await call_next(request)
|
|
|
|
tenant_id = request.headers.get(TENANT_ID_HEADER)
|
|
|
|
if not tenant_id or not tenant_id.strip():
|
|
logger.warning("[AC-AISVC-12] Missing or empty X-Tenant-Id header")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content=ErrorResponse(
|
|
code=ErrorCode.MISSING_TENANT_ID.value,
|
|
message="Missing required header: X-Tenant-Id",
|
|
).model_dump(exclude_none=True),
|
|
)
|
|
|
|
tenant_id = tenant_id.strip()
|
|
|
|
if not validate_tenant_id_format(tenant_id):
|
|
logger.warning(f"[AC-AISVC-10] Invalid tenant ID format: {tenant_id}")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content=ErrorResponse(
|
|
code=ErrorCode.INVALID_TENANT_ID.value,
|
|
message="Invalid tenant ID format. Expected: name@ash@year (e.g., szmp@ash@2026)",
|
|
).model_dump(exclude_none=True),
|
|
)
|
|
|
|
if request.url.path.startswith("/admin/") or request.url.path.startswith("/ai/"):
|
|
try:
|
|
await self._ensure_tenant_exists(request, tenant_id)
|
|
except Exception as e:
|
|
logger.error(f"[AC-AISVC-10] Failed to ensure tenant exists: {e}")
|
|
|
|
set_tenant_context(tenant_id)
|
|
request.state.tenant_id = tenant_id
|
|
|
|
logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id}")
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
finally:
|
|
clear_tenant_context()
|
|
|
|
return response
|
|
|
|
async def _ensure_tenant_exists(self, request: Request, tenant_id: str) -> None:
|
|
"""
|
|
[AC-AISVC-10] Ensure tenant exists in database, create if not.
|
|
"""
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.database import async_session_maker
|
|
from app.models.entities import Tenant
|
|
|
|
name, year = parse_tenant_id(tenant_id)
|
|
|
|
async with async_session_maker() as session:
|
|
stmt = select(Tenant).where(Tenant.tenant_id == tenant_id)
|
|
result = await session.execute(stmt)
|
|
existing_tenant = result.scalar_one_or_none()
|
|
|
|
if existing_tenant:
|
|
logger.debug(f"[AC-AISVC-10] Tenant already exists: {tenant_id}")
|
|
return
|
|
|
|
new_tenant = Tenant(
|
|
tenant_id=tenant_id,
|
|
name=name,
|
|
year=year,
|
|
)
|
|
session.add(new_tenant)
|
|
await session.commit()
|
|
|
|
logger.info(f"[AC-AISVC-10] Auto-created new tenant: {tenant_id} (name={name}, year={year})")
|
|
|
|
|
|
def is_sse_request(request: Request) -> bool:
|
|
"""
|
|
[AC-AISVC-06] Check if the request expects SSE streaming response.
|
|
Based on Accept header: text/event-stream indicates SSE mode.
|
|
"""
|
|
accept_header = request.headers.get(ACCEPT_HEADER, "")
|
|
return SSE_CONTENT_TYPE in accept_header
|
|
|
|
|
|
def get_response_mode(request: Request) -> str:
|
|
"""
|
|
[AC-AISVC-06] Determine response mode based on Accept header.
|
|
Returns 'streaming' for SSE, 'json' for regular JSON response.
|
|
"""
|
|
return "streaming" if is_sse_request(request) else "json"
|