""" Middleware for AI Service. [AC-AISVC-10, AC-AISVC-12] X-Tenant-Id header validation and tenant context injection. """ import logging from typing import Callable from fastapi import Request, Response, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from app.core.exceptions import ErrorCode, ErrorResponse, MissingTenantIdException from app.core.tenant import clear_tenant_context, set_tenant_context logger = logging.getLogger(__name__) TENANT_ID_HEADER = "X-Tenant-Id" ACCEPT_HEADER = "Accept" SSE_CONTENT_TYPE = "text/event-stream" class TenantContextMiddleware(BaseHTTPMiddleware): """ [AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header. Injects tenant context into request state for downstream processing. """ async def dispatch(self, request: Request, call_next: Callable) -> Response: clear_tenant_context() if request.url.path == "/ai/health": return await call_next(request) tenant_id = request.headers.get(TENANT_ID_HEADER) if not tenant_id or not tenant_id.strip(): logger.warning("[AC-AISVC-12] Missing or empty X-Tenant-Id header") return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content=ErrorResponse( code=ErrorCode.MISSING_TENANT_ID.value, message="Missing required header: X-Tenant-Id", ).model_dump(exclude_none=True), ) set_tenant_context(tenant_id.strip()) request.state.tenant_id = tenant_id.strip() logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id.strip()}") try: response = await call_next(request) finally: clear_tenant_context() return response def is_sse_request(request: Request) -> bool: """ [AC-AISVC-06] Check if the request expects SSE streaming response. Based on Accept header: text/event-stream indicates SSE mode. """ accept_header = request.headers.get(ACCEPT_HEADER, "") return SSE_CONTENT_TYPE in accept_header def get_response_mode(request: Request) -> str: """ [AC-AISVC-06] Determine response mode based on Accept header. Returns 'streaming' for SSE, 'json' for regular JSON response. """ return "streaming" if is_sse_request(request) else "json"