package org.example.websocket.server; import cn.hutool.json.JSONUtil; import com.google.common.reflect.TypeToken; import com.google.gson.Gson; import com.google.gson.GsonBuilder; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.example.pojo.MarketDataOut; import org.example.pojo.bo.WsBo; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.stereotype.Component; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import javax.annotation.PostConstruct; import javax.websocket.*; import javax.websocket.server.ServerEndpoint; import java.io.IOException; import java.lang.reflect.Type; import java.math.BigDecimal; import java.nio.ByteBuffer; import java.text.SimpleDateFormat; import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.stream.Collectors; import java.util.zip.Deflater; import java.util.zip.DeflaterOutputStream; /** * @ClassDescription: websocket服务端 * @JdkVersion: 1.8 * @Created: 2023/8/31 14:59 */ @Slf4j @Component @ServerEndpoint("/websocket-server") public class WsServer { private Session session; private static AtomicInteger onlineCount = new AtomicInteger(0); private static CopyOnWriteArraySet wsServers = new CopyOnWriteArraySet<>(); // 线程局部变量,用于存储每个线程的数据 private static final Map threadLocalData = new ConcurrentHashMap<>(); @Autowired @Qualifier("threadPoolTaskExecutor") private ThreadPoolTaskExecutor threadPoolTaskExecutor; // 定义常量:任务检查的超时时间(秒) private static final int SUBSCRIPTION_TIMEOUT_SECONDS = 30; private ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); private Map> scheduledTasks = new ConcurrentHashMap<>(); @OnOpen public void onOpen(Session session) { this.session = session; int count = onlineCount.incrementAndGet(); wsServers.add(this); log.info("与客户端连接成功,当前连接的客户端数量为:{}", count); // 设置定时任务,在SUBSCRIPTION_TIMEOUT_SECONDS秒后检查是否收到订阅消息 ScheduledFuture timeoutTask = scheduler.schedule(() -> { if (!hasReceivedSubscription(session)) { closeSession(session, "未及时发送订阅消息"); } }, SUBSCRIPTION_TIMEOUT_SECONDS, TimeUnit.SECONDS); scheduledTasks.put(session.getId(), timeoutTask); } private boolean hasReceivedSubscription(Session session) { WsBo wsBo = getWsBoForSession(session.getId()); return wsBo != null; // 简化逻辑 } @OnError public void onError(Session session, @NonNull Throwable throwable) { log.error("连接发生报错: {}", throwable.getMessage()); throwable.printStackTrace(); } @OnClose public void onClose() { int count = onlineCount.decrementAndGet(); wsServers.remove(this); cancelScheduledTasks(); // 取消定时任务 log.info("服务端断开连接,当前连接的客户端数量为:{}", count); } private void cancelScheduledTasks() { ScheduledFuture future = scheduledTasks.remove(this.session.getId()); if (future != null) { future.cancel(true); // 取消定时任务 } } @OnMessage public void onMessage(String message, Session session) throws IOException { try { WsBo bean = JSONUtil.toBean(message, WsBo.class); threadLocalData.put(session.getId(), bean); }catch (Exception e){ log.error("客户段订阅消息格式错误"); } } private Map sessionLocks = new ConcurrentHashMap<>(); private Lock getSessionLock(String sessionId) { sessionLocks.putIfAbsent(sessionId, new ReentrantLock()); return sessionLocks.get(sessionId); } public void sendMessageToAll(String message) { List> futures = new ArrayList<>(); wsServers.forEach(ws -> { futures.add(CompletableFuture.runAsync(() -> { try { Session session = ws.session; if (session != null && session.isOpen()) { Lock sessionLock = getSessionLock(session.getId()); sessionLock.lock(); try { schedulePushMessage(session, message); } catch (Exception e) { log.error("发送消息时出现异常: {}", e.getMessage()); } finally { sessionLock.unlock(); } } else { log.error("会话不存在或已关闭,无法发送消息"); } } catch (Exception e) { log.error("处理消息失败: {}", e.getMessage()); } }, threadPoolTaskExecutor)); }); // 等待所有任务执行完成 CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); } private WsBo getWsBoForSession(String sessionId) { return threadLocalData.get(sessionId); } private Map lastMessageTimeMap = new ConcurrentHashMap<>(); private void schedulePushMessage(Session session, String message) { WsBo wsBo = getWsBoForSession(session.getId()); if (wsBo != null) { long currentTime = System.currentTimeMillis(); long lastMessageTime = lastMessageTimeMap.getOrDefault(session, 0L); int time = wsBo.getTime(); message = megFiltration(wsBo,message); if (currentTime - lastMessageTime >= time * 1000) { // 时间间隔达到要求,可以发送消息 pushMessage(session, message); lastMessageTimeMap.put(session, currentTime); // 更新最后发送时间 } else { // 时间间隔未达到,不发送消息,可以记录日志或者其他操作 log.info("距离上次发送消息时间未达到指定间隔,不发送消息。"); } } } private static final Gson gson = new Gson(); private String megFiltration(WsBo wsBo,String message){ List redisValueMap = gson.fromJson(message, new TypeToken>() {}.getType()); //查询币种 if(null != wsBo.getCurrency()){ redisValueMap = redisValueMap.stream() .filter(data -> wsBo.getCurrency().equals(data.getBaseAsset())) .collect(Collectors.toList()); } //价差 if(wsBo.getSpread() > 0){ redisValueMap = redisValueMap.stream() .filter(data -> Double.parseDouble(data.getSpread()) >= wsBo.getSpread()) .collect(Collectors.toList()); } //最低金额 if(null != wsBo.getMinAmount()){ redisValueMap = redisValueMap.stream() .filter(data -> new BigDecimal(data.getSellTotalPrice()).compareTo(new BigDecimal(wsBo.getMinAmount())) >= 0 ) .collect(Collectors.toList()); } //过滤平台 if(null != wsBo.getPlatformList()){ List list = Arrays.asList(wsBo.getPlatformList().split(",")); redisValueMap = redisValueMap.stream() .filter(data -> !list.contains(data.getBuyingPlatform()) && !list.contains(data.getSellPlatform())) .collect(Collectors.toList()); } //自选标记 if(null != wsBo.getIsMarker()){ List list = Arrays.asList(wsBo.getIsMarker().split(",")); redisValueMap.stream() .filter(data -> list.contains(data.getBaseAsset())) .forEach(data -> data.setMarker(true)); } Gson gson = new GsonBuilder().setPrettyPrinting().create(); String json = gson.toJson(redisValueMap); return json; } private void pushMessage(Session session, String message) { try { if (session != null && session.isOpen()) { session.getBasicRemote().sendText(message); } else { log.error("会话不存在或已关闭,无法推送消息"); } } catch (IOException e) { log.error("推送消息时出现IO异常: {}", e.getMessage()); } } // 关闭会话的方法 private void closeSession(Session session, String reason) { try { session.close(new CloseReason(CloseReason.CloseCodes.UNEXPECTED_CONDITION, reason)); } catch (IOException e) { log.error("强制断开连接----异常: {}", e.getMessage()); } wsServers.remove(this); log.info("客户端未及时发送订阅消息,断开连接"); } }