From 2c9a5fa4b53955033b1bd2ac637dea0728808551 Mon Sep 17 00:00:00 2001
From: zj <1772600164@qq.com>
Date: Wed, 27 Mar 2024 20:01:13 +0800
Subject: [PATCH] 多链接

---
 websocketSerivce/src/main/java/org/example/controller/JournalismController.java     |   37 +++
 websocketSerivce/src/main/java/org/example/controller/StockNewShareController.java  |   68 ++++++
 websocketSerivce/src/main/java/org/example/pojo/StockNewShare.java                  |   59 +++++
 websocketSerivce/src/main/java/org/example/util/ApplicationContextRegisterUtil.java |   27 ++
 websocketSerivce/src/main/java/org/example/pojo/DataServiceKey.java                 |   58 +++++
 websocketSerivce/src/main/java/org/example/timedTask/NewShareTask.java              |   80 +++++++
 websocketSerivce/src/main/java/org/example/dao/DataServiceKeyMapper.java            |    9 
 .idea/vcs.xml                                                                       |    6 
 websocketSerivce/src/main/java/org/example/constant/StockConstant.java              |    1 
 websocketSerivce/src/main/java/org/example/websocket/controller/GenerateKey.java    |   50 ++++
 websocketSerivce/src/main/java/org/example/enums/EStockType.java                    |   10 
 websocketSerivce/src/main/java/org/example/timedTask/JournalismTask.java            |    5 
 websocketSerivce/src/main/java/org/example/dao/StockNewShareMapper.java             |    9 
 websocketSerivce/src/main/java/org/example/controller/StockMarketNewController.java |   31 ++
 websocketSerivce/src/main/java/org/example/controller/ApiController.java            |   22 -
 websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java           |  188 ++++++++++++----
 16 files changed, 587 insertions(+), 73 deletions(-)

diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="" vcs="Git" />
+  </component>
+</project>
\ No newline at end of file
diff --git a/websocketSerivce/src/main/java/org/example/constant/StockConstant.java b/websocketSerivce/src/main/java/org/example/constant/StockConstant.java
index 5184d4c..be00447 100644
--- a/websocketSerivce/src/main/java/org/example/constant/StockConstant.java
+++ b/websocketSerivce/src/main/java/org/example/constant/StockConstant.java
@@ -22,6 +22,7 @@
 
     public  static String US_KEY = "F03fXyNJKeFiTGsaoXHg";
 
+    public static String IPO_HTTP_API = "http://test.js-stock.top/";
 
 
 }
diff --git a/websocketSerivce/src/main/java/org/example/controller/ApiController.java b/websocketSerivce/src/main/java/org/example/controller/ApiController.java
index 8550326..4f9e6fe 100644
--- a/websocketSerivce/src/main/java/org/example/controller/ApiController.java
+++ b/websocketSerivce/src/main/java/org/example/controller/ApiController.java
@@ -20,24 +20,6 @@
 @RequestMapping("/api/all")
 public class ApiController {
 
-    @Autowired
-    JournalismMapper journalismMapper;
-
-    @Autowired
-    StockMarketNewMapper stockMarketNewMapper;
-
-
-    @GetMapping("JournalismAll")
-    public ServerResponse JournalismAll(){
-        LambdaQueryWrapper<Journalism> queryWrapper = new LambdaQueryWrapper<>();
-        return ServerResponse.createBySuccess(journalismMapper.selectList(queryWrapper));
-    }
-
-    @GetMapping("StockMarketNew")
-    public ServerResponse StockMarketNew(){
-        LambdaQueryWrapper<StockMarketNew> queryWrapper = new LambdaQueryWrapper<>();
-        return ServerResponse.createBySuccess(stockMarketNewMapper.selectList(queryWrapper));
-    }
 
     /*查询股票日线*/
     @RequestMapping({"getKData.do"})
@@ -48,9 +30,9 @@
             @RequestParam("stockType") String stockType
     ) {
         EStockType eStockType = null;
-        if(stockType.equals("US")){
+        if (stockType.equals("US")) {
             eStockType = EStockType.US;
-        }else{
+        } else {
             eStockType = EStockType.IN;
         }
         return HttpUtil.get(eStockType.stockUrl + "kline?pid=" + pid + "&interval=" + interval + "&key=" + eStockType.stockKey);
diff --git a/websocketSerivce/src/main/java/org/example/controller/JournalismController.java b/websocketSerivce/src/main/java/org/example/controller/JournalismController.java
new file mode 100644
index 0000000..48a543c
--- /dev/null
+++ b/websocketSerivce/src/main/java/org/example/controller/JournalismController.java
@@ -0,0 +1,37 @@
+package org.example.controller;
+
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import org.example.common.ServerResponse;
+import org.example.dao.JournalismMapper;
+import org.example.pojo.Journalism;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RestController;
+
+/**
+ * @program: webSocketProject
+ * @description: 新闻
+ * @create: 2024-03-27 14:03
+ **/
+@RestController
+@RequestMapping("/api/news/")
+public class JournalismController {
+
+    @Autowired
+    JournalismMapper journalismMapper;
+
+    @GetMapping("getNewsListAll.do")
+    public ServerResponse JournalismAll(){
+        LambdaQueryWrapper<Journalism> queryWrapper = new LambdaQueryWrapper<>();
+        return ServerResponse.createBySuccess(journalismMapper.selectList(queryWrapper));
+    }
+
+    @GetMapping("getNewsList.do")
+    public ServerResponse JournalismNumber(){
+        LambdaQueryWrapper<Journalism>  queryWrapper  =  new  LambdaQueryWrapper<>();
+        queryWrapper.orderByDesc(Journalism::getTime).last("limit  20");
+        return ServerResponse.createBySuccess(journalismMapper.selectList(queryWrapper));
+    }
+
+}
diff --git a/websocketSerivce/src/main/java/org/example/controller/StockMarketNewController.java b/websocketSerivce/src/main/java/org/example/controller/StockMarketNewController.java
new file mode 100644
index 0000000..a03b22c
--- /dev/null
+++ b/websocketSerivce/src/main/java/org/example/controller/StockMarketNewController.java
@@ -0,0 +1,31 @@
+//package org.example.controller;
+//
+//import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+//import org.example.common.ServerResponse;
+//import org.example.dao.StockMarketNewMapper;
+//import org.example.pojo.StockMarketNew;
+//import org.springframework.beans.factory.annotation.Autowired;
+//import org.springframework.web.bind.annotation.GetMapping;
+//import org.springframework.web.bind.annotation.RequestMapping;
+//import org.springframework.web.bind.annotation.RestController;
+//
+///**
+// * @program: webSocketProject
+// * @description:
+// * @create: 2024-03-27 14:07
+// **/
+//@RestController
+//@RequestMapping("/api/StockMarketNew")
+//public class StockMarketNewController {
+//
+//
+//    @Autowired
+//    StockMarketNewMapper stockMarketNewMapper;
+//
+//    @GetMapping("StockMarketNewAll")
+//    public ServerResponse StockMarketNew(){
+//        LambdaQueryWrapper<StockMarketNew> queryWrapper = new LambdaQueryWrapper<>();
+//        return ServerResponse.createBySuccess(stockMarketNewMapper.selectList(queryWrapper));
+//    }
+//
+//}
diff --git a/websocketSerivce/src/main/java/org/example/controller/StockNewShareController.java b/websocketSerivce/src/main/java/org/example/controller/StockNewShareController.java
new file mode 100644
index 0000000..217d1f7
--- /dev/null
+++ b/websocketSerivce/src/main/java/org/example/controller/StockNewShareController.java
@@ -0,0 +1,68 @@
+package org.example.controller;
+
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.baomidou.mybatisplus.core.mapper.BaseMapper;
+import org.example.common.ServerResponse;
+import org.example.dao.StockMarketNewMapper;
+import org.example.dao.StockNewShareMapper;
+import org.example.enums.EStockType;
+import org.example.pojo.StockMarketNew;
+import org.example.pojo.StockNewShare;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RequestParam;
+import org.springframework.web.bind.annotation.RestController;
+
+import java.util.List;
+
+/**
+ * @program: webSocketProject
+ * @description: 新股
+ * @create: 2024-03-27 15:27
+ **/
+@RestController
+@RequestMapping("/api/stock")
+public class StockNewShareController {
+
+    @Autowired
+    StockNewShareMapper stockNewShareMapper;
+
+    @Autowired
+    StockMarketNewMapper stockMarketNewMapper;
+
+    @GetMapping("getStock.do")
+    public ServerResponse StockMarketNew(@RequestParam(value = "stockType", required = false) String stockType){
+        // 将输入的股票类型转换为大写
+        String upperCase = stockType.toUpperCase();
+        // 根据代码获取对应的枚举类型
+        EStockType code = EStockType.getEsByCode(upperCase);
+        if(code == null){
+            return ServerResponse.createBySuccessMsg("请输入正确的stockType");
+        }
+
+        // 根据枚举类型进行不同的操作
+        switch(code){
+            case XG:
+                // 查询新股市场数据
+                return getStockData(stockNewShareMapper);
+            case IN:
+                // 查询股票数据
+                return getStockData(stockMarketNewMapper);
+            default:
+                return ServerResponse.createBySuccessMsg("未找到对应的股票数据");
+        }
+    }
+
+    // 通用方法,根据传入的mapper查询数据
+    private <T> ServerResponse getStockData(BaseMapper<T> mapper){
+        LambdaQueryWrapper<T> wrapper = new LambdaQueryWrapper<>();
+        List<T> list = mapper.selectList(wrapper);
+        if(list.isEmpty()){
+            return ServerResponse.createByErrorMsg("查询结果为空");
+        }
+        return ServerResponse.createBySuccess(list);
+    }
+
+
+}
diff --git a/websocketSerivce/src/main/java/org/example/dao/DataServiceKeyMapper.java b/websocketSerivce/src/main/java/org/example/dao/DataServiceKeyMapper.java
new file mode 100644
index 0000000..a89d006
--- /dev/null
+++ b/websocketSerivce/src/main/java/org/example/dao/DataServiceKeyMapper.java
@@ -0,0 +1,9 @@
+package org.example.dao;
+
+import com.baomidou.mybatisplus.core.mapper.BaseMapper;
+import org.apache.ibatis.annotations.Mapper;
+import org.example.pojo.DataServiceKey;
+
+@Mapper
+public interface DataServiceKeyMapper extends BaseMapper<DataServiceKey> {
+}
diff --git a/websocketSerivce/src/main/java/org/example/dao/StockNewShareMapper.java b/websocketSerivce/src/main/java/org/example/dao/StockNewShareMapper.java
new file mode 100644
index 0000000..7c163ad
--- /dev/null
+++ b/websocketSerivce/src/main/java/org/example/dao/StockNewShareMapper.java
@@ -0,0 +1,9 @@
+package org.example.dao;
+
+import com.baomidou.mybatisplus.core.mapper.BaseMapper;
+import org.apache.ibatis.annotations.Mapper;
+import org.example.pojo.StockNewShare;
+
+@Mapper
+public interface StockNewShareMapper extends BaseMapper<StockNewShare> {
+}
diff --git a/websocketSerivce/src/main/java/org/example/enums/EStockType.java b/websocketSerivce/src/main/java/org/example/enums/EStockType.java
index c8b70fb..1a8d9b7 100644
--- a/websocketSerivce/src/main/java/org/example/enums/EStockType.java
+++ b/websocketSerivce/src/main/java/org/example/enums/EStockType.java
@@ -11,6 +11,7 @@
 
 
     IN("IN","印度股票","14", StockConstant.HTTP_API, StockConstant.KEY),
+    XG("XG","新股","14", StockConstant.HTTP_API, StockConstant.KEY),
     US("US","美国股票","5",StockConstant.US_API_URL,StockConstant.US_KEY);
     private String code;
     private String typeDesc;
@@ -36,6 +37,15 @@
         }
     }
 
