From 2c9a5fa4b53955033b1bd2ac637dea0728808551 Mon Sep 17 00:00:00 2001
From: zj <1772600164@qq.com>
Date: Wed, 27 Mar 2024 20:01:13 +0800
Subject: [PATCH] 多链接

---
 websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java |  188 ++++++++++++++++++++++++++++++++++------------
 1 files changed, 138 insertions(+), 50 deletions(-)

diff --git a/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java b/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
index 0984ec6..95d9667 100644
--- a/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
+++ b/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
@@ -1,15 +1,35 @@
 package org.example.websocket.server;
 
+import cn.hutool.core.date.DateUnit;
+import cn.hutool.core.date.DateUtil;
+import cn.hutool.core.io.unit.DataUnit;
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.baomidou.mybatisplus.core.toolkit.StringUtils;
 import lombok.NonNull;
 import lombok.extern.slf4j.Slf4j;
+import org.example.dao.DataServiceKeyMapper;
+import org.example.pojo.DataServiceKey;
+import org.example.util.ApplicationContextRegisterUtil;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.ApplicationContext;
+import org.springframework.scheduling.annotation.Scheduled;
 import org.springframework.stereotype.Component;
 import org.springframework.web.bind.annotation.PostMapping;
 import org.springframework.web.bind.annotation.RequestBody;
 
 import javax.websocket.*;
+import javax.websocket.server.PathParam;
 import javax.websocket.server.ServerEndpoint;
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
 
 /**
  * @ClassDescription: websocket服务端
@@ -18,12 +38,11 @@
  */
 @Slf4j
 @Component
-//@RestController
 @ServerEndpoint("/websocket-server")
