287 lines
7.5 KiB
Go
287 lines
7.5 KiB
Go
package do
|
||
|
||
import (
|
||
"ai_scheduler/internal/biz/llm_service"
|
||
"ai_scheduler/internal/config"
|
||
"ai_scheduler/internal/data/constants"
|
||
errors "ai_scheduler/internal/data/error"
|
||
"ai_scheduler/internal/data/impl"
|
||
"ai_scheduler/internal/data/model"
|
||
"ai_scheduler/internal/entitys"
|
||
"ai_scheduler/internal/gateway"
|
||
"ai_scheduler/internal/pkg"
|
||
"ai_scheduler/internal/pkg/l_request"
|
||
"ai_scheduler/internal/pkg/mapstructure"
|
||
"ai_scheduler/internal/tools"
|
||
"ai_scheduler/internal/tools_bot"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"gorm.io/gorm/utils"
|
||
)
|
||
|
||
type Handle struct {
|
||
Ollama *llm_service.OllamaService
|
||
toolManager *tools.Manager
|
||
Bot *tools_bot.BotTool
|
||
conf *config.Config
|
||
sessionImpl *impl.SessionImpl
|
||
}
|
||
|
||
func NewHandle(
|
||
Ollama *llm_service.OllamaService,
|
||
toolManager *tools.Manager,
|
||
conf *config.Config,
|
||
sessionImpl *impl.SessionImpl,
|
||
dTalkBot *tools_bot.BotTool,
|
||
) *Handle {
|
||
return &Handle{
|
||
Ollama: Ollama,
|
||
toolManager: toolManager,
|
||
conf: conf,
|
||
sessionImpl: sessionImpl,
|
||
Bot: dTalkBot,
|
||
}
|
||
}
|
||
|
||
func (r *Handle) Recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||
entitys.ResLog(requireData.Ch, "recognize_start", "准备意图识别")
|
||
|
||
//意图识别
|
||
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
entitys.ResLog(requireData.Ch, "recognize", recognizeMsg)
|
||
entitys.ResLog(requireData.Ch, "recognize_end", "意图识别结束")
|
||
|
||
var match entitys.Match
|
||
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
|
||
err = errors.SysErr("数据结构错误:%v", err.Error())
|
||
return
|
||
}
|
||
requireData.Match = &match
|
||
return
|
||
}
|
||
|
||
func (r *Handle) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
|
||
return
|
||
}
|
||
|
||
func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) {
|
||
|
||
if !requireData.Match.IsMatch {
|
||
if len(requireData.Match.Chat) != 0 {
|
||
entitys.ResText(requireData.Ch, "", requireData.Match.Chat)
|
||
} else {
|
||
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
|
||
}
|
||
|
||
return
|
||
}
|
||
var pointTask *model.AiTask
|
||
for _, task := range requireData.Tasks {
|
||
if task.Index == requireData.Match.Index {
|
||
pointTask = &task
|
||
break
|
||
}
|
||
}
|
||
|
||
if pointTask == nil || pointTask.Index == "other" {
|
||
return r.OtherTask(ctx, requireData)
|
||
}
|
||
|
||
// 校验用户权限
|
||
// if err = r.PermissionAuth(client, pointTask); err != nil {
|
||
// log.Errorf("权限验证失败: %s", err.Error())
|
||
// return
|
||
// }
|
||
|
||
switch constants.TaskType(pointTask.Type) {
|
||
case constants.TaskTypeApi:
|
||
return r.handleApiTask(ctx, requireData, pointTask)
|
||
case constants.TaskTypeFunc:
|
||
return r.handleTask(ctx, requireData, pointTask)
|
||
case constants.TaskTypeKnowle:
|
||
return r.handleKnowle(ctx, requireData, pointTask)
|
||
case constants.TaskTypeBot:
|
||
return r.handleBot(ctx, requireData, pointTask)
|
||
default:
|
||
return r.handleOtherTask(ctx, requireData)
|
||
}
|
||
}
|
||
|
||
func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
|
||
return
|
||
}
|
||
|
||
func (r *Handle) handleBot(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||
var configData entitys.ConfigDataTool
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = r.Bot.Execute(ctx, configData.Tool, requireData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (r *Handle) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||
var configData entitys.ConfigDataTool
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// 知识库
|
||
func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||
|
||
var (
|
||
configData entitys.ConfigDataTool
|
||
sessionIdKnowledge string
|
||
query string
|
||
host string
|
||
)
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 通过session 找到知识库session
|
||
var has bool
|
||
if len(requireData.Session) == 0 {
|
||
return errors.SessionNotFound
|
||
}
|
||
requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session))
|
||
if err != nil {
|
||
return
|
||
} else if !has {
|
||
return errors.SessionNotFound
|
||
}
|
||
|
||
// 找到知识库的host
|
||
{
|
||
tool, exists := r.toolManager.GetTool(configData.Tool)
|
||
if !exists {
|
||
return fmt.Errorf("tool not found: %s", configData.Tool)
|
||
}
|
||
|
||
if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok {
|
||
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
|
||
} else {
|
||
host = knowledgeTool.GetConfig().BaseURL
|
||
}
|
||
|
||
}
|
||
|
||
// 知识库的session为空,请求知识库获取, 并绑定
|
||
if requireData.SessionInfo.KnowlegeSessionID == "" {
|
||
// 请求知识库
|
||
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil {
|
||
return
|
||
}
|
||
|
||
// 绑定知识库session,下次可以使用
|
||
requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
|
||
if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil {
|
||
return
|
||
}
|
||
}
|
||
|
||
// 用户输入解析
|
||
var ok bool
|
||
input := make(map[string]string)
|
||
if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil {
|
||
return
|
||
}
|
||
if query, ok = input["query"]; !ok {
|
||
return fmt.Errorf("query不能为空")
|
||
}
|
||
|
||
requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{
|
||
Session: requireData.SessionInfo.KnowlegeSessionID,
|
||
ApiKey: requireData.Sys.KnowlegeTenantKey,
|
||
Query: query,
|
||
}
|
||
|
||
// 执行工具
|
||
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||
var (
|
||
request l_request.Request
|
||
requestParam map[string]interface{}
|
||
)
|
||
err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam)
|
||
if err != nil {
|
||
return
|
||
}
|
||
request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
|
||
for k, v := range requestParam {
|
||
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
|
||
}
|
||
var configData entitys.ConfigDataHttp
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = mapstructure.Decode(configData.Request, &request)
|
||
if err != nil {
|
||
return
|
||
}
|
||
if len(request.Url) == 0 {
|
||
err = errors.NewBusinessErr(422, "api地址获取失败")
|
||
return
|
||
}
|
||
res, err := request.Send()
|
||
if err != nil {
|
||
return
|
||
}
|
||
entitys.ResJson(requireData.Ch, "", pkg.JsonStringIgonErr(res.Text))
|
||
|
||
return
|
||
}
|
||
|
||
// 权限验证
|
||
func (r *Handle) PermissionAuth(client *gateway.Client, pointTask *model.AiTask) (err error) {
|
||
// 通用权限校验
|
||
if utils.Contains(r.conf.PermissionConfig.WhiteList, pointTask.Index) {
|
||
return nil
|
||
}
|
||
|
||
// 系统权限校验
|
||
if v, ok := r.conf.PermissionConfig.SysPermission[client.GetSysCode()]; !ok {
|
||
return fmt.Errorf("未配置系统权限校验: %s", client.GetSysCode())
|
||
} else if utils.Contains(v.WhiteList, pointTask.Index) {
|
||
return nil
|
||
}
|
||
|
||
// 授权检查权限
|
||
if !utils.Contains(client.GetCodes(), pointTask.Index) {
|
||
return fmt.Errorf("用户权限不足: %s", pointTask.Name)
|
||
}
|
||
|
||
return nil
|
||
}
|