From 66fa2d2677704c46825b1b81f63cb98b0b3b3b36 Mon Sep 17 00:00:00 2001 From: MerCry Date: Tue, 24 Feb 2026 13:31:42 +0800 Subject: [PATCH] feat(ai-service): implement confidence calculation for T3.3 [AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] - Add ConfidenceCalculator class for confidence scoring - Implement retrieval insufficiency detection (hit count, score threshold, evidence tokens) - Implement confidence calculation based on retrieval scores - Implement shouldTransfer logic with configurable threshold - Add transferReason for low confidence scenarios - Add comprehensive unit tests (19 test cases) - Update config with confidence-related settings --- ai-service/app/api/chat.py | 55 ++- ai-service/app/core/config.py | 4 +- ai-service/app/services/confidence.py | 224 ++++++++++++ ai-service/tests/test_confidence.py | 302 +++++++++++++++++ ai-service/tests/test_sse_state_machine.py | 376 +++++++++++++++++++++ 5 files changed, 952 insertions(+), 9 deletions(-) create mode 100644 ai-service/app/services/confidence.py create mode 100644 ai-service/tests/test_confidence.py create mode 100644 ai-service/tests/test_sse_state_machine.py diff --git a/ai-service/app/api/chat.py b/ai-service/app/api/chat.py index fd58bd8..4eac671 100644 --- a/ai-service/app/api/chat.py +++ b/ai-service/app/api/chat.py @@ -1,6 +1,6 @@ """ Chat endpoint for AI Service. -[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06] Main chat endpoint with streaming/non-streaming modes. +[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-08, AC-AISVC-09] Main chat endpoint with streaming/non-streaming modes. """ import logging @@ -11,7 +11,7 @@ from fastapi.responses import JSONResponse from sse_starlette.sse import EventSourceResponse from app.core.middleware import get_response_mode, is_sse_request -from app.core.sse import create_error_event +from app.core.sse import SSEStateMachine, create_error_event from app.core.tenant import get_tenant_id from app.models import ChatRequest, ChatResponse, ErrorResponse from app.services.orchestrator import OrchestratorService, get_orchestrator_service @@ -109,19 +109,58 @@ async def _handle_streaming_request( ) -> EventSourceResponse: """ [AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Handle SSE streaming request. - Yields message events followed by final or error event. + + SSE Event Sequence (per design.md Section 6.2): + - message* (0 or more) -> final (exactly 1) -> close + - OR message* (0 or more) -> error (exactly 1) -> close + + State machine ensures: + - No events after final/error + - Only one final OR one error event + - Proper connection close """ logger.info(f"[AC-AISVC-06] Processing SSE request for tenant={tenant_id}") + state_machine = SSEStateMachine() + async def event_generator(): + """ + [AC-AISVC-08, AC-AISVC-09] Event generator with state machine enforcement. + Ensures proper event sequence and error handling. + """ + await state_machine.transition_to_streaming() + try: async for event in orchestrator.generate_stream(tenant_id, chat_request): - yield event + if not state_machine.can_send_message(): + logger.warning("[AC-AISVC-08] Received event after state machine closed, ignoring") + break + + if event.event == "final": + if await state_machine.transition_to_final(): + logger.info("[AC-AISVC-08] Sending final event and closing stream") + yield event + break + + elif event.event == "error": + if await state_machine.transition_to_error(): + logger.info("[AC-AISVC-09] Sending error event and closing stream") + yield event + break + + elif event.event == "message": + yield event + except Exception as e: logger.error(f"[AC-AISVC-09] Streaming error: {e}") - yield create_error_event( - code="STREAMING_ERROR", - message=str(e), - ) + if await state_machine.transition_to_error(): + yield create_error_event( + code="STREAMING_ERROR", + message=str(e), + ) + + finally: + await state_machine.close() + logger.debug("[AC-AISVC-08] SSE connection closed") return EventSourceResponse(event_generator(), ping=15) diff --git a/ai-service/app/core/config.py b/ai-service/app/core/config.py index 913df53..27d282c 100644 --- a/ai-service/app/core/config.py +++ b/ai-service/app/core/config.py @@ -45,7 +45,9 @@ class Settings(BaseSettings): rag_min_hits: int = 1 rag_max_evidence_tokens: int = 2000 - confidence_threshold_low: float = 0.5 + confidence_low_threshold: float = 0.5 + confidence_high_threshold: float = 0.8 + confidence_insufficient_penalty: float = 0.3 max_history_tokens: int = 4000 diff --git a/ai-service/app/services/confidence.py b/ai-service/app/services/confidence.py new file mode 100644 index 0000000..8cfdc25 --- /dev/null +++ b/ai-service/app/services/confidence.py @@ -0,0 +1,224 @@ +""" +Confidence calculation for AI Service. +[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Confidence scoring and transfer suggestion logic. + +Design reference: design.md Section 4.3 - 检索不中兜底与置信度策略 +- Retrieval insufficiency detection +- Confidence calculation based on retrieval scores +- shouldTransfer logic with threshold T_low +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +from app.core.config import get_settings +from app.services.retrieval.base import RetrievalResult + +logger = logging.getLogger(__name__) + + +@dataclass +class ConfidenceConfig: + """ + Configuration for confidence calculation. + [AC-AISVC-17, AC-AISVC-18] Configurable thresholds. + """ + score_threshold: float = 0.7 + min_hits: int = 1 + confidence_low_threshold: float = 0.5 + confidence_high_threshold: float = 0.8 + insufficient_penalty: float = 0.3 + max_evidence_tokens: int = 2000 + + +@dataclass +class ConfidenceResult: + """ + Result of confidence calculation. + [AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Contains confidence and transfer suggestion. + """ + confidence: float + should_transfer: bool + transfer_reason: str | None = None + is_retrieval_insufficient: bool = False + diagnostics: dict[str, Any] = field(default_factory=dict) + + +class ConfidenceCalculator: + """ + [AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculator for response confidence. + + Design reference: design.md Section 4.3 + - MVP: confidence based on RAG retrieval scores + - Insufficient retrieval triggers confidence downgrade + - shouldTransfer when confidence < T_low + """ + + def __init__(self, config: ConfidenceConfig | None = None): + settings = get_settings() + self._config = config or ConfidenceConfig( + score_threshold=getattr(settings, "rag_score_threshold", 0.7), + min_hits=getattr(settings, "rag_min_hits", 1), + confidence_low_threshold=getattr(settings, "confidence_low_threshold", 0.5), + confidence_high_threshold=getattr(settings, "confidence_high_threshold", 0.8), + insufficient_penalty=getattr(settings, "confidence_insufficient_penalty", 0.3), + max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000), + ) + + def is_retrieval_insufficient( + self, + retrieval_result: RetrievalResult, + evidence_tokens: int | None = None, + ) -> tuple[bool, str]: + """ + [AC-AISVC-17] Determine if retrieval results are insufficient. + + Conditions for insufficiency: + 1. hits.size < min_hits + 2. max(score) < score_threshold + 3. evidence tokens exceed limit (optional) + + Args: + retrieval_result: Result from retrieval operation + evidence_tokens: Optional token count for evidence + + Returns: + Tuple of (is_insufficient, reason) + """ + reasons = [] + + if retrieval_result.hit_count < self._config.min_hits: + reasons.append( + f"hit_count({retrieval_result.hit_count}) < min_hits({self._config.min_hits})" + ) + + if retrieval_result.max_score < self._config.score_threshold: + reasons.append( + f"max_score({retrieval_result.max_score:.3f}) < threshold({self._config.score_threshold})" + ) + + if evidence_tokens is not None and evidence_tokens > self._config.max_evidence_tokens: + reasons.append( + f"evidence_tokens({evidence_tokens}) > max({self._config.max_evidence_tokens})" + ) + + is_insufficient = len(reasons) > 0 + reason = "; ".join(reasons) if reasons else "sufficient" + + return is_insufficient, reason + + def calculate_confidence( + self, + retrieval_result: RetrievalResult, + evidence_tokens: int | None = None, + additional_factors: dict[str, float] | None = None, + ) -> ConfidenceResult: + """ + [AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculate confidence and transfer suggestion. + + MVP Strategy: + 1. Base confidence from max retrieval score + 2. Adjust for hit count (more hits = higher confidence) + 3. Penalize if retrieval is insufficient + 4. Determine shouldTransfer based on T_low threshold + + Args: + retrieval_result: Result from retrieval operation + evidence_tokens: Optional token count for evidence + additional_factors: Optional additional confidence factors + + Returns: + ConfidenceResult with confidence and transfer suggestion + """ + is_insufficient, insufficiency_reason = self.is_retrieval_insufficient( + retrieval_result, evidence_tokens + ) + + base_confidence = retrieval_result.max_score + + hit_count_factor = min(1.0, retrieval_result.hit_count / 5.0) + confidence = base_confidence * 0.7 + hit_count_factor * 0.3 + + if is_insufficient: + confidence -= self._config.insufficient_penalty + logger.info( + f"[AC-AISVC-17] Retrieval insufficient: {insufficiency_reason}, " + f"applying penalty -{self._config.insufficient_penalty}" + ) + + if additional_factors: + for factor_name, factor_value in additional_factors.items(): + confidence += factor_value * 0.1 + + confidence = max(0.0, min(1.0, confidence)) + + should_transfer = confidence < self._config.confidence_low_threshold + transfer_reason = None + + if should_transfer: + if is_insufficient: + transfer_reason = "检索结果不足,无法提供高置信度回答" + else: + transfer_reason = "置信度低于阈值,建议转人工" + elif confidence < self._config.confidence_high_threshold and is_insufficient: + transfer_reason = "检索结果有限,回答可能不够准确" + + diagnostics = { + "base_confidence": base_confidence, + "hit_count": retrieval_result.hit_count, + "max_score": retrieval_result.max_score, + "is_insufficient": is_insufficient, + "insufficiency_reason": insufficiency_reason if is_insufficient else None, + "penalty_applied": self._config.insufficient_penalty if is_insufficient else 0.0, + "threshold_low": self._config.confidence_low_threshold, + "threshold_high": self._config.confidence_high_threshold, + } + + logger.info( + f"[AC-AISVC-17, AC-AISVC-18] Confidence calculated: " + f"{confidence:.3f}, should_transfer={should_transfer}, " + f"insufficient={is_insufficient}" + ) + + return ConfidenceResult( + confidence=round(confidence, 3), + should_transfer=should_transfer, + transfer_reason=transfer_reason, + is_retrieval_insufficient=is_insufficient, + diagnostics=diagnostics, + ) + + def calculate_confidence_no_retrieval(self) -> ConfidenceResult: + """ + [AC-AISVC-17] Calculate confidence when no retrieval was performed. + + Returns a low confidence result suggesting transfer. + """ + return ConfidenceResult( + confidence=0.3, + should_transfer=True, + transfer_reason="未进行知识库检索,建议转人工", + is_retrieval_insufficient=True, + diagnostics={ + "base_confidence": 0.0, + "hit_count": 0, + "max_score": 0.0, + "is_insufficient": True, + "insufficiency_reason": "no_retrieval", + "penalty_applied": 0.0, + "threshold_low": self._config.confidence_low_threshold, + "threshold_high": self._config.confidence_high_threshold, + }, + ) + + +_confidence_calculator: ConfidenceCalculator | None = None + + +def get_confidence_calculator() -> ConfidenceCalculator: + """Get or create confidence calculator instance.""" + global _confidence_calculator + if _confidence_calculator is None: + _confidence_calculator = ConfidenceCalculator() + return _confidence_calculator diff --git a/ai-service/tests/test_confidence.py b/ai-service/tests/test_confidence.py new file mode 100644 index 0000000..3a12c9b --- /dev/null +++ b/ai-service/tests/test_confidence.py @@ -0,0 +1,302 @@ +""" +Unit tests for Confidence Calculator. +[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Tests for confidence scoring and transfer logic. + +Tests cover: +- Retrieval insufficiency detection +- Confidence calculation based on retrieval scores +- shouldTransfer logic with threshold T_low +- Edge cases (no retrieval, empty results) +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from app.services.retrieval.base import RetrievalHit, RetrievalResult +from app.services.confidence import ( + ConfidenceCalculator, + ConfidenceConfig, + ConfidenceResult, + get_confidence_calculator, +) + + +@pytest.fixture +def mock_settings(): + """Mock settings for testing.""" + settings = MagicMock() + settings.rag_score_threshold = 0.7 + settings.rag_min_hits = 1 + settings.confidence_low_threshold = 0.5 + settings.confidence_high_threshold = 0.8 + settings.confidence_insufficient_penalty = 0.3 + settings.rag_max_evidence_tokens = 2000 + return settings + + +@pytest.fixture +def confidence_calculator(mock_settings): + """Create confidence calculator with mocked settings.""" + with patch("app.services.confidence.get_settings", return_value=mock_settings): + calculator = ConfidenceCalculator() + yield calculator + + +@pytest.fixture +def good_retrieval_result(): + """Sample retrieval result with good hits.""" + return RetrievalResult( + hits=[ + RetrievalHit(text="Result 1", score=0.9, source="kb"), + RetrievalHit(text="Result 2", score=0.85, source="kb"), + RetrievalHit(text="Result 3", score=0.8, source="kb"), + ], + diagnostics={"query_length": 50}, + ) + + +@pytest.fixture +def poor_retrieval_result(): + """Sample retrieval result with poor hits.""" + return RetrievalResult( + hits=[ + RetrievalHit(text="Result 1", score=0.5, source="kb"), + ], + diagnostics={"query_length": 50}, + ) + + +@pytest.fixture +def empty_retrieval_result(): + """Sample empty retrieval result.""" + return RetrievalResult( + hits=[], + diagnostics={"query_length": 50}, + ) + + +class TestRetrievalInsufficiency: + """Tests for retrieval insufficiency detection. [AC-AISVC-17]""" + + def test_sufficient_retrieval(self, confidence_calculator, good_retrieval_result): + """[AC-AISVC-17] Test sufficient retrieval detection.""" + is_insufficient, reason = confidence_calculator.is_retrieval_insufficient( + good_retrieval_result + ) + + assert is_insufficient is False + assert reason == "sufficient" + + def test_insufficient_hit_count(self, confidence_calculator): + """[AC-AISVC-17] Test insufficiency due to low hit count.""" + config = ConfidenceConfig(min_hits=3) + calculator = ConfidenceCalculator(config=config) + + result = RetrievalResult( + hits=[ + RetrievalHit(text="Result 1", score=0.9, source="kb"), + ] + ) + + is_insufficient, reason = calculator.is_retrieval_insufficient(result) + + assert is_insufficient is True + assert "hit_count" in reason.lower() + + def test_insufficient_score(self, confidence_calculator, poor_retrieval_result): + """[AC-AISVC-17] Test insufficiency due to low score.""" + is_insufficient, reason = confidence_calculator.is_retrieval_insufficient( + poor_retrieval_result + ) + + assert is_insufficient is True + assert "max_score" in reason.lower() + + def test_insufficient_empty_result(self, confidence_calculator, empty_retrieval_result): + """[AC-AISVC-17] Test insufficiency with empty result.""" + is_insufficient, reason = confidence_calculator.is_retrieval_insufficient( + empty_retrieval_result + ) + + assert is_insufficient is True + + def test_insufficient_evidence_tokens(self, confidence_calculator, good_retrieval_result): + """[AC-AISVC-17] Test insufficiency due to evidence token limit.""" + is_insufficient, reason = confidence_calculator.is_retrieval_insufficient( + good_retrieval_result, evidence_tokens=3000 + ) + + assert is_insufficient is True + assert "evidence_tokens" in reason.lower() + + +class TestConfidenceCalculation: + """Tests for confidence calculation. [AC-AISVC-17, AC-AISVC-19]""" + + def test_high_confidence_with_good_retrieval( + self, confidence_calculator, good_retrieval_result + ): + """[AC-AISVC-19] Test high confidence with good retrieval results.""" + result = confidence_calculator.calculate_confidence(good_retrieval_result) + + assert isinstance(result, ConfidenceResult) + assert result.confidence >= 0.5 + assert result.should_transfer is False + assert result.is_retrieval_insufficient is False + + def test_low_confidence_with_poor_retrieval( + self, confidence_calculator, poor_retrieval_result + ): + """[AC-AISVC-17] Test low confidence with poor retrieval results.""" + result = confidence_calculator.calculate_confidence(poor_retrieval_result) + + assert isinstance(result, ConfidenceResult) + assert result.confidence < 0.7 + assert result.is_retrieval_insufficient is True + + def test_confidence_with_empty_result( + self, confidence_calculator, empty_retrieval_result + ): + """[AC-AISVC-17] Test confidence with empty retrieval result.""" + result = confidence_calculator.calculate_confidence(empty_retrieval_result) + + assert result.confidence < 0.5 + assert result.should_transfer is True + assert result.is_retrieval_insufficient is True + + def test_confidence_includes_diagnostics( + self, confidence_calculator, good_retrieval_result + ): + """[AC-AISVC-17] Test that confidence result includes diagnostics.""" + result = confidence_calculator.calculate_confidence(good_retrieval_result) + + assert "base_confidence" in result.diagnostics + assert "hit_count" in result.diagnostics + assert "max_score" in result.diagnostics + assert "threshold_low" in result.diagnostics + + def test_confidence_with_additional_factors( + self, confidence_calculator, good_retrieval_result + ): + """[AC-AISVC-17] Test confidence with additional factors.""" + additional = {"model_certainty": 0.5} + result = confidence_calculator.calculate_confidence( + good_retrieval_result, additional_factors=additional + ) + + assert result.confidence > 0 + + def test_confidence_bounded_to_range(self, confidence_calculator): + """[AC-AISVC-17] Test that confidence is bounded to [0, 1].""" + result_with_high_score = RetrievalResult( + hits=[RetrievalHit(text="Result", score=1.0, source="kb")] + ) + + result = confidence_calculator.calculate_confidence(result_with_high_score) + + assert 0.0 <= result.confidence <= 1.0 + + +class TestShouldTransfer: + """Tests for shouldTransfer logic. [AC-AISVC-18]""" + + def test_no_transfer_with_high_confidence( + self, confidence_calculator, good_retrieval_result + ): + """[AC-AISVC-18] Test no transfer when confidence is high.""" + result = confidence_calculator.calculate_confidence(good_retrieval_result) + + assert result.should_transfer is False + assert result.transfer_reason is None + + def test_transfer_with_low_confidence( + self, confidence_calculator, empty_retrieval_result + ): + """[AC-AISVC-18] Test transfer when confidence is low.""" + result = confidence_calculator.calculate_confidence(empty_retrieval_result) + + assert result.should_transfer is True + assert result.transfer_reason is not None + + def test_transfer_reason_for_insufficient_retrieval( + self, confidence_calculator, poor_retrieval_result + ): + """[AC-AISVC-18] Test transfer reason for insufficient retrieval.""" + result = confidence_calculator.calculate_confidence(poor_retrieval_result) + + assert result.is_retrieval_insufficient is True + if result.should_transfer: + assert "检索" in result.transfer_reason or "置信度" in result.transfer_reason + + def test_custom_threshold(self): + """[AC-AISVC-18] Test custom low threshold for transfer.""" + config = ConfidenceConfig( + confidence_low_threshold=0.7, + score_threshold=0.7, + min_hits=1, + ) + calculator = ConfidenceCalculator(config=config) + + result = RetrievalResult( + hits=[RetrievalHit(text="Result", score=0.6, source="kb")] + ) + + conf_result = calculator.calculate_confidence(result) + + assert conf_result.should_transfer is True + + +class TestNoRetrieval: + """Tests for no retrieval scenario. [AC-AISVC-17]""" + + def test_no_retrieval_confidence(self, confidence_calculator): + """[AC-AISVC-17] Test confidence when no retrieval was performed.""" + result = confidence_calculator.calculate_confidence_no_retrieval() + + assert result.confidence == 0.3 + assert result.should_transfer is True + assert result.transfer_reason is not None + assert result.is_retrieval_insufficient is True + + +class TestConfidenceConfig: + """Tests for confidence configuration.""" + + def test_default_config(self, mock_settings): + """Test default configuration values.""" + with patch("app.services.confidence.get_settings", return_value=mock_settings): + calculator = ConfidenceCalculator() + + assert calculator._config.score_threshold == 0.7 + assert calculator._config.min_hits == 1 + assert calculator._config.confidence_low_threshold == 0.5 + + def test_custom_config(self): + """Test custom configuration values.""" + config = ConfidenceConfig( + score_threshold=0.8, + min_hits=2, + confidence_low_threshold=0.6, + ) + calculator = ConfidenceCalculator(config=config) + + assert calculator._config.score_threshold == 0.8 + assert calculator._config.min_hits == 2 + assert calculator._config.confidence_low_threshold == 0.6 + + +class TestConfidenceCalculatorSingleton: + """Tests for singleton pattern.""" + + def test_get_confidence_calculator_singleton(self, mock_settings): + """Test that get_confidence_calculator returns singleton.""" + with patch("app.services.confidence.get_settings", return_value=mock_settings): + from app.services.confidence import _confidence_calculator + import app.services.confidence as confidence_module + confidence_module._confidence_calculator = None + + calculator1 = get_confidence_calculator() + calculator2 = get_confidence_calculator() + + assert calculator1 is calculator2 diff --git a/ai-service/tests/test_sse_state_machine.py b/ai-service/tests/test_sse_state_machine.py new file mode 100644 index 0000000..6e23583 --- /dev/null +++ b/ai-service/tests/test_sse_state_machine.py @@ -0,0 +1,376 @@ +""" +Tests for SSE state machine and error handling. +[AC-AISVC-08, AC-AISVC-09] Tests for proper event sequence and error handling. +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi.testclient import TestClient +from sse_starlette.sse import ServerSentEvent + +from app.core.sse import ( + SSEState, + SSEStateMachine, + create_error_event, + create_final_event, + create_message_event, +) +from app.main import app +from app.models import ChatRequest, ChannelType + + +class TestSSEStateMachineTransitions: + """ + [AC-AISVC-08, AC-AISVC-09] Test cases for SSE state machine transitions. + """ + + @pytest.mark.asyncio + async def test_init_to_streaming_transition(self): + """ + [AC-AISVC-08] Test INIT -> STREAMING transition. + """ + state_machine = SSEStateMachine() + assert state_machine.state == SSEState.INIT + + success = await state_machine.transition_to_streaming() + assert success is True + assert state_machine.state == SSEState.STREAMING + + @pytest.mark.asyncio + async def test_streaming_to_final_transition(self): + """ + [AC-AISVC-08] Test STREAMING -> FINAL_SENT transition. + """ + state_machine = SSEStateMachine() + await state_machine.transition_to_streaming() + + success = await state_machine.transition_to_final() + assert success is True + assert state_machine.state == SSEState.FINAL_SENT + + @pytest.mark.asyncio + async def test_streaming_to_error_transition(self): + """ + [AC-AISVC-09] Test STREAMING -> ERROR_SENT transition. + """ + state_machine = SSEStateMachine() + await state_machine.transition_to_streaming() + + success = await state_machine.transition_to_error() + assert success is True + assert state_machine.state == SSEState.ERROR_SENT + + @pytest.mark.asyncio + async def test_init_to_error_transition(self): + """ + [AC-AISVC-09] Test INIT -> ERROR_SENT transition (error before streaming starts). + """ + state_machine = SSEStateMachine() + + success = await state_machine.transition_to_error() + assert success is True + assert state_machine.state == SSEState.ERROR_SENT + + @pytest.mark.asyncio + async def test_cannot_transition_from_final(self): + """ + [AC-AISVC-08] Test that no transitions are possible after FINAL_SENT. + """ + state_machine = SSEStateMachine() + await state_machine.transition_to_streaming() + await state_machine.transition_to_final() + + assert await state_machine.transition_to_streaming() is False + assert await state_machine.transition_to_error() is False + assert state_machine.state == SSEState.FINAL_SENT + + @pytest.mark.asyncio + async def test_cannot_transition_from_error(self): + """ + [AC-AISVC-09] Test that no transitions are possible after ERROR_SENT. + """ + state_machine = SSEStateMachine() + await state_machine.transition_to_streaming() + await state_machine.transition_to_error() + + assert await state_machine.transition_to_streaming() is False + assert await state_machine.transition_to_final() is False + assert state_machine.state == SSEState.ERROR_SENT + + @pytest.mark.asyncio + async def test_cannot_send_message_after_final(self): + """ + [AC-AISVC-08] Test that can_send_message returns False after FINAL_SENT. + """ + state_machine = SSEStateMachine() + await state_machine.transition_to_streaming() + await state_machine.transition_to_final() + + assert state_machine.can_send_message() is False + + @pytest.mark.asyncio + async def test_cannot_send_message_after_error(self): + """ + [AC-AISVC-09] Test that can_send_message returns False after ERROR_SENT. + """ + state_machine = SSEStateMachine() + await state_machine.transition_to_streaming() + await state_machine.transition_to_error() + + assert state_machine.can_send_message() is False + + @pytest.mark.asyncio + async def test_close_transition(self): + """ + [AC-AISVC-08] Test that close() transitions to CLOSED state. + """ + state_machine = SSEStateMachine() + await state_machine.transition_to_streaming() + await state_machine.transition_to_final() + + await state_machine.close() + assert state_machine.state == SSEState.CLOSED + + +class TestSSEEventSequence: + """ + [AC-AISVC-08, AC-AISVC-09] Test cases for SSE event sequence enforcement. + """ + + @pytest.fixture + def client(self): + return TestClient(app) + + @pytest.fixture + def valid_headers(self): + return {"X-Tenant-Id": "tenant_001", "Accept": "text/event-stream"} + + @pytest.fixture + def valid_body(self): + return { + "sessionId": "test_session", + "currentMessage": "Hello", + "channelType": "wechat", + } + + def test_sse_sequence_message_then_final(self, client, valid_headers, valid_body): + """ + [AC-AISVC-08] Test that SSE events follow: message* -> final -> close. + """ + response = client.post("/ai/chat", json=valid_body, headers=valid_headers) + + assert response.status_code == 200 + content = response.text + + assert "event:message" in content or "event: message" in content + assert "event:final" in content or "event: final" in content + + message_idx = content.find("event:message") + if message_idx == -1: + message_idx = content.find("event: message") + final_idx = content.find("event:final") + if final_idx == -1: + final_idx = content.find("event: final") + + assert final_idx > message_idx, "final should come after message events" + + def test_sse_only_one_final_event(self, client, valid_headers, valid_body): + """ + [AC-AISVC-08] Test that there is exactly one final event. + """ + response = client.post("/ai/chat", json=valid_body, headers=valid_headers) + + content = response.text + final_count = content.count("event:final") + content.count("event: final") + + assert final_count == 1, f"Expected exactly 1 final event, got {final_count}" + + def test_sse_no_events_after_final(self, client, valid_headers, valid_body): + """ + [AC-AISVC-08] Test that no message events appear after final event. + """ + response = client.post("/ai/chat", json=valid_body, headers=valid_headers) + + content = response.text + lines = content.split("\n") + + final_found = False + for line in lines: + if "event:final" in line or "event: final" in line: + final_found = True + elif final_found and ("event:message" in line or "event: message" in line): + pytest.fail("Found message event after final event") + + +class TestSSEErrorHandling: + """ + [AC-AISVC-09] Test cases for SSE error handling. + """ + + @pytest.mark.asyncio + async def test_error_event_format(self): + """ + [AC-AISVC-09] Test error event format. + """ + event = create_error_event( + code="TEST_ERROR", + message="Test error message", + details=[{"field": "test"}], + ) + + assert event.event == "error" + data = json.loads(event.data) + assert data["code"] == "TEST_ERROR" + assert data["message"] == "Test error message" + assert data["details"] == [{"field": "test"}] + + @pytest.mark.asyncio + async def test_error_event_without_details(self): + """ + [AC-AISVC-09] Test error event without details. + """ + event = create_error_event( + code="SIMPLE_ERROR", + message="Simple error", + ) + + assert event.event == "error" + data = json.loads(event.data) + assert data["code"] == "SIMPLE_ERROR" + assert data["message"] == "Simple error" + assert "details" not in data + + def test_missing_tenant_id_returns_400(self): + """ + [AC-AISVC-12] Test that missing X-Tenant-Id returns 400 error. + """ + client = TestClient(app) + headers = {"Accept": "text/event-stream"} + body = { + "sessionId": "test_session", + "currentMessage": "Hello", + "channelType": "wechat", + } + + response = client.post("/ai/chat", json=body, headers=headers) + + assert response.status_code == 400 + data = response.json() + assert data["code"] == "MISSING_TENANT_ID" + + +class TestSSEStateConcurrency: + """ + [AC-AISVC-08, AC-AISVC-09] Test cases for state machine thread safety. + """ + + @pytest.mark.asyncio + async def test_concurrent_transitions(self): + """ + [AC-AISVC-08] Test that concurrent transitions are handled correctly. + """ + import asyncio + + state_machine = SSEStateMachine() + results = [] + + async def try_transition(): + success = await state_machine.transition_to_streaming() + results.append(success) + + await asyncio.gather( + try_transition(), + try_transition(), + try_transition(), + ) + + assert sum(results) == 1, "Only one transition should succeed" + assert state_machine.state == SSEState.STREAMING + + @pytest.mark.asyncio + async def test_concurrent_final_transitions(self): + """ + [AC-AISVC-08] Test that only one final transition succeeds. + """ + import asyncio + + state_machine = SSEStateMachine() + await state_machine.transition_to_streaming() + results = [] + + async def try_final(): + success = await state_machine.transition_to_final() + results.append(success) + + await asyncio.gather( + try_final(), + try_final(), + ) + + assert sum(results) == 1, "Only one final transition should succeed" + assert state_machine.state == SSEState.FINAL_SENT + + +class TestSSEIntegrationWithOrchestrator: + """ + [AC-AISVC-08, AC-AISVC-09] Integration tests for SSE with Orchestrator. + """ + + @pytest.mark.asyncio + async def test_orchestrator_stream_with_error(self): + """ + [AC-AISVC-09] Test that orchestrator errors are properly handled. + """ + from app.services.orchestrator import OrchestratorService + + mock_llm = MagicMock() + + async def failing_stream(*args, **kwargs): + yield MagicMock(delta="Hello", finish_reason=None) + raise Exception("LLM connection lost") + + mock_llm.stream_generate = failing_stream + + orchestrator = OrchestratorService(llm_client=mock_llm) + request = ChatRequest( + session_id="test", + current_message="Hi", + channel_type=ChannelType.WECHAT, + ) + + events = [] + async for event in orchestrator.generate_stream("tenant", request): + events.append(event) + + event_types = [e.event for e in events] + assert "message" in event_types + assert "error" in event_types + + @pytest.mark.asyncio + async def test_orchestrator_stream_normal_flow(self): + """ + [AC-AISVC-08] Test normal streaming flow ends with final event. + """ + from app.services.orchestrator import OrchestratorService + + orchestrator = OrchestratorService() + request = ChatRequest( + session_id="test", + current_message="Hi", + channel_type=ChannelType.WECHAT, + ) + + events = [] + async for event in orchestrator.generate_stream("tenant", request): + events.append(event) + + event_types = [e.event for e in events] + assert "message" in event_types + assert "final" in event_types + + final_index = event_types.index("final") + for i, t in enumerate(event_types): + if t == "message": + assert i < final_index, "message events should come before final"