ai_scheduler/internal/biz/router.go

555 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package biz
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/pkg/utils_ollama"
tools "ai_scheduler/internal/tools"
"ai_scheduler/tmpl/dataTemp"
"context"
"encoding/json"
"fmt"
"strings"
"time"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2"
"github.com/ollama/ollama/api"
"github.com/tmc/langchaingo/llms"
"xorm.io/builder"
)
// AiRouterBiz 智能路由服务
type AiRouterBiz struct {
//aiClient entitys.AIClient
toolManager *tools.Manager
sessionImpl *impl.SessionImpl
sysImpl *impl.SysImpl
taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl
conf *config.Config
utilAgent *utils_ollama.UtilOllama
ollama *utils_ollama.Client
channelPool *pkg.SafeChannelPool
rds *pkg.Rdb
}
// NewRouterService 创建路由服务
func NewAiRouterBiz(
//aiClient entitys.AIClient,
toolManager *tools.Manager,
sessionImpl *impl.SessionImpl,
sysImpl *impl.SysImpl,
taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl,
conf *config.Config,
utilAgent *utils_ollama.UtilOllama,
channelPool *pkg.SafeChannelPool,
ollama *utils_ollama.Client,
) *AiRouterBiz {
return &AiRouterBiz{
//aiClient: aiClient,
toolManager: toolManager,
sessionImpl: sessionImpl,
conf: conf,
sysImpl: sysImpl,
hisImpl: hisImpl,
taskImpl: taskImpl,
utilAgent: utilAgent,
channelPool: channelPool,
ollama: ollama,
}
}
// Route 执行智能路由
func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
return nil, nil
}
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
session := c.Query("x-session", "")
if len(session) == 0 {
return errors.SessionNotFound
}
auth := c.Query("x-authorization", "")
if len(auth) == 0 {
return errors.AuthNotFound
}
key := c.Query("x-app-key", "")
if len(key) == 0 {
return errors.KeyNotFound
}
var chat = make([]string, 0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
//ch := r.channelPool.Get()
ch := make(chan entitys.Response)
done := make(chan struct{})
go func() {
defer func() {
close(done)
var his = []*model.AiChatHi{
{
SessionID: session,
Role: "user",
Content: req.Text,
},
}
if len(chat) > 0 {
his = append(his, &model.AiChatHi{
SessionID: session,
Role: "assistant",
Content: strings.Join(chat, ""),
})
}
for _, hi := range his {
r.hisImpl.Add(hi)
}
}()
for {
select {
case v, ok := <-ch:
if !ok {
return
}
// 带超时的发送,避免阻塞
if err = sendWithTimeout(c, v, 2*time.Second); err != nil {
log.Errorf("Send error: %v", err)
cancel() // 通知主流程退出
return
}
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
chat = append(chat, v.Content)
}
case <-ctx.Done():
return
}
}
}()
defer func() {
close(ch)
}()
sysInfo, err := r.getSysInfo(key)
if err != nil {
return errors.SysErr("获取系统信息失败:%v", err.Error())
}
history, err := r.getSessionChatHis(session)
if err != nil {
return errors.SysErr("获取历史记录失败:%v", err.Error())
}
task, err := r.getTasks(sysInfo.SysID)
if err != nil {
return errors.SysErr("获取任务列表失败:%v", err.Error())
}
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
AgentClient := r.utilAgent.Get()
ch <- entitys.Response{
Index: "",
Content: "准备意图识别",
Type: entitys.ResponseLog,
}
match, err := AgentClient.Llm.GenerateContent(
ctx, // 使用可取消的上下文
prompt,
llms.WithJSONMode(),
)
resMsg := match.Choices[0].Content
r.utilAgent.Put(AgentClient)
ch <- entitys.Response{
Index: "",
Content: resMsg,
Type: entitys.ResponseLog,
}
ch <- entitys.Response{
Index: "",
Content: "意图识别结束",
Type: entitys.ResponseLog,
}
if err != nil {
log.Errorf("LLM error: %v", err)
return errors.SystemError
}
var matchJson entitys.Match
if err := json.Unmarshal([]byte(resMsg), &matchJson); err != nil {
log.Info(resMsg)
return errors.SysErr("数据结构错误:%v", err.Error())
}
matchJson.History = pkg.JsonByteIgonErr(history)
matchJson.UserInput = req.Text
if err := r.handleMatch(ctx, c, ch, &matchJson, task, sysInfo); err != nil {
return err
}
return nil
}
// 辅助函数:带超时的 WebSocket 发送
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
done := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
done <- fmt.Errorf("panic in MsgSend: %v", r)
}
}()
done <- entitys.MsgSend(c, data)
}()
select {
case err := <-done:
return err
case <-sendCtx.Done():
return sendCtx.Err()
}
}
func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match) (err error) {
ch <- entitys.Response{
Index: "",
Content: matchJson.Reasoning,
Type: entitys.ResponseText,
}
return
}
func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) {
if !matchJson.IsMatch {
_ = entitys.MsgSend(c, entitys.Response{
Index: "",
Content: matchJson.Reasoning,
Type: entitys.ResponseText,
})
return
}
var pointTask *model.AiTask
for _, task := range tasks {
if task.Index == matchJson.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return r.handleOtherTask(c, ch, matchJson)
}
switch pointTask.Type {
case constants.TaskTypeApi:
return r.handleApiTask(ch, c, matchJson, pointTask)
case constants.TaskTypeFunc:
return r.handleTask(ch, c, matchJson, pointTask)
case constants.TaskTypeKnowle:
return r.handleKnowle(ch, c, matchJson, pointTask, sysInfo)
default:
return r.handleOtherTask(c, ch, matchJson)
}
}
func (r *AiRouterBiz) handleTask(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, []byte(matchJson.Parameters), matchJson)
if err != nil {
return
}
return
}
// 知识库
func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask, sysInfo model.AiSy) (err error) {
var (
configData entitys.ConfigDataTool
sessionIdKnowledge string
query string
host string
)
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
// 通过session 找到知识库session
session := c.Query("x-session", "")
if len(session) == 0 {
return errors.SessionNotFound
}
sessionInfo, has, err := r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(session))
if err != nil {
return
} else if !has {
return errors.SessionNotFound
}
// 找到知识库的host
{
tool, exists := r.toolManager.GetTool(configData.Tool)
if !exists {
return fmt.Errorf("tool not found: %s", configData.Tool)
}
if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok {
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
} else {
host = knowledgeTool.GetConfig().BaseURL
}
}
// 知识库的session为空请求知识库获取, 并绑定
if sessionInfo.KnowlegeSessionID == "" {
// 请求知识库
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, sysInfo.KnowlegeBaseID, sysInfo.KnowlegeTenantKey); err != nil {
return
}
// 绑定知识库session下次可以使用
sessionInfo.KnowlegeSessionID = sessionIdKnowledge
if err = r.sessionImpl.Update(&sessionInfo, r.sessionImpl.WithSessionId(sessionInfo.SessionID)); err != nil {
return
}
}
// 用户输入解析
var ok bool
input := make(map[string]string)
if err = json.Unmarshal([]byte(matchJson.Parameters), &input); err != nil {
return
}
if query, ok = input["query"]; !ok {
return fmt.Errorf("query不能为空")
}
knowledgeConfig := tools.KnowledgeBaseRequest{
Session: sessionInfo.KnowlegeSessionID,
ApiKey: sysInfo.KnowlegeTenantKey,
Query: query,
}
b, err := json.Marshal(knowledgeConfig)
if err != nil {
return
}
// 执行工具
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, b, nil)
if err != nil {
return
}
return
}
func (r *AiRouterBiz) handleApiTask(channels chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var (
request l_request.Request
auth = c.Query("x-authorization", "")
requestParam map[string]interface{}
)
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
if err != nil {
return
}
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
for k, v := range requestParam {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
}
var configData entitys.ConfigDataHttp
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = mapstructure.Decode(configData.Request, &request)
if err != nil {
return
}
if len(request.Url) == 0 {
err = errors.NewBusinessErr(422, "api地址获取失败")
return
}
res, err := request.Send()
if err != nil {
return
}
c.WriteMessage(1, res.Content)
return
}
func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": sessionId})
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id desc")
return
}
func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": appKey})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
return
}
func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"sys_id": sysId})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
return
}
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool {
taskPrompt := make([]llms.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfig
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, llms.Tool{
Type: "function",
Function: &llms.FunctionDefinition{
Name: task.Index,
Description: task.Desc,
Parameters: taskConfig.Param,
},
})
}
return taskPrompt
}
func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
var (
prompt = make([]entitys.Message, 0)
)
prompt = append(prompt, entitys.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, entitys.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
}, entitys.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
func (r *AiRouterBiz) getPromptOllama(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []api.Message {
var (
prompt = make([]api.Message, 0)
)
prompt = append(prompt, api.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, api.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
}, api.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
var (
prompt = make([]llms.MessageContent, 0)
)
prompt = append(prompt, llms.MessageContent{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{
llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{
llms.TextPart(reqInput),
},
})
return prompt
}
// buildSystemPrompt 构建系统提示词
func (r *AiRouterBiz) buildSystemPrompt(prompt string) string {
if len(prompt) == 0 {
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式禁用markdown格式返回\n3.只返回json字符串不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
}
return prompt
}
func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
for _, item := range his {
if len(chatHis.SessionId) == 0 {
chatHis.SessionId = item.SessionID
}
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
Role: item.Role,
Content: item.Content,
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
})
}
chatHis.Context = entitys.HisContext{
UserLanguage: "zh-CN",
SystemMode: "technical_support",
}
return
}
// handleKnowledgeQA 处理知识问答意图
func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
return nil, nil
}