""" Unit tests for Retrieval Strategy Integration. [AC-AISVC-RES-01~15] Tests for integrated strategy and mode routing. """ import pytest from unittest.mock import AsyncMock, MagicMock, patch from app.services.retrieval.routing_config import ( RagRuntimeMode, StrategyType, RoutingConfig, StrategyContext, ) from app.services.retrieval.strategy_integration import ( RetrievalStrategyResult, RetrievalStrategyIntegration, get_retrieval_strategy_integration, reset_retrieval_strategy_integration, ) class TestRetrievalStrategyResult: """Tests for RetrievalStrategyResult.""" def test_result_creation(self): """Should create result with all fields.""" result = RetrievalStrategyResult( retrieval_result=None, final_answer="Test answer", strategy=StrategyType.ENHANCED, mode=RagRuntimeMode.REACT, should_fallback=True, fallback_reason="Low confidence", diagnostics={"key": "value"}, duration_ms=100, ) assert result.retrieval_result is None assert result.final_answer == "Test answer" assert result.strategy == StrategyType.ENHANCED assert result.mode == RagRuntimeMode.REACT assert result.should_fallback is True assert result.fallback_reason == "Low confidence" assert result.diagnostics == {"key": "value"} assert result.duration_ms == 100 def test_result_defaults(self): """Should create result with default values.""" result = RetrievalStrategyResult( retrieval_result=None, final_answer=None, strategy=StrategyType.DEFAULT, mode=RagRuntimeMode.DIRECT, ) assert result.should_fallback is False assert result.fallback_reason is None assert result.mode_route_result is None assert result.diagnostics == {} assert result.duration_ms == 0 class TestRetrievalStrategyIntegration: """[AC-AISVC-RES-01~15] Tests for RetrievalStrategyIntegration.""" @pytest.fixture def integration(self): reset_retrieval_strategy_integration() return RetrievalStrategyIntegration() def test_initial_state(self, integration): """Should initialize with default configuration.""" assert integration.config.strategy == StrategyType.DEFAULT assert integration.config.rag_runtime_mode == RagRuntimeMode.AUTO def test_update_config(self, integration): """[AC-AISVC-RES-15] Should update all configurations.""" new_config = RoutingConfig( strategy=StrategyType.ENHANCED, rag_runtime_mode=RagRuntimeMode.REACT, react_max_steps=7, ) integration.update_config(new_config) assert integration.config.strategy == StrategyType.ENHANCED assert integration.config.rag_runtime_mode == RagRuntimeMode.REACT def test_get_current_strategy(self, integration): """Should return current strategy from router.""" strategy = integration.get_current_strategy() assert strategy == StrategyType.DEFAULT def test_get_rollback_records(self, integration): """Should return rollback records from router.""" records = integration.get_rollback_records() assert isinstance(records, list) def test_validate_config(self, integration): """[AC-AISVC-RES-06] Should validate configuration.""" is_valid, errors = integration.validate_config() assert is_valid is True assert len(errors) == 0 @pytest.mark.asyncio async def test_execute_direct_mode(self, integration): """[AC-AISVC-RES-09] Should execute direct mode.""" integration._config.rag_runtime_mode = RagRuntimeMode.DIRECT ctx = StrategyContext(tenant_id="tenant_a", query="Test query") mock_result = MagicMock() mock_result.hits = [] mock_route_result = MagicMock() mock_route_result.mode = RagRuntimeMode.DIRECT with patch.object( integration._mode_router, "route", return_value=mock_route_result ): with patch.object( integration._mode_router, "execute_with_fallback", new_callable=AsyncMock ) as mock_execute: mock_execute.return_value = (mock_result, None, mock_route_result) result = await integration.execute(ctx) assert result.retrieval_result == mock_result assert result.final_answer is None assert result.mode == RagRuntimeMode.DIRECT @pytest.mark.asyncio async def test_execute_react_mode(self, integration): """[AC-AISVC-RES-10] Should execute react mode.""" integration._config.rag_runtime_mode = RagRuntimeMode.REACT ctx = StrategyContext(tenant_id="tenant_a", query="Test query") mock_route_result = MagicMock() mock_route_result.mode = RagRuntimeMode.REACT with patch.object( integration._mode_router, "route", return_value=mock_route_result ): with patch.object( integration._mode_router, "execute_react", new_callable=AsyncMock ) as mock_execute: mock_execute.return_value = ("Final answer", None, {}) result = await integration.execute(ctx) assert result.retrieval_result is None assert result.final_answer == "Final answer" assert result.mode == RagRuntimeMode.REACT @pytest.mark.asyncio async def test_execute_with_fallback(self, integration): """[AC-AISVC-RES-14] Should handle fallback from direct to react.""" integration._config.rag_runtime_mode = RagRuntimeMode.DIRECT integration._config.direct_fallback_on_low_confidence = True integration._config.direct_fallback_confidence_threshold = 0.4 ctx = StrategyContext(tenant_id="tenant_a", query="Test query") mock_route_result = MagicMock() mock_route_result.mode = RagRuntimeMode.DIRECT mock_route_result.should_fallback_to_react = True mock_route_result.fallback_reason = "low_confidence" with patch.object( integration._mode_router, "route", return_value=mock_route_result ): with patch.object( integration._mode_router, "execute_with_fallback", new_callable=AsyncMock ) as mock_execute: mock_execute.return_value = (None, "Fallback answer", mock_route_result) result = await integration.execute(ctx) assert result.retrieval_result is None assert result.final_answer == "Fallback answer" assert result.should_fallback is True @pytest.mark.asyncio async def test_execute_includes_diagnostics(self, integration): """Should include diagnostics in result.""" ctx = StrategyContext(tenant_id="tenant_a", query="Test query") mock_result = MagicMock() mock_result.hits = [] mock_route_result = MagicMock() mock_route_result.mode = RagRuntimeMode.DIRECT mock_route_result.diagnostics = {"mode_key": "mode_value"} with patch.object( integration._mode_router, "route", return_value=mock_route_result ): with patch.object( integration._mode_router, "execute_with_fallback", new_callable=AsyncMock ) as mock_execute: mock_execute.return_value = (mock_result, None, mock_route_result) result = await integration.execute(ctx) assert "strategy_diagnostics" in result.diagnostics assert "mode_diagnostics" in result.diagnostics assert "duration_ms" in result.diagnostics @pytest.mark.asyncio async def test_execute_tracks_duration(self, integration): """Should track execution duration.""" ctx = StrategyContext(tenant_id="tenant_a", query="Test query") mock_result = MagicMock() mock_result.hits = [] mock_route_result = MagicMock() mock_route_result.mode = RagRuntimeMode.DIRECT with patch.object( integration._mode_router, "route", return_value=mock_route_result ): with patch.object( integration._mode_router, "execute_with_fallback", new_callable=AsyncMock ) as mock_execute: mock_execute.return_value = (mock_result, None, mock_route_result) result = await integration.execute(ctx) assert result.duration_ms >= 0 class TestSingletonInstances: """Tests for singleton instance getters.""" def test_get_retrieval_strategy_integration_singleton(self): """Should return same integration instance.""" reset_retrieval_strategy_integration() integration1 = get_retrieval_strategy_integration() integration2 = get_retrieval_strategy_integration() assert integration1 is integration2 def test_reset_retrieval_strategy_integration(self): """Should create new instance after reset.""" integration1 = get_retrieval_strategy_integration() reset_retrieval_strategy_integration() integration2 = get_retrieval_strategy_integration() assert integration1 is not integration2