ai_scheduler/internal/biz/router.go

359 lines
9.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package biz
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/internal/tools"
"ai_scheduler/tmpl/dataTemp"
"context"
"encoding/json"
"fmt"
"strings"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2"
"github.com/tmc/langchaingo/llms"
"xorm.io/builder"
)
// AiRouterBiz 智能路由服务
type AiRouterBiz struct {
//aiClient entitys.AIClient
toolManager *tools.Manager
sessionImpl *impl.SessionImpl
sysImpl *impl.SysImpl
taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl
conf *config.Config
utilAgent *utils_ollama.UtilOllama
channelPool *pkg.SafeChannelPool
}
// NewRouterService 创建路由服务
func NewAiRouterBiz(
//aiClient entitys.AIClient,
toolManager *tools.Manager,
sessionImpl *impl.SessionImpl,
sysImpl *impl.SysImpl,
taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl,
conf *config.Config,
utilAgent *utils_ollama.UtilOllama,
channelPool *pkg.SafeChannelPool,
) *AiRouterBiz {
return &AiRouterBiz{
//aiClient: aiClient,
toolManager: toolManager,
sessionImpl: sessionImpl,
conf: conf,
sysImpl: sysImpl,
hisImpl: hisImpl,
taskImpl: taskImpl,
utilAgent: utilAgent,
channelPool: channelPool,
}
}
// Route 执行智能路由
func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
return nil, nil
}
// Route 执行智能路由
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
ch, err := r.channelPool.Get()
if err != nil {
return err
}
defer func() {
if err != nil {
_ = c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
}
_ = c.WriteMessage(websocket.TextMessage, []byte("EOF"))
r.channelPool.Put(ch)
}()
session := c.Headers("X-Session", "")
if len(session) == 0 {
return errors.SessionNotFound
}
auth := c.Headers("X-Authorization", "")
if len(auth) == 0 {
return errors.AuthNotFound
}
key := c.Headers("X-App-Key", "")
if len(key) == 0 {
return errors.KeyNotFound
}
sysInfo, err := r.getSysInfo(key)
if err != nil {
return errors.SysNotFound
}
history, err := r.getSessionChatHis(session)
if err != nil {
return errors.SystemError
}
task, err := r.getTasks(sysInfo.SysID)
if err != nil {
return errors.SystemError
}
//意图预测
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
llms.WithJSONMode(),
)
if err != nil {
return errors.SystemError
}
log.Info(match.Choices[0].Content)
_ = c.WriteMessage(websocket.TextMessage, []byte(match.Choices[0].Content))
var matchJson entitys.Match
err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
if err != nil {
return errors.SystemError
}
go func() {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("recovered from panic: %v", r)
}
}()
defer close(ch)
err = r.handleMatch(c, ch, &matchJson, task)
if err != nil {
return
}
}()
for v := range ch {
if err := c.WriteMessage(websocket.TextMessage, v); err != nil {
return err
}
}
_ = c.WriteMessage(websocket.TextMessage, []byte("结束"))
return
}
func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan []byte, matchJson *entitys.Match) (err error) {
ch <- []byte(matchJson.Reasoning)
return
}
func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan []byte, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
if !matchJson.IsMatch {
ch <- []byte(matchJson.Reasoning)
return
}
var pointTask *model.AiTask
for _, task := range tasks {
if task.Index == matchJson.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return r.handleOtherTask(c, ch, matchJson)
}
switch pointTask.Type {
case constants.TaskTypeApi:
return r.handleApiTask(ch, c, matchJson, pointTask)
case constants.TaskTypeFunc:
return r.handleTask(ch, c, matchJson, pointTask)
default:
return r.handleOtherTask(c, ch, matchJson)
}
}
func (r *AiRouterBiz) handleTask(channel chan []byte, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, []byte(matchJson.Parameters))
if err != nil {
return
}
return
}
func (r *AiRouterBiz) handleApiTask(channels chan []byte, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var (
request l_request.Request
auth = c.Headers("X-Authorization", "")
requestParam map[string]interface{}
)
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
if err != nil {
return
}
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
for k, v := range requestParam {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
}
var configData entitys.ConfigDataHttp
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = mapstructure.Decode(configData.Request, &request)
if err != nil {
return
}
if len(request.Url) == 0 {
err = errors.NewBusinessErr("00022", "api地址获取失败")
return
}
res, err := request.Send()
if err != nil {
return
}
c.WriteMessage(1, res.Content)
return
}
func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": sessionId})
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id asc")
return
}
func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": appKey})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
return
}
func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"sys_id": sysId})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
return
}
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool {
taskPrompt := make([]llms.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfig
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, llms.Tool{
Type: "function",
Function: &llms.FunctionDefinition{
Name: task.Index,
Description: task.Desc,
Parameters: taskConfig.Param,
},
})
}
return taskPrompt
}
func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
var (
prompt = make([]entitys.Message, 0)
)
prompt = append(prompt, entitys.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, entitys.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
}, entitys.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
var (
prompt = make([]llms.MessageContent, 0)
)
prompt = append(prompt, llms.MessageContent{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{
llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{
llms.TextPart(reqInput),
},
})
return prompt
}
// buildSystemPrompt 构建系统提示词
func (r *AiRouterBiz) buildSystemPrompt(prompt string) string {
if len(prompt) == 0 {
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式禁用markdown格式返回\n3.只返回json字符串不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
}
return prompt
}
func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
for _, item := range his {
if len(chatHis.SessionId) == 0 {
chatHis.SessionId = item.SessionID
}
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
Role: item.Role,
Content: item.Content,
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
})
}
chatHis.Context = entitys.HisContext{
UserLanguage: "zh-CN",
SystemMode: "technical_support",
}
return
}
// handleKnowledgeQA 处理知识问答意图
func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
return nil, nil
}