ai-robot-core/ai-service/app/api/admin/sessions.py

295 lines
9.0 KiB
Python
Raw Normal View History

"""
Session monitoring and management endpoints.
[AC-ASA-07, AC-ASA-09] Session list and detail monitoring.
"""
import logging
from collections.abc import Sequence
from datetime import datetime
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, Query
from fastapi.responses import JSONResponse
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.models import ErrorResponse
from app.models.entities import ChatMessage, ChatSession, SessionStatus
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/sessions", tags=["Session Monitoring"])
def get_current_tenant_id() -> str:
"""Dependency to get current tenant ID or raise exception."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
@router.get(
"",
operation_id="listSessions",
summary="Query session list",
description="[AC-ASA-09] Get list of sessions with pagination and filtering.",
responses={
200: {"description": "Session list with pagination"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def list_sessions(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
status: Annotated[Optional[str], Query()] = None,
start_time: Annotated[Optional[str], Query(alias="startTime")] = None,
end_time: Annotated[Optional[str], Query(alias="endTime")] = None,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
) -> JSONResponse:
"""
[AC-ASA-09] List sessions with filtering and pagination.
"""
logger.info(
f"[AC-ASA-09] Listing sessions: tenant={tenant_id}, status={status}, "
f"start_time={start_time}, end_time={end_time}, page={page}, page_size={page_size}"
)
stmt = select(ChatSession).where(ChatSession.tenant_id == tenant_id)
if status:
stmt = stmt.where(ChatSession.metadata_["status"].as_string() == status)
if start_time:
try:
start_dt = datetime.fromisoformat(start_time.replace("Z", "+00:00"))
stmt = stmt.where(ChatSession.created_at >= start_dt)
except ValueError:
pass
if end_time:
try:
end_dt = datetime.fromisoformat(end_time.replace("Z", "+00:00"))
stmt = stmt.where(ChatSession.created_at <= end_dt)
except ValueError:
pass
count_stmt = select(func.count()).select_from(stmt.subquery())
total_result = await session.execute(count_stmt)
total = total_result.scalar() or 0
stmt = stmt.order_by(col(ChatSession.created_at).desc())
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
result = await session.execute(stmt)
sessions = result.scalars().all()
session_ids = [s.session_id for s in sessions]
if session_ids:
msg_count_stmt = (
select(
ChatMessage.session_id,
func.count(ChatMessage.id).label("count")
)
.where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.session_id.in_(session_ids)
)
.group_by(ChatMessage.session_id)
)
msg_count_result = await session.execute(msg_count_stmt)
msg_counts = {row.session_id: row.count for row in msg_count_result}
else:
msg_counts = {}
data = []
for s in sessions:
session_status = SessionStatus.ACTIVE.value
if s.metadata_ and "status" in s.metadata_:
session_status = s.metadata_["status"]
end_time_val = None
if s.metadata_ and "endTime" in s.metadata_:
end_time_val = s.metadata_["endTime"]
data.append({
"sessionId": s.session_id,
"tenantId": tenant_id,
"status": session_status,
"startTime": s.created_at.isoformat() + "Z",
"endTime": end_time_val,
"messageCount": msg_counts.get(s.session_id, 0),
"channelType": s.channel_type,
})
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
return JSONResponse(
content={
"data": data,
"pagination": {
"page": page,
"pageSize": page_size,
"total": total,
"totalPages": total_pages,
},
}
)
@router.get(
"/{session_id}",
operation_id="getSessionDetail",
summary="Get session details",
description="[AC-ASA-07] Get full session details with messages and trace.",
responses={
200: {"description": "Full session details with messages and trace"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def get_session_detail(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
session_id: str,
) -> JSONResponse:
"""
[AC-ASA-07] Get session detail with messages and trace information.
"""
logger.info(
f"[AC-ASA-07] Getting session detail: tenant={tenant_id}, session_id={session_id}"
)
session_stmt = select(ChatSession).where(
ChatSession.tenant_id == tenant_id,
ChatSession.session_id == session_id,
)
session_result = await session.execute(session_stmt)
chat_session = session_result.scalar_one_or_none()
if not chat_session:
return JSONResponse(
status_code=404,
content={
"code": "SESSION_NOT_FOUND",
"message": f"Session {session_id} not found",
},
)
messages_stmt = (
select(ChatMessage)
.where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.session_id == session_id,
)
.order_by(col(ChatMessage.created_at).asc())
)
messages_result = await session.execute(messages_stmt)
messages = messages_result.scalars().all()
messages_data = []
for msg in messages:
msg_data = {
"role": msg.role,
"content": msg.content,
"timestamp": msg.created_at.isoformat() + "Z",
}
messages_data.append(msg_data)
trace = _build_trace_info(messages)
return JSONResponse(
content={
"sessionId": session_id,
"messages": messages_data,
"trace": trace,
"metadata": chat_session.metadata_ or {},
}
)
def _build_trace_info(messages: Sequence[ChatMessage]) -> dict:
"""
Build trace information from messages.
This extracts retrieval and tool call information from message metadata.
"""
trace = {
"retrieval": [],
"tools": [],
"errors": [],
}
for msg in messages:
if msg.role == "assistant":
pass
return trace
@router.put(
"/{session_id}/status",
operation_id="updateSessionStatus",
summary="Update session status",
description="[AC-ASA-09] Update session status (active, closed, expired).",
responses={
200: {"description": "Session status updated"},
404: {"description": "Session not found"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def update_session_status(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
db_session: Annotated[AsyncSession, Depends(get_session)],
session_id: str,
status: str = Query(..., description="New status: active, closed, expired"),
) -> JSONResponse:
"""
[AC-ASA-09] Update session status.
"""
logger.info(
f"[AC-ASA-09] Updating session status: tenant={tenant_id}, "
f"session_id={session_id}, status={status}"
)
stmt = select(ChatSession).where(
ChatSession.tenant_id == tenant_id,
ChatSession.session_id == session_id,
)
result = await db_session.execute(stmt)
chat_session = result.scalar_one_or_none()
if not chat_session:
return JSONResponse(
status_code=404,
content={
"code": "SESSION_NOT_FOUND",
"message": f"Session {session_id} not found",
},
)
metadata = chat_session.metadata_ or {}
metadata["status"] = status
if status == SessionStatus.CLOSED.value or status == SessionStatus.EXPIRED.value:
metadata["endTime"] = datetime.utcnow().isoformat() + "Z"
chat_session.metadata_ = metadata
chat_session.updated_at = datetime.utcnow()
await db_session.flush()
return JSONResponse(
content={
"success": True,
"sessionId": session_id,
"status": status,
}
)