package net import ( "context" "github.com/gorilla/websocket" "log/slog" "net/http" "os" "os/signal" "slgserver/config" "strings" "syscall" "time" ) // http升级websocket协议的配置 var wsUpgrader = websocket.Upgrader{ // 允许配置的 CORS 请求 CheckOrigin: checkOrigin, } var allowedOrigins = loadAllowedOrigins() type server struct { addr string router *Router needSecret bool beforeClose func(WSConn) httpServer *http.Server } func NewServer(addr string, needSecret bool) *server { s := server{ addr: addr, needSecret: needSecret, } return &s } func (this *server) Router(router *Router) { this.router = router } func (this *server) Start() { mux := http.NewServeMux() mux.HandleFunc("/", this.wsHandler) this.httpServer = &http.Server{ Addr: this.addr, Handler: mux, } slog.Info("server starting", "addr", this.addr) go func() { if err := this.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { slog.Error("server start failed", "error", err, "addr", this.addr) } }() this.waitForShutdown() } func (this *server) waitForShutdown() { quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit slog.Info("server shutting down...") ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() if err := this.httpServer.Shutdown(ctx); err != nil { slog.Error("server shutdown error", "error", err) } else { slog.Info("server shutdown gracefully") } } func (this *server) SetOnBeforeClose(hookFunc func(WSConn)) { this.beforeClose = hookFunc } func (this *server) wsHandler(resp http.ResponseWriter, req *http.Request) { wsSocket, err := wsUpgrader.Upgrade(resp, req, nil) if err != nil { slog.Warn("websocket upgrade failed", "error", err, "remote", req.RemoteAddr) return } conn := ConnMgr.NewConn(wsSocket, this.needSecret) slog.Info("client connect", "addr", wsSocket.RemoteAddr().String()) conn.SetRouter(this.router) conn.SetOnClose(ConnMgr.RemoveConn) conn.SetOnBeforeClose(this.beforeClose) conn.Start() conn.Handshake() } func checkOrigin(r *http.Request) bool { if len(allowedOrigins) == 0 { return true } origin := strings.ToLower(strings.TrimSpace(r.Header.Get("Origin"))) if origin == "" { return true } if _, ok := allowedOrigins[origin]; ok { return true } slog.Warn("origin not allowed", "origin", origin) return false } func loadAllowedOrigins() map[string]struct{} { origins := config.GetString("server.allowed_origins", "") if origins == "" { return nil } originMap := make(map[string]struct{}) items := strings.Split(origins, ",") for _, item := range items { val := strings.ToLower(strings.TrimSpace(item)) if val == "" { continue } if val == "*" { return nil } originMap[val] = struct{}{} } return originMap }