1
zj
2024-07-23 22c359ea29b5ab086369f73ea3ed93529dbdc362
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package org.example.websocket.server;
 
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import java.io.ByteArrayOutputStream;
 
import javax.websocket.*;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.ByteBuffer;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.zip.Deflater;
import java.util.zip.DeflaterOutputStream;
 
/**
 * @ClassDescription: websocket服务端
 * @JdkVersion: 1.8
 * @Created: 2023/8/31 14:59
 */
@Slf4j
@Component
@ServerEndpoint("/websocket-server")
public class WsServer {
 
    private Session session;
    private static AtomicInteger onlineCount = new AtomicInteger(0);
    private static CopyOnWriteArraySet<WsServer> wsServers = new CopyOnWriteArraySet<>();
    @Autowired
    @Qualifier("threadPoolTaskExecutor")
    private ThreadPoolTaskExecutor threadPoolTaskExecutor;
 
    @OnOpen
    public void onOpen(Session session) {
        this.session = session;
        int count = onlineCount.incrementAndGet();
        wsServers.add(this);
        log.info("与客户端连接成功,当前连接的客户端数量为:{}", count);
    }
 
    @OnError
    public void onError(Session session, @NonNull Throwable throwable) {
        log.error("连接发生报错");
        throwable.printStackTrace();
    }
 
    @OnClose
    public void onClose() {
        int count = onlineCount.decrementAndGet();
        wsServers.remove(this);
        log.info("服务端断开连接,当前连接的客户端数量为:{}", count);
    }
 
    private Map<String, Lock> sessionLocks = new ConcurrentHashMap<>();
 
    private Lock getSessionLock(String sessionId) {
        sessionLocks.putIfAbsent(sessionId, new ReentrantLock());
        return sessionLocks.get(sessionId);
    }
 
    public void sendMessageToAll(String message) {
//        Map<String, Object> map = jsonToMap(message);
//        if (map.get("pid").equals("00000001")) {
//            System.out.println(message);
//        }
        try {
            List<Future<?>> futures = new ArrayList<>();
            wsServers.forEach(ws -> {
                Future<?> future = threadPoolTaskExecutor.submit(() -> {
                    Session session = ws.session;
                    if (session != null && session.isOpen()) {
                        Lock sessionLock = getSessionLock(session.getId());
                        sessionLock.lock();
                        try {
                            synchronized (session){
                                // 压缩消息
                                byte[] compressedData = compress(message);
 
                                // 发送压缩后的消息
                                session.getBasicRemote().sendBinary(ByteBuffer.wrap(compressedData));
//                                session.getBasicRemote().sendText(message);
                            }
                        } catch (Exception e) {
                            log.error("发送消息时出现异常: " + e.getMessage());
                        } finally {
                            sessionLock.unlock();
                        }
                    } else {
                        log.error("会话不存在或已关闭,无法发送消息");
                    }
                });
                futures.add(future);
            });
 
            //等待所有任务执行完成
            for (Future<?> future : futures) {
                try {
                    future.get();
                } catch (InterruptedException | ExecutionException e) {
                    log.error("发送消息时出现异常: " + e.getMessage());
                }
            }
        } catch (Exception e) {
            log.error("发送消息时出现异常: " + e.getMessage());
        }
    }
 
    private byte[] compress(String data) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        Deflater deflater = new Deflater(Deflater.BEST_COMPRESSION, true);
        try (DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(byteArrayOutputStream, deflater)) {
            deflaterOutputStream.write(data.getBytes("UTF-8"));
        }
        deflater.end();
        return byteArrayOutputStream.toByteArray();
    }
 
//
//    public static Map<String, Object> jsonToMap(String json) {
//        Gson gson = new Gson();
//        Type type = new TypeToken<Map<String, Object>>() {
//        }.getType();
//        return gson.fromJson(json, type);
//    }
}