package router import ( "fmt" "log/slog" "net/http" "runtime/debug" "strings" "time" ) import ( "github.com/gin-gonic/gin" ) import ( "auth-server/common" "auth-server/logger" ) func InitMiddleware(r *gin.Engine) { // NoCache is a middleware function that appends headers r.Use(NoCache) // 跨域处理 r.Use(Options) // Secure is a middleware function that appends security r.Use(Secure) // Use Slog Logger r.Use(GinLogger(logger.WithGroup("gin"))) // Global Recover r.Use(GinRecovery(logger.WithGroup("ginRecovery"))) // check header r.Use(CheckLanguage) r.Use(CheckSource) r.Use(CheckProduct) // check token r.Use(CheckAuth) } // NoCache is a middleware function that appends headers // to prevent the client from caching the HTTP response. func NoCache(c *gin.Context) { c.Header("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate, value") c.Header("Expires", "Thu, 01 Jan 1970 00:00:00 GMT") c.Header("Last-Modified", time.Now().UTC().Format(http.TimeFormat)) c.Next() } // Options is a middleware function that appends headers // for options requests and aborts then exits the middleware // chain and ends the request. func Options(c *gin.Context) { if c.Request.Method != "OPTIONS" { c.Next() } else { c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS") c.Header("Access-Control-Allow-Headers", "authorization, origin, content-type, accept") c.Header("Allow", "HEAD,GET,POST,PUT,PATCH,DELETE,OPTIONS") c.Header("Content-Type", "application/json") c.AbortWithStatus(200) } } // Secure is a middleware function that appends security // and resource access headers. func Secure(c *gin.Context) { c.Header("Access-Control-Allow-Origin", "*") //c.Header("X-Frame-Options", "DENY") c.Header("X-Content-Type-Options", "nosniff") c.Header("X-XSS-Protection", "1; mode=block") if c.Request.TLS != nil { c.Header("Strict-Transport-Security", "max-age=31536000") } // Also consider adding Content-Security-Policy headers // c.Header("Content-Security-Policy", "script-src 'self' https://cdnjs.cloudflare.com") } func CheckLanguage(c *gin.Context) { if strings.Contains(c.FullPath(), "/api/v1/auth") { ok := false language := c.Request.Header.Get("Language") for _, l := range common.MetadataConfig.GetLanguages() { if language == l.ToString() { ok = true } } if !ok { c.AbortWithStatusJSON(200, common.ErrToH(common.InvalidLanguage, c.GetHeader("Language"))) } else { c.Set("language", language) } } c.Next() } func CheckSource(c *gin.Context) { if strings.Contains(c.FullPath(), "/api/v1/auth") { source := c.Request.Header.Get("Source") ok := false for _, s := range common.MetadataConfig.GetSources() { if source == s.ToString() { ok = true } } if !ok { c.AbortWithStatusJSON(200, common.ErrToH(common.InvalidSource, c.GetHeader("Source"))) } else { c.Set("source", source) } } c.Next() } func CheckProduct(c *gin.Context) { if strings.Contains(c.FullPath(), "/api/v1/auth") { product := c.Request.Header.Get("Product") if strings.ToUpper(product) != common.MetadataConfig.GetProduct().ToString() { c.AbortWithStatusJSON(200, common.ErrToH(common.InvalidProduct, product)) } else { c.Set("product", product) } } c.Next() } func CheckAuth(c *gin.Context) { if strings.Contains(c.FullPath(), "/api/v1/auth") { token := c.Request.Header.Get("Authorization") uid, username, err := common.ParseToken(strings.TrimPrefix(token, "Bearer ")) if err != nil { c.AbortWithStatusJSON(200, common.ErrToH(common.InvalidToken, c.GetHeader("locale"))) } c.Set("uid", uid) c.Set("username", username) } c.Next() } func GinLogger(logger *slog.Logger) gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() path := c.Request.URL.Path if len(c.Request.URL.RawQuery) > 0 { path += "?" + c.Request.URL.RawQuery } c.Next() cost := time.Since(start) logger.Info(fmt.Sprintf("[%s]%s, header[%s-%s-%s] ip[%s], resp[%d] %s errors[%s]", c.Request.Method, path, c.GetHeader("Product"), c.GetHeader("Source"), c.GetHeader("Language"), c.ClientIP(), c.Writer.Status(), cost.String(), c.Errors.ByType(gin.ErrorTypePrivate).String(), )) } } func GinRecovery(logger *slog.Logger) gin.HandlerFunc { return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) { logger.Error("Recovery from panic", "recoverd", recovered, "stack", string(debug.Stack())) common.HttpErr(c, common.Unknown) c.AbortWithStatus(http.StatusOK) }) }