l_ai_knowledge/internal/models/chat/ollama.go

192 lines
4.7 KiB
Go

package chat
import (
"context"
"fmt"
ollamaapi "github.com/ollama/ollama/api"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/models/utils/ollama"
"knowlege-lsxd/internal/types"
)
// OllamaChat 实现了基于 Ollama 的聊天
type OllamaChat struct {
modelName string
modelID string
ollamaService *ollama.OllamaService
}
// NewOllamaChat 创建 Ollama 聊天实例
func NewOllamaChat(config *ChatConfig, ollamaService *ollama.OllamaService) (*OllamaChat, error) {
return &OllamaChat{
modelName: config.ModelName,
modelID: config.ModelID,
ollamaService: ollamaService,
}, nil
}
// convertMessages 转换消息格式为Ollama API格式
func (c *OllamaChat) convertMessages(messages []Message) []ollamaapi.Message {
ollamaMessages := make([]ollamaapi.Message, len(messages))
for i, msg := range messages {
ollamaMessages[i] = ollamaapi.Message{
Role: msg.Role,
Content: msg.Content,
}
}
return ollamaMessages
}
// buildChatRequest 构建聊天请求参数
func (c *OllamaChat) buildChatRequest(messages []Message, opts *ChatOptions, isStream bool) *ollamaapi.ChatRequest {
// 设置流式标志
streamFlag := isStream
// 构建请求参数
chatReq := &ollamaapi.ChatRequest{
Model: c.modelName,
Messages: c.convertMessages(messages),
Stream: &streamFlag,
Options: make(map[string]interface{}),
}
// 添加可选参数
if opts != nil {
if opts.Temperature > 0 {
chatReq.Options["temperature"] = opts.Temperature
}
if opts.TopP > 0 {
chatReq.Options["top_p"] = opts.TopP
}
if opts.MaxTokens > 0 {
chatReq.Options["num_predict"] = opts.MaxTokens
}
if opts.Thinking != nil {
chatReq.Think = &ollamaapi.ThinkValue{
Value: *opts.Thinking,
}
}
}
return chatReq
}
// Chat 进行非流式聊天
func (c *OllamaChat) Chat(ctx context.Context, messages []Message, opts *ChatOptions) (*types.ChatResponse, error) {
// 确保模型可用
if err := c.ensureModelAvailable(ctx); err != nil {
return nil, err
}
// 构建请求参数
chatReq := c.buildChatRequest(messages, opts, false)
// 记录请求日志
logger.GetLogger(ctx).Infof("发送聊天请求到模型 %s", c.modelName)
var responseContent string
var promptTokens, completionTokens int
// 使用 Ollama 客户端发送请求
err := c.ollamaService.Chat(ctx, chatReq, func(resp ollamaapi.ChatResponse) error {
responseContent = resp.Message.Content
// 获取token计数
if resp.EvalCount > 0 {
promptTokens = resp.PromptEvalCount
completionTokens = resp.EvalCount - promptTokens
}
return nil
})
if err != nil {
return nil, fmt.Errorf("聊天请求失败: %w", err)
}
// 构建响应
return &types.ChatResponse{
Content: responseContent,
Usage: struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
},
}, nil
}
// ChatStream 进行流式聊天
func (c *OllamaChat) ChatStream(
ctx context.Context,
messages []Message,
opts *ChatOptions,
) (<-chan types.StreamResponse, error) {
// 确保模型可用
if err := c.ensureModelAvailable(ctx); err != nil {
return nil, err
}
// 构建请求参数
chatReq := c.buildChatRequest(messages, opts, true)
// 记录请求日志
logger.GetLogger(ctx).Infof("发送流式聊天请求到模型 %s", c.modelName)
// 创建流式响应通道
streamChan := make(chan types.StreamResponse)
// 启动goroutine处理流式响应
go func() {
defer close(streamChan)
err := c.ollamaService.Chat(ctx, chatReq, func(resp ollamaapi.ChatResponse) error {
if resp.Message.Content != "" {
streamChan <- types.StreamResponse{
ResponseType: types.ResponseTypeAnswer,
Content: resp.Message.Content,
Done: false,
}
}
if resp.Done {
streamChan <- types.StreamResponse{
ResponseType: types.ResponseTypeAnswer,
Done: true,
}
}
return nil
})
if err != nil {
logger.GetLogger(ctx).Errorf("流式聊天请求失败: %v", err)
// 发送错误响应
streamChan <- types.StreamResponse{
ResponseType: types.ResponseTypeAnswer,
Done: true,
}
}
}()
return streamChan, nil
}
// 确保模型可用
func (c *OllamaChat) ensureModelAvailable(ctx context.Context) error {
logger.GetLogger(ctx).Infof("确保模型 %s 可用", c.modelName)
return c.ollamaService.EnsureModelAvailable(ctx, c.modelName)
}
// GetModelName 获取模型名称
func (c *OllamaChat) GetModelName() string {
return c.modelName
}
// GetModelID 获取模型ID
func (c *OllamaChat) GetModelID() string {
return c.modelID
}