结构修改

This commit is contained in:
renzhiyuan 2025-09-22 11:46:06 +08:00
parent 32866d59a1
commit 570e13e527
5 changed files with 97 additions and 39 deletions

View File

@ -25,8 +25,8 @@ import (
"xorm.io/builder" "xorm.io/builder"
) )
// AiRouterService 智能路由服务 // AiRouterBiz 智能路由服务
type AiRouterService struct { type AiRouterBiz struct {
//aiClient entitys.AIClient //aiClient entitys.AIClient
toolManager *tools.Manager toolManager *tools.Manager
sessionImpl *impl.SessionImpl sessionImpl *impl.SessionImpl
@ -48,8 +48,8 @@ func NewAiRouterBiz(
conf *config.Config, conf *config.Config,
utilAgent *utils_ollama.UtilOllama, utilAgent *utils_ollama.UtilOllama,
) entitys.RouterService { ) *AiRouterBiz {
return &AiRouterService{ return &AiRouterBiz{
//aiClient: aiClient, //aiClient: aiClient,
toolManager: toolManager, toolManager: toolManager,
sessionImpl: sessionImpl, sessionImpl: sessionImpl,
@ -62,13 +62,13 @@ func NewAiRouterBiz(
} }
// Route 执行智能路由 // Route 执行智能路由
func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) { func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
return nil, nil return nil, nil
} }
// Route 执行智能路由 // Route 执行智能路由
func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error { func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error {
session := c.Headers("X-Session", "") session := c.Headers("X-Session", "")
if len(session) == 0 { if len(session) == 0 {
@ -201,7 +201,7 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
return nil return nil
} }
func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) { func (r *AiRouterBiz) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
defer func() { defer func() {
if err != nil { if err != nil {
c.WriteMessage(websocket.TextMessage, []byte(err.Error())) c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
@ -235,7 +235,7 @@ func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Matc
return return
} }
func (r *AiRouterService) handleTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { func (r *AiRouterBiz) handleTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData) err = json.Unmarshal([]byte(task.Config), &configData)
@ -250,14 +250,14 @@ func (r *AiRouterService) handleTask(c *websocket.Conn, matchJson *entitys.Match
return return
} }
func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) { func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) {
c.WriteMessage(1, []byte(matchJson.Reasoning)) c.WriteMessage(1, []byte(matchJson.Reasoning))
return return
} }
func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { func (r *AiRouterBiz) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var ( var (
request l_request.Request request l_request.Request
auth = c.Headers("X-Authorization", "") auth = c.Headers("X-Authorization", "")
@ -292,7 +292,7 @@ func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Ma
return return
} }
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) { func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
cond := builder.NewCond() cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": sessionId}) cond = cond.And(builder.Eq{"session_id": sessionId})
@ -302,7 +302,7 @@ func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiCha
return return
} }
func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err error) { func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
cond := builder.NewCond() cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": appKey}) cond = cond.And(builder.Eq{"app_key": appKey})
cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.IsNull{"delete_at"})
@ -311,7 +311,7 @@ func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err err
return return
} }
func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error) { func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
cond := builder.NewCond() cond := builder.NewCond()
cond = cond.And(builder.Eq{"sys_id": sysId}) cond = cond.And(builder.Eq{"sys_id": sysId})
@ -322,7 +322,7 @@ func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error
return return
} }
func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool { func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool {
taskPrompt := make([]llms.Tool, 0) taskPrompt := make([]llms.Tool, 0)
for _, task := range tasks { for _, task := range tasks {
var taskConfig entitys.TaskConfig var taskConfig entitys.TaskConfig
@ -343,7 +343,7 @@ func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool {
return taskPrompt return taskPrompt
} }
func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message { func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
var ( var (
prompt = make([]entitys.Message, 0) prompt = make([]entitys.Message, 0)
) )
@ -360,7 +360,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi
return prompt return prompt
} }
func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
var ( var (
prompt = make([]llms.MessageContent, 0) prompt = make([]llms.MessageContent, 0)
) )
@ -389,7 +389,7 @@ func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiCha
} }
// buildSystemPrompt 构建系统提示词 // buildSystemPrompt 构建系统提示词
func (r *AiRouterService) buildSystemPrompt(prompt string) string { func (r *AiRouterBiz) buildSystemPrompt(prompt string) string {
if len(prompt) == 0 { if len(prompt) == 0 {
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式禁用markdown格式返回\n3.只返回json字符串不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容" prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式禁用markdown格式返回\n3.只返回json字符串不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
} }
@ -397,7 +397,7 @@ func (r *AiRouterService) buildSystemPrompt(prompt string) string {
return prompt return prompt
} }
func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) { func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
for _, item := range his { for _, item := range his {
if len(chatHis.SessionId) == 0 { if len(chatHis.SessionId) == 0 {
chatHis.SessionId = item.SessionID chatHis.SessionId = item.SessionID
@ -416,7 +416,7 @@ func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.
} }
// handleKnowledgeQA 处理知识问答意图 // handleKnowledgeQA 处理知识问答意图
func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
return nil, nil return nil, nil
} }

View File

@ -0,0 +1,54 @@
package biz
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/internal/tools"
"ai_scheduler/utils"
"encoding/json"
"flag"
"testing"
"github.com/gofiber/fiber/v2/log"
)
func Test_task(t *testing.T) {
var c entitys.TaskConfig
config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
err := json.Unmarshal([]byte(config), &c)
t.Log(err)
}
type configData struct {
Param map[string]interface{} `json:"param"`
Do map[string]interface{} `json:"do"`
}
func Test_Order(t *testing.T) {
routerBiz := in()
err := routerBiz.handleTask(nil, &entitys.Match{Index: "order_diagnosis"}, &model.AiTask{Config: `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`})
t.Log(err)
}
func in() *AiRouterBiz {
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
flag.Parse()
configConfig, err := config.LoadConfig(*configPath)
if err != nil {
panic("加载配置失败")
}
allLogger := log.DefaultLogger()
manager := tools.NewManager(configConfig)
db, _ := utils.NewGormDb(configConfig)
sessionImpl := impl.NewSessionImpl(db)
sysImpl := impl.NewSysImpl(db)
taskImpl := impl.NewTaskImpl(db)
chatImpl := impl.NewChatImpl(db)
utilOllama := utils_ollama.NewUtilOllama(configConfig, allLogger)
routerBiz := NewAiRouterBiz(manager, sessionImpl, sysImpl, taskImpl, chatImpl, configConfig, utilOllama)
return routerBiz
}

