494 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			494 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
package biz
 | 
						||
 | 
						||
import (
 | 
						||
	"ai_scheduler/internal/biz/llm_service"
 | 
						||
	"ai_scheduler/internal/config"
 | 
						||
	"ai_scheduler/internal/data/constants"
 | 
						||
	errors "ai_scheduler/internal/data/error"
 | 
						||
	"ai_scheduler/internal/data/impl"
 | 
						||
	"ai_scheduler/internal/data/model"
 | 
						||
	"ai_scheduler/internal/entitys"
 | 
						||
	"ai_scheduler/internal/pkg"
 | 
						||
	"ai_scheduler/internal/pkg/mapstructure"
 | 
						||
 | 
						||
	"ai_scheduler/internal/tools"
 | 
						||
	"ai_scheduler/tmpl/dataTemp"
 | 
						||
	"context"
 | 
						||
	"encoding/json"
 | 
						||
	"fmt"
 | 
						||
	"strings"
 | 
						||
	"time"
 | 
						||
 | 
						||
	"gitea.cdlsxd.cn/self-tools/l_request"
 | 
						||
	"github.com/gofiber/fiber/v2/log"
 | 
						||
	"github.com/gofiber/websocket/v2"
 | 
						||
	"xorm.io/builder"
 | 
						||
)
 | 
						||
 | 
						||
// AiRouterBiz 智能路由服务
 | 
						||
type AiRouterBiz struct {
 | 
						||
	toolManager *tools.Manager
 | 
						||
	sessionImpl *impl.SessionImpl
 | 
						||
	sysImpl     *impl.SysImpl
 | 
						||
	taskImpl    *impl.TaskImpl
 | 
						||
	hisImpl     *impl.ChatImpl
 | 
						||
	conf        *config.Config
 | 
						||
	rds         *pkg.Rdb
 | 
						||
	langChain   *llm_service.LangChainService
 | 
						||
	Ollama      *llm_service.OllamaService
 | 
						||
}
 | 
						||
 | 
						||
// NewRouterService 创建路由服务
 | 
						||
