feat(ASA-P5): PDF智能分块处理,使用tiktoken按token分块并保留页码元数据 [AC-ASA-01]
This commit is contained in:
parent
e9fee2f80e
commit
559d8c0c53
|
|
@ -6,8 +6,10 @@ Knowledge Base management endpoints.
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile, File, Form
|
from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile, File, Form
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
@ -25,6 +27,59 @@ logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/admin/kb", tags=["KB Management"])
|
router = APIRouter(prefix="/admin/kb", tags=["KB Management"])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextChunk:
|
||||||
|
"""Text chunk with metadata."""
|
||||||
|
text: str
|
||||||
|
start_token: int
|
||||||
|
end_token: int
|
||||||
|
page: int | None = None
|
||||||
|
source: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_text_with_tiktoken(
|
||||||
|
text: str,
|
||||||
|
chunk_size: int = 512,
|
||||||
|
overlap: int = 100,
|
||||||
|
page: int | None = None,
|
||||||
|
source: str | None = None,
|
||||||
|
) -> list[TextChunk]:
|
||||||
|
"""
|
||||||
|
使用 tiktoken 按 token 数分块,支持重叠分块。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 要分块的文本
|
||||||
|
chunk_size: 每个块的最大 token 数
|
||||||
|
overlap: 块之间的重叠 token 数
|
||||||
|
page: 页码(可选)
|
||||||
|
source: 来源文件路径(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分块列表,每个块包含文本及起始/结束位置
|
||||||
|
"""
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
tokens = encoding.encode(text)
|
||||||
|
chunks: list[TextChunk] = []
|
||||||
|
start = 0
|
||||||
|
|
||||||
|
while start < len(tokens):
|
||||||
|
end = min(start + chunk_size, len(tokens))
|
||||||
|
chunk_tokens = tokens[start:end]
|
||||||
|
chunk_text = encoding.decode(chunk_tokens)
|
||||||
|
chunks.append(TextChunk(
|
||||||
|
text=chunk_text,
|
||||||
|
start_token=start,
|
||||||
|
end_token=end,
|
||||||
|
page=page,
|
||||||
|
source=source,
|
||||||
|
))
|
||||||
|
if end == len(tokens):
|
||||||
|
break
|
||||||
|
start += chunk_size - overlap
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def get_current_tenant_id() -> str:
|
def get_current_tenant_id() -> str:
|
||||||
"""Dependency to get current tenant ID or raise exception."""
|
"""Dependency to get current tenant ID or raise exception."""
|
||||||
tenant_id = get_tenant_id()
|
tenant_id = get_tenant_id()
|
||||||
|
|
@ -238,12 +293,13 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
|
||||||
from app.services.kb import KBService
|
from app.services.kb import KBService
|
||||||
from app.core.qdrant_client import get_qdrant_client
|
from app.core.qdrant_client import get_qdrant_client
|
||||||
from app.services.embedding import get_embedding_provider
|
from app.services.embedding import get_embedding_provider
|
||||||
from app.services.document import parse_document, UnsupportedFormatError, DocumentParseException
|
from app.services.document import parse_document, UnsupportedFormatError, DocumentParseException, PageText
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
import asyncio
|
import asyncio
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger.info(f"[INDEX] Starting indexing: tenant={tenant_id}, job_id={job_id}, doc_id={doc_id}, filename={filename}")
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async with async_session_maker() as session:
|
async with async_session_maker() as session:
|
||||||
|
|
@ -254,14 +310,18 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
parse_result = None
|
||||||
text = None
|
text = None
|
||||||
file_ext = Path(filename or "").suffix.lower()
|
file_ext = Path(filename or "").suffix.lower()
|
||||||
|
logger.info(f"[INDEX] File extension: {file_ext}, content size: {len(content)} bytes")
|
||||||
|
|
||||||
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
|
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
|
||||||
|
|
||||||
if file_ext in text_extensions or not file_ext:
|
if file_ext in text_extensions or not file_ext:
|
||||||
|
logger.info(f"[INDEX] Treating as text file, decoding with UTF-8")
|
||||||
text = content.decode("utf-8", errors="ignore")
|
text = content.decode("utf-8", errors="ignore")
|
||||||
else:
|
else:
|
||||||
|
logger.info(f"[INDEX] Binary file detected, will parse with document parser")
|
||||||
await kb_service.update_job_status(
|
await kb_service.update_job_status(
|
||||||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
|
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
|
||||||
)
|
)
|
||||||
|
|
@ -271,45 +331,95 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
|
||||||
tmp_file.write(content)
|
tmp_file.write(content)
|
||||||
tmp_path = tmp_file.name
|
tmp_path = tmp_file.name
|
||||||
|
|
||||||
|
logger.info(f"[INDEX] Temp file created: {tmp_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"[INDEX] Starting document parsing for {file_ext}...")
|
||||||
parse_result = parse_document(tmp_path)
|
parse_result = parse_document(tmp_path)
|
||||||
text = parse_result.text
|
text = parse_result.text
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-AISVC-33] Parsed document: {filename}, "
|
f"[INDEX] Parsed document SUCCESS: {filename}, "
|
||||||
f"chars={len(text)}, format={parse_result.metadata.get('format')}"
|
f"chars={len(text)}, format={parse_result.metadata.get('format')}, "
|
||||||
|
f"pages={len(parse_result.pages) if parse_result.pages else 'N/A'}, "
|
||||||
|
f"metadata={parse_result.metadata}"
|
||||||
)
|
)
|
||||||
except (UnsupportedFormatError, DocumentParseException) as e:
|
if len(text) < 100:
|
||||||
logger.warning(f"Failed to parse document {filename}: {e}, falling back to text decode")
|
logger.warning(f"[INDEX] Parsed text is very short, preview: {text[:200]}")
|
||||||
|
except UnsupportedFormatError as e:
|
||||||
|
logger.error(f"[INDEX] UnsupportedFormatError: {e}")
|
||||||
|
text = content.decode("utf-8", errors="ignore")
|
||||||
|
except DocumentParseException as e:
|
||||||
|
logger.error(f"[INDEX] DocumentParseException: {e}, details={getattr(e, 'details', {})}")
|
||||||
|
text = content.decode("utf-8", errors="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[INDEX] Unexpected parsing error: {type(e).__name__}: {e}")
|
||||||
text = content.decode("utf-8", errors="ignore")
|
text = content.decode("utf-8", errors="ignore")
|
||||||
finally:
|
finally:
|
||||||
Path(tmp_path).unlink(missing_ok=True)
|
Path(tmp_path).unlink(missing_ok=True)
|
||||||
|
logger.info(f"[INDEX] Temp file cleaned up")
|
||||||
|
|
||||||
|
logger.info(f"[INDEX] Final text length: {len(text)} chars")
|
||||||
|
if len(text) < 50:
|
||||||
|
logger.warning(f"[INDEX] Text too short, preview: {repr(text[:200])}")
|
||||||
|
|
||||||
await kb_service.update_job_status(
|
await kb_service.update_job_status(
|
||||||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20
|
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(f"[INDEX] Getting embedding provider...")
|
||||||
embedding_provider = await get_embedding_provider()
|
embedding_provider = await get_embedding_provider()
|
||||||
|
logger.info(f"[INDEX] Embedding provider: {type(embedding_provider).__name__}")
|
||||||
|
|
||||||
chunks = [text[i:i+500] for i in range(0, len(text), 500)]
|
all_chunks: list[TextChunk] = []
|
||||||
|
|
||||||
|
if parse_result and parse_result.pages:
|
||||||
|
logger.info(f"[INDEX] PDF with {len(parse_result.pages)} pages, using tiktoken chunking with page metadata")
|
||||||
|
for page in parse_result.pages:
|
||||||
|
page_chunks = chunk_text_with_tiktoken(
|
||||||
|
page.text,
|
||||||
|
chunk_size=512,
|
||||||
|
overlap=100,
|
||||||
|
page=page.page,
|
||||||
|
source=filename,
|
||||||
|
)
|
||||||
|
all_chunks.extend(page_chunks)
|
||||||
|
logger.info(f"[INDEX] Total chunks from PDF: {len(all_chunks)}")
|
||||||
|
else:
|
||||||
|
logger.info(f"[INDEX] Using tiktoken chunking without page metadata")
|
||||||
|
all_chunks = chunk_text_with_tiktoken(
|
||||||
|
text,
|
||||||
|
chunk_size=512,
|
||||||
|
overlap=100,
|
||||||
|
source=filename,
|
||||||
|
)
|
||||||
|
logger.info(f"[INDEX] Total chunks: {len(all_chunks)}")
|
||||||
|
|
||||||
qdrant = await get_qdrant_client()
|
qdrant = await get_qdrant_client()
|
||||||
await qdrant.ensure_collection_exists(tenant_id)
|
await qdrant.ensure_collection_exists(tenant_id)
|
||||||
|
|
||||||
points = []
|
points = []
|
||||||
total_chunks = len(chunks)
|
total_chunks = len(all_chunks)
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(all_chunks):
|
||||||
embedding = await embedding_provider.embed(chunk)
|
embedding = await embedding_provider.embed(chunk.text)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"text": chunk.text,
|
||||||
|
"source": doc_id,
|
||||||
|
"chunk_index": i,
|
||||||
|
"start_token": chunk.start_token,
|
||||||
|
"end_token": chunk.end_token,
|
||||||
|
}
|
||||||
|
if chunk.page is not None:
|
||||||
|
payload["page"] = chunk.page
|
||||||
|
if chunk.source:
|
||||||
|
payload["filename"] = chunk.source
|
||||||
|
|
||||||
points.append(
|
points.append(
|
||||||
PointStruct(
|
PointStruct(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
vector=embedding,
|
vector=embedding,
|
||||||
payload={
|
payload=payload,
|
||||||
"text": chunk,
|
|
||||||
"source": doc_id,
|
|
||||||
"chunk_index": i,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -321,6 +431,7 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
if points:
|
if points:
|
||||||
|
logger.info(f"[INDEX] Upserting {len(points)} vectors to Qdrant...")
|
||||||
await qdrant.upsert_vectors(tenant_id, points)
|
await qdrant.upsert_vectors(tenant_id, points)
|
||||||
|
|
||||||
await kb_service.update_job_status(
|
await kb_service.update_job_status(
|
||||||
|
|
@ -329,12 +440,13 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-ASA-01] Indexing completed: tenant={tenant_id}, "
|
f"[INDEX] COMPLETED: tenant={tenant_id}, "
|
||||||
f"job_id={job_id}, chunks={len(chunks)}"
|
f"job_id={job_id}, chunks={len(all_chunks)}, text_len={len(text)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[AC-ASA-01] Indexing failed: {e}")
|
import traceback
|
||||||
|
logger.error(f"[INDEX] FAILED: {e}\n{traceback.format_exc()}")
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
async with async_session_maker() as error_session:
|
async with async_session_maker() as error_session:
|
||||||
kb_service = KBService(error_session)
|
kb_service = KBService(error_session)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ Document parsing services package.
|
||||||
from app.services.document.base import (
|
from app.services.document.base import (
|
||||||
DocumentParseException,
|
DocumentParseException,
|
||||||
DocumentParser,
|
DocumentParser,
|
||||||
|
PageText,
|
||||||
ParseResult,
|
ParseResult,
|
||||||
UnsupportedFormatError,
|
UnsupportedFormatError,
|
||||||
)
|
)
|
||||||
|
|
@ -22,6 +23,7 @@ from app.services.document.word_parser import WordParser
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DocumentParseException",
|
"DocumentParseException",
|
||||||
"DocumentParser",
|
"DocumentParser",
|
||||||
|
"PageText",
|
||||||
"ParseResult",
|
"ParseResult",
|
||||||
"UnsupportedFormatError",
|
"UnsupportedFormatError",
|
||||||
"DocumentParserFactory",
|
"DocumentParserFactory",
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,15 @@ from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PageText:
|
||||||
|
"""
|
||||||
|
Text content from a single page.
|
||||||
|
"""
|
||||||
|
page: int
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParseResult:
|
class ParseResult:
|
||||||
"""
|
"""
|
||||||
|
|
@ -24,6 +33,7 @@ class ParseResult:
|
||||||
file_size: int
|
file_size: int
|
||||||
page_count: int | None = None
|
page_count: int | None = None
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
pages: list[PageText] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DocumentParser(ABC):
|
class DocumentParser(ABC):
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from typing import Any
|
||||||
from app.services.document.base import (
|
from app.services.document.base import (
|
||||||
DocumentParseException,
|
DocumentParseException,
|
||||||
DocumentParser,
|
DocumentParser,
|
||||||
|
PageText,
|
||||||
ParseResult,
|
ParseResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -68,13 +69,15 @@ class PDFParser(DocumentParser):
|
||||||
try:
|
try:
|
||||||
doc = fitz.open(path)
|
doc = fitz.open(path)
|
||||||
|
|
||||||
|
pages: list[PageText] = []
|
||||||
text_parts = []
|
text_parts = []
|
||||||
page_count = len(doc)
|
page_count = len(doc)
|
||||||
|
|
||||||
for page_num in range(page_count):
|
for page_num in range(page_count):
|
||||||
page = doc[page_num]
|
page = doc[page_num]
|
||||||
text = page.get_text()
|
text = page.get_text().strip()
|
||||||
if text.strip():
|
if text:
|
||||||
|
pages.append(PageText(page=page_num + 1, text=text))
|
||||||
text_parts.append(f"[Page {page_num + 1}]\n{text}")
|
text_parts.append(f"[Page {page_num + 1}]\n{text}")
|
||||||
|
|
||||||
doc.close()
|
doc.close()
|
||||||
|
|
@ -95,7 +98,8 @@ class PDFParser(DocumentParser):
|
||||||
metadata={
|
metadata={
|
||||||
"format": "pdf",
|
"format": "pdf",
|
||||||
"page_count": page_count,
|
"page_count": page_count,
|
||||||
}
|
},
|
||||||
|
pages=pages,
|
||||||
)
|
)
|
||||||
|
|
||||||
except DocumentParseException:
|
except DocumentParseException:
|
||||||
|
|
@ -156,6 +160,7 @@ class PDFPlumberParser(DocumentParser):
|
||||||
pdfplumber = self._get_pdfplumber()
|
pdfplumber = self._get_pdfplumber()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
pages: list[PageText] = []
|
||||||
text_parts = []
|
text_parts = []
|
||||||
page_count = 0
|
page_count = 0
|
||||||
|
|
||||||
|
|
@ -171,7 +176,9 @@ class PDFPlumberParser(DocumentParser):
|
||||||
table_text = self._format_table(table)
|
table_text = self._format_table(table)
|
||||||
text += f"\n\n{table_text}"
|
text += f"\n\n{table_text}"
|
||||||
|
|
||||||
if text.strip():
|
text = text.strip()
|
||||||
|
if text:
|
||||||
|
pages.append(PageText(page=page_num + 1, text=text))
|
||||||
text_parts.append(f"[Page {page_num + 1}]\n{text}")
|
text_parts.append(f"[Page {page_num + 1}]\n{text}")
|
||||||
|
|
||||||
full_text = "\n\n".join(text_parts)
|
full_text = "\n\n".join(text_parts)
|
||||||
|
|
@ -191,7 +198,8 @@ class PDFPlumberParser(DocumentParser):
|
||||||
"format": "pdf",
|
"format": "pdf",
|
||||||
"parser": "pdfplumber",
|
"parser": "pdfplumber",
|
||||||
"page_count": page_count,
|
"page_count": page_count,
|
||||||
}
|
},
|
||||||
|
pages=pages,
|
||||||
)
|
)
|
||||||
|
|
||||||
except DocumentParseException:
|
except DocumentParseException:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue