middleware.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. package router
  2. import (
  3. "fmt"
  4. "log/slog"
  5. "net/http"
  6. "runtime/debug"
  7. "strings"
  8. "time"
  9. )
  10. import (
  11. "github.com/gin-gonic/gin"
  12. )
  13. import (
  14. "auth-server/common"
  15. "auth-server/logger"
  16. )
  17. func InitMiddleware(r *gin.Engine) {
  18. // NoCache is a middleware function that appends headers
  19. r.Use(NoCache)
  20. // 跨域处理
  21. r.Use(Options)
  22. // Secure is a middleware function that appends security
  23. r.Use(Secure)
  24. // Use Slog Logger
  25. r.Use(GinLogger(logger.WithGroup("gin")))
  26. // Global Recover
  27. r.Use(GinRecovery(logger.WithGroup("ginRecovery")))
  28. // check header
  29. r.Use(CheckLanguage)
  30. r.Use(CheckSource)
  31. r.Use(CheckProduct)
  32. // check token
  33. r.Use(CheckAuth)
  34. }
  35. // NoCache is a middleware function that appends headers
  36. // to prevent the client from caching the HTTP response.
  37. func NoCache(c *gin.Context) {
  38. c.Header("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate, value")
  39. c.Header("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
  40. c.Header("Last-Modified", time.Now().UTC().Format(http.TimeFormat))
  41. c.Next()
  42. }
  43. // Options is a middleware function that appends headers
  44. // for options requests and aborts then exits the middleware
  45. // chain and ends the request.
  46. func Options(c *gin.Context) {
  47. if c.Request.Method != "OPTIONS" {
  48. c.Next()
  49. } else {
  50. c.Header("Access-Control-Allow-Origin", "*")
  51. c.Header("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS")
  52. c.Header("Access-Control-Allow-Headers", "authorization, origin, content-type, accept")
  53. c.Header("Allow", "HEAD,GET,POST,PUT,PATCH,DELETE,OPTIONS")
  54. c.Header("Content-Type", "application/json")
  55. c.AbortWithStatus(200)
  56. }
  57. }
  58. // Secure is a middleware function that appends security
  59. // and resource access headers.
  60. func Secure(c *gin.Context) {
  61. c.Header("Access-Control-Allow-Origin", "*")
  62. //c.Header("X-Frame-Options", "DENY")
  63. c.Header("X-Content-Type-Options", "nosniff")
  64. c.Header("X-XSS-Protection", "1; mode=block")
  65. if c.Request.TLS != nil {
  66. c.Header("Strict-Transport-Security", "max-age=31536000")
  67. }
  68. // Also consider adding Content-Security-Policy headers
  69. // c.Header("Content-Security-Policy", "script-src 'self' https://cdnjs.cloudflare.com")
  70. }
  71. func CheckLanguage(c *gin.Context) {
  72. ok := false
  73. language := c.Request.Header.Get("Language")
  74. for _, l := range common.MetadataConfig.GetLanguages() {
  75. if language == l.ToString() {
  76. ok = true
  77. }
  78. }
  79. if !ok {
  80. c.AbortWithStatusJSON(200, common.ErrToH(common.InvalidLanguage, c.GetHeader("Language")))
  81. } else {
  82. c.Set("language", language)
  83. }
  84. c.Next()
  85. }
  86. func CheckSource(c *gin.Context) {
  87. source := c.Request.Header.Get("Source")
  88. ok := false
  89. for _, s := range common.MetadataConfig.GetSources() {
  90. if source == s.ToString() {
  91. ok = true
  92. }
  93. }
  94. if !ok {
  95. c.AbortWithStatusJSON(200, common.ErrToH(common.InvalidSource, c.GetHeader("Source")))
  96. } else {
  97. c.Set("source", source)
  98. }
  99. c.Next()
  100. }
  101. func CheckProduct(c *gin.Context) {
  102. product := c.Request.Header.Get("Product")
  103. if strings.ToUpper(product) != common.MetadataConfig.GetProduct().ToString() {
  104. c.AbortWithStatusJSON(200, common.ErrToH(common.InvalidProduct, product))
  105. } else {
  106. c.Set("product", product)
  107. }
  108. c.Next()
  109. }
  110. func CheckAuth(c *gin.Context) {
  111. if strings.HasPrefix(c.FullPath(), "/dr/api/v1/auth") {
  112. token := c.Request.Header.Get("Authorization")
  113. uid, username, err := common.ParseToken(strings.TrimPrefix(token, "Bearer "))
  114. if err != nil {
  115. c.AbortWithStatusJSON(200, common.ErrToH(common.InvalidToken, c.GetHeader("locale")))
  116. }
  117. c.Set("uid", uid)
  118. c.Set("username", username)
  119. }
  120. c.Next()
  121. }
  122. func GinLogger(logger *slog.Logger) gin.HandlerFunc {
  123. return func(c *gin.Context) {
  124. start := time.Now()
  125. path := c.Request.URL.Path
  126. if len(c.Request.URL.RawQuery) > 0 {
  127. path += "?" + c.Request.URL.RawQuery
  128. }
  129. c.Next()
  130. cost := time.Since(start)
  131. logger.Info(fmt.Sprintf("[%s]%s, header[%s-%s-%s] ip[%s], resp[%d] %s errors[%s]",
  132. c.Request.Method,
  133. path,
  134. c.GetHeader("Product"),
  135. c.GetHeader("Source"),
  136. c.GetHeader("Language"),
  137. c.ClientIP(),
  138. c.Writer.Status(),
  139. cost.String(),
  140. c.Errors.ByType(gin.ErrorTypePrivate).String(),
  141. ))
  142. }
  143. }
  144. func GinRecovery(logger *slog.Logger) gin.HandlerFunc {
  145. return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
  146. logger.Error("Recovery from panic", "recoverd", recovered, "stack", string(debug.Stack()))
  147. common.HttpErr(c, common.Unknown)
  148. c.AbortWithStatus(http.StatusOK)
  149. })
  150. }