ai_scheduler/internal/biz/router.go

422 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package biz
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constant"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/internal/tools"
"ai_scheduler/tmpl/dataTemp"
"context"
"encoding/json"
"fmt"
"strings"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2"
"github.com/ollama/ollama/api"
"xorm.io/builder"
)
// AiRouterBiz 智能路由服务
type AiRouterBiz struct {
//aiClient entitys.AIClient
toolManager *tools.Manager
sessionImpl *impl.SessionImpl
sysImpl *impl.SysImpl
taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl
conf *config.Config
ai *utils_ollama.Client
}
// NewRouterService 创建路由服务
func NewAiRouterBiz(
//aiClient entitys.AIClient,
toolManager *tools.Manager,
sessionImpl *impl.SessionImpl,
sysImpl *impl.SysImpl,
taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl,
conf *config.Config,
ai *utils_ollama.Client,
) *AiRouterBiz {
return &AiRouterBiz{
//aiClient: aiClient,
toolManager: toolManager,
sessionImpl: sessionImpl,
conf: conf,
sysImpl: sysImpl,
hisImpl: hisImpl,
taskImpl: taskImpl,
ai: ai,
}
}
// Route 执行智能路由
func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
return nil, nil
}
// Route 执行智能路由
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error {
session := c.Query("x-session", "")
if len(session) == 0 {
return errors.SessionNotFound
}
auth := c.Query("x-authorization", "")
if len(auth) == 0 {
return errors.AuthNotFound
}
key := c.Query("x-app-key", "")
if len(key) == 0 {
return errors.KeyNotFound
}
sysInfo, err := r.getSysInfo(key)
if err != nil {
return errors.SysNotFound
}
history, err := r.getSessionChatHis(session)
if err != nil {
return errors.SystemError
}
task, err := r.getTasks(sysInfo.SysID)
if err != nil {
return errors.SystemError
}
//toolDefinitions := r.registerTools(task)
//prompt := r.getPrompt(sysInfo, history, req.Text)
//意图预测
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
toolDefinitions := r.registerTools(task)
match, err := r.ai.ToolSelect(context.TODO(), prompt, toolDefinitions)
if err != nil {
return errors.SystemError
}
log.Info(match)
//var matchJson entitys.Match
//err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
//if err != nil {
// return errors.SystemError
//}
//return r.handleMatch(c, &matchJson, task)
c.WriteMessage(1, []byte(match.Message.Content))
// 构建消息
//messages := []entitys.Message{
// {
// Role: "user",
// Content: req.UserInput,
// },
//}
//
//// 第1次调用AI获取用户意图
//intentResponse, err := r.aiClient.Chat(ctx, messages, nil)
//if err != nil {
// return nil, fmt.Errorf("AI响应失败: %w", err)
//}
//
//// 从AI响应中提取意图
//intent := r.extractIntent(intentResponse)
//if intent == "" {
// return nil, fmt.Errorf("未识别到用户意图")
//}
//
//switch intent {
//case "order_diagnosis":
// // 订单诊断意图
// return r.handleOrderDiagnosis(ctx, req, messages)
//case "knowledge_qa":
// // 知识问答意图
// return r.handleKnowledgeQA(ctx, req, messages)
//default:
// // 未知意图
// return nil, fmt.Errorf("意图识别失败,请明确您的需求呢,我可以为您")
//}
//
//// 获取工具定义
//toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller))
//
//// 第2次调用AI获取是否需要使用工具
//response, err := r.aiClient.Chat(ctx, messages, toolDefinitions)
//if err != nil {
// return nil, fmt.Errorf("failed to chat with AI: %w", err)
//}
//
//// 如果没有工具调用,直接返回
//if len(response.ToolCalls) == 0 {
// return response, nil
//}
//
//// 执行工具调用
//toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls)
//if err != nil {
// return nil, fmt.Errorf("failed to execute tools: %w", err)
//}
//
//// 构建包含工具结果的消息
//messages = append(messages, entitys.Message{
// Role: "assistant",
// Content: response.Message,
//})
//
//// 添加工具调用结果
//for _, toolResult := range toolResults {
// toolResultStr, _ := json.Marshal(toolResult.Result)
// messages = append(messages, entitys.Message{
// Role: "tool",
// Content: fmt.Sprintf("Tool %s result: %s", toolResult.Function.Name, string(toolResultStr)),
// })
//}
//
//// 第二次调用AI生成最终回复
//finalResponse, err := r.aiClient.Chat(ctx, messages, nil)
//if err != nil {
// return nil, fmt.Errorf("failed to generate final response: %w", err)
//}
//
//// 合并工具调用信息到最终响应
//finalResponse.ToolCalls = toolResults
//
//log.Printf("Router processed request: %s, used %d tools", req.UserInput, len(toolResults))
//return finalResponse, nil
return nil
}
func (r *AiRouterBiz) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
var resChan = make(chan []byte, 10)
defer func() {
close(resChan)
if err != nil {
c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
}
c.WriteMessage(websocket.TextMessage, []byte("EOF"))
}()
if !matchJson.IsMatch {
c.WriteMessage(websocket.TextMessage, []byte(matchJson.Reasoning))
return
}
var pointTask *model.AiTask
for _, task := range tasks {
if task.Index == matchJson.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return r.handleOtherTask(resChan, c, matchJson)
}
switch pointTask.Type {
case constant.TaskTypeApi:
err = r.handleApiTask(resChan, c, matchJson, pointTask)
case constant.TaskTypeFunc:
err = r.handleTask(resChan, c, matchJson, pointTask)
default:
return r.handleOtherTask(resChan, c, matchJson)
}
select {
case v := <-resChan: // 尝试接收
fmt.Println("接收到值:", v)
default:
fmt.Println("无数据可接收")
}
return
}
func (r *AiRouterBiz) handleTask(channel chan []byte, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, []byte(matchJson.Parameters))
if err != nil {
return
}
return
}
func (r *AiRouterBiz) handleOtherTask(channel chan []byte, c *websocket.Conn, matchJson *entitys.Match) (err error) {
channel <- []byte(matchJson.Reasoning)
return
}
func (r *AiRouterBiz) handleApiTask(channels chan []byte, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var (
request l_request.Request
auth = c.Headers("X-Authorization", "")
requestParam map[string]interface{}
)
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
if err != nil {
return
}
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
for k, v := range requestParam {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
}
var configData entitys.ConfigDataHttp
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = mapstructure.Decode(configData.Request, &request)
if err != nil {
return
}
if len(request.Url) == 0 {
err = errors.NewBusinessErr("00022", "api地址获取失败")
return
}
res, err := request.Send()
if err != nil {
return
}
c.WriteMessage(1, res.Content)
return
}
func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": sessionId})
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id asc")
return
}
func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": appKey})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
return
}
func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"sys_id": sysId})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
return
}
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []api.Tool {
taskPrompt := make([]api.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfigDetail
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: task.Index,
Description: task.Desc,
Parameters: api.ToolFunctionParameters{
Type: taskConfig.Param.Type,
Required: taskConfig.Param.Required,
Properties: taskConfig.Param.Properties,
},
},
})
}
return taskPrompt
}
func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
var (
prompt = make([]entitys.Message, 0)
)
prompt = append(prompt, entitys.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, entitys.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
}, entitys.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
var (
prompt = make([]api.Message, 0)
)
prompt = append(prompt, api.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, api.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
}, api.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.registerTools(tasks)),
}, api.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
// buildSystemPrompt 构建系统提示词
func (r *AiRouterBiz) buildSystemPrompt(prompt string) string {
if len(prompt) == 0 {
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式禁用markdown格式返回\n3.只返回json字符串不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
}
return prompt
}
func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
for _, item := range his {
if len(chatHis.SessionId) == 0 {
chatHis.SessionId = item.SessionID
}
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
Role: item.Role,
Content: item.Content,
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
})
}
chatHis.Context = entitys.HisContext{
UserLanguage: "zh-CN",
SystemMode: "technical_support",
}
return
}
// handleKnowledgeQA 处理知识问答意图
func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
return nil, nil
}