1
zj
2024-04-02 4c155dc4b01f39c4c2bd06885ea202f77912a91f
websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
@@ -1,15 +1,35 @@
package org.example.websocket.server;
import cn.hutool.core.date.DateUnit;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.io.unit.DataUnit;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.example.dao.DataServiceKeyMapper;
import org.example.pojo.DataServiceKey;
import org.example.util.ApplicationContextRegisterUtil;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
/**
 * @ClassDescription: websocket服务端
@@ -18,12 +38,11 @@
 */
@Slf4j
@Component
//@RestController
@ServerEndpoint("/websocket-server")
//@ServerEndpoint("/")
public class WsServer {
    private Session session;
    /**
     * 记录在线连接客户端数量
     */
@@ -31,41 +50,127 @@
    /**
     * 存放每个连接进来的客户端对应的websocketServer对象,用于后面群发消息
     */
    private static CopyOnWriteArrayList<WsServer> wsServers = new CopyOnWriteArrayList<>();
    private static ConcurrentHashMap<String, Session> webSocketMap = new ConcurrentHashMap<>();
    private final Lock lock = new ReentrantLock();
    /**
     * 关闭过期链接
     */
    @Scheduled(cron = "0  0/1  *  *  *  ?")
    public void clearExpiration() {
        if (lock.tryLock()) {
            log.info("webSocket关闭过期链接--------->开始");
            try {
                ApplicationContext act = ApplicationContextRegisterUtil.getApplicationContext();
                DataServiceKeyMapper mapper = act.getBean(DataServiceKeyMapper.class);
                LambdaQueryWrapper<DataServiceKey> queryWrapper = new LambdaQueryWrapper<>();
                List<DataServiceKey> keyList = mapper.selectList(queryWrapper.lt(DataServiceKey::getExpirationTime, DateUtil.date())
                        .or().eq(DataServiceKey::getIsAvailable, 0));
                keyList.forEach(f -> {
                    //  在锁内操作
                    Session sessions = webSocketMap.get(f.getTokenKey());
                    if (null != sessions) {
                        try {
                            sendInfo(sessions, 0);
                            f.setIsAvailable(0);
                            mapper.updateById(f);
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                    log.info("webSocket关闭过期链接---key:"+f.getTokenKey());
                });
            } catch (Exception e) {
                throw new RuntimeException(e);
            } finally {
                lock.unlock();
                log.info("webSocket关闭过期链接--------->结束");
            }
        } else {
            log.info("webSocket关闭过期链接--------->上次任务还未执行完成,本次任务忽略");
        }
    }
    /**
     * 服务端与客户端连接成功时执行
     *
     * @param session 会话
     */
    @OnOpen
    public void onOpen(Session session){
        this.session = session;
        //接入的客户端+1
        int count = onlineCount.incrementAndGet();
        //集合中存入客户端对象+1
        wsServers.add(this);
        log.info("与客户端连接成功,当前连接的客户端数量为:{}", count);
    public void onOpen(Session session) throws Exception {
        try {
            this.session = session;
            //查询该用户有没有权限
            Map<String, List<String>> params = session.getRequestParameterMap();
            List<String> funcTypes = params.get("key");
            // 取出funcType参数的值
            if (null == funcTypes) {
                sendInfo(session, 1);
            }
            String key = funcTypes.get(0);
            if (StringUtils.isEmpty(key)) {
                sendInfo(session, 1);
            }
            ApplicationContext act = ApplicationContextRegisterUtil.getApplicationContext();
            DataServiceKeyMapper mapper = act.getBean(DataServiceKeyMapper.class);
            DataServiceKey dataServiceKey = mapper.selectOne(new LambdaQueryWrapper<DataServiceKey>()
                    .eq(DataServiceKey::getTokenKey, key)
                    .gt(DataServiceKey::getExpirationTime, DateUtil.date())
                    .eq(DataServiceKey::getIsAvailable, 1));
            if (null != dataServiceKey) {
                if (webSocketMap.containsKey(key)) {
                    webSocketMap.put(key, session);
                } else {
                    //接入的客户端+1
                    onlineCount.incrementAndGet();
                    webSocketMap.put(key, session);
                }
                log.info("与客户端连接成功,当前连接的客户端数量为:{}", onlineCount);
            } else {
                sendInfo(session, 1);
            }
        } catch (Exception e) {
            sendInfo(session, 1);
            log.error("客户端连接错误:" + e.getMessage());
        }
    }
    /**
     * 关闭连接
     */
    public void sendInfo(Session session, int type) throws Exception {
        if (session.getBasicRemote() != null) {
            if (type == 1) {
                session.getBasicRemote().sendText("订阅key不存在!");
            }
            session.close();
        }
    }
    /**
     * 收到客户端的消息时执行
     *
     * @param message 消息
     * @param session 会话
     */
    @OnMessage
    public void onMessage(String message, Session session){
    public void onMessage(Session session, String message) {
        log.info("收到来自客户端的消息,客户端地址:{},消息内容:{}", session.getMessageHandlers(), message);
        //业务逻辑,对消息的处理
//        sendMessageToAll("群发消息的内容");
    }
    /**
     * 连接发生报错时执行
     * @param session 会话
     *
     * @param session   会话
     * @param throwable 报错
     */
    @OnError
    public void onError(Session session, @NonNull Throwable throwable){
    public void onError(Session session, @NonNull Throwable throwable) {
        log.error("连接发生报错");
        throwable.printStackTrace();
    }
@@ -74,53 +179,36 @@
     * 连接断开时执行
     */
    @OnClose
    public void onClose(){
    public void onClose() {
        //接入客户端连接数-1
        int count = onlineCount.decrementAndGet();
        if (webSocketMap.get(this.session) != null) {
            onlineCount.decrementAndGet();
        }
        //集合中的客户端对象-1
        wsServers.remove(this);
        log.info("服务端断开连接,当前连接的客户端数量为:{}", count);
        webSocketMap.remove(this);
        log.info("服务端断开连接,当前连接的客户端数量为:{}", onlineCount.get());
    }
    /**
     * 向客户端推送消息
     * @param message 消息
     */
    public void sendMessage(String message){
        this.session.getAsyncRemote().sendText(message);
        log.info("推送消息给客户端:{},消息内容为:{}", this.session.getMessageHandlers(), message);
    }
//    @PostMapping("/send2c")
//    public void sendMessage1(@RequestBody String message){
//        this.session.getAsyncRemote().sendText(message);
//        try {
//            this.session.getBasicRemote().sendText(message);
//        } catch (IOException e) {
//            throw new RuntimeException(e);
//        }
//        log.info("推送消息给客户端,消息内容为:{}", message);
//    }
    /**
     * 群发消息
     *
     * @param message 消息
     */
    public void sendMessageToAll(String message){
        CopyOnWriteArrayList<WsServer> ws = wsServers;
        for (WsServer wsServer : ws){
            wsServer.sendMessage(message);
        }
    public void sendMessageToAll(String message) { // 发送消息给所有客户端
        webSocketMap.values().forEach(session -> {
            try {
                session.getBasicRemote().sendText(message);
            } catch (IOException e) {
                log.error("发送消息给客户端失败,客户端ID为:{}", session.getId(), e);
            }
        });
        log.info("推送消息给所有客户端,消息内容为:{}", message);
    }
    @PostMapping("/send2AllC")
    public void sendMessageToAll1(@RequestBody String message){
        CopyOnWriteArrayList<WsServer> ws = wsServers;
        for (WsServer wsServer : ws){
            wsServer.sendMessage(message);
        }
    public void sendMessageToAll1(@RequestBody String message) { // 发送消息给所有客户端
        sendMessageToAll(message);
    }