package org.example.websocket.server; import cn.hutool.core.date.DateUnit; import cn.hutool.core.date.DateUtil; import cn.hutool.core.io.unit.DataUnit; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.toolkit.StringUtils; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.example.dao.DataServiceKeyMapper; import org.example.pojo.DataServiceKey; import org.example.util.ApplicationContextRegisterUtil; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import javax.websocket.*; import javax.websocket.server.PathParam; import javax.websocket.server.ServerEndpoint; import java.io.IOException; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; /** * @ClassDescription: websocket服务端 * @JdkVersion: 1.8 * @Created: 2023/8/31 14:59 */ @Slf4j @Component @ServerEndpoint("/websocket-server") public class WsServer { private Session session; /** * 记录在线连接客户端数量 */ private static AtomicInteger onlineCount = new AtomicInteger(0); /** * 存放每个连接进来的客户端对应的websocketServer对象,用于后面群发消息 */ private static ConcurrentHashMap webSocketMap = new ConcurrentHashMap<>(); private final Lock lock = new ReentrantLock(); /** * 关闭过期链接 */ @Scheduled(cron = "0 0/1 * * * ?") public void clearExpiration() { if (lock.tryLock()) { log.info("webSocket关闭过期链接--------->开始"); try { ApplicationContext act = ApplicationContextRegisterUtil.getApplicationContext(); DataServiceKeyMapper mapper = act.getBean(DataServiceKeyMapper.class); LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); List keyList = mapper.selectList(queryWrapper.lt(DataServiceKey::getExpirationTime, DateUtil.date()) .or().eq(DataServiceKey::getIsAvailable, 0)); keyList.forEach(f -> { // 在锁内操作 Session sessions = webSocketMap.get(f.getTokenKey()); if (null != sessions) { try { sendInfo(sessions, 0); f.setIsAvailable(0); mapper.updateById(f); } catch (Exception e) { throw new RuntimeException(e); } } log.info("webSocket关闭过期链接---key:"+f.getTokenKey()); }); } catch (Exception e) { throw new RuntimeException(e); } finally { lock.unlock(); log.info("webSocket关闭过期链接--------->结束"); } } else { log.info("webSocket关闭过期链接--------->上次任务还未执行完成,本次任务忽略"); } } /** * 服务端与客户端连接成功时执行 * * @param session 会话 */ @OnOpen public void onOpen(Session session) throws Exception { try { this.session = session; //查询该用户有没有权限 Map> params = session.getRequestParameterMap(); List funcTypes = params.get("key"); // 取出funcType参数的值 if (null == funcTypes) { sendInfo(session, 1); } String key = funcTypes.get(0); if (StringUtils.isEmpty(key)) { sendInfo(session, 1); } ApplicationContext act = ApplicationContextRegisterUtil.getApplicationContext(); DataServiceKeyMapper mapper = act.getBean(DataServiceKeyMapper.class); DataServiceKey dataServiceKey = mapper.selectOne(new LambdaQueryWrapper() .eq(DataServiceKey::getTokenKey, key) .gt(DataServiceKey::getExpirationTime, DateUtil.date()) .eq(DataServiceKey::getIsAvailable, 1)); if (null != dataServiceKey) { if (webSocketMap.containsKey(key)) { webSocketMap.put(key, session); } else { //接入的客户端+1 onlineCount.incrementAndGet(); webSocketMap.put(key, session); } log.info("与客户端连接成功,当前连接的客户端数量为:{}", onlineCount); } else { sendInfo(session, 1); } } catch (Exception e) { sendInfo(session, 1); log.error("客户端连接错误:" + e.getMessage()); } } /** * 关闭连接 */ public void sendInfo(Session session, int type) throws Exception { if (session.getBasicRemote() != null) { if (type == 1) { session.getBasicRemote().sendText("订阅key不存在!"); } session.close(); } } /** * 收到客户端的消息时执行 * * @param message 消息 * @param session 会话 */ @OnMessage public void onMessage(Session session, String message) { log.info("收到来自客户端的消息,客户端地址:{},消息内容:{}", session.getMessageHandlers(), message); } /** * 连接发生报错时执行 * * @param session 会话 * @param throwable 报错 */ @OnError public void onError(Session session, @NonNull Throwable throwable) { log.error("连接发生报错"); throwable.printStackTrace(); } /** * 连接断开时执行 */ @OnClose public void onClose() { //接入客户端连接数-1 if (webSocketMap.get(this.session) != null) { onlineCount.decrementAndGet(); } //集合中的客户端对象-1 webSocketMap.remove(this); log.info("服务端断开连接,当前连接的客户端数量为:{}", onlineCount.get()); } /** * 群发消息 * * @param message 消息 */ public void sendMessageToAll(String message) { // 发送消息给所有客户端 webSocketMap.values().forEach(session -> { try { session.getBasicRemote().sendText(message); } catch (IOException e) { log.error("发送消息给客户端失败,客户端ID为:{}", session.getId(), e); } }); log.info("推送消息给所有客户端,消息内容为:{}", message); } @PostMapping("/send2AllC") public void sendMessageToAll1(@RequestBody String message) { // 发送消息给所有客户端 sendMessageToAll(message); } }