package org.example.websocket.server; import cn.hutool.json.JSONUtil; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.module.SimpleModule; import com.google.common.reflect.TypeToken; import com.google.gson.Gson; import com.google.gson.GsonBuilder; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.example.pojo.ConfigCurrency; import org.example.pojo.MarketDataOut; import org.example.pojo.bo.WsBo; import org.example.util.RedisUtil; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; import javax.websocket.*; import javax.websocket.server.ServerEndpoint; import java.io.IOException; import java.math.BigDecimal; import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.stream.Collectors; /** * @ClassDescription: websocket服务端 * @JdkVersion: 1.8 * @Created: 2023/8/31 14:59 */ @Slf4j @Component @ServerEndpoint("/websocket-server") public class WsServer { private Session session; private static AtomicInteger onlineCount = new AtomicInteger(0); private static CopyOnWriteArraySet wsServers = new CopyOnWriteArraySet<>(); // 线程局部变量,用于存储每个线程的数据 private static final Map threadLocalData = new ConcurrentHashMap<>(); @Autowired @Qualifier("threadPoolTaskExecutor") private ThreadPoolTaskExecutor threadPoolTaskExecutor; // 定义常量:任务检查的超时时间(秒) private static final int SUBSCRIPTION_TIMEOUT_SECONDS = 30; private ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); private Map> scheduledTasks = new ConcurrentHashMap<>(); @OnOpen public void onOpen(Session session) { this.session = session; int count = onlineCount.incrementAndGet(); wsServers.add(this); log.info("与客户端连接成功,当前连接的客户端数量为:{}", count); // 设置定时任务,在SUBSCRIPTION_TIMEOUT_SECONDS秒后检查是否收到订阅消息 ScheduledFuture timeoutTask = scheduler.schedule(() -> { if (!hasReceivedSubscription(session)) { closeSession(session, "未及时发送订阅消息"); } }, SUBSCRIPTION_TIMEOUT_SECONDS, TimeUnit.SECONDS); scheduledTasks.put(session.getId(), timeoutTask); } private boolean hasReceivedSubscription(Session session) { WsBo wsBo = getWsBoForSession(session.getId()); String s = RedisUtil.get(wsBo.getUserId().toString()); if(StringUtils.isEmpty(s) || !wsBo.getToken().equals(s)){ closeSession(session, "用户未登录"); Map map = new HashMap<>(); map.put("status",1); pushMessage(session,JSONUtil.toJsonStr(map)); return false; } return wsBo != null; } @OnError public void onError(Session session, @NonNull Throwable throwable) { onClose(); log.error("连接发生报错: {}", throwable.getMessage()); throwable.printStackTrace(); } @OnClose public void onClose() { int count = onlineCount.decrementAndGet(); threadLocalData.remove(session.getId()); wsServers.remove(this); cancelScheduledTasks(); // 取消定时任务 log.info("服务端断开连接,当前连接的客户端数量为:{}", count); } private void cancelScheduledTasks() { ScheduledFuture future = scheduledTasks.remove(this.session.getId()); if (future != null) { future.cancel(true); // 取消定时任务 } } @OnMessage public void onMessage(String message, Session session) throws IOException { try { if(!message.equals("ping")){ WsBo bean = JSONUtil.toBean(message, WsBo.class); if(null == bean){ log.error("没有订阅消息"); closeSession(session,null); } String s = RedisUtil.get(bean.getUserId().toString()); if(StringUtils.isEmpty(s)){ log.error("未登录"); Map map = new HashMap<>(); map.put("status",1); pushMessage(session,JSONUtil.toJsonStr(map)); closeSession(session,null); } threadLocalData.put(session.getId(), bean); } }catch (Exception e){ log.error("客户段订阅消息格式错误"); } } private Map sessionLocks = new ConcurrentHashMap<>(); private Lock getSessionLock(String sessionId) { sessionLocks.putIfAbsent(sessionId, new ReentrantLock()); return sessionLocks.get(sessionId); } public void sendMessageToAll(String message) { List> futures = new ArrayList<>(); wsServers.forEach(ws -> { futures.add(CompletableFuture.runAsync(() -> { try { Session session = ws.session; if (session != null && session.isOpen()) { Lock sessionLock = getSessionLock(session.getId()); sessionLock.lock(); try { schedulePushMessage(session, message); } catch (Exception e) { e.printStackTrace(); closeSession(session, "发送消息异常,断开链接"); log.error("发送消息时出现异常: {}", e.getMessage()); } finally { sessionLock.unlock(); } } else { closeSession(session, "会话不存在或已关闭"); log.error("会话不存在或已关闭,无法发送消息"); } } catch (Exception e) { log.error("处理消息失败: {}", e.getMessage()); } }, threadPoolTaskExecutor)); }); // 等待所有任务执行完成 CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); } private WsBo getWsBoForSession(String sessionId) { return threadLocalData.get(sessionId); } private Map lastMessageTimeMap = new ConcurrentHashMap<>(); private void schedulePushMessage(Session session, String message) throws JsonProcessingException { WsBo wsBo = getWsBoForSession(session.getId()); if (wsBo != null) { String s = RedisUtil.get(wsBo.getUserId().toString()); if(StringUtils.isEmpty(s) || !s.equals(wsBo.getToken())){ Map map = new HashMap<>(); map.put("status",1); pushMessage(session,JSONUtil.toJsonStr(map)); closeSession(session,"登录状态失效"); } long currentTime = System.currentTimeMillis(); long lastMessageTime = lastMessageTimeMap.getOrDefault(session, 0L); int time = wsBo.getTime(); if(wsBo.getPushNow()){ message = megFiltration(wsBo,message); pushMessage(session, message); lastMessageTimeMap.put(session, currentTime); // 更新最后发送时间 wsBo.setPushNow(false); }else{ message = megFiltration(wsBo,message); if (currentTime - lastMessageTime >= time * 1000) { // 时间间隔达到要求,可以发送消息 pushMessage(session, message); lastMessageTimeMap.put(session, currentTime); // 更新最后发送时间 } } } } private static final Gson gson = new Gson(); private String megFiltration(WsBo wsBo,String message) throws JsonProcessingException { List redisValueMap = gson.fromJson(message, new TypeToken>() {}.getType()); Map map = new HashMap<>(); String key = "config_"; String value = RedisUtil.get(key + wsBo.getUserId()); List currencies = null; if(null != value && !value.isEmpty()){ ObjectMapper objectMapper = new ObjectMapper(); currencies = objectMapper.readValue(value, new TypeReference>() {}); } if (!CollectionUtils.isEmpty(currencies)) { Set filtrationSet = currencies.stream() .map(f -> f.getCurrency() + f.getBuy() + f.getSell()) .collect(Collectors.toSet()); redisValueMap.removeIf(data -> filtrationSet.contains(data.getBuyAndSell())); } //查询币种 if(StringUtils.isNotEmpty(wsBo.getCurrency())){ redisValueMap = redisValueMap.stream() .filter(data -> wsBo.getCurrency().equals(data.getBaseAsset())) .collect(Collectors.toList()); } //价差 redisValueMap = redisValueMap.stream() .filter(data -> Double.parseDouble(data.getSpread()) >= wsBo.getSpread() && Double.parseDouble(data.getSpread()) <= 1000) .collect(Collectors.toList()); //最低金额 if(null != wsBo.getMinAmount()){ redisValueMap = redisValueMap.stream() .filter(data -> new BigDecimal(data.getSellTotalPrice()).compareTo(new BigDecimal(wsBo.getMinAmount())) >= 0 ) .collect(Collectors.toList()); } //过滤平台 if(null != wsBo.getPlatformList()){ List list = Arrays.asList(wsBo.getPlatformList().split(",")); redisValueMap = redisValueMap.stream() .filter(data -> !list.contains(data.getBuyingPlatform()) && !list.contains(data.getSellPlatform())) .collect(Collectors.toList()); } //过滤数据 if(null != wsBo.getBuyAndSell()){ List list = Arrays.asList(wsBo.getBuyAndSell().split(",")); redisValueMap = redisValueMap.stream() .filter(data -> !list.contains(data.getBuyAndSell())) .collect(Collectors.toList()); } //自选标记 String mark = RedisUtil.get(wsBo.getUserId() + "_mark"); if(StringUtils.isNoneEmpty(mark)){ List list = Arrays.asList(mark.split(",")); redisValueMap.stream() .filter(data -> list.contains(data.getBuyAndSell())) .forEach(data -> data.setMarker(true)); } map.put("uuid",wsBo.getUuid()); map.put("current",wsBo.getCurrent()); map.put("sizes",wsBo.getSizes()); map.put("total",redisValueMap.size()); sortBySpread(redisValueMap); Integer current = 0; if(wsBo.getCurrent() != 1){ current = (wsBo.getCurrent() - 1) * wsBo.getSizes(); } // 确保 startIndex 在有效范围内 current = Math.min(current, redisValueMap.size()); // 计算子列表的结束索引 int endIndex = Math.min(current + wsBo.getSizes(), redisValueMap.size()); // 根据计算出的索引获取子列表 redisValueMap = redisValueMap.subList(current, endIndex); map.put("data",redisValueMap); Gson gson = new GsonBuilder().setPrettyPrinting().create(); String json = gson.toJson(map); return json; } public static void sortBySpread(List marketDataList) { Collections.sort(marketDataList, new Comparator() { @Override public int compare(MarketDataOut a, MarketDataOut b) { // 将买入总价和卖出总价转换为 BigDecimal BigDecimal buyTotalPriceA = new BigDecimal(a.getBuyTotalPrice()); BigDecimal buyTotalPriceB = new BigDecimal(b.getBuyTotalPrice()); BigDecimal sellTotalPriceA = new BigDecimal(a.getSellTotalPrice()); BigDecimal sellTotalPriceB = new BigDecimal(b.getSellTotalPrice()); BigDecimal spreadA = new BigDecimal(a.getSpread()); BigDecimal spreadB = new BigDecimal(b.getSpread()); // 检查 a 和 b 的买入总价或卖出总价是否大于 1000 boolean aBuyOrSellAbove1000 = buyTotalPriceA.compareTo(new BigDecimal("1000")) > 0 || sellTotalPriceA.compareTo(new BigDecimal("1000")) > 0; boolean bBuyOrSellAbove1000 = buyTotalPriceB.compareTo(new BigDecimal("1000")) > 0 || sellTotalPriceB.compareTo(new BigDecimal("1000")) > 0; if (aBuyOrSellAbove1000 && !bBuyOrSellAbove1000) { return -1; // a 应排在 b 前面 } else if (!aBuyOrSellAbove1000 && bBuyOrSellAbove1000) { return 1; // b 应排在 a 前面 } else { // a 和 b 都大于 1000 或都小于等于 1000 // 先按照 spread 的降序排列 int spreadComparison = spreadB.compareTo(spreadA); if (spreadComparison != 0) { return spreadComparison; // 先按照 spread 排序 } else { // 如果 spread 相同,再按照买入总价或卖出总价的升序排列 if (aBuyOrSellAbove1000) { // 对于买入总价或卖出总价大于 1000 的记录 return buyTotalPriceA.compareTo(buyTotalPriceB); // 按照买入总价升序排序 } else { // 对于买入总价或卖出总价小于等于 1000 的记录 return buyTotalPriceA.compareTo(buyTotalPriceB); // 可以选择卖出总价或买入总价的升序排序 } } } } }); } private void pushMessage(Session session, String message) { try { if (session != null && session.isOpen()) { session.getBasicRemote().sendText(message); } else { log.error("会话不存在或已关闭,无法推送消息"); } } catch (IOException e) { log.error("推送消息时出现IO异常: {}", e.getMessage()); } } // 关闭会话的方法 private void closeSession(Session session, String reason) { threadLocalData.remove(session.getId()); wsServers.remove(this); log.info(reason); onClose(); } }