""" Unit tests for Retrieval Strategy Service. [AC-AISVC-RES-01~15] Tests for strategy management, switching, validation, and rollback. """ import pytest from unittest.mock import MagicMock, patch from datetime import datetime from app.schemas.retrieval_strategy import ( ReactMode, RolloutConfig, RolloutMode, StrategyType, RetrievalStrategyStatus, RetrievalStrategySwitchRequest, RetrievalStrategyValidationRequest, ValidationResult, ) from app.services.retrieval.strategy_service import ( RetrievalStrategyService, StrategyState, get_strategy_service, ) from app.services.retrieval.strategy_audit import ( StrategyAuditService, get_audit_service, ) from app.services.retrieval.strategy_metrics import ( StrategyMetricsService, get_metrics_service, ) class TestRetrievalStrategySchemas: """[AC-AISVC-RES-01~15] Tests for strategy schema models.""" def test_rollout_config_off_mode(self): """[AC-AISVC-RES-03] Off mode should not require percentage or allowlist.""" config = RolloutConfig(mode=RolloutMode.OFF) assert config.mode == RolloutMode.OFF assert config.percentage is None assert config.allowlist is None def test_rollout_config_percentage_mode(self): """[AC-AISVC-RES-03] Percentage mode should require percentage.""" config = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0) assert config.mode == RolloutMode.PERCENTAGE assert config.percentage == 50.0 def test_rollout_config_percentage_mode_missing_value(self): """[AC-AISVC-RES-03] Percentage mode without percentage should raise error.""" with pytest.raises(ValueError, match="percentage is required"): RolloutConfig(mode=RolloutMode.PERCENTAGE) def test_rollout_config_allowlist_mode(self): """[AC-AISVC-RES-03] Allowlist mode should require allowlist.""" config = RolloutConfig(mode=RolloutMode.ALLOWLIST, allowlist=["tenant1", "tenant2"]) assert config.mode == RolloutMode.ALLOWLIST assert config.allowlist == ["tenant1", "tenant2"] def test_rollout_config_allowlist_mode_missing_value(self): """[AC-AISVC-RES-03] Allowlist mode without allowlist should raise error.""" with pytest.raises(ValueError, match="allowlist is required"): RolloutConfig(mode=RolloutMode.ALLOWLIST) def test_retrieval_strategy_status(self): """[AC-AISVC-RES-01] Status should contain all required fields.""" rollout = RolloutConfig(mode=RolloutMode.OFF) status = RetrievalStrategyStatus( active_strategy=StrategyType.DEFAULT, react_mode=ReactMode.NON_REACT, rollout=rollout, ) assert status.active_strategy == StrategyType.DEFAULT assert status.react_mode == ReactMode.NON_REACT assert status.rollout.mode == RolloutMode.OFF def test_switch_request_minimal(self): """[AC-AISVC-RES-02] Switch request should work with minimal fields.""" request = RetrievalStrategySwitchRequest(target_strategy=StrategyType.ENHANCED) assert request.target_strategy == StrategyType.ENHANCED assert request.react_mode is None assert request.rollout is None assert request.reason is None def test_switch_request_full(self): """[AC-AISVC-RES-02,03,05] Switch request should accept all fields.""" rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=30.0) request = RetrievalStrategySwitchRequest( target_strategy=StrategyType.ENHANCED, react_mode=ReactMode.REACT, rollout=rollout, reason="Testing enhanced strategy", ) assert request.target_strategy == StrategyType.ENHANCED assert request.react_mode == ReactMode.REACT assert request.rollout.percentage == 30.0 assert request.reason == "Testing enhanced strategy" class TestRetrievalStrategyService: """[AC-AISVC-RES-01~15] Tests for strategy service.""" @pytest.fixture def service(self): """Create a fresh service instance for each test.""" return RetrievalStrategyService() def test_get_current_status_default(self, service): """[AC-AISVC-RES-01] Default status should be default strategy and non_react mode.""" status = service.get_current_status() assert status.active_strategy == StrategyType.DEFAULT assert status.react_mode == ReactMode.NON_REACT assert status.rollout.mode == RolloutMode.OFF def test_switch_strategy_to_enhanced(self, service): """[AC-AISVC-RES-02] Should switch to enhanced strategy.""" request = RetrievalStrategySwitchRequest( target_strategy=StrategyType.ENHANCED, react_mode=ReactMode.REACT, ) response = service.switch_strategy(request) assert response.previous.active_strategy == StrategyType.DEFAULT assert response.current.active_strategy == StrategyType.ENHANCED assert response.current.react_mode == ReactMode.REACT def test_switch_strategy_with_grayscale_percentage(self, service): """[AC-AISVC-RES-03] Should switch with grayscale percentage.""" rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0) request = RetrievalStrategySwitchRequest( target_strategy=StrategyType.ENHANCED, rollout=rollout, ) response = service.switch_strategy(request) assert response.current.active_strategy == StrategyType.ENHANCED assert response.current.rollout.mode == RolloutMode.PERCENTAGE assert response.current.rollout.percentage == 50.0 def test_switch_strategy_with_allowlist(self, service): """[AC-AISVC-RES-03] Should switch with allowlist grayscale.""" rollout = RolloutConfig( mode=RolloutMode.ALLOWLIST, allowlist=["tenant_a", "tenant_b"], ) request = RetrievalStrategySwitchRequest( target_strategy=StrategyType.ENHANCED, rollout=rollout, ) response = service.switch_strategy(request) assert response.current.rollout.mode == RolloutMode.ALLOWLIST assert "tenant_a" in response.current.rollout.allowlist def test_rollback_strategy(self, service): """[AC-AISVC-RES-07] Should rollback to previous strategy.""" request = RetrievalStrategySwitchRequest( target_strategy=StrategyType.ENHANCED, react_mode=ReactMode.REACT, ) service.switch_strategy(request) response = service.rollback_strategy() assert response.rollback_to.active_strategy == StrategyType.DEFAULT assert response.rollback_to.react_mode == ReactMode.NON_REACT def test_rollback_without_previous_returns_default(self, service): """[AC-AISVC-RES-07] Rollback without previous should return default.""" response = service.rollback_strategy() assert response.rollback_to.active_strategy == StrategyType.DEFAULT def test_should_use_enhanced_strategy_default(self, service): """[AC-AISVC-RES-01] Default strategy should not use enhanced.""" assert service.should_use_enhanced_strategy("tenant_a") is False def test_should_use_enhanced_strategy_with_allowlist(self, service): """[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist.""" rollout = RolloutConfig( mode=RolloutMode.ALLOWLIST, allowlist=["tenant_a"], ) request = RetrievalStrategySwitchRequest( target_strategy=StrategyType.ENHANCED, rollout=rollout, ) service.switch_strategy(request) assert service.should_use_enhanced_strategy("tenant_a") is True assert service.should_use_enhanced_strategy("tenant_b") is False def test_get_route_mode_react(self, service): """[AC-AISVC-RES-10] React mode should return react route.""" request = RetrievalStrategySwitchRequest( target_strategy=StrategyType.ENHANCED, react_mode=ReactMode.REACT, ) service.switch_strategy(request) route = service.get_route_mode("test query") assert route == "react" def test_get_route_mode_direct(self, service): """[AC-AISVC-RES-09] Non-react mode should return direct route.""" request = RetrievalStrategySwitchRequest( target_strategy=StrategyType.DEFAULT, react_mode=ReactMode.NON_REACT, ) service.switch_strategy(request) route = service.get_route_mode("test query") assert route == "direct" def test_get_route_mode_auto_short_query(self, service): """[AC-AISVC-RES-12] Short query with high confidence should use direct route.""" service._state.react_mode = None route = service._auto_route("短问题", confidence=0.8) assert route == "direct" def test_get_route_mode_auto_multiple_conditions(self, service): """[AC-AISVC-RES-13] Query with multiple conditions should use react route.""" route = service._auto_route("查询订单状态和物流信息") assert route == "react" def test_get_route_mode_auto_low_confidence(self, service): """[AC-AISVC-RES-13] Low confidence should use react route.""" route = service._auto_route("test query", confidence=0.3) assert route == "react" def test_get_switch_history(self, service): """Should track switch history.""" request = RetrievalStrategySwitchRequest( target_strategy=StrategyType.ENHANCED, reason="Testing", ) service.switch_strategy(request) history = service.get_switch_history() assert len(history) == 1 assert history[0]["to_strategy"] == "enhanced" class TestRetrievalStrategyValidation: """[AC-AISVC-RES-04,06,08] Tests for strategy validation.""" @pytest.fixture def service(self): return RetrievalStrategyService() def test_validate_default_strategy(self, service): """[AC-AISVC-RES-06] Default strategy should pass validation.""" request = RetrievalStrategyValidationRequest( strategy=StrategyType.DEFAULT, ) response = service.validate_strategy(request) assert response.passed is True def test_validate_enhanced_strategy(self, service): """[AC-AISVC-RES-06] Enhanced strategy validation.""" request = RetrievalStrategyValidationRequest( strategy=StrategyType.ENHANCED, ) response = service.validate_strategy(request) assert isinstance(response.passed, bool) assert len(response.results) > 0 def test_validate_specific_checks(self, service): """[AC-AISVC-RES-06] Should run specific validation checks.""" request = RetrievalStrategyValidationRequest( strategy=StrategyType.ENHANCED, checks=["metadata_consistency", "performance_budget"], ) response = service.validate_strategy(request) check_names = [r.check for r in response.results] assert "metadata_consistency" in check_names assert "performance_budget" in check_names def test_check_metadata_consistency(self, service): """[AC-AISVC-RES-04] Metadata consistency check.""" result = service._check_metadata_consistency(StrategyType.DEFAULT) assert result.check == "metadata_consistency" assert result.passed is True def test_check_rrf_config(self, service): """[AC-AISVC-RES-02] RRF config check.""" result = service._check_rrf_config(StrategyType.DEFAULT) assert result.check == "rrf_config" assert isinstance(result.passed, bool) def test_check_performance_budget(self, service): """[AC-AISVC-RES-08] Performance budget check.""" result = service._check_performance_budget( StrategyType.ENHANCED, ReactMode.REACT, ) assert result.check == "performance_budget" assert isinstance(result.passed, bool) class TestStrategyAuditService: """[AC-AISVC-RES-07] Tests for audit service.""" @pytest.fixture def audit_service(self): return StrategyAuditService(max_entries=100) def test_log_switch_operation(self, audit_service): """[AC-AISVC-RES-07] Should log switch operation.""" audit_service.log( operation="switch", previous_strategy="default", new_strategy="enhanced", reason="Testing", operator="admin", ) entries = audit_service.get_audit_log() assert len(entries) == 1 assert entries[0].operation == "switch" assert entries[0].previous_strategy == "default" assert entries[0].new_strategy == "enhanced" def test_log_rollback_operation(self, audit_service): """[AC-AISVC-RES-07] Should log rollback operation.""" audit_service.log_rollback( previous_strategy="enhanced", new_strategy="default", reason="Performance issue", operator="admin", ) entries = audit_service.get_audit_log(operation="rollback") assert len(entries) == 1 assert entries[0].operation == "rollback" def test_log_validation_operation(self, audit_service): """[AC-AISVC-RES-06] Should log validation operation.""" audit_service.log_validation( strategy="enhanced", checks=["metadata_consistency"], passed=True, ) entries = audit_service.get_audit_log(operation="validate") assert len(entries) == 1 assert entries[0].operation == "validate" def test_get_audit_log_with_limit(self, audit_service): """Should limit audit log entries.""" for i in range(10): audit_service.log(operation="switch", new_strategy=f"strategy_{i}") entries = audit_service.get_audit_log(limit=5) assert len(entries) == 5 def test_get_audit_stats(self, audit_service): """Should return audit statistics.""" audit_service.log(operation="switch", new_strategy="enhanced") audit_service.log(operation="rollback", new_strategy="default") stats = audit_service.get_audit_stats() assert stats["total_entries"] == 2 assert stats["operation_counts"]["switch"] == 1 assert stats["operation_counts"]["rollback"] == 1 def test_clear_audit_log(self, audit_service): """Should clear audit log.""" audit_service.log(operation="switch", new_strategy="enhanced") assert len(audit_service.get_audit_log()) == 1 count = audit_service.clear_audit_log() assert count == 1 assert len(audit_service.get_audit_log()) == 0 class TestStrategyMetricsService: """[AC-AISVC-RES-03,08] Tests for metrics service.""" @pytest.fixture def metrics_service(self): return StrategyMetricsService() def test_record_request(self, metrics_service): """[AC-AISVC-RES-08] Should record request metrics.""" metrics_service.record_request( latency_ms=100.0, success=True, route_mode="direct", ) metrics = metrics_service.get_metrics() assert metrics.total_requests == 1 assert metrics.successful_requests == 1 assert metrics.avg_latency_ms == 100.0 def test_record_failed_request(self, metrics_service): """[AC-AISVC-RES-08] Should record failed request.""" metrics_service.record_request(latency_ms=50.0, success=False) metrics = metrics_service.get_metrics() assert metrics.failed_requests == 1 def test_record_fallback(self, metrics_service): """[AC-AISVC-RES-08] Should record fallback count.""" metrics_service.record_request( latency_ms=100.0, success=True, fallback=True, ) metrics = metrics_service.get_metrics() assert metrics.fallback_count == 1 def test_record_route_metrics(self, metrics_service): """[AC-AISVC-RES-08] Should track route mode metrics.""" metrics_service.record_request(latency_ms=100.0, success=True, route_mode="react") metrics_service.record_request(latency_ms=50.0, success=True, route_mode="direct") route_metrics = metrics_service.get_route_metrics() assert "react" in route_metrics assert "direct" in route_metrics def test_get_all_metrics(self, metrics_service): """Should get metrics for all strategies.""" metrics_service.set_current_strategy(StrategyType.ENHANCED, ReactMode.REACT) metrics_service.record_request(latency_ms=100.0, success=True) all_metrics = metrics_service.get_all_metrics() assert StrategyType.DEFAULT.value in all_metrics assert StrategyType.ENHANCED.value in all_metrics def test_get_performance_summary(self, metrics_service): """[AC-AISVC-RES-08] Should get performance summary.""" metrics_service.record_request(latency_ms=100.0, success=True) metrics_service.record_request(latency_ms=200.0, success=True) metrics_service.record_request(latency_ms=50.0, success=False) summary = metrics_service.get_performance_summary() assert summary["total_requests"] == 3 assert summary["successful_requests"] == 2 assert summary["failed_requests"] == 1 assert summary["success_rate"] == pytest.approx(0.6667, rel=0.01) def test_check_performance_threshold_ok(self, metrics_service): """[AC-AISVC-RES-08] Should pass performance threshold check.""" metrics_service.record_request(latency_ms=100.0, success=True) result = metrics_service.check_performance_threshold( strategy=StrategyType.DEFAULT, max_latency_ms=5000.0, max_error_rate=0.1, ) assert result["latency_ok"] is True assert result["error_rate_ok"] is True assert result["overall_ok"] is True def test_check_performance_threshold_exceeded(self, metrics_service): """[AC-AISVC-RES-08] Should fail when threshold exceeded.""" metrics_service.record_request(latency_ms=6000.0, success=True) metrics_service.record_request(latency_ms=100.0, success=False) result = metrics_service.check_performance_threshold( strategy=StrategyType.DEFAULT, max_latency_ms=5000.0, max_error_rate=0.1, ) assert result["latency_ok"] is False or result["error_rate_ok"] is False def test_reset_metrics(self, metrics_service): """Should reset metrics.""" metrics_service.record_request(latency_ms=100.0, success=True) metrics_service.reset_metrics() metrics = metrics_service.get_metrics() assert metrics.total_requests == 0 class TestSingletonInstances: """Tests for singleton instance getters.""" def test_get_strategy_service_singleton(self): """Should return same strategy service instance.""" from app.services.retrieval.strategy_service import _strategy_service import app.services.retrieval.strategy_service as module module._strategy_service = None service1 = get_strategy_service() service2 = get_strategy_service() assert service1 is service2 def test_get_audit_service_singleton(self): """Should return same audit service instance.""" from app.services.retrieval.strategy_audit import _audit_service import app.services.retrieval.strategy_audit as module module._audit_service = None service1 = get_audit_service() service2 = get_audit_service() assert service1 is service2 def test_get_metrics_service_singleton(self): """Should return same metrics service instance.""" from app.services.retrieval.strategy_metrics import _metrics_service import app.services.retrieval.strategy_metrics as module module._metrics_service = None service1 = get_metrics_service() service2 = get_metrics_service() assert service1 is service2