func NewAiRouterBiz(
 | 
						||
	toolManager *tools.Manager,
 | 
						||
	sessionImpl *impl.SessionImpl,
 | 
						||
	sysImpl *impl.SysImpl,
 | 
						||
	taskImpl *impl.TaskImpl,
 | 
						||
	hisImpl *impl.ChatImpl,
 | 
						||
	conf *config.Config,
 | 
						||
	langChain *llm_service.LangChainService,
 | 
						||
	Ollama *llm_service.OllamaService,
 | 
						||
) *AiRouterBiz {
 | 
						||
	return &AiRouterBiz{
 | 
						||
		toolManager: toolManager,
 | 
						||
		sessionImpl: sessionImpl,
 | 
						||
		conf:        conf,
 | 
						||
		sysImpl:     sysImpl,
 | 
						||
		hisImpl:     hisImpl,
 | 
						||
		taskImpl:    taskImpl,
 | 
						||
		langChain:   langChain,
 | 
						||
		Ollama:      Ollama,
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
 | 
						||
	//必要数据验证和获取
 | 
						||
	var requireData entitys.RequireData
 | 
						||
	err = r.dataAuth(c, &requireData)
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	//初始化通道/上下文
 | 
						||
	requireData.Ch = make(chan entitys.Response)
 | 
						||
	ctx, cancel := context.WithCancel(context.Background())
 | 
						||
	// 启动独立的消息处理协程
 | 
						||
	done := r.startMessageHandler(ctx, c, &requireData, req.Text)
 | 
						||
	defer func() {
 | 
						||
		close(requireData.Ch) //关闭主通道
 | 
						||
		<-done                // 等待消息处理完成
 | 
						||
		cancel()
 | 
						||
	}()
 | 
						||
 | 
						||
	//获取图片信息
 | 
						||
	err = r.getImgData(req.Img, &requireData)
 | 
						||
	if err != nil {
 | 
						||
		log.Errorf("GetImgData error: %v", err)
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	//获取系统信息
 | 
						||
	err = r.getRequireData(req.Text, &requireData)
 | 
						||
	if err != nil {
 | 
						||
		log.Errorf("SQL error: %v", err)
 | 
						||
		return
 | 
						||
	}
 | 
						||
	//意图识别
 | 
						||
	err = r.recognize(ctx, &requireData)
 | 
						||
	if err != nil {
 | 
						||
		log.Errorf("LLM error: %v", err)
 | 
						||
		return
 | 
						||
	}
 | 
						||
	//向下传递
 | 
						||
	if err = r.handleMatch(ctx, &requireData); err != nil {
 | 
						||
		log.Errorf("Handle error: %v", err)
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
// startMessageHandler 启动独立的消息处理协程
 | 
						||
func (r *AiRouterBiz) startMessageHandler(
 | 
						||
	ctx context.Context,
 | 
						||
	c *websocket.Conn,
 | 
						||
	requireData *entitys.RequireData,
 | 
						||
	userInput string,
 | 
						||
) <-chan struct{} {
 | 
						||
	done := make(chan struct{})
 | 
						||
	var chat []string
 | 
						||
 | 
						||
	go func() {
 | 
						||
		defer func() {
 | 
						||
			close(done)
 | 
						||
			// 保存历史记录
 | 
						||
			var his = []*model.AiChatHi{
 | 
						||
				{
 | 
						||
					SessionID: requireData.Session,
 | 
						||
					Role:      "user",
 | 
						||
					Content:   userInput, // 用户输入在外部处理
 | 
						||
				},
 | 
						||
			}
 | 
						||
			if len(chat) > 0 {
 | 
						||
				his = append(his, &model.AiChatHi{
 | 
						||
					SessionID: requireData.Session,
 | 
						||
					Role:      "assistant",
 | 
						||
					Content:   strings.Join(chat, ""),
 | 
						||
				})
 | 
						||
			}
 | 
						||
			for _, hi := range his {
 | 
						||
				r.hisImpl.Add(hi)
 | 
						||
			}
 | 
						||
		}()
 | 
						||
 | 
						||
		for v := range requireData.Ch { // 自动检测通道关闭
 | 
						||
			if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
 | 
						||
				log.Errorf("Send error: %v", err)
 | 
						||
				return
 | 
						||
			}
 | 
						||
			if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
 | 
						||
				chat = append(chat, v.Content)
 | 
						||
			}
 | 
						||
		}
 | 
						||
	}()
 | 
						||
 | 
						||
	return done
 | 
						||
}
 | 
						||
 | 
						||
// 辅助函数:带超时的 WebSocket 发送
 | 
						||
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
 | 
						||
	sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
 | 
						||
	defer cancel()
 | 
						||
 | 
						||
	done := make(chan error, 1)
 | 
						||
	go func() {
 | 
						||
		defer func() {
 | 
						||
			if r := recover(); r != nil {
 | 
						||
				done <- fmt.Errorf("panic in MsgSend: %v", r)
 | 
						||
			}
 | 
						||
			close(done)
 | 
						||
		}()
 | 
						||
		// 如果 MsgSend 阻塞,这里会卡住
 | 
						||
		err := entitys.MsgSend(c, data)
 | 
						||
		done <- err
 | 
						||
	}()
 | 
						||
 | 
						||
	select {
 | 
						||
	case err := <-done:
 | 
						||
		return err
 | 
						||
	case <-sendCtx.Done():
 | 
						||
		return sendCtx.Err()
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) {
 | 
						||
 | 
						||
	if len(imgUrl) == 0 {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	if err = pkg.ValidateImageURL(imgUrl); err != nil {
 | 
						||
		return err
 | 
						||
	}
 | 
						||
	req := l_request.Request{
 | 
						||
		Method: "GET",
 | 
						||
		Url:    imgUrl,
 | 
						||
	}
 | 
						||
	res, err := req.Send()
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	if _, ex := res.Headers["Content-Type"]; !ex {
 | 
						||
		return errors.ParamErr("图片格式错误:Content-Type未获取")
 | 
						||
	}
 | 
						||
	if !strings.HasPrefix(res.Headers["Content-Type"], "image/") {
 | 
						||
		return errors.ParamErr("expected image content, got %s", res.Headers["Content-Type"])
 | 
						||
	}
 | 
						||
	requireData.ImgByte = append(requireData.ImgByte, res.Content)
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
 | 
						||
	requireData.Ch <- entitys.Response{
 | 
						||
		Index:   "",
 | 
						||
		Content: "准备意图识别",
 | 
						||
		Type:    entitys.ResponseLog,
 | 
						||
	}
 | 
						||
	//意图识别
 | 
						||
	recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	requireData.Ch <- entitys.Response{
 | 
						||
		Index:   "",
 | 
						||
		Content: recognizeMsg,
 | 
						||
		Type:    entitys.ResponseLog,
 | 
						||
	}
 | 
						||
	requireData.Ch <- entitys.Response{
 | 
						||
		Index:   "",
 | 
						||
		Content: "意图识别结束",
 | 
						||
		Type:    entitys.ResponseLog,
 | 
						||
	}
 | 
						||
	var match entitys.Match
 | 
						||
	if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
 | 
						||
		err = errors.SysErr("数据结构错误:%v", err.Error())
 | 
						||
		return
 | 
						||
	}
 | 
						||
	requireData.Match = &match
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
func (r *AiRouterBiz) getRequireData(userInput string, requireData *entitys.RequireData) (err error) {
 | 
						||
	requireData.Sys, err = r.getSysInfo(requireData.Key)
 | 
						||
	if err != nil {
 | 
						||
		err = errors.SysErr("获取系统信息失败:%v", err.Error())
 | 
						||
		return
 | 
						||
	}
 | 
						||
	requireData.Histories, err = r.getSessionChatHis(requireData.Session)
 | 
						||
	if err != nil {
 | 
						||
		err = errors.SysErr("获取历史记录失败:%v", err.Error())
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	requireData.Tasks, err = r.getTasks(requireData.Sys.SysID)
 | 
						||
	if err != nil {
 | 
						||
		err = errors.SysErr("获取任务列表失败:%v", err.Error())
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	requireData.UserInput = userInput
 | 
						||
	if len(requireData.UserInput) == 0 {
 | 
						||
		err = errors.SysErr("获取用户输入失败")
 | 
						||
		return
 | 
						||
	}
 | 
						||
	if len(requireData.UserInput) == 0 {
 | 
						||
		err = errors.SysErr("获取用户输入失败")
 | 
						||
		return
 | 
						||
	}
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
func (r *AiRouterBiz) dataAuth(c *websocket.Conn, requireData *entitys.RequireData) (err error) {
 | 
						||
	requireData.Session = c.Query("x-session", "")
 | 
						||
	if len(requireData.Session) == 0 {
 | 
						||
		err = errors.SessionNotFound
 | 
						||
		return
 | 
						||
	}
 | 
						||
	requireData.Auth = c.Query("x-authorization", "")
 | 
						||
	if len(requireData.Auth) == 0 {
 | 
						||
		err = errors.AuthNotFound
 | 
						||
		return
 | 
						||
	}
 | 
						||
	requireData.Key = c.Query("x-app-key", "")
 | 
						||
	if len(requireData.Key) == 0 {
 | 
						||
		err = errors.KeyNotFound
 | 
						||
		return
 | 
						||
	}
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
func (r *AiRouterBiz) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
 | 
						||
	requireData.Ch <- entitys.Response{
 | 
						||
		Index:   "",
 | 
						||
		Content: requireData.Match.Reasoning,
 | 
						||
		Type:    entitys.ResponseText,
 | 
						||
	}
 | 
						||
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
func (r *AiRouterBiz) handleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) {
 | 
						||
 | 
						||
	if !requireData.Match.IsMatch {
 | 
						||
		requireData.Ch <- entitys.Response{
 | 
						||
			Index:   "",
 | 
						||
			Content: requireData.Match.Chat,
 | 
						||
			Type:    entitys.ResponseText,
 | 
						||
		}
 | 
						||
		return
 | 
						||
	}
 | 
						||
	var pointTask *model.AiTask
 | 
						||
	for _, task := range requireData.Tasks {
 | 
						||
		if task.Index == requireData.Match.Index {
 | 
						||
			pointTask = &task
 | 
						||
			break
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	if pointTask == nil || pointTask.Index == "other" {
 | 
						||
		return r.handleOtherTask(ctx, requireData)
 | 
						||
	}
 | 
						||
	switch pointTask.Type {
 | 
						||
	case constants.TaskTypeApi:
 | 
						||
		return r.handleApiTask(ctx, requireData, pointTask)
 | 
						||
	case constants.TaskTypeFunc:
 | 
						||
		return r.handleTask(ctx, requireData, pointTask)
 | 
						||
	case constants.TaskTypeKnowle:
 | 
						||
		return r.handleKnowle(ctx, requireData, pointTask)
 | 
						||
	default:
 | 
						||
		return r.handleOtherTask(ctx, requireData)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
func (r *AiRouterBiz) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
 | 
						||
	var configData entitys.ConfigDataTool
 | 
						||
	err = json.Unmarshal([]byte(task.Config), &configData)
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
// 知识库
 | 
						||
func (r *AiRouterBiz) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
 | 
						||
 | 
						||
	var (
 | 
						||
		configData         entitys.ConfigDataTool
 | 
						||
		sessionIdKnowledge string
 | 
						||
		query              string
 | 
						||
		host               string
 | 
						||
	)
 | 
						||
	err = json.Unmarshal([]byte(task.Config), &configData)
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	// 通过session 找到知识库session
 | 
						||
	var has bool
 | 
						||
	if len(requireData.Session) == 0 {
 | 
						||
		return errors.SessionNotFound
 | 
						||
	}
 | 
						||
	requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session))
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	} else if !has {
 | 
						||
		return errors.SessionNotFound
 | 
						||
	}
 | 
						||
 | 
						||
	// 找到知识库的host
 | 
						||
	{
 | 
						||
		tool, exists := r.toolManager.GetTool(configData.Tool)
 | 
						||
		if !exists {
 | 
						||
			return fmt.Errorf("tool not found: %s", configData.Tool)
 | 
						||
		}
 | 
						||
 | 
						||
		if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok {
 | 
						||
			return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
 | 
						||
		} else {
 | 
						||
			host = knowledgeTool.GetConfig().BaseURL
 | 
						||
		}
 | 
						||
 | 
						||
	}
 | 
						||
 | 
						||
	// 知识库的session为空,请求知识库获取, 并绑定
 | 
						||
	if requireData.SessionInfo.KnowlegeSessionID == "" {
 | 
						||
		// 请求知识库
 | 
						||
		if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil {
 | 
						||
			return
 | 
						||
		}
 | 
						||
 | 
						||
		// 绑定知识库session,下次可以使用
 | 
						||
		requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
 | 
						||
		if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil {
 | 
						||
			return
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	// 用户输入解析
 | 
						||
	var ok bool
 | 
						||
	input := make(map[string]string)
 | 
						||
	if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	if query, ok = input["query"]; !ok {
 | 
						||
		return fmt.Errorf("query不能为空")
 | 
						||
	}
 | 
						||
 | 
						||
	requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{
 | 
						||
		Session: requireData.SessionInfo.KnowlegeSessionID,
 | 
						||
		ApiKey:  requireData.Sys.KnowlegeTenantKey,
 | 
						||
		Query:   query,
 | 
						||
	}
 | 
						||
 | 
						||
	// 执行工具
 | 
						||
	err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
func (r *AiRouterBiz) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
 | 
						||
	var (
 | 
						||
		request      l_request.Request
 | 
						||
		requestParam map[string]interface{}
 | 
						||
	)
 | 
						||
	err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam)
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
 | 
						||
	for k, v := range requestParam {
 | 
						||
		task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
 | 
						||
	}
 | 
						||
	var configData entitys.ConfigDataHttp
 | 
						||
	err = json.Unmarshal([]byte(task.Config), &configData)
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	err = mapstructure.Decode(configData.Request, &request)
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	if len(request.Url) == 0 {
 | 
						||
		err = errors.NewBusinessErr(422, "api地址获取失败")
 | 
						||
		return
 | 
						||
	}
 | 
						||
	res, err := request.Send()
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	requireData.Ch <- entitys.Response{
 | 
						||
		Index:   "",
 | 
						||
		Content: pkg.JsonStringIgonErr(res.Text),
 | 
						||
		Type:    entitys.ResponseJson,
 | 
						||
	}
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
 | 
						||
 | 
						||
	cond := builder.NewCond()
 | 
						||
	cond = cond.And(builder.Eq{"session_id": sessionId})
 | 
						||
 | 
						||
	_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id desc")
 | 
						||
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
 | 
						||
	cond := builder.NewCond()
 | 
						||
	cond = cond.And(builder.Eq{"app_key": appKey})
 | 
						||
	cond = cond.And(builder.IsNull{"delete_at"})
 | 
						||
	cond = cond.And(builder.Eq{"status": 1})
 | 
						||
	err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
 | 
						||
 | 
						||
	cond := builder.NewCond()
 | 
						||
	cond = cond.And(builder.Eq{"sys_id": sysId})
 | 
						||
	cond = cond.And(builder.IsNull{"delete_at"})
 | 
						||
	cond = cond.And(builder.Eq{"status": 1})
 | 
						||
	_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
 | 
						||
 | 
						||
	return
 | 
						||
}
 |