结构修改

This commit is contained in:
renzhiyuan 2025-09-22 17:04:47 +08:00
parent c130900ac6
commit ed7591975b
7 changed files with 153 additions and 170 deletions

View File

@ -21,6 +21,7 @@ import (
"gitea.cdlsxd.cn/self-tools/l_request" "gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log" "github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
"github.com/ollama/ollama/api"
"github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms"
"xorm.io/builder" "xorm.io/builder"
) )
@ -34,7 +35,7 @@ type AiRouterBiz struct {
taskImpl *impl.TaskImpl taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl hisImpl *impl.ChatImpl
conf *config.Config conf *config.Config
utilAgent *utils_ollama.UtilOllama ai *utils_ollama.Client
} }
// NewRouterService 创建路由服务 // NewRouterService 创建路由服务
@ -46,7 +47,7 @@ func NewAiRouterBiz(
taskImpl *impl.TaskImpl, taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl, hisImpl *impl.ChatImpl,
conf *config.Config, conf *config.Config,
utilAgent *utils_ollama.UtilOllama, ai *utils_ollama.Client,
) *AiRouterBiz { ) *AiRouterBiz {
return &AiRouterBiz{ return &AiRouterBiz{
@ -57,7 +58,7 @@ func NewAiRouterBiz(
sysImpl: sysImpl, sysImpl: sysImpl,
hisImpl: hisImpl, hisImpl: hisImpl,
taskImpl: taskImpl, taskImpl: taskImpl,
utilAgent: utilAgent, ai: ai,
} }
} }
@ -103,22 +104,19 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe
//意图预测 //意图预测
prompt := r.getPromptLLM(sysInfo, history, req.Text, task) prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt, toolDefinitions := r.registerTools(task)
//llms.WithTools(toolDefinitions), match, err := r.ai.ToolSelect(context.TODO(), prompt, toolDefinitions)
//llms.WithToolChoice("tool_name"),
llms.WithJSONMode(),
)
if err != nil { if err != nil {
return errors.SystemError return errors.SystemError
} }
log.Info(match) log.Info(match)
var matchJson entitys.Match //var matchJson entitys.Match
err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson) //err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
if err != nil { //if err != nil {
return errors.SystemError // return errors.SystemError
} //}
return r.handleMatch(c, &matchJson, task) //return r.handleMatch(c, &matchJson, task)
//c.WriteMessage(1, []byte(msg.Choices[0].Content)) c.WriteMessage(1, []byte(match.Message.Content))
// 构建消息 // 构建消息
//messages := []entitys.Message{ //messages := []entitys.Message{
// { // {
@ -327,20 +325,25 @@ func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
return return
} }
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool { func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []api.Tool {
taskPrompt := make([]llms.Tool, 0) taskPrompt := make([]api.Tool, 0)
for _, task := range tasks { for _, task := range tasks {
var taskConfig entitys.TaskConfig var taskConfig entitys.TaskConfigDetail
err := json.Unmarshal([]byte(task.Config), &taskConfig) err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil { if err != nil {
continue continue
} }
taskPrompt = append(taskPrompt, llms.Tool{
taskPrompt = append(taskPrompt, api.Tool{
Type: "function", Type: "function",
Function: &llms.FunctionDefinition{ Function: api.ToolFunction{
Name: task.Index, Name: task.Index,
Description: task.Desc, Description: task.Desc,
Parameters: taskConfig.Param, Parameters: api.ToolFunctionParameters{
Type: taskConfig.Param.Type,
Required: taskConfig.Param.Required,
Properties: taskConfig.Param.Properties,
},
}, },
}) })
@ -365,30 +368,22 @@ func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, re
return prompt return prompt
} }
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
var ( var (
prompt = make([]llms.MessageContent, 0) prompt = make([]api.Message, 0)
) )
prompt = append(prompt, llms.MessageContent{ prompt = append(prompt, api.Message{
Role: llms.ChatMessageTypeSystem, Role: string(llms.ChatMessageTypeSystem),
Parts: []llms.ContentPart{ Content: r.buildSystemPrompt(sysInfo.SysPrompt),
llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)), }, api.Message{
}, Role: "chatHistory",
}, llms.MessageContent{ Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
Role: llms.ChatMessageTypeTool, }, api.Message{
Parts: []llms.ContentPart{ Role: string(llms.ChatMessageTypeTool),
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))), Content: pkg.JsonStringIgonErr(r.registerTools(tasks)),
}, }, api.Message{
}, llms.MessageContent{ Role: string(llms.ChatMessageTypeHuman),
Role: llms.ChatMessageTypeTool, Content: reqInput,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{
llms.TextPart(reqInput),
},
}) })
return prompt return prompt
} }

