549 lines
14 KiB
Go
549 lines
14 KiB
Go
package biz
|
||
|
||
import (
|
||
"ai_scheduler/internal/config"
|
||
"ai_scheduler/internal/data/constants"
|
||
errors "ai_scheduler/internal/data/error"
|
||
"ai_scheduler/internal/data/impl"
|
||
"ai_scheduler/internal/data/model"
|
||
"ai_scheduler/internal/entitys"
|
||
"ai_scheduler/internal/pkg"
|
||
"ai_scheduler/internal/pkg/mapstructure"
|
||
"ai_scheduler/internal/pkg/utils_ollama"
|
||
tools "ai_scheduler/internal/tools"
|
||
"ai_scheduler/tmpl/dataTemp"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
|
||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||
"github.com/gofiber/fiber/v2/log"
|
||
"github.com/gofiber/websocket/v2"
|
||
"github.com/ollama/ollama/api"
|
||
"github.com/tmc/langchaingo/llms"
|
||
"xorm.io/builder"
|
||
)
|
||
|
||
// AiRouterBiz 智能路由服务
|
||
type AiRouterBiz struct {
|
||
//aiClient entitys.AIClient
|
||
toolManager *tools.Manager
|
||
sessionImpl *impl.SessionImpl
|
||
sysImpl *impl.SysImpl
|
||
taskImpl *impl.TaskImpl
|
||
hisImpl *impl.ChatImpl
|
||
conf *config.Config
|
||
utilAgent *utils_ollama.UtilOllama
|
||
ollama *utils_ollama.Client
|
||
channelPool *pkg.SafeChannelPool
|
||
rds *pkg.Rdb
|
||
}
|
||
|
||
// NewRouterService 创建路由服务
|
||
func NewAiRouterBiz(
|
||
//aiClient entitys.AIClient,
|
||
toolManager *tools.Manager,
|
||
sessionImpl *impl.SessionImpl,
|
||
sysImpl *impl.SysImpl,
|
||
taskImpl *impl.TaskImpl,
|
||
hisImpl *impl.ChatImpl,
|
||
conf *config.Config,
|
||
utilAgent *utils_ollama.UtilOllama,
|
||
channelPool *pkg.SafeChannelPool,
|
||
ollama *utils_ollama.Client,
|
||
|
||
) *AiRouterBiz {
|
||
return &AiRouterBiz{
|
||
//aiClient: aiClient,
|
||
toolManager: toolManager,
|
||
sessionImpl: sessionImpl,
|
||
conf: conf,
|
||
sysImpl: sysImpl,
|
||
hisImpl: hisImpl,
|
||
taskImpl: taskImpl,
|
||
utilAgent: utilAgent,
|
||
channelPool: channelPool,
|
||
ollama: ollama,
|
||
}
|
||
}
|
||
|
||
// Route 执行智能路由
|
||
func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
|
||
|
||
return nil, nil
|
||
}
|
||
|
||
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
|
||
|
||
session := c.Query("x-session", "")
|
||
if len(session) == 0 {
|
||
return errors.SessionNotFound
|
||
}
|
||
auth := c.Query("x-authorization", "")
|
||
if len(auth) == 0 {
|
||
return errors.AuthNotFound
|
||
}
|
||
key := c.Query("x-app-key", "")
|
||
if len(key) == 0 {
|
||
return errors.KeyNotFound
|
||
}
|
||
|
||
var chat = make([]string, 0)
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
//ch := r.channelPool.Get()
|
||
ch := make(chan entitys.Response)
|
||
done := make(chan struct{})
|
||
go func() {
|
||
defer func() {
|
||
close(done)
|
||
if len(chat) > 0 {
|
||
|
||
}
|
||
var his = []*model.AiChatHi{
|
||
{
|
||
SessionID: session,
|
||
Role: "user",
|
||
Content: req.Text,
|
||
},
|
||
}
|
||
if len(chat) > 0 {
|
||
his = append(his, &model.AiChatHi{
|
||
SessionID: session,
|
||
Role: "assistant",
|
||
Content: strings.Join(chat, ""),
|
||
})
|
||
}
|
||
for _, hi := range his {
|
||
r.hisImpl.Add(hi)
|
||
}
|
||
}()
|
||
for {
|
||
select {
|
||
case v, ok := <-ch:
|
||
if !ok {
|
||
return
|
||
}
|
||
// 带超时的发送,避免阻塞
|
||
if err = sendWithTimeout(c, v, 2*time.Second); err != nil {
|
||
log.Errorf("Send error: %v", err)
|
||
cancel() // 通知主流程退出
|
||
return
|
||
}
|
||
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream {
|
||
chat = append(chat, v.Content)
|
||
}
|
||
|
||
case <-ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
|
||
defer func() {
|
||
close(ch)
|
||
}()
|
||
|
||
sysInfo, err := r.getSysInfo(key)
|
||
if err != nil {
|
||
return errors.SysNotFound
|
||
}
|
||
|
||
history, err := r.getSessionChatHis(session)
|
||
if err != nil {
|
||
return errors.SystemError
|
||
}
|
||
|
||
task, err := r.getTasks(sysInfo.SysID)
|
||
if err != nil {
|
||
return errors.SystemError
|
||
}
|
||
|
||
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
|
||
|
||
AgentClient := r.utilAgent.Get()
|
||
|
||
ch <- entitys.Response{
|
||
Index: "",
|
||
Content: "准备意图识别",
|
||
Type: entitys.ResponseLog,
|
||
}
|
||
|
||
match, err := AgentClient.Llm.GenerateContent(
|
||
ctx, // 使用可取消的上下文
|
||
prompt,
|
||
llms.WithJSONMode(),
|
||
)
|
||
resMsg := match.Choices[0].Content
|
||
|
||
r.utilAgent.Put(AgentClient)
|
||
ch <- entitys.Response{
|
||
Index: "",
|
||
Content: resMsg,
|
||
Type: entitys.ResponseLog,
|
||
}
|
||
ch <- entitys.Response{
|
||
Index: "",
|
||
Content: "意图识别结束",
|
||
Type: entitys.ResponseLog,
|
||
}
|
||
|
||
if err != nil {
|
||
log.Errorf("LLM error: %v", err)
|
||
return errors.SystemError
|
||
}
|
||
var matchJson entitys.Match
|
||
if err := json.Unmarshal([]byte(resMsg), &matchJson); err != nil {
|
||
return errors.SystemError
|
||
}
|
||
|
||
if err := r.handleMatch(ctx, c, ch, &matchJson, task, sysInfo); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 辅助函数:带超时的 WebSocket 发送
|
||
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
|
||
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
|
||
done := make(chan error, 1)
|
||
go func() {
|
||
done <- entitys.MsgSend(c, data)
|
||
}()
|
||
|
||
select {
|
||
case err := <-done:
|
||
return err
|
||
case <-sendCtx.Done():
|
||
return sendCtx.Err()
|
||
}
|
||
}
|
||
func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match) (err error) {
|
||
ch <- entitys.Response{
|
||
Index: "",
|
||
Content: matchJson.Reasoning,
|
||
Type: entitys.ResponseText,
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) {
|
||
|
||
if !matchJson.IsMatch {
|
||
ch <- entitys.Response{
|
||
Index: "",
|
||
Content: matchJson.Reasoning,
|
||
Type: entitys.ResponseText,
|
||
}
|
||
return
|
||
}
|
||
var pointTask *model.AiTask
|
||
for _, task := range tasks {
|
||
if task.Index == matchJson.Index {
|
||
pointTask = &task
|
||
break
|
||
}
|
||
}
|
||
|
||
if pointTask == nil || pointTask.Index == "other" {
|
||
return r.handleOtherTask(c, ch, matchJson)
|
||
}
|
||
switch pointTask.Type {
|
||
case constants.TaskTypeApi:
|
||
return r.handleApiTask(ch, c, matchJson, pointTask)
|
||
case constants.TaskTypeFunc:
|
||
return r.handleTask(ch, c, matchJson, pointTask)
|
||
case constants.TaskTypeKnowle:
|
||
return r.handleKnowle(ch, c, matchJson, pointTask, sysInfo)
|
||
default:
|
||
return r.handleOtherTask(c, ch, matchJson)
|
||
}
|
||
}
|
||
|
||
func (r *AiRouterBiz) handleTask(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
|
||
|
||
var configData entitys.ConfigDataTool
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, []byte(matchJson.Parameters))
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// 知识库
|
||
func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask, sysInfo model.AiSy) (err error) {
|
||
|
||
var (
|
||
configData entitys.ConfigDataTool
|
||
sessionIdKnowledge string
|
||
query string
|
||
host string
|
||
)
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 通过session 找到知识库session
|
||
session := c.Headers("X-Session", "")
|
||
if len(session) == 0 {
|
||
return errors.SessionNotFound
|
||
}
|
||
sessionInfo, has, err := r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(session))
|
||
if err != nil {
|
||
return
|
||
} else if !has {
|
||
return errors.SessionNotFound
|
||
}
|
||
|
||
// 找到知识库的host
|
||
{
|
||
tool, exists := r.toolManager.GetTool(configData.Tool)
|
||
if !exists {
|
||
return fmt.Errorf("tool not found: %s", configData.Tool)
|
||
}
|
||
|
||
if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok {
|
||
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
|
||
} else {
|
||
host = knowledgeTool.GetConfig().BaseURL
|
||
}
|
||
|
||
}
|
||
|
||
// 知识库的session为空,请求知识库获取, 并绑定
|
||
if sessionInfo.KnowlegeSessionID == "" {
|
||
// 请求知识库
|
||
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, sysInfo.KnowlegeBaseID, sysInfo.KnowlegeTenantKey); err != nil {
|
||
return
|
||
}
|
||
|
||
// 绑定知识库session,下次可以使用
|
||
sessionInfo.KnowlegeSessionID = sessionIdKnowledge
|
||
if err = r.sessionImpl.Update(&sessionInfo, r.sessionImpl.WithSessionId(sessionInfo.SessionID)); err != nil {
|
||
return
|
||
}
|
||
}
|
||
|
||
// 用户输入解析
|
||
var ok bool
|
||
input := make(map[string]string)
|
||
if err = json.Unmarshal([]byte(matchJson.Parameters), &input); err != nil {
|
||
return
|
||
}
|
||
if query, ok = input["query"]; !ok {
|
||
return fmt.Errorf("query不能为空")
|
||
}
|
||
|
||
knowledgeConfig := tools.KnowledgeBaseRequest{
|
||
Session: sessionInfo.KnowlegeSessionID,
|
||
ApiKey: sysInfo.KnowlegeTenantKey,
|
||
Query: query,
|
||
}
|
||
b, err := json.Marshal(knowledgeConfig)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 执行工具
|
||
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, b)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) handleApiTask(channels chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
|
||
var (
|
||
request l_request.Request
|
||
auth = c.Query("x-authorization", "")
|
||
requestParam map[string]interface{}
|
||
)
|
||
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
|
||
if err != nil {
|
||
return
|
||
}
|
||
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
|
||
for k, v := range requestParam {
|
||
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
|
||
}
|
||
var configData entitys.ConfigDataHttp
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = mapstructure.Decode(configData.Request, &request)
|
||
if err != nil {
|
||
return
|
||
}
|
||
if len(request.Url) == 0 {
|
||
err = errors.NewBusinessErr(422, "api地址获取失败")
|
||
return
|
||
}
|
||
res, err := request.Send()
|
||
if err != nil {
|
||
return
|
||
}
|
||
c.WriteMessage(1, res.Content)
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
||
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"session_id": sessionId})
|
||
|
||
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id asc")
|
||
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"app_key": appKey})
|
||
cond = cond.And(builder.IsNull{"delete_at"})
|
||
cond = cond.And(builder.Eq{"status": 1})
|
||
err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||
cond = cond.And(builder.IsNull{"delete_at"})
|
||
cond = cond.And(builder.Eq{"status": 1})
|
||
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
|
||
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool {
|
||
taskPrompt := make([]llms.Tool, 0)
|
||
for _, task := range tasks {
|
||
var taskConfig entitys.TaskConfig
|
||
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
taskPrompt = append(taskPrompt, llms.Tool{
|
||
Type: "function",
|
||
Function: &llms.FunctionDefinition{
|
||
Name: task.Index,
|
||
Description: task.Desc,
|
||
Parameters: taskConfig.Param,
|
||
},
|
||
})
|
||
|
||
}
|
||
return taskPrompt
|
||
}
|
||
|
||
func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
|
||
var (
|
||
prompt = make([]entitys.Message, 0)
|
||
)
|
||
prompt = append(prompt, entitys.Message{
|
||
Role: "system",
|
||
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
||
}, entitys.Message{
|
||
Role: "assistant",
|
||
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
|
||
}, entitys.Message{
|
||
Role: "user",
|
||
Content: reqInput,
|
||
})
|
||
return prompt
|
||
}
|
||
|
||
func (r *AiRouterBiz) getPromptOllama(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []api.Message {
|
||
var (
|
||
prompt = make([]api.Message, 0)
|
||
)
|
||
prompt = append(prompt, api.Message{
|
||
Role: "system",
|
||
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
||
}, api.Message{
|
||
Role: "assistant",
|
||
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
|
||
}, api.Message{
|
||
Role: "user",
|
||
Content: reqInput,
|
||
})
|
||
return prompt
|
||
}
|
||
|
||
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||
var (
|
||
prompt = make([]llms.MessageContent, 0)
|
||
)
|
||
prompt = append(prompt, llms.MessageContent{
|
||
Role: llms.ChatMessageTypeSystem,
|
||
Parts: []llms.ContentPart{
|
||
llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)),
|
||
},
|
||
}, llms.MessageContent{
|
||
Role: llms.ChatMessageTypeTool,
|
||
Parts: []llms.ContentPart{
|
||
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
|
||
},
|
||
}, llms.MessageContent{
|
||
Role: llms.ChatMessageTypeTool,
|
||
Parts: []llms.ContentPart{
|
||
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
||
},
|
||
}, llms.MessageContent{
|
||
Role: llms.ChatMessageTypeHuman,
|
||
Parts: []llms.ContentPart{
|
||
llms.TextPart(reqInput),
|
||
},
|
||
})
|
||
return prompt
|
||
}
|
||
|
||
// buildSystemPrompt 构建系统提示词
|
||
func (r *AiRouterBiz) buildSystemPrompt(prompt string) string {
|
||
if len(prompt) == 0 {
|
||
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
|
||
}
|
||
|
||
return prompt
|
||
}
|
||
|
||
func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
|
||
for _, item := range his {
|
||
if len(chatHis.SessionId) == 0 {
|
||
chatHis.SessionId = item.SessionID
|
||
}
|
||
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
|
||
Role: item.Role,
|
||
Content: item.Content,
|
||
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
|
||
})
|
||
}
|
||
chatHis.Context = entitys.HisContext{
|
||
UserLanguage: "zh-CN",
|
||
SystemMode: "technical_support",
|
||
}
|
||
return
|
||
}
|
||
|
||
// handleKnowledgeQA 处理知识问答意图
|
||
func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
|
||
|
||
return nil, nil
|
||
}
|