l_ai_knowledge/internal/models/chat/remote_api.go

167 lines
4.1 KiB
Go

package chat
import (
"context"
"fmt"
"github.com/sashabaranov/go-openai"
"knowlege-lsxd/internal/types"
)
// RemoteAPIChat 实现了基于的聊天
type RemoteAPIChat struct {
modelName string
client *openai.Client
modelID string
}
// NewRemoteAPIChat 调用远程API 聊天实例
func NewRemoteAPIChat(chatConfig *ChatConfig) (*RemoteAPIChat, error) {
apiKey := chatConfig.APIKey
config := openai.DefaultConfig(apiKey)
if baseURL := chatConfig.BaseURL; baseURL != "" {
config.BaseURL = baseURL
}
return &RemoteAPIChat{
modelName: chatConfig.ModelName,
client: openai.NewClientWithConfig(config),
modelID: chatConfig.ModelID,
}, nil
}
// convertMessages 转换消息格式为OpenAI格式
func (c *RemoteAPIChat) convertMessages(messages []Message) []openai.ChatCompletionMessage {
openaiMessages := make([]openai.ChatCompletionMessage, len(messages))
for i, msg := range messages {
openaiMessages[i] = openai.ChatCompletionMessage{
Role: msg.Role,
Content: msg.Content,
}
}
return openaiMessages
}
// buildChatCompletionRequest 构建聊天请求参数
func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
opts *ChatOptions, isStream bool,
) openai.ChatCompletionRequest {
req := openai.ChatCompletionRequest{
Model: c.modelName,
Messages: c.convertMessages(messages),
Stream: isStream,
}
// 添加可选参数
if opts != nil {
if opts.Temperature > 0 {
req.Temperature = float32(opts.Temperature)
}
if opts.TopP > 0 {
req.TopP = float32(opts.TopP)
}
if opts.MaxTokens > 0 {
req.MaxTokens = opts.MaxTokens
}
if opts.MaxCompletionTokens > 0 {
req.MaxCompletionTokens = opts.MaxCompletionTokens
}
if opts.FrequencyPenalty > 0 {
req.FrequencyPenalty = float32(opts.FrequencyPenalty)
}
if opts.PresencePenalty > 0 {
req.PresencePenalty = float32(opts.PresencePenalty)
}
if opts.Thinking != nil {
req.ChatTemplateKwargs = map[string]any{
"enable_thinking": *opts.Thinking,
}
}
}
return req
}
// Chat 进行非流式聊天
func (c *RemoteAPIChat) Chat(ctx context.Context, messages []Message, opts *ChatOptions) (*types.ChatResponse, error) {
// 构建请求参数
req := c.buildChatCompletionRequest(messages, opts, false)
// 发送请求
resp, err := c.client.CreateChatCompletion(ctx, req)
if err != nil {
return nil, fmt.Errorf("create chat completion: %w", err)
}
if len(resp.Choices) == 0 {
return nil, fmt.Errorf("no response from OpenAI")
}
// 转换响应格式
return &types.ChatResponse{
Content: resp.Choices[0].Message.Content,
Usage: struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
},
}, nil
}
// ChatStream 进行流式聊天
func (c *RemoteAPIChat) ChatStream(ctx context.Context,
messages []Message, opts *ChatOptions,
) (<-chan types.StreamResponse, error) {
// 构建请求参数
req := c.buildChatCompletionRequest(messages, opts, true)
// 创建流式响应通道
streamChan := make(chan types.StreamResponse)
// 启动流式请求
stream, err := c.client.CreateChatCompletionStream(ctx, req)
if err != nil {
close(streamChan)
return nil, fmt.Errorf("create chat completion stream: %w", err)
}
// 在后台处理流式响应
go func() {
defer close(streamChan)
defer stream.Close()
for {
response, err := stream.Recv()
if err != nil {
streamChan <- types.StreamResponse{
ResponseType: types.ResponseTypeAnswer,
Done: true,
}
return
}
if len(response.Choices) > 0 {
streamChan <- types.StreamResponse{
ResponseType: types.ResponseTypeAnswer,
Content: response.Choices[0].Delta.Content,
Done: false,
}
}
}
}()
return streamChan, nil
}
// GetModelName 获取模型名称
func (c *RemoteAPIChat) GetModelName() string {
return c.modelName
}
// GetModelID 获取模型ID
func (c *RemoteAPIChat) GetModelID() string {
return c.modelID
}