ai-robot-core/ai-service/app/services/embedding/ollama_provider.py

158 lines
4.8 KiB
Python

"""
Ollama embedding provider implementation.
[AC-AISVC-29, AC-AISVC-30] Ollama-based embedding provider.
Uses Ollama API for generating text embeddings.
"""
import logging
import time
from typing import Any
import httpx
from app.services.embedding.base import (
EmbeddingConfig,
EmbeddingException,
EmbeddingProvider,
)
logger = logging.getLogger(__name__)
class OllamaEmbeddingProvider(EmbeddingProvider):
"""
Embedding provider using Ollama API.
[AC-AISVC-29, AC-AISVC-30] Supports local embedding models via Ollama.
"""
PROVIDER_NAME = "ollama"
def __init__(
self,
base_url: str = "http://localhost:11434",
model: str = "nomic-embed-text",
dimension: int = 768,
timeout_seconds: int = 60,
**kwargs: Any,
):
self._base_url = base_url.rstrip("/")
self._model = model
self._dimension = dimension
self._timeout = timeout_seconds
self._client: httpx.AsyncClient | None = None
self._extra_config = kwargs
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=self._timeout)
return self._client
async def embed(self, text: str) -> list[float]:
"""
Generate embedding vector for a single text using Ollama API.
[AC-AISVC-29] Returns embedding vector.
"""
start_time = time.perf_counter()
try:
client = await self._get_client()
response = await client.post(
f"{self._base_url}/api/embeddings",
json={
"model": self._model,
"prompt": text,
}
)
response.raise_for_status()
data = response.json()
embedding = data.get("embedding", [])
if not embedding:
raise EmbeddingException(
"Empty embedding returned",
provider=self.PROVIDER_NAME,
details={"text_length": len(text)}
)
latency_ms = (time.perf_counter() - start_time) * 1000
logger.debug(
f"Generated embedding via Ollama: dim={len(embedding)}, "
f"latency={latency_ms:.2f}ms"
)
return embedding
except httpx.HTTPStatusError as e:
raise EmbeddingException(
f"Ollama API error: {e.response.status_code}",
provider=self.PROVIDER_NAME,
details={"status_code": e.response.status_code, "response": e.response.text}
)
except httpx.RequestError as e:
raise EmbeddingException(
f"Ollama connection error: {e}",
provider=self.PROVIDER_NAME,
details={"base_url": self._base_url}
)
except EmbeddingException:
raise
except Exception as e:
raise EmbeddingException(
f"Embedding generation failed: {e}",
provider=self.PROVIDER_NAME
)
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""
Generate embedding vectors for multiple texts.
[AC-AISVC-29] Sequential embedding generation.
"""
embeddings = []
for text in texts:
embedding = await self.embed(text)
embeddings.append(embedding)
return embeddings
def get_dimension(self) -> int:
"""Get the dimension of embedding vectors."""
return self._dimension
def get_provider_name(self) -> str:
"""Get the name of this embedding provider."""
return self.PROVIDER_NAME
def get_config_schema(self) -> dict[str, Any]:
"""
Get the configuration schema for Ollama provider.
[AC-AISVC-38] Returns JSON Schema for configuration parameters.
"""
return {
"base_url": {
"type": "string",
"description": "Ollama API 地址",
"default": "http://localhost:11434",
},
"model": {
"type": "string",
"description": "嵌入模型名称",
"default": "nomic-embed-text",
},
"dimension": {
"type": "integer",
"description": "向量维度",
"default": 768,
},
"timeout_seconds": {
"type": "integer",
"description": "请求超时时间(秒)",
"default": 60,
},
}
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None