493 lines
12 KiB
Go
493 lines
12 KiB
Go
package biz
|
||
|
||
import (
|
||
"ai_scheduler/internal/biz/llm_service"
|
||
"ai_scheduler/internal/config"
|
||
"ai_scheduler/internal/data/constants"
|
||
errors "ai_scheduler/internal/data/error"
|
||
"ai_scheduler/internal/data/impl"
|
||
"ai_scheduler/internal/data/model"
|
||
"ai_scheduler/internal/entitys"
|
||
"ai_scheduler/internal/pkg"
|
||
"ai_scheduler/internal/pkg/mapstructure"
|
||
|
||
"ai_scheduler/internal/tools"
|
||
"ai_scheduler/tmpl/dataTemp"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
|
||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||
"github.com/gofiber/fiber/v2/log"
|
||
"github.com/gofiber/websocket/v2"
|
||
"xorm.io/builder"
|
||
)
|
||
|
||
// AiRouterBiz 智能路由服务
|
||
type AiRouterBiz struct {
|
||
toolManager *tools.Manager
|
||
sessionImpl *impl.SessionImpl
|
||
sysImpl *impl.SysImpl
|
||
taskImpl *impl.TaskImpl
|
||
hisImpl *impl.ChatImpl
|
||
conf *config.Config
|
||
rds *pkg.Rdb
|
||
langChain *llm_service.LangChainService
|
||
Ollama *llm_service.OllamaService
|
||
}
|
||
|
||
// NewRouterService 创建路由服务
|
||
func NewAiRouterBiz(
|
||
toolManager *tools.Manager,
|
||
sessionImpl *impl.SessionImpl,
|
||
sysImpl *impl.SysImpl,
|
||
taskImpl *impl.TaskImpl,
|
||
hisImpl *impl.ChatImpl,
|
||
conf *config.Config,
|
||
langChain *llm_service.LangChainService,
|
||
Ollama *llm_service.OllamaService,
|
||
) *AiRouterBiz {
|
||
return &AiRouterBiz{
|
||
toolManager: toolManager,
|
||
sessionImpl: sessionImpl,
|
||
conf: conf,
|
||
sysImpl: sysImpl,
|
||
hisImpl: hisImpl,
|
||
taskImpl: taskImpl,
|
||
langChain: langChain,
|
||
Ollama: Ollama,
|
||
}
|
||
}
|
||
|
||
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
|
||
//必要数据验证和获取
|
||
var requireData entitys.RequireData
|
||
err = r.dataAuth(c, &requireData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
//初始化通道/上下文
|
||
requireData.Ch = make(chan entitys.Response)
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
// 启动独立的消息处理协程
|
||
done := r.startMessageHandler(ctx, c, &requireData)
|
||
defer func() {
|
||
close(requireData.Ch) //关闭主通道
|
||
<-done // 等待消息处理完成
|
||
cancel()
|
||
}()
|
||
|
||
//获取图片信息
|
||
err = r.getImgData(req.Img, &requireData)
|
||
if err != nil {
|
||
log.Errorf("GetImgData error: %v", err)
|
||
return
|
||
}
|
||
|
||
//获取系统信息
|
||
err = r.getRequireData(req.Text, &requireData)
|
||
if err != nil {
|
||
log.Errorf("SQL error: %v", err)
|
||
return
|
||
}
|
||
//意图识别
|
||
err = r.recognize(ctx, &requireData)
|
||
if err != nil {
|
||
log.Errorf("LLM error: %v", err)
|
||
return
|
||
}
|
||
//向下传递
|
||
if err = r.handleMatch(ctx, &requireData); err != nil {
|
||
log.Errorf("Handle error: %v", err)
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// startMessageHandler 启动独立的消息处理协程
|
||
func (r *AiRouterBiz) startMessageHandler(
|
||
ctx context.Context,
|
||
c *websocket.Conn,
|
||
requireData *entitys.RequireData,
|
||
) <-chan struct{} {
|
||
done := make(chan struct{})
|
||
var chat []string
|
||
|
||
go func() {
|
||
defer func() {
|
||
close(done)
|
||
// 保存历史记录
|
||
var his = []*model.AiChatHi{
|
||
{
|
||
SessionID: requireData.Session,
|
||
Role: "user",
|
||
Content: "", // 用户输入在外部处理
|
||
},
|
||
}
|
||
if len(chat) > 0 {
|
||
his = append(his, &model.AiChatHi{
|
||
SessionID: requireData.Session,
|
||
Role: "assistant",
|
||
Content: strings.Join(chat, ""),
|
||
})
|
||
}
|
||
for _, hi := range his {
|
||
r.hisImpl.Add(hi)
|
||
}
|
||
}()
|
||
|
||
for v := range requireData.Ch { // 自动检测通道关闭
|
||
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
|
||
log.Errorf("Send error: %v", err)
|
||
return
|
||
}
|
||
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
|
||
chat = append(chat, v.Content)
|
||
}
|
||
}
|
||
}()
|
||
|
||
return done
|
||
}
|
||
|
||
// 辅助函数:带超时的 WebSocket 发送
|
||
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
|
||
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
|
||
done := make(chan error, 1)
|
||
go func() {
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
done <- fmt.Errorf("panic in MsgSend: %v", r)
|
||
}
|
||
close(done)
|
||
}()
|
||
// 如果 MsgSend 阻塞,这里会卡住
|
||
err := entitys.MsgSend(c, data)
|
||
done <- err
|
||
}()
|
||
|
||
select {
|
||
case err := <-done:
|
||
return err
|
||
case <-sendCtx.Done():
|
||
return sendCtx.Err()
|
||
}
|
||
}
|
||
|
||
func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) {
|
||
return
|
||
if len(imgUrl) == 0 {
|
||
return
|
||
}
|
||
if err = pkg.ValidateImageURL(imgUrl); err != nil {
|
||
return err
|
||
}
|
||
req := l_request.Request{
|
||
Method: "GET",
|
||
Url: imgUrl,
|
||
}
|
||
res, err := req.Send()
|
||
if err != nil {
|
||
return
|
||
}
|
||
if _, ex := res.Headers["Content-Type"]; !ex {
|
||
return errors.ParamErr("图片格式错误:Content-Type未获取")
|
||
}
|
||
if !strings.HasPrefix(res.Headers["Content-Type"], "image/") {
|
||
return errors.ParamErr("expected image content, got %s", res.Headers["Content-Type"])
|
||
}
|
||
requireData.ImgByte = append(requireData.ImgByte, res.Content)
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||
requireData.Ch <- entitys.Response{
|
||
Index: "",
|
||
Content: "准备意图识别",
|
||
Type: entitys.ResponseLog,
|
||
}
|
||
//意图识别
|
||
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
requireData.Ch <- entitys.Response{
|
||
Index: "",
|
||
Content: recognizeMsg,
|
||
Type: entitys.ResponseLog,
|
||
}
|
||
requireData.Ch <- entitys.Response{
|
||
Index: "",
|
||
Content: "意图识别结束",
|
||
Type: entitys.ResponseLog,
|
||
}
|
||
var match entitys.Match
|
||
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
|
||
err = errors.SysErr("数据结构错误:%v", err.Error())
|
||
return
|
||
}
|
||
requireData.Match = &match
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) getRequireData(userInput string, requireData *entitys.RequireData) (err error) {
|
||
requireData.Sys, err = r.getSysInfo(requireData.Key)
|
||
if err != nil {
|
||
err = errors.SysErr("获取系统信息失败:%v", err.Error())
|
||
return
|
||
}
|
||
requireData.Histories, err = r.getSessionChatHis(requireData.Session)
|
||
if err != nil {
|
||
err = errors.SysErr("获取历史记录失败:%v", err.Error())
|
||
return
|
||
}
|
||
|
||
requireData.Tasks, err = r.getTasks(requireData.Sys.SysID)
|
||
if err != nil {
|
||
err = errors.SysErr("获取任务列表失败:%v", err.Error())
|
||
return
|
||
}
|
||
|
||
requireData.UserInput = userInput
|
||
if len(requireData.UserInput) == 0 {
|
||
err = errors.SysErr("获取用户输入失败")
|
||
return
|
||
}
|
||
if len(requireData.UserInput) == 0 {
|
||
err = errors.SysErr("获取用户输入失败")
|
||
return
|
||
}
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) dataAuth(c *websocket.Conn, requireData *entitys.RequireData) (err error) {
|
||
requireData.Session = c.Query("x-session", "")
|
||
if len(requireData.Session) == 0 {
|
||
err = errors.SessionNotFound
|
||
return
|
||
}
|
||
requireData.Auth = c.Query("x-authorization", "")
|
||
if len(requireData.Auth) == 0 {
|
||
err = errors.AuthNotFound
|
||
return
|
||
}
|
||
requireData.Key = c.Query("x-app-key", "")
|
||
if len(requireData.Key) == 0 {
|
||
err = errors.KeyNotFound
|
||
return
|
||
}
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||
requireData.Ch <- entitys.Response{
|
||
Index: "",
|
||
Content: requireData.Match.Reasoning,
|
||
Type: entitys.ResponseText,
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) handleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||
|
||
if !requireData.Match.IsMatch {
|
||
requireData.Ch <- entitys.Response{
|
||
Index: "",
|
||
Content: requireData.Match.Chat,
|
||
Type: entitys.ResponseText,
|
||
}
|
||
return
|
||
}
|
||
var pointTask *model.AiTask
|
||
for _, task := range requireData.Tasks {
|
||
if task.Index == requireData.Match.Index {
|
||
pointTask = &task
|
||
break
|
||
}
|
||
}
|
||
|
||
if pointTask == nil || pointTask.Index == "other" {
|
||
return r.handleOtherTask(ctx, requireData)
|
||
}
|
||
switch pointTask.Type {
|
||
case constants.TaskTypeApi:
|
||
return r.handleApiTask(ctx, requireData, pointTask)
|
||
case constants.TaskTypeFunc:
|
||
return r.handleTask(ctx, requireData, pointTask)
|
||
case constants.TaskTypeKnowle:
|
||
return r.handleKnowle(ctx, requireData, pointTask)
|
||
default:
|
||
return r.handleOtherTask(ctx, requireData)
|
||
}
|
||
}
|
||
|
||
func (r *AiRouterBiz) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||
var configData entitys.ConfigDataTool
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// 知识库
|
||
func (r *AiRouterBiz) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||
|
||
var (
|
||
configData entitys.ConfigDataTool
|
||
sessionIdKnowledge string
|
||
query string
|
||
host string
|
||
)
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 通过session 找到知识库session
|
||
var has bool
|
||
if len(requireData.Session) == 0 {
|
||
return errors.SessionNotFound
|
||
}
|
||
requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session))
|
||
if err != nil {
|
||
return
|
||
} else if !has {
|
||
return errors.SessionNotFound
|
||
}
|
||
|
||
// 找到知识库的host
|
||
{
|
||
tool, exists := r.toolManager.GetTool(configData.Tool)
|
||
if !exists {
|
||
return fmt.Errorf("tool not found: %s", configData.Tool)
|
||
}
|
||
|
||
if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok {
|
||
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
|
||
} else {
|
||
host = knowledgeTool.GetConfig().BaseURL
|
||
}
|
||
|
||
}
|
||
|
||
// 知识库的session为空,请求知识库获取, 并绑定
|
||
if requireData.SessionInfo.KnowlegeSessionID == "" {
|
||
// 请求知识库
|
||
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil {
|
||
return
|
||
}
|
||
|
||
// 绑定知识库session,下次可以使用
|
||
requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
|
||
if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil {
|
||
return
|
||
}
|
||
}
|
||
|
||
// 用户输入解析
|
||
var ok bool
|
||
input := make(map[string]string)
|
||
if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil {
|
||
return
|
||
}
|
||
if query, ok = input["query"]; !ok {
|
||
return fmt.Errorf("query不能为空")
|
||
}
|
||
|
||
requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{
|
||
Session: requireData.SessionInfo.KnowlegeSessionID,
|
||
ApiKey: requireData.Sys.KnowlegeTenantKey,
|
||
Query: query,
|
||
}
|
||
|
||
// 执行工具
|
||
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||
var (
|
||
request l_request.Request
|
||
requestParam map[string]interface{}
|
||
)
|
||
err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam)
|
||
if err != nil {
|
||
return
|
||
}
|
||
request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
|
||
for k, v := range requestParam {
|
||
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
|
||
}
|
||
var configData entitys.ConfigDataHttp
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = mapstructure.Decode(configData.Request, &request)
|
||
if err != nil {
|
||
return
|
||
}
|
||
if len(request.Url) == 0 {
|
||
err = errors.NewBusinessErr(422, "api地址获取失败")
|
||
return
|
||
}
|
||
res, err := request.Send()
|
||
if err != nil {
|
||
return
|
||
}
|
||
requireData.Ch <- entitys.Response{
|
||
Index: "",
|
||
Content: pkg.JsonStringIgonErr(res.Text),
|
||
Type: entitys.ResponseJson,
|
||
}
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
||
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"session_id": sessionId})
|
||
|
||
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id desc")
|
||
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"app_key": appKey})
|
||
cond = cond.And(builder.IsNull{"delete_at"})
|
||
cond = cond.And(builder.Eq{"status": 1})
|
||
err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||
cond = cond.And(builder.IsNull{"delete_at"})
|
||
cond = cond.And(builder.Eq{"status": 1})
|
||
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
|
||
|
||
return
|
||
}
|