View File

@ -1,6 +1,7 @@
package services package services
import ( import (
"ai_scheduler/internal/biz"
"ai_scheduler/internal/data/constant" "ai_scheduler/internal/data/constant"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway" "ai_scheduler/internal/gateway"
@ -17,15 +18,15 @@ import (
// ChatHandler 聊天处理器 // ChatHandler 聊天处理器
type ChatService struct { type ChatService struct {
routerService entitys.RouterService routerBiz *biz.AiRouterBiz
Gw *gateway.Gateway Gw *gateway.Gateway
mu sync.Mutex mu sync.Mutex
} }
// NewChatHandler 创建聊天处理器 // NewChatHandler 创建聊天处理器
func NewChatService(routerService entitys.RouterService, gw *gateway.Gateway) *ChatService { func NewChatService(routerService *biz.AiRouterBiz, gw *gateway.Gateway) *ChatService {
return &ChatService{ return &ChatService{
routerService: routerService, routerBiz: routerService,
Gw: gw, Gw: gw,
} }
} }
@ -100,7 +101,7 @@ func (h *ChatService) Chat(c *websocket.Conn) {
log.Println("JSON parse error:", err) log.Println("JSON parse error:", err)
continue continue
} }
err = h.routerService.RouteWithSocket(c, &req) err = h.routerBiz.RouteWithSocket(c, &req)
if err != nil { if err != nil {
log.Println("处理失败:", err) log.Println("处理失败:", err)
continue continue

View File

@ -27,7 +27,7 @@ func Test_task2(t *testing.T) {
config configData config configData
) )
configJson := `{"param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}, "do": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}` configJson := `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`
err := json.Unmarshal([]byte(configJson), &config) err := json.Unmarshal([]byte(configJson), &config)
if err != nil { if err != nil {
panic(err) panic(err)
@ -35,3 +35,7 @@ func Test_task2(t *testing.T) {
mapstructure.Decode(config.Do, &c) mapstructure.Decode(config.Do, &c)
t.Log(err) t.Log(err)
} }
func in() {
}

View File

@ -6,8 +6,8 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
) )
@ -87,21 +87,20 @@ func (w *ZltxOrderDetailTool) Execute(c *websocket.Conn, args json.RawMessage) e
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据 // getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *ZltxOrderDetailTool) getZltxOrderDetail(c *websocket.Conn, number string) (err error) { func (w *ZltxOrderDetailTool) getZltxOrderDetail(c *websocket.Conn, number string) (err error) {
url := fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number) //查询订单详情
authorization := fmt.Sprintf("Bearer %s", w.config.APIKey) req := l_request.Request{
Url: fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number),
Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", w.config.APIKey),
},
Method: "GET",
}
res, err := req.Send()
// 发送http请求
req, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {
return return
} }
req.Header.Set("Authorization", authorization) c.WriteMessage(websocket.TextMessage, res.Content)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
return return
} }