l_ai_knowledge/internal/middleware/trace.go

125 lines
3.5 KiB
Go

package middleware
import (
"bytes"
"fmt"
"io"
"strings"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"knowlege-lsxd/internal/tracing"
"knowlege-lsxd/internal/types"
)
// Custom ResponseWriter to capture response content
type responseBodyWriter struct {
gin.ResponseWriter
body *bytes.Buffer
}
// Override Write method to write response content to buffer and original writer
func (r responseBodyWriter) Write(b []byte) (int, error) {
r.body.Write(b)
return r.ResponseWriter.Write(b)
}
// TracingMiddleware provides a Gin middleware that creates a trace span for each request
func TracingMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Extract trace context from request headers
propagator := tracing.GetTracer()
if propagator == nil {
c.Next()
return
}
// Get request ID as Span ID
requestID := c.GetString(string(types.RequestIDContextKey))
if requestID == "" {
requestID = c.GetHeader("X-Request-ID")
}
// Create new span
spanName := fmt.Sprintf("%s %s", c.Request.Method, c.FullPath())
ctx, span := tracing.ContextWithSpan(c.Request.Context(), spanName)
defer span.End()
// Set basic span attributes
span.SetAttributes(
attribute.String("http.method", c.Request.Method),
attribute.String("http.url", c.Request.URL.String()),
attribute.String("http.path", c.FullPath()),
)
// Record request headers (optional, or selectively record important headers)
for key, values := range c.Request.Header {
// Skip sensitive or unnecessary headers
if strings.ToLower(key) == "authorization" || strings.ToLower(key) == "cookie" {
continue
}
span.SetAttributes(attribute.String("http.request.header."+key, strings.Join(values, ";")))
}
// Record request body (for POST/PUT/PATCH requests)
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
if c.Request.Body != nil {
bodyBytes, _ := io.ReadAll(c.Request.Body)
span.SetAttributes(attribute.String("http.request.body", string(bodyBytes)))
// Reset request body because ReadAll consumes the Reader content
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
}
// Record query parameters
if len(c.Request.URL.RawQuery) > 0 {
span.SetAttributes(attribute.String("http.request.query", c.Request.URL.RawQuery))
}
// Set request context with span context
c.Request = c.Request.WithContext(ctx)
// Store tracing context in Gin context
c.Set("trace.span", span)
c.Set("trace.ctx", ctx)
// Create response body capturer
responseBody := &bytes.Buffer{}
responseWriter := &responseBodyWriter{
ResponseWriter: c.Writer,
body: responseBody,
}
c.Writer = responseWriter
// Process request
c.Next()
// Set response status code
statusCode := c.Writer.Status()
span.SetAttributes(attribute.Int("http.status_code", statusCode))
// Record response body
responseContent := responseBody.String()
if len(responseContent) > 0 {
span.SetAttributes(attribute.String("http.response.body", responseContent))
}
// Record response headers (optional, or selectively record important headers)
for key, values := range c.Writer.Header() {
span.SetAttributes(attribute.String("http.response.header."+key, strings.Join(values, ";")))
}
// Mark as error if status code >= 400
if statusCode >= 400 {
span.SetStatus(codes.Error, fmt.Sprintf("HTTP %d", statusCode))
if err := c.Errors.Last(); err != nil {
span.RecordError(err.Err)
}
} else {
span.SetStatus(codes.Ok, "")
}
}
}