package logger import ( "context" "errors" "log/slog" "regexp" "strings" "time" ) import ( "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" ) var ( pbkdf2Re *regexp.Regexp ) func init() { pbkdf2Re, _ = regexp.Compile(`'pbkdf2:\S+?'`) } func desensitize(str string) string { fs := pbkdf2Re.FindAllString(str, -1) if len(fs) > 0 { for _, f := range fs { str = strings.Replace(str, f, "'******'", -1) } } return str } type GormLogger struct { Logger *slog.Logger LogLevel gormlogger.LogLevel SlowThreshold time.Duration SkipCallerLookup bool IgnoreRecordNotFoundError bool } func NewGormLogger(logger2 *slog.Logger) *GormLogger { return &GormLogger{ Logger: logger2, LogLevel: gormlogger.Info, SlowThreshold: 100 * time.Millisecond, SkipCallerLookup: false, IgnoreRecordNotFoundError: true, } } func (l GormLogger) SetAsDefault() { gormlogger.Default = l } func (l GormLogger) LogMode(level gormlogger.LogLevel) gormlogger.Interface { return GormLogger{ SlowThreshold: l.SlowThreshold, LogLevel: level, SkipCallerLookup: l.SkipCallerLookup, IgnoreRecordNotFoundError: l.IgnoreRecordNotFoundError, } } func (l GormLogger) Info(ctx context.Context, str string, args ...interface{}) { if l.LogLevel < gormlogger.Info { return } logger.Info(str, slog.Any("data", args)) } func (l GormLogger) Warn(ctx context.Context, str string, args ...interface{}) { if l.LogLevel < gormlogger.Warn { return } logger.Warn(str, slog.Any("data", args)) } func (l GormLogger) Error(ctx context.Context, str string, args ...interface{}) { if l.LogLevel < gormlogger.Error { return } logger.Error(str, slog.Any("data", args)) } func (l GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.LogLevel <= 0 { return } elapsed := time.Since(begin) switch { case err != nil && l.LogLevel >= gormlogger.Error && (!l.IgnoreRecordNotFoundError || !errors.Is(err, gorm.ErrRecordNotFound)): sql, rows := fc() logger.Error("gorm trace error", "err", err, "elapsed", elapsed, "rows", rows, "sql", desensitize(sql)) case l.SlowThreshold != 0 && elapsed > l.SlowThreshold && l.LogLevel >= gormlogger.Warn: sql, rows := fc() logger.Warn("gorm trace warn", "elapsed", elapsed, "rows", rows, "sql", desensitize(sql)) case l.LogLevel >= gormlogger.Info: sql, rows := fc() logger.Info("gorm trace info", "elapsed", elapsed, "rows", rows, "sql", desensitize(sql)) } }