ai_scheduler/internal/biz/router.go

458 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package biz
import (
"ai_scheduler/internal/biz/llm_service"
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/tools"
"ai_scheduler/tmpl/dataTemp"
"context"
"encoding/json"
"fmt"
"strings"
"time"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2"
"xorm.io/builder"
)
// AiRouterBiz 智能路由服务
type AiRouterBiz struct {
toolManager *tools.Manager
sessionImpl *impl.SessionImpl
sysImpl *impl.SysImpl
taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl
conf *config.Config
rds *pkg.Rdb
langChain *llm_service.LangChainService
Ollama *llm_service.OllamaService
}
// NewRouterService 创建路由服务
func NewAiRouterBiz(
toolManager *tools.Manager,
sessionImpl *impl.SessionImpl,
sysImpl *impl.SysImpl,
taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl,
conf *config.Config,
langChain *llm_service.LangChainService,
Ollama *llm_service.OllamaService,
) *AiRouterBiz {
return &AiRouterBiz{
toolManager: toolManager,
sessionImpl: sessionImpl,
conf: conf,
sysImpl: sysImpl,
hisImpl: hisImpl,
taskImpl: taskImpl,
langChain: langChain,
Ollama: Ollama,
}
}
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
//必要数据验证和获取
var requireData entitys.RequireData
err = r.dataAuth(c, &requireData)
if err != nil {
return
}
//初始化通道/上下文
requireData.Ch = make(chan entitys.Response)
ctx, cancel := context.WithCancel(context.Background())
// 启动独立的消息处理协程
done := r.startMessageHandler(ctx, c, &requireData)
defer func() {
close(requireData.Ch) //关闭主通道
<-done // 等待消息处理完成
cancel()
}()
//获取系统信息
err = r.getRequireData(req.Text, &requireData)
if err != nil {
log.Errorf("SQL error: %v", err)
return
}
//意图识别
err = r.recognize(ctx, &requireData)
if err != nil {
log.Errorf("LLM error: %v", err)
return
}
//向下传递
if err = r.handleMatch(ctx, &requireData); err != nil {
log.Errorf("Handle error: %v", err)
return
}
return
}
// startMessageHandler 启动独立的消息处理协程
func (r *AiRouterBiz) startMessageHandler(
ctx context.Context,
c *websocket.Conn,
requireData *entitys.RequireData,
) <-chan struct{} {
done := make(chan struct{})
var chat []string
go func() {
defer func() {
close(done)
// 保存历史记录
var his = []*model.AiChatHi{
{
SessionID: requireData.Session,
Role: "user",
Content: "", // 用户输入在外部处理
},
}
if len(chat) > 0 {
his = append(his, &model.AiChatHi{
SessionID: requireData.Session,
Role: "assistant",
Content: strings.Join(chat, ""),
})
}
for _, hi := range his {
r.hisImpl.Add(hi)
}
}()
for v := range requireData.Ch { // 自动检测通道关闭
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
log.Errorf("Send error: %v", err)
return
}
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
chat = append(chat, v.Content)
}
}
}()
return done
}
// 辅助函数:带超时的 WebSocket 发送
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
done := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
done <- fmt.Errorf("panic in MsgSend: %v", r)
}
close(done)
}()
// 如果 MsgSend 阻塞,这里会卡住
err := entitys.MsgSend(c, data)
done <- err
}()
select {
case err := <-done:
return err
case <-sendCtx.Done():
return sendCtx.Err()
}
}
func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
requireData.Ch <- entitys.Response{
Index: "",
Content: "准备意图识别",
Type: entitys.ResponseLog,
}
//意图识别
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
if err != nil {
return
}
requireData.Ch <- entitys.Response{
Index: "",
Content: recognizeMsg,
Type: entitys.ResponseLog,
}
requireData.Ch <- entitys.Response{
Index: "",
Content: "意图识别结束",
Type: entitys.ResponseLog,
}
if err = json.Unmarshal([]byte(recognizeMsg), requireData.Match); err != nil {
err = errors.SysErr("数据结构错误:%v", err.Error())
return
}
return
}
func (r *AiRouterBiz) getRequireData(userInput string, requireData *entitys.RequireData) (err error) {
requireData.Sys, err = r.getSysInfo(requireData.Key)
if err != nil {
err = errors.SysErr("获取系统信息失败:%v", err.Error())
return
}
requireData.Histories, err = r.getSessionChatHis(requireData.Session)
if err != nil {
err = errors.SysErr("获取历史记录失败:%v", err.Error())
return
}
requireData.Tasks, err = r.getTasks(requireData.Sys.SysID)
if err != nil {
err = errors.SysErr("获取任务列表失败:%v", err.Error())
return
}
requireData.UserInput = userInput
if len(requireData.UserInput) == 0 {
err = errors.SysErr("获取用户输入失败")
return
}
if len(requireData.UserInput) == 0 {
err = errors.SysErr("获取用户输入失败")
return
}
return
}
func (r *AiRouterBiz) dataAuth(c *websocket.Conn, requireData *entitys.RequireData) (err error) {
requireData.Session = c.Query("x-session", "")
if len(requireData.Session) == 0 {
err = errors.SessionNotFound
return
}
requireData.Auth = c.Query("x-authorization", "")
if len(requireData.Auth) == 0 {
err = errors.AuthNotFound
return
}
requireData.Key = c.Query("x-app-key", "")
if len(requireData.Key) == 0 {
err = errors.KeyNotFound
return
}
return
}
func (r *AiRouterBiz) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
requireData.Ch <- entitys.Response{
Index: "",
Content: requireData.Match.Reasoning,
Type: entitys.ResponseText,
}
return
}
func (r *AiRouterBiz) handleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) {
if !requireData.Match.IsMatch {
requireData.Ch <- entitys.Response{
Index: "",
Content: requireData.Match.Reasoning,
Type: entitys.ResponseText,
}
return
}
var pointTask *model.AiTask
for _, task := range requireData.Tasks {
if task.Index == requireData.Match.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return r.handleOtherTask(ctx, requireData)
}
switch pointTask.Type {
case constants.TaskTypeApi:
return r.handleApiTask(ctx, requireData, pointTask)
case constants.TaskTypeFunc:
return r.handleTask(ctx, requireData, pointTask)
case constants.TaskTypeKnowle:
return r.handleKnowle(ctx, requireData, pointTask)
default:
return r.handleOtherTask(ctx, requireData)
}
}
func (r *AiRouterBiz) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
if err != nil {
return
}
return
}
// 知识库
func (r *AiRouterBiz) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var (
configData entitys.ConfigDataTool
sessionIdKnowledge string
query string
host string
)
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
// 通过session 找到知识库session
var has bool
if len(requireData.Session) == 0 {
return errors.SessionNotFound
}
requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session))
if err != nil {
return
} else if !has {
return errors.SessionNotFound
}
// 找到知识库的host
{
tool, exists := r.toolManager.GetTool(configData.Tool)
if !exists {
return fmt.Errorf("tool not found: %s", configData.Tool)
}
if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok {
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
} else {
host = knowledgeTool.GetConfig().BaseURL
}
}
// 知识库的session为空请求知识库获取, 并绑定
if requireData.SessionInfo.KnowlegeSessionID == "" {
// 请求知识库
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil {
return
}
// 绑定知识库session下次可以使用
requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil {
return
}
}
// 用户输入解析
var ok bool
input := make(map[string]string)
if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil {
return
}
if query, ok = input["query"]; !ok {
return fmt.Errorf("query不能为空")
}
requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{
Session: requireData.SessionInfo.KnowlegeSessionID,
ApiKey: requireData.Sys.KnowlegeTenantKey,
Query: query,
}
// 执行工具
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
if err != nil {
return
}
return
}
func (r *AiRouterBiz) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var (
request l_request.Request
requestParam map[string]interface{}
)
err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam)
if err != nil {
return
}
request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
for k, v := range requestParam {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
}
var configData entitys.ConfigDataHttp
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = mapstructure.Decode(configData.Request, &request)
if err != nil {
return
}
if len(request.Url) == 0 {
err = errors.NewBusinessErr(422, "api地址获取失败")
return
}
res, err := request.Send()
if err != nil {
return
}
requireData.Ch <- entitys.Response{
Index: "",
Content: pkg.JsonStringIgonErr(res.Text),
Type: entitys.ResponseJson,
}
return
}
func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": sessionId})
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id desc")
return
}
func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": appKey})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
return
}
func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"sys_id": sysId})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
return
}