ai-robot-core/ai-service/app/services/mid/trace_logger.py

270 lines
8.2 KiB
Python
Raw Normal View History

"""
Trace Logger for Mid Platform.
[AC-MARH-02, AC-MARH-03, AC-MARH-12] Trace collection and audit logging.
Audit Fields:
- session_id, request_id, generation_id
- mode, intent, tool_calls
- guardrail_triggered, guardrail_rule_id
- interrupt_consumed
"""
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
from app.models.mid.schemas import (
ExecutionMode,
ToolCallTrace,
TraceInfo,
)
logger = logging.getLogger(__name__)
@dataclass
class AuditRecord:
"""[AC-MARH-12] Audit record for database persistence."""
tenant_id: str
session_id: str
request_id: str
generation_id: str
mode: ExecutionMode
intent: str | None = None
tool_calls: list[dict[str, Any]] = field(default_factory=list)
guardrail_triggered: bool = False
guardrail_rule_id: str | None = None
interrupt_consumed: bool = False
fallback_reason_code: str | None = None
react_iterations: int = 0
latency_ms: int = 0
created_at: str | None = None
def to_dict(self) -> dict[str, Any]:
return {
"tenant_id": self.tenant_id,
"session_id": self.session_id,
"request_id": self.request_id,
"generation_id": self.generation_id,
"mode": self.mode.value,
"intent": self.intent,
"tool_calls": self.tool_calls,
"guardrail_triggered": self.guardrail_triggered,
"guardrail_rule_id": self.guardrail_rule_id,
"interrupt_consumed": self.interrupt_consumed,
"fallback_reason_code": self.fallback_reason_code,
"react_iterations": self.react_iterations,
"latency_ms": self.latency_ms,
"created_at": self.created_at,
}
class TraceLogger:
"""
[AC-MARH-02, AC-MARH-03, AC-MARH-12] Trace logger for observability and audit.
Features:
- Request-scoped trace context
- Tool call tracing
- Guardrail event logging with rule_id
- Interrupt consumption tracking
- Audit record generation
"""
def __init__(self):
self._traces: dict[str, TraceInfo] = {}
self._audit_records: list[AuditRecord] = []
def start_trace(
self,
tenant_id: str,
session_id: str,
request_id: str | None = None,
generation_id: str | None = None,
) -> TraceInfo:
"""
[AC-MARH-12] Start a new trace context.
Args:
tenant_id: Tenant ID
session_id: Session ID
request_id: Request ID (auto-generated if not provided)
generation_id: Generation ID for interrupt handling
Returns:
TraceInfo for the new trace
"""
request_id = request_id or str(uuid.uuid4())
generation_id = generation_id or str(uuid.uuid4())
trace = TraceInfo(
mode=ExecutionMode.AGENT,
request_id=request_id,
generation_id=generation_id,
)
self._traces[request_id] = trace
logger.info(
f"[AC-MARH-12] Trace started: request_id={request_id}, "
f"session_id={session_id}, generation_id={generation_id}"
)
return trace
def get_trace(self, request_id: str) -> TraceInfo | None:
"""Get trace by request ID."""
return self._traces.get(request_id)
def update_trace(
self,
request_id: str,
mode: ExecutionMode | None = None,
intent: str | None = None,
guardrail_triggered: bool | None = None,
guardrail_rule_id: str | None = None,
interrupt_consumed: bool | None = None,
fallback_reason_code: str | None = None,
react_iterations: int | None = None,
tool_calls: list[ToolCallTrace] | None = None,
) -> TraceInfo | None:
"""
[AC-MARH-02, AC-MARH-03, AC-MARH-12] Update trace with execution details.
"""
trace = self._traces.get(request_id)
if not trace:
logger.warning(f"[AC-MARH-12] Trace not found: {request_id}")
return None
if mode is not None:
trace.mode = mode
if intent is not None:
trace.intent = intent
if guardrail_triggered is not None:
trace.guardrail_triggered = guardrail_triggered
if guardrail_rule_id is not None:
trace.guardrail_rule_id = guardrail_rule_id
if interrupt_consumed is not None:
trace.interrupt_consumed = interrupt_consumed
if fallback_reason_code is not None:
trace.fallback_reason_code = fallback_reason_code
if react_iterations is not None:
trace.react_iterations = react_iterations
if tool_calls is not None:
trace.tool_calls = tool_calls
return trace
def add_tool_call(
self,
request_id: str,
tool_call: ToolCallTrace,
) -> None:
"""
[AC-MARH-12] Add tool call trace to request.
"""
trace = self._traces.get(request_id)
if not trace:
logger.warning(f"[AC-MARH-12] Trace not found for tool call: {request_id}")
return
if trace.tool_calls is None:
trace.tool_calls = []
trace.tool_calls.append(tool_call)
if trace.tools_used is None:
trace.tools_used = []
if tool_call.tool_name not in trace.tools_used:
trace.tools_used.append(tool_call.tool_name)
logger.debug(
f"[AC-MARH-12] Tool call recorded: request_id={request_id}, "
f"tool={tool_call.tool_name}, status={tool_call.status.value}"
)
def end_trace(
self,
request_id: str,
tenant_id: str,
session_id: str,
latency_ms: int,
) -> AuditRecord:
"""
[AC-MARH-12] End trace and create audit record.
"""
trace = self._traces.get(request_id)
if not trace:
logger.warning(f"[AC-MARH-12] Trace not found for end: {request_id}")
return AuditRecord(
tenant_id=tenant_id,
session_id=session_id,
request_id=request_id,
generation_id=str(uuid.uuid4()),
mode=ExecutionMode.FIXED,
latency_ms=latency_ms,
)
audit = AuditRecord(
tenant_id=tenant_id,
session_id=session_id,
request_id=request_id,
generation_id=trace.generation_id or str(uuid.uuid4()),
mode=trace.mode,
intent=trace.intent,
tool_calls=[tc.model_dump() for tc in trace.tool_calls] if trace.tool_calls else [],
guardrail_triggered=trace.guardrail_triggered or False,
guardrail_rule_id=trace.guardrail_rule_id,
interrupt_consumed=trace.interrupt_consumed or False,
fallback_reason_code=trace.fallback_reason_code,
react_iterations=trace.react_iterations or 0,
latency_ms=latency_ms,
created_at=time.strftime("%Y-%m-%d %H:%M:%S"),
)
self._audit_records.append(audit)
if request_id in self._traces:
del self._traces[request_id]
logger.info(
f"[AC-MARH-12] Trace ended: request_id={request_id}, "
f"mode={trace.mode.value}, latency_ms={latency_ms}"
)
return audit
def get_audit_records(
self,
tenant_id: str,
session_id: str | None = None,
limit: int = 100,
) -> list[AuditRecord]:
"""Get audit records for a tenant/session."""
records = [
r for r in self._audit_records
if r.tenant_id == tenant_id
]
if session_id:
records = [r for r in records if r.session_id == session_id]
return records[-limit:]
def clear_audit_records(self, tenant_id: str | None = None) -> int:
"""Clear audit records, optionally filtered by tenant."""
if tenant_id:
original_count = len(self._audit_records)
self._audit_records = [
r for r in self._audit_records
if r.tenant_id != tenant_id
]
return original_count - len(self._audit_records)
else:
count = len(self._audit_records)
self._audit_records = []
return count