132 lines
2.8 KiB
Go
132 lines
2.8 KiB
Go
package utils_ollama
|
||
|
||
import (
|
||
"ai_scheduler/internal/config"
|
||
"ai_scheduler/internal/entitys"
|
||
"context"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
"sync"
|
||
|
||
"github.com/ollama/ollama/api"
|
||
)
|
||
|
||
// Client Ollama客户端适配器
|
||
type Client struct {
|
||
client *api.Client
|
||
config *config.OllamaConfig
|
||
}
|
||
|
||
// NewClient 创建新的Ollama客户端
|
||
func NewClient(config *config.Config) (client *Client, cleanFunc func(), err error) {
|
||
client = &Client{
|
||
config: &config.Ollama,
|
||
}
|
||
url, err := client.getUrl()
|
||
if err != nil {
|
||
return
|
||
}
|
||
client.client = api.NewClient(url, http.DefaultClient)
|
||
|
||
cleanup := func() {
|
||
if client != nil {
|
||
client = nil
|
||
}
|
||
}
|
||
|
||
return client, cleanup, nil
|
||
}
|
||
|
||
// ToolSelect 工具选择
|
||
func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools []api.Tool) (res api.ChatResponse, err error) {
|
||
// 构建聊天请求
|
||
req := &api.ChatRequest{
|
||
Model: c.config.Model,
|
||
Messages: messages,
|
||
Stream: new(bool), // 设置为false,不使用流式响应
|
||
Think: &api.ThinkValue{Value: true},
|
||
Tools: tools,
|
||
}
|
||
|
||
err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||
res = resp
|
||
return nil
|
||
})
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (c *Client) ChatStream(ctx context.Context, ch chan entitys.ResponseData, messages []api.Message) (err error) {
|
||
// 构建聊天请求
|
||
req := &api.ChatRequest{
|
||
Model: c.config.Model,
|
||
Messages: messages,
|
||
Stream: nil,
|
||
Think: &api.ThinkValue{Value: true},
|
||
}
|
||
var w sync.WaitGroup
|
||
w.Add(1)
|
||
go func() {
|
||
defer w.Done()
|
||
err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||
if resp.Message.Content != "" {
|
||
ch <- entitys.ResponseData{
|
||
Done: false,
|
||
Content: resp.Message.Content,
|
||
Type: entitys.ResponseStream,
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
if err != nil {
|
||
return
|
||
}
|
||
}()
|
||
w.Wait()
|
||
|
||
return
|
||
}
|
||
|
||
// convertResponse 转换响应格式
|
||
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
|
||
//result := &entitys.ChatResponse{
|
||
// Message: resp.Message.Content,
|
||
// Finished: resp.Done,
|
||
//}
|
||
//
|
||
//// 转换工具调用
|
||
//if len(resp.Message.ToolCalls) > 0 {
|
||
// result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls))
|
||
// for i, toolCall := range resp.Message.ToolCalls {
|
||
// // 转换函数参数
|
||
// argBytes, _ := json.Marshal(toolCall.Function.Arguments)
|
||
//
|
||
// result.ToolCalls[i] = entitys.ToolCall{
|
||
// ID: fmt.Sprintf("call_%d", i),
|
||
// Type: "function",
|
||
// Function: entitys.FunctionCall{
|
||
// Name: toolCall.Function.Name,
|
||
// Arguments: json.RawMessage(argBytes),
|
||
// },
|
||
// }
|
||
// }
|
||
//}
|
||
|
||
//return result
|
||
return nil
|
||
}
|
||
|
||
func (c *Client) getUrl() (*url.URL, error) {
|
||
baseURL := c.config.BaseURL
|
||
envURL := os.Getenv("OLLAMA_BASE_URL")
|
||
if envURL != "" {
|
||
baseURL = envURL
|
||
}
|
||
|
||
return url.Parse(baseURL)
|
||
}
|