ai-robot-channel/src/main/java/com/wecom/robot/websocket/CsWebSocketHandler.java

116 lines
4.0 KiB
Java

package com.wecom.robot.websocket;
import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.*;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
@Component
public class CsWebSocketHandler implements WebSocketHandler {
private static final Map<String, WebSocketSession> csSessions = new ConcurrentHashMap<>();
private static final Map<String, String> sessionToCsMap = new ConcurrentHashMap<>();
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
String csId = extractCsId(session);
if (csId != null) {
csSessions.put(csId, session);
log.info("客服WebSocket连接建立: csId={}, sessionId={}", csId, session.getId());
}
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
if (message instanceof TextMessage) {
String payload = ((TextMessage) message).getPayload();
log.debug("收到WebSocket消息: {}", payload);
Map<String, Object> msgMap = JSON.parseObject(payload, Map.class);
String type = (String) msgMap.get("type");
if ("bind_session".equals(type)) {
String sessionId = (String) msgMap.get("sessionId");
String csId = extractCsId(session);
if (sessionId != null && csId != null) {
sessionToCsMap.put(sessionId, csId);
log.info("绑定会话: sessionId={}, csId={}", sessionId, csId);
}
}
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
log.error("WebSocket传输错误: sessionId={}", session.getId(), exception);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
String csId = extractCsId(session);
if (csId != null) {
csSessions.remove(csId);
sessionToCsMap.entrySet().removeIf(entry -> csId.equals(entry.getValue()));
log.info("客服WebSocket连接关闭: csId={}, status={}", csId, status);
}
}
@Override
public boolean supportsPartialMessages() {
return false;
}
public void sendMessageToCs(String csId, Object message) {
WebSocketSession session = csSessions.get(csId);
if (session != null && session.isOpen()) {
try {
String json = JSON.toJSONString(message);
session.sendMessage(new TextMessage(json));
log.debug("发送消息给客服: csId={}, message={}", csId, json);
} catch (IOException e) {
log.error("发送WebSocket消息失败: csId={}", csId, e);
}
} else {
log.warn("客服不在线: csId={}", csId);
}
}
public void broadcastToAll(Object message) {
String json = JSON.toJSONString(message);
TextMessage textMessage = new TextMessage(json);
csSessions.values().forEach(session -> {
if (session.isOpen()) {
try {
session.sendMessage(textMessage);
} catch (IOException e) {
log.error("广播消息失败: sessionId={}", session.getId(), e);
}
}
});
}
public void sendMessageToSession(String sessionId, Object message) {
String csId = sessionToCsMap.get(sessionId);
if (csId != null) {
sendMessageToCs(csId, message);
} else {
log.warn("会话未绑定客服: sessionId={}", sessionId);
}
}
private String extractCsId(WebSocketSession session) {
String path = session.getUri().getPath();
String[] parts = path.split("/");
if (parts.length >= 4) {
return parts[3];
}
return null;
}
}