84 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			84 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Go
		
	
	
	
package middleware
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/Tencent/WeKnora/internal/logger"
 | 
						|
	"github.com/Tencent/WeKnora/internal/types"
 | 
						|
	"github.com/gin-gonic/gin"
 | 
						|
	"github.com/google/uuid"
 | 
						|
)
 | 
						|
 | 
						|
// RequestID middleware adds a unique request ID to the context
 | 
						|
func RequestID() gin.HandlerFunc {
 | 
						|
	return func(c *gin.Context) {
 | 
						|
		// Get request ID from header or generate a new one
 | 
						|
		requestID := c.GetHeader("X-Request-ID")
 | 
						|
		if requestID == "" {
 | 
						|
			requestID = uuid.New().String()
 | 
						|
		}
 | 
						|
		// Set request ID in header
 | 
						|
		c.Header("X-Request-ID", requestID)
 | 
						|
 | 
						|
		// Set request ID in context
 | 
						|
		c.Set(types.RequestIDContextKey.String(), requestID)
 | 
						|
 | 
						|
		// Set logger in context
 | 
						|
		requestLogger := logger.GetLogger(c)
 | 
						|
		requestLogger = requestLogger.WithField("request_id", requestID)
 | 
						|
		c.Set(types.LoggerContextKey.String(), requestLogger)
 | 
						|
 | 
						|
		// Set request ID in the global context for logging
 | 
						|
		c.Request = c.Request.WithContext(
 | 
						|
			context.WithValue(
 | 
						|
				context.WithValue(c.Request.Context(), types.RequestIDContextKey, requestID),
 | 
						|
				types.LoggerContextKey, requestLogger,
 | 
						|
			),
 | 
						|
		)
 | 
						|
 | 
						|
		c.Next()
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Logger middleware logs request details with request ID
 | 
						|
func Logger() gin.HandlerFunc {
 | 
						|
	return func(c *gin.Context) {
 | 
						|
		start := time.Now()
 | 
						|
		path := c.Request.URL.Path
 | 
						|
		raw := c.Request.URL.RawQuery
 | 
						|
 | 
						|
		// Process request
 | 
						|
		c.Next()
 | 
						|
 | 
						|
		// Get request ID from context
 | 
						|
		requestID, exists := c.Get(types.RequestIDContextKey.String())
 | 
						|
		if !exists {
 | 
						|
			requestID = "unknown"
 | 
						|
		}
 | 
						|
 | 
						|
		// Calculate latency
 | 
						|
		latency := time.Since(start)
 | 
						|
 | 
						|
		// Get client IP and status code
 | 
						|
		clientIP := c.ClientIP()
 | 
						|
		statusCode := c.Writer.Status()
 | 
						|
		method := c.Request.Method
 | 
						|
 | 
						|
		if raw != "" {
 | 
						|
			path = path + "?" + raw
 | 
						|
		}
 | 
						|
 | 
						|
		// Log with request ID
 | 
						|
		logger.GetLogger(c).Infof("[%s] %d | %3d | %13v | %15s | %s %s",
 | 
						|
			requestID,
 | 
						|
			statusCode,
 | 
						|
			c.Writer.Size(),
 | 
						|
			latency,
 | 
						|
			clientIP,
 | 
						|
			method,
 | 
						|
			path,
 | 
						|
		)
 | 
						|
	}
 | 
						|
}
 |