335 lines
9.5 KiB
Go
335 lines
9.5 KiB
Go
package biz
|
||
|
||
import (
|
||
"ai_scheduler/internal/config"
|
||
errors "ai_scheduler/internal/data/error"
|
||
"ai_scheduler/internal/data/impl"
|
||
"ai_scheduler/internal/data/model"
|
||
"ai_scheduler/internal/entitys"
|
||
"ai_scheduler/internal/pkg"
|
||
"ai_scheduler/internal/pkg/utils_ollama"
|
||
"ai_scheduler/internal/tools"
|
||
"ai_scheduler/tmpl/dataTemp"
|
||
"context"
|
||
"encoding/json"
|
||
"log"
|
||
|
||
"github.com/tmc/langchaingo/llms"
|
||
|
||
"github.com/gofiber/websocket/v2"
|
||
"xorm.io/builder"
|
||
)
|
||
|
||
// AiRouterService 智能路由服务
|
||
type AiRouterService struct {
|
||
//aiClient entitys.AIClient
|
||
toolManager *tools.Manager
|
||
sessionImpl *impl.SessionImpl
|
||
sysImpl *impl.SysImpl
|
||
taskImpl *impl.TaskImpl
|
||
hisImpl *impl.ChatHisImpl
|
||
conf *config.Config
|
||
utilAgent *utils_ollama.UtilOllama
|
||
}
|
||
|
||
// NewRouterService 创建路由服务
|
||
func NewAiRouterBiz(
|
||
//aiClient entitys.AIClient,
|
||
toolManager *tools.Manager,
|
||
sessionImpl *impl.SessionImpl,
|
||
sysImpl *impl.SysImpl,
|
||
taskImpl *impl.TaskImpl,
|
||
hisImpl *impl.ChatHisImpl,
|
||
conf *config.Config,
|
||
utilAgent *utils_ollama.UtilOllama,
|
||
) entitys.RouterService {
|
||
return &AiRouterService{
|
||
//aiClient: aiClient,
|
||
toolManager: toolManager,
|
||
sessionImpl: sessionImpl,
|
||
conf: conf,
|
||
sysImpl: sysImpl,
|
||
hisImpl: hisImpl,
|
||
taskImpl: taskImpl,
|
||
utilAgent: utilAgent,
|
||
}
|
||
}
|
||
|
||
// Route 执行智能路由
|
||
func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
|
||
|
||
return nil, nil
|
||
}
|
||
|
||
// Route 执行智能路由
|
||
func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error {
|
||
|
||
session := c.Headers("X-Session", "")
|
||
if len(session) == 0 {
|
||
return errors.SessionNotFound
|
||
}
|
||
auth := c.Headers("X-Authorization", "")
|
||
if len(auth) == 0 {
|
||
return errors.AuthNotFound
|
||
}
|
||
key := c.Headers("X-App-Key", "")
|
||
if len(key) == 0 {
|
||
return errors.KeyNotFound
|
||
}
|
||
|
||
sysInfo, err := r.getSysInfo(key)
|
||
if err != nil {
|
||
return errors.SysNotFound
|
||
}
|
||
|
||
history, err := r.getSessionChatHis(session)
|
||
if err != nil {
|
||
return errors.SystemError
|
||
}
|
||
|
||
task, err := r.getTasks(sysInfo.SysID)
|
||
if err != nil {
|
||
return errors.SystemError
|
||
}
|
||
|
||
toolDefinitions := r.registerTools(task)
|
||
prompt := r.getPrompt(sysInfo, history, req.Text)
|
||
msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt),
|
||
llms.WithTools(toolDefinitions),
|
||
llms.WithToolChoice(llms.FunctionCallBehaviorAuto),
|
||
llms.WithJSONMode(),
|
||
)
|
||
c.WriteMessage(1, []byte(msg))
|
||
// 构建消息
|
||
//messages := []entitys.Message{
|
||
// {
|
||
// Role: "user",
|
||
// Content: req.UserInput,
|
||
// },
|
||
//}
|
||
//
|
||
//// 第1次调用AI,获取用户意图
|
||
//intentResponse, err := r.aiClient.Chat(ctx, messages, nil)
|
||
//if err != nil {
|
||
// return nil, fmt.Errorf("AI响应失败: %w", err)
|
||
//}
|
||
//
|
||
//// 从AI响应中提取意图
|
||
//intent := r.extractIntent(intentResponse)
|
||
//if intent == "" {
|
||
// return nil, fmt.Errorf("未识别到用户意图")
|
||
//}
|
||
//
|
||
//switch intent {
|
||
//case "order_diagnosis":
|
||
// // 订单诊断意图
|
||
// return r.handleOrderDiagnosis(ctx, req, messages)
|
||
//case "knowledge_qa":
|
||
// // 知识问答意图
|
||
// return r.handleKnowledgeQA(ctx, req, messages)
|
||
//default:
|
||
// // 未知意图
|
||
// return nil, fmt.Errorf("意图识别失败,请明确您的需求呢,我可以为您")
|
||
//}
|
||
//
|
||
//// 获取工具定义
|
||
//toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller))
|
||
//
|
||
//// 第2次调用AI,获取是否需要使用工具
|
||
//response, err := r.aiClient.Chat(ctx, messages, toolDefinitions)
|
||
//if err != nil {
|
||
// return nil, fmt.Errorf("failed to chat with AI: %w", err)
|
||
//}
|
||
//
|
||
//// 如果没有工具调用,直接返回
|
||
//if len(response.ToolCalls) == 0 {
|
||
// return response, nil
|
||
//}
|
||
//
|
||
//// 执行工具调用
|
||
//toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls)
|
||
//if err != nil {
|
||
// return nil, fmt.Errorf("failed to execute tools: %w", err)
|
||
//}
|
||
//
|
||
//// 构建包含工具结果的消息
|
||
//messages = append(messages, entitys.Message{
|
||
// Role: "assistant",
|
||
// Content: response.Message,
|
||
//})
|
||
//
|
||
//// 添加工具调用结果
|
||
//for _, toolResult := range toolResults {
|
||
// toolResultStr, _ := json.Marshal(toolResult.Result)
|
||
// messages = append(messages, entitys.Message{
|
||
// Role: "tool",
|
||
// Content: fmt.Sprintf("Tool %s result: %s", toolResult.Function.Name, string(toolResultStr)),
|
||
// })
|
||
//}
|
||
//
|
||
//// 第二次调用AI,生成最终回复
|
||
//finalResponse, err := r.aiClient.Chat(ctx, messages, nil)
|
||
//if err != nil {
|
||
// return nil, fmt.Errorf("failed to generate final response: %w", err)
|
||
//}
|
||
//
|
||
//// 合并工具调用信息到最终响应
|
||
//finalResponse.ToolCalls = toolResults
|
||
//
|
||
//log.Printf("Router processed request: %s, used %d tools", req.UserInput, len(toolResults))
|
||
|
||
//return finalResponse, nil
|
||
return nil
|
||
}
|
||
|
||
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
||
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"session_id": sessionId})
|
||
|
||
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his)
|
||
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"app_key": appKey})
|
||
cond = cond.And(builder.IsNull{"delete_at"})
|
||
cond = cond.And(builder.Eq{"status": 1})
|
||
err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||
cond = cond.And(builder.IsNull{"delete_at"})
|
||
cond = cond.And(builder.Eq{"status": 1})
|
||
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks)
|
||
|
||
return
|
||
}
|
||
|
||
func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool {
|
||
taskPrompt := make([]llms.Tool, len(tasks))
|
||
for k, task := range tasks {
|
||
var taskConfig entitys.TaskConfig
|
||
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
taskPrompt[k].Type = "function"
|
||
taskPrompt[k].Function = &llms.FunctionDefinition{
|
||
Name: task.Name,
|
||
Description: task.Desc,
|
||
Parameters: taskConfig.Param,
|
||
}
|
||
}
|
||
return taskPrompt
|
||
}
|
||
|
||
func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
|
||
var (
|
||
prompt = make([]entitys.Message, 0)
|
||
)
|
||
prompt = append(prompt, entitys.Message{}, entitys.Message{
|
||
Role: "system",
|
||
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
||
}, entitys.Message{
|
||
Role: "assistant",
|
||
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
|
||
}, entitys.Message{
|
||
Role: "user",
|
||
Content: reqInput,
|
||
})
|
||
return prompt
|
||
}
|
||
|
||
// buildSystemPrompt 构建系统提示词
|
||
func (r *AiRouterService) buildSystemPrompt(prompt string) string {
|
||
if len(prompt) == 0 {
|
||
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
|
||
}
|
||
|
||
return prompt
|
||
}
|
||
|
||
func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
|
||
for _, item := range his {
|
||
if len(chatHis.SessionId) == 0 {
|
||
chatHis.SessionId = item.SessionID
|
||
}
|
||
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
|
||
Role: item.Role,
|
||
Content: item.Content,
|
||
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
|
||
})
|
||
}
|
||
chatHis.Context = entitys.HisContext{
|
||
UserLanguage: "zh-CN",
|
||
SystemMode: "technical_support",
|
||
}
|
||
return
|
||
}
|
||
|
||
// extractIntent 从AI响应中提取意图
|
||
func (r *AiRouterService) extractIntent(response *entitys.ChatResponse) string {
|
||
if response == nil || response.Message == "" {
|
||
return ""
|
||
}
|
||
|
||
// 尝试解析JSON
|
||
var intent struct {
|
||
Intent string `json:"intent"`
|
||
Confidence string `json:"confidence"`
|
||
Reasoning string `json:"reasoning"`
|
||
}
|
||
err := json.Unmarshal([]byte(response.Message), &intent)
|
||
if err != nil {
|
||
log.Printf("Failed to parse intent JSON: %v", err)
|
||
return ""
|
||
}
|
||
|
||
return intent.Intent
|
||
}
|
||
|
||
// handleOrderDiagnosis 处理订单诊断意图
|
||
func (r *AiRouterService) handleOrderDiagnosis(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
|
||
// 调用订单详情工具
|
||
//orderDetailTool, ok := r.toolManager.GetTool("zltxOrderDetail")
|
||
//if orderDetailTool == nil || !ok {
|
||
// return nil, fmt.Errorf("order detail tool not found")
|
||
//}
|
||
//orderDetailTool.Execute(ctx, json.RawMessage{})
|
||
//
|
||
//// 获取相关工具定义
|
||
//toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller))
|
||
//
|
||
//// 调用AI,获取是否需要使用工具
|
||
//response, err := r.aiClient.Chat(ctx, messages, toolDefinitions)
|
||
//if err != nil {
|
||
// return nil, fmt.Errorf("failed to chat with AI: %w", err)
|
||
//}
|
||
//
|
||
//// 如果没有工具调用,直接返回
|
||
//if len(response.ToolCalls) == 0 {
|
||
// return response, nil
|
||
//}
|
||
//
|
||
//// 执行工具调用
|
||
//toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls)
|
||
//if err != nil {
|
||
// return nil, fmt.Errorf("failed to execute tools: %w", err)
|
||
//}
|
||
|
||
return nil, nil
|
||
}
|
||
|
||
// handleKnowledgeQA 处理知识问答意图
|
||
func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
|
||
|
||
return nil, nil
|
||
}
|