ai-robot-core/ai-service/app/services/prompt/template_service.py

416 lines
13 KiB
Python
Raw Normal View History

"""
Prompt template service for AI Service.
[AC-AISVC-51~AC-AISVC-58] Template CRUD, version management, publish/rollback, and caching.
"""
import logging
import time
import uuid
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from app.core.prompts import SYSTEM_PROMPT
from app.models.entities import (
PromptTemplate,
PromptTemplateCreate,
PromptTemplateUpdate,
PromptTemplateVersion,
TemplateVersionStatus,
)
from app.services.prompt.variable_resolver import VariableResolver
logger = logging.getLogger(__name__)
CACHE_TTL_SECONDS = 300
class TemplateCache:
"""
[AC-AISVC-51] In-memory cache for published templates.
Key: (tenant_id, scene)
Value: (template_version, cached_at)
TTL: 300 seconds
"""
def __init__(self, ttl_seconds: int = CACHE_TTL_SECONDS):
self._cache: dict[tuple[str, str], tuple[PromptTemplateVersion, float]] = {}
self._ttl = ttl_seconds
def get(self, tenant_id: str, scene: str) -> PromptTemplateVersion | None:
"""Get cached template version if not expired."""
key = (tenant_id, scene)
if key in self._cache:
version, cached_at = self._cache[key]
if time.time() - cached_at < self._ttl:
return version
else:
del self._cache[key]
return None
def set(self, tenant_id: str, scene: str, version: PromptTemplateVersion) -> None:
"""Cache a template version."""
key = (tenant_id, scene)
self._cache[key] = (version, time.time())
def invalidate(self, tenant_id: str, scene: str | None = None) -> None:
"""Invalidate cache for a tenant (optionally for a specific scene)."""
if scene:
key = (tenant_id, scene)
if key in self._cache:
del self._cache[key]
else:
keys_to_delete = [k for k in self._cache if k[0] == tenant_id]
for key in keys_to_delete:
del self._cache[key]
_template_cache = TemplateCache()
class PromptTemplateService:
"""
[AC-AISVC-52~AC-AISVC-58] Service for managing prompt templates.
Features:
- Template CRUD with tenant isolation
- Version management (auto-create new version on update)
- Publish/rollback functionality
- In-memory caching with TTL
- Fallback to hardcoded SYSTEM_PROMPT
"""
def __init__(self, session: AsyncSession):
self._session = session
self._cache = _template_cache
async def create_template(
self,
tenant_id: str,
create_data: PromptTemplateCreate,
) -> PromptTemplate:
"""
[AC-AISVC-52] Create a new prompt template with initial version.
"""
template = PromptTemplate(
tenant_id=tenant_id,
name=create_data.name,
scene=create_data.scene,
description=create_data.description,
is_default=create_data.is_default,
)
self._session.add(template)
await self._session.flush()
initial_version = PromptTemplateVersion(
template_id=template.id,
version=1,
status=TemplateVersionStatus.DRAFT.value,
system_instruction=create_data.system_instruction,
variables=create_data.variables,
)
self._session.add(initial_version)
await self._session.flush()
logger.info(
f"[AC-AISVC-52] Created prompt template: tenant={tenant_id}, "
f"id={template.id}, name={template.name}"
)
return template
async def list_templates(
self,
tenant_id: str,
scene: str | None = None,
) -> Sequence[PromptTemplate]:
"""
[AC-AISVC-57] List templates for a tenant, optionally filtered by scene.
"""
stmt = select(PromptTemplate).where(
PromptTemplate.tenant_id == tenant_id
)
if scene:
stmt = stmt.where(PromptTemplate.scene == scene)
stmt = stmt.order_by(col(PromptTemplate.created_at).desc())
result = await self._session.execute(stmt)
return result.scalars().all()
async def get_template(
self,
tenant_id: str,
template_id: uuid.UUID,
) -> PromptTemplate | None:
"""
[AC-AISVC-58] Get template by ID with tenant isolation.
"""
stmt = select(PromptTemplate).where(
PromptTemplate.tenant_id == tenant_id,
PromptTemplate.id == template_id,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_template_detail(
self,
tenant_id: str,
template_id: uuid.UUID,
) -> dict[str, Any] | None:
"""
[AC-AISVC-58] Get template detail with all versions.
"""
template = await self.get_template(tenant_id, template_id)
if not template:
return None
versions = await self._get_versions(template_id)
current_version = None
for v in versions:
if v.status == TemplateVersionStatus.PUBLISHED.value:
current_version = v
break
return {
"id": str(template.id),
"name": template.name,
"scene": template.scene,
"description": template.description,
"is_default": template.is_default,
"current_version": {
"version": current_version.version,
"status": current_version.status,
"system_instruction": current_version.system_instruction,
"variables": current_version.variables or [],
} if current_version else None,
"versions": [
{
"version": v.version,
"status": v.status,
"created_at": v.created_at.isoformat(),
}
for v in versions
],
"created_at": template.created_at.isoformat(),
"updated_at": template.updated_at.isoformat(),
}
async def update_template(
self,
tenant_id: str,
template_id: uuid.UUID,
update_data: PromptTemplateUpdate,
) -> PromptTemplate | None:
"""
[AC-AISVC-53] Update template and create a new version.
"""
template = await self.get_template(tenant_id, template_id)
if not template:
return None
if update_data.name is not None:
template.name = update_data.name
if update_data.scene is not None:
template.scene = update_data.scene
if update_data.description is not None:
template.description = update_data.description
if update_data.is_default is not None:
template.is_default = update_data.is_default
template.updated_at = datetime.utcnow()
if update_data.system_instruction is not None:
latest_version = await self._get_latest_version(template_id)
new_version_num = (latest_version.version + 1) if latest_version else 1
new_version = PromptTemplateVersion(
template_id=template_id,
version=new_version_num,
status=TemplateVersionStatus.DRAFT.value,
system_instruction=update_data.system_instruction,
variables=update_data.variables,
)
self._session.add(new_version)
await self._session.flush()
self._cache.invalidate(tenant_id, template.scene)
logger.info(
f"[AC-AISVC-53] Updated prompt template: tenant={tenant_id}, id={template_id}"
)
return template
async def publish_version(
self,
tenant_id: str,
template_id: uuid.UUID,
version: int,
) -> bool:
"""
[AC-AISVC-54] Publish a specific version of the template.
Old published version will be archived.
"""
template = await self.get_template(tenant_id, template_id)
if not template:
return False
versions = await self._get_versions(template_id)
for v in versions:
if v.status == TemplateVersionStatus.PUBLISHED.value:
v.status = TemplateVersionStatus.ARCHIVED.value
target_version = None
for v in versions:
if v.version == version:
target_version = v
break
if not target_version:
return False
target_version.status = TemplateVersionStatus.PUBLISHED.value
await self._session.flush()
self._cache.invalidate(tenant_id, template.scene)
self._cache.set(tenant_id, template.scene, target_version)
logger.info(
f"[AC-AISVC-54] Published template version: tenant={tenant_id}, "
f"template_id={template_id}, version={version}"
)
return True
async def rollback_version(
self,
tenant_id: str,
template_id: uuid.UUID,
version: int,
) -> bool:
"""
[AC-AISVC-55] Rollback to a specific historical version.
"""
return await self.publish_version(tenant_id, template_id, version)
async def get_published_template(
self,
tenant_id: str,
scene: str,
resolver: VariableResolver | None = None,
) -> str:
"""
[AC-AISVC-51, AC-AISVC-56] Get the published template for a scene.
Resolution order:
1. Check in-memory cache
2. Query database for published version
3. Fallback to hardcoded SYSTEM_PROMPT
"""
cached = self._cache.get(tenant_id, scene)
if cached:
logger.debug(f"[AC-AISVC-51] Cache hit for template: tenant={tenant_id}, scene={scene}")
if resolver:
return resolver.resolve(cached.system_instruction, cached.variables)
return cached.system_instruction
stmt = (
select(PromptTemplateVersion)
.join(PromptTemplate, PromptTemplateVersion.template_id == PromptTemplate.id)
.where(
PromptTemplate.tenant_id == tenant_id,
PromptTemplate.scene == scene,
PromptTemplateVersion.status == TemplateVersionStatus.PUBLISHED.value,
)
)
result = await self._session.execute(stmt)
published_version = result.scalar_one_or_none()
if published_version:
self._cache.set(tenant_id, scene, published_version)
logger.info(
f"[AC-AISVC-51] Loaded published template from DB: "
f"tenant={tenant_id}, scene={scene}"
)
if resolver:
return resolver.resolve(published_version.system_instruction, published_version.variables)
return published_version.system_instruction
logger.info(
f"[AC-AISVC-51] No published template found, using fallback: "
f"tenant={tenant_id}, scene={scene}"
)
return SYSTEM_PROMPT
async def get_published_version_info(
self,
tenant_id: str,
template_id: uuid.UUID,
) -> int | None:
"""Get the published version number for a template."""
stmt = (
select(PromptTemplateVersion)
.where(
PromptTemplateVersion.template_id == template_id,
PromptTemplateVersion.status == TemplateVersionStatus.PUBLISHED.value,
)
)
result = await self._session.execute(stmt)
version = result.scalar_one_or_none()
return version.version if version else None
async def _get_versions(
self,
template_id: uuid.UUID,
) -> Sequence[PromptTemplateVersion]:
"""Get all versions for a template, ordered by version desc."""
stmt = (
select(PromptTemplateVersion)
.where(PromptTemplateVersion.template_id == template_id)
.order_by(col(PromptTemplateVersion.version).desc())
)
result = await self._session.execute(stmt)
return result.scalars().all()
async def _get_latest_version(
self,
template_id: uuid.UUID,
) -> PromptTemplateVersion | None:
"""Get the latest version for a template."""
stmt = (
select(PromptTemplateVersion)
.where(PromptTemplateVersion.template_id == template_id)
.order_by(col(PromptTemplateVersion.version).desc())
.limit(1)
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def delete_template(
self,
tenant_id: str,
template_id: uuid.UUID,
) -> bool:
"""Delete a template and all its versions."""
template = await self.get_template(tenant_id, template_id)
if not template:
return False
versions = await self._get_versions(template_id)
for v in versions:
await self._session.delete(v)
await self._session.delete(template)
await self._session.flush()
self._cache.invalidate(tenant_id, template.scene)
logger.info(
f"Deleted prompt template: tenant={tenant_id}, id={template_id}"
)
return True