package utils_ollama import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "context" "net/http" "net/url" "os" "sync" "github.com/ollama/ollama/api" ) // Client Ollama客户端适配器 type Client struct { client *api.Client config *config.OllamaConfig } // NewClient 创建新的Ollama客户端 func NewClient(config *config.Config) (client *Client, cleanFunc func(), err error) { client = &Client{ config: &config.Ollama, } url, err := client.getUrl() if err != nil { return } client.client = api.NewClient(url, http.DefaultClient) cleanup := func() { if client != nil { client = nil } } return client, cleanup, nil } // ToolSelect 工具选择 func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools []api.Tool) (res api.ChatResponse, err error) { // 构建聊天请求 req := &api.ChatRequest{ Model: c.config.Model, Messages: messages, Stream: new(bool), // 设置为false,不使用流式响应 Think: &api.ThinkValue{Value: true}, //Tools: tools, } err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { res = resp return nil }) if err != nil { return } return } func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string) (err error) { // 构建聊天请求 req := &api.ChatRequest{ Model: c.config.Model, Messages: messages, Stream: nil, Think: &api.ThinkValue{Value: true}, } var w sync.WaitGroup w.Add(1) go func() { defer w.Done() err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { if resp.Message.Content != "" { ch <- entitys.Response{ Index: index, Content: resp.Message.Content, Type: entitys.ResponseStream, } } return nil }) if err != nil { return } }() w.Wait() return } // convertResponse 转换响应格式 func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse { //result := &entitys.ChatResponse{ // Message: resp.Message.Content, // Finished: resp.Done, //} // //// 转换工具调用 //if len(resp.Message.ToolCalls) > 0 { // result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls)) // for i, toolCall := range resp.Message.ToolCalls { // // 转换函数参数 // argBytes, _ := json.Marshal(toolCall.Function.Arguments) // // result.ToolCalls[i] = entitys.ToolCall{ // ID: fmt.Sprintf("call_%d", i), // Type: "function", // Function: entitys.FunctionCall{ // Name: toolCall.Function.Name, // Arguments: json.RawMessage(argBytes), // }, // } // } //} //return result return nil } func (c *Client) getUrl() (*url.URL, error) { baseURL := c.config.BaseURL envURL := os.Getenv("OLLAMA_BASE_URL") if envURL != "" { baseURL = envURL } return url.Parse(baseURL) }