View File

@ -6,6 +6,7 @@ import (
"gitea.cdlsxd.cn/self-tools/l_request" "gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
"github.com/ollama/ollama/api"
) )
// ChatRequest 聊天请求 // ChatRequest 聊天请求
@ -93,6 +94,15 @@ type FuncApi struct {
type TaskConfig struct { type TaskConfig struct {
Param interface{} `json:"param"` Param interface{} `json:"param"`
} }
type TaskConfigDetail struct {
Param ConfigParam `json:"param"`
}
type ConfigParam struct {
Properties map[string]api.ToolProperty
Required []string `json:"required"`
Type string `json:"type"`
}
type Match struct { type Match struct {
Confidence float64 `json:"confidence"` Confidence float64 `json:"confidence"`
Index string `json:"index"` Index string `json:"index"`

View File

@ -10,4 +10,5 @@ var ProviderSetClient = wire.NewSet(
NewRdb, NewRdb,
NewGormDb, NewGormDb,
utils_ollama.NewUtilOllama, utils_ollama.NewUtilOllama,
utils_ollama.NewClient,
) )

View File

@ -0,0 +1,99 @@
package utils_ollama
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"context"
"net/http"
"net/url"
"os"
"github.com/ollama/ollama/api"
)
// Client Ollama客户端适配器
type Client struct {
client *api.Client
config *config.OllamaConfig
}
// NewClient 创建新的Ollama客户端
func NewClient(config *config.Config) (client *Client, cleanFunc func(), err error) {
client = &Client{
config: &config.Ollama,
}
url, err := client.getUrl()
if err != nil {
return
}
client.client = api.NewClient(url, http.DefaultClient)
cleanup := func() {
if client != nil {
client = nil
}
}
return client, cleanup, nil
}
// ToolSelect 工具选择
func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools []api.Tool) (res api.ChatResponse, err error) {
// 构建聊天请求
req := &api.ChatRequest{
Model: c.config.Model,
Messages: messages,
Stream: new(bool), // 设置为false不使用流式响应
Think: &api.ThinkValue{Value: true},
Tools: tools,
}
err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
res = resp
return nil
})
if err != nil {
return
}
return
}
// convertResponse 转换响应格式
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
//result := &entitys.ChatResponse{
// Message: resp.Message.Content,
// Finished: resp.Done,
//}
//
//// 转换工具调用
//if len(resp.Message.ToolCalls) > 0 {
// result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls))
// for i, toolCall := range resp.Message.ToolCalls {
// // 转换函数参数
// argBytes, _ := json.Marshal(toolCall.Function.Arguments)
//
// result.ToolCalls[i] = entitys.ToolCall{
// ID: fmt.Sprintf("call_%d", i),
// Type: "function",
// Function: entitys.FunctionCall{
// Name: toolCall.Function.Name,
// Arguments: json.RawMessage(argBytes),
// },
// }
// }
//}
//return result
return nil
}
func (c *Client) getUrl() (*url.URL, error) {
baseURL := c.config.BaseURL
envURL := os.Getenv("OLLAMA_BASE_URL")
if envURL != "" {
baseURL = envURL
}
return url.Parse(baseURL)
}

View File