+    public static EStockType getEsByCode(String code) {
+        for (EStockType type : EStockType.values()) {
+            if (type.getCode().equals(code)) {
+                return type;
+            }
+        }
+        return null;
+    }
+
     public String getContryId() {
         return contryId;
     }
diff --git a/websocketSerivce/src/main/java/org/example/pojo/DataServiceKey.java b/websocketSerivce/src/main/java/org/example/pojo/DataServiceKey.java
new file mode 100644
index 0000000..9bdce6b
--- /dev/null
+++ b/websocketSerivce/src/main/java/org/example/pojo/DataServiceKey.java
@@ -0,0 +1,58 @@
+package org.example.pojo;
+
+import com.baomidou.mybatisplus.annotation.*;
+import lombok.Data;
+
+import java.io.Serializable;
+import java.util.Date;
+
+/**
+ * @program: webSocketProject
+ * @description:
+ * @create: 2024-03-27 17:01
+ **/
+@Data
+@TableName("data_service_key")
+public class DataServiceKey implements Serializable {
+
+    private static final long serialVersionUID = 1L;
+
+
+    /**
+     * id
+     */
+    @TableId(type = IdType.AUTO)
+    private Integer id;
+
+    /**
+     * key
+     */
+    private String tokenKey;
+
+    /**
+     * 是否可用 0:否 1: 是
+     */
+    private Integer isAvailable;
+
+    /**
+     * 到期时间
+     */
+    private Date expirationTime;
+
+    /**
+     * 添加时间
+     */
+    private Date startTime;
+
+    /**
+     * 修改时间
+     */
+    private Date updateTime;
+
+    /**
+     * 备注
+     */
+    private String remark;
+
+    public DataServiceKey() {}
+}
diff --git a/websocketSerivce/src/main/java/org/example/pojo/StockNewShare.java b/websocketSerivce/src/main/java/org/example/pojo/StockNewShare.java
new file mode 100644
index 0000000..139fad1
--- /dev/null
+++ b/websocketSerivce/src/main/java/org/example/pojo/StockNewShare.java
@@ -0,0 +1,59 @@
+package org.example.pojo;
+
+import com.baomidou.mybatisplus.annotation.IdType;
+import com.baomidou.mybatisplus.annotation.TableId;
+import com.baomidou.mybatisplus.annotation.TableName;
+import lombok.Data;
+
+import java.io.Serializable;
+
+/**
+ * @program: webSocketProject
+ * @description:
+ * @create: 2024-03-27 13:34
+ **/
+@Data
+@TableName("stock_new_share")
+public class StockNewShare implements Serializable {
+
+    private static final long serialVersionUID = 1L;
+
+
+    /**
+     * id
+     */
+    @TableId(type = IdType.AUTO)
+    private Integer id;
+
+    /**
+     * 产品id
+     */
+    private String pid;
+
+    /**
+     * 发现价格
+     */
+    private String ipoPrice;
+
+    /**
+     * 发行市价
+     */
+    private String ipoValue;
+
+    /**
+     * 交易平台
+     */
+    private String exchange;
+
+    /**
+     * 公司名称
+     */
+    private String company;
+
+    /**
+     * 上市时间
+     */
+    private String iopListing;
+
+    public StockNewShare() {}
+}
\ No newline at end of file
diff --git a/websocketSerivce/src/main/java/org/example/timedTask/NewsTask.java b/websocketSerivce/src/main/java/org/example/timedTask/JournalismTask.java
similarity index 88%
rename from websocketSerivce/src/main/java/org/example/timedTask/NewsTask.java
rename to websocketSerivce/src/main/java/org/example/timedTask/JournalismTask.java
index 128bfee..a6ec46f 100644
--- a/websocketSerivce/src/main/java/org/example/timedTask/NewsTask.java
+++ b/websocketSerivce/src/main/java/org/example/timedTask/JournalismTask.java
@@ -1,6 +1,5 @@
 package org.example.timedTask;
 
