""" Context management utilities for AI Service. [AC-AISVC-14, AC-AISVC-15] Context merging and truncation strategies. Design reference: design.md Section 7 - 上下文合并规则 - H_local: Memory layer history (sorted by time) - H_ext: External history from Java request (in passed order) - Deduplication: fingerprint = hash(role + "|" + normalized(content)) - Truncation: Keep most recent N messages within token budget """ import hashlib import logging from dataclasses import dataclass, field from typing import Any import tiktoken from app.core.config import get_settings from app.models import ChatMessage logger = logging.getLogger(__name__) @dataclass class MergedContext: """ Result of context merging. [AC-AISVC-14, AC-AISVC-15] Contains merged messages and diagnostics. """ messages: list[dict[str, str]] = field(default_factory=list) total_tokens: int = 0 local_count: int = 0 external_count: int = 0 duplicates_skipped: int = 0 truncated_count: int = 0 diagnostics: list[dict[str, Any]] = field(default_factory=list) class ContextMerger: """ [AC-AISVC-14, AC-AISVC-15] Context merger for combining local and external history. Design reference: design.md Section 7 - Deduplication based on message fingerprint - Priority: local history takes precedence - Token-based truncation using tiktoken """ def __init__( self, max_history_tokens: int | None = None, encoding_name: str = "cl100k_base", ): settings = get_settings() self._max_history_tokens = max_history_tokens or 4096 self._encoding = tiktoken.get_encoding(encoding_name) def compute_fingerprint(self, role: str, content: str) -> str: """ Compute message fingerprint for deduplication. [AC-AISVC-15] fingerprint = hash(role + "|" + normalized(content)) Args: role: Message role (user/assistant) content: Message content Returns: SHA256 hash of the normalized message """ normalized_content = content.strip() fingerprint_input = f"{role}|{normalized_content}" return hashlib.sha256(fingerprint_input.encode("utf-8")).hexdigest() def _message_to_dict(self, message: ChatMessage | dict[str, str]) -> dict[str, str]: """Convert ChatMessage or dict to standard dict format.""" if isinstance(message, ChatMessage): return {"role": message.role.value, "content": message.content} return message def _count_tokens(self, messages: list[dict[str, str]]) -> int: """ Count total tokens in messages using tiktoken. [AC-AISVC-14] Token counting for history truncation. """ total = 0 for msg in messages: total += len(self._encoding.encode(msg.get("role", ""))) total += len(self._encoding.encode(msg.get("content", ""))) total += 4 # Approximate overhead for message structure return total def merge_context( self, local_history: list[ChatMessage] | list[dict[str, str]] | None, external_history: list[ChatMessage] | list[dict[str, str]] | None, ) -> MergedContext: """ Merge local and external history with deduplication. [AC-AISVC-14, AC-AISVC-15] Implements context merging strategy. Design reference: design.md Section 7.2 1. Build seen set from H_local 2. Traverse H_ext, append if fingerprint not seen 3. Local history takes priority Args: local_history: History from Memory layer (H_local) external_history: History from Java request (H_ext) Returns: MergedContext with merged messages and diagnostics """ result = MergedContext() seen_fingerprints: set[str] = set() merged_messages: list[dict[str, str]] = [] diagnostics: list[dict[str, Any]] = [] local_messages = [self._message_to_dict(m) for m in (local_history or [])] external_messages = [self._message_to_dict(m) for m in (external_history or [])] for msg in local_messages: fingerprint = self.compute_fingerprint(msg["role"], msg["content"]) seen_fingerprints.add(fingerprint) merged_messages.append(msg) result.local_count += 1 for msg in external_messages: fingerprint = self.compute_fingerprint(msg["role"], msg["content"]) if fingerprint not in seen_fingerprints: seen_fingerprints.add(fingerprint) merged_messages.append(msg) result.external_count += 1 else: result.duplicates_skipped += 1 diagnostics.append({ "type": "duplicate_skipped", "role": msg["role"], "content_preview": msg["content"][:50] + "..." if len(msg["content"]) > 50 else msg["content"], }) result.messages = merged_messages result.diagnostics = diagnostics result.total_tokens = self._count_tokens(merged_messages) logger.info( f"[AC-AISVC-14, AC-AISVC-15] Context merged: " f"local={result.local_count}, external={result.external_count}, " f"duplicates_skipped={result.duplicates_skipped}, " f"total_tokens={result.total_tokens}" ) return result def truncate_context( self, messages: list[dict[str, str]], max_tokens: int | None = None, ) -> tuple[list[dict[str, str]], int]: """ Truncate context to fit within token budget. [AC-AISVC-14] Keep most recent N messages within budget. Design reference: design.md Section 7.4 - Budget = maxHistoryTokens (configurable) - Strategy: Keep most recent messages (from tail backward) Args: messages: List of messages to truncate max_tokens: Maximum token budget (uses default if not provided) Returns: Tuple of (truncated messages, truncated count) """ budget = max_tokens or self._max_history_tokens if not messages: return [], 0 total_tokens = self._count_tokens(messages) if total_tokens <= budget: return messages, 0 truncated_messages: list[dict[str, str]] = [] current_tokens = 0 truncated_count = 0 for msg in reversed(messages): msg_tokens = len(self._encoding.encode(msg.get("role", ""))) msg_tokens += len(self._encoding.encode(msg.get("content", ""))) msg_tokens += 4 if current_tokens + msg_tokens <= budget: truncated_messages.insert(0, msg) current_tokens += msg_tokens else: truncated_count += 1 logger.info( f"[AC-AISVC-14] Context truncated: " f"original={len(messages)}, truncated={len(truncated_messages)}, " f"removed={truncated_count}, tokens={current_tokens}/{budget}" ) return truncated_messages, truncated_count def merge_and_truncate( self, local_history: list[ChatMessage] | list[dict[str, str]] | None, external_history: list[ChatMessage] | list[dict[str, str]] | None, max_tokens: int | None = None, ) -> MergedContext: """ Merge and truncate context in one operation. [AC-AISVC-14, AC-AISVC-15] Complete context preparation pipeline. Args: local_history: History from Memory layer (H_local) external_history: History from Java request (H_ext) max_tokens: Maximum token budget Returns: MergedContext with final messages after merge and truncate """ merged = self.merge_context(local_history, external_history) truncated_messages, truncated_count = self.truncate_context( merged.messages, max_tokens ) merged.messages = truncated_messages merged.truncated_count = truncated_count merged.total_tokens = self._count_tokens(truncated_messages) return merged _context_merger: ContextMerger | None = None def get_context_merger() -> ContextMerger: """Get or create context merger instance.""" global _context_merger if _context_merger is None: _context_merger = ContextMerger() return _context_merger