-//@ServerEndpoint("/")
 public class WsServer {
 
     private Session session;
+
     /**
      * 记录在线连接客户端数量
      */
@@ -31,41 +50,127 @@
     /**
      * 存放每个连接进来的客户端对应的websocketServer对象,用于后面群发消息
      */
-    private static CopyOnWriteArrayList<WsServer> wsServers = new CopyOnWriteArrayList<>();
+    private static ConcurrentHashMap<String, Session> webSocketMap = new ConcurrentHashMap<>();
+
+    private final Lock lock = new ReentrantLock();
+
+    /**
+     * 关闭过期链接
+     */
+    @Scheduled(cron = "0  0/1  *  *  *  ?")
+    public void clearExpiration() {
+        if (lock.tryLock()) {
+            log.info("webSocket关闭过期链接--------->开始");
+            try {
+                ApplicationContext act = ApplicationContextRegisterUtil.getApplicationContext();
+                DataServiceKeyMapper mapper = act.getBean(DataServiceKeyMapper.class);
+                LambdaQueryWrapper<DataServiceKey> queryWrapper = new LambdaQueryWrapper<>();
+                List<DataServiceKey> keyList = mapper.selectList(queryWrapper.lt(DataServiceKey::getExpirationTime, DateUtil.date())
+                        .or().eq(DataServiceKey::getIsAvailable, 0));
+
+                keyList.forEach(f -> {
+                    //  在锁内操作
+                    Session sessions = webSocketMap.get(f.getTokenKey());
+                    if (null != sessions) {
+                        try {
+                            sendInfo(sessions, 0);
+                            f.setIsAvailable(0);
+                            mapper.updateById(f);
+                        } catch (Exception e) {
+                            throw new RuntimeException(e);
+                        }
+                    }
+                    log.info("webSocket关闭过期链接---key:"+f.getTokenKey());
+                });
+            } catch (Exception e) {
+                throw new RuntimeException(e);
+            } finally {
+                lock.unlock();
+                log.info("webSocket关闭过期链接--------->结束");
+            }
+        } else {
+            log.info("webSocket关闭过期链接--------->上次任务还未执行完成,本次任务忽略");
+        }
+    }
+
 
     /**
      * 服务端与客户端连接成功时执行
+     *
      * @param session 会话
      */
     @OnOpen
-    public void onOpen(Session session){
-        this.session = session;
-        //接入的客户端+1
-        int count = onlineCount.incrementAndGet();
-        //集合中存入客户端对象+1
-        wsServers.add(this);
-        log.info("与客户端连接成功,当前连接的客户端数量为:{}", count);
+    public void onOpen(Session session) throws Exception {
+        try {
+            this.session = session;
+            //查询该用户有没有权限
+            Map<String, List<String>> params = session.getRequestParameterMap();
+            List<String> funcTypes = params.get("key");
+            // 取出funcType参数的值
+            if (null == funcTypes) {
+                sendInfo(session, 1);
+            }
+            String key = funcTypes.get(0);
+            if (StringUtils.isEmpty(key)) {
+                sendInfo(session, 1);
+            }
+            ApplicationContext act = ApplicationContextRegisterUtil.getApplicationContext();
+            DataServiceKeyMapper mapper = act.getBean(DataServiceKeyMapper.class);
+            DataServiceKey dataServiceKey = mapper.selectOne(new LambdaQueryWrapper<DataServiceKey>()
+                    .eq(DataServiceKey::getTokenKey, key)
+                    .gt(DataServiceKey::getExpirationTime, DateUtil.date())
+                    .eq(DataServiceKey::getIsAvailable, 1));
+            if (null != dataServiceKey) {
+                if (webSocketMap.containsKey(key)) {
+                    webSocketMap.put(key, session);
+                } else {
+                    //接入的客户端+1
+                    onlineCount.incrementAndGet();
+                    webSocketMap.put(key, session);
+                }
+
+                log.info("与客户端连接成功,当前连接的客户端数量为:{}", onlineCount);
+            } else {
+                sendInfo(session, 1);
+            }
+        } catch (Exception e) {
+            sendInfo(session, 1);
+            log.error("客户端连接错误:" + e.getMessage());
+        }
+    }
+
+
+    /**
+     * 关闭连接
+     */
+    public void sendInfo(Session session, int type) throws Exception {
+        if (session.getBasicRemote() != null) {
+            if (type == 1) {
+                session.getBasicRemote().sendText("订阅key不存在!");
+            }
+            session.close();
+        }
     }
 
     /**
      * 收到客户端的消息时执行
+     *
      * @param message 消息
      * @param session 会话
      */
     @OnMessage
-    public void onMessage(String message, Session session){
+    public void onMessage(Session session, String message) {
         log.info("收到来自客户端的消息,客户端地址:{},消息内容:{}", session.getMessageHandlers(), message);
-        //业务逻辑,对消息的处理
-//        sendMessageToAll("群发消息的内容");
     }
 
     /**
      * 连接发生报错时执行
-     * @param session 会话
+     *
+     * @param session   会话
      * @param throwable 报错
      */
     @OnError
-    public void onError(Session session, @NonNull Throwable throwable){
+    public void onError(Session session, @NonNull Throwable throwable) {
         log.error("连接发生报错");
         throwable.printStackTrace();
     }
@@ -74,53 +179,36 @@
      * 连接断开时执行
      */
     @OnClose
-    public void onClose(){
+    public void onClose() {
         //接入客户端连接数-1
-        int count = onlineCount.decrementAndGet();
+        if (webSocketMap.get(this.session) != null) {
+            onlineCount.decrementAndGet();
+        }
         //集合中的客户端对象-1
-        wsServers.remove(this);
-        log.info("服务端断开连接,当前连接的客户端数量为:{}", count);
+        webSocketMap.remove(this);
+        log.info("服务端断开连接,当前连接的客户端数量为:{}", onlineCount.get());
     }
-
-
-    /**
-     * 向客户端推送消息
-     * @param message 消息
-     */
-    public void sendMessage(String message){
-        this.session.getAsyncRemote().sendText(message);
-        log.info("推送消息给客户端:{},消息内容为:{}", this.session.getMessageHandlers(), message);
-    }
-
-//    @PostMapping("/send2c")
-//    public void sendMessage1(@RequestBody String message){
-//        this.session.getAsyncRemote().sendText(message);
-//        try {
-//            this.session.getBasicRemote().sendText(message);
-//        } catch (IOException e) {
-//            throw new RuntimeException(e);
-//        }
-//        log.info("推送消息给客户端,消息内容为:{}", message);
-//    }
 
 
     /**
      * 群发消息
+     *
      * @param message 消息
      */
-    public void sendMessageToAll(String message){
-        CopyOnWriteArrayList<WsServer> ws = wsServers;
-        for (WsServer wsServer : ws){
-            wsServer.sendMessage(message);
-        }
+    public void sendMessageToAll(String message) { // 发送消息给所有客户端
+        webSocketMap.values().forEach(session -> {
+            try {
+                session.getBasicRemote().sendText(message);
+            } catch (IOException e) {
+                log.error("发送消息给客户端失败,客户端ID为:{}", session.getId(), e);
+            }
+        });
+        log.info("推送消息给所有客户端,消息内容为:{}", message);
     }
 
     @PostMapping("/send2AllC")
-    public void sendMessageToAll1(@RequestBody String message){
-        CopyOnWriteArrayList<WsServer> ws = wsServers;
-        for (WsServer wsServer : ws){
-            wsServer.sendMessage(message);
-        }
+    public void sendMessageToAll1(@RequestBody String message) { // 发送消息给所有客户端
+        sendMessageToAll(message);
     }
 
 

--
Gitblit v1.9.3