ai_scheduler/internal/biz/router.go

335 lines
9.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package biz
import (
"ai_scheduler/internal/config"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/internal/tools"
"ai_scheduler/tmpl/dataTemp"
"context"
"encoding/json"
"log"
"github.com/tmc/langchaingo/llms"
"github.com/gofiber/websocket/v2"
"xorm.io/builder"
)
// AiRouterService 智能路由服务
type AiRouterService struct {
//aiClient entitys.AIClient
toolManager *tools.Manager
sessionImpl *impl.SessionImpl
sysImpl *impl.SysImpl
taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl
conf *config.Config
utilAgent *utils_ollama.UtilOllama
}
// NewRouterService 创建路由服务
func NewAiRouterBiz(
//aiClient entitys.AIClient,
toolManager *tools.Manager,
sessionImpl *impl.SessionImpl,
sysImpl *impl.SysImpl,
taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl,
conf *config.Config,
utilAgent *utils_ollama.UtilOllama,
) entitys.RouterService {
return &AiRouterService{
//aiClient: aiClient,
toolManager: toolManager,
sessionImpl: sessionImpl,
conf: conf,
sysImpl: sysImpl,
hisImpl: hisImpl,
taskImpl: taskImpl,
utilAgent: utilAgent,
}
}
// Route 执行智能路由
func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
return nil, nil
}
// Route 执行智能路由
func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error {
session := c.Headers("X-Session", "")
if len(session) == 0 {
return errors.SessionNotFound
}
auth := c.Headers("X-Authorization", "")
if len(auth) == 0 {
return errors.AuthNotFound
}
key := c.Headers("X-App-Key", "")
if len(key) == 0 {
return errors.KeyNotFound
}
sysInfo, err := r.getSysInfo(key)
if err != nil {
return errors.SysNotFound
}
history, err := r.getSessionChatHis(session)
if err != nil {
return errors.SystemError
}
task, err := r.getTasks(sysInfo.SysID)
if err != nil {
return errors.SystemError
}
toolDefinitions := r.registerTools(task)
prompt := r.getPrompt(sysInfo, history, req.Text)
msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt),
llms.WithTools(toolDefinitions),
llms.WithToolChoice(llms.FunctionCallBehaviorAuto),
llms.WithJSONMode(),
)
c.WriteMessage(1, []byte(msg))
// 构建消息
//messages := []entitys.Message{
// {
// Role: "user",
// Content: req.UserInput,
// },
//}
//
//// 第1次调用AI获取用户意图
//intentResponse, err := r.aiClient.Chat(ctx, messages, nil)
//if err != nil {
// return nil, fmt.Errorf("AI响应失败: %w", err)
//}
//
//// 从AI响应中提取意图
//intent := r.extractIntent(intentResponse)
//if intent == "" {
// return nil, fmt.Errorf("未识别到用户意图")
//}
//
//switch intent {
//case "order_diagnosis":
// // 订单诊断意图
// return r.handleOrderDiagnosis(ctx, req, messages)
//case "knowledge_qa":
// // 知识问答意图
// return r.handleKnowledgeQA(ctx, req, messages)
//default:
// // 未知意图
// return nil, fmt.Errorf("意图识别失败,请明确您的需求呢,我可以为您")
//}
//
//// 获取工具定义
//toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller))
//
//// 第2次调用AI获取是否需要使用工具
//response, err := r.aiClient.Chat(ctx, messages, toolDefinitions)
//if err != nil {
// return nil, fmt.Errorf("failed to chat with AI: %w", err)
//}
//
//// 如果没有工具调用,直接返回
//if len(response.ToolCalls) == 0 {
// return response, nil
//}
//
//// 执行工具调用
//toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls)
//if err != nil {
// return nil, fmt.Errorf("failed to execute tools: %w", err)
//}
//
//// 构建包含工具结果的消息
//messages = append(messages, entitys.Message{
// Role: "assistant",
// Content: response.Message,
//})
//
//// 添加工具调用结果
//for _, toolResult := range toolResults {
// toolResultStr, _ := json.Marshal(toolResult.Result)
// messages = append(messages, entitys.Message{
// Role: "tool",
// Content: fmt.Sprintf("Tool %s result: %s", toolResult.Function.Name, string(toolResultStr)),
// })
//}
//
//// 第二次调用AI生成最终回复
//finalResponse, err := r.aiClient.Chat(ctx, messages, nil)
//if err != nil {
// return nil, fmt.Errorf("failed to generate final response: %w", err)
//}
//
//// 合并工具调用信息到最终响应
//finalResponse.ToolCalls = toolResults
//
//log.Printf("Router processed request: %s, used %d tools", req.UserInput, len(toolResults))
//return finalResponse, nil
return nil
}
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": sessionId})
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his)
return
}
func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": appKey})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
return
}
func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"sys_id": sysId})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks)
return
}
func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool {
taskPrompt := make([]llms.Tool, len(tasks))
for k, task := range tasks {
var taskConfig entitys.TaskConfig
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt[k].Type = "function"
taskPrompt[k].Function = &llms.FunctionDefinition{
Name: task.Name,
Description: task.Desc,
Parameters: taskConfig.Param,
}
}
return taskPrompt
}
func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
var (
prompt = make([]entitys.Message, 0)
)
prompt = append(prompt, entitys.Message{}, entitys.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, entitys.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
}, entitys.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
// buildSystemPrompt 构建系统提示词
func (r *AiRouterService) buildSystemPrompt(prompt string) string {
if len(prompt) == 0 {
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式禁用markdown格式返回\n3.只返回json字符串不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
}
return prompt
}
func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
for _, item := range his {
if len(chatHis.SessionId) == 0 {
chatHis.SessionId = item.SessionID
}
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
Role: item.Role,
Content: item.Content,
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
})
}
chatHis.Context = entitys.HisContext{
UserLanguage: "zh-CN",
SystemMode: "technical_support",
}
return
}
// extractIntent 从AI响应中提取意图
func (r *AiRouterService) extractIntent(response *entitys.ChatResponse) string {
if response == nil || response.Message == "" {
return ""
}
// 尝试解析JSON
var intent struct {
Intent string `json:"intent"`
Confidence string `json:"confidence"`
Reasoning string `json:"reasoning"`
}
err := json.Unmarshal([]byte(response.Message), &intent)
if err != nil {
log.Printf("Failed to parse intent JSON: %v", err)
return ""
}
return intent.Intent
}
// handleOrderDiagnosis 处理订单诊断意图
func (r *AiRouterService) handleOrderDiagnosis(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
// 调用订单详情工具
//orderDetailTool, ok := r.toolManager.GetTool("zltxOrderDetail")
//if orderDetailTool == nil || !ok {
// return nil, fmt.Errorf("order detail tool not found")
//}
//orderDetailTool.Execute(ctx, json.RawMessage{})
//
//// 获取相关工具定义
//toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller))
//
//// 调用AI获取是否需要使用工具
//response, err := r.aiClient.Chat(ctx, messages, toolDefinitions)
//if err != nil {
// return nil, fmt.Errorf("failed to chat with AI: %w", err)
//}
//
//// 如果没有工具调用,直接返回
//if len(response.ToolCalls) == 0 {
// return response, nil
//}
//
//// 执行工具调用
//toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls)
//if err != nil {
// return nil, fmt.Errorf("failed to execute tools: %w", err)
//}
return nil, nil
}
// handleKnowledgeQA 处理知识问答意图
func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
return nil, nil
}