middleware.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. package router
  2. import (
  3. "fmt"
  4. "log/slog"
  5. "net/http"
  6. "runtime/debug"
  7. "strings"
  8. "time"
  9. )
  10. import (
  11. "github.com/gin-gonic/gin"
  12. )
  13. import (
  14. "auth-server/common"
  15. "auth-server/logger"
  16. )
  17. func InitMiddleware(r *gin.Engine) {
  18. // NoCache is a middleware function that appends headers
  19. r.Use(NoCache)
  20. // 跨域处理
  21. r.Use(Options)
  22. // Secure is a middleware function that appends security
  23. r.Use(Secure)
  24. // Use Slog Logger
  25. r.Use(GinLogger(logger.WithGroup("gin")))
  26. // Global Recover
  27. r.Use(GinRecovery(logger.WithGroup("ginRecovery")))
  28. // check header
  29. r.Use(CheckLanguage)
  30. r.Use(CheckSource)
  31. r.Use(CheckProduct)
  32. // check token
  33. r.Use(CheckAuth)
  34. }
  35. // NoCache is a middleware function that appends headers
  36. // to prevent the client from caching the HTTP response.
  37. func NoCache(c *gin.Context) {
  38. c.Header("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate, value")
  39. c.Header("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
  40. c.Header("Last-Modified", time.Now().UTC().Format(http.TimeFormat))
  41. c.Next()
  42. }
  43. // Options is a middleware function that appends headers
  44. // for options requests and aborts then exits the middleware
  45. // chain and ends the request.
  46. func Options(c *gin.Context) {
  47. if c.Request.Method != "OPTIONS" {
  48. c.Next()
  49. } else {
  50. c.Header("Access-Control-Allow-Origin", "*")
  51. c.Header("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS")
  52. c.Header("Access-Control-Allow-Headers", "authorization, origin, content-type, accept")
  53. c.Header("Allow", "HEAD,GET,POST,PUT,PATCH,DELETE,OPTIONS")
  54. c.Header("Content-Type", "application/json")
  55. c.AbortWithStatus(200)
  56. }
  57. }
  58. // Secure is a middleware function that appends security
  59. // and resource access headers.
  60. func Secure(c *gin.Context) {
  61. c.Header("Access-Control-Allow-Origin", "*")
  62. //c.Header("X-Frame-Options", "DENY")
  63. c.Header("X-Content-Type-Options", "nosniff")
  64. c.Header("X-XSS-Protection", "1; mode=block")
  65. if c.Request.TLS != nil {
  66. c.Header("Strict-Transport-Security", "max-age=31536000")
  67. }
  68. // Also consider adding Content-Security-Policy headers
  69. // c.Header("Content-Security-Policy", "script-src 'self' https://cdnjs.cloudflare.com")
  70. }
  71. func CheckLanguage(c *gin.Context) {
  72. if strings.Contains(c.FullPath(), "/api/v1/auth") {
  73. ok := false
  74. language := c.Request.Header.Get("Language")
  75. for _, l := range common.MetadataConfig.GetLanguages() {
  76. if language == l.ToString() {
  77. ok = true
  78. }
  79. }
  80. if !ok {
  81. c.AbortWithStatusJSON(200, common.ErrToH(common.InvalidLanguage, c.GetHeader("Language")))
  82. } else {
  83. c.Set("language", language)
  84. }
  85. }
  86. c.Next()
  87. }
  88. func CheckSource(c *gin.Context) {
  89. if strings.Contains(c.FullPath(), "/api/v1/auth") {
  90. source := c.Request.Header.Get("Source")
  91. ok := false
  92. for _, s := range common.MetadataConfig.GetSources() {
  93. if source == s.ToString() {
  94. ok = true
  95. }
  96. }
  97. if !ok {
  98. c.AbortWithStatusJSON(200, common.ErrToH(common.InvalidSource, c.GetHeader("Source")))
  99. } else {
  100. c.Set("source", source)
  101. }
  102. }
  103. c.Next()
  104. }
  105. func CheckProduct(c *gin.Context) {
  106. if strings.Contains(c.FullPath(), "/api/v1/auth") {
  107. product := c.Request.Header.Get("Product")
  108. if strings.ToUpper(product) != common.MetadataConfig.GetProduct().ToString() {
  109. c.AbortWithStatusJSON(200, common.ErrToH(common.InvalidProduct, product))
  110. } else {
  111. c.Set("product", product)
  112. }
  113. }
  114. c.Next()
  115. }
  116. func CheckAuth(c *gin.Context) {
  117. if strings.Contains(c.FullPath(), "/api/v1/auth") {
  118. token := c.Request.Header.Get("Authorization")
  119. uid, username, err := common.ParseToken(strings.TrimPrefix(token, "Bearer "))
  120. if err != nil {
  121. c.AbortWithStatusJSON(200, common.ErrToH(common.InvalidToken, c.GetHeader("locale")))
  122. }
  123. c.Set("uid", uid)
  124. c.Set("username", username)
  125. }
  126. c.Next()
  127. }
  128. func GinLogger(logger *slog.Logger) gin.HandlerFunc {
  129. return func(c *gin.Context) {
  130. start := time.Now()
  131. path := c.Request.URL.Path
  132. if len(c.Request.URL.RawQuery) > 0 {
  133. path += "?" + c.Request.URL.RawQuery
  134. }
  135. c.Next()
  136. cost := time.Since(start)
  137. logger.Info(fmt.Sprintf("[%s]%s, header[%s-%s-%s] ip[%s], resp[%d] %s errors[%s]",
  138. c.Request.Method,
  139. path,
  140. c.GetHeader("Product"),
  141. c.GetHeader("Source"),
  142. c.GetHeader("Language"),
  143. c.ClientIP(),
  144. c.Writer.Status(),
  145. cost.String(),
  146. c.Errors.ByType(gin.ErrorTypePrivate).String(),
  147. ))
  148. }
  149. }
  150. func GinRecovery(logger *slog.Logger) gin.HandlerFunc {
  151. return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
  152. logger.Error("Recovery from panic", "recoverd", recovered, "stack", string(debug.Stack()))
  153. common.HttpErr(c, common.Unknown)
  154. c.AbortWithStatus(http.StatusOK)
  155. })
  156. }