package org.example.websocket.server; import cn.hutool.json.JSONUtil; import com.google.gson.Gson; import com.google.gson.reflect.TypeToken; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.example.pojo.bo.WsBo; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.stereotype.Component; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import javax.annotation.PostConstruct; import javax.websocket.*; import javax.websocket.server.ServerEndpoint; import java.io.IOException; import java.lang.reflect.Type; import java.nio.ByteBuffer; import java.text.SimpleDateFormat; import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.zip.Deflater; import java.util.zip.DeflaterOutputStream; /** * @ClassDescription: websocket服务端 * @JdkVersion: 1.8 * @Created: 2023/8/31 14:59 */ @Slf4j @Component @ServerEndpoint("/websocket-server") public class WsServer { private Session session; private static AtomicInteger onlineCount = new AtomicInteger(0); private static CopyOnWriteArraySet wsServers = new CopyOnWriteArraySet<>(); // 线程局部变量,用于存储每个线程的数据 private static final Map threadLocalData = new ConcurrentHashMap<>(); @Autowired @Qualifier("threadPoolTaskExecutor") private ThreadPoolTaskExecutor threadPoolTaskExecutor; // 定义常量:任务检查的超时时间(秒) private static final int SUBSCRIPTION_TIMEOUT_SECONDS = 30; private ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); private Map> scheduledTasks = new ConcurrentHashMap<>(); @OnOpen public void onOpen(Session session) { this.session = session; int count = onlineCount.incrementAndGet(); wsServers.add(this); log.info("与客户端连接成功,当前连接的客户端数量为:{}", count); // 设置定时任务,在SUBSCRIPTION_TIMEOUT_SECONDS秒后检查是否收到订阅消息 ScheduledFuture timeoutTask = scheduler.schedule(() -> { if (!hasReceivedSubscription(session)) { closeSession(session, "未及时发送订阅消息"); } }, SUBSCRIPTION_TIMEOUT_SECONDS, TimeUnit.SECONDS); scheduledTasks.put(session.getId(), timeoutTask); } private boolean hasReceivedSubscription(Session session) { WsBo wsBo = getWsBoForSession(session.getId()); return wsBo != null; // 简化逻辑 } @OnError public void onError(Session session, @NonNull Throwable throwable) { log.error("连接发生报错: {}", throwable.getMessage()); throwable.printStackTrace(); } @OnClose public void onClose() { int count = onlineCount.decrementAndGet(); wsServers.remove(this); cancelScheduledTasks(); // 取消定时任务 log.info("服务端断开连接,当前连接的客户端数量为:{}", count); } private void cancelScheduledTasks() { ScheduledFuture future = scheduledTasks.remove(this.session.getId()); if (future != null) { future.cancel(true); // 取消定时任务 } ScheduledFuture task = pushTasks.remove(this.session.getId()); if (task != null) { task.cancel(true); // 取消推送任务 } } @OnMessage public void onMessage(String message, Session session) throws IOException { try { WsBo bean = JSONUtil.toBean(message, WsBo.class); threadLocalData.put(session.getId(), bean); }catch (Exception e){ log.error("客户段订阅消息格式错误"); } } private Map sessionLocks = new ConcurrentHashMap<>(); private Lock getSessionLock(String sessionId) { sessionLocks.putIfAbsent(sessionId, new ReentrantLock()); return sessionLocks.get(sessionId); } public void sendMessageToAll(String message) { List> futures = new ArrayList<>(); wsServers.forEach(ws -> { futures.add(CompletableFuture.runAsync(() -> { try { Session session = ws.session; if (session != null && session.isOpen()) { Lock sessionLock = getSessionLock(session.getId()); sessionLock.lock(); try { WsBo wsBo = getWsBoForSession(session.getId()); if (wsBo != null) { int intervalSeconds = wsBo.getTime(); pushMessageWithInterval(session, message, intervalSeconds); } } catch (Exception e) { log.error("发送消息时出现异常: {}", e.getMessage()); } finally { sessionLock.unlock(); } } else { log.error("会话不存在或已关闭,无法发送消息"); } } catch (Exception e) { log.error("处理消息失败: {}", e.getMessage()); } }, threadPoolTaskExecutor)); }); // 等待所有任务执行完成 CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); } private WsBo getWsBoForSession(String sessionId) { return threadLocalData.get(sessionId); } private Map> pushTasks = new ConcurrentHashMap<>(); ScheduledExecutorService pushScheduler = Executors.newScheduledThreadPool(1); private void pushMessageWithInterval(Session session, String message, int intervalSeconds) { // 创建一个可以控制任务状态的 AtomicBoolean intervalSeconds = 5; AtomicBoolean isActive = new AtomicBoolean(true); ScheduledFuture future = pushScheduler.scheduleAtFixedRate(() -> { if (isActive.get()) { // 检查是否应该继续执行 try { pushMessage(session, message); } catch (Exception e) { log.error("推送消息时出现异常: {}", e.getMessage()); isActive.set(false); // 出现异常则停止任务 } } }, 0, intervalSeconds, TimeUnit.SECONDS); // 保存任务的引用,以便后续取消 pushTasks.put(session.getId(), future); } public void pushMessage(Session session, String message) throws IOException { session.getBasicRemote().sendText(message); } // 关闭会话的方法 private void closeSession(Session session, String reason) { try { session.close(new CloseReason(CloseReason.CloseCodes.UNEXPECTED_CONDITION, reason)); } catch (IOException e) { log.error("强制断开连接----异常: {}", e.getMessage()); } wsServers.remove(this); log.info("客户端未及时发送订阅消息,断开连接"); } }