结构修改
This commit is contained in:
parent
c130900ac6
commit
ed7591975b
|
@ -21,6 +21,7 @@ import (
|
|||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
@ -34,7 +35,7 @@ type AiRouterBiz struct {
|
|||
taskImpl *impl.TaskImpl
|
||||
hisImpl *impl.ChatImpl
|
||||
conf *config.Config
|
||||
utilAgent *utils_ollama.UtilOllama
|
||||
ai *utils_ollama.Client
|
||||
}
|
||||
|
||||
// NewRouterService 创建路由服务
|
||||
|
@ -46,7 +47,7 @@ func NewAiRouterBiz(
|
|||
taskImpl *impl.TaskImpl,
|
||||
hisImpl *impl.ChatImpl,
|
||||
conf *config.Config,
|
||||
utilAgent *utils_ollama.UtilOllama,
|
||||
ai *utils_ollama.Client,
|
||||
|
||||
) *AiRouterBiz {
|
||||
return &AiRouterBiz{
|
||||
|
@ -57,7 +58,7 @@ func NewAiRouterBiz(
|
|||
sysImpl: sysImpl,
|
||||
hisImpl: hisImpl,
|
||||
taskImpl: taskImpl,
|
||||
utilAgent: utilAgent,
|
||||
ai: ai,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -103,22 +104,19 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe
|
|||
|
||||
//意图预测
|
||||
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
|
||||
match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
|
||||
//llms.WithTools(toolDefinitions),
|
||||
//llms.WithToolChoice("tool_name"),
|
||||
llms.WithJSONMode(),
|
||||
)
|
||||
toolDefinitions := r.registerTools(task)
|
||||
match, err := r.ai.ToolSelect(context.TODO(), prompt, toolDefinitions)
|
||||
if err != nil {
|
||||
return errors.SystemError
|
||||
}
|
||||
log.Info(match)
|
||||
var matchJson entitys.Match
|
||||
err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
|
||||
if err != nil {
|
||||
return errors.SystemError
|
||||
}
|
||||
return r.handleMatch(c, &matchJson, task)
|
||||
//c.WriteMessage(1, []byte(msg.Choices[0].Content))
|
||||
//var matchJson entitys.Match
|
||||
//err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
|
||||
//if err != nil {
|
||||
// return errors.SystemError
|
||||
//}
|
||||
//return r.handleMatch(c, &matchJson, task)
|
||||
c.WriteMessage(1, []byte(match.Message.Content))
|
||||
// 构建消息
|
||||
//messages := []entitys.Message{
|
||||
// {
|
||||
|
@ -327,20 +325,25 @@ func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool {
|
||||
taskPrompt := make([]llms.Tool, 0)
|
||||
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []api.Tool {
|
||||
taskPrompt := make([]api.Tool, 0)
|
||||
for _, task := range tasks {
|
||||
var taskConfig entitys.TaskConfig
|
||||
var taskConfig entitys.TaskConfigDetail
|
||||
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
taskPrompt = append(taskPrompt, llms.Tool{
|
||||
|
||||
taskPrompt = append(taskPrompt, api.Tool{
|
||||
Type: "function",
|
||||
Function: &llms.FunctionDefinition{
|
||||
Function: api.ToolFunction{
|
||||
Name: task.Index,
|
||||
Description: task.Desc,
|
||||
Parameters: taskConfig.Param,
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: taskConfig.Param.Type,
|
||||
Required: taskConfig.Param.Required,
|
||||
Properties: taskConfig.Param.Properties,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
|
@ -365,30 +368,22 @@ func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, re
|
|||
return prompt
|
||||
}
|
||||
|
||||
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
|
||||
var (
|
||||
prompt = make([]llms.MessageContent, 0)
|
||||
prompt = make([]api.Message, 0)
|
||||
)
|
||||
prompt = append(prompt, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeSystem,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)),
|
||||
},
|
||||
}, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeTool,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
|
||||
},
|
||||
}, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeTool,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
||||
},
|
||||
}, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeHuman,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(reqInput),
|
||||
},
|
||||
prompt = append(prompt, api.Message{
|
||||
Role: string(llms.ChatMessageTypeSystem),
|
||||
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
||||
}, api.Message{
|
||||
Role: "chatHistory",
|
||||
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
|
||||
}, api.Message{
|
||||
Role: string(llms.ChatMessageTypeTool),
|
||||
Content: pkg.JsonStringIgonErr(r.registerTools(tasks)),
|
||||
}, api.Message{
|
||||
Role: string(llms.ChatMessageTypeHuman),
|
||||
Content: reqInput,
|
||||
})
|
||||
return prompt
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// ChatRequest 聊天请求
|
||||
|
@ -93,6 +94,15 @@ type FuncApi struct {
|
|||
type TaskConfig struct {
|
||||
Param interface{} `json:"param"`
|
||||
}
|
||||
|
||||
type TaskConfigDetail struct {
|
||||
Param ConfigParam `json:"param"`
|
||||
}
|
||||
type ConfigParam struct {
|
||||
Properties map[string]api.ToolProperty
|
||||
Required []string `json:"required"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
type Match struct {
|
||||
Confidence float64 `json:"confidence"`
|
||||
Index string `json:"index"`
|
||||
|
|
|
@ -10,4 +10,5 @@ var ProviderSetClient = wire.NewSet(
|
|||
NewRdb,
|
||||
NewGormDb,
|
||||
utils_ollama.NewUtilOllama,
|
||||
utils_ollama.NewClient,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
package utils_ollama
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Client Ollama客户端适配器
|
||||
type Client struct {
|
||||
client *api.Client
|
||||
config *config.OllamaConfig
|
||||
}
|
||||
|
||||
// NewClient 创建新的Ollama客户端
|
||||
func NewClient(config *config.Config) (client *Client, cleanFunc func(), err error) {
|
||||
client = &Client{
|
||||
config: &config.Ollama,
|
||||
}
|
||||
url, err := client.getUrl()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
client.client = api.NewClient(url, http.DefaultClient)
|
||||
|
||||
cleanup := func() {
|
||||
if client != nil {
|
||||
client = nil
|
||||
}
|
||||
}
|
||||
|
||||
return client, cleanup, nil
|
||||
}
|
||||
|
||||
// ToolSelect 工具选择
|
||||
func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools []api.Tool) (res api.ChatResponse, err error) {
|
||||
// 构建聊天请求
|
||||
req := &api.ChatRequest{
|
||||
Model: c.config.Model,
|
||||
Messages: messages,
|
||||
Stream: new(bool), // 设置为false,不使用流式响应
|
||||
Think: &api.ThinkValue{Value: true},
|
||||
Tools: tools,
|
||||
}
|
||||
|
||||
err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||
res = resp
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// convertResponse 转换响应格式
|
||||
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
|
||||
//result := &entitys.ChatResponse{
|
||||
// Message: resp.Message.Content,
|
||||
// Finished: resp.Done,
|
||||
//}
|
||||
//
|
||||
//// 转换工具调用
|
||||
//if len(resp.Message.ToolCalls) > 0 {
|
||||
// result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls))
|
||||
// for i, toolCall := range resp.Message.ToolCalls {
|
||||
// // 转换函数参数
|
||||
// argBytes, _ := json.Marshal(toolCall.Function.Arguments)
|
||||
//
|
||||
// result.ToolCalls[i] = entitys.ToolCall{
|
||||
// ID: fmt.Sprintf("call_%d", i),
|
||||
// Type: "function",
|
||||
// Function: entitys.FunctionCall{
|
||||
// Name: toolCall.Function.Name,
|
||||
// Arguments: json.RawMessage(argBytes),
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
//return result
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) getUrl() (*url.URL, error) {
|
||||
baseURL := c.config.BaseURL
|
||||
envURL := os.Getenv("OLLAMA_BASE_URL")
|
||||
if envURL != "" {
|
||||
baseURL = envURL
|
||||
}
|
||||
|
||||
return url.Parse(baseURL)
|
||||
}
|
|
@ -1,124 +0,0 @@
|
|||
package utils_ollama
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/tmc/langchaingo/llms/ollama"
|
||||
)
|
||||
|
||||
// Client Ollama客户端适配器
|
||||
type Client struct {
|
||||
client *api.Client
|
||||
config *config.OllamaConfig
|
||||
}
|
||||
|
||||
// NewClient 创建新的Ollama客户端
|
||||
func NewClient(config *config.Config) (entitys.AIClient, func(), error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
cleanup := func() {
|
||||
if client != nil {
|
||||
client = nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, cleanup, fmt.Errorf("failed to create ollama client: %w", err)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
client: client,
|
||||
config: &config.Ollama,
|
||||
}, cleanup, nil
|
||||
}
|
||||
|
||||
// Chat 实现聊天功能
|
||||
func (c *Client) Chat(ctx context.Context, messages []entitys.Message, tools []entitys.ToolDefinition) (*entitys.ChatResponse, error) {
|
||||
// 构建聊天请求
|
||||
req := &api.ChatRequest{
|
||||
Model: c.config.Model,
|
||||
Messages: make([]api.Message, len(messages)),
|
||||
Stream: new(bool), // 设置为false,不使用流式响应
|
||||
Think: &api.ThinkValue{Value: true},
|
||||
}
|
||||
|
||||
// 转换消息格式
|
||||
for i, msg := range messages {
|
||||
req.Messages[i] = api.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
}
|
||||
}
|
||||
|
||||
// 添加工具定义
|
||||
if len(tools) > 0 {
|
||||
req.Tools = make([]api.Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
toolData, _ := json.Marshal(tool)
|
||||
var apiTool api.Tool
|
||||
json.Unmarshal(toolData, &apiTool)
|
||||
req.Tools[i] = apiTool
|
||||
}
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
responseChan := make(chan api.ChatResponse)
|
||||
errorChan := make(chan error)
|
||||
|
||||
go func() {
|
||||
err := c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||
responseChan <- resp
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
errorChan <- err
|
||||
}
|
||||
close(responseChan)
|
||||
close(errorChan)
|
||||
}()
|
||||
|
||||
// 等待响应
|
||||
select {
|
||||
case resp := <-responseChan:
|
||||
return c.convertResponse(&resp), nil
|
||||
case err := <-errorChan:
|
||||
return nil, fmt.Errorf("chat request failed: %w", err)
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(c.config.Timeout):
|
||||
return nil, fmt.Errorf("chat request timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// convertResponse 转换响应格式
|
||||
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
|
||||
//result := &entitys.ChatResponse{
|
||||
// Message: resp.Message.Content,
|
||||
// Finished: resp.Done,
|
||||
//}
|
||||
//
|
||||
//// 转换工具调用
|
||||
//if len(resp.Message.ToolCalls) > 0 {
|
||||
// result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls))
|
||||
// for i, toolCall := range resp.Message.ToolCalls {
|
||||
// // 转换函数参数
|
||||
// argBytes, _ := json.Marshal(toolCall.Function.Arguments)
|
||||
//
|
||||
// result.ToolCalls[i] = entitys.ToolCall{
|
||||
// ID: fmt.Sprintf("call_%d", i),
|
||||
// Type: "function",
|
||||
// Function: entitys.FunctionCall{
|
||||
// Name: toolCall.Function.Name,
|
||||
// Arguments: json.RawMessage(argBytes),
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
//return result
|
||||
return nil
|
||||
}
|
|
@ -45,7 +45,7 @@ func NewManager(config *config.Config, llm *utils_ollama.UtilOllama) *Manager {
|
|||
|
||||
// 注册直连天下订单详情工具
|
||||
if config.Tools.ZltxOrderDetail.Enabled {
|
||||
zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail)
|
||||
zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm)
|
||||
m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ package tools
|
|||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
|
@ -13,11 +14,12 @@ import (
|
|||
// ZltxOrderDetailTool 直连天下订单详情工具
|
||||
type ZltxOrderDetailTool struct {
|
||||
config config.ToolConfig
|
||||
llm *utils_ollama.UtilOllama
|
||||
}
|
||||
|
||||
// NewZltxOrderDetailTool 创建直连天下订单详情工具
|
||||
func NewZltxOrderDetailTool(config config.ToolConfig) *ZltxOrderDetailTool {
|
||||
return &ZltxOrderDetailTool{config: config}
|
||||
func NewZltxOrderDetailTool(config config.ToolConfig, llm *utils_ollama.UtilOllama) *ZltxOrderDetailTool {
|
||||
return &ZltxOrderDetailTool{config: config, llm: llm}
|
||||
}
|
||||
|
||||
// Name 返回工具名称
|
||||
|
|
Loading…
Reference in New Issue