package org.example.websocket.server;
|
|
import cn.hutool.core.date.DateUnit;
|
import cn.hutool.core.date.DateUtil;
|
import cn.hutool.core.io.unit.DataUnit;
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
|
import lombok.NonNull;
|
import lombok.extern.slf4j.Slf4j;
|
import org.example.dao.DataServiceKeyMapper;
|
import org.example.pojo.DataServiceKey;
|
import org.example.util.ApplicationContextRegisterUtil;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.context.ApplicationContext;
|
import org.springframework.scheduling.annotation.Scheduled;
|
import org.springframework.stereotype.Component;
|
import org.springframework.web.bind.annotation.PostMapping;
|
import org.springframework.web.bind.annotation.RequestBody;
|
|
import javax.websocket.*;
|
import javax.websocket.server.PathParam;
|
import javax.websocket.server.ServerEndpoint;
|
import java.io.IOException;
|
import java.util.HashSet;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.Set;
|
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.CopyOnWriteArrayList;
|
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.locks.Lock;
|
import java.util.concurrent.locks.ReentrantLock;
|
|
/**
|
* @ClassDescription: websocket服务端
|
* @JdkVersion: 1.8
|
* @Created: 2023/8/31 14:59
|
*/
|
@Slf4j
|
@Component
|
@ServerEndpoint("/websocket-server")
|
public class WsServer {
|
|
private Session session;
|
|
/**
|
* 记录在线连接客户端数量
|
*/
|
private static AtomicInteger onlineCount = new AtomicInteger(0);
|
/**
|
* 存放每个连接进来的客户端对应的websocketServer对象,用于后面群发消息
|
*/
|
private static ConcurrentHashMap<String, Session> webSocketMap = new ConcurrentHashMap<>();
|
|
private final Lock lock = new ReentrantLock();
|
|
/**
|
* 关闭过期链接
|
*/
|
@Scheduled(cron = "0 0/1 * * * ?")
|
public void clearExpiration() {
|
if (lock.tryLock()) {
|
log.info("webSocket关闭过期链接--------->开始");
|
try {
|
ApplicationContext act = ApplicationContextRegisterUtil.getApplicationContext();
|
DataServiceKeyMapper mapper = act.getBean(DataServiceKeyMapper.class);
|
LambdaQueryWrapper<DataServiceKey> queryWrapper = new LambdaQueryWrapper<>();
|
List<DataServiceKey> keyList = mapper.selectList(queryWrapper.lt(DataServiceKey::getExpirationTime, DateUtil.date())
|
.or().eq(DataServiceKey::getIsAvailable, 0));
|
|
keyList.forEach(f -> {
|
// 在锁内操作
|
Session sessions = webSocketMap.get(f.getTokenKey());
|
if (null != sessions) {
|
try {
|
sendInfo(sessions, 0);
|
f.setIsAvailable(0);
|
mapper.updateById(f);
|
} catch (Exception e) {
|
throw new RuntimeException(e);
|
}
|
}
|
log.info("webSocket关闭过期链接---key:"+f.getTokenKey());
|
});
|
} catch (Exception e) {
|
throw new RuntimeException(e);
|
} finally {
|
lock.unlock();
|
log.info("webSocket关闭过期链接--------->结束");
|
}
|
} else {
|
log.info("webSocket关闭过期链接--------->上次任务还未执行完成,本次任务忽略");
|
}
|
}
|
|
|
/**
|
* 服务端与客户端连接成功时执行
|
*
|
* @param session 会话
|
*/
|
@OnOpen
|
public void onOpen(Session session) throws Exception {
|
try {
|
this.session = session;
|
//查询该用户有没有权限
|
Map<String, List<String>> params = session.getRequestParameterMap();
|
List<String> funcTypes = params.get("key");
|
// 取出funcType参数的值
|
if (null == funcTypes) {
|
sendInfo(session, 1);
|
}
|
String key = funcTypes.get(0);
|
if (StringUtils.isEmpty(key)) {
|
sendInfo(session, 1);
|
}
|
ApplicationContext act = ApplicationContextRegisterUtil.getApplicationContext();
|
DataServiceKeyMapper mapper = act.getBean(DataServiceKeyMapper.class);
|
DataServiceKey dataServiceKey = mapper.selectOne(new LambdaQueryWrapper<DataServiceKey>()
|
.eq(DataServiceKey::getTokenKey, key)
|
.gt(DataServiceKey::getExpirationTime, DateUtil.date())
|
.eq(DataServiceKey::getIsAvailable, 1));
|
if (null != dataServiceKey) {
|
if (webSocketMap.containsKey(key)) {
|
webSocketMap.put(key, session);
|
} else {
|
//接入的客户端+1
|
onlineCount.incrementAndGet();
|
webSocketMap.put(key, session);
|
}
|
|
log.info("与客户端连接成功,当前连接的客户端数量为:{}", onlineCount);
|
} else {
|
sendInfo(session, 1);
|
}
|
} catch (Exception e) {
|
sendInfo(session, 1);
|
log.error("客户端连接错误:" + e.getMessage());
|
}
|
}
|
|
|
/**
|
* 关闭连接
|
*/
|
public void sendInfo(Session session, int type) throws Exception {
|
if (session.getBasicRemote() != null) {
|
if (type == 1) {
|
session.getBasicRemote().sendText("订阅key不存在!");
|
}
|
session.close();
|
}
|
}
|
|
/**
|
* 收到客户端的消息时执行
|
*
|
* @param message 消息
|
* @param session 会话
|
*/
|
@OnMessage
|
public void onMessage(Session session, String message) {
|
log.info("收到来自客户端的消息,客户端地址:{},消息内容:{}", session.getMessageHandlers(), message);
|
}
|
|
/**
|
* 连接发生报错时执行
|
*
|
* @param session 会话
|
* @param throwable 报错
|
*/
|
@OnError
|
public void onError(Session session, @NonNull Throwable throwable) {
|
log.error("连接发生报错");
|
throwable.printStackTrace();
|
}
|
|
/**
|
* 连接断开时执行
|
*/
|
@OnClose
|
public void onClose() {
|
//接入客户端连接数-1
|
if (webSocketMap.get(this.session) != null) {
|
onlineCount.decrementAndGet();
|
}
|
//集合中的客户端对象-1
|
webSocketMap.remove(this);
|
log.info("服务端断开连接,当前连接的客户端数量为:{}", onlineCount.get());
|
}
|
|
|
/**
|
* 群发消息
|
*
|
* @param message 消息
|
*/
|
public void sendMessageToAll(String message) { // 发送消息给所有客户端
|
webSocketMap.values().forEach(session -> {
|
try {
|
session.getBasicRemote().sendText(message);
|
} catch (IOException e) {
|
log.error("发送消息给客户端失败,客户端ID为:{}", session.getId(), e);
|
}
|
});
|
log.info("推送消息给所有客户端,消息内容为:{}", message);
|
}
|
|
@PostMapping("/send2AllC")
|
public void sendMessageToAll1(@RequestBody String message) { // 发送消息给所有客户端
|
sendMessageToAll(message);
|
}
|
|
|
}
|