-import org.example.enums.EStockType;
 import org.example.server.ISiteNewsService;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -14,8 +13,8 @@
 
 
 @Component
-public class NewsTask {
-    private static final Logger log = LoggerFactory.getLogger(NewsTask.class);
+public class JournalismTask {
+    private static final Logger log = LoggerFactory.getLogger(JournalismTask.class);
 
     @Autowired
     ISiteNewsService iSiteNewsService;
diff --git a/websocketSerivce/src/main/java/org/example/timedTask/NewShareTask.java b/websocketSerivce/src/main/java/org/example/timedTask/NewShareTask.java
new file mode 100644
index 0000000..aba120a
--- /dev/null
+++ b/websocketSerivce/src/main/java/org/example/timedTask/NewShareTask.java
@@ -0,0 +1,80 @@
+package org.example.timedTask;
+
+import cn.hutool.json.JSONUtil;
+import com.alibaba.fastjson.JSONArray;
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.google.gson.Gson;
+import lombok.extern.slf4j.Slf4j;
+import org.example.dao.StockNewShareMapper;
+import org.example.enums.EStockType;
+import org.example.pojo.StockNewShare;
+import org.example.util.HttpClientRequest;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.scheduling.annotation.Scheduled;
+import org.springframework.stereotype.Component;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/**
+ * @program: webSocketProject
+ * @description: 新股
+ * @create: 2024-03-27 13:14
+ **/
+@Component
+@Slf4j
+public class NewShareTask {
+
+    private final Lock lock = new ReentrantLock();
+
+    @Autowired
+    StockNewShareMapper stockNewShareMapper;
+
+    /*
+     * ipo、新股日历抓取
+     * */
+    @Scheduled(cron = "0 0/1 * * * ? ")
+    public void get() {
+        if (lock.tryLock()) {
+            log.info("ipo、新股日历抓取--------->开始");
+            try {
+                newShare(EStockType.IN);
+            } finally {
+                lock.unlock();
+                log.info("ipo、新股日历抓取--------->结束");
+            }
+        } else {
+            log.info("ipo、新股日历抓取--------->上次任务还未执行完成,本次任务忽略");
+        }
+    }
+
+    public void newShare(EStockType e) {
+        String result = HttpClientRequest.doGet(e.stockUrl + "new-stock?country_id=" + e.getContryId() + "&key=" + e.stockKey);
+        JSONArray jsonArray = JSONArray.parseArray(result);
+
+        List<StockNewShare> list = jsonArray.stream().map(stock -> JSONUtil.toBean(stock.toString(), StockNewShare.class)).collect(Collectors.toList());
+        List<String> pidList = list.stream().map(StockNewShare::getPid).collect(Collectors.toList());
+        List<StockNewShare> shareList = stockNewShareMapper.selectList(new LambdaQueryWrapper<StockNewShare>().in(StockNewShare::getPid, pidList));
+        Map<String, StockNewShare> resultMap = shareList.stream()
+                .collect(Collectors.toMap(StockNewShare::getPid, Function.identity()));
+        list.forEach(f -> {
+            StockNewShare share = resultMap.get(f.getPid());
+            if (share == null) {
+                stockNewShareMapper.insert(f);
+            } else {
+                f.setId(share.getId());
+                stockNewShareMapper.updateById(f);
+            }
+        });
+
+    }
+
+}
diff --git a/websocketSerivce/src/main/java/org/example/util/ApplicationContextRegisterUtil.java b/websocketSerivce/src/main/java/org/example/util/ApplicationContextRegisterUtil.java
new file mode 100644
index 0000000..6b07b6b
--- /dev/null
+++ b/websocketSerivce/src/main/java/org/example/util/ApplicationContextRegisterUtil.java
@@ -0,0 +1,27 @@
+package org.example.util;
+
+import org.springframework.beans.BeansException;
+import org.springframework.context.ApplicationContext;
+import org.springframework.context.ApplicationContextAware;
+import org.springframework.context.annotation.Lazy;
+import org.springframework.stereotype.Component;
+
+/**
+ * @program: webSocketProject
+ * @description:
+ * @create: 2024-03-27 17:55
+ **/
+@Component
+@Lazy(false)
+public class ApplicationContextRegisterUtil  implements ApplicationContextAware {
+
+    private static ApplicationContext APPLICATION_CONTEXT;
+
+    @Override
+    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
+        APPLICATION_CONTEXT = applicationContext;
+    }
+    public static ApplicationContext getApplicationContext() {
+        return APPLICATION_CONTEXT;
+    }
+}
\ No newline at end of file
diff --git a/websocketSerivce/src/main/java/org/example/websocket/controller/GenerateKey.java b/websocketSerivce/src/main/java/org/example/websocket/controller/GenerateKey.java
new file mode 100644
index 0000000..34c4205
--- /dev/null
+++ b/websocketSerivce/src/main/java/org/example/websocket/controller/GenerateKey.java
@@ -0,0 +1,50 @@
+package org.example.websocket.controller;
+
+import cn.hutool.core.date.DateUtil;
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import org.example.common.ServerResponse;
+import org.example.dao.DataServiceKeyMapper;
+import org.example.pojo.DataServiceKey;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RequestParam;
+import org.springframework.web.bind.annotation.RestController;
+
+import java.util.Date;
+import java.util.UUID;
+
+
+/**
+ * @program: webSocketProject
+ * @description: 生成key
+ * @create: 2024-03-27 19:38
+ **/
+@RestController
+@RequestMapping("/api")
+public class GenerateKey {
+
+    @Autowired
+    DataServiceKeyMapper dataServiceKeyMapper;
+
+    @PostMapping("/creationKey")
+    public ServerResponse sendNotification(@RequestParam("time") Date time) {
+        String randomKey = UUID.randomUUID().toString();
+        try {
+            Long count = dataServiceKeyMapper.selectCount(new LambdaQueryWrapper<DataServiceKey>().eq(DataServiceKey::getTokenKey, randomKey));
+            if(count > 0){
+                return ServerResponse.createByErrorMsg("请重新生成");
+            }
+            System.out.println(randomKey);
+            DataServiceKey dataServiceKey = new DataServiceKey();
+            dataServiceKey.setTokenKey(randomKey);
+            dataServiceKey.setExpirationTime(time);
+            dataServiceKey.setStartTime(DateUtil.date());
+            dataServiceKeyMapper.insert(dataServiceKey);
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+        return ServerResponse.createBySuccessMsg(randomKey);
+    }
+
+}
diff --git a/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java b/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
index 0984ec6..95d9667 100644
--- a/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
+++ b/websocketSerivce/src/main/java/org/example/websocket/server/WsServer.java
@@ -1,15 +1,35 @@
 package org.example.websocket.server;
 
+import cn.hutool.core.date.DateUnit;
+import cn.hutool.core.date.DateUtil;
+import cn.hutool.core.io.unit.DataUnit;
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.baomidou.mybatisplus.core.toolkit.StringUtils;
 import lombok.NonNull;
 import lombok.extern.slf4j.Slf4j;
+import org.example.dao.DataServiceKeyMapper;
+import org.example.pojo.DataServiceKey;
+import org.example.util.ApplicationContextRegisterUtil;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.ApplicationContext;
+import org.springframework.scheduling.annotation.Scheduled;
 import org.springframework.stereotype.Component;
 import org.springframework.web.bind.annotation.PostMapping;
 import org.springframework.web.bind.annotation.RequestBody;
 
 import javax.websocket.*;
+import javax.websocket.server.PathParam;
 import javax.websocket.server.ServerEndpoint;
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
 
 /**
  * @ClassDescription: websocket服务端
@@ -18,12 +38,11 @@
  */
 @Slf4j
 @Component
-//@RestController
 @ServerEndpoint("/websocket-server")
-//@ServerEndpoint("/")
 public class WsServer {
 
     private Session session;
+
     /**
      * 记录在线连接客户端数量
      */
@@ -31,41 +50,127 @@
     /**
      * 存放每个连接进来的客户端对应的websocketServer对象,用于后面群发消息
      */
-    private static CopyOnWriteArrayList<WsServer> wsServers = new CopyOnWriteArrayList<>();
+    private static ConcurrentHashMap<String, Session> webSocketMap = new ConcurrentHashMap<>();
+
+    private final Lock lock = new ReentrantLock();
+
+    /**
+     * 关闭过期链接
+     */
+    @Scheduled(cron = "0  0/1  *  *  *  ?")
+    public void clearExpiration() {
+        if (lock.tryLock()) {
+            log.info("webSocket关闭过期链接--------->开始");
+            try {
+                ApplicationContext act = ApplicationContextRegisterUtil.getApplicationContext();
+                DataServiceKeyMapper mapper = act.getBean(DataServiceKeyMapper.class);
+                LambdaQueryWrapper<DataServiceKey> queryWrapper = new LambdaQueryWrapper<>();
+                List<DataServiceKey> keyList = mapper.selectList(queryWrapper.lt(DataServiceKey::getExpirationTime, DateUtil.date())
+                        .or().eq(DataServiceKey::getIsAvailable, 0));
+
+                keyList.forEach(f -> {
+                    //  在锁内操作
+                    Session sessions = webSocketMap.get(f.getTokenKey());
+                    if (null != sessions) {
+                        try {
+                            sendInfo(sessions, 0);
+                            f.setIsAvailable(0);
+                            mapper.updateById(f);
+                        } catch (Exception e) {
+                            throw new RuntimeException(e);
+                        }
+                    }
+                    log.info("webSocket关闭过期链接---key:"+f.getTokenKey());
+                });
+            } catch (Exception e) {
+                throw new RuntimeException(e);
+            } finally {
+                lock.unlock();
+                log.info("webSocket关闭过期链接--------->结束");
+            }
+        } else {
+            log.info("webSocket关闭过期链接--------->上次任务还未执行完成,本次任务忽略");
+        }
+    }
+
 
     /**
      * 服务端与客户端连接成功时执行
+     *
      * @param session 会话
      */
     @OnOpen
-    public void onOpen(Session session){
-        this.session = session;
-        //接入的客户端+1
-        int count = onlineCount.incrementAndGet();
-        //集合中存入客户端对象+1
-        wsServers.add(this);
-        log.info("与客户端连接成功,当前连接的客户端数量为:{}", count);
+    public void onOpen(Session session) throws Exception {
+        try {
+            this.session = session;
+            //查询该用户有没有权限
+            Map<String, List<String>> params = session.getRequestParameterMap();
+            List<String> funcTypes = params.get("key");
+            // 取出funcType参数的值
+            if (null == funcTypes) {
+                sendInfo(session, 1);
+            }
+            String key = funcTypes.get(0);
+            if (StringUtils.isEmpty(key)) {
+                sendInfo(session, 1);
+            }
+            ApplicationContext act = ApplicationContextRegisterUtil.getApplicationContext();
+            DataServiceKeyMapper mapper = act.getBean(DataServiceKeyMapper.class);
+            DataServiceKey dataServiceKey = mapper.selectOne(new LambdaQueryWrapper<DataServiceKey>()
+                    .eq(DataServiceKey::getTokenKey, key)
+                    .gt(DataServiceKey::getExpirationTime, DateUtil.date())
+                    .eq(DataServiceKey::getIsAvailable, 1));
+            if (null != dataServiceKey) {
+                if (webSocketMap.containsKey(key)) {
+                    webSocketMap.put(key, session);
+                } else {
+                    //接入的客户端+1
+                    onlineCount.incrementAndGet();
+                    webSocketMap.put(key, session);
+                }
+
+                log.info("与客户端连接成功,当前连接的客户端数量为:{}", onlineCount);
+            } else {
+                sendInfo(session, 1);
+            }
+        } catch (Exception e) {
+            sendInfo(session, 1);
+            log.error("客户端连接错误:" + e.getMessage());
+        }
+    }
+
+
+    /**
+     * 关闭连接
+     */
+    public void sendInfo(Session session, int type) throws Exception {
+        if (session.getBasicRemote() != null) {
+            if (type == 1) {
+                session.getBasicRemote().sendText("订阅key不存在!");
+            }
+            session.close();
+        }
     }
 
     /**
      * 收到客户端的消息时执行
+     *
      * @param message 消息
      * @param session 会话
      */
     @OnMessage
-    public void onMessage(String message, Session session){
+    public void onMessage(Session session, String message) {
         log.info("收到来自客户端的消息,客户端地址:{},消息内容:{}", session.getMessageHandlers(), message);
-        //业务逻辑,对消息的处理
-//        sendMessageToAll("群发消息的内容");
     }
 
     /**
      * 连接发生报错时执行
-     * @param session 会话
+     *
+     * @param session   会话
      * @param throwable 报错
      */
     @OnError
-    public void onError(Session session, @NonNull Throwable throwable){
+    public void onError(Session session, @NonNull Throwable throwable) {
         log.error("连接发生报错");
         throwable.printStackTrace();
     }
@@ -74,53 +179,36 @@
      * 连接断开时执行
      */
     @OnClose
-    public void onClose(){
+    public void onClose() {
         //接入客户端连接数-1
-        int count = onlineCount.decrementAndGet();
+        if (webSocketMap.get(this.session) != null) {
+            onlineCount.decrementAndGet();
+        }
         //集合中的客户端对象-1
-        wsServers.remove(this);
-        log.info("服务端断开连接,当前连接的客户端数量为:{}", count);
+        webSocketMap.remove(this);
+        log.info("服务端断开连接,当前连接的客户端数量为:{}", onlineCount.get());
     }
-
-
-    /**
-     * 向客户端推送消息
-     * @param message 消息
-     */
-    public void sendMessage(String message){
-        this.session.getAsyncRemote().sendText(message);
-        log.info("推送消息给客户端:{},消息内容为:{}", this.session.getMessageHandlers(), message);
-    }
-
-//    @PostMapping("/send2c")
-//    public void sendMessage1(@RequestBody String message){
-//        this.session.getAsyncRemote().sendText(message);
-//        try {
-//            this.session.getBasicRemote().sendText(message);
-//        } catch (IOException e) {
-//            throw new RuntimeException(e);
-//        }
-//        log.info("推送消息给客户端,消息内容为:{}", message);
-//    }
 
 
     /**
      * 群发消息
+     *
      * @param message 消息
      */
-    public void sendMessageToAll(String message){
-        CopyOnWriteArrayList<WsServer> ws = wsServers;
-        for (WsServer wsServer : ws){
-            wsServer.sendMessage(message);
-        }
+    public void sendMessageToAll(String message) { // 发送消息给所有客户端
+        webSocketMap.values().forEach(session -> {
+            try {
+                session.getBasicRemote().sendText(message);
+            } catch (IOException e) {
+                log.error("发送消息给客户端失败,客户端ID为:{}", session.getId(), e);
+            }
+        });
+        log.info("推送消息给所有客户端,消息内容为:{}", message);
     }
 
     @PostMapping("/send2AllC")
-    public void sendMessageToAll1(@RequestBody String message){
-        CopyOnWriteArrayList<WsServer> ws = wsServers;
-        for (WsServer wsServer : ws){
-            wsServer.sendMessage(message);
-        }
+    public void sendMessageToAll1(@RequestBody String message) { // 发送消息给所有客户端
+        sendMessageToAll(message);
     }
 
 

--
Gitblit v1.9.3