package biz import ( "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/tools" "ai_scheduler/tmpl/dataTemp" "context" "encoding/json" "fmt" "strings" "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" "github.com/tmc/langchaingo/llms" "xorm.io/builder" ) // AiRouterBiz 智能路由服务 type AiRouterBiz struct { //aiClient entitys.AIClient toolManager *tools.Manager sessionImpl *impl.SessionImpl sysImpl *impl.SysImpl taskImpl *impl.TaskImpl hisImpl *impl.ChatImpl conf *config.Config utilAgent *utils_ollama.UtilOllama channelPool *pkg.SafeChannelPool } // NewRouterService 创建路由服务 func NewAiRouterBiz( //aiClient entitys.AIClient, toolManager *tools.Manager, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, taskImpl *impl.TaskImpl, hisImpl *impl.ChatImpl, conf *config.Config, utilAgent *utils_ollama.UtilOllama, channelPool *pkg.SafeChannelPool, ) *AiRouterBiz { return &AiRouterBiz{ //aiClient: aiClient, toolManager: toolManager, sessionImpl: sessionImpl, conf: conf, sysImpl: sysImpl, hisImpl: hisImpl, taskImpl: taskImpl, utilAgent: utilAgent, channelPool: channelPool, } } // Route 执行智能路由 func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) { return nil, nil } // Route 执行智能路由 func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { ch, err := r.channelPool.Get() if err != nil { return err } defer func() { if err != nil { entitys.MsgSend(c, entitys.ResponseData{ Done: false, Content: err.Error(), Type: entitys.ResponseErr, }) } entitys.MsgSend(c, entitys.ResponseData{ Done: true, Content: "", Type: entitys.ResponseEnd, }) err = r.channelPool.Put(ch) if err != nil { close(ch) } }() session := c.Headers("X-Session", "") if len(session) == 0 { return errors.SessionNotFound } auth := c.Headers("X-Authorization", "") if len(auth) == 0 { return errors.AuthNotFound } key := c.Headers("X-App-Key", "") if len(key) == 0 { return errors.KeyNotFound } sysInfo, err := r.getSysInfo(key) if err != nil { return errors.SysNotFound } history, err := r.getSessionChatHis(session) if err != nil { return errors.SystemError } task, err := r.getTasks(sysInfo.SysID) if err != nil { return errors.SystemError } //意图预测 prompt := r.getPromptLLM(sysInfo, history, req.Text, task) match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt, llms.WithJSONMode(), ) if err != nil { return errors.SystemError } log.Info(match.Choices[0].Content) ch <- entitys.ResponseData{ Done: false, Content: match.Choices[0].Content, Type: entitys.ResponseLog, } var matchJson entitys.Match err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson) if err != nil { return errors.SystemError } go func() { defer func() { if r := recover(); r != nil { err = fmt.Errorf("recovered from panic: %v", r) } }() defer close(ch) err = r.handleMatch(c, ch, &matchJson, task) if err != nil { return } }() for v := range ch { if err := entitys.MsgSend(c, v); err != nil { return err } } return } func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match) (err error) { ch <- entitys.ResponseData{ Done: false, Content: matchJson.Reasoning, Type: entitys.ResponseText, } return } func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask) (err error) { if !matchJson.IsMatch { ch <- entitys.ResponseData{ Done: false, Content: matchJson.Reasoning, Type: entitys.ResponseText, } return } var pointTask *model.AiTask for _, task := range tasks { if task.Index == matchJson.Index { pointTask = &task break } } if pointTask == nil || pointTask.Index == "other" { return r.handleOtherTask(c, ch, matchJson) } switch pointTask.Type { case constants.TaskTypeApi: return r.handleApiTask(ch, c, matchJson, pointTask) case constants.TaskTypeFunc: return r.handleTask(ch, c, matchJson, pointTask) default: return r.handleOtherTask(c, ch, matchJson) } } func (r *AiRouterBiz) handleTask(channel chan entitys.ResponseData, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } err = r.toolManager.ExecuteTool(channel, c, configData.Tool, []byte(matchJson.Parameters)) if err != nil { return } return } func (r *AiRouterBiz) handleApiTask(channels chan entitys.ResponseData, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var ( request l_request.Request auth = c.Headers("X-Authorization", "") requestParam map[string]interface{} ) err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam) if err != nil { return } request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth) for k, v := range requestParam { task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v)) } var configData entitys.ConfigDataHttp err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } err = mapstructure.Decode(configData.Request, &request) if err != nil { return } if len(request.Url) == 0 { err = errors.NewBusinessErr("00022", "api地址获取失败") return } res, err := request.Send() if err != nil { return } c.WriteMessage(1, res.Content) return } func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"session_id": sessionId}) _, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id asc") return } func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"app_key": appKey}) cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.Eq{"status": 1}) err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo) return } func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"sys_id": sysId}) cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.Eq{"status": 1}) _, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "") return } func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool { taskPrompt := make([]llms.Tool, 0) for _, task := range tasks { var taskConfig entitys.TaskConfig err := json.Unmarshal([]byte(task.Config), &taskConfig) if err != nil { continue } taskPrompt = append(taskPrompt, llms.Tool{ Type: "function", Function: &llms.FunctionDefinition{ Name: task.Index, Description: task.Desc, Parameters: taskConfig.Param, }, }) } return taskPrompt } func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message { var ( prompt = make([]entitys.Message, 0) ) prompt = append(prompt, entitys.Message{ Role: "system", Content: r.buildSystemPrompt(sysInfo.SysPrompt), }, entitys.Message{ Role: "assistant", Content: pkg.JsonStringIgonErr(r.buildAssistant(history)), }, entitys.Message{ Role: "user", Content: reqInput, }) return prompt } func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { var ( prompt = make([]llms.MessageContent, 0) ) prompt = append(prompt, llms.MessageContent{ Role: llms.ChatMessageTypeSystem, Parts: []llms.ContentPart{ llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)), }, }, llms.MessageContent{ Role: llms.ChatMessageTypeTool, Parts: []llms.ContentPart{ llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))), }, }, llms.MessageContent{ Role: llms.ChatMessageTypeTool, Parts: []llms.ContentPart{ llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), }, }, llms.MessageContent{ Role: llms.ChatMessageTypeHuman, Parts: []llms.ContentPart{ llms.TextPart(reqInput), }, }) return prompt } // buildSystemPrompt 构建系统提示词 func (r *AiRouterBiz) buildSystemPrompt(prompt string) string { if len(prompt) == 0 { prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容" } return prompt } func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) { for _, item := range his { if len(chatHis.SessionId) == 0 { chatHis.SessionId = item.SessionID } chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{ Role: item.Role, Content: item.Content, Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"), }) } chatHis.Context = entitys.HisContext{ UserLanguage: "zh-CN", SystemMode: "technical_support", } return } // handleKnowledgeQA 处理知识问答意图 func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { return nil, nil }