From 42c55aa0fb937c45d28f02bed25dd93888d5a6e0 Mon Sep 17 00:00:00 2001
From: zj <1772600164@qq.com>
Date: Fri, 28 Jun 2024 17:04:42 +0800
Subject: [PATCH] 1

---
 websocketClient/src/main/java/org/example/config/AsyncConfiguration.java            |   28 +++++++++
 websocketSerivce/src/main/java/org/example/websocket/config/AsyncConfiguration.java |   28 +++++++++
 websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java           |   74 +++++++++++++++---------
 3 files changed, 102 insertions(+), 28 deletions(-)

diff --git a/websocketClient/src/main/java/org/example/config/AsyncConfiguration.java b/websocketClient/src/main/java/org/example/config/AsyncConfiguration.java
new file mode 100644
index 0000000..f77bada
--- /dev/null
+++ b/websocketClient/src/main/java/org/example/config/AsyncConfiguration.java
@@ -0,0 +1,28 @@
+package org.example.config;
+
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
+
+import java.util.concurrent.ThreadPoolExecutor;
+
+/**
+ * @program: dabaogp
+ * @description:
+ * @create: 2024-06-25 16:37
+ **/
+@Configuration
+public class AsyncConfiguration {
+
+    @Bean(name = "threadPoolTaskExecutor")
+    public ThreadPoolTaskExecutor threadPoolTaskExecutor() {
+        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
+        executor.setCorePoolSize(50);    //  核心线程数
+        executor.setMaxPoolSize(100);    //  最大线程数
+        executor.setQueueCapacity(300);    //  队列容量
+        executor.setKeepAliveSeconds(60);    //  线程空闲时的存活时间为60秒
+        executor.setThreadNamePrefix("MyThread-");    //  线程名称的前缀
+        executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());    //  使用  CallerRunsPolicy  拒绝策略
+        return executor;
+    }
+}
diff --git a/websocketSerivce/src/main/java/org/example/websocket/config/AsyncConfiguration.java b/websocketSerivce/src/main/java/org/example/websocket/config/AsyncConfiguration.java
new file mode 100644
index 0000000..a764f03
--- /dev/null
+++ b/websocketSerivce/src/main/java/org/example/websocket/config/AsyncConfiguration.java
@@ -0,0 +1,28 @@
+package org.example.websocket.config;
+
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
+
+import java.util.concurrent.ThreadPoolExecutor;
+
+/**
+ * @program: dabaogp
+ * @description:
+ * @create: 2024-06-25 16:37
+ **/
+@Configuration
+public class AsyncConfiguration {
+
+    @Bean(name = "threadPoolTaskExecutor")
+    public ThreadPoolTaskExecutor threadPoolTaskExecutor() {
+        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
+        executor.setCorePoolSize(5);    //  核心线程数
+        executor.setMaxPoolSize(100);    //  最大线程数
+        executor.setQueueCapacity(300);    //  队列容量
+        executor.setKeepAliveSeconds(60);    //  线程空闲时的存活时间为60秒
+        executor.setThreadNamePrefix("MyThread-");    //  线程名称的前缀
+        executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());    //  使用  CallerRunsPolicy  拒绝策略
+        return executor;
+    }
+}
diff --git a/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java b/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
index 08c8a2d..d27d0a5 100644
--- a/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
+++ b/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
@@ -4,6 +4,9 @@
 import com.google.gson.reflect.TypeToken;
 import lombok.NonNull;
 import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Qualifier;
+import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
 import org.springframework.stereotype.Component;
 import org.springframework.web.bind.annotation.PostMapping;
 import org.springframework.web.bind.annotation.RequestBody;
@@ -12,9 +15,14 @@
 import javax.websocket.server.ServerEndpoint;
 import java.io.IOException;
 import java.lang.reflect.Type;
+import java.text.SimpleDateFormat;
+import java.util.Date;
+import java.util.HashMap;
 import java.util.Map;
-import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.*;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
 
 /**
  * @ClassDescription: websocket服务端
@@ -28,7 +36,10 @@
 
     private Session session;
     private static AtomicInteger onlineCount = new AtomicInteger(0);
-    private static CopyOnWriteArrayList<WsServer> wsServers = new CopyOnWriteArrayList<>();
+    private static CopyOnWriteArraySet<WsServer> wsServers = new CopyOnWriteArraySet<>();
+    @Autowired
+    @Qualifier("threadPoolTaskExecutor")
+    private ThreadPoolTaskExecutor threadPoolTaskExecutor;
 
     @OnOpen
     public void onOpen(Session session) {
@@ -51,41 +62,48 @@
         log.info("服务端断开连接,当前连接的客户端数量为:{}", count);
     }
 
-    @OnMessage
-    public void sendMessage(String message) throws IOException {
+    private Map<String, Lock> sessionLocks = new ConcurrentHashMap<>();
+
+    private Lock getSessionLock(String sessionId) {
+        sessionLocks.putIfAbsent(sessionId, new ReentrantLock());
+        return sessionLocks.get(sessionId);
+    }
+
+    public void sendMessageToAll(String message) {
         Map<String, Object> map = jsonToMap(message);
-        if(map.get("pid").equals("00000001")){
+        if (map.get("pid").equals("00000001")) {
             System.out.println(message);
         }
         try {
-            if (session.isOpen()) {
-                this.session.getBasicRemote().sendText(message);
-            }
-        } catch (IOException e) {
-            throw new IOException("消息发送失败", e);
-        }
-    }
-
-    public void sendMessageToAll(String message) throws IOException {
-        for (WsServer wsServer : wsServers) {
-            wsServer.sendMessage(message);
-        }
-    }
-
-
-    @PostMapping("/send2AllC")
-    public void sendMessageToAll1(@RequestBody String message) throws  IOException {
-        CopyOnWriteArrayList<WsServer> ws = wsServers;
-        for (WsServer wsServer : ws){
-            wsServer.sendMessage(message);
+            wsServers.forEach(ws -> {
+                threadPoolTaskExecutor.execute(() -> {
+                    Session session = ws.session;
+                    if (session != null && session.isOpen()) {
+                        Lock sessionLock = getSessionLock(session.getId());
+                        sessionLock.lock();
+                        try {
+                            synchronized (session){
+                                session.getAsyncRemote().sendText(message);
+                            }
+                        } catch (Exception e) {
+                            log.error("发送消息时出现异常: " + e.getMessage());
+                        } finally {
+                            sessionLock.unlock();
+                        }
+                    } else {
+                        log.error("会话不存在或已关闭,无法发送消息");
+                    }
+                });
+            });
+        } catch (Exception e) {
+            log.error("发送消息时出现异常: " + e.getMessage());
         }
     }
 
     public static Map<String, Object> jsonToMap(String json) {
         Gson gson = new Gson();
-        Type type = new TypeToken<Map<String, Object>>(){}.getType();
+        Type type = new TypeToken<Map<String, Object>>() {
+        }.getType();
         return gson.fromJson(json, type);
     }
-
 }
-

--
Gitblit v1.9.3