diff --git a/modules/i3plus-core-apiservice/src/main/java/cn/estsh/i3plus/core/apiservice/websocket/MessageWebSocket.java b/modules/i3plus-core-apiservice/src/main/java/cn/estsh/i3plus/core/apiservice/websocket/MessageWebSocket.java index 1b6aa3a..2c6ca61 100644 --- a/modules/i3plus-core-apiservice/src/main/java/cn/estsh/i3plus/core/apiservice/websocket/MessageWebSocket.java +++ b/modules/i3plus-core-apiservice/src/main/java/cn/estsh/i3plus/core/apiservice/websocket/MessageWebSocket.java @@ -1,6 +1,7 @@ package cn.estsh.i3plus.core.apiservice.websocket; import cn.estsh.i3plus.platform.common.util.PlatformConstWords; +import io.netty.util.internal.ConcurrentSet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Component; @@ -8,7 +9,12 @@ import org.springframework.stereotype.Component; import javax.websocket.*; import javax.websocket.server.PathParam; import javax.websocket.server.ServerEndpoint; +import java.io.EOFException; import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -19,25 +25,38 @@ import java.util.concurrent.ConcurrentMap; * @CreateDate : 2018-11-24 16:57 * @Modify: **/ -@ServerEndpoint(value= PlatformConstWords.WEBSOCKET_URL + "/message-websocket/{userId}") +@ServerEndpoint(value = PlatformConstWords.WEBSOCKET_URL + "/message-websocket/{userId}/{userLoginSid}") @Component public class MessageWebSocket { private static final Logger LOGGER = LoggerFactory.getLogger(MessageWebSocket.class); - private long userId = 1L; + private long userId; + private String userLoginSid; + //websocket会话 private Session session; // 当前对象会话 private static int sendCount = 1; - //concurrent线程安全集合,存放客户端websocket对象 - private static ConcurrentMap webSocketSet = new ConcurrentHashMap(); + + // 用户会话消息 + private static ConcurrentMap> userSessionMap = new ConcurrentHashMap<>(); + // concurrent线程安全集合,存放客户端websocket对象 + private static ConcurrentMap webSocketMap = new ConcurrentHashMap<>(); @OnOpen - public void onOpen(@PathParam("userId")long userId, Session session){ + public void onOpen(@PathParam("userId")long userId,@PathParam("userLoginSid")String userLoginSid, Session session){ this.userId = userId; + this.userLoginSid = userLoginSid; this.session = session; - webSocketSet.put(userId,this); //在线人数添加 + ConcurrentSet sidSet = userSessionMap.get(userId); + if(sidSet == null){ + sidSet = new ConcurrentSet<>(); + } + sidSet.add(userLoginSid); + + userSessionMap.put(userId, sidSet); + webSocketMap.put(userLoginSid, this); //在线人数添加 LOGGER.info("{}加入!当前在线人数为{}",userId,getOnlineCount()); } @@ -46,7 +65,7 @@ public class MessageWebSocket { */ @OnClose public void onClose() { - subOnlineUser(this.userId); + subOnlineUser(this.userId, this.userLoginSid); LOGGER.info("有一连接关闭!当前在线人数为" + getOnlineCount()); } @@ -55,10 +74,10 @@ public class MessageWebSocket { * * @param message 客户端发送过来的消息*/ @OnMessage - public void onMessage(@PathParam("userId")Long userId,String message) { + public void onMessage(@PathParam("userLoginSid")String userLoginSid,String message) { // 心跳 if("heartBit".equals(message)){ - this.sendMessage(userId,"heartBit"); + this.sendMessage(userLoginSid,"heartBit"); }else{ LOGGER.info("来自客户端的消息:" , message); } @@ -70,25 +89,26 @@ public class MessageWebSocket { */ @OnError public void onError(Session session, Throwable error) { - LOGGER.info("发生错误"); - error.printStackTrace(); + if(error.getClass().equals(EOFException.class)){ + LOGGER.error("WebSocket连接已断开"); + }else { + LOGGER.error("发生错误",error.toString()); + error.printStackTrace(); + } } /** - * 发送消息 - * @param message - * @throws IOException + * 根据用户id发送消息 + * @param userId 用户id + * @param message 消息主体 */ public static void sendMessage(Long userId, String message){ try { - MessageWebSocket websocket = webSocketSet.get(userId); - synchronized (websocket){ - if(websocket != null) { - if (message.equals("heartBit")) { - websocket.session.getBasicRemote().sendText(message + "=" + sendCount); - sendCount++; - } else { - websocket.session.getBasicRemote().sendText(message); + ConcurrentSet websocket = userSessionMap.get(userId); + if (websocket != null && websocket.size() != 0) { + for (String sid : websocket) { + if(webSocketMap.get(sid).session.isOpen()){ + webSocketMap.get(sid).session.getBasicRemote().sendText(message); } } } @@ -97,11 +117,39 @@ public class MessageWebSocket { } } + /** + * 根据会话id发送消息 + * @param userLoginSid 会话id + * @param message 消息主体 + */ + public static void sendMessage(String userLoginSid, String message){ + try { + MessageWebSocket websocket = webSocketMap.get(userLoginSid); + if (websocket != null && websocket.session.isOpen()) { + if (message.equals("heartBit")) { + websocket.session.getBasicRemote().sendText(message + "=" + sendCount); + sendCount++; + } else { + websocket.session.getBasicRemote().sendText(message); + } + } + } catch (IOException e) { + e.printStackTrace(); + } + } + public static synchronized int getOnlineCount() { - return webSocketSet.size(); + return userSessionMap.size(); } - public synchronized void subOnlineUser(long userId) { - webSocketSet.remove(userId); + public synchronized void subOnlineUser(long userId, String userLoginSid) { + webSocketMap.remove(userId); + ConcurrentSet sidSet = userSessionMap.get(userId); + if(sidSet != null){ + sidSet.remove(userLoginSid); + if(sidSet.isEmpty()){ + userSessionMap.remove(userId); + } + } } }