192 lines
4.7 KiB
Go
192 lines
4.7 KiB
Go
package chat
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
ollamaapi "github.com/ollama/ollama/api"
|
|
"knowlege-lsxd/internal/logger"
|
|
"knowlege-lsxd/internal/models/utils/ollama"
|
|
"knowlege-lsxd/internal/types"
|
|
)
|
|
|
|
// OllamaChat 实现了基于 Ollama 的聊天
|
|
type OllamaChat struct {
|
|
modelName string
|
|
modelID string
|
|
ollamaService *ollama.OllamaService
|
|
}
|
|
|
|
// NewOllamaChat 创建 Ollama 聊天实例
|
|
func NewOllamaChat(config *ChatConfig, ollamaService *ollama.OllamaService) (*OllamaChat, error) {
|
|
return &OllamaChat{
|
|
modelName: config.ModelName,
|
|
modelID: config.ModelID,
|
|
ollamaService: ollamaService,
|
|
}, nil
|
|
}
|
|
|
|
// convertMessages 转换消息格式为Ollama API格式
|
|
func (c *OllamaChat) convertMessages(messages []Message) []ollamaapi.Message {
|
|
ollamaMessages := make([]ollamaapi.Message, len(messages))
|
|
for i, msg := range messages {
|
|
ollamaMessages[i] = ollamaapi.Message{
|
|
Role: msg.Role,
|
|
Content: msg.Content,
|
|
}
|
|
}
|
|
return ollamaMessages
|
|
}
|
|
|
|
// buildChatRequest 构建聊天请求参数
|
|
func (c *OllamaChat) buildChatRequest(messages []Message, opts *ChatOptions, isStream bool) *ollamaapi.ChatRequest {
|
|
// 设置流式标志
|
|
streamFlag := isStream
|
|
|
|
// 构建请求参数
|
|
chatReq := &ollamaapi.ChatRequest{
|
|
Model: c.modelName,
|
|
Messages: c.convertMessages(messages),
|
|
Stream: &streamFlag,
|
|
Options: make(map[string]interface{}),
|
|
}
|
|
|
|
// 添加可选参数
|
|
if opts != nil {
|
|
if opts.Temperature > 0 {
|
|
chatReq.Options["temperature"] = opts.Temperature
|
|
}
|
|
if opts.TopP > 0 {
|
|
chatReq.Options["top_p"] = opts.TopP
|
|
}
|
|
if opts.MaxTokens > 0 {
|
|
chatReq.Options["num_predict"] = opts.MaxTokens
|
|
}
|
|
if opts.Thinking != nil {
|
|
chatReq.Think = &ollamaapi.ThinkValue{
|
|
Value: *opts.Thinking,
|
|
}
|
|
}
|
|
}
|
|
|
|
return chatReq
|
|
}
|
|
|
|
// Chat 进行非流式聊天
|
|
func (c *OllamaChat) Chat(ctx context.Context, messages []Message, opts *ChatOptions) (*types.ChatResponse, error) {
|
|
// 确保模型可用
|
|
if err := c.ensureModelAvailable(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 构建请求参数
|
|
chatReq := c.buildChatRequest(messages, opts, false)
|
|
|
|
// 记录请求日志
|
|
logger.GetLogger(ctx).Infof("发送聊天请求到模型 %s", c.modelName)
|
|
|
|
var responseContent string
|
|
var promptTokens, completionTokens int
|
|
|
|
// 使用 Ollama 客户端发送请求
|
|
err := c.ollamaService.Chat(ctx, chatReq, func(resp ollamaapi.ChatResponse) error {
|
|
responseContent = resp.Message.Content
|
|
|
|
// 获取token计数
|
|
if resp.EvalCount > 0 {
|
|
promptTokens = resp.PromptEvalCount
|
|
completionTokens = resp.EvalCount - promptTokens
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("聊天请求失败: %w", err)
|
|
}
|
|
|
|
// 构建响应
|
|
return &types.ChatResponse{
|
|
Content: responseContent,
|
|
Usage: struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}{
|
|
PromptTokens: promptTokens,
|
|
CompletionTokens: completionTokens,
|
|
TotalTokens: promptTokens + completionTokens,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// ChatStream 进行流式聊天
|
|
func (c *OllamaChat) ChatStream(
|
|
ctx context.Context,
|
|
messages []Message,
|
|
opts *ChatOptions,
|
|
) (<-chan types.StreamResponse, error) {
|
|
// 确保模型可用
|
|
if err := c.ensureModelAvailable(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 构建请求参数
|
|
chatReq := c.buildChatRequest(messages, opts, true)
|
|
|
|
// 记录请求日志
|
|
logger.GetLogger(ctx).Infof("发送流式聊天请求到模型 %s", c.modelName)
|
|
|
|
// 创建流式响应通道
|
|
streamChan := make(chan types.StreamResponse)
|
|
|
|
// 启动goroutine处理流式响应
|
|
go func() {
|
|
defer close(streamChan)
|
|
|
|
err := c.ollamaService.Chat(ctx, chatReq, func(resp ollamaapi.ChatResponse) error {
|
|
if resp.Message.Content != "" {
|
|
streamChan <- types.StreamResponse{
|
|
ResponseType: types.ResponseTypeAnswer,
|
|
Content: resp.Message.Content,
|
|
Done: false,
|
|
}
|
|
}
|
|
|
|
if resp.Done {
|
|
streamChan <- types.StreamResponse{
|
|
ResponseType: types.ResponseTypeAnswer,
|
|
Done: true,
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
logger.GetLogger(ctx).Errorf("流式聊天请求失败: %v", err)
|
|
// 发送错误响应
|
|
streamChan <- types.StreamResponse{
|
|
ResponseType: types.ResponseTypeAnswer,
|
|
Done: true,
|
|
}
|
|
}
|
|
}()
|
|
|
|
return streamChan, nil
|
|
}
|
|
|
|
// 确保模型可用
|
|
func (c *OllamaChat) ensureModelAvailable(ctx context.Context) error {
|
|
logger.GetLogger(ctx).Infof("确保模型 %s 可用", c.modelName)
|
|
return c.ollamaService.EnsureModelAvailable(ctx, c.modelName)
|
|
}
|
|
|
|
// GetModelName 获取模型名称
|
|
func (c *OllamaChat) GetModelName() string {
|
|
return c.modelName
|
|
}
|
|
|
|
// GetModelID 获取模型ID
|
|
func (c *OllamaChat) GetModelID() string {
|
|
return c.modelID
|
|
}
|