From 3a4ca606fad5d286e8b0de99f39ffbea8ef3cc21 Mon Sep 17 00:00:00 2001
From: zj <1772600164@qq.com>
Date: Mon, 21 Oct 2024 10:33:30 +0800
Subject: [PATCH] 1

---
 websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java |  155 +++++++++++++++++++++++++++++++++++++++++----------
 1 files changed, 123 insertions(+), 32 deletions(-)

diff --git a/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java b/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
index a6c451f..6808b3c 100644
--- a/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
+++ b/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
@@ -17,10 +17,13 @@
 import org.example.util.RedisUtil;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Qualifier;
+import org.springframework.context.annotation.Bean;
+import org.springframework.scheduling.annotation.Async;
 import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
 import org.springframework.stereotype.Component;
 import org.springframework.util.CollectionUtils;
 
+import javax.annotation.PreDestroy;
 import javax.websocket.*;
 import javax.websocket.server.ServerEndpoint;
 import java.io.IOException;
@@ -49,7 +52,7 @@
     private static final Map<String, WsBo> threadLocalData = new ConcurrentHashMap<>();
 
     @Autowired
-    @Qualifier("threadPoolTaskExecutor")
+    @Qualifier("markthreadPoolTaskExecutor")
     private ThreadPoolTaskExecutor threadPoolTaskExecutor;
 
     // 定义常量:任务检查的超时时间(秒)
@@ -142,36 +145,114 @@
         return sessionLocks.get(sessionId);
     }
 
+//    public void sendMessageToAll(String message) {
+//        List<CompletableFuture<Void>> futures = new ArrayList<>();
+//        wsServers.forEach(ws -> {
+//            futures.add(CompletableFuture.runAsync(() -> {
+//                try {
+//                    Session session = ws.session;
+//                    if (session != null && session.isOpen()) {
+//                        Lock sessionLock = getSessionLock(session.getId());
+//                        sessionLock.lock();
+//                        try {
+//                            schedulePushMessage(session, message);
+//                        } catch (Exception e) {
+//                            e.printStackTrace();
+//                            closeSession(session, "发送消息异常,断开链接");
+//                            log.error("发送消息时出现异常: {}", e.getMessage());
+//                        } finally {
+//                            sessionLock.unlock();
+//                        }
+//                    } else {
+//                        closeSession(session, "会话不存在或已关闭");
+//                        log.error("会话不存在或已关闭,无法发送消息");
+//                    }
+//                } catch (Exception e) {
+//                    log.error("处理消息失败: {}", e.getMessage());
+//                }
+//            }, threadPoolTaskExecutor));
+//        });
+//
+//        // 等待所有任务执行完成
+//        CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
+//    }
     public void sendMessageToAll(String message) {
-        List<CompletableFuture<Void>> futures = new ArrayList<>();
+        List<Future<Void>> futures = new ArrayList<>();
+        ExecutorService executorService = Executors.newFixedThreadPool(100); // 使用固定大小的线程池
+
+        // 收集所有活动的会话
+        List<Session> activeSessions = new ArrayList<>();
         wsServers.forEach(ws -> {
-            futures.add(CompletableFuture.runAsync(() -> {
-                try {
-                    Session session = ws.session;
-                    if (session != null && session.isOpen()) {
-                        Lock sessionLock = getSessionLock(session.getId());
-                        sessionLock.lock();
-                        try {
-                            schedulePushMessage(session, message);
-                        } catch (Exception e) {
-                            e.printStackTrace();
-                            closeSession(session, "发送消息异常,断开链接");
-                            log.error("发送消息时出现异常: {}", e.getMessage());
-                        } finally {
-                            sessionLock.unlock();
-                        }
-                    } else {
-                        closeSession(session, "会话不存在或已关闭");
-                        log.error("会话不存在或已关闭,无法发送消息");
-                    }
-                } catch (Exception e) {
-                    log.error("处理消息失败: {}", e.getMessage());
-                }
-            }, threadPoolTaskExecutor));
+            Session session = ws.session;
+            if (session != null && session.isOpen()) {
+                activeSessions.add(session);
+            } else {
+                closeSession(session, "会话不存在或已关闭");
+                log.error("会话不存在或已关闭,无法发送消息");
+            }
         });
 
-        // 等待所有任务执行完成
-        CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
+        // 并发处理所有活动的会话
+        for (Session session : activeSessions) {
+            Future<Void> future = executorService.submit(() -> {
+                Lock sessionLock = getSessionLock(session.getId());
+                try {
+                    if (sessionLock.tryLock(100, TimeUnit.MILLISECONDS)) {
+                        try {
+                            schedulePushMessage(session, message);
+                        } finally {
+                            sessionLock.unlock(); // 确保锁只在这里释放
+                        }
+                    } else {
+                        log.error("无法获取锁,放弃对会话 {} 的操作", session.getId());
+                    }
+                } catch (InterruptedException e) {
+                    Thread.currentThread().interrupt(); // 重新设置中断状态
+                    log.error("线程被中断: {}", e.getMessage());
+                } catch (JsonProcessingException e) {
+                    log.error("JSON处理异常: {}", e.getMessage());
+                }
+                return null; // 需要返回一个值,Void 类型的
+            });
+
+            futures.add(future);
+        }
+
+        // 等待所有任务完成或超时
+        for (Future<Void> future : futures) {
+            try {
+                future.get(60, TimeUnit.SECONDS); // 指定超时时间
+            } catch (TimeoutException e) {
+                log.error("某个任务超时,可能未能完成: {}", e.getMessage());
+                // 这里可以选择是否取消该任务
+                future.cancel(true); // 取消任务,设置为 true 则会中断正在执行的线程
+            } catch (InterruptedException e) {
+                Thread.currentThread().interrupt(); // 重新设置中断状态
+                log.error("线程被中断: {}", e.getMessage());
+            } catch (ExecutionException e) {
+                log.error("处理会话时发生异常: {}", e.getCause().getMessage());
+            }
+        }
+
+        // 关闭线程池
+        executorService.shutdown();
+    }
+
+
+    @PreDestroy
+    public void shutdownExecutor() {
+        threadPoolTaskExecutor.shutdown(); // 先关闭线程池
+        try {
+            // 等待现有任务在指定时间内完成
+            if (!threadPoolTaskExecutor.getThreadPoolExecutor().awaitTermination(60, TimeUnit.SECONDS)) {
+                // 如果未能完成,则强制关闭
+                threadPoolTaskExecutor.getThreadPoolExecutor().shutdownNow();
+            }
+        } catch (InterruptedException e) {
+            // 如果当前线程被中断,则强制关闭
+            threadPoolTaskExecutor.getThreadPoolExecutor().shutdownNow();
+            Thread.currentThread().interrupt(); // 还原中断状态
+        }
     }
 
     private WsBo getWsBoForSession(String sessionId) {
@@ -213,6 +294,9 @@
             }
         }
     }
+
+
+
     private static final Gson gson = new Gson();
     private String megFiltration(WsBo wsBo,String message) throws JsonProcessingException {
         List<MarketDataOut> redisValueMap = gson.fromJson(message, new TypeToken<List<MarketDataOut>>() {}.getType());
@@ -229,7 +313,8 @@
                     .map(f -> f.getCurrency() + f.getBuy() + f.getSell()) //组合过滤 ,暂时不使用,直接过滤整个币种
 //                    .map(f -> f.getCurrency())
                     .collect(Collectors.toSet());
-            redisValueMap.removeIf(data -> filtrationSet.contains(data.getBaseAsset()));
+//            redisValueMap.removeIf(data -> filtrationSet.contains(data.getBaseAsset()));
+            redisValueMap.removeIf(data -> filtrationSet.contains(data.getBuyAndSell()));
         }
 
 
@@ -347,10 +432,16 @@
         }
     }
 
-    // 关闭会话的方法
     private void closeSession(Session session, String reason) {
-        wsServers.remove(this);
-        log.info(reason);
-        onClose();
+        try {
+            if (session != null && session.isOpen()) {
+                wsServers.remove(this);
+                log.info(reason);
+                session.close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, reason));
+            }
+        } catch (IOException e) {
+            log.error("关闭会话时出现异常: {}", e.getMessage());
+        }
     }
+
 }

--
Gitblit v1.9.3