结构修改
This commit is contained in:
parent
b81c9ef137
commit
32866d59a1
|
@ -15,7 +15,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"strings"
|
||||
|
||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||
|
@ -203,10 +203,13 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
|
|||
|
||||
func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
|
||||
defer func() {
|
||||
c.WriteMessage(1, []byte("EOF"))
|
||||
if err != nil {
|
||||
c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
|
||||
}
|
||||
c.WriteMessage(websocket.TextMessage, []byte("EOF"))
|
||||
}()
|
||||
if !matchJson.IsMatch {
|
||||
c.WriteMessage(1, []byte(matchJson.Reasoning))
|
||||
c.WriteMessage(websocket.TextMessage, []byte(matchJson.Reasoning))
|
||||
return
|
||||
}
|
||||
var pointTask *model.AiTask
|
||||
|
@ -219,33 +222,32 @@ func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Matc
|
|||
if pointTask == nil || pointTask.Index == "other" {
|
||||
return r.handleOtherTask(c, matchJson)
|
||||
}
|
||||
var res []byte
|
||||
|
||||
switch pointTask.Type {
|
||||
case constant.TaskTypeApi:
|
||||
res, err = r.handleApiTask(c, matchJson, pointTask)
|
||||
err = r.handleApiTask(c, matchJson, pointTask)
|
||||
case constant.TaskTypeFunc:
|
||||
ctx := context.TODO()
|
||||
res, err = r.handleTask(ctx, matchJson, pointTask)
|
||||
err = r.handleTask(c, matchJson, pointTask)
|
||||
default:
|
||||
return r.handleOtherTask(c, matchJson)
|
||||
}
|
||||
fmt.Println(res)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) handleTask(c context.Context, matchJson *entitys.Match, task *model.AiTask) (res []byte, err error) {
|
||||
func (r *AiRouterService) handleTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
|
||||
|
||||
var configData entitys.ConfigDataTool
|
||||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resInterface, err := r.toolManager.ExecuteTool(c, configData.Tool, []byte(matchJson.Parameters))
|
||||
err = r.toolManager.ExecuteTool(c, configData.Tool, []byte(matchJson.Parameters))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return
|
||||
}
|
||||
|
||||
return json.Marshal(resInterface)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) {
|
||||
|
@ -255,7 +257,7 @@ func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.
|
|||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (resByte []byte, err error) {
|
||||
func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
|
||||
var (
|
||||
request l_request.Request
|
||||
auth = c.Headers("X-Authorization", "")
|
||||
|
@ -286,7 +288,8 @@ func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Ma
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
return res.Content, nil
|
||||
c.WriteMessage(1, res.Content)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
||||
|
@ -412,59 +415,6 @@ func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.
|
|||
return
|
||||
}
|
||||
|
||||
// extractIntent 从AI响应中提取意图
|
||||
func (r *AiRouterService) extractIntent(response *entitys.ChatResponse) string {
|
||||
if response == nil || response.Message == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 尝试解析JSON
|
||||
var intent struct {
|
||||
Intent string `json:"intent"`
|
||||
Confidence string `json:"confidence"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
err := json.Unmarshal([]byte(response.Message), &intent)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse intent JSON: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return intent.Intent
|
||||
}
|
||||
|
||||
// handleOrderDiagnosis 处理订单诊断意图
|
||||
func (r *AiRouterService) handleOrderDiagnosis(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
|
||||
// 调用订单详情工具
|
||||
//orderDetailTool, ok := r.toolManager.GetTool("zltxOrderDetail")
|
||||
//if orderDetailTool == nil || !ok {
|
||||
// return nil, fmt.Errorf("order detail tool not found")
|
||||
//}
|
||||
//orderDetailTool.Execute(ctx, json.RawMessage{})
|
||||
//
|
||||
//// 获取相关工具定义
|
||||
//toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller))
|
||||
//
|
||||
//// 调用AI,获取是否需要使用工具
|
||||
//response, err := r.aiClient.Chat(ctx, messages, toolDefinitions)
|
||||
//if err != nil {
|
||||
// return nil, fmt.Errorf("failed to chat with AI: %w", err)
|
||||
//}
|
||||
//
|
||||
//// 如果没有工具调用,直接返回
|
||||
//if len(response.ToolCalls) == 0 {
|
||||
// return response, nil
|
||||
//}
|
||||
//
|
||||
//// 执行工具调用
|
||||
//toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls)
|
||||
//if err != nil {
|
||||
// return nil, fmt.Errorf("failed to execute tools: %w", err)
|
||||
//}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// handleKnowledgeQA 处理知识问答意图
|
||||
func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ type Tool interface {
|
|||
Name() string
|
||||
Description() string
|
||||
Definition() ToolDefinition
|
||||
Execute(ctx context.Context, args json.RawMessage) (interface{}, error)
|
||||
Execute(c *websocket.Conn, args json.RawMessage) error
|
||||
}
|
||||
|
||||
type ConfigDataHttp struct {
|
||||
|
|
|
@ -4,10 +4,10 @@ import (
|
|||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/constants"
|
||||
"ai_scheduler/internal/entitys"
|
||||
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
)
|
||||
|
||||
// Manager 工具管理器
|
||||
|
@ -22,16 +22,16 @@ func NewManager(config *config.Config) *Manager {
|
|||
}
|
||||
|
||||
// 注册天气工具
|
||||
if config.Tools.Weather.Enabled {
|
||||
weatherTool := NewWeatherTool()
|
||||
m.tools[weatherTool.Name()] = weatherTool
|
||||
}
|
||||
|
||||
// 注册计算器工具
|
||||
if config.Tools.Calculator.Enabled {
|
||||
calcTool := NewCalculatorTool()
|
||||
m.tools[calcTool.Name()] = calcTool
|
||||
}
|
||||
//if config.Tools.Weather.Enabled {
|
||||
// weatherTool := NewWeatherTool()
|
||||
// m.tools[weatherTool.Name()] = weatherTool
|
||||
//}
|
||||
//
|
||||
//// 注册计算器工具
|
||||
//if config.Tools.Calculator.Enabled {
|
||||
// calcTool := NewCalculatorTool()
|
||||
// m.tools[calcTool.Name()] = calcTool
|
||||
//}
|
||||
|
||||
// 注册知识库工具
|
||||
// if config.Knowledge.Enabled {
|
||||
|
@ -80,43 +80,43 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi
|
|||
}
|
||||
|
||||
// ExecuteTool 执行工具
|
||||
func (m *Manager) ExecuteTool(ctx context.Context, name string, args json.RawMessage) (interface{}, error) {
|
||||
func (m *Manager) ExecuteTool(c *websocket.Conn, name string, args json.RawMessage) error {
|
||||
tool, exists := m.GetTool(name)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("tool not found: %s", name)
|
||||
return fmt.Errorf("tool not found: %s", name)
|
||||
}
|
||||
|
||||
return tool.Execute(ctx, args)
|
||||
return tool.Execute(c, args)
|
||||
}
|
||||
|
||||
// ExecuteToolCalls 执行多个工具调用
|
||||
func (m *Manager) ExecuteToolCalls(ctx context.Context, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) {
|
||||
results := make([]entitys.ToolCall, len(toolCalls))
|
||||
|
||||
for i, toolCall := range toolCalls {
|
||||
results[i] = toolCall
|
||||
|
||||
// 执行工具
|
||||
result, err := m.ExecuteTool(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
if err != nil {
|
||||
// 将错误信息作为结果返回
|
||||
errorResult := map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
}
|
||||
resultBytes, _ := json.Marshal(errorResult)
|
||||
results[i].Result = resultBytes
|
||||
} else {
|
||||
// 将成功结果序列化
|
||||
resultBytes, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
errorResult := map[string]interface{}{
|
||||
"error": fmt.Sprintf("failed to serialize result: %v", err),
|
||||
}
|
||||
resultBytes, _ = json.Marshal(errorResult)
|
||||
}
|
||||
results[i].Result = resultBytes
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
//func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) {
|
||||
// results := make([]entitys.ToolCall, len(toolCalls))
|
||||
//
|
||||
// for i, toolCall := range toolCalls {
|
||||
// results[i] = toolCall
|
||||
//
|
||||
// // 执行工具
|
||||
// err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
// if err != nil {
|
||||
// // 将错误信息作为结果返回
|
||||
// errorResult := map[string]interface{}{
|
||||
// "error": err.Error(),
|
||||
// }
|
||||
// resultBytes, _ := json.Marshal(errorResult)
|
||||
// results[i].Result = resultBytes
|
||||
// } else {
|
||||
// // 将成功结果序列化
|
||||
// resultBytes, err := json.Marshal(result)
|
||||
// if err != nil {
|
||||
// errorResult := map[string]interface{}{
|
||||
// "error": fmt.Sprintf("failed to serialize result: %v", err),
|
||||
// }
|
||||
// resultBytes, _ = json.Marshal(errorResult)
|
||||
// }
|
||||
// results[i].Result = resultBytes
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// return results, nil
|
||||
//}
|
||||
|
|
|
@ -4,10 +4,11 @@ import (
|
|||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/entitys"
|
||||
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gofiber/websocket/v2"
|
||||
)
|
||||
|
||||
// ZltxOrderDetailTool 直连天下订单详情工具
|
||||
|
@ -70,37 +71,37 @@ type ZltxOrderDetailData struct {
|
|||
}
|
||||
|
||||
// Execute 执行直连天下订单详情查询
|
||||
func (w *ZltxOrderDetailTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) {
|
||||
func (w *ZltxOrderDetailTool) Execute(c *websocket.Conn, args json.RawMessage) error {
|
||||
var req ZltxOrderDetailRequest
|
||||
if err := json.Unmarshal(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("invalid zltxOrderDetail request: %w", err)
|
||||
return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
|
||||
}
|
||||
|
||||
if req.Number == "" {
|
||||
return nil, fmt.Errorf("number is required")
|
||||
return fmt.Errorf("number is required")
|
||||
}
|
||||
|
||||
// 这里可以集成真实的直连天下订单详情API
|
||||
return w.getZltxOrderDetail(ctx, req.Number), nil
|
||||
return w.getZltxOrderDetail(c, req.Number)
|
||||
}
|
||||
|
||||
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
|
||||
func (w *ZltxOrderDetailTool) getZltxOrderDetail(ctx context.Context, number string) *ZltxOrderDetailResponse {
|
||||
func (w *ZltxOrderDetailTool) getZltxOrderDetail(c *websocket.Conn, number string) (err error) {
|
||||
url := fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number)
|
||||
authorization := fmt.Sprintf("Bearer %s", w.config.APIKey)
|
||||
|
||||
// 发送http请求
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return &ZltxOrderDetailResponse{}
|
||||
return
|
||||
}
|
||||
req.Header.Set("Authorization", authorization)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return &ZltxOrderDetailResponse{}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return &ZltxOrderDetailResponse{}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue