| | |
| | | package org.example.websocket.server; |
| | | |
| | | import cn.hutool.json.JSONUtil; |
| | | import com.google.gson.Gson; |
| | | import com.google.gson.reflect.TypeToken; |
| | | import lombok.NonNull; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.example.pojo.bo.WsBo; |
| | | import org.springframework.beans.factory.annotation.Autowired; |
| | | import org.springframework.beans.factory.annotation.Qualifier; |
| | | import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.web.bind.annotation.PostMapping; |
| | | import org.springframework.web.bind.annotation.RequestBody; |
| | | import java.io.ByteArrayOutputStream; |
| | | |
| | | import javax.annotation.PostConstruct; |
| | | import javax.websocket.*; |
| | | import javax.websocket.server.ServerEndpoint; |
| | | import java.io.IOException; |
| | |
| | | import java.text.SimpleDateFormat; |
| | | import java.util.*; |
| | | import java.util.concurrent.*; |
| | | import java.util.concurrent.atomic.AtomicBoolean; |
| | | import java.util.concurrent.atomic.AtomicInteger; |
| | | import java.util.concurrent.locks.Lock; |
| | | import java.util.concurrent.locks.ReentrantLock; |
| | |
| | | private Session session; |
| | | private static AtomicInteger onlineCount = new AtomicInteger(0); |
| | | private static CopyOnWriteArraySet<WsServer> wsServers = new CopyOnWriteArraySet<>(); |
| | | // 线程局部变量,用于存储每个线程的数据 |
| | | private static final Map<String, WsBo> threadLocalData = new ConcurrentHashMap<>(); |
| | | |
| | | @Autowired |
| | | @Qualifier("threadPoolTaskExecutor") |
| | | private ThreadPoolTaskExecutor threadPoolTaskExecutor; |
| | | |
| | | // 定义常量:任务检查的超时时间(秒) |
| | | private static final int SUBSCRIPTION_TIMEOUT_SECONDS = 30; |
| | | |
| | | private ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); |
| | | private Map<String, ScheduledFuture<?>> scheduledTasks = new ConcurrentHashMap<>(); |
| | | |
| | | @OnOpen |
| | | public void onOpen(Session session) { |
| | |
| | | int count = onlineCount.incrementAndGet(); |
| | | wsServers.add(this); |
| | | log.info("与客户端连接成功,当前连接的客户端数量为:{}", count); |
| | | |
| | | // 设置定时任务,在SUBSCRIPTION_TIMEOUT_SECONDS秒后检查是否收到订阅消息 |
| | | ScheduledFuture<?> timeoutTask = scheduler.schedule(() -> { |
| | | if (!hasReceivedSubscription(session)) { |
| | | closeSession(session, "未及时发送订阅消息"); |
| | | } |
| | | }, SUBSCRIPTION_TIMEOUT_SECONDS, TimeUnit.SECONDS); |
| | | scheduledTasks.put(session.getId(), timeoutTask); |
| | | } |
| | | |
| | | private boolean hasReceivedSubscription(Session session) { |
| | | WsBo wsBo = getWsBoForSession(session.getId()); |
| | | return wsBo != null; // 简化逻辑 |
| | | } |
| | | |
| | | @OnError |
| | | public void onError(Session session, @NonNull Throwable throwable) { |
| | | log.error("连接发生报错"); |
| | | log.error("连接发生报错: {}", throwable.getMessage()); |
| | | throwable.printStackTrace(); |
| | | } |
| | | |
| | |
| | | public void onClose() { |
| | | int count = onlineCount.decrementAndGet(); |
| | | wsServers.remove(this); |
| | | cancelScheduledTasks(); // 取消定时任务 |
| | | log.info("服务端断开连接,当前连接的客户端数量为:{}", count); |
| | | } |
| | | |
| | | private void cancelScheduledTasks() { |
| | | ScheduledFuture<?> future = scheduledTasks.remove(this.session.getId()); |
| | | if (future != null) { |
| | | future.cancel(true); // 取消定时任务 |
| | | } |
| | | ScheduledFuture<?> task = pushTasks.remove(this.session.getId()); |
| | | if (task != null) { |
| | | task.cancel(true); // 取消推送任务 |
| | | } |
| | | } |
| | | |
| | | @OnMessage |
| | | public void onMessage(String message, Session session) throws IOException { |
| | | try { |
| | | WsBo bean = JSONUtil.toBean(message, WsBo.class); |
| | | threadLocalData.put(session.getId(), bean); |
| | | }catch (Exception e){ |
| | | log.error("客户段订阅消息格式错误"); |
| | | } |
| | | } |
| | | |
| | | private Map<String, Lock> sessionLocks = new ConcurrentHashMap<>(); |
| | |
| | | } |
| | | |
| | | public void sendMessageToAll(String message) { |
| | | // Map<String, Object> map = jsonToMap(message); |
| | | // if (map.get("pid").equals("00000001")) { |
| | | // System.out.println(message); |
| | | // } |
| | | try { |
| | | List<Future<?>> futures = new ArrayList<>(); |
| | | wsServers.forEach(ws -> { |
| | | Future<?> future = threadPoolTaskExecutor.submit(() -> { |
| | | List<CompletableFuture<Void>> futures = new ArrayList<>(); |
| | | wsServers.forEach(ws -> { |
| | | futures.add(CompletableFuture.runAsync(() -> { |
| | | try { |
| | | Session session = ws.session; |
| | | if (session != null && session.isOpen()) { |
| | | Lock sessionLock = getSessionLock(session.getId()); |
| | | sessionLock.lock(); |
| | | try { |
| | | synchronized (session){ |
| | | // 压缩消息 |
| | | byte[] compressedData = compress(message); |
| | | |
| | | // 发送压缩后的消息 |
| | | session.getBasicRemote().sendBinary(ByteBuffer.wrap(compressedData)); |
| | | // session.getBasicRemote().sendText(message); |
| | | WsBo wsBo = getWsBoForSession(session.getId()); |
| | | if (wsBo != null) { |
| | | int intervalSeconds = wsBo.getTime(); |
| | | pushMessageWithInterval(session, message, intervalSeconds); |
| | | } |
| | | } catch (Exception e) { |
| | | log.error("发送消息时出现异常: " + e.getMessage()); |
| | | log.error("发送消息时出现异常: {}", e.getMessage()); |
| | | } finally { |
| | | sessionLock.unlock(); |
| | | } |
| | | } else { |
| | | log.error("会话不存在或已关闭,无法发送消息"); |
| | | } |
| | | }); |
| | | futures.add(future); |
| | | }); |
| | | } catch (Exception e) { |
| | | log.error("处理消息失败: {}", e.getMessage()); |
| | | } |
| | | }, threadPoolTaskExecutor)); |
| | | }); |
| | | |
| | | //等待所有任务执行完成 |
| | | for (Future<?> future : futures) { |
| | | // 等待所有任务执行完成 |
| | | CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); |
| | | } |
| | | |
| | | private WsBo getWsBoForSession(String sessionId) { |
| | | return threadLocalData.get(sessionId); |
| | | } |
| | | |
| | | private Map<String, ScheduledFuture<?>> pushTasks = new ConcurrentHashMap<>(); |
| | | ScheduledExecutorService pushScheduler = Executors.newScheduledThreadPool(1); |
| | | |
| | | private void pushMessageWithInterval(Session session, String message, int intervalSeconds) { |
| | | // 创建一个可以控制任务状态的 AtomicBoolean |
| | | intervalSeconds = 5; |
| | | AtomicBoolean isActive = new AtomicBoolean(true); |
| | | ScheduledFuture<?> future = pushScheduler.scheduleAtFixedRate(() -> { |
| | | if (isActive.get()) { // 检查是否应该继续执行 |
| | | try { |
| | | future.get(); |
| | | } catch (InterruptedException | ExecutionException e) { |
| | | log.error("发送消息时出现异常: " + e.getMessage()); |
| | | pushMessage(session, message); |
| | | } catch (Exception e) { |
| | | log.error("推送消息时出现异常: {}", e.getMessage()); |
| | | isActive.set(false); // 出现异常则停止任务 |
| | | } |
| | | } |
| | | } catch (Exception e) { |
| | | log.error("发送消息时出现异常: " + e.getMessage()); |
| | | } |
| | | }, 0, intervalSeconds, TimeUnit.SECONDS); |
| | | |
| | | // 保存任务的引用,以便后续取消 |
| | | pushTasks.put(session.getId(), future); |
| | | } |
| | | |
| | | private byte[] compress(String data) throws IOException { |
| | | ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); |
| | | Deflater deflater = new Deflater(Deflater.BEST_COMPRESSION, true); |
| | | try (DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(byteArrayOutputStream, deflater)) { |
| | | deflaterOutputStream.write(data.getBytes("UTF-8")); |
| | | } |
| | | deflater.end(); |
| | | return byteArrayOutputStream.toByteArray(); |
| | | public void pushMessage(Session session, String message) throws IOException { |
| | | session.getBasicRemote().sendText(message); |
| | | } |
| | | |
| | | // |
| | | // public static Map<String, Object> jsonToMap(String json) { |
| | | // Gson gson = new Gson(); |
| | | // Type type = new TypeToken<Map<String, Object>>() { |
| | | // }.getType(); |
| | | // return gson.fromJson(json, type); |
| | | // } |
| | | // 关闭会话的方法 |
| | | private void closeSession(Session session, String reason) { |
| | | try { |
| | | session.close(new CloseReason(CloseReason.CloseCodes.UNEXPECTED_CONDITION, reason)); |
| | | } catch (IOException e) { |
| | | log.error("强制断开连接----异常: {}", e.getMessage()); |
| | | } |
| | | wsServers.remove(this); |
| | | log.info("客户端未及时发送订阅消息,断开连接"); |
| | | } |
| | | } |