结构修改

This commit is contained in:
renzhiyuan 2025-09-22 11:23:28 +08:00
parent b81c9ef137
commit 32866d59a1
4 changed files with 72 additions and 121 deletions

View File

@ -15,7 +15,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"gitea.cdlsxd.cn/self-tools/l_request"
@ -203,10 +203,13 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
defer func() {
c.WriteMessage(1, []byte("EOF"))
if err != nil {
c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
}
c.WriteMessage(websocket.TextMessage, []byte("EOF"))
}()
if !matchJson.IsMatch {
c.WriteMessage(1, []byte(matchJson.Reasoning))
c.WriteMessage(websocket.TextMessage, []byte(matchJson.Reasoning))
return
}
var pointTask *model.AiTask
@ -219,33 +222,32 @@ func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Matc
if pointTask == nil || pointTask.Index == "other" {
return r.handleOtherTask(c, matchJson)
}
var res []byte
switch pointTask.Type {
case constant.TaskTypeApi:
res, err = r.handleApiTask(c, matchJson, pointTask)
err = r.handleApiTask(c, matchJson, pointTask)
case constant.TaskTypeFunc:
ctx := context.TODO()
res, err = r.handleTask(ctx, matchJson, pointTask)
err = r.handleTask(c, matchJson, pointTask)
default:
return r.handleOtherTask(c, matchJson)
}
fmt.Println(res)
return
}
func (r *AiRouterService) handleTask(c context.Context, matchJson *entitys.Match, task *model.AiTask) (res []byte, err error) {
func (r *AiRouterService) handleTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
resInterface, err := r.toolManager.ExecuteTool(c, configData.Tool, []byte(matchJson.Parameters))
err = r.toolManager.ExecuteTool(c, configData.Tool, []byte(matchJson.Parameters))
if err != nil {
return nil, err
return
}
return json.Marshal(resInterface)
return
}
func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) {
@ -255,7 +257,7 @@ func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.
return
}
func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (resByte []byte, err error) {
func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var (
request l_request.Request
auth = c.Headers("X-Authorization", "")
@ -286,7 +288,8 @@ func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Ma
if err != nil {
return
}
return res.Content, nil
c.WriteMessage(1, res.Content)
return
}
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
@ -412,59 +415,6 @@ func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.
return
}
// extractIntent 从AI响应中提取意图
func (r *AiRouterService) extractIntent(response *entitys.ChatResponse) string {
if response == nil || response.Message == "" {
return ""
}
// 尝试解析JSON
var intent struct {
Intent string `json:"intent"`
Confidence string `json:"confidence"`
Reasoning string `json:"reasoning"`
}
err := json.Unmarshal([]byte(response.Message), &intent)
if err != nil {
log.Printf("Failed to parse intent JSON: %v", err)
return ""
}
return intent.Intent
}
// handleOrderDiagnosis 处理订单诊断意图
func (r *AiRouterService) handleOrderDiagnosis(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
// 调用订单详情工具
//orderDetailTool, ok := r.toolManager.GetTool("zltxOrderDetail")
//if orderDetailTool == nil || !ok {
// return nil, fmt.Errorf("order detail tool not found")
//}
//orderDetailTool.Execute(ctx, json.RawMessage{})
//
//// 获取相关工具定义
//toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller))
//
//// 调用AI获取是否需要使用工具
//response, err := r.aiClient.Chat(ctx, messages, toolDefinitions)
//if err != nil {
// return nil, fmt.Errorf("failed to chat with AI: %w", err)
//}
//
//// 如果没有工具调用,直接返回
//if len(response.ToolCalls) == 0 {
// return response, nil
//}
//
//// 执行工具调用
//toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls)
//if err != nil {
// return nil, fmt.Errorf("failed to execute tools: %w", err)
//}
return nil, nil
}
// handleKnowledgeQA 处理知识问答意图
func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {

View File

@ -66,7 +66,7 @@ type Tool interface {
Name() string
Description() string
Definition() ToolDefinition
Execute(ctx context.Context, args json.RawMessage) (interface{}, error)
Execute(c *websocket.Conn, args json.RawMessage) error
}
type ConfigDataHttp struct {

View File

@ -4,10 +4,10 @@ import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/constants"
"ai_scheduler/internal/entitys"
"context"
"encoding/json"
"fmt"
"github.com/gofiber/websocket/v2"
)
// Manager 工具管理器
@ -22,16 +22,16 @@ func NewManager(config *config.Config) *Manager {
}
// 注册天气工具
if config.Tools.Weather.Enabled {
weatherTool := NewWeatherTool()
m.tools[weatherTool.Name()] = weatherTool
}
// 注册计算器工具
if config.Tools.Calculator.Enabled {
calcTool := NewCalculatorTool()
m.tools[calcTool.Name()] = calcTool
}
//if config.Tools.Weather.Enabled {
// weatherTool := NewWeatherTool()
// m.tools[weatherTool.Name()] = weatherTool
//}
//
//// 注册计算器工具
//if config.Tools.Calculator.Enabled {
// calcTool := NewCalculatorTool()
// m.tools[calcTool.Name()] = calcTool
//}
// 注册知识库工具
// if config.Knowledge.Enabled {
@ -80,43 +80,43 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi
}
// ExecuteTool 执行工具
func (m *Manager) ExecuteTool(ctx context.Context, name string, args json.RawMessage) (interface{}, error) {
func (m *Manager) ExecuteTool(c *websocket.Conn, name string, args json.RawMessage) error {
tool, exists := m.GetTool(name)
if !exists {
return nil, fmt.Errorf("tool not found: %s", name)
return fmt.Errorf("tool not found: %s", name)
}
return tool.Execute(ctx, args)
return tool.Execute(c, args)
}
// ExecuteToolCalls 执行多个工具调用
func (m *Manager) ExecuteToolCalls(ctx context.Context, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) {
results := make([]entitys.ToolCall, len(toolCalls))
for i, toolCall := range toolCalls {
results[i] = toolCall
// 执行工具
result, err := m.ExecuteTool(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
if err != nil {
// 将错误信息作为结果返回
errorResult := map[string]interface{}{
"error": err.Error(),
}
resultBytes, _ := json.Marshal(errorResult)
results[i].Result = resultBytes
} else {
// 将成功结果序列化
resultBytes, err := json.Marshal(result)
if err != nil {
errorResult := map[string]interface{}{
"error": fmt.Sprintf("failed to serialize result: %v", err),
}
resultBytes, _ = json.Marshal(errorResult)
}
results[i].Result = resultBytes
}
}
return results, nil
}
//func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) {
// results := make([]entitys.ToolCall, len(toolCalls))
//
// for i, toolCall := range toolCalls {
// results[i] = toolCall
//
// // 执行工具
// err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments)
// if err != nil {
// // 将错误信息作为结果返回
// errorResult := map[string]interface{}{
// "error": err.Error(),
// }
// resultBytes, _ := json.Marshal(errorResult)
// results[i].Result = resultBytes
// } else {
// // 将成功结果序列化
// resultBytes, err := json.Marshal(result)
// if err != nil {
// errorResult := map[string]interface{}{
// "error": fmt.Sprintf("failed to serialize result: %v", err),
// }
// resultBytes, _ = json.Marshal(errorResult)
// }
// results[i].Result = resultBytes
// }
// }
//
// return results, nil
//}

View File

@ -4,10 +4,11 @@ import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"context"
"encoding/json"
"fmt"
"net/http"
"github.com/gofiber/websocket/v2"
)
// ZltxOrderDetailTool 直连天下订单详情工具
@ -70,37 +71,37 @@ type ZltxOrderDetailData struct {
}
// Execute 执行直连天下订单详情查询
func (w *ZltxOrderDetailTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) {
func (w *ZltxOrderDetailTool) Execute(c *websocket.Conn, args json.RawMessage) error {
var req ZltxOrderDetailRequest
if err := json.Unmarshal(args, &req); err != nil {
return nil, fmt.Errorf("invalid zltxOrderDetail request: %w", err)
return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
}
if req.Number == "" {
return nil, fmt.Errorf("number is required")
return fmt.Errorf("number is required")
}
// 这里可以集成真实的直连天下订单详情API
return w.getZltxOrderDetail(ctx, req.Number), nil
return w.getZltxOrderDetail(c, req.Number)
}
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *ZltxOrderDetailTool) getZltxOrderDetail(ctx context.Context, number string) *ZltxOrderDetailResponse {
func (w *ZltxOrderDetailTool) getZltxOrderDetail(c *websocket.Conn, number string) (err error) {
url := fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number)
authorization := fmt.Sprintf("Bearer %s", w.config.APIKey)
// 发送http请求
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return &ZltxOrderDetailResponse{}
return
}
req.Header.Set("Authorization", authorization)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return &ZltxOrderDetailResponse{}
return
}
defer resp.Body.Close()
return &ZltxOrderDetailResponse{}
return
}