结构修改

This commit is contained in:
renzhiyuan 2025-09-22 11:23:28 +08:00
parent b81c9ef137
commit 32866d59a1
4 changed files with 72 additions and 121 deletions

View File

@ -15,7 +15,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"strings" "strings"
"gitea.cdlsxd.cn/self-tools/l_request" "gitea.cdlsxd.cn/self-tools/l_request"
@ -203,10 +203,13 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) { func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
defer func() { defer func() {
c.WriteMessage(1, []byte("EOF")) if err != nil {
c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
}
c.WriteMessage(websocket.TextMessage, []byte("EOF"))
}() }()
if !matchJson.IsMatch { if !matchJson.IsMatch {
c.WriteMessage(1, []byte(matchJson.Reasoning)) c.WriteMessage(websocket.TextMessage, []byte(matchJson.Reasoning))
return return
} }
var pointTask *model.AiTask var pointTask *model.AiTask
@ -219,33 +222,32 @@ func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Matc
if pointTask == nil || pointTask.Index == "other" { if pointTask == nil || pointTask.Index == "other" {
return r.handleOtherTask(c, matchJson) return r.handleOtherTask(c, matchJson)
} }
var res []byte
switch pointTask.Type { switch pointTask.Type {
case constant.TaskTypeApi: case constant.TaskTypeApi:
res, err = r.handleApiTask(c, matchJson, pointTask) err = r.handleApiTask(c, matchJson, pointTask)
case constant.TaskTypeFunc: case constant.TaskTypeFunc:
ctx := context.TODO() err = r.handleTask(c, matchJson, pointTask)
res, err = r.handleTask(ctx, matchJson, pointTask)
default: default:
return r.handleOtherTask(c, matchJson) return r.handleOtherTask(c, matchJson)
} }
fmt.Println(res)
return return
} }
func (r *AiRouterService) handleTask(c context.Context, matchJson *entitys.Match, task *model.AiTask) (res []byte, err error) { func (r *AiRouterService) handleTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData) err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil { if err != nil {
return return
} }
resInterface, err := r.toolManager.ExecuteTool(c, configData.Tool, []byte(matchJson.Parameters)) err = r.toolManager.ExecuteTool(c, configData.Tool, []byte(matchJson.Parameters))
if err != nil { if err != nil {
return nil, err return
} }
return json.Marshal(resInterface) return
} }
func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) { func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) {
@ -255,7 +257,7 @@ func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.
return return
} }
func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (resByte []byte, err error) { func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var ( var (
request l_request.Request request l_request.Request
auth = c.Headers("X-Authorization", "") auth = c.Headers("X-Authorization", "")
@ -286,7 +288,8 @@ func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Ma
if err != nil { if err != nil {
return return
} }
return res.Content, nil c.WriteMessage(1, res.Content)
return
} }
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) { func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
@ -412,59 +415,6 @@ func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.
return return
} }
// extractIntent 从AI响应中提取意图
func (r *AiRouterService) extractIntent(response *entitys.ChatResponse) string {
if response == nil || response.Message == "" {
return ""
}
// 尝试解析JSON
var intent struct {
Intent string `json:"intent"`
Confidence string `json:"confidence"`
Reasoning string `json:"reasoning"`
}
err := json.Unmarshal([]byte(response.Message), &intent)
if err != nil {
log.Printf("Failed to parse intent JSON: %v", err)
return ""
}
return intent.Intent
}
// handleOrderDiagnosis 处理订单诊断意图
func (r *AiRouterService) handleOrderDiagnosis(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
// 调用订单详情工具
//orderDetailTool, ok := r.toolManager.GetTool("zltxOrderDetail")
//if orderDetailTool == nil || !ok {
// return nil, fmt.Errorf("order detail tool not found")
//}
//orderDetailTool.Execute(ctx, json.RawMessage{})
//
//// 获取相关工具定义
//toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller))
//
//// 调用AI获取是否需要使用工具
//response, err := r.aiClient.Chat(ctx, messages, toolDefinitions)
//if err != nil {
// return nil, fmt.Errorf("failed to chat with AI: %w", err)
//}
//
//// 如果没有工具调用,直接返回
//if len(response.ToolCalls) == 0 {
// return response, nil
//}
//
//// 执行工具调用
//toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls)
//if err != nil {
// return nil, fmt.Errorf("failed to execute tools: %w", err)
//}
return nil, nil
}
// handleKnowledgeQA 处理知识问答意图 // handleKnowledgeQA 处理知识问答意图
func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {

View File

@ -66,7 +66,7 @@ type Tool interface {
Name() string Name() string
Description() string Description() string
Definition() ToolDefinition Definition() ToolDefinition
Execute(ctx context.Context, args json.RawMessage) (interface{}, error) Execute(c *websocket.Conn, args json.RawMessage) error
} }
type ConfigDataHttp struct { type ConfigDataHttp struct {

View File

@ -4,10 +4,10 @@ import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/constants" "ai_scheduler/internal/constants"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gofiber/websocket/v2"
) )
// Manager 工具管理器 // Manager 工具管理器
@ -22,16 +22,16 @@ func NewManager(config *config.Config) *Manager {
} }
// 注册天气工具 // 注册天气工具
if config.Tools.Weather.Enabled { //if config.Tools.Weather.Enabled {
weatherTool := NewWeatherTool() // weatherTool := NewWeatherTool()
m.tools[weatherTool.Name()] = weatherTool // m.tools[weatherTool.Name()] = weatherTool
} //}
//
// 注册计算器工具 //// 注册计算器工具
if config.Tools.Calculator.Enabled { //if config.Tools.Calculator.Enabled {
calcTool := NewCalculatorTool() // calcTool := NewCalculatorTool()
m.tools[calcTool.Name()] = calcTool // m.tools[calcTool.Name()] = calcTool
} //}
// 注册知识库工具 // 注册知识库工具
// if config.Knowledge.Enabled { // if config.Knowledge.Enabled {
@ -80,43 +80,43 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi
} }
// ExecuteTool 执行工具 // ExecuteTool 执行工具
func (m *Manager) ExecuteTool(ctx context.Context, name string, args json.RawMessage) (interface{}, error) { func (m *Manager) ExecuteTool(c *websocket.Conn, name string, args json.RawMessage) error {
tool, exists := m.GetTool(name) tool, exists := m.GetTool(name)
if !exists { if !exists {
return nil, fmt.Errorf("tool not found: %s", name) return fmt.Errorf("tool not found: %s", name)
} }
return tool.Execute(ctx, args) return tool.Execute(c, args)
} }
// ExecuteToolCalls 执行多个工具调用 // ExecuteToolCalls 执行多个工具调用
func (m *Manager) ExecuteToolCalls(ctx context.Context, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) { //func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) {
results := make([]entitys.ToolCall, len(toolCalls)) // results := make([]entitys.ToolCall, len(toolCalls))
//
for i, toolCall := range toolCalls { // for i, toolCall := range toolCalls {
results[i] = toolCall // results[i] = toolCall
//
// 执行工具 // // 执行工具
result, err := m.ExecuteTool(ctx, toolCall.Function.Name, toolCall.Function.Arguments) // err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments)
if err != nil { // if err != nil {
// 将错误信息作为结果返回 // // 将错误信息作为结果返回
errorResult := map[string]interface{}{ // errorResult := map[string]interface{}{
"error": err.Error(), // "error": err.Error(),
} // }
resultBytes, _ := json.Marshal(errorResult) // resultBytes, _ := json.Marshal(errorResult)
results[i].Result = resultBytes // results[i].Result = resultBytes
} else { // } else {
// 将成功结果序列化 // // 将成功结果序列化
resultBytes, err := json.Marshal(result) // resultBytes, err := json.Marshal(result)
if err != nil { // if err != nil {
errorResult := map[string]interface{}{ // errorResult := map[string]interface{}{
"error": fmt.Sprintf("failed to serialize result: %v", err), // "error": fmt.Sprintf("failed to serialize result: %v", err),
} // }
resultBytes, _ = json.Marshal(errorResult) // resultBytes, _ = json.Marshal(errorResult)
} // }
results[i].Result = resultBytes // results[i].Result = resultBytes
} // }
} // }
//
return results, nil // return results, nil
} //}

View File

@ -4,10 +4,11 @@ import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"github.com/gofiber/websocket/v2"
) )
// ZltxOrderDetailTool 直连天下订单详情工具 // ZltxOrderDetailTool 直连天下订单详情工具
@ -70,37 +71,37 @@ type ZltxOrderDetailData struct {
} }
// Execute 执行直连天下订单详情查询 // Execute 执行直连天下订单详情查询
func (w *ZltxOrderDetailTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) { func (w *ZltxOrderDetailTool) Execute(c *websocket.Conn, args json.RawMessage) error {
var req ZltxOrderDetailRequest var req ZltxOrderDetailRequest
if err := json.Unmarshal(args, &req); err != nil { if err := json.Unmarshal(args, &req); err != nil {
return nil, fmt.Errorf("invalid zltxOrderDetail request: %w", err) return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
} }
if req.Number == "" { if req.Number == "" {
return nil, fmt.Errorf("number is required") return fmt.Errorf("number is required")
} }
// 这里可以集成真实的直连天下订单详情API // 这里可以集成真实的直连天下订单详情API
return w.getZltxOrderDetail(ctx, req.Number), nil return w.getZltxOrderDetail(c, req.Number)
} }
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据 // getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *ZltxOrderDetailTool) getZltxOrderDetail(ctx context.Context, number string) *ZltxOrderDetailResponse { func (w *ZltxOrderDetailTool) getZltxOrderDetail(c *websocket.Conn, number string) (err error) {
url := fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number) url := fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number)
authorization := fmt.Sprintf("Bearer %s", w.config.APIKey) authorization := fmt.Sprintf("Bearer %s", w.config.APIKey)
// 发送http请求 // 发送http请求
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {
return &ZltxOrderDetailResponse{} return
} }
req.Header.Set("Authorization", authorization) req.Header.Set("Authorization", authorization)
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return &ZltxOrderDetailResponse{} return
} }
defer resp.Body.Close() defer resp.Body.Close()
return &ZltxOrderDetailResponse{} return
} }