138 lines
2.8 KiB
Go
138 lines
2.8 KiB
Go
package net
|
|
|
|
import (
|
|
"context"
|
|
"github.com/gorilla/websocket"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"slgserver/config"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
// http升级websocket协议的配置
|
|
var wsUpgrader = websocket.Upgrader{
|
|
// 允许配置的 CORS 请求
|
|
CheckOrigin: checkOrigin,
|
|
}
|
|
|
|
var allowedOrigins = loadAllowedOrigins()
|
|
|
|
type server struct {
|
|
addr string
|
|
router *Router
|
|
needSecret bool
|
|
beforeClose func(WSConn)
|
|
httpServer *http.Server
|
|
}
|
|
|
|
func NewServer(addr string, needSecret bool) *server {
|
|
s := server{
|
|
addr: addr,
|
|
needSecret: needSecret,
|
|
}
|
|
return &s
|
|
}
|
|
|
|
func (this *server) Router(router *Router) {
|
|
this.router = router
|
|
}
|
|
|
|
func (this *server) Start() {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/", this.wsHandler)
|
|
|
|
this.httpServer = &http.Server{
|
|
Addr: this.addr,
|
|
Handler: mux,
|
|
}
|
|
|
|
slog.Info("server starting", "addr", this.addr)
|
|
|
|
go func() {
|
|
if err := this.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
slog.Error("server start failed", "error", err, "addr", this.addr)
|
|
}
|
|
}()
|
|
|
|
this.waitForShutdown()
|
|
}
|
|
|
|
func (this *server) waitForShutdown() {
|
|
quit := make(chan os.Signal, 1)
|
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
|
<-quit
|
|
|
|
slog.Info("server shutting down...")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
if err := this.httpServer.Shutdown(ctx); err != nil {
|
|
slog.Error("server shutdown error", "error", err)
|
|
} else {
|
|
slog.Info("server shutdown gracefully")
|
|
}
|
|
}
|
|
|
|
func (this *server) SetOnBeforeClose(hookFunc func(WSConn)) {
|
|
this.beforeClose = hookFunc
|
|
}
|
|
|
|
func (this *server) wsHandler(resp http.ResponseWriter, req *http.Request) {
|
|
|
|
wsSocket, err := wsUpgrader.Upgrade(resp, req, nil)
|
|
if err != nil {
|
|
slog.Warn("websocket upgrade failed", "error", err, "remote", req.RemoteAddr)
|
|
return
|
|
}
|
|
|
|
conn := ConnMgr.NewConn(wsSocket, this.needSecret)
|
|
slog.Info("client connect", "addr", wsSocket.RemoteAddr().String())
|
|
|
|
conn.SetRouter(this.router)
|
|
conn.SetOnClose(ConnMgr.RemoveConn)
|
|
conn.SetOnBeforeClose(this.beforeClose)
|
|
conn.Start()
|
|
conn.Handshake()
|
|
|
|
}
|
|
|
|
func checkOrigin(r *http.Request) bool {
|
|
if len(allowedOrigins) == 0 {
|
|
return true
|
|
}
|
|
origin := strings.ToLower(strings.TrimSpace(r.Header.Get("Origin")))
|
|
if origin == "" {
|
|
return true
|
|
}
|
|
if _, ok := allowedOrigins[origin]; ok {
|
|
return true
|
|
}
|
|
slog.Warn("origin not allowed", "origin", origin)
|
|
return false
|
|
}
|
|
|
|
func loadAllowedOrigins() map[string]struct{} {
|
|
origins := config.GetString("server.allowed_origins", "")
|
|
if origins == "" {
|
|
return nil
|
|
}
|
|
originMap := make(map[string]struct{})
|
|
items := strings.Split(origins, ",")
|
|
for _, item := range items {
|
|
val := strings.ToLower(strings.TrimSpace(item))
|
|
if val == "" {
|
|
continue
|
|
}
|
|
if val == "*" {
|
|
return nil
|
|
}
|
|
originMap[val] = struct{}{}
|
|
}
|
|
return originMap
|
|
}
|