125 lines
2.9 KiB
Go
125 lines
2.9 KiB
Go
package utils_ollama
|
||
|
||
import (
|
||
"ai_scheduler/internal/config"
|
||
"ai_scheduler/internal/entitys"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"time"
|
||
|
||
"github.com/ollama/ollama/api"
|
||
"github.com/tmc/langchaingo/llms/ollama"
|
||
)
|
||
|
||
// Client Ollama客户端适配器
|
||
type Client struct {
|
||
client *api.Client
|
||
config *config.OllamaConfig
|
||
}
|
||
|
||
// NewClient 创建新的Ollama客户端
|
||
func NewClient(config *config.Config) (entitys.AIClient, func(), error) {
|
||
client, err := api.ClientFromEnvironment()
|
||
cleanup := func() {
|
||
if client != nil {
|
||
client = nil
|
||
}
|
||
}
|
||
if err != nil {
|
||
return nil, cleanup, fmt.Errorf("failed to create ollama client: %w", err)
|
||
}
|
||
|
||
return &Client{
|
||
client: client,
|
||
config: &config.Ollama,
|
||
}, cleanup, nil
|
||
}
|
||
|
||
// Chat 实现聊天功能
|
||
func (c *Client) Chat(ctx context.Context, messages []entitys.Message, tools []entitys.ToolDefinition) (*entitys.ChatResponse, error) {
|
||
// 构建聊天请求
|
||
req := &api.ChatRequest{
|
||
Model: c.config.Model,
|
||
Messages: make([]api.Message, len(messages)),
|
||
Stream: new(bool), // 设置为false,不使用流式响应
|
||
Think: &api.ThinkValue{Value: true},
|
||
}
|
||
|
||
// 转换消息格式
|
||
for i, msg := range messages {
|
||
req.Messages[i] = api.Message{
|
||
Role: msg.Role,
|
||
Content: msg.Content,
|
||
}
|
||
}
|
||
|
||
// 添加工具定义
|
||
if len(tools) > 0 {
|
||
req.Tools = make([]api.Tool, len(tools))
|
||
for i, tool := range tools {
|
||
toolData, _ := json.Marshal(tool)
|
||
var apiTool api.Tool
|
||
json.Unmarshal(toolData, &apiTool)
|
||
req.Tools[i] = apiTool
|
||
}
|
||
}
|
||
|
||
// 发送请求
|
||
responseChan := make(chan api.ChatResponse)
|
||
errorChan := make(chan error)
|
||
|
||
go func() {
|
||
err := c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||
responseChan <- resp
|
||
return nil
|
||
})
|
||
if err != nil {
|
||
errorChan <- err
|
||
}
|
||
close(responseChan)
|
||
close(errorChan)
|
||
}()
|
||
|
||
// 等待响应
|
||
select {
|
||
case resp := <-responseChan:
|
||
return c.convertResponse(&resp), nil
|
||
case err := <-errorChan:
|
||
return nil, fmt.Errorf("chat request failed: %w", err)
|
||
case <-ctx.Done():
|
||
return nil, ctx.Err()
|
||
case <-time.After(c.config.Timeout):
|
||
return nil, fmt.Errorf("chat request timeout")
|
||
}
|
||
}
|
||
|
||
// convertResponse 转换响应格式
|
||
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
|
||
//result := &entitys.ChatResponse{
|
||
// Message: resp.Message.Content,
|
||
// Finished: resp.Done,
|
||
//}
|
||
//
|
||
//// 转换工具调用
|
||
//if len(resp.Message.ToolCalls) > 0 {
|
||
// result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls))
|
||
// for i, toolCall := range resp.Message.ToolCalls {
|
||
// // 转换函数参数
|
||
// argBytes, _ := json.Marshal(toolCall.Function.Arguments)
|
||
//
|
||
// result.ToolCalls[i] = entitys.ToolCall{
|
||
// ID: fmt.Sprintf("call_%d", i),
|
||
// Type: "function",
|
||
// Function: entitys.FunctionCall{
|
||
// Name: toolCall.Function.Name,
|
||
// Arguments: json.RawMessage(argBytes),
|
||
// },
|
||
// }
|
||
// }
|
||
//}
|
||
|
||
//return result
|
||
return nil
|
||
}
|