feat: add API key management with entity model and service layer [AC-AISVC-APIKEY]
This commit is contained in:
parent
5f4bde8752
commit
f823e8fb86
|
|
@ -4,6 +4,7 @@ API Key management endpoints.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
|
@ -26,6 +27,9 @@ class ApiKeyResponse(BaseModel):
|
|||
key: str = Field(..., description="API key value")
|
||||
name: str = Field(..., description="API key name")
|
||||
is_active: bool = Field(..., description="Whether the key is active")
|
||||
expires_at: str | None = Field(default=None, description="Expiration time")
|
||||
allowed_ips: list[str] | None = Field(default=None, description="Optional client IP allowlist")
|
||||
rate_limit_qpm: int | None = Field(default=60, description="Per-minute quota")
|
||||
created_at: str = Field(..., description="Creation time")
|
||||
updated_at: str = Field(..., description="Last update time")
|
||||
|
||||
|
|
@ -42,6 +46,9 @@ class CreateApiKeyRequest(BaseModel):
|
|||
|
||||
name: str = Field(..., description="API key name/description")
|
||||
key: str | None = Field(default=None, description="Custom API key (auto-generated if not provided)")
|
||||
expires_at: datetime | None = Field(default=None, description="Expiration time; null means never expires")
|
||||
allowed_ips: list[str] | None = Field(default=None, description="Optional client IP allowlist")
|
||||
rate_limit_qpm: int | None = Field(default=60, ge=1, le=60000, description="Per-minute quota")
|
||||
|
||||
|
||||
class ToggleApiKeyRequest(BaseModel):
|
||||
|
|
@ -57,6 +64,9 @@ def api_key_to_response(api_key: ApiKey) -> ApiKeyResponse:
|
|||
key=api_key.key,
|
||||
name=api_key.name,
|
||||
is_active=api_key.is_active,
|
||||
expires_at=api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
allowed_ips=api_key.allowed_ips,
|
||||
rate_limit_qpm=api_key.rate_limit_qpm,
|
||||
created_at=api_key.created_at.isoformat(),
|
||||
updated_at=api_key.updated_at.isoformat(),
|
||||
)
|
||||
|
|
@ -94,6 +104,9 @@ async def create_api_key(
|
|||
key=key_value,
|
||||
name=request.name,
|
||||
is_active=True,
|
||||
expires_at=request.expires_at,
|
||||
allowed_ips=request.allowed_ips,
|
||||
rate_limit_qpm=request.rate_limit_qpm,
|
||||
)
|
||||
|
||||
api_key = await service.create_key(session, key_create)
|
||||
|
|
|
|||
|
|
@ -294,6 +294,13 @@ class ApiKey(SQLModel, table=True):
|
|||
key: str = Field(..., description="API Key (unique)", unique=True, index=True)
|
||||
name: str = Field(..., description="Key name/description for identification")
|
||||
is_active: bool = Field(default=True, description="Whether the key is active")
|
||||
expires_at: datetime | None = Field(default=None, description="Expiration time; null means never expires")
|
||||
allowed_ips: list[str] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("allowed_ips", JSON, nullable=True),
|
||||
description="Optional IP allowlist for this key",
|
||||
)
|
||||
rate_limit_qpm: int | None = Field(default=60, description="Per-minute quota for this key")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
|
@ -304,6 +311,9 @@ class ApiKeyCreate(SQLModel):
|
|||
key: str
|
||||
name: str
|
||||
is_active: bool = True
|
||||
expires_at: datetime | None = None
|
||||
allowed_ips: list[str] | None = None
|
||||
rate_limit_qpm: int | None = 60
|
||||
|
||||
|
||||
class TemplateVersionStatus(str, Enum):
|
||||
|
|
|
|||
|
|
@ -3,9 +3,13 @@ API Key management service.
|
|||
[AC-AISVC-50] Lightweight authentication with in-memory cache.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
|
@ -16,6 +20,25 @@ from app.models.entities import ApiKey, ApiKeyCreate
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedApiKeyMeta:
|
||||
"""Cached metadata for API key policy checks."""
|
||||
|
||||
is_active: bool
|
||||
expires_at: datetime | None
|
||||
allowed_ips: set[str] = field(default_factory=set)
|
||||
rate_limit_qpm: int = 60
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Validation output for middleware auth + policy checks."""
|
||||
|
||||
ok: bool
|
||||
reason: str | None = None
|
||||
rate_limit_qpm: int = 60
|
||||
|
||||
|
||||
class ApiKeyService:
|
||||
"""
|
||||
[AC-AISVC-50] API Key management service.
|
||||
|
|
@ -28,6 +51,8 @@ class ApiKeyService:
|
|||
|
||||
def __init__(self):
|
||||
self._keys_cache: set[str] = set()
|
||||
self._key_meta: dict[str, CachedApiKeyMeta] = {}
|
||||
self._rate_buckets: dict[str, deque[datetime]] = {}
|
||||
self._initialized: bool = False
|
||||
|
||||
async def initialize(self, session: AsyncSession) -> None:
|
||||
|
|
@ -35,15 +60,50 @@ class ApiKeyService:
|
|||
Load all active API keys from database into memory.
|
||||
Should be called on application startup.
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(ApiKey).where(ApiKey.is_active == True)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ApiKey).where(ApiKey.is_active == True)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
|
||||
self._keys_cache = {key.key for key in keys}
|
||||
self._initialized = True
|
||||
self._keys_cache = {key.key for key in keys}
|
||||
self._key_meta = {
|
||||
key.key: CachedApiKeyMeta(
|
||||
is_active=key.is_active,
|
||||
expires_at=key.expires_at,
|
||||
allowed_ips=set(key.allowed_ips or []),
|
||||
rate_limit_qpm=key.rate_limit_qpm or 60,
|
||||
)
|
||||
for key in keys
|
||||
}
|
||||
self._initialized = True
|
||||
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys into memory")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}")
|
||||
await session.rollback()
|
||||
|
||||
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys into memory")
|
||||
# Backward-compat fallback for environments without new columns
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ApiKey.key, ApiKey.is_active).where(ApiKey.is_active == True)
|
||||
)
|
||||
rows = result.all()
|
||||
self._keys_cache = {row[0] for row in rows}
|
||||
self._key_meta = {
|
||||
row[0]: CachedApiKeyMeta(
|
||||
is_active=bool(row[1]),
|
||||
expires_at=None,
|
||||
allowed_ips=set(),
|
||||
rate_limit_qpm=60,
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
self._initialized = True
|
||||
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys in legacy compatibility mode")
|
||||
except Exception as fallback_error:
|
||||
self._initialized = False
|
||||
logger.error(f"[AC-AISVC-50] API key initialization failed in both full/legacy mode: {fallback_error}")
|
||||
|
||||
def validate_key(self, key: str) -> bool:
|
||||
"""
|
||||
|
|
@ -61,6 +121,41 @@ class ApiKeyService:
|
|||
|
||||
return key in self._keys_cache
|
||||
|
||||
def validate_key_with_context(self, key: str, client_ip: str | None) -> ValidationResult:
|
||||
"""Validate key and policy constraints: expiration, IP allowlist, and per-minute rate."""
|
||||
if not self._initialized:
|
||||
return ValidationResult(ok=False, reason="service_not_initialized")
|
||||
|
||||
if key not in self._keys_cache:
|
||||
return ValidationResult(ok=False, reason="invalid_key")
|
||||
|
||||
meta = self._key_meta.get(key)
|
||||
if not meta or not meta.is_active:
|
||||
return ValidationResult(ok=False, reason="inactive_key")
|
||||
|
||||
now = datetime.utcnow()
|
||||
if meta.expires_at and now > meta.expires_at:
|
||||
return ValidationResult(ok=False, reason="expired_key")
|
||||
|
||||
if meta.allowed_ips and client_ip and client_ip not in meta.allowed_ips:
|
||||
return ValidationResult(ok=False, reason="ip_not_allowed")
|
||||
|
||||
self._evict_stale_rate_entries(key, now)
|
||||
bucket = self._rate_buckets.setdefault(key, deque())
|
||||
limit = meta.rate_limit_qpm or 60
|
||||
if len(bucket) >= limit:
|
||||
return ValidationResult(ok=False, reason="rate_limited", rate_limit_qpm=limit)
|
||||
|
||||
bucket.append(now)
|
||||
return ValidationResult(ok=True, rate_limit_qpm=limit)
|
||||
|
||||
def _evict_stale_rate_entries(self, key: str, now: datetime) -> None:
|
||||
"""Keep only requests in the latest 60 seconds for token bucket emulation."""
|
||||
bucket = self._rate_buckets.setdefault(key, deque())
|
||||
threshold = now - timedelta(seconds=60)
|
||||
while bucket and bucket[0] < threshold:
|
||||
bucket.popleft()
|
||||
|
||||
def generate_key(self) -> str:
|
||||
"""
|
||||
Generate a new secure API key.
|
||||
|
|
@ -89,6 +184,9 @@ class ApiKeyService:
|
|||
key=key_create.key,
|
||||
name=key_create.name,
|
||||
is_active=key_create.is_active,
|
||||
expires_at=key_create.expires_at,
|
||||
allowed_ips=key_create.allowed_ips,
|
||||
rate_limit_qpm=key_create.rate_limit_qpm or 60,
|
||||
)
|
||||
|
||||
session.add(api_key)
|
||||
|
|
@ -97,6 +195,12 @@ class ApiKeyService:
|
|||
|
||||
if api_key.is_active:
|
||||
self._keys_cache.add(api_key.key)
|
||||
self._key_meta[api_key.key] = CachedApiKeyMeta(
|
||||
is_active=api_key.is_active,
|
||||
expires_at=api_key.expires_at,
|
||||
allowed_ips=set(api_key.allowed_ips or []),
|
||||
rate_limit_qpm=api_key.rate_limit_qpm or 60,
|
||||
)
|
||||
|
||||
logger.info(f"[AC-AISVC-50] Created API key: {api_key.name}")
|
||||
return api_key
|
||||
|
|
@ -108,8 +212,14 @@ class ApiKeyService:
|
|||
Returns:
|
||||
The created ApiKey or None if keys already exist
|
||||
"""
|
||||
result = await session.execute(select(ApiKey).limit(1))
|
||||
existing = result.scalar_one_or_none()
|
||||
try:
|
||||
result = await session.execute(select(ApiKey).limit(1))
|
||||
existing = result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-50] Full schema query failed in create_default_key, using fallback: {e}")
|
||||
await session.rollback()
|
||||
result = await session.execute(select(ApiKey.key).limit(1))
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
return None
|
||||
|
|
@ -126,6 +236,12 @@ class ApiKeyService:
|
|||
await session.refresh(api_key)
|
||||
|
||||
self._keys_cache.add(api_key.key)
|
||||
self._key_meta[api_key.key] = CachedApiKeyMeta(
|
||||
is_active=api_key.is_active,
|
||||
expires_at=getattr(api_key, 'expires_at', None),
|
||||
allowed_ips=set(getattr(api_key, 'allowed_ips', []) or []),
|
||||
rate_limit_qpm=getattr(api_key, 'rate_limit_qpm', 60) or 60,
|
||||
)
|
||||
|
||||
logger.info(f"[AC-AISVC-50] Created default API key: {api_key.key}")
|
||||
return api_key
|
||||
|
|
@ -165,6 +281,8 @@ class ApiKeyService:
|
|||
await session.commit()
|
||||
|
||||
self._keys_cache.discard(key_value)
|
||||
self._key_meta.pop(key_value, None)
|
||||
self._rate_buckets.pop(key_value, None)
|
||||
|
||||
logger.info(f"[AC-AISVC-50] Deleted API key: {api_key.name}")
|
||||
return True
|
||||
|
|
@ -210,8 +328,16 @@ class ApiKeyService:
|
|||
|
||||
if is_active:
|
||||
self._keys_cache.add(api_key.key)
|
||||
self._key_meta[api_key.key] = CachedApiKeyMeta(
|
||||
is_active=api_key.is_active,
|
||||
expires_at=api_key.expires_at,
|
||||
allowed_ips=set(api_key.allowed_ips or []),
|
||||
rate_limit_qpm=api_key.rate_limit_qpm or 60,
|
||||
)
|
||||
else:
|
||||
self._keys_cache.discard(api_key.key)
|
||||
self._key_meta.pop(api_key.key, None)
|
||||
self._rate_buckets.pop(api_key.key, None)
|
||||
|
||||
logger.info(f"[AC-AISVC-50] Toggled API key {api_key.name}: active={is_active}")
|
||||
return api_key
|
||||
|
|
@ -234,6 +360,8 @@ class ApiKeyService:
|
|||
Reload all API keys from database into memory.
|
||||
"""
|
||||
self._keys_cache.clear()
|
||||
self._key_meta.clear()
|
||||
self._rate_buckets.clear()
|
||||
await self.initialize(session)
|
||||
logger.info("[AC-AISVC-50] API key cache reloaded")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue