142 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			142 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Go
		
	
	
	
package utils_ollama
 | 
						||
 | 
						||
import (
 | 
						||
	"ai_scheduler/internal/config"
 | 
						||
	"ai_scheduler/internal/entitys"
 | 
						||
	"context"
 | 
						||
	"net/http"
 | 
						||
	"net/url"
 | 
						||
	"os"
 | 
						||
	"sync"
 | 
						||
 | 
						||
	"github.com/ollama/ollama/api"
 | 
						||
)
 | 
						||
 | 
						||
// Client Ollama客户端适配器
 | 
						||
type Client struct {
 | 
						||
	client *api.Client
 | 
						||
	config *config.OllamaConfig
 | 
						||
}
 | 
						||
 | 
						||
// NewClient 创建新的Ollama客户端
 | 
						||
func NewClient(config *config.Config) (client *Client, cleanFunc func(), err error) {
 | 
						||
	client = &Client{
 | 
						||
		config: &config.Ollama,
 | 
						||
	}
 | 
						||
	url, err := client.getUrl()
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	client.client = api.NewClient(url, http.DefaultClient)
 | 
						||
 | 
						||
	cleanup := func() {
 | 
						||
		if client != nil {
 | 
						||
			client = nil
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	return client, cleanup, nil
 | 
						||
}
 | 
						||
 | 
						||
// ToolSelect 工具选择
 | 
						||
func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools []api.Tool) (res api.ChatResponse, err error) {
 | 
						||
	// 构建聊天请求
 | 
						||
	req := &api.ChatRequest{
 | 
						||
		Model:    c.config.Model,
 | 
						||
		Messages: messages,
 | 
						||
		Stream:   new(bool), // 设置为false,不使用流式响应
 | 
						||
		Think:    &api.ThinkValue{Value: false},
 | 
						||
		Tools:    tools,
 | 
						||
	}
 | 
						||
	err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
 | 
						||
		res = resp
 | 
						||
		return nil
 | 
						||
	})
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string, model string) (err error) {
 | 
						||
	if len(model) == 0 {
 | 
						||
		model = c.config.Model
 | 
						||
	}
 | 
						||
	// 构建聊天请求
 | 
						||
	req := &api.ChatRequest{
 | 
						||
		Model:    model,
 | 
						||
		Messages: messages,
 | 
						||
		Stream:   nil,
 | 
						||
		Think:    &api.ThinkValue{Value: true},
 | 
						||
	}
 | 
						||
	var w sync.WaitGroup
 | 
						||
	w.Add(1)
 | 
						||
	go func() {
 | 
						||
		defer w.Done()
 | 
						||
		err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
 | 
						||
			if resp.Message.Content != "" {
 | 
						||
				ch <- entitys.Response{
 | 
						||
					Index:   index,
 | 
						||
					Content: resp.Message.Content,
 | 
						||
					Type:    entitys.ResponseStream,
 | 
						||
				}
 | 
						||
			}
 | 
						||
			return nil
 | 
						||
		})
 | 
						||
		if err != nil {
 | 
						||
			return
 | 
						||
		}
 | 
						||
	}()
 | 
						||
	w.Wait()
 | 
						||
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
func (c *Client) Generation(ctx context.Context, generateRequest *api.GenerateRequest) (result api.GenerateResponse, err error) {
 | 
						||
	err = c.client.Generate(ctx, generateRequest, func(resp api.GenerateResponse) error {
 | 
						||
		result = resp
 | 
						||
		return nil
 | 
						||
	})
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
// convertResponse 转换响应格式
 | 
						||
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
 | 
						||
	//result := &entitys.ChatResponse{
 | 
						||
	//	Message:  resp.Message.Content,
 | 
						||
	//	Finished: resp.Done,
 | 
						||
	//}
 | 
						||
	//
 | 
						||
	//// 转换工具调用
 | 
						||
	//if len(resp.Message.ToolCalls) > 0 {
 | 
						||
	//	result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls))
 | 
						||
	//	for i, toolCall := range resp.Message.ToolCalls {
 | 
						||
	//		// 转换函数参数
 | 
						||
	//		argBytes, _ := json.Marshal(toolCall.Function.Arguments)
 | 
						||
	//
 | 
						||
	//		result.ToolCalls[i] = entitys.ToolCall{
 | 
						||
	//			ID:   fmt.Sprintf("call_%d", i),
 | 
						||
	//			Type: "function",
 | 
						||
	//			Function: entitys.FunctionCall{
 | 
						||
	//				Name:      toolCall.Function.Name,
 | 
						||
	//				Arguments: json.RawMessage(argBytes),
 | 
						||
	//			},
 | 
						||
	//		}
 | 
						||
	//	}
 | 
						||
	//}
 | 
						||
 | 
						||
	//return result
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func (c *Client) getUrl() (*url.URL, error) {
 | 
						||
	baseURL := c.config.BaseURL
 | 
						||
	envURL := os.Getenv("OLLAMA_BASE_URL")
 | 
						||
	if envURL != "" {
 | 
						||
		baseURL = envURL
 | 
						||
	}
 | 
						||
 | 
						||
	return url.Parse(baseURL)
 | 
						||
}
 |