ai_scheduler/internal/pkg/utils_ollama/client.go.bak

125 lines
2.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package utils_ollama
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"context"
"encoding/json"
"fmt"
"time"
"github.com/ollama/ollama/api"
"github.com/tmc/langchaingo/llms/ollama"
)
// Client Ollama客户端适配器
type Client struct {
client *api.Client
config *config.OllamaConfig
}
// NewClient 创建新的Ollama客户端
func NewClient(config *config.Config) (entitys.AIClient, func(), error) {
client, err := api.ClientFromEnvironment()
cleanup := func() {
if client != nil {
client = nil
}
}
if err != nil {
return nil, cleanup, fmt.Errorf("failed to create ollama client: %w", err)
}
return &Client{
client: client,
config: &config.Ollama,
}, cleanup, nil
}
// Chat 实现聊天功能
func (c *Client) Chat(ctx context.Context, messages []entitys.Message, tools []entitys.ToolDefinition) (*entitys.ChatResponse, error) {
// 构建聊天请求
req := &api.ChatRequest{
Model: c.config.Model,
Messages: make([]api.Message, len(messages)),
Stream: new(bool), // 设置为false不使用流式响应
Think: &api.ThinkValue{Value: true},
}
// 转换消息格式
for i, msg := range messages {
req.Messages[i] = api.Message{
Role: msg.Role,
Content: msg.Content,
}
}
// 添加工具定义
if len(tools) > 0 {
req.Tools = make([]api.Tool, len(tools))
for i, tool := range tools {
toolData, _ := json.Marshal(tool)
var apiTool api.Tool
json.Unmarshal(toolData, &apiTool)
req.Tools[i] = apiTool
}
}
// 发送请求
responseChan := make(chan api.ChatResponse)
errorChan := make(chan error)
go func() {
err := c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
responseChan <- resp
return nil
})
if err != nil {
errorChan <- err
}
close(responseChan)
close(errorChan)
}()
// 等待响应
select {
case resp := <-responseChan:
return c.convertResponse(&resp), nil
case err := <-errorChan:
return nil, fmt.Errorf("chat request failed: %w", err)
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(c.config.Timeout):
return nil, fmt.Errorf("chat request timeout")
}
}
// convertResponse 转换响应格式
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
//result := &entitys.ChatResponse{
// Message: resp.Message.Content,
// Finished: resp.Done,
//}
//
//// 转换工具调用
//if len(resp.Message.ToolCalls) > 0 {
// result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls))
// for i, toolCall := range resp.Message.ToolCalls {
// // 转换函数参数
// argBytes, _ := json.Marshal(toolCall.Function.Arguments)
//
// result.ToolCalls[i] = entitys.ToolCall{
// ID: fmt.Sprintf("call_%d", i),
// Type: "function",
// Function: entitys.FunctionCall{
// Name: toolCall.Function.Name,
// Arguments: json.RawMessage(argBytes),
// },
// }
// }
//}
//return result
return nil
}