""" Middleware for AI Service. [AC-AISVC-10, AC-AISVC-12] X-Tenant-Id header validation and tenant context injection. """ import logging import re from typing import Callable from fastapi import Request, Response, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from app.core.exceptions import ErrorCode, ErrorResponse, MissingTenantIdException from app.core.tenant import clear_tenant_context, set_tenant_context logger = logging.getLogger(__name__) TENANT_ID_HEADER = "X-Tenant-Id" ACCEPT_HEADER = "Accept" SSE_CONTENT_TYPE = "text/event-stream" # Tenant ID format: name@ash@year (e.g., szmp@ash@2026) TENANT_ID_PATTERN = re.compile(r'^[^@]+@ash@\d{4}$') def validate_tenant_id_format(tenant_id: str) -> bool: """ [AC-AISVC-10] Validate tenant ID format: name@ash@year Examples: szmp@ash@2026, abc123@ash@2025 """ return bool(TENANT_ID_PATTERN.match(tenant_id)) def parse_tenant_id(tenant_id: str) -> tuple[str, str]: """ [AC-AISVC-10] Parse tenant ID into name and year. Returns: (name, year) """ parts = tenant_id.split('@') return parts[0], parts[2] class TenantContextMiddleware(BaseHTTPMiddleware): """ [AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header. Injects tenant context into request state for downstream processing. Validates tenant ID format and auto-creates tenant if not exists. """ async def dispatch(self, request: Request, call_next: Callable) -> Response: clear_tenant_context() if request.url.path == "/ai/health": return await call_next(request) tenant_id = request.headers.get(TENANT_ID_HEADER) if not tenant_id or not tenant_id.strip(): logger.warning("[AC-AISVC-12] Missing or empty X-Tenant-Id header") return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content=ErrorResponse( code=ErrorCode.MISSING_TENANT_ID.value, message="Missing required header: X-Tenant-Id", ).model_dump(exclude_none=True), ) tenant_id = tenant_id.strip() # Validate tenant ID format if not validate_tenant_id_format(tenant_id): logger.warning(f"[AC-AISVC-10] Invalid tenant ID format: {tenant_id}") return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content=ErrorResponse( code=ErrorCode.INVALID_TENANT_ID.value, message="Invalid tenant ID format. Expected: name@ash@year (e.g., szmp@ash@2026)", ).model_dump(exclude_none=True), ) # Auto-create tenant if not exists (for admin endpoints) if request.url.path.startswith("/admin/") or request.url.path.startswith("/ai/"): try: await self._ensure_tenant_exists(request, tenant_id) except Exception as e: logger.error(f"[AC-AISVC-10] Failed to ensure tenant exists: {e}") # Continue processing even if tenant creation fails set_tenant_context(tenant_id) request.state.tenant_id = tenant_id logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id}") try: response = await call_next(request) finally: clear_tenant_context() return response async def _ensure_tenant_exists(self, request: Request, tenant_id: str) -> None: """ [AC-AISVC-10] Ensure tenant exists in database, create if not. """ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import async_session_maker from app.models.entities import Tenant name, year = parse_tenant_id(tenant_id) async with async_session_maker() as session: # Check if tenant exists stmt = select(Tenant).where(Tenant.tenant_id == tenant_id) result = await session.execute(stmt) existing_tenant = result.scalar_one_or_none() if existing_tenant: logger.debug(f"[AC-AISVC-10] Tenant already exists: {tenant_id}") return # Create new tenant new_tenant = Tenant( tenant_id=tenant_id, name=name, year=year, ) session.add(new_tenant) await session.commit() logger.info(f"[AC-AISVC-10] Auto-created new tenant: {tenant_id} (name={name}, year={year})") def is_sse_request(request: Request) -> bool: """ [AC-AISVC-06] Check if the request expects SSE streaming response. Based on Accept header: text/event-stream indicates SSE mode. """ accept_header = request.headers.get(ACCEPT_HEADER, "") return SSE_CONTENT_TYPE in accept_header def get_response_mode(request: Request) -> str: """ [AC-AISVC-06] Determine response mode based on Accept header. Returns 'streaming' for SSE, 'json' for regular JSON response. """ return "streaming" if is_sse_request(request) else "json"