@ -1,124 +0,0 @@
package utils_ollama
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"context"
"encoding/json"
"fmt"
"time"
"github.com/ollama/ollama/api"
"github.com/tmc/langchaingo/llms/ollama"
)
// Client Ollama客户端适配器
type Client struct {
client *api.Client
config *config.OllamaConfig
}
// NewClient 创建新的Ollama客户端
func NewClient(config *config.Config) (entitys.AIClient, func(), error) {
client, err := api.ClientFromEnvironment()
cleanup := func() {
if client != nil {
client = nil
}
}
if err != nil {
return nil, cleanup, fmt.Errorf("failed to create ollama client: %w", err)
}
return &Client{
client: client,
config: &config.Ollama,
}, cleanup, nil
}
// Chat 实现聊天功能
func (c *Client) Chat(ctx context.Context, messages []entitys.Message, tools []entitys.ToolDefinition) (*entitys.ChatResponse, error) {
// 构建聊天请求
req := &api.ChatRequest{
Model: c.config.Model,
Messages: make([]api.Message, len(messages)),
Stream: new(bool), // 设置为false不使用流式响应
Think: &api.ThinkValue{Value: true},
}
// 转换消息格式
for i, msg := range messages {
req.Messages[i] = api.Message{
Role: msg.Role,
Content: msg.Content,
}
}
// 添加工具定义
if len(tools) > 0 {
req.Tools = make([]api.Tool, len(tools))
for i, tool := range tools {
toolData, _ := json.Marshal(tool)
var apiTool api.Tool
json.Unmarshal(toolData, &apiTool)
req.Tools[i] = apiTool
}
}
// 发送请求
responseChan := make(chan api.ChatResponse)
errorChan := make(chan error)
go func() {
err := c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
responseChan <- resp
return nil
})
if err != nil {
errorChan <- err
}
close(responseChan)
close(errorChan)
}()
// 等待响应
select {
case resp := <-responseChan:
return c.convertResponse(&resp), nil
case err := <-errorChan:
return nil, fmt.Errorf("chat request failed: %w", err)
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(c.config.Timeout):
return nil, fmt.Errorf("chat request timeout")
}
}
// convertResponse 转换响应格式
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
//result := &entitys.ChatResponse{
// Message: resp.Message.Content,
// Finished: resp.Done,
//}
//
//// 转换工具调用
//if len(resp.Message.ToolCalls) > 0 {
// result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls))
// for i, toolCall := range resp.Message.ToolCalls {
// // 转换函数参数
// argBytes, _ := json.Marshal(toolCall.Function.Arguments)
//
// result.ToolCalls[i] = entitys.ToolCall{
// ID: fmt.Sprintf("call_%d", i),
// Type: "function",
// Function: entitys.FunctionCall{
// Name: toolCall.Function.Name,
// Arguments: json.RawMessage(argBytes),
// },
// }
// }
//}
//return result
return nil
}

View File

@ -45,7 +45,7 @@ func NewManager(config *config.Config, llm *utils_ollama.UtilOllama) *Manager {
// 注册直连天下订单详情工具 // 注册直连天下订单详情工具
if config.Tools.ZltxOrderDetail.Enabled { if config.Tools.ZltxOrderDetail.Enabled {
zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail) zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm)
m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool
} }

View File

@ -3,6 +3,7 @@ package tools
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/utils_ollama"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -13,11 +14,12 @@ import (
// ZltxOrderDetailTool 直连天下订单详情工具 // ZltxOrderDetailTool 直连天下订单详情工具
type ZltxOrderDetailTool struct { type ZltxOrderDetailTool struct {
config config.ToolConfig config config.ToolConfig
llm *utils_ollama.UtilOllama
} }
// NewZltxOrderDetailTool 创建直连天下订单详情工具 // NewZltxOrderDetailTool 创建直连天下订单详情工具
func NewZltxOrderDetailTool(config config.ToolConfig) *ZltxOrderDetailTool { func NewZltxOrderDetailTool(config config.ToolConfig, llm *utils_ollama.UtilOllama) *ZltxOrderDetailTool {
return &ZltxOrderDetailTool{config: config} return &ZltxOrderDetailTool{config: config, llm: llm}
} }
// Name 返回工具名称 // Name 返回工具名称