package do import ( "ai_scheduler/internal/biz/llm_service" "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/tools" "ai_scheduler/internal/tools_bot" "context" "encoding/json" "fmt" "github.com/gofiber/fiber/v2/log" "gorm.io/gorm/utils" "net/http" "strings" ) type Handle struct { Ollama *llm_service.OllamaService toolManager *tools.Manager Bot *tools_bot.BotTool conf *config.Config sessionImpl *impl.SessionImpl } func NewHandle( Ollama *llm_service.OllamaService, toolManager *tools.Manager, conf *config.Config, sessionImpl *impl.SessionImpl, dTalkBot *tools_bot.BotTool, ) *Handle { return &Handle{ Ollama: Ollama, toolManager: toolManager, conf: conf, sessionImpl: sessionImpl, Bot: dTalkBot, } } func (r *Handle) Recognize(ctx context.Context, requireData *entitys.RequireData) (err error) { entitys.ResLog(requireData.Ch, "recognize_start", "准备意图识别") //意图识别 recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData) if err != nil { return } entitys.ResLog(requireData.Ch, "recognize", recognizeMsg) entitys.ResLog(requireData.Ch, "recognize_end", "意图识别结束") var match entitys.Match if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil { err = errors.SysErr("数据结构错误:%v", err.Error()) return } requireData.Match = &match return } func (r *Handle) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) { entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning) return } func (r *Handle) HandleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) { if !requireData.Match.IsMatch { if len(requireData.Match.Chat) != 0 { entitys.ResText(requireData.Ch, "", requireData.Match.Chat) } else { entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning) } return } var pointTask *model.AiTask for _, task := range requireData.Tasks { if task.Index == requireData.Match.Index { pointTask = &task break } } if pointTask == nil || pointTask.Index == "other" { return r.OtherTask(ctx, requireData) } // 校验用户权限 if err = r.PermissionAuth(requireData, pointTask); err != nil { log.Errorf("权限验证失败: %s", err.Error()) entitys.ResLog(requireData.Ch, "", "权限验证失败:"+err.Error()) return } switch constants.TaskType(pointTask.Type) { case constants.TaskTypeApi: return r.handleApiTask(ctx, requireData, pointTask) case constants.TaskTypeFunc: return r.handleTask(ctx, requireData, pointTask) case constants.TaskTypeKnowle: return r.handleKnowle(ctx, requireData, pointTask) case constants.TaskTypeBot: return r.handleBot(ctx, requireData, pointTask) default: return r.handleOtherTask(ctx, requireData) } } func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) { entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning) return } func (r *Handle) handleBot(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } err = r.Bot.Execute(ctx, configData.Tool, requireData) if err != nil { return } return } func (r *Handle) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) if err != nil { return } return } // 知识库 func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { var ( configData entitys.ConfigDataTool sessionIdKnowledge string query string host string ) err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } // 通过session 找到知识库session var has bool if len(requireData.Session) == 0 { return errors.SessionNotFound } requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session)) if err != nil { return } else if !has { return errors.SessionNotFound } // 找到知识库的host { tool, exists := r.toolManager.GetTool(configData.Tool) if !exists { return fmt.Errorf("tool not found: %s", configData.Tool) } if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok { return fmt.Errorf("未找到知识库Tool: %s", configData.Tool) } else { host = knowledgeTool.GetConfig().BaseURL } } // 知识库的session为空,请求知识库获取, 并绑定 if requireData.SessionInfo.KnowlegeSessionID == "" { // 请求知识库 if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil { return } // 绑定知识库session,下次可以使用 requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil { return } } // 用户输入解析 var ok bool input := make(map[string]string) if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil { return } if query, ok = input["query"]; !ok { return fmt.Errorf("query不能为空") } requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{ Session: requireData.SessionInfo.KnowlegeSessionID, ApiKey: requireData.Sys.KnowlegeTenantKey, Query: query, } // 执行工具 err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) if err != nil { return } return } func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { var ( request l_request.Request requestParam map[string]interface{} ) err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam) if err != nil { return } request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) for k, v := range requestParam { task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v)) } var configData entitys.ConfigDataHttp err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } err = mapstructure.Decode(configData.Request, &request) if err != nil { return } if len(request.Url) == 0 { err = errors.NewBusinessErr(422, "api地址获取失败") return } res, err := request.Send() if err != nil { return } entitys.ResJson(requireData.Ch, "", pkg.JsonStringIgonErr(res.Text)) return } // 权限验证 func (r *Handle) PermissionAuth(requireData *entitys.RequireData, pointTask *model.AiTask) (err error) { // 白名单接口不要校验权限 if utils.Contains(r.conf.PermissionConfig.WhiteList, pointTask.Index) { return nil } // 查询用户权限 var ( request l_request.Request ) request.Url = r.conf.PermissionConfig.UnifiedLoginPlatformBaseURL request.Method = "GET" request.Headers = map[string]string{ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", "Accept": "application/json, text/plain, */*", "Authorization": "Bearer " + requireData.Auth, } // 发送请求 res, err := request.Send() if err != nil { return err } // 检查响应状态码 if res.StatusCode != http.StatusOK { return fmt.Errorf("unexpected status code: %d", res.StatusCode) } type resp struct { Codes []string `json:"codes"` } // 解析响应体 var respBody resp err = json.Unmarshal([]byte(res.Text), &respBody) if err != nil { return err } // 检查权限 if !utils.Contains(respBody.Codes, pointTask.Index) { return fmt.Errorf("用户权限不足: %s", pointTask.Name) } return nil }