From 2c9a5fa4b53955033b1bd2ac637dea0728808551 Mon Sep 17 00:00:00 2001
From: zj <1772600164@qq.com>
Date: Wed, 27 Mar 2024 20:01:13 +0800
Subject: [PATCH] 多链接
---
websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java | 188 ++++++++++++++++++++++++++++++++++------------
1 files changed, 138 insertions(+), 50 deletions(-)
diff --git a/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java b/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
index 0984ec6..95d9667 100644
--- a/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
+++ b/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
@@ -1,15 +1,35 @@
package org.example.websocket.server;
+import cn.hutool.core.date.DateUnit;
+import cn.hutool.core.date.DateUtil;
+import cn.hutool.core.io.unit.DataUnit;
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
+import org.example.dao.DataServiceKeyMapper;
+import org.example.pojo.DataServiceKey;
+import org.example.util.ApplicationContextRegisterUtil;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.ApplicationContext;
+import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import javax.websocket.*;
+import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
/**
* @ClassDescription: websocket服务端
@@ -18,12 +38,11 @@
*/
@Slf4j
@Component
-//@RestController
@ServerEndpoint("/websocket-server")
-//@ServerEndpoint("/")
public class WsServer {
private Session session;
+
/**
* 记录在线连接客户端数量
*/
@@ -31,41 +50,127 @@
/**
* 存放每个连接进来的客户端对应的websocketServer对象,用于后面群发消息
*/
- private static CopyOnWriteArrayList<WsServer> wsServers = new CopyOnWriteArrayList<>();
+ private static ConcurrentHashMap<String, Session> webSocketMap = new ConcurrentHashMap<>();
+
+ private final Lock lock = new ReentrantLock();
+
+ /**
+ * 关闭过期链接
+ */
+ @Scheduled(cron = "0 0/1 * * * ?")
+ public void clearExpiration() {
+ if (lock.tryLock()) {
+ log.info("webSocket关闭过期链接--------->开始");
+ try {
+ ApplicationContext act = ApplicationContextRegisterUtil.getApplicationContext();
+ DataServiceKeyMapper mapper = act.getBean(DataServiceKeyMapper.class);
+ LambdaQueryWrapper<DataServiceKey> queryWrapper = new LambdaQueryWrapper<>();
+ List<DataServiceKey> keyList = mapper.selectList(queryWrapper.lt(DataServiceKey::getExpirationTime, DateUtil.date())
+ .or().eq(DataServiceKey::getIsAvailable, 0));
+
+ keyList.forEach(f -> {
+ // 在锁内操作
+ Session sessions = webSocketMap.get(f.getTokenKey());
+ if (null != sessions) {
+ try {
+ sendInfo(sessions, 0);
+ f.setIsAvailable(0);
+ mapper.updateById(f);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ log.info("webSocket关闭过期链接---key:"+f.getTokenKey());
+ });
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ } finally {
+ lock.unlock();
+ log.info("webSocket关闭过期链接--------->结束");
+ }
+ } else {
+ log.info("webSocket关闭过期链接--------->上次任务还未执行完成,本次任务忽略");
+ }
+ }
+
/**
* 服务端与客户端连接成功时执行
+ *
* @param session 会话
*/
@OnOpen
- public void onOpen(Session session){
- this.session = session;
- //接入的客户端+1
- int count = onlineCount.incrementAndGet();
- //集合中存入客户端对象+1
- wsServers.add(this);
- log.info("与客户端连接成功,当前连接的客户端数量为:{}", count);
+ public void onOpen(Session session) throws Exception {
+ try {
+ this.session = session;
+ //查询该用户有没有权限
+ Map<String, List<String>> params = session.getRequestParameterMap();
+ List<String> funcTypes = params.get("key");
+ // 取出funcType参数的值
+ if (null == funcTypes) {
+ sendInfo(session, 1);
+ }
+ String key = funcTypes.get(0);
+ if (StringUtils.isEmpty(key)) {
+ sendInfo(session, 1);
+ }
+ ApplicationContext act = ApplicationContextRegisterUtil.getApplicationContext();
+ DataServiceKeyMapper mapper = act.getBean(DataServiceKeyMapper.class);
+ DataServiceKey dataServiceKey = mapper.selectOne(new LambdaQueryWrapper<DataServiceKey>()
+ .eq(DataServiceKey::getTokenKey, key)
+ .gt(DataServiceKey::getExpirationTime, DateUtil.date())
+ .eq(DataServiceKey::getIsAvailable, 1));
+ if (null != dataServiceKey) {
+ if (webSocketMap.containsKey(key)) {
+ webSocketMap.put(key, session);
+ } else {
+ //接入的客户端+1
+ onlineCount.incrementAndGet();
+ webSocketMap.put(key, session);
+ }
+
+ log.info("与客户端连接成功,当前连接的客户端数量为:{}", onlineCount);
+ } else {
+ sendInfo(session, 1);
+ }
+ } catch (Exception e) {
+ sendInfo(session, 1);
+ log.error("客户端连接错误:" + e.getMessage());
+ }
+ }
+
+
+ /**
+ * 关闭连接
+ */
+ public void sendInfo(Session session, int type) throws Exception {
+ if (session.getBasicRemote() != null) {
+ if (type == 1) {
+ session.getBasicRemote().sendText("订阅key不存在!");
+ }
+ session.close();
+ }
}
/**
* 收到客户端的消息时执行
+ *
* @param message 消息
* @param session 会话
*/
@OnMessage
- public void onMessage(String message, Session session){
+ public void onMessage(Session session, String message) {
log.info("收到来自客户端的消息,客户端地址:{},消息内容:{}", session.getMessageHandlers(), message);
- //业务逻辑,对消息的处理
-// sendMessageToAll("群发消息的内容");
}
/**
* 连接发生报错时执行
- * @param session 会话
+ *
+ * @param session 会话
* @param throwable 报错
*/
@OnError
- public void onError(Session session, @NonNull Throwable throwable){
+ public void onError(Session session, @NonNull Throwable throwable) {
log.error("连接发生报错");
throwable.printStackTrace();
}
@@ -74,53 +179,36 @@
* 连接断开时执行
*/
@OnClose
- public void onClose(){
+ public void onClose() {
//接入客户端连接数-1
- int count = onlineCount.decrementAndGet();
+ if (webSocketMap.get(this.session) != null) {
+ onlineCount.decrementAndGet();
+ }
//集合中的客户端对象-1
- wsServers.remove(this);
- log.info("服务端断开连接,当前连接的客户端数量为:{}", count);
+ webSocketMap.remove(this);
+ log.info("服务端断开连接,当前连接的客户端数量为:{}", onlineCount.get());
}
-
-
- /**
- * 向客户端推送消息
- * @param message 消息
- */
- public void sendMessage(String message){
- this.session.getAsyncRemote().sendText(message);
- log.info("推送消息给客户端:{},消息内容为:{}", this.session.getMessageHandlers(), message);
- }
-
-// @PostMapping("/send2c")
-// public void sendMessage1(@RequestBody String message){
-// this.session.getAsyncRemote().sendText(message);
-// try {
-// this.session.getBasicRemote().sendText(message);
-// } catch (IOException e) {
-// throw new RuntimeException(e);
-// }
-// log.info("推送消息给客户端,消息内容为:{}", message);
-// }
/**
* 群发消息
+ *
* @param message 消息
*/
- public void sendMessageToAll(String message){
- CopyOnWriteArrayList<WsServer> ws = wsServers;
- for (WsServer wsServer : ws){
- wsServer.sendMessage(message);
- }
+ public void sendMessageToAll(String message) { // 发送消息给所有客户端
+ webSocketMap.values().forEach(session -> {
+ try {
+ session.getBasicRemote().sendText(message);
+ } catch (IOException e) {
+ log.error("发送消息给客户端失败,客户端ID为:{}", session.getId(), e);
+ }
+ });
+ log.info("推送消息给所有客户端,消息内容为:{}", message);
}
@PostMapping("/send2AllC")
- public void sendMessageToAll1(@RequestBody String message){
- CopyOnWriteArrayList<WsServer> ws = wsServers;
- for (WsServer wsServer : ws){
- wsServer.sendMessage(message);
- }
+ public void sendMessageToAll1(@RequestBody String message) { // 发送消息给所有客户端
+ sendMessageToAll(message);
}
--
Gitblit v1.9.3