""" Tests for Slot State Cache. [AC-MRS-SLOT-CACHE-01] 多轮状态持久化测试 """ import json import time from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.services.cache.slot_state_cache import ( CachedSlotState, CachedSlotValue, SlotStateCache, get_slot_state_cache, ) class TestCachedSlotValue: """CachedSlotValue 测试""" def test_init(self): """测试初始化""" value = CachedSlotValue( value="test_value", source="user_confirmed", confidence=0.9, ) assert value.value == "test_value" assert value.source == "user_confirmed" assert value.confidence == 0.9 assert value.updated_at > 0 def test_to_dict(self): """测试转换为字典""" value = CachedSlotValue( value="test_value", source="rule_extracted", confidence=0.8, ) d = value.to_dict() assert d["value"] == "test_value" assert d["source"] == "rule_extracted" assert d["confidence"] == 0.8 assert "updated_at" in d def test_from_dict(self): """测试从字典创建""" d = { "value": "test_value", "source": "llm_inferred", "confidence": 0.7, "updated_at": 12345.0, } value = CachedSlotValue.from_dict(d) assert value.value == "test_value" assert value.source == "llm_inferred" assert value.confidence == 0.7 assert value.updated_at == 12345.0 class TestCachedSlotState: """CachedSlotState 测试""" def test_init(self): """测试初始化""" state = CachedSlotState() assert state.filled_slots == {} assert state.slot_to_field_map == {} assert state.created_at > 0 assert state.updated_at > 0 def test_with_slots(self): """测试带槽位初始化""" slots = { "region": CachedSlotValue(value="北京", source="user_confirmed"), "product": CachedSlotValue(value="手机", source="rule_extracted"), } state = CachedSlotState( filled_slots=slots, slot_to_field_map={"region": "region_field"}, ) assert len(state.filled_slots) == 2 assert state.slot_to_field_map["region"] == "region_field" def test_to_dict_and_from_dict(self): """测试序列化和反序列化""" slots = { "region": CachedSlotValue(value="北京", source="user_confirmed"), } original = CachedSlotState( filled_slots=slots, slot_to_field_map={"region": "region_field"}, ) d = original.to_dict() restored = CachedSlotState.from_dict(d) assert len(restored.filled_slots) == 1 assert restored.filled_slots["region"].value == "北京" assert restored.filled_slots["region"].source == "user_confirmed" assert restored.slot_to_field_map["region"] == "region_field" def test_get_simple_filled_slots(self): """测试获取简化槽位字典""" slots = { "region": CachedSlotValue(value="北京", source="user_confirmed"), "product": CachedSlotValue(value="手机", source="rule_extracted"), } state = CachedSlotState(filled_slots=slots) simple = state.get_simple_filled_slots() assert simple == {"region": "北京", "product": "手机"} def test_get_slot_sources(self): """测试获取槽位来源""" slots = { "region": CachedSlotValue(value="北京", source="user_confirmed"), "product": CachedSlotValue(value="手机", source="rule_extracted"), } state = CachedSlotState(filled_slots=slots) sources = state.get_slot_sources() assert sources == {"region": "user_confirmed", "product": "rule_extracted"} def test_get_slot_confidence(self): """测试获取槽位置信度""" slots = { "region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0), "product": CachedSlotValue(value="手机", source="rule_extracted", confidence=0.8), } state = CachedSlotState(filled_slots=slots) confidence = state.get_slot_confidence() assert confidence == {"region": 1.0, "product": 0.8} class TestSlotStateCache: """SlotStateCache 测试""" def test_source_priority(self): """测试来源优先级""" cache = SlotStateCache() assert cache._get_source_priority("user_confirmed") == 100 assert cache._get_source_priority("rule_extracted") == 80 assert cache._get_source_priority("llm_inferred") == 60 assert cache._get_source_priority("context") == 40 assert cache._get_source_priority("default") == 20 assert cache._get_source_priority("unknown") == 0 def test_make_key(self): """测试 key 生成""" cache = SlotStateCache() key = cache._make_key("tenant_123", "session_456") assert key == "slot_state:tenant_123:session_456" @pytest.mark.asyncio async def test_l1_cache_hit(self): """测试 L1 缓存命中""" cache = SlotStateCache() tenant_id = "tenant_1" session_id = "session_1" state = CachedSlotState( filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")}, ) cache._local_cache[f"{tenant_id}:{session_id}"] = (state, time.time()) result = await cache.get(tenant_id, session_id) assert result is not None assert result.filled_slots["region"].value == "北京" @pytest.mark.asyncio async def test_l1_cache_expired(self): """测试 L1 缓存过期""" cache = SlotStateCache() tenant_id = "tenant_1" session_id = "session_1" state = CachedSlotState( filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")}, ) old_time = time.time() - 400 cache._local_cache[f"{tenant_id}:{session_id}"] = (state, old_time) result = await cache.get(tenant_id, session_id) assert result is None assert f"{tenant_id}:{session_id}" not in cache._local_cache @pytest.mark.asyncio async def test_set_and_get_l1(self): """测试设置和获取 L1 缓存""" cache = SlotStateCache(redis_client=None) cache._enabled = False tenant_id = "tenant_1" session_id = "session_1" state = CachedSlotState( filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")}, ) await cache.set(tenant_id, session_id, state) local_key = f"{tenant_id}:{session_id}" assert local_key in cache._local_cache result = await cache.get(tenant_id, session_id) assert result is not None assert result.filled_slots["region"].value == "北京" @pytest.mark.asyncio async def test_delete(self): """测试删除缓存""" cache = SlotStateCache(redis_client=None) cache._enabled = False tenant_id = "tenant_1" session_id = "session_1" state = CachedSlotState( filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")}, ) await cache.set(tenant_id, session_id, state) await cache.delete(tenant_id, session_id) result = await cache.get(tenant_id, session_id) assert result is None @pytest.mark.asyncio async def test_clear_slot(self): """测试清除单个槽位""" cache = SlotStateCache(redis_client=None) cache._enabled = False tenant_id = "tenant_1" session_id = "session_1" state = CachedSlotState( filled_slots={ "region": CachedSlotValue(value="北京", source="user_confirmed"), "product": CachedSlotValue(value="手机", source="rule_extracted"), }, ) await cache.set(tenant_id, session_id, state) await cache.clear_slot(tenant_id, session_id, "region") result = await cache.get(tenant_id, session_id) assert result is not None assert "region" not in result.filled_slots assert "product" in result.filled_slots @pytest.mark.asyncio async def test_merge_and_set_priority(self): """测试合并时优先级处理""" cache = SlotStateCache(redis_client=None) cache._enabled = False tenant_id = "tenant_1" session_id = "session_1" existing_state = CachedSlotState( filled_slots={ "region": CachedSlotValue(value="上海", source="llm_inferred", confidence=0.6), }, ) await cache.set(tenant_id, session_id, existing_state) new_slots = { "region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0), } result = await cache.merge_and_set(tenant_id, session_id, new_slots) assert result.filled_slots["region"].value == "北京" assert result.filled_slots["region"].source == "user_confirmed" @pytest.mark.asyncio async def test_merge_and_set_lower_priority_ignored(self): """测试低优先级值被忽略""" cache = SlotStateCache(redis_client=None) cache._enabled = False tenant_id = "tenant_1" session_id = "session_1" existing_state = CachedSlotState( filled_slots={ "region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0), }, ) await cache.set(tenant_id, session_id, existing_state) new_slots = { "region": CachedSlotValue(value="上海", source="llm_inferred", confidence=0.6), } result = await cache.merge_and_set(tenant_id, session_id, new_slots) assert result.filled_slots["region"].value == "北京" assert result.filled_slots["region"].source == "user_confirmed" class TestGetSlotStateCache: """get_slot_state_cache 单例测试""" def test_singleton(self): """测试单例模式""" cache1 = get_slot_state_cache() cache2 = get_slot_state_cache() assert cache1 is cache2 class TestSlotStateCacheWithRedis: """SlotStateCache Redis 集成测试""" @pytest.mark.asyncio async def test_redis_set_and_get(self): """测试 Redis 存取""" mock_redis = AsyncMock() mock_redis.get = AsyncMock(return_value=None) mock_redis.setex = AsyncMock(return_value=True) cache = SlotStateCache(redis_client=mock_redis) tenant_id = "tenant_1" session_id = "session_1" state = CachedSlotState( filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")}, ) await cache.set(tenant_id, session_id, state) mock_redis.setex.assert_called_once() call_args = mock_redis.setex.call_args assert call_args[0][0] == f"slot_state:{tenant_id}:{session_id}" @pytest.mark.asyncio async def test_redis_get_hit(self): """测试 Redis 命中""" state_dict = { "filled_slots": { "region": { "value": "北京", "source": "user_confirmed", "confidence": 1.0, "updated_at": 12345.0, } }, "slot_to_field_map": {"region": "region_field"}, "created_at": 12340.0, "updated_at": 12345.0, } mock_redis = AsyncMock() mock_redis.get = AsyncMock(return_value=json.dumps(state_dict)) cache = SlotStateCache(redis_client=mock_redis) tenant_id = "tenant_1" session_id = "session_1" result = await cache.get(tenant_id, session_id) assert result is not None assert result.filled_slots["region"].value == "北京" assert result.filled_slots["region"].source == "user_confirmed" @pytest.mark.asyncio async def test_redis_delete(self): """测试 Redis 删除""" mock_redis = AsyncMock() mock_redis.delete = AsyncMock(return_value=1) cache = SlotStateCache(redis_client=mock_redis) tenant_id = "tenant_1" session_id = "session_1" await cache.delete(tenant_id, session_id) mock_redis.delete.assert_called_once_with(f"slot_state:{tenant_id}:{session_id}") class TestCacheTTL: """TTL 配置测试""" def test_default_ttl(self): """测试默认 TTL""" cache = SlotStateCache() assert cache._cache_ttl == 1800 def test_local_cache_ttl(self): """测试本地缓存 TTL""" cache = SlotStateCache() assert cache._local_cache_ttl == 300