From 570e13e52711a71155c366729e9921b648f4fcda Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Mon, 22 Sep 2025 11:46:06 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=93=E6=9E=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/router.go | 38 +++++++------- internal/biz/router_test.go | 54 ++++++++++++++++++++ internal/services/chat.go | 15 +++--- internal/test/{chat_test.go => task_test.go} | 6 ++- internal/tools/zltx_order_detail.go | 23 ++++----- 5 files changed, 97 insertions(+), 39 deletions(-) create mode 100644 internal/biz/router_test.go rename internal/test/{chat_test.go => task_test.go} (74%) diff --git a/internal/biz/router.go b/internal/biz/router.go index 45ad515..a127e27 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -25,8 +25,8 @@ import ( "xorm.io/builder" ) -// AiRouterService 智能路由服务 -type AiRouterService struct { +// AiRouterBiz 智能路由服务 +type AiRouterBiz struct { //aiClient entitys.AIClient toolManager *tools.Manager sessionImpl *impl.SessionImpl @@ -48,8 +48,8 @@ func NewAiRouterBiz( conf *config.Config, utilAgent *utils_ollama.UtilOllama, -) entitys.RouterService { - return &AiRouterService{ +) *AiRouterBiz { + return &AiRouterBiz{ //aiClient: aiClient, toolManager: toolManager, sessionImpl: sessionImpl, @@ -62,13 +62,13 @@ func NewAiRouterBiz( } // Route 执行智能路由 -func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) { +func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) { return nil, nil } // Route 执行智能路由 -func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error { +func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error { session := c.Headers("X-Session", "") if len(session) == 0 { @@ -201,7 +201,7 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo return nil } -func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) { +func (r *AiRouterBiz) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) { defer func() { if err != nil { c.WriteMessage(websocket.TextMessage, []byte(err.Error())) @@ -235,7 +235,7 @@ func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Matc return } -func (r *AiRouterService) handleTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { +func (r *AiRouterBiz) handleTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) @@ -250,14 +250,14 @@ func (r *AiRouterService) handleTask(c *websocket.Conn, matchJson *entitys.Match return } -func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) { +func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) { c.WriteMessage(1, []byte(matchJson.Reasoning)) return } -func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { +func (r *AiRouterBiz) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var ( request l_request.Request auth = c.Headers("X-Authorization", "") @@ -292,7 +292,7 @@ func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Ma return } -func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) { +func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"session_id": sessionId}) @@ -302,7 +302,7 @@ func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiCha return } -func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err error) { +func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"app_key": appKey}) cond = cond.And(builder.IsNull{"delete_at"}) @@ -311,7 +311,7 @@ func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err err return } -func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error) { +func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"sys_id": sysId}) @@ -322,7 +322,7 @@ func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error return } -func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool { +func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool { taskPrompt := make([]llms.Tool, 0) for _, task := range tasks { var taskConfig entitys.TaskConfig @@ -343,7 +343,7 @@ func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool { return taskPrompt } -func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message { +func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message { var ( prompt = make([]entitys.Message, 0) ) @@ -360,7 +360,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi return prompt } -func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { +func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { var ( prompt = make([]llms.MessageContent, 0) ) @@ -389,7 +389,7 @@ func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiCha } // buildSystemPrompt 构建系统提示词 -func (r *AiRouterService) buildSystemPrompt(prompt string) string { +func (r *AiRouterBiz) buildSystemPrompt(prompt string) string { if len(prompt) == 0 { prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容" } @@ -397,7 +397,7 @@ func (r *AiRouterService) buildSystemPrompt(prompt string) string { return prompt } -func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) { +func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) { for _, item := range his { if len(chatHis.SessionId) == 0 { chatHis.SessionId = item.SessionID @@ -416,7 +416,7 @@ func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys. } // handleKnowledgeQA 处理知识问答意图 -func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { +func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { return nil, nil } diff --git a/internal/biz/router_test.go b/internal/biz/router_test.go new file mode 100644 index 0000000..0f43807 --- /dev/null +++ b/internal/biz/router_test.go @@ -0,0 +1,54 @@ +package biz + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/tools" + "ai_scheduler/utils" + "encoding/json" + "flag" + "testing" + + "github.com/gofiber/fiber/v2/log" +) + +func Test_task(t *testing.T) { + var c entitys.TaskConfig + config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}` + err := json.Unmarshal([]byte(config), &c) + t.Log(err) +} + +type configData struct { + Param map[string]interface{} `json:"param"` + Do map[string]interface{} `json:"do"` +} + +func Test_Order(t *testing.T) { + routerBiz := in() + err := routerBiz.handleTask(nil, &entitys.Match{Index: "order_diagnosis"}, &model.AiTask{Config: `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`}) + t.Log(err) +} + +func in() *AiRouterBiz { + configPath := flag.String("config", "./config/config.yaml", "Path to configuration file") + flag.Parse() + configConfig, err := config.LoadConfig(*configPath) + if err != nil { + panic("加载配置失败") + } + allLogger := log.DefaultLogger() + manager := tools.NewManager(configConfig) + db, _ := utils.NewGormDb(configConfig) + sessionImpl := impl.NewSessionImpl(db) + sysImpl := impl.NewSysImpl(db) + taskImpl := impl.NewTaskImpl(db) + chatImpl := impl.NewChatImpl(db) + utilOllama := utils_ollama.NewUtilOllama(configConfig, allLogger) + routerBiz := NewAiRouterBiz(manager, sessionImpl, sysImpl, taskImpl, chatImpl, configConfig, utilOllama) + + return routerBiz +} diff --git a/internal/services/chat.go b/internal/services/chat.go index 43804b8..82995e1 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -1,6 +1,7 @@ package services import ( + "ai_scheduler/internal/biz" "ai_scheduler/internal/data/constant" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" @@ -17,16 +18,16 @@ import ( // ChatHandler 聊天处理器 type ChatService struct { - routerService entitys.RouterService - Gw *gateway.Gateway - mu sync.Mutex + routerBiz *biz.AiRouterBiz + Gw *gateway.Gateway + mu sync.Mutex } // NewChatHandler 创建聊天处理器 -func NewChatService(routerService entitys.RouterService, gw *gateway.Gateway) *ChatService { +func NewChatService(routerService *biz.AiRouterBiz, gw *gateway.Gateway) *ChatService { return &ChatService{ - routerService: routerService, - Gw: gw, + routerBiz: routerService, + Gw: gw, } } @@ -100,7 +101,7 @@ func (h *ChatService) Chat(c *websocket.Conn) { log.Println("JSON parse error:", err) continue } - err = h.routerService.RouteWithSocket(c, &req) + err = h.routerBiz.RouteWithSocket(c, &req) if err != nil { log.Println("处理失败:", err) continue diff --git a/internal/test/chat_test.go b/internal/test/task_test.go similarity index 74% rename from internal/test/chat_test.go rename to internal/test/task_test.go index 489eb4f..f46832f 100644 --- a/internal/test/chat_test.go +++ b/internal/test/task_test.go @@ -27,7 +27,7 @@ func Test_task2(t *testing.T) { config configData ) - configJson := `{"param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}, "do": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}` + configJson := `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}` err := json.Unmarshal([]byte(configJson), &config) if err != nil { panic(err) @@ -35,3 +35,7 @@ func Test_task2(t *testing.T) { mapstructure.Decode(config.Do, &c) t.Log(err) } + +func in() { + +} diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx_order_detail.go index 417a409..8a3c1c8 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx_order_detail.go @@ -6,8 +6,8 @@ import ( "encoding/json" "fmt" - "net/http" + "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/websocket/v2" ) @@ -87,21 +87,20 @@ func (w *ZltxOrderDetailTool) Execute(c *websocket.Conn, args json.RawMessage) e // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 func (w *ZltxOrderDetailTool) getZltxOrderDetail(c *websocket.Conn, number string) (err error) { - url := fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number) - authorization := fmt.Sprintf("Bearer %s", w.config.APIKey) + //查询订单详情 + req := l_request.Request{ + Url: fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number), + Headers: map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", w.config.APIKey), + }, + Method: "GET", + } + res, err := req.Send() - // 发送http请求 - req, err := http.NewRequest("GET", url, nil) if err != nil { return } - req.Header.Set("Authorization", authorization) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return - } - defer resp.Body.Close() + c.WriteMessage(websocket.TextMessage, res.Content) return }