246 lines
8.4 KiB
Python
246 lines
8.4 KiB
Python
|
|
"""
|
||
|
|
Context management utilities for AI Service.
|
||
|
|
[AC-AISVC-14, AC-AISVC-15] Context merging and truncation strategies.
|
||
|
|
|
||
|
|
Design reference: design.md Section 7 - 上下文合并规则
|
||
|
|
- H_local: Memory layer history (sorted by time)
|
||
|
|
- H_ext: External history from Java request (in passed order)
|
||
|
|
- Deduplication: fingerprint = hash(role + "|" + normalized(content))
|
||
|
|
- Truncation: Keep most recent N messages within token budget
|
||
|
|
"""
|
||
|
|
|
||
|
|
import hashlib
|
||
|
|
import logging
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
import tiktoken
|
||
|
|
|
||
|
|
from app.core.config import get_settings
|
||
|
|
from app.models import ChatMessage, Role
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class MergedContext:
|
||
|
|
"""
|
||
|
|
Result of context merging.
|
||
|
|
[AC-AISVC-14, AC-AISVC-15] Contains merged messages and diagnostics.
|
||
|
|
"""
|
||
|
|
messages: list[dict[str, str]] = field(default_factory=list)
|
||
|
|
total_tokens: int = 0
|
||
|
|
local_count: int = 0
|
||
|
|
external_count: int = 0
|
||
|
|
duplicates_skipped: int = 0
|
||
|
|
truncated_count: int = 0
|
||
|
|
diagnostics: list[dict[str, Any]] = field(default_factory=list)
|
||
|
|
|
||
|
|
|
||
|
|
class ContextMerger:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-14, AC-AISVC-15] Context merger for combining local and external history.
|
||
|
|
|
||
|
|
Design reference: design.md Section 7
|
||
|
|
- Deduplication based on message fingerprint
|
||
|
|
- Priority: local history takes precedence
|
||
|
|
- Token-based truncation using tiktoken
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
max_history_tokens: int | None = None,
|
||
|
|
encoding_name: str = "cl100k_base",
|
||
|
|
):
|
||
|
|
settings = get_settings()
|
||
|
|
self._max_history_tokens = max_history_tokens or 4096
|
||
|
|
self._encoding = tiktoken.get_encoding(encoding_name)
|
||
|
|
|
||
|
|
def compute_fingerprint(self, role: str, content: str) -> str:
|
||
|
|
"""
|
||
|
|
Compute message fingerprint for deduplication.
|
||
|
|
[AC-AISVC-15] fingerprint = hash(role + "|" + normalized(content))
|
||
|
|
|
||
|
|
Args:
|
||
|
|
role: Message role (user/assistant)
|
||
|
|
content: Message content
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
SHA256 hash of the normalized message
|
||
|
|
"""
|
||
|
|
normalized_content = content.strip()
|
||
|
|
fingerprint_input = f"{role}|{normalized_content}"
|
||
|
|
return hashlib.sha256(fingerprint_input.encode("utf-8")).hexdigest()
|
||
|
|
|
||
|
|
def _message_to_dict(self, message: ChatMessage | dict[str, str]) -> dict[str, str]:
|
||
|
|
"""Convert ChatMessage or dict to standard dict format."""
|
||
|
|
if isinstance(message, ChatMessage):
|
||
|
|
return {"role": message.role.value, "content": message.content}
|
||
|
|
return message
|
||
|
|
|
||
|
|
def _count_tokens(self, messages: list[dict[str, str]]) -> int:
|
||
|
|
"""
|
||
|
|
Count total tokens in messages using tiktoken.
|
||
|
|
[AC-AISVC-14] Token counting for history truncation.
|
||
|
|
"""
|
||
|
|
total = 0
|
||
|
|
for msg in messages:
|
||
|
|
total += len(self._encoding.encode(msg.get("role", "")))
|
||
|
|
total += len(self._encoding.encode(msg.get("content", "")))
|
||
|
|
total += 4 # Approximate overhead for message structure
|
||
|
|
return total
|
||
|
|
|
||
|
|
def merge_context(
|
||
|
|
self,
|
||
|
|
local_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||
|
|
external_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||
|
|
) -> MergedContext:
|
||
|
|
"""
|
||
|
|
Merge local and external history with deduplication.
|
||
|
|
[AC-AISVC-14, AC-AISVC-15] Implements context merging strategy.
|
||
|
|
|
||
|
|
Design reference: design.md Section 7.2
|
||
|
|
1. Build seen set from H_local
|
||
|
|
2. Traverse H_ext, append if fingerprint not seen
|
||
|
|
3. Local history takes priority
|
||
|
|
|
||
|
|
Args:
|
||
|
|
local_history: History from Memory layer (H_local)
|
||
|
|
external_history: History from Java request (H_ext)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
MergedContext with merged messages and diagnostics
|
||
|
|
"""
|
||
|
|
result = MergedContext()
|
||
|
|
seen_fingerprints: set[str] = set()
|
||
|
|
merged_messages: list[dict[str, str]] = []
|
||
|
|
diagnostics: list[dict[str, Any]] = []
|
||
|
|
|
||
|
|
local_messages = [self._message_to_dict(m) for m in (local_history or [])]
|
||
|
|
external_messages = [self._message_to_dict(m) for m in (external_history or [])]
|
||
|
|
|
||
|
|
for msg in local_messages:
|
||
|
|
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
|
||
|
|
seen_fingerprints.add(fingerprint)
|
||
|
|
merged_messages.append(msg)
|
||
|
|
result.local_count += 1
|
||
|
|
|
||
|
|
for msg in external_messages:
|
||
|
|
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
|
||
|
|
if fingerprint not in seen_fingerprints:
|
||
|
|
seen_fingerprints.add(fingerprint)
|
||
|
|
merged_messages.append(msg)
|
||
|
|
result.external_count += 1
|
||
|
|
else:
|
||
|
|
result.duplicates_skipped += 1
|
||
|
|
diagnostics.append({
|
||
|
|
"type": "duplicate_skipped",
|
||
|
|
"role": msg["role"],
|
||
|
|
"content_preview": msg["content"][:50] + "..." if len(msg["content"]) > 50 else msg["content"],
|
||
|
|
})
|
||
|
|
|
||
|
|
result.messages = merged_messages
|
||
|
|
result.diagnostics = diagnostics
|
||
|
|
result.total_tokens = self._count_tokens(merged_messages)
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-14, AC-AISVC-15] Context merged: "
|
||
|
|
f"local={result.local_count}, external={result.external_count}, "
|
||
|
|
f"duplicates_skipped={result.duplicates_skipped}, "
|
||
|
|
f"total_tokens={result.total_tokens}"
|
||
|
|
)
|
||
|
|
|
||
|
|
return result
|
||
|
|
|
||
|
|
def truncate_context(
|
||
|
|
self,
|
||
|
|
messages: list[dict[str, str]],
|
||
|
|
max_tokens: int | None = None,
|
||
|
|
) -> tuple[list[dict[str, str]], int]:
|
||
|
|
"""
|
||
|
|
Truncate context to fit within token budget.
|
||
|
|
[AC-AISVC-14] Keep most recent N messages within budget.
|
||
|
|
|
||
|
|
Design reference: design.md Section 7.4
|
||
|
|
- Budget = maxHistoryTokens (configurable)
|
||
|
|
- Strategy: Keep most recent messages (from tail backward)
|
||
|
|
|
||
|
|
Args:
|
||
|
|
messages: List of messages to truncate
|
||
|
|
max_tokens: Maximum token budget (uses default if not provided)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple of (truncated messages, truncated count)
|
||
|
|
"""
|
||
|
|
budget = max_tokens or self._max_history_tokens
|
||
|
|
if not messages:
|
||
|
|
return [], 0
|
||
|
|
|
||
|
|
total_tokens = self._count_tokens(messages)
|
||
|
|
if total_tokens <= budget:
|
||
|
|
return messages, 0
|
||
|
|
|
||
|
|
truncated_messages: list[dict[str, str]] = []
|
||
|
|
current_tokens = 0
|
||
|
|
truncated_count = 0
|
||
|
|
|
||
|
|
for msg in reversed(messages):
|
||
|
|
msg_tokens = len(self._encoding.encode(msg.get("role", "")))
|
||
|
|
msg_tokens += len(self._encoding.encode(msg.get("content", "")))
|
||
|
|
msg_tokens += 4
|
||
|
|
|
||
|
|
if current_tokens + msg_tokens <= budget:
|
||
|
|
truncated_messages.insert(0, msg)
|
||
|
|
current_tokens += msg_tokens
|
||
|
|
else:
|
||
|
|
truncated_count += 1
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-14] Context truncated: "
|
||
|
|
f"original={len(messages)}, truncated={len(truncated_messages)}, "
|
||
|
|
f"removed={truncated_count}, tokens={current_tokens}/{budget}"
|
||
|
|
)
|
||
|
|
|
||
|
|
return truncated_messages, truncated_count
|
||
|
|
|
||
|
|
def merge_and_truncate(
|
||
|
|
self,
|
||
|
|
local_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||
|
|
external_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||
|
|
max_tokens: int | None = None,
|
||
|
|
) -> MergedContext:
|
||
|
|
"""
|
||
|
|
Merge and truncate context in one operation.
|
||
|
|
[AC-AISVC-14, AC-AISVC-15] Complete context preparation pipeline.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
local_history: History from Memory layer (H_local)
|
||
|
|
external_history: History from Java request (H_ext)
|
||
|
|
max_tokens: Maximum token budget
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
MergedContext with final messages after merge and truncate
|
||
|
|
"""
|
||
|
|
merged = self.merge_context(local_history, external_history)
|
||
|
|
|
||
|
|
truncated_messages, truncated_count = self.truncate_context(
|
||
|
|
merged.messages, max_tokens
|
||
|
|
)
|
||
|
|
|
||
|
|
merged.messages = truncated_messages
|
||
|
|
merged.truncated_count = truncated_count
|
||
|
|
merged.total_tokens = self._count_tokens(truncated_messages)
|
||
|
|
|
||
|
|
return merged
|
||
|
|
|
||
|
|
|
||
|
|
_context_merger: ContextMerger | None = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_context_merger() -> ContextMerger:
|
||
|
|
"""Get or create context merger instance."""
|
||
|
|
global _context_merger
|
||
|
|
if _context_merger is None:
|
||
|
|
_context_merger = ContextMerger()
|
||
|
|
return _context_merger
|