package org.example.websocket.server; import com.google.gson.Gson; import com.google.gson.reflect.TypeToken; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.stereotype.Component; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import javax.websocket.*; import javax.websocket.server.ServerEndpoint; import java.io.IOException; import java.lang.reflect.Type; import java.text.SimpleDateFormat; import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; /** * @ClassDescription: websocket服务端 * @JdkVersion: 1.8 * @Created: 2023/8/31 14:59 */ @Slf4j @Component @ServerEndpoint("/websocket-server") public class WsServer { private Session session; private static AtomicInteger onlineCount = new AtomicInteger(0); private static CopyOnWriteArraySet wsServers = new CopyOnWriteArraySet<>(); @Autowired @Qualifier("threadPoolTaskExecutor") private ThreadPoolTaskExecutor threadPoolTaskExecutor; @OnOpen public void onOpen(Session session) { this.session = session; int count = onlineCount.incrementAndGet(); wsServers.add(this); log.info("与客户端连接成功,当前连接的客户端数量为:{}", count); } @OnError public void onError(Session session, @NonNull Throwable throwable) { log.error("连接发生报错"); throwable.printStackTrace(); } @OnClose public void onClose() { int count = onlineCount.decrementAndGet(); wsServers.remove(this); log.info("服务端断开连接,当前连接的客户端数量为:{}", count); } private Map sessionLocks = new ConcurrentHashMap<>(); private Lock getSessionLock(String sessionId) { sessionLocks.putIfAbsent(sessionId, new ReentrantLock()); return sessionLocks.get(sessionId); } public void sendMessageToAll(String message) { Map map = jsonToMap(message); if (map.get("pid").equals("00000001")) { System.out.println(message); } try { List> futures = new ArrayList<>(); wsServers.forEach(ws -> { Future future = threadPoolTaskExecutor.submit(() -> { Session session = ws.session; if (session != null && session.isOpen()) { Lock sessionLock = getSessionLock(session.getId()); sessionLock.lock(); try { synchronized (session){ session.getBasicRemote().sendText(message); } } catch (Exception e) { log.error("发送消息时出现异常: " + e.getMessage()); } finally { sessionLock.unlock(); } } else { log.error("会话不存在或已关闭,无法发送消息"); } }); futures.add(future); }); //等待所有任务执行完成 for (Future future : futures) { try { future.get(); } catch (InterruptedException | ExecutionException e) { log.error("发送消息时出现异常: " + e.getMessage()); } } } catch (Exception e) { log.error("发送消息时出现异常: " + e.getMessage()); } } public static Map jsonToMap(String json) { Gson gson = new Gson(); Type type = new TypeToken>() { }.getType(); return gson.fromJson(json, type); } }