123 lines
4.0 KiB
Go
123 lines
4.0 KiB
Go
package rerank
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
|
|
"knowlege-lsxd/internal/logger"
|
|
)
|
|
|
|
// OpenAIReranker implements a reranking system based on OpenAI models
|
|
type OpenAIReranker struct {
|
|
modelName string // Name of the model used for reranking
|
|
modelID string // Unique identifier of the model
|
|
apiKey string // API key for authentication
|
|
baseURL string // Base URL for API requests
|
|
client *http.Client // HTTP client for making API requests
|
|
}
|
|
|
|
// RerankRequest represents a request to rerank documents based on relevance to a query
|
|
type RerankRequest struct {
|
|
Model string `json:"model"` // Model to use for reranking
|
|
Query string `json:"query"` // Query text to compare documents against
|
|
Documents []string `json:"documents"` // List of document texts to rerank
|
|
AdditionalData map[string]interface{} `json:"additional_data"` // Optional additional data for the model
|
|
TruncatePromptTokens int `json:"truncate_prompt_tokens"` // Maximum prompt tokens to use
|
|
}
|
|
|
|
// RerankResponse represents the response from a reranking request
|
|
type RerankResponse struct {
|
|
ID string `json:"id"` // Request ID
|
|
Model string `json:"model"` // Model used for reranking
|
|
Usage UsageInfo `json:"usage"` // Token usage information
|
|
Results []RankResult `json:"results"` // Ranked results with relevance scores
|
|
}
|
|
|
|
// UsageInfo contains information about token usage in the API request
|
|
type UsageInfo struct {
|
|
TotalTokens int `json:"total_tokens"` // Total tokens consumed
|
|
}
|
|
|
|
// NewOpenAIReranker creates a new instance of OpenAI reranker with the provided configuration
|
|
func NewOpenAIReranker(config *RerankerConfig) (*OpenAIReranker, error) {
|
|
apiKey := config.APIKey
|
|
baseURL := "https://api.openai.com/v1"
|
|
if url := config.BaseURL; url != "" {
|
|
baseURL = url
|
|
}
|
|
|
|
return &OpenAIReranker{
|
|
modelName: config.ModelName,
|
|
modelID: config.ModelID,
|
|
apiKey: apiKey,
|
|
baseURL: baseURL,
|
|
client: &http.Client{},
|
|
}, nil
|
|
}
|
|
|
|
// Rerank performs document reranking based on relevance to the query
|
|
func (r *OpenAIReranker) Rerank(ctx context.Context, query string, documents []string) ([]RankResult, error) {
|
|
// Build the request body
|
|
requestBody := &RerankRequest{
|
|
Model: r.modelName,
|
|
Query: query,
|
|
Documents: documents,
|
|
TruncatePromptTokens: 511,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(requestBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal request body: %w", err)
|
|
}
|
|
|
|
// Send the request
|
|
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/rerank", r.baseURL), bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", r.apiKey))
|
|
|
|
// Log the curl equivalent for debugging
|
|
logger.GetLogger(ctx).Infof(
|
|
"curl -X POST %s/rerank -H \"Content-Type: application/json\" -H \"Authorization: Bearer %s\" -d '%s'",
|
|
r.baseURL, r.apiKey, string(jsonData),
|
|
)
|
|
|
|
resp, err := r.client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("do request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Read the response
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read response body: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("Rerank API error: Http Status: %s", resp.Status)
|
|
}
|
|
|
|
var response RerankResponse
|
|
if err := json.Unmarshal(body, &response); err != nil {
|
|
return nil, fmt.Errorf("unmarshal response: %w", err)
|
|
}
|
|
return response.Results, nil
|
|
}
|
|
|
|
// GetModelName returns the name of the reranking model
|
|
func (r *OpenAIReranker) GetModelName() string {
|
|
return r.modelName
|
|
}
|
|
|
|
// GetModelID returns the unique identifier of the reranking model
|
|
func (r *OpenAIReranker) GetModelID() string {
|
|
return r.modelID
|
|
}
|