first commit
This commit is contained in:
137
net/server.go
Normal file
137
net/server.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/gorilla/websocket"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"slgserver/config"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// http升级websocket协议的配置
|
||||
var wsUpgrader = websocket.Upgrader{
|
||||
// 允许配置的 CORS 请求
|
||||
CheckOrigin: checkOrigin,
|
||||
}
|
||||
|
||||
var allowedOrigins = loadAllowedOrigins()
|
||||
|
||||
type server struct {
|
||||
addr string
|
||||
router *Router
|
||||
needSecret bool
|
||||
beforeClose func(WSConn)
|
||||
httpServer *http.Server
|
||||
}
|
||||
|
||||
func NewServer(addr string, needSecret bool) *server {
|
||||
s := server{
|
||||
addr: addr,
|
||||
needSecret: needSecret,
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
func (this *server) Router(router *Router) {
|
||||
this.router = router
|
||||
}
|
||||
|
||||
func (this *server) Start() {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", this.wsHandler)
|
||||
|
||||
this.httpServer = &http.Server{
|
||||
Addr: this.addr,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
slog.Info("server starting", "addr", this.addr)
|
||||
|
||||
go func() {
|
||||
if err := this.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
slog.Error("server start failed", "error", err, "addr", this.addr)
|
||||
}
|
||||
}()
|
||||
|
||||
this.waitForShutdown()
|
||||
}
|
||||
|
||||
func (this *server) waitForShutdown() {
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
slog.Info("server shutting down...")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := this.httpServer.Shutdown(ctx); err != nil {
|
||||
slog.Error("server shutdown error", "error", err)
|
||||
} else {
|
||||
slog.Info("server shutdown gracefully")
|
||||
}
|
||||
}
|
||||
|
||||
func (this *server) SetOnBeforeClose(hookFunc func(WSConn)) {
|
||||
this.beforeClose = hookFunc
|
||||
}
|
||||
|
||||
func (this *server) wsHandler(resp http.ResponseWriter, req *http.Request) {
|
||||
|
||||
wsSocket, err := wsUpgrader.Upgrade(resp, req, nil)
|
||||
if err != nil {
|
||||
slog.Warn("websocket upgrade failed", "error", err, "remote", req.RemoteAddr)
|
||||
return
|
||||
}
|
||||
|
||||
conn := ConnMgr.NewConn(wsSocket, this.needSecret)
|
||||
slog.Info("client connect", "addr", wsSocket.RemoteAddr().String())
|
||||
|
||||
conn.SetRouter(this.router)
|
||||
conn.SetOnClose(ConnMgr.RemoveConn)
|
||||
conn.SetOnBeforeClose(this.beforeClose)
|
||||
conn.Start()
|
||||
conn.Handshake()
|
||||
|
||||
}
|
||||
|
||||
func checkOrigin(r *http.Request) bool {
|
||||
if len(allowedOrigins) == 0 {
|
||||
return true
|
||||
}
|
||||
origin := strings.ToLower(strings.TrimSpace(r.Header.Get("Origin")))
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
if _, ok := allowedOrigins[origin]; ok {
|
||||
return true
|
||||
}
|
||||
slog.Warn("origin not allowed", "origin", origin)
|
||||
return false
|
||||
}
|
||||
|
||||
func loadAllowedOrigins() map[string]struct{} {
|
||||
origins := config.GetString("server.allowed_origins", "")
|
||||
if origins == "" {
|
||||
return nil
|
||||
}
|
||||
originMap := make(map[string]struct{})
|
||||
items := strings.Split(origins, ",")
|
||||
for _, item := range items {
|
||||
val := strings.ToLower(strings.TrimSpace(item))
|
||||
if val == "" {
|
||||
continue
|
||||
}
|
||||
if val == "*" {
|
||||
return nil
|
||||
}
|
||||
originMap[val] = struct{}{}
|
||||
}
|
||||
return originMap
|
||||
}
|
||||
Reference in New Issue
Block a user