package com.wecom.robot.websocket; import com.alibaba.fastjson.JSON; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; import org.springframework.web.socket.*; import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @Slf4j @Component public class CsWebSocketHandler implements WebSocketHandler { private static final Map csSessions = new ConcurrentHashMap<>(); private static final Map sessionToCsMap = new ConcurrentHashMap<>(); @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { String csId = extractCsId(session); if (csId != null) { csSessions.put(csId, session); log.info("客服WebSocket连接建立: csId={}, sessionId={}", csId, session.getId()); } } @Override public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { if (message instanceof TextMessage) { String payload = ((TextMessage) message).getPayload(); log.debug("收到WebSocket消息: {}", payload); Map msgMap = JSON.parseObject(payload, Map.class); String type = (String) msgMap.get("type"); if ("bind_session".equals(type)) { String sessionId = (String) msgMap.get("sessionId"); String csId = extractCsId(session); if (sessionId != null && csId != null) { sessionToCsMap.put(sessionId, csId); log.info("绑定会话: sessionId={}, csId={}", sessionId, csId); } } } } @Override public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { log.error("WebSocket传输错误: sessionId={}", session.getId(), exception); } @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { String csId = extractCsId(session); if (csId != null) { csSessions.remove(csId); sessionToCsMap.entrySet().removeIf(entry -> csId.equals(entry.getValue())); log.info("客服WebSocket连接关闭: csId={}, status={}", csId, status); } } @Override public boolean supportsPartialMessages() { return false; } public void sendMessageToCs(String csId, Object message) { WebSocketSession session = csSessions.get(csId); if (session != null && session.isOpen()) { try { String json = JSON.toJSONString(message); session.sendMessage(new TextMessage(json)); log.debug("发送消息给客服: csId={}, message={}", csId, json); } catch (IOException e) { log.error("发送WebSocket消息失败: csId={}", csId, e); } } else { log.warn("客服不在线: csId={}", csId); } } public void broadcastToAll(Object message) { String json = JSON.toJSONString(message); TextMessage textMessage = new TextMessage(json); csSessions.values().forEach(session -> { if (session.isOpen()) { try { session.sendMessage(textMessage); } catch (IOException e) { log.error("广播消息失败: sessionId={}", session.getId(), e); } } }); } public void sendMessageToSession(String sessionId, Object message) { String csId = sessionToCsMap.get(sessionId); if (csId != null) { sendMessageToCs(csId, message); } else { log.warn("会话未绑定客服: sessionId={}", sessionId); } } private String extractCsId(WebSocketSession session) { String path = session.getUri().getPath(); String[] parts = path.split("/"); if (parts.length >= 4) { return parts[3]; } return null; } }