338 lines
11 KiB
Python
338 lines
11 KiB
Python
"""
|
|
Tool Registry for Mid Platform.
|
|
[AC-IDMP-19] Unified tool registration, auth, timeout, version, and enable/disable governance.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Coroutine
|
|
|
|
from app.models.mid.schemas import (
|
|
ToolCallStatus,
|
|
ToolCallTrace,
|
|
ToolType,
|
|
)
|
|
from app.services.mid.timeout_governor import TimeoutGovernor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ToolDefinition:
|
|
"""Tool definition for registry."""
|
|
name: str
|
|
description: str
|
|
tool_type: ToolType = ToolType.INTERNAL
|
|
version: str = "1.0.0"
|
|
enabled: bool = True
|
|
auth_required: bool = False
|
|
timeout_ms: int = 2000
|
|
handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]] | None = None
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class ToolExecutionResult:
|
|
"""Tool execution result."""
|
|
success: bool
|
|
output: Any = None
|
|
error: str | None = None
|
|
duration_ms: int = 0
|
|
auth_applied: bool = False
|
|
registry_version: str | None = None
|
|
|
|
|
|
class ToolRegistry:
|
|
"""
|
|
[AC-IDMP-19] Unified tool registry for governance.
|
|
|
|
Features:
|
|
- Tool registration with metadata
|
|
- Auth policy enforcement
|
|
- Timeout governance
|
|
- Version management
|
|
- Enable/disable control
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
timeout_governor: TimeoutGovernor | None = None,
|
|
):
|
|
self._tools: dict[str, ToolDefinition] = {}
|
|
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
|
self._version = "1.0.0"
|
|
|
|
@property
|
|
def version(self) -> str:
|
|
"""Get registry version."""
|
|
return self._version
|
|
|
|
def register(
|
|
self,
|
|
name: str,
|
|
description: str,
|
|
handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]],
|
|
tool_type: ToolType = ToolType.INTERNAL,
|
|
version: str = "1.0.0",
|
|
auth_required: bool = False,
|
|
timeout_ms: int = 2000,
|
|
enabled: bool = True,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> ToolDefinition:
|
|
"""
|
|
[AC-IDMP-19] Register a tool.
|
|
|
|
Args:
|
|
name: Tool name (unique identifier)
|
|
description: Tool description
|
|
handler: Async handler function
|
|
tool_type: Tool type (internal/mcp)
|
|
version: Tool version
|
|
auth_required: Whether auth is required
|
|
timeout_ms: Tool-specific timeout
|
|
enabled: Whether tool is enabled
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
ToolDefinition for the registered tool
|
|
"""
|
|
if name in self._tools:
|
|
logger.warning(f"[AC-IDMP-19] Tool already registered, overwriting: {name}")
|
|
|
|
tool = ToolDefinition(
|
|
name=name,
|
|
description=description,
|
|
tool_type=tool_type,
|
|
version=version,
|
|
enabled=enabled,
|
|
auth_required=auth_required,
|
|
timeout_ms=min(timeout_ms, 2000),
|
|
handler=handler,
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
self._tools[name] = tool
|
|
|
|
logger.info(
|
|
f"[AC-IDMP-19] Tool registered: name={name}, type={tool_type.value}, "
|
|
f"version={version}, auth_required={auth_required}"
|
|
)
|
|
|
|
return tool
|
|
|
|
def unregister(self, name: str) -> bool:
|
|
"""Unregister a tool."""
|
|
if name in self._tools:
|
|
del self._tools[name]
|
|
logger.info(f"[AC-IDMP-19] Tool unregistered: {name}")
|
|
return True
|
|
return False
|
|
|
|
def get_tool(self, name: str) -> ToolDefinition | None:
|
|
"""Get tool definition by name."""
|
|
return self._tools.get(name)
|
|
|
|
def list_tools(
|
|
self,
|
|
tool_type: ToolType | None = None,
|
|
enabled_only: bool = True,
|
|
) -> list[ToolDefinition]:
|
|
"""List registered tools, optionally filtered."""
|
|
tools = list(self._tools.values())
|
|
|
|
if tool_type:
|
|
tools = [t for t in tools if t.tool_type == tool_type]
|
|
|
|
if enabled_only:
|
|
tools = [t for t in tools if t.enabled]
|
|
|
|
return tools
|
|
|
|
def enable_tool(self, name: str) -> bool:
|
|
"""Enable a tool."""
|
|
tool = self._tools.get(name)
|
|
if tool:
|
|
tool.enabled = True
|
|
logger.info(f"[AC-IDMP-19] Tool enabled: {name}")
|
|
return True
|
|
return False
|
|
|
|
def disable_tool(self, name: str) -> bool:
|
|
"""Disable a tool."""
|
|
tool = self._tools.get(name)
|
|
if tool:
|
|
tool.enabled = False
|
|
logger.info(f"[AC-IDMP-19] Tool disabled: {name}")
|
|
return True
|
|
return False
|
|
|
|
async def execute(
|
|
self,
|
|
tool_name: str,
|
|
args: dict[str, Any],
|
|
auth_context: dict[str, Any] | None = None,
|
|
) -> ToolExecutionResult:
|
|
"""
|
|
[AC-IDMP-19] Execute a tool with governance.
|
|
|
|
Args:
|
|
tool_name: Tool name to execute
|
|
args: Tool arguments
|
|
auth_context: Authentication context
|
|
|
|
Returns:
|
|
ToolExecutionResult with output and metadata
|
|
"""
|
|
start_time = time.time()
|
|
|
|
tool = self._tools.get(tool_name)
|
|
if not tool:
|
|
logger.warning(f"[AC-IDMP-19] Tool not found: {tool_name}")
|
|
return ToolExecutionResult(
|
|
success=False,
|
|
error=f"Tool not found: {tool_name}",
|
|
duration_ms=0,
|
|
)
|
|
|
|
if not tool.enabled:
|
|
logger.warning(f"[AC-IDMP-19] Tool disabled: {tool_name}")
|
|
return ToolExecutionResult(
|
|
success=False,
|
|
error=f"Tool disabled: {tool_name}",
|
|
duration_ms=0,
|
|
registry_version=tool.version,
|
|
)
|
|
|
|
auth_applied = False
|
|
if tool.auth_required:
|
|
if not auth_context:
|
|
logger.warning(f"[AC-IDMP-19] Auth required but no context: {tool_name}")
|
|
return ToolExecutionResult(
|
|
success=False,
|
|
error="Authentication required",
|
|
duration_ms=int((time.time() - start_time) * 1000),
|
|
auth_applied=False,
|
|
registry_version=tool.version,
|
|
)
|
|
auth_applied = True
|
|
|
|
try:
|
|
timeout_seconds = tool.timeout_ms / 1000.0
|
|
|
|
result = await asyncio.wait_for(
|
|
tool.handler(**args) if tool.handler else asyncio.sleep(0),
|
|
timeout=timeout_seconds,
|
|
)
|
|
|
|
duration_ms = int((time.time() - start_time) * 1000)
|
|
|
|
logger.info(
|
|
f"[AC-IDMP-19] Tool executed: name={tool_name}, "
|
|
f"duration_ms={duration_ms}, success=True"
|
|
)
|
|
|
|
return ToolExecutionResult(
|
|
success=True,
|
|
output=result,
|
|
duration_ms=duration_ms,
|
|
auth_applied=auth_applied,
|
|
registry_version=tool.version,
|
|
)
|
|
|
|
except asyncio.TimeoutError:
|
|
duration_ms = int((time.time() - start_time) * 1000)
|
|
logger.warning(
|
|
f"[AC-IDMP-19] Tool timeout: name={tool_name}, "
|
|
f"duration_ms={duration_ms}"
|
|
)
|
|
return ToolExecutionResult(
|
|
success=False,
|
|
error=f"Tool timeout after {tool.timeout_ms}ms",
|
|
duration_ms=duration_ms,
|
|
auth_applied=auth_applied,
|
|
registry_version=tool.version,
|
|
)
|
|
|
|
except Exception as e:
|
|
duration_ms = int((time.time() - start_time) * 1000)
|
|
logger.error(
|
|
f"[AC-IDMP-19] Tool error: name={tool_name}, error={e}"
|
|
)
|
|
return ToolExecutionResult(
|
|
success=False,
|
|
error=str(e),
|
|
duration_ms=duration_ms,
|
|
auth_applied=auth_applied,
|
|
registry_version=tool.version,
|
|
)
|
|
|
|
def create_trace(
|
|
self,
|
|
tool_name: str,
|
|
result: ToolExecutionResult,
|
|
args_digest: str | None = None,
|
|
) -> ToolCallTrace:
|
|
"""
|
|
[AC-IDMP-19] Create ToolCallTrace from execution result.
|
|
"""
|
|
tool = self._tools.get(tool_name)
|
|
|
|
return ToolCallTrace(
|
|
tool_name=tool_name,
|
|
tool_type=tool.tool_type if tool else ToolType.INTERNAL,
|
|
registry_version=result.registry_version,
|
|
auth_applied=result.auth_applied,
|
|
duration_ms=result.duration_ms,
|
|
status=ToolCallStatus.OK if result.success else (
|
|
ToolCallStatus.TIMEOUT if "timeout" in (result.error or "").lower()
|
|
else ToolCallStatus.ERROR
|
|
),
|
|
error_code=result.error if not result.success else None,
|
|
args_digest=args_digest,
|
|
result_digest=str(result.output)[:100] if result.output else None,
|
|
)
|
|
|
|
def get_governance_report(self) -> dict[str, Any]:
|
|
"""Get governance report for all tools."""
|
|
return {
|
|
"registry_version": self._version,
|
|
"total_tools": len(self._tools),
|
|
"enabled_tools": sum(1 for t in self._tools.values() if t.enabled),
|
|
"disabled_tools": sum(1 for t in self._tools.values() if not t.enabled),
|
|
"auth_required_tools": sum(1 for t in self._tools.values() if t.auth_required),
|
|
"mcp_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.MCP),
|
|
"internal_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.INTERNAL),
|
|
"tools": [
|
|
{
|
|
"name": t.name,
|
|
"type": t.tool_type.value,
|
|
"version": t.version,
|
|
"enabled": t.enabled,
|
|
"auth_required": t.auth_required,
|
|
"timeout_ms": t.timeout_ms,
|
|
}
|
|
for t in self._tools.values()
|
|
],
|
|
}
|
|
|
|
|
|
_registry: ToolRegistry | None = None
|
|
|
|
|
|
def get_tool_registry() -> ToolRegistry:
|
|
"""Get global tool registry instance."""
|
|
global _registry
|
|
if _registry is None:
|
|
_registry = ToolRegistry()
|
|
return _registry
|
|
|
|
|
|
def init_tool_registry(timeout_governor: TimeoutGovernor | None = None) -> ToolRegistry:
|
|
"""Initialize and return tool registry."""
|
|
global _registry
|
|
_registry = ToolRegistry(timeout_governor=timeout_governor)
|
|
return _registry
|