package biz import ( "ai_scheduler/internal/biz/llm_service" "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/tools" "ai_scheduler/tmpl/dataTemp" "context" "encoding/json" "fmt" "strings" "time" "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" "xorm.io/builder" ) // AiRouterBiz 智能路由服务 type AiRouterBiz struct { toolManager *tools.Manager sessionImpl *impl.SessionImpl sysImpl *impl.SysImpl taskImpl *impl.TaskImpl hisImpl *impl.ChatImpl conf *config.Config rds *pkg.Rdb langChain *llm_service.LangChainService Ollama *llm_service.OllamaService } // NewRouterService 创建路由服务 func NewAiRouterBiz( toolManager *tools.Manager, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, taskImpl *impl.TaskImpl, hisImpl *impl.ChatImpl, conf *config.Config, langChain *llm_service.LangChainService, Ollama *llm_service.OllamaService, ) *AiRouterBiz { return &AiRouterBiz{ toolManager: toolManager, sessionImpl: sessionImpl, conf: conf, sysImpl: sysImpl, hisImpl: hisImpl, taskImpl: taskImpl, langChain: langChain, Ollama: Ollama, } } func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { //必要数据验证和获取 var requireData entitys.RequireData err = r.dataAuth(c, &requireData) if err != nil { return } //初始化通道/上下文 requireData.Ch = make(chan entitys.Response) ctx, cancel := context.WithCancel(context.Background()) // 启动独立的消息处理协程 done := r.startMessageHandler(ctx, c, &requireData) defer func() { close(requireData.Ch) //关闭主通道 <-done // 等待消息处理完成 cancel() }() //获取系统信息 err = r.getRequireData(req.Text, &requireData) if err != nil { log.Errorf("SQL error: %v", err) return } //意图识别 err = r.recognize(ctx, &requireData) if err != nil { log.Errorf("LLM error: %v", err) return } //向下传递 if err = r.handleMatch(ctx, &requireData); err != nil { log.Errorf("Handle error: %v", err) return } return } // startMessageHandler 启动独立的消息处理协程 func (r *AiRouterBiz) startMessageHandler( ctx context.Context, c *websocket.Conn, requireData *entitys.RequireData, ) <-chan struct{} { done := make(chan struct{}) var chat []string go func() { defer func() { close(done) // 保存历史记录 var his = []*model.AiChatHi{ { SessionID: requireData.Session, Role: "user", Content: "", // 用户输入在外部处理 }, } if len(chat) > 0 { his = append(his, &model.AiChatHi{ SessionID: requireData.Session, Role: "assistant", Content: strings.Join(chat, ""), }) } for _, hi := range his { r.hisImpl.Add(hi) } }() for v := range requireData.Ch { // 自动检测通道关闭 if err := sendWithTimeout(c, v, 2*time.Second); err != nil { log.Errorf("Send error: %v", err) return } if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson { chat = append(chat, v.Content) } } }() return done } // 辅助函数:带超时的 WebSocket 发送 func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error { sendCtx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() done := make(chan error, 1) go func() { defer func() { if r := recover(); r != nil { done <- fmt.Errorf("panic in MsgSend: %v", r) } close(done) }() // 如果 MsgSend 阻塞,这里会卡住 err := entitys.MsgSend(c, data) done <- err }() select { case err := <-done: return err case <-sendCtx.Done(): return sendCtx.Err() } } func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) { requireData.Ch <- entitys.Response{ Index: "", Content: "准备意图识别", Type: entitys.ResponseLog, } //意图识别 recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData) if err != nil { return } requireData.Ch <- entitys.Response{ Index: "", Content: recognizeMsg, Type: entitys.ResponseLog, } requireData.Ch <- entitys.Response{ Index: "", Content: "意图识别结束", Type: entitys.ResponseLog, } var match entitys.Match if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil { err = errors.SysErr("数据结构错误:%v", err.Error()) return } requireData.Match = &match return } func (r *AiRouterBiz) getRequireData(userInput string, requireData *entitys.RequireData) (err error) { requireData.Sys, err = r.getSysInfo(requireData.Key) if err != nil { err = errors.SysErr("获取系统信息失败:%v", err.Error()) return } requireData.Histories, err = r.getSessionChatHis(requireData.Session) if err != nil { err = errors.SysErr("获取历史记录失败:%v", err.Error()) return } requireData.Tasks, err = r.getTasks(requireData.Sys.SysID) if err != nil { err = errors.SysErr("获取任务列表失败:%v", err.Error()) return } requireData.UserInput = userInput if len(requireData.UserInput) == 0 { err = errors.SysErr("获取用户输入失败") return } if len(requireData.UserInput) == 0 { err = errors.SysErr("获取用户输入失败") return } return } func (r *AiRouterBiz) dataAuth(c *websocket.Conn, requireData *entitys.RequireData) (err error) { requireData.Session = c.Query("x-session", "") if len(requireData.Session) == 0 { err = errors.SessionNotFound return } requireData.Auth = c.Query("x-authorization", "") if len(requireData.Auth) == 0 { err = errors.AuthNotFound return } requireData.Key = c.Query("x-app-key", "") if len(requireData.Key) == 0 { err = errors.KeyNotFound return } return } func (r *AiRouterBiz) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) { requireData.Ch <- entitys.Response{ Index: "", Content: requireData.Match.Reasoning, Type: entitys.ResponseText, } return } func (r *AiRouterBiz) handleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) { if !requireData.Match.IsMatch { requireData.Ch <- entitys.Response{ Index: "", Content: requireData.Match.Chat, Type: entitys.ResponseText, } return } var pointTask *model.AiTask for _, task := range requireData.Tasks { if task.Index == requireData.Match.Index { pointTask = &task break } } if pointTask == nil || pointTask.Index == "other" { return r.handleOtherTask(ctx, requireData) } switch pointTask.Type { case constants.TaskTypeApi: return r.handleApiTask(ctx, requireData, pointTask) case constants.TaskTypeFunc: return r.handleTask(ctx, requireData, pointTask) case constants.TaskTypeKnowle: return r.handleKnowle(ctx, requireData, pointTask) default: return r.handleOtherTask(ctx, requireData) } } func (r *AiRouterBiz) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) if err != nil { return } return } // 知识库 func (r *AiRouterBiz) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { var ( configData entitys.ConfigDataTool sessionIdKnowledge string query string host string ) err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } // 通过session 找到知识库session var has bool if len(requireData.Session) == 0 { return errors.SessionNotFound } requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session)) if err != nil { return } else if !has { return errors.SessionNotFound } // 找到知识库的host { tool, exists := r.toolManager.GetTool(configData.Tool) if !exists { return fmt.Errorf("tool not found: %s", configData.Tool) } if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok { return fmt.Errorf("未找到知识库Tool: %s", configData.Tool) } else { host = knowledgeTool.GetConfig().BaseURL } } // 知识库的session为空,请求知识库获取, 并绑定 if requireData.SessionInfo.KnowlegeSessionID == "" { // 请求知识库 if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil { return } // 绑定知识库session,下次可以使用 requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil { return } } // 用户输入解析 var ok bool input := make(map[string]string) if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil { return } if query, ok = input["query"]; !ok { return fmt.Errorf("query不能为空") } requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{ Session: requireData.SessionInfo.KnowlegeSessionID, ApiKey: requireData.Sys.KnowlegeTenantKey, Query: query, } // 执行工具 err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) if err != nil { return } return } func (r *AiRouterBiz) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { var ( request l_request.Request requestParam map[string]interface{} ) err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam) if err != nil { return } request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) for k, v := range requestParam { task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v)) } var configData entitys.ConfigDataHttp err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } err = mapstructure.Decode(configData.Request, &request) if err != nil { return } if len(request.Url) == 0 { err = errors.NewBusinessErr(422, "api地址获取失败") return } res, err := request.Send() if err != nil { return } requireData.Ch <- entitys.Response{ Index: "", Content: pkg.JsonStringIgonErr(res.Text), Type: entitys.ResponseJson, } return } func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"session_id": sessionId}) _, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id desc") return } func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"app_key": appKey}) cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.Eq{"status": 1}) err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo) return } func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"sys_id": sysId}) cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.Eq{"status": 1}) _, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "") return }