结构修改
This commit is contained in:
		
							parent
							
								
									8b57afd572
								
							
						
					
					
						commit
						3058a1e6b9
					
				| 
						 | 
					@ -4,8 +4,8 @@ server:
 | 
				
			||||||
  host: "0.0.0.0"
 | 
					  host: "0.0.0.0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ollama:
 | 
					ollama:
 | 
				
			||||||
  base_url: "http://localhost:11434"
 | 
					  base_url: "http://127.0.0.1:11434"
 | 
				
			||||||
  model: "qwen3:8b"
 | 
					  model: "qwen3-coder:480b-cloud"
 | 
				
			||||||
  timeout: "120s"
 | 
					  timeout: "120s"
 | 
				
			||||||
  level: "info"
 | 
					  level: "info"
 | 
				
			||||||
  format: "json"
 | 
					  format: "json"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,38 @@
 | 
				
			||||||
 | 
					package llm_service
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"ai_scheduler/internal/data/model"
 | 
				
			||||||
 | 
						"ai_scheduler/internal/entitys"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type LlmService interface {
 | 
				
			||||||
 | 
						IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (string, error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// buildSystemPrompt 构建系统提示词
 | 
				
			||||||
 | 
					func buildSystemPrompt(prompt string) string {
 | 
				
			||||||
 | 
						if len(prompt) == 0 {
 | 
				
			||||||
 | 
							prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return prompt
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
 | 
				
			||||||
 | 
						for _, item := range his {
 | 
				
			||||||
 | 
							if len(chatHis.SessionId) == 0 {
 | 
				
			||||||
 | 
								chatHis.SessionId = item.SessionID
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
 | 
				
			||||||
 | 
								Role:      item.Role,
 | 
				
			||||||
 | 
								Content:   item.Content,
 | 
				
			||||||
 | 
								Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						chatHis.Context = entitys.HisContext{
 | 
				
			||||||
 | 
							UserLanguage: "zh-CN",
 | 
				
			||||||
 | 
							SystemMode:   "technical_support",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,86 @@
 | 
				
			||||||
 | 
					package llm_service
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"ai_scheduler/internal/data/model"
 | 
				
			||||||
 | 
						"ai_scheduler/internal/entitys"
 | 
				
			||||||
 | 
						"ai_scheduler/internal/pkg"
 | 
				
			||||||
 | 
						"ai_scheduler/internal/pkg/utils_langchain"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/tmc/langchaingo/llms"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type LangChainService struct {
 | 
				
			||||||
 | 
						client *utils_langchain.UtilLangChain
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewLangChainGenerate(
 | 
				
			||||||
 | 
						client *utils_langchain.UtilLangChain,
 | 
				
			||||||
 | 
					) *LangChainService {
 | 
				
			||||||
 | 
						return &LangChainService{
 | 
				
			||||||
 | 
							client: client,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
 | 
				
			||||||
 | 
						prompt := r.getPrompt(sysInfo, history, userInput, tasks)
 | 
				
			||||||
 | 
						AgentClient := r.client.Get()
 | 
				
			||||||
 | 
						defer r.client.Put(AgentClient)
 | 
				
			||||||
 | 
						match, err := AgentClient.Llm.GenerateContent(
 | 
				
			||||||
 | 
							ctx, // 使用可取消的上下文
 | 
				
			||||||
 | 
							prompt,
 | 
				
			||||||
 | 
							llms.WithJSONMode(),
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						msg = match.Choices[0].Content
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							prompt = make([]llms.MessageContent, 0)
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prompt = append(prompt, llms.MessageContent{
 | 
				
			||||||
 | 
							Role: llms.ChatMessageTypeSystem,
 | 
				
			||||||
 | 
							Parts: []llms.ContentPart{
 | 
				
			||||||
 | 
								llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}, llms.MessageContent{
 | 
				
			||||||
 | 
							Role: llms.ChatMessageTypeTool,
 | 
				
			||||||
 | 
							Parts: []llms.ContentPart{
 | 
				
			||||||
 | 
								llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}, llms.MessageContent{
 | 
				
			||||||
 | 
							Role: llms.ChatMessageTypeTool,
 | 
				
			||||||
 | 
							Parts: []llms.ContentPart{
 | 
				
			||||||
 | 
								llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}, llms.MessageContent{
 | 
				
			||||||
 | 
							Role: llms.ChatMessageTypeHuman,
 | 
				
			||||||
 | 
							Parts: []llms.ContentPart{
 | 
				
			||||||
 | 
								llms.TextPart(reqInput),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						return prompt
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
 | 
				
			||||||
 | 
						taskPrompt := make([]llms.Tool, 0)
 | 
				
			||||||
 | 
						for _, task := range tasks {
 | 
				
			||||||
 | 
							var taskConfig entitys.TaskConfig
 | 
				
			||||||
 | 
							err := json.Unmarshal([]byte(task.Config), &taskConfig)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							taskPrompt = append(taskPrompt, llms.Tool{
 | 
				
			||||||
 | 
								Type: "function",
 | 
				
			||||||
 | 
								Function: &llms.FunctionDefinition{
 | 
				
			||||||
 | 
									Name:        task.Index,
 | 
				
			||||||
 | 
									Description: task.Desc,
 | 
				
			||||||
 | 
									Parameters:  taskConfig.Param,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return taskPrompt
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,79 @@
 | 
				
			||||||
 | 
					package llm_service
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"ai_scheduler/internal/data/model"
 | 
				
			||||||
 | 
						"ai_scheduler/internal/entitys"
 | 
				
			||||||
 | 
						"ai_scheduler/internal/pkg"
 | 
				
			||||||
 | 
						"ai_scheduler/internal/pkg/utils_ollama"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/ollama/ollama/api"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type OllamaService struct {
 | 
				
			||||||
 | 
						client *utils_ollama.Client
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewOllamaGenerate(
 | 
				
			||||||
 | 
						client *utils_ollama.Client,
 | 
				
			||||||
 | 
					) *OllamaService {
 | 
				
			||||||
 | 
						return &OllamaService{
 | 
				
			||||||
 | 
							client: client,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
 | 
				
			||||||
 | 
						prompt := r.getPrompt(requireData.Sys, requireData.Histories, requireData.UserInput, requireData.Tasks)
 | 
				
			||||||
 | 
						toolDefinitions := r.registerToolsOllama(requireData.Tasks)
 | 
				
			||||||
 | 
						match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						msg = match.Message.Content
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *OllamaService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							prompt = make([]api.Message, 0)
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prompt = append(prompt, api.Message{
 | 
				
			||||||
 | 
							Role:    "system",
 | 
				
			||||||
 | 
							Content: buildSystemPrompt(sysInfo.SysPrompt),
 | 
				
			||||||
 | 
						}, api.Message{
 | 
				
			||||||
 | 
							Role:    "assistant",
 | 
				
			||||||
 | 
							Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(history))),
 | 
				
			||||||
 | 
						}, api.Message{
 | 
				
			||||||
 | 
							Role:    "user",
 | 
				
			||||||
 | 
							Content: reqInput,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						return prompt
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
 | 
				
			||||||
 | 
						taskPrompt := make([]api.Tool, 0)
 | 
				
			||||||
 | 
						for _, task := range tasks {
 | 
				
			||||||
 | 
							var taskConfig entitys.TaskConfigDetail
 | 
				
			||||||
 | 
							err := json.Unmarshal([]byte(task.Config), &taskConfig)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							taskPrompt = append(taskPrompt, api.Tool{
 | 
				
			||||||
 | 
								Type: "function",
 | 
				
			||||||
 | 
								Function: api.ToolFunction{
 | 
				
			||||||
 | 
									Name:        task.Index,
 | 
				
			||||||
 | 
									Description: task.Desc,
 | 
				
			||||||
 | 
									Parameters: api.ToolFunctionParameters{
 | 
				
			||||||
 | 
										Type:       taskConfig.Param.Type,
 | 
				
			||||||
 | 
										Required:   taskConfig.Param.Required,
 | 
				
			||||||
 | 
										Properties: taskConfig.Param.Properties,
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return taskPrompt
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,15 @@
 | 
				
			||||||
package biz
 | 
					package biz
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import "github.com/google/wire"
 | 
					import (
 | 
				
			||||||
 | 
						"ai_scheduler/internal/biz/llm_service"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var ProviderSetBiz = wire.NewSet(NewAiRouterBiz, NewSessionBiz, NewChatHistoryBiz)
 | 
						"github.com/google/wire"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var ProviderSetBiz = wire.NewSet(
 | 
				
			||||||
 | 
						NewAiRouterBiz,
 | 
				
			||||||
 | 
						NewSessionBiz,
 | 
				
			||||||
 | 
						NewChatHistoryBiz,
 | 
				
			||||||
 | 
						llm_service.NewLangChainGenerate,
 | 
				
			||||||
 | 
						llm_service.NewOllamaGenerate,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,7 @@
 | 
				
			||||||
package biz
 | 
					package biz
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"ai_scheduler/internal/biz/llm_service"
 | 
				
			||||||
	"ai_scheduler/internal/config"
 | 
						"ai_scheduler/internal/config"
 | 
				
			||||||
	"ai_scheduler/internal/data/constants"
 | 
						"ai_scheduler/internal/data/constants"
 | 
				
			||||||
	errors "ai_scheduler/internal/data/error"
 | 
						errors "ai_scheduler/internal/data/error"
 | 
				
			||||||
| 
						 | 
					@ -9,8 +10,8 @@ import (
 | 
				
			||||||
	"ai_scheduler/internal/entitys"
 | 
						"ai_scheduler/internal/entitys"
 | 
				
			||||||
	"ai_scheduler/internal/pkg"
 | 
						"ai_scheduler/internal/pkg"
 | 
				
			||||||
	"ai_scheduler/internal/pkg/mapstructure"
 | 
						"ai_scheduler/internal/pkg/mapstructure"
 | 
				
			||||||
	"ai_scheduler/internal/pkg/utils_ollama"
 | 
					
 | 
				
			||||||
	tools "ai_scheduler/internal/tools"
 | 
						"ai_scheduler/internal/tools"
 | 
				
			||||||
	"ai_scheduler/tmpl/dataTemp"
 | 
						"ai_scheduler/tmpl/dataTemp"
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
| 
						 | 
					@ -21,96 +22,108 @@ import (
 | 
				
			||||||
	"gitea.cdlsxd.cn/self-tools/l_request"
 | 
						"gitea.cdlsxd.cn/self-tools/l_request"
 | 
				
			||||||
	"github.com/gofiber/fiber/v2/log"
 | 
						"github.com/gofiber/fiber/v2/log"
 | 
				
			||||||
	"github.com/gofiber/websocket/v2"
 | 
						"github.com/gofiber/websocket/v2"
 | 
				
			||||||
	"github.com/ollama/ollama/api"
 | 
					 | 
				
			||||||
	"github.com/tmc/langchaingo/llms"
 | 
					 | 
				
			||||||
	"xorm.io/builder"
 | 
						"xorm.io/builder"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// AiRouterBiz 智能路由服务
 | 
					// AiRouterBiz 智能路由服务
 | 
				
			||||||
type AiRouterBiz struct {
 | 
					type AiRouterBiz struct {
 | 
				
			||||||
	//aiClient    entitys.AIClient
 | 
					 | 
				
			||||||
	toolManager *tools.Manager
 | 
						toolManager *tools.Manager
 | 
				
			||||||
	sessionImpl *impl.SessionImpl
 | 
						sessionImpl *impl.SessionImpl
 | 
				
			||||||
	sysImpl     *impl.SysImpl
 | 
						sysImpl     *impl.SysImpl
 | 
				
			||||||
	taskImpl    *impl.TaskImpl
 | 
						taskImpl    *impl.TaskImpl
 | 
				
			||||||
	hisImpl     *impl.ChatImpl
 | 
						hisImpl     *impl.ChatImpl
 | 
				
			||||||
	conf        *config.Config
 | 
						conf        *config.Config
 | 
				
			||||||
	utilAgent   *utils_ollama.UtilOllama
 | 
					 | 
				
			||||||
	ollama      *utils_ollama.Client
 | 
					 | 
				
			||||||
	channelPool *pkg.SafeChannelPool
 | 
					 | 
				
			||||||
	rds         *pkg.Rdb
 | 
						rds         *pkg.Rdb
 | 
				
			||||||
 | 
						langChain   *llm_service.LangChainService
 | 
				
			||||||
 | 
						Ollama      *llm_service.OllamaService
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewRouterService 创建路由服务
 | 
					// NewRouterService 创建路由服务
 | 
				
			||||||
func NewAiRouterBiz(
 | 
					func NewAiRouterBiz(
 | 
				
			||||||
	//aiClient entitys.AIClient,
 | 
					 | 
				
			||||||
	toolManager *tools.Manager,
 | 
						toolManager *tools.Manager,
 | 
				
			||||||
	sessionImpl *impl.SessionImpl,
 | 
						sessionImpl *impl.SessionImpl,
 | 
				
			||||||
	sysImpl *impl.SysImpl,
 | 
						sysImpl *impl.SysImpl,
 | 
				
			||||||
	taskImpl *impl.TaskImpl,
 | 
						taskImpl *impl.TaskImpl,
 | 
				
			||||||
	hisImpl *impl.ChatImpl,
 | 
						hisImpl *impl.ChatImpl,
 | 
				
			||||||
	conf *config.Config,
 | 
						conf *config.Config,
 | 
				
			||||||
	utilAgent *utils_ollama.UtilOllama,
 | 
						langChain *llm_service.LangChainService,
 | 
				
			||||||
	channelPool *pkg.SafeChannelPool,
 | 
						Ollama *llm_service.OllamaService,
 | 
				
			||||||
	ollama *utils_ollama.Client,
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
) *AiRouterBiz {
 | 
					) *AiRouterBiz {
 | 
				
			||||||
	return &AiRouterBiz{
 | 
						return &AiRouterBiz{
 | 
				
			||||||
		//aiClient:    aiClient,
 | 
					 | 
				
			||||||
		toolManager: toolManager,
 | 
							toolManager: toolManager,
 | 
				
			||||||
		sessionImpl: sessionImpl,
 | 
							sessionImpl: sessionImpl,
 | 
				
			||||||
		conf:        conf,
 | 
							conf:        conf,
 | 
				
			||||||
		sysImpl:     sysImpl,
 | 
							sysImpl:     sysImpl,
 | 
				
			||||||
		hisImpl:     hisImpl,
 | 
							hisImpl:     hisImpl,
 | 
				
			||||||
		taskImpl:    taskImpl,
 | 
							taskImpl:    taskImpl,
 | 
				
			||||||
		utilAgent:   utilAgent,
 | 
							langChain:   langChain,
 | 
				
			||||||
		channelPool: channelPool,
 | 
							Ollama:      Ollama,
 | 
				
			||||||
		ollama:      ollama,
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Route 执行智能路由
 | 
					 | 
				
			||||||
func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
 | 
					func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
 | 
				
			||||||
 | 
						//必要数据验证和获取
 | 
				
			||||||
	session := c.Query("x-session", "")
 | 
						var requireData entitys.RequireData
 | 
				
			||||||
	if len(session) == 0 {
 | 
						err = r.dataAuth(c, &requireData)
 | 
				
			||||||
		return errors.SessionNotFound
 | 
						if err != nil {
 | 
				
			||||||
	}
 | 
							return
 | 
				
			||||||
	auth := c.Query("x-authorization", "")
 | 
					 | 
				
			||||||
	if len(auth) == 0 {
 | 
					 | 
				
			||||||
		return errors.AuthNotFound
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	key := c.Query("x-app-key", "")
 | 
					 | 
				
			||||||
	if len(key) == 0 {
 | 
					 | 
				
			||||||
		return errors.KeyNotFound
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var chat = make([]string, 0)
 | 
						//初始化通道/上下文
 | 
				
			||||||
 | 
						requireData.Ch = make(chan entitys.Response)
 | 
				
			||||||
	ctx, cancel := context.WithCancel(context.Background())
 | 
						ctx, cancel := context.WithCancel(context.Background())
 | 
				
			||||||
	defer cancel()
 | 
						// 启动独立的消息处理协程
 | 
				
			||||||
	//ch := r.channelPool.Get()
 | 
						done := r.startMessageHandler(ctx, c, &requireData)
 | 
				
			||||||
	ch := make(chan entitys.Response)
 | 
						defer func() {
 | 
				
			||||||
 | 
							close(requireData.Ch) //关闭主通道
 | 
				
			||||||
 | 
							<-done                // 等待消息处理完成
 | 
				
			||||||
 | 
							cancel()
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						//获取系统信息
 | 
				
			||||||
 | 
						err = r.getRequireData(req.Text, &requireData)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Errorf("SQL error: %v", err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						//意图识别
 | 
				
			||||||
 | 
						err = r.recognize(ctx, &requireData)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Errorf("LLM error: %v", err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						//向下传递
 | 
				
			||||||
 | 
						if err = r.handleMatch(ctx, &requireData); err != nil {
 | 
				
			||||||
 | 
							log.Errorf("Handle error: %v", err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// startMessageHandler 启动独立的消息处理协程
 | 
				
			||||||
 | 
					func (r *AiRouterBiz) startMessageHandler(
 | 
				
			||||||
 | 
						ctx context.Context,
 | 
				
			||||||
 | 
						c *websocket.Conn,
 | 
				
			||||||
 | 
						requireData *entitys.RequireData,
 | 
				
			||||||
 | 
					) <-chan struct{} {
 | 
				
			||||||
	done := make(chan struct{})
 | 
						done := make(chan struct{})
 | 
				
			||||||
 | 
						var chat []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		defer func() {
 | 
							defer func() {
 | 
				
			||||||
			close(done)
 | 
								close(done)
 | 
				
			||||||
 | 
								// 保存历史记录
 | 
				
			||||||
			var his = []*model.AiChatHi{
 | 
								var his = []*model.AiChatHi{
 | 
				
			||||||
				{
 | 
									{
 | 
				
			||||||
					SessionID: session,
 | 
										SessionID: requireData.Session,
 | 
				
			||||||
					Role:      "user",
 | 
										Role:      "user",
 | 
				
			||||||
					Content:   req.Text,
 | 
										Content:   "", // 用户输入在外部处理
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if len(chat) > 0 {
 | 
								if len(chat) > 0 {
 | 
				
			||||||
				his = append(his, &model.AiChatHi{
 | 
									his = append(his, &model.AiChatHi{
 | 
				
			||||||
					SessionID: session,
 | 
										SessionID: requireData.Session,
 | 
				
			||||||
					Role:      "assistant",
 | 
										Role:      "assistant",
 | 
				
			||||||
					Content:   strings.Join(chat, ""),
 | 
										Content:   strings.Join(chat, ""),
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
| 
						 | 
					@ -119,92 +132,19 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe
 | 
				
			||||||
				r.hisImpl.Add(hi)
 | 
									r.hisImpl.Add(hi)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
		for {
 | 
					 | 
				
			||||||
			select {
 | 
					 | 
				
			||||||
			case v, ok := <-ch:
 | 
					 | 
				
			||||||
				if !ok {
 | 
					 | 
				
			||||||
					return
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				// 带超时的发送,避免阻塞
 | 
					 | 
				
			||||||
				if err = sendWithTimeout(c, v, 2*time.Second); err != nil {
 | 
					 | 
				
			||||||
					log.Errorf("Send error: %v", err)
 | 
					 | 
				
			||||||
					cancel() // 通知主流程退出
 | 
					 | 
				
			||||||
					return
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
 | 
					 | 
				
			||||||
					chat = append(chat, v.Content)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
			case <-ctx.Done():
 | 
							for v := range requireData.Ch { // 自动检测通道关闭
 | 
				
			||||||
 | 
								if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
 | 
				
			||||||
 | 
									log.Errorf("Send error: %v", err)
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
 | 
				
			||||||
 | 
									chat = append(chat, v.Content)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	defer func() {
 | 
						return done
 | 
				
			||||||
		close(ch)
 | 
					 | 
				
			||||||
	}()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sysInfo, err := r.getSysInfo(key)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return errors.SysErr("获取系统信息失败:%v", err.Error())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	history, err := r.getSessionChatHis(session)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return errors.SysErr("获取历史记录失败:%v", err.Error())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	task, err := r.getTasks(sysInfo.SysID)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return errors.SysErr("获取任务列表失败:%v", err.Error())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	AgentClient := r.utilAgent.Get()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	ch <- entitys.Response{
 | 
					 | 
				
			||||||
		Index:   "",
 | 
					 | 
				
			||||||
		Content: "准备意图识别",
 | 
					 | 
				
			||||||
		Type:    entitys.ResponseLog,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	match, err := AgentClient.Llm.GenerateContent(
 | 
					 | 
				
			||||||
		ctx, // 使用可取消的上下文
 | 
					 | 
				
			||||||
		prompt,
 | 
					 | 
				
			||||||
		llms.WithJSONMode(),
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	resMsg := match.Choices[0].Content
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	r.utilAgent.Put(AgentClient)
 | 
					 | 
				
			||||||
	ch <- entitys.Response{
 | 
					 | 
				
			||||||
		Index:   "",
 | 
					 | 
				
			||||||
		Content: resMsg,
 | 
					 | 
				
			||||||
		Type:    entitys.ResponseLog,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	ch <- entitys.Response{
 | 
					 | 
				
			||||||
		Index:   "",
 | 
					 | 
				
			||||||
		Content: "意图识别结束",
 | 
					 | 
				
			||||||
		Type:    entitys.ResponseLog,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		log.Errorf("LLM error: %v", err)
 | 
					 | 
				
			||||||
		return errors.SystemError
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	var matchJson entitys.Match
 | 
					 | 
				
			||||||
	if err := json.Unmarshal([]byte(resMsg), &matchJson); err != nil {
 | 
					 | 
				
			||||||
		log.Info(resMsg)
 | 
					 | 
				
			||||||
		return errors.SysErr("数据结构错误:%v", err.Error())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	matchJson.History = pkg.JsonByteIgonErr(history)
 | 
					 | 
				
			||||||
	matchJson.UserInput = req.Text
 | 
					 | 
				
			||||||
	if err := r.handleMatch(ctx, c, ch, &matchJson, task, sysInfo); err != nil {
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 辅助函数:带超时的 WebSocket 发送
 | 
					// 辅助函数:带超时的 WebSocket 发送
 | 
				
			||||||
| 
						 | 
					@ -218,8 +158,11 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
 | 
				
			||||||
			if r := recover(); r != nil {
 | 
								if r := recover(); r != nil {
 | 
				
			||||||
				done <- fmt.Errorf("panic in MsgSend: %v", r)
 | 
									done <- fmt.Errorf("panic in MsgSend: %v", r)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								close(done)
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
		done <- entitys.MsgSend(c, data)
 | 
							// 如果 MsgSend 阻塞,这里会卡住
 | 
				
			||||||
 | 
							err := entitys.MsgSend(c, data)
 | 
				
			||||||
 | 
							done <- err
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	select {
 | 
						select {
 | 
				
			||||||
| 
						 | 
					@ -229,58 +172,135 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
 | 
				
			||||||
		return sendCtx.Err()
 | 
							return sendCtx.Err()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match) (err error) {
 | 
					
 | 
				
			||||||
	ch <- entitys.Response{
 | 
					func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
 | 
				
			||||||
 | 
						requireData.Ch <- entitys.Response{
 | 
				
			||||||
		Index:   "",
 | 
							Index:   "",
 | 
				
			||||||
		Content: matchJson.Reasoning,
 | 
							Content: "准备意图识别",
 | 
				
			||||||
 | 
							Type:    entitys.ResponseLog,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						//意图识别
 | 
				
			||||||
 | 
						recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						requireData.Ch <- entitys.Response{
 | 
				
			||||||
 | 
							Index:   "",
 | 
				
			||||||
 | 
							Content: recognizeMsg,
 | 
				
			||||||
 | 
							Type:    entitys.ResponseLog,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						requireData.Ch <- entitys.Response{
 | 
				
			||||||
 | 
							Index:   "",
 | 
				
			||||||
 | 
							Content: "意图识别结束",
 | 
				
			||||||
 | 
							Type:    entitys.ResponseLog,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err = json.Unmarshal([]byte(recognizeMsg), requireData.Match); err != nil {
 | 
				
			||||||
 | 
							err = errors.SysErr("数据结构错误:%v", err.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *AiRouterBiz) getRequireData(userInput string, requireData *entitys.RequireData) (err error) {
 | 
				
			||||||
 | 
						requireData.Sys, err = r.getSysInfo(requireData.Key)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							err = errors.SysErr("获取系统信息失败:%v", err.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						requireData.Histories, err = r.getSessionChatHis(requireData.Session)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							err = errors.SysErr("获取历史记录失败:%v", err.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						requireData.Tasks, err = r.getTasks(requireData.Sys.SysID)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							err = errors.SysErr("获取任务列表失败:%v", err.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						requireData.UserInput = userInput
 | 
				
			||||||
 | 
						if len(requireData.UserInput) == 0 {
 | 
				
			||||||
 | 
							err = errors.SysErr("获取用户输入失败")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(requireData.UserInput) == 0 {
 | 
				
			||||||
 | 
							err = errors.SysErr("获取用户输入失败")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *AiRouterBiz) dataAuth(c *websocket.Conn, requireData *entitys.RequireData) (err error) {
 | 
				
			||||||
 | 
						requireData.Session = c.Query("x-session", "")
 | 
				
			||||||
 | 
						if len(requireData.Session) == 0 {
 | 
				
			||||||
 | 
							err = errors.SessionNotFound
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						requireData.Auth = c.Query("x-authorization", "")
 | 
				
			||||||
 | 
						if len(requireData.Auth) == 0 {
 | 
				
			||||||
 | 
							err = errors.AuthNotFound
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						requireData.Key = c.Query("x-app-key", "")
 | 
				
			||||||
 | 
						if len(requireData.Key) == 0 {
 | 
				
			||||||
 | 
							err = errors.KeyNotFound
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *AiRouterBiz) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
 | 
				
			||||||
 | 
						requireData.Ch <- entitys.Response{
 | 
				
			||||||
 | 
							Index:   "",
 | 
				
			||||||
 | 
							Content: requireData.Match.Reasoning,
 | 
				
			||||||
		Type:    entitys.ResponseText,
 | 
							Type:    entitys.ResponseText,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) {
 | 
					func (r *AiRouterBiz) handleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !matchJson.IsMatch {
 | 
						if !requireData.Match.IsMatch {
 | 
				
			||||||
		_ = entitys.MsgSend(c, entitys.Response{
 | 
							requireData.Ch <- entitys.Response{
 | 
				
			||||||
			Index:   "",
 | 
								Index:   "",
 | 
				
			||||||
			Content: matchJson.Reasoning,
 | 
								Content: requireData.Match.Reasoning,
 | 
				
			||||||
			Type:    entitys.ResponseText,
 | 
								Type:    entitys.ResponseText,
 | 
				
			||||||
		})
 | 
							}
 | 
				
			||||||
 | 
					 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var pointTask *model.AiTask
 | 
						var pointTask *model.AiTask
 | 
				
			||||||
	for _, task := range tasks {
 | 
						for _, task := range requireData.Tasks {
 | 
				
			||||||
		if task.Index == matchJson.Index {
 | 
							if task.Index == requireData.Match.Index {
 | 
				
			||||||
			pointTask = &task
 | 
								pointTask = &task
 | 
				
			||||||
			break
 | 
								break
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if pointTask == nil || pointTask.Index == "other" {
 | 
						if pointTask == nil || pointTask.Index == "other" {
 | 
				
			||||||
		return r.handleOtherTask(c, ch, matchJson)
 | 
							return r.handleOtherTask(ctx, requireData)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	switch pointTask.Type {
 | 
						switch pointTask.Type {
 | 
				
			||||||
	case constants.TaskTypeApi:
 | 
						case constants.TaskTypeApi:
 | 
				
			||||||
		return r.handleApiTask(ch, c, matchJson, pointTask)
 | 
							return r.handleApiTask(ctx, requireData, pointTask)
 | 
				
			||||||
	case constants.TaskTypeFunc:
 | 
						case constants.TaskTypeFunc:
 | 
				
			||||||
		return r.handleTask(ch, c, matchJson, pointTask)
 | 
							return r.handleTask(ctx, requireData, pointTask)
 | 
				
			||||||
	case constants.TaskTypeKnowle:
 | 
						case constants.TaskTypeKnowle:
 | 
				
			||||||
		return r.handleKnowle(ch, c, matchJson, pointTask, sysInfo)
 | 
							return r.handleKnowle(ctx, requireData, pointTask)
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		return r.handleOtherTask(c, ch, matchJson)
 | 
							return r.handleOtherTask(ctx, requireData)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (r *AiRouterBiz) handleTask(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
 | 
					func (r *AiRouterBiz) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
 | 
				
			||||||
	var configData entitys.ConfigDataTool
 | 
						var configData entitys.ConfigDataTool
 | 
				
			||||||
	err = json.Unmarshal([]byte(task.Config), &configData)
 | 
						err = json.Unmarshal([]byte(task.Config), &configData)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = r.toolManager.ExecuteTool(channel, c, configData.Tool, []byte(matchJson.Parameters), matchJson)
 | 
						err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -289,7 +309,7 @@ func (r *AiRouterBiz) handleTask(channel chan entitys.Response, c *websocket.Con
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 知识库
 | 
					// 知识库
 | 
				
			||||||
func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask, sysInfo model.AiSy) (err error) {
 | 
					func (r *AiRouterBiz) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		configData         entitys.ConfigDataTool
 | 
							configData         entitys.ConfigDataTool
 | 
				
			||||||
| 
						 | 
					@ -303,11 +323,11 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 通过session 找到知识库session
 | 
						// 通过session 找到知识库session
 | 
				
			||||||
	session := c.Query("x-session", "")
 | 
						var has bool
 | 
				
			||||||
	if len(session) == 0 {
 | 
						if len(requireData.Session) == 0 {
 | 
				
			||||||
		return errors.SessionNotFound
 | 
							return errors.SessionNotFound
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	sessionInfo, has, err := r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(session))
 | 
						requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	} else if !has {
 | 
						} else if !has {
 | 
				
			||||||
| 
						 | 
					@ -330,15 +350,15 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 知识库的session为空,请求知识库获取, 并绑定
 | 
						// 知识库的session为空,请求知识库获取, 并绑定
 | 
				
			||||||
	if sessionInfo.KnowlegeSessionID == "" {
 | 
						if requireData.SessionInfo.KnowlegeSessionID == "" {
 | 
				
			||||||
		// 请求知识库
 | 
							// 请求知识库
 | 
				
			||||||
		if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, sysInfo.KnowlegeBaseID, sysInfo.KnowlegeTenantKey); err != nil {
 | 
							if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil {
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 绑定知识库session,下次可以使用
 | 
							// 绑定知识库session,下次可以使用
 | 
				
			||||||
		sessionInfo.KnowlegeSessionID = sessionIdKnowledge
 | 
							requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
 | 
				
			||||||
		if err = r.sessionImpl.Update(&sessionInfo, r.sessionImpl.WithSessionId(sessionInfo.SessionID)); err != nil {
 | 
							if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil {
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -346,25 +366,21 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C
 | 
				
			||||||
	// 用户输入解析
 | 
						// 用户输入解析
 | 
				
			||||||
	var ok bool
 | 
						var ok bool
 | 
				
			||||||
	input := make(map[string]string)
 | 
						input := make(map[string]string)
 | 
				
			||||||
	if err = json.Unmarshal([]byte(matchJson.Parameters), &input); err != nil {
 | 
						if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if query, ok = input["query"]; !ok {
 | 
						if query, ok = input["query"]; !ok {
 | 
				
			||||||
		return fmt.Errorf("query不能为空")
 | 
							return fmt.Errorf("query不能为空")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	knowledgeConfig := tools.KnowledgeBaseRequest{
 | 
						requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{
 | 
				
			||||||
		Session: sessionInfo.KnowlegeSessionID,
 | 
							Session: requireData.SessionInfo.KnowlegeSessionID,
 | 
				
			||||||
		ApiKey:  sysInfo.KnowlegeTenantKey,
 | 
							ApiKey:  requireData.Sys.KnowlegeTenantKey,
 | 
				
			||||||
		Query:   query,
 | 
							Query:   query,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	b, err := json.Marshal(knowledgeConfig)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 执行工具
 | 
						// 执行工具
 | 
				
			||||||
	err = r.toolManager.ExecuteTool(channel, c, configData.Tool, b, nil)
 | 
						err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -372,17 +388,16 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (r *AiRouterBiz) handleApiTask(channels chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
 | 
					func (r *AiRouterBiz) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		request      l_request.Request
 | 
							request      l_request.Request
 | 
				
			||||||
		auth         = c.Query("x-authorization", "")
 | 
					 | 
				
			||||||
		requestParam map[string]interface{}
 | 
							requestParam map[string]interface{}
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
 | 
						err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
 | 
						request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
 | 
				
			||||||
	for k, v := range requestParam {
 | 
						for k, v := range requestParam {
 | 
				
			||||||
		task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
 | 
							task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -403,7 +418,11 @@ func (r *AiRouterBiz) handleApiTask(channels chan entitys.Response, c *websocket
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.WriteMessage(1, res.Content)
 | 
						requireData.Ch <- entitys.Response{
 | 
				
			||||||
 | 
							Index:   "",
 | 
				
			||||||
 | 
							Content: pkg.JsonStringIgonErr(res.Text),
 | 
				
			||||||
 | 
							Type:    entitys.ResponseJson,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -436,119 +455,3 @@ func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool {
 | 
					 | 
				
			||||||
	taskPrompt := make([]llms.Tool, 0)
 | 
					 | 
				
			||||||
	for _, task := range tasks {
 | 
					 | 
				
			||||||
		var taskConfig entitys.TaskConfig
 | 
					 | 
				
			||||||
		err := json.Unmarshal([]byte(task.Config), &taskConfig)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		taskPrompt = append(taskPrompt, llms.Tool{
 | 
					 | 
				
			||||||
			Type: "function",
 | 
					 | 
				
			||||||
			Function: &llms.FunctionDefinition{
 | 
					 | 
				
			||||||
				Name:        task.Index,
 | 
					 | 
				
			||||||
				Description: task.Desc,
 | 
					 | 
				
			||||||
				Parameters:  taskConfig.Param,
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return taskPrompt
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		prompt = make([]entitys.Message, 0)
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	prompt = append(prompt, entitys.Message{
 | 
					 | 
				
			||||||
		Role:    "system",
 | 
					 | 
				
			||||||
		Content: r.buildSystemPrompt(sysInfo.SysPrompt),
 | 
					 | 
				
			||||||
	}, entitys.Message{
 | 
					 | 
				
			||||||
		Role:    "assistant",
 | 
					 | 
				
			||||||
		Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
 | 
					 | 
				
			||||||
	}, entitys.Message{
 | 
					 | 
				
			||||||
		Role:    "user",
 | 
					 | 
				
			||||||
		Content: reqInput,
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
	return prompt
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (r *AiRouterBiz) getPromptOllama(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []api.Message {
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		prompt = make([]api.Message, 0)
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	prompt = append(prompt, api.Message{
 | 
					 | 
				
			||||||
		Role:    "system",
 | 
					 | 
				
			||||||
		Content: r.buildSystemPrompt(sysInfo.SysPrompt),
 | 
					 | 
				
			||||||
	}, api.Message{
 | 
					 | 
				
			||||||
		Role:    "assistant",
 | 
					 | 
				
			||||||
		Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
 | 
					 | 
				
			||||||
	}, api.Message{
 | 
					 | 
				
			||||||
		Role:    "user",
 | 
					 | 
				
			||||||
		Content: reqInput,
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
	return prompt
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
 | 
					 | 
				
			||||||
	var (
 | 
					 | 
				
			||||||
		prompt = make([]llms.MessageContent, 0)
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	prompt = append(prompt, llms.MessageContent{
 | 
					 | 
				
			||||||
		Role: llms.ChatMessageTypeSystem,
 | 
					 | 
				
			||||||
		Parts: []llms.ContentPart{
 | 
					 | 
				
			||||||
			llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)),
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}, llms.MessageContent{
 | 
					 | 
				
			||||||
		Role: llms.ChatMessageTypeTool,
 | 
					 | 
				
			||||||
		Parts: []llms.ContentPart{
 | 
					 | 
				
			||||||
			llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}, llms.MessageContent{
 | 
					 | 
				
			||||||
		Role: llms.ChatMessageTypeTool,
 | 
					 | 
				
			||||||
		Parts: []llms.ContentPart{
 | 
					 | 
				
			||||||
			llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}, llms.MessageContent{
 | 
					 | 
				
			||||||
		Role: llms.ChatMessageTypeHuman,
 | 
					 | 
				
			||||||
		Parts: []llms.ContentPart{
 | 
					 | 
				
			||||||
			llms.TextPart(reqInput),
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
	return prompt
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// buildSystemPrompt 构建系统提示词
 | 
					 | 
				
			||||||
func (r *AiRouterBiz) buildSystemPrompt(prompt string) string {
 | 
					 | 
				
			||||||
	if len(prompt) == 0 {
 | 
					 | 
				
			||||||
		prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return prompt
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
 | 
					 | 
				
			||||||
	for _, item := range his {
 | 
					 | 
				
			||||||
		if len(chatHis.SessionId) == 0 {
 | 
					 | 
				
			||||||
			chatHis.SessionId = item.SessionID
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
 | 
					 | 
				
			||||||
			Role:      item.Role,
 | 
					 | 
				
			||||||
			Content:   item.Content,
 | 
					 | 
				
			||||||
			Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	chatHis.Context = entitys.HisContext{
 | 
					 | 
				
			||||||
		UserLanguage: "zh-CN",
 | 
					 | 
				
			||||||
		SystemMode:   "technical_support",
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// handleKnowledgeQA 处理知识问答意图
 | 
					 | 
				
			||||||
func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,8 @@
 | 
				
			||||||
package entitys
 | 
					package entitys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"ai_scheduler/internal/data/model"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -73,7 +75,7 @@ type Tool interface {
 | 
				
			||||||
	Name() string
 | 
						Name() string
 | 
				
			||||||
	Description() string
 | 
						Description() string
 | 
				
			||||||
	Definition() ToolDefinition
 | 
						Definition() ToolDefinition
 | 
				
			||||||
	Execute(channel chan Response, c *websocket.Conn, args json.RawMessage, matchJson *Match) error
 | 
						Execute(ctx context.Context, requireData *RequireData) error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ConfigDataHttp struct {
 | 
					type ConfigDataHttp struct {
 | 
				
			||||||
| 
						 | 
					@ -118,6 +120,7 @@ type Match struct {
 | 
				
			||||||
	Reasoning  string `json:"reasoning"`
 | 
						Reasoning  string `json:"reasoning"`
 | 
				
			||||||
	History    []byte `json:"history"`
 | 
						History    []byte `json:"history"`
 | 
				
			||||||
	UserInput  string `json:"user_input"`
 | 
						UserInput  string `json:"user_input"`
 | 
				
			||||||
 | 
						Auth       string `json:"auth"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
type ChatHis struct {
 | 
					type ChatHis struct {
 | 
				
			||||||
	SessionId string       `json:"session_id"`
 | 
						SessionId string       `json:"session_id"`
 | 
				
			||||||
| 
						 | 
					@ -135,8 +138,27 @@ type HisContext struct {
 | 
				
			||||||
	SystemMode   string `json:"system_mode"`
 | 
						SystemMode   string `json:"system_mode"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequireData struct {
 | 
				
			||||||
 | 
						Session       string
 | 
				
			||||||
 | 
						Key           string
 | 
				
			||||||
 | 
						Sys           model.AiSy
 | 
				
			||||||
 | 
						Histories     []model.AiChatHi
 | 
				
			||||||
 | 
						SessionInfo   model.AiSession
 | 
				
			||||||
 | 
						Tasks         []model.AiTask
 | 
				
			||||||
 | 
						Match         *Match
 | 
				
			||||||
 | 
						UserInput     string
 | 
				
			||||||
 | 
						Auth          string
 | 
				
			||||||
 | 
						Ch            chan Response
 | 
				
			||||||
 | 
						KnowledgeConf KnowledgeBaseRequest
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type KnowledgeBaseRequest struct {
 | 
				
			||||||
 | 
						Session string // 知识库会话id
 | 
				
			||||||
 | 
						ApiKey  string // 知识库apiKey
 | 
				
			||||||
 | 
						Query   string // 用户输入
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// RouterService 路由服务接口
 | 
					// RouterService 路由服务接口
 | 
				
			||||||
type RouterService interface {
 | 
					type RouterService interface {
 | 
				
			||||||
	Route(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
 | 
					 | 
				
			||||||
	RouteWithSocket(c *websocket.Conn, req *ChatSockRequest) error
 | 
						RouteWithSocket(c *websocket.Conn, req *ChatSockRequest) error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,7 @@
 | 
				
			||||||
package pkg
 | 
					package pkg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"ai_scheduler/internal/pkg/utils_langchain"
 | 
				
			||||||
	"ai_scheduler/internal/pkg/utils_ollama"
 | 
						"ai_scheduler/internal/pkg/utils_ollama"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/google/wire"
 | 
						"github.com/google/wire"
 | 
				
			||||||
| 
						 | 
					@ -9,7 +10,7 @@ import (
 | 
				
			||||||
var ProviderSetClient = wire.NewSet(
 | 
					var ProviderSetClient = wire.NewSet(
 | 
				
			||||||
	NewRdb,
 | 
						NewRdb,
 | 
				
			||||||
	NewGormDb,
 | 
						NewGormDb,
 | 
				
			||||||
	utils_ollama.NewUtilOllama,
 | 
						utils_langchain.NewUtilLangChain,
 | 
				
			||||||
	utils_ollama.NewClient,
 | 
						utils_ollama.NewClient,
 | 
				
			||||||
	NewSafeChannelPool,
 | 
						NewSafeChannelPool,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
package utils_ollama
 | 
					package utils_langchain
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"ai_scheduler/internal/config"
 | 
						"ai_scheduler/internal/config"
 | 
				
			||||||
| 
						 | 
					@ -13,7 +13,7 @@ import (
 | 
				
			||||||
	"github.com/tmc/langchaingo/llms/ollama"
 | 
						"github.com/tmc/langchaingo/llms/ollama"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type UtilOllama struct {
 | 
					type UtilLangChain struct {
 | 
				
			||||||
	LlmClientPool *sync.Pool
 | 
						LlmClientPool *sync.Pool
 | 
				
			||||||
	poolSize      int // 记录池大小,用于调试
 | 
						poolSize      int // 记录池大小,用于调试
 | 
				
			||||||
	model         string
 | 
						model         string
 | 
				
			||||||
| 
						 | 
					@ -26,7 +26,7 @@ type LlmObj struct {
 | 
				
			||||||
	Llm    *ollama.LLM
 | 
						Llm    *ollama.LLM
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
 | 
					func NewUtilLangChain(c *config.Config, logger log.AllLogger) *UtilLangChain {
 | 
				
			||||||
	poolSize := c.Sys.LlmPoolLen
 | 
						poolSize := c.Sys.LlmPoolLen
 | 
				
			||||||
	if poolSize <= 0 {
 | 
						if poolSize <= 0 {
 | 
				
			||||||
		poolSize = 10 // 默认值
 | 
							poolSize = 10 // 默认值
 | 
				
			||||||
| 
						 | 
					@ -60,7 +60,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
 | 
				
			||||||
		pool.Put(pool.New())
 | 
							pool.Put(pool.New())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &UtilOllama{
 | 
						return &UtilLangChain{
 | 
				
			||||||
		LlmClientPool: pool,
 | 
							LlmClientPool: pool,
 | 
				
			||||||
		poolSize:      poolSize,
 | 
							poolSize:      poolSize,
 | 
				
			||||||
		model:         c.Ollama.Model,
 | 
							model:         c.Ollama.Model,
 | 
				
			||||||
| 
						 | 
					@ -69,7 +69,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (o *UtilOllama) NewClient() *ollama.LLM {
 | 
					func (o *UtilLangChain) NewClient() *ollama.LLM {
 | 
				
			||||||
	llm, _ := ollama.New(
 | 
						llm, _ := ollama.New(
 | 
				
			||||||
		ollama.WithModel(o.c.Ollama.Model),
 | 
							ollama.WithModel(o.c.Ollama.Model),
 | 
				
			||||||
		ollama.WithHTTPClient(&http.Client{
 | 
							ollama.WithHTTPClient(&http.Client{
 | 
				
			||||||
| 
						 | 
					@ -91,13 +91,13 @@ func (o *UtilOllama) NewClient() *ollama.LLM {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Get 返回一个可用的 LLM 客户端
 | 
					// Get 返回一个可用的 LLM 客户端
 | 
				
			||||||
func (o *UtilOllama) Get() *LlmObj {
 | 
					func (o *UtilLangChain) Get() *LlmObj {
 | 
				
			||||||
	client := o.LlmClientPool.Get().(*LlmObj)
 | 
						client := o.LlmClientPool.Get().(*LlmObj)
 | 
				
			||||||
	return client
 | 
						return client
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Put 归还客户端(可选:检查是否仍可用)
 | 
					// Put 归还客户端(可选:检查是否仍可用)
 | 
				
			||||||
func (o *UtilOllama) Put(llm *LlmObj) {
 | 
					func (o *UtilLangChain) Put(llm *LlmObj) {
 | 
				
			||||||
	if llm == nil {
 | 
						if llm == nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -105,7 +105,7 @@ func (o *UtilOllama) Put(llm *LlmObj) {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Stats 返回池的统计信息(用于监控)
 | 
					// Stats 返回池的统计信息(用于监控)
 | 
				
			||||||
func (o *UtilOllama) Stats() (current, max int) {
 | 
					func (o *UtilLangChain) Stats() (current, max int) {
 | 
				
			||||||
	return o.poolSize, o.poolSize
 | 
						return o.poolSize, o.poolSize
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,13 +5,11 @@ import (
 | 
				
			||||||
	"ai_scheduler/internal/entitys"
 | 
						"ai_scheduler/internal/entitys"
 | 
				
			||||||
	"ai_scheduler/internal/pkg/l_request"
 | 
						"ai_scheduler/internal/pkg/l_request"
 | 
				
			||||||
	"bufio"
 | 
						"bufio"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/gofiber/fiber/v2/log"
 | 
					 | 
				
			||||||
	"github.com/gofiber/websocket/v2"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 知识库工具
 | 
					// 知识库工具
 | 
				
			||||||
| 
						 | 
					@ -60,22 +58,10 @@ func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Execute 执行知识库查询
 | 
					// Execute 执行知识库查询
 | 
				
			||||||
func (k *KnowledgeBaseTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error {
 | 
					func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var params KnowledgeBaseRequest
 | 
						return k.chat(requireData)
 | 
				
			||||||
	if err := json.Unmarshal(args, ¶ms); err != nil {
 | 
					 | 
				
			||||||
		return fmt.Errorf("unmarshal args failed: %w", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	log.Info("开始执行知识库 KnowledgeBaseTool Execute, params: %v", params)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return k.chat(channel, c, params)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type KnowledgeBaseRequest struct {
 | 
					 | 
				
			||||||
	Session string // 知识库会话id
 | 
					 | 
				
			||||||
	ApiKey  string // 知识库apiKey
 | 
					 | 
				
			||||||
	Query   string // 用户输入
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Message 表示解析后的 SSE 消息
 | 
					// Message 表示解析后的 SSE 消息
 | 
				
			||||||
| 
						 | 
					@ -110,20 +96,20 @@ func (this *KnowledgeBaseTool) msgContentParse(input string, channel chan entity
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 请求知识库聊天
 | 
					// 请求知识库聊天
 | 
				
			||||||
func (this *KnowledgeBaseTool) chat(channel chan entitys.Response, c *websocket.Conn, param KnowledgeBaseRequest) (err error) {
 | 
					func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := l_request.Request{
 | 
						req := l_request.Request{
 | 
				
			||||||
		Method: "post",
 | 
							Method: "post",
 | 
				
			||||||
		Url:    this.config.BaseURL + "/api/v1/knowledge-chat/" + param.Session,
 | 
							Url:    this.config.BaseURL + "/api/v1/knowledge-chat/" + requireData.KnowledgeConf.Session,
 | 
				
			||||||
		Params: nil,
 | 
							Params: nil,
 | 
				
			||||||
		Headers: map[string]string{
 | 
							Headers: map[string]string{
 | 
				
			||||||
			"Content-Type": "application/json",
 | 
								"Content-Type": "application/json",
 | 
				
			||||||
			"X-API-Key":    param.ApiKey,
 | 
								"X-API-Key":    requireData.KnowledgeConf.ApiKey,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		Cookies: nil,
 | 
							Cookies: nil,
 | 
				
			||||||
		Data:    nil,
 | 
							Data:    nil,
 | 
				
			||||||
		Json: map[string]interface{}{
 | 
							Json: map[string]interface{}{
 | 
				
			||||||
			"query": param.Query,
 | 
								"query": requireData.KnowledgeConf.Query,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		Files:    nil,
 | 
							Files:    nil,
 | 
				
			||||||
		Raw:      "",
 | 
							Raw:      "",
 | 
				
			||||||
| 
						 | 
					@ -137,7 +123,7 @@ func (this *KnowledgeBaseTool) chat(channel chan entitys.Response, c *websocket.
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	defer rsp.Body.Close()
 | 
						defer rsp.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = this.connectAndReadSSE(rsp, channel)
 | 
						err = this.connectAndReadSSE(rsp, requireData.Ch)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,11 +5,9 @@ import (
 | 
				
			||||||
	"ai_scheduler/internal/data/constants"
 | 
						"ai_scheduler/internal/data/constants"
 | 
				
			||||||
	"ai_scheduler/internal/entitys"
 | 
						"ai_scheduler/internal/entitys"
 | 
				
			||||||
	"ai_scheduler/internal/pkg/utils_ollama"
 | 
						"ai_scheduler/internal/pkg/utils_ollama"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"encoding/json"
 | 
					 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/gofiber/websocket/v2"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Manager 工具管理器
 | 
					// Manager 工具管理器
 | 
				
			||||||
| 
						 | 
					@ -100,13 +98,13 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ExecuteTool 执行工具
 | 
					// ExecuteTool 执行工具
 | 
				
			||||||
func (m *Manager) ExecuteTool(channel chan entitys.Response, c *websocket.Conn, name string, args json.RawMessage, matchJson *entitys.Match) error {
 | 
					func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error {
 | 
				
			||||||
	tool, exists := m.GetTool(name)
 | 
						tool, exists := m.GetTool(name)
 | 
				
			||||||
	if !exists {
 | 
						if !exists {
 | 
				
			||||||
		return fmt.Errorf("tool not found: %s", name)
 | 
							return fmt.Errorf("tool not found: %s", name)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return tool.Execute(channel, c, args, matchJson)
 | 
						return tool.Execute(ctx, requireData)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ExecuteToolCalls 执行多个工具调用
 | 
					// ExecuteToolCalls 执行多个工具调用
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,13 +3,13 @@ package tools
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"ai_scheduler/internal/config"
 | 
						"ai_scheduler/internal/config"
 | 
				
			||||||
	"ai_scheduler/internal/entitys"
 | 
						"ai_scheduler/internal/entitys"
 | 
				
			||||||
 | 
						"ai_scheduler/internal/pkg"
 | 
				
			||||||
	"ai_scheduler/internal/pkg/utils_ollama"
 | 
						"ai_scheduler/internal/pkg/utils_ollama"
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gitea.cdlsxd.cn/self-tools/l_request"
 | 
						"gitea.cdlsxd.cn/self-tools/l_request"
 | 
				
			||||||
	"github.com/gofiber/websocket/v2"
 | 
					 | 
				
			||||||
	"github.com/ollama/ollama/api"
 | 
						"github.com/ollama/ollama/api"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -81,34 +81,26 @@ type ZltxOrderDetailData struct {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Execute 执行直连天下订单详情查询
 | 
					// Execute 执行直连天下订单详情查询
 | 
				
			||||||
func (w *ZltxOrderDetailTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error {
 | 
					func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
 | 
				
			||||||
	var req ZltxOrderDetailRequest
 | 
						var req ZltxOrderDetailRequest
 | 
				
			||||||
	if err := json.Unmarshal(args, &req); err != nil {
 | 
						if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
 | 
				
			||||||
		return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
 | 
							return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	if req.OrderNumber == "" {
 | 
						if req.OrderNumber == "" {
 | 
				
			||||||
		return fmt.Errorf("number is required")
 | 
							return fmt.Errorf("number is required")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 这里可以集成真实的直连天下订单详情API
 | 
						// 这里可以集成真实的直连天下订单详情API
 | 
				
			||||||
	return w.getZltxOrderDetail(channel, c, req.OrderNumber, matchJson)
 | 
						return w.getZltxOrderDetail(requireData, req.OrderNumber)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
 | 
					// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
 | 
				
			||||||
func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *websocket.Conn, number string, matchJson *entitys.Match) (err error) {
 | 
					func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireData, number string) (err error) {
 | 
				
			||||||
	//查询订单详情
 | 
						//查询订单详情
 | 
				
			||||||
	var auth string
 | 
					 | 
				
			||||||
	if c != nil {
 | 
					 | 
				
			||||||
		auth = c.Query("x-authorization", "")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if len(auth) == 0 {
 | 
					 | 
				
			||||||
		auth = w.config.APIKey
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	req := l_request.Request{
 | 
						req := l_request.Request{
 | 
				
			||||||
		Url: fmt.Sprintf("%szltx_api/admin/direct/ai/%s", w.config.BaseURL, number),
 | 
							Url: fmt.Sprintf("%szltx_api/admin/direct/ai/%s", w.config.BaseURL, number),
 | 
				
			||||||
		Headers: map[string]string{
 | 
							Headers: map[string]string{
 | 
				
			||||||
			"Authorization": fmt.Sprintf("Bearer %s", auth),
 | 
								"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		Method: "GET",
 | 
							Method: "GET",
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -129,13 +121,13 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
 | 
				
			||||||
	if err = json.Unmarshal(res.Content, &resData); err != nil {
 | 
						if err = json.Unmarshal(res.Content, &resData); err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	ch <- entitys.Response{
 | 
						requireData.Ch <- entitys.Response{
 | 
				
			||||||
		Index:   w.Name(),
 | 
							Index:   w.Name(),
 | 
				
			||||||
		Content: res.Text,
 | 
							Content: res.Text,
 | 
				
			||||||
		Type:    entitys.ResponseJson,
 | 
							Type:    entitys.ResponseJson,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if resData.Data.Direct != nil && resData.Data.Direct["needAi"].(bool) {
 | 
						if resData.Data.Direct != nil && resData.Data.Direct["needAi"].(bool) {
 | 
				
			||||||
		ch <- entitys.Response{
 | 
							requireData.Ch <- entitys.Response{
 | 
				
			||||||
			Index:   w.Name(),
 | 
								Index:   w.Name(),
 | 
				
			||||||
			Content: "正在分析订单日志",
 | 
								Content: "正在分析订单日志",
 | 
				
			||||||
			Type:    entitys.ResponseLoading,
 | 
								Type:    entitys.ResponseLoading,
 | 
				
			||||||
| 
						 | 
					@ -144,7 +136,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
 | 
				
			||||||
		req = l_request.Request{
 | 
							req = l_request.Request{
 | 
				
			||||||
			Url: fmt.Sprintf("%szltx_api/admin/direct/log/%s/%s", w.config.BaseURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)),
 | 
								Url: fmt.Sprintf("%szltx_api/admin/direct/log/%s/%s", w.config.BaseURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)),
 | 
				
			||||||
			Headers: map[string]string{
 | 
								Headers: map[string]string{
 | 
				
			||||||
				"Authorization": fmt.Sprintf("Bearer %s", auth),
 | 
									"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			Method: "GET",
 | 
								Method: "GET",
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
| 
						 | 
					@ -164,14 +156,14 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
 | 
				
			||||||
			return fmt.Errorf("订单日志解析失败:%s", err)
 | 
								return fmt.Errorf("订单日志解析失败:%s", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		err = w.llm.ChatStream(context.TODO(), ch, []api.Message{
 | 
							err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{
 | 
				
			||||||
			{
 | 
								{
 | 
				
			||||||
				Role:    "system",
 | 
									Role:    "system",
 | 
				
			||||||
				Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,提取出订单失败的原因。",
 | 
									Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,提取出订单失败的原因。",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			{
 | 
								{
 | 
				
			||||||
				Role:    "assistant",
 | 
									Role:    "assistant",
 | 
				
			||||||
				Content: fmt.Sprintf("聊天记录:%s", string(matchJson.History)),
 | 
									Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)),
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			{
 | 
								{
 | 
				
			||||||
				Role:    "assistant",
 | 
									Role:    "assistant",
 | 
				
			||||||
| 
						 | 
					@ -179,7 +171,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			{
 | 
								{
 | 
				
			||||||
				Role:    "user",
 | 
									Role:    "user",
 | 
				
			||||||
				Content: matchJson.UserInput,
 | 
									Content: requireData.UserInput,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
		}, w.Name())
 | 
							}, w.Name())
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -187,15 +179,11 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if resData.Data.Direct == nil {
 | 
						if resData.Data.Direct == nil {
 | 
				
			||||||
		ch <- entitys.Response{
 | 
							requireData.Ch <- entitys.Response{
 | 
				
			||||||
			Index:   w.Name(),
 | 
								Index:   w.Name(),
 | 
				
			||||||
			Content: "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘",
 | 
								Content: "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘",
 | 
				
			||||||
			Type:    entitys.ResponseText,
 | 
								Type:    entitys.ResponseText,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// else {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	//}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,11 +3,11 @@ package tools
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"ai_scheduler/internal/config"
 | 
						"ai_scheduler/internal/config"
 | 
				
			||||||
	"ai_scheduler/internal/entitys"
 | 
						"ai_scheduler/internal/entitys"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gitea.cdlsxd.cn/self-tools/l_request"
 | 
						"gitea.cdlsxd.cn/self-tools/l_request"
 | 
				
			||||||
	"github.com/gofiber/websocket/v2"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ZltxOrderLogTool struct {
 | 
					type ZltxOrderLogTool struct {
 | 
				
			||||||
| 
						 | 
					@ -67,31 +67,25 @@ type ZltxOrderDirectLogData struct {
 | 
				
			||||||
	Data     map[string]interface{} `json:"data"`
 | 
						Data     map[string]interface{} `json:"data"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *ZltxOrderLogTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error {
 | 
					func (t *ZltxOrderLogTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
 | 
				
			||||||
	var req ZltxOrderLogRequest
 | 
						var req ZltxOrderLogRequest
 | 
				
			||||||
	if err := json.Unmarshal(args, &req); err != nil {
 | 
						if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
 | 
				
			||||||
		return fmt.Errorf("invalid zltxOrderLog request: %w", err)
 | 
							return fmt.Errorf("invalid zltxOrderLog request: %w", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if req.OrderNumber == "" || req.SerialNumber == "" {
 | 
						if req.OrderNumber == "" || req.SerialNumber == "" {
 | 
				
			||||||
		return fmt.Errorf("orderNumber and serialNumber is required")
 | 
							return fmt.Errorf("orderNumber and serialNumber is required")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return t.getZltxOrderLog(channel, c, req.OrderNumber, req.SerialNumber)
 | 
						return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, requireData)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.Response, c *websocket.Conn, orderNumber, serialNumber string) (err error) {
 | 
					func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, requireData *entitys.RequireData) (err error) {
 | 
				
			||||||
	//查询订单详情
 | 
						//查询订单详情
 | 
				
			||||||
	var auth string
 | 
					
 | 
				
			||||||
	if c != nil {
 | 
					 | 
				
			||||||
		auth = c.Query("x-authorization", "")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if len(auth) == 0 {
 | 
					 | 
				
			||||||
		auth = t.config.APIKey
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber)
 | 
						url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber)
 | 
				
			||||||
	req := l_request.Request{
 | 
						req := l_request.Request{
 | 
				
			||||||
		Url: url,
 | 
							Url: url,
 | 
				
			||||||
		Headers: map[string]string{
 | 
							Headers: map[string]string{
 | 
				
			||||||
			"Authorization": fmt.Sprintf("Bearer %s", auth),
 | 
								"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		Method: "GET",
 | 
							Method: "GET",
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -106,7 +100,7 @@ func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.Response, c *web
 | 
				
			||||||
	if err = json.Unmarshal(res.Content, &resData); err != nil {
 | 
						if err = json.Unmarshal(res.Content, &resData); err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	channel <- entitys.Response{
 | 
						requireData.Ch <- entitys.Response{
 | 
				
			||||||
		Index:   t.Name(),
 | 
							Index:   t.Name(),
 | 
				
			||||||
		Content: res.Text,
 | 
							Content: res.Text,
 | 
				
			||||||
		Type:    entitys.ResponseJson,
 | 
							Type:    entitys.ResponseJson,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,13 +3,13 @@ package tools
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"ai_scheduler/internal/config"
 | 
						"ai_scheduler/internal/config"
 | 
				
			||||||
	"ai_scheduler/internal/entitys"
 | 
						"ai_scheduler/internal/entitys"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gitea.cdlsxd.cn/self-tools/l_request"
 | 
						"gitea.cdlsxd.cn/self-tools/l_request"
 | 
				
			||||||
	"github.com/gofiber/websocket/v2"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ZltxProductTool struct {
 | 
					type ZltxProductTool struct {
 | 
				
			||||||
| 
						 | 
					@ -53,12 +53,12 @@ type ZltxProductRequest struct {
 | 
				
			||||||
	Name string `json:"name"`
 | 
						Name string `json:"name"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (z ZltxProductTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error {
 | 
					func (z ZltxProductTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
 | 
				
			||||||
	var req ZltxProductRequest
 | 
						var req ZltxProductRequest
 | 
				
			||||||
	if err := json.Unmarshal(args, &req); err != nil {
 | 
						if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
 | 
				
			||||||
		return fmt.Errorf("invalid zltxProduct request: %w", err)
 | 
							return fmt.Errorf("invalid zltxProduct request: %w", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return z.getZltxProduct(channel, c, req.Id, req.Name)
 | 
						return z.getZltxProduct(&req, requireData)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ZltxProductResponse struct {
 | 
					type ZltxProductResponse struct {
 | 
				
			||||||
| 
						 | 
					@ -133,22 +133,16 @@ type ZltxProductData struct {
 | 
				
			||||||
	PlatformProductList interface{} `json:"platform_product_list"`
 | 
						PlatformProductList interface{} `json:"platform_product_list"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websocket.Conn, id string, name string) error {
 | 
					func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *entitys.RequireData) error {
 | 
				
			||||||
	var auth string
 | 
					
 | 
				
			||||||
	if c != nil {
 | 
					 | 
				
			||||||
		auth = c.Query("x-authorization", "")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if len(auth) == 0 {
 | 
					 | 
				
			||||||
		auth = z.config.APIKey
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	var Url string
 | 
						var Url string
 | 
				
			||||||
	var params map[string]string
 | 
						var params map[string]string
 | 
				
			||||||
	if id != "" {
 | 
						if body.Id != "" {
 | 
				
			||||||
		Url = fmt.Sprintf("%s/%s", z.config.BaseURL, id)
 | 
							Url = fmt.Sprintf("%s/%s", z.config.BaseURL, body.Id)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		Url = fmt.Sprintf("%s?keyword=%s&limit=10&page=1", z.config.BaseURL, name)
 | 
							Url = fmt.Sprintf("%s?keyword=%s&limit=10&page=1", z.config.BaseURL, body.Name)
 | 
				
			||||||
		params = map[string]string{
 | 
							params = map[string]string{
 | 
				
			||||||
			"keyword": name,
 | 
								"keyword": body.Name,
 | 
				
			||||||
			"limit":   "10",
 | 
								"limit":   "10",
 | 
				
			||||||
			"page":    "1",
 | 
								"page":    "1",
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
| 
						 | 
					@ -159,7 +153,7 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websoc
 | 
				
			||||||
		//根据商品ID或名称走不同的接口查询
 | 
							//根据商品ID或名称走不同的接口查询
 | 
				
			||||||
		Url: Url,
 | 
							Url: Url,
 | 
				
			||||||
		Headers: map[string]string{
 | 
							Headers: map[string]string{
 | 
				
			||||||
			"Authorization": fmt.Sprintf("Bearer %s", auth),
 | 
								"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		Params: params,
 | 
							Params: params,
 | 
				
			||||||
		Method: "GET",
 | 
							Method: "GET",
 | 
				
			||||||
| 
						 | 
					@ -191,7 +185,7 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websoc
 | 
				
			||||||
		for i := range resp.Data.List {
 | 
							for i := range resp.Data.List {
 | 
				
			||||||
			// 调用 平台商品列表
 | 
								// 调用 平台商品列表
 | 
				
			||||||
			if resp.Data.List[i].AuthProductIds != "" {
 | 
								if resp.Data.List[i].AuthProductIds != "" {
 | 
				
			||||||
				platformProductList := z.ExecutePlatformProductList(auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID)
 | 
									platformProductList := z.ExecutePlatformProductList(requireData.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID)
 | 
				
			||||||
				resp.Data.List[i].PlatformProductList = platformProductList
 | 
									resp.Data.List[i].PlatformProductList = platformProductList
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
| 
						 | 
					@ -200,7 +194,7 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websoc
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	channel <- entitys.Response{
 | 
						requireData.Ch <- entitys.Response{
 | 
				
			||||||
		Index:   z.Name(),
 | 
							Index:   z.Name(),
 | 
				
			||||||
		Content: string(marshal),
 | 
							Content: string(marshal),
 | 
				
			||||||
		Type:    entitys.ResponseJson,
 | 
							Type:    entitys.ResponseJson,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,12 +3,12 @@ package tools
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"ai_scheduler/internal/config"
 | 
						"ai_scheduler/internal/config"
 | 
				
			||||||
	"ai_scheduler/internal/entitys"
 | 
						"ai_scheduler/internal/entitys"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"sort"
 | 
						"sort"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gitea.cdlsxd.cn/self-tools/l_request"
 | 
						"gitea.cdlsxd.cn/self-tools/l_request"
 | 
				
			||||||
	"github.com/gofiber/websocket/v2"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ZltxOrderStatisticsTool struct {
 | 
					type ZltxOrderStatisticsTool struct {
 | 
				
			||||||
| 
						 | 
					@ -47,15 +47,15 @@ type ZltxOrderStatisticsRequest struct {
 | 
				
			||||||
	Number string `json:"number"`
 | 
						Number string `json:"number"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (z ZltxOrderStatisticsTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error {
 | 
					func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
 | 
				
			||||||
	var req ZltxOrderStatisticsRequest
 | 
						var req ZltxOrderStatisticsRequest
 | 
				
			||||||
	if err := json.Unmarshal(args, &req); err != nil {
 | 
						if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if req.Number == "" {
 | 
						if req.Number == "" {
 | 
				
			||||||
		return fmt.Errorf("number is required")
 | 
							return fmt.Errorf("number is required")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return z.getZltxOrderStatistics(channel, c, req.Number)
 | 
						return z.getZltxOrderStatistics(req.Number, requireData)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ZltxOrderStatisticsResponse struct {
 | 
					type ZltxOrderStatisticsResponse struct {
 | 
				
			||||||
| 
						 | 
					@ -75,20 +75,14 @@ type ZltxOrderStatisticsData struct {
 | 
				
			||||||
	Total   int    `json:"total"`
 | 
						Total   int    `json:"total"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.Response, c *websocket.Conn, number string) error {
 | 
					func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireData *entitys.RequireData) error {
 | 
				
			||||||
	//查询订单详情
 | 
						//查询订单详情
 | 
				
			||||||
	var auth string
 | 
					
 | 
				
			||||||
	if c != nil {
 | 
					 | 
				
			||||||
		auth = c.Query("x-authorization", "")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if len(auth) == 0 {
 | 
					 | 
				
			||||||
		auth = z.config.APIKey
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	url := fmt.Sprintf("%s%s", z.config.BaseURL, number)
 | 
						url := fmt.Sprintf("%s%s", z.config.BaseURL, number)
 | 
				
			||||||
	req := l_request.Request{
 | 
						req := l_request.Request{
 | 
				
			||||||
		Url: url,
 | 
							Url: url,
 | 
				
			||||||
		Headers: map[string]string{
 | 
							Headers: map[string]string{
 | 
				
			||||||
			"Authorization": fmt.Sprintf("Bearer %s", auth),
 | 
								"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		Method: "GET",
 | 
							Method: "GET",
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -114,7 +108,7 @@ func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.Res
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	channel <- entitys.Response{
 | 
						requireData.Ch <- entitys.Response{
 | 
				
			||||||
		Index:   z.Name(),
 | 
							Index:   z.Name(),
 | 
				
			||||||
		Content: string(jsonByte),
 | 
							Content: string(jsonByte),
 | 
				
			||||||
		Type:    entitys.ResponseJson,
 | 
							Type:    entitys.ResponseJson,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue