""" Memory service for AI Service. [AC-AISVC-13] Session-based memory management with tenant isolation. """ import logging from collections.abc import Sequence from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import col from app.models.entities import ChatMessage, ChatSession logger = logging.getLogger(__name__) class MemoryService: """ [AC-AISVC-13] Memory service for session-based conversation history. All operations are scoped by (tenant_id, session_id) for multi-tenant isolation. """ def __init__(self, session: AsyncSession): self._session = session async def get_or_create_session( self, tenant_id: str, session_id: str, channel_type: str | None = None, metadata: dict | None = None, ) -> ChatSession: """ [AC-AISVC-13] Get existing session or create a new one. Ensures tenant isolation by querying with tenant_id. """ stmt = select(ChatSession).where( ChatSession.tenant_id == tenant_id, ChatSession.session_id == session_id, ) result = await self._session.execute(stmt) existing_session = result.scalar_one_or_none() if existing_session: logger.info( f"[AC-AISVC-13] Found existing session: tenant={tenant_id}, session={session_id}" ) return existing_session new_session = ChatSession( tenant_id=tenant_id, session_id=session_id, channel_type=channel_type, metadata_=metadata, ) self._session.add(new_session) await self._session.flush() logger.info( f"[AC-AISVC-13] Created new session: tenant={tenant_id}, session={session_id}" ) return new_session async def load_history( self, tenant_id: str, session_id: str, limit: int | None = None, ) -> Sequence[ChatMessage]: """ [AC-AISVC-13] Load conversation history for a session. All queries are filtered by tenant_id to ensure isolation. """ stmt = ( select(ChatMessage) .where( ChatMessage.tenant_id == tenant_id, ChatMessage.session_id == session_id, ) .order_by(col(ChatMessage.created_at).asc()) ) if limit: stmt = stmt.limit(limit) result = await self._session.execute(stmt) messages = result.scalars().all() logger.info( f"[AC-AISVC-13] Loaded {len(messages)} messages for tenant={tenant_id}, session={session_id}" ) return messages async def append_message( self, tenant_id: str, session_id: str, role: str, content: str, ) -> ChatMessage: """ [AC-AISVC-13] Append a message to the session history. Message is scoped by tenant_id for isolation. """ message = ChatMessage( tenant_id=tenant_id, session_id=session_id, role=role, content=content, ) self._session.add(message) await self._session.flush() logger.info( f"[AC-AISVC-13] Appended message: tenant={tenant_id}, session={session_id}, role={role}" ) return message async def append_messages( self, tenant_id: str, session_id: str, messages: list[dict[str, str]], ) -> list[ChatMessage]: """ [AC-AISVC-13] Append multiple messages to the session history. Used for batch insertion of conversation turns. """ chat_messages = [] for msg in messages: message = ChatMessage( tenant_id=tenant_id, session_id=session_id, role=msg["role"], content=msg["content"], ) self._session.add(message) chat_messages.append(message) await self._session.flush() logger.info( f"[AC-AISVC-13] Appended {len(chat_messages)} messages for tenant={tenant_id}, session={session_id}" ) return chat_messages async def clear_history(self, tenant_id: str, session_id: str) -> int: """ [AC-AISVC-13] Clear all messages for a session. Only affects messages within the tenant's scope. """ stmt = select(ChatMessage).where( ChatMessage.tenant_id == tenant_id, ChatMessage.session_id == session_id, ) result = await self._session.execute(stmt) messages = result.scalars().all() count = 0 for message in messages: await self._session.delete(message) count += 1 await self._session.flush() logger.info( f"[AC-AISVC-13] Cleared {count} messages for tenant={tenant_id}, session={session_id}" ) return count