first commit

This commit is contained in:
ytc1012
2025-11-18 18:08:48 +08:00
commit de90ad79ea
162 changed files with 28098 additions and 0 deletions

137
net/server.go Normal file
View File

@@ -0,0 +1,137 @@
package net
import (
"context"
"github.com/gorilla/websocket"
"log/slog"
"net/http"
"os"
"os/signal"
"slgserver/config"
"strings"
"syscall"
"time"
)
// http升级websocket协议的配置
var wsUpgrader = websocket.Upgrader{
// 允许配置的 CORS 请求
CheckOrigin: checkOrigin,
}
var allowedOrigins = loadAllowedOrigins()
type server struct {
addr string
router *Router
needSecret bool
beforeClose func(WSConn)
httpServer *http.Server
}
func NewServer(addr string, needSecret bool) *server {
s := server{
addr: addr,
needSecret: needSecret,
}
return &s
}
func (this *server) Router(router *Router) {
this.router = router
}
func (this *server) Start() {
mux := http.NewServeMux()
mux.HandleFunc("/", this.wsHandler)
this.httpServer = &http.Server{
Addr: this.addr,
Handler: mux,
}
slog.Info("server starting", "addr", this.addr)
go func() {
if err := this.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
slog.Error("server start failed", "error", err, "addr", this.addr)
}
}()
this.waitForShutdown()
}
func (this *server) waitForShutdown() {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
slog.Info("server shutting down...")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := this.httpServer.Shutdown(ctx); err != nil {
slog.Error("server shutdown error", "error", err)
} else {
slog.Info("server shutdown gracefully")
}
}
func (this *server) SetOnBeforeClose(hookFunc func(WSConn)) {
this.beforeClose = hookFunc
}
func (this *server) wsHandler(resp http.ResponseWriter, req *http.Request) {
wsSocket, err := wsUpgrader.Upgrade(resp, req, nil)
if err != nil {
slog.Warn("websocket upgrade failed", "error", err, "remote", req.RemoteAddr)
return
}
conn := ConnMgr.NewConn(wsSocket, this.needSecret)
slog.Info("client connect", "addr", wsSocket.RemoteAddr().String())
conn.SetRouter(this.router)
conn.SetOnClose(ConnMgr.RemoveConn)
conn.SetOnBeforeClose(this.beforeClose)
conn.Start()
conn.Handshake()
}
func checkOrigin(r *http.Request) bool {
if len(allowedOrigins) == 0 {
return true
}
origin := strings.ToLower(strings.TrimSpace(r.Header.Get("Origin")))
if origin == "" {
return true
}
if _, ok := allowedOrigins[origin]; ok {
return true
}
slog.Warn("origin not allowed", "origin", origin)
return false
}
func loadAllowedOrigins() map[string]struct{} {
origins := config.GetString("server.allowed_origins", "")
if origins == "" {
return nil
}
originMap := make(map[string]struct{})
items := strings.Split(origins, ",")
for _, item := range items {
val := strings.ToLower(strings.TrimSpace(item))
if val == "" {
continue
}
if val == "*" {
return nil
}
originMap[val] = struct{}{}
}
return originMap
}