结构修改
This commit is contained in:
parent
32866d59a1
commit
570e13e527
|
@ -25,8 +25,8 @@ import (
|
|||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// AiRouterService 智能路由服务
|
||||
type AiRouterService struct {
|
||||
// AiRouterBiz 智能路由服务
|
||||
type AiRouterBiz struct {
|
||||
//aiClient entitys.AIClient
|
||||
toolManager *tools.Manager
|
||||
sessionImpl *impl.SessionImpl
|
||||
|
@ -48,8 +48,8 @@ func NewAiRouterBiz(
|
|||
conf *config.Config,
|
||||
utilAgent *utils_ollama.UtilOllama,
|
||||
|
||||
) entitys.RouterService {
|
||||
return &AiRouterService{
|
||||
) *AiRouterBiz {
|
||||
return &AiRouterBiz{
|
||||
//aiClient: aiClient,
|
||||
toolManager: toolManager,
|
||||
sessionImpl: sessionImpl,
|
||||
|
@ -62,13 +62,13 @@ func NewAiRouterBiz(
|
|||
}
|
||||
|
||||
// Route 执行智能路由
|
||||
func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
|
||||
func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Route 执行智能路由
|
||||
func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error {
|
||||
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error {
|
||||
|
||||
session := c.Headers("X-Session", "")
|
||||
if len(session) == 0 {
|
||||
|
@ -201,7 +201,7 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
|
||||
func (r *AiRouterBiz) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
|
||||
|
@ -235,7 +235,7 @@ func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Matc
|
|||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) handleTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
|
||||
func (r *AiRouterBiz) handleTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
|
||||
|
||||
var configData entitys.ConfigDataTool
|
||||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||
|
@ -250,14 +250,14 @@ func (r *AiRouterService) handleTask(c *websocket.Conn, matchJson *entitys.Match
|
|||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) {
|
||||
func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) {
|
||||
|
||||
c.WriteMessage(1, []byte(matchJson.Reasoning))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
|
||||
func (r *AiRouterBiz) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
|
||||
var (
|
||||
request l_request.Request
|
||||
auth = c.Headers("X-Authorization", "")
|
||||
|
@ -292,7 +292,7 @@ func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Ma
|
|||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
||||
func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
||||
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"session_id": sessionId})
|
||||
|
@ -302,7 +302,7 @@ func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiCha
|
|||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
|
||||
func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"app_key": appKey})
|
||||
cond = cond.And(builder.IsNull{"delete_at"})
|
||||
|
@ -311,7 +311,7 @@ func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err err
|
|||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||||
func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||||
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||||
|
@ -322,7 +322,7 @@ func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error
|
|||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool {
|
||||
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool {
|
||||
taskPrompt := make([]llms.Tool, 0)
|
||||
for _, task := range tasks {
|
||||
var taskConfig entitys.TaskConfig
|
||||
|
@ -343,7 +343,7 @@ func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool {
|
|||
return taskPrompt
|
||||
}
|
||||
|
||||
func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
|
||||
func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
|
||||
var (
|
||||
prompt = make([]entitys.Message, 0)
|
||||
)
|
||||
|
@ -360,7 +360,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi
|
|||
return prompt
|
||||
}
|
||||
|
||||
func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||
var (
|
||||
prompt = make([]llms.MessageContent, 0)
|
||||
)
|
||||
|
@ -389,7 +389,7 @@ func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiCha
|
|||
}
|
||||
|
||||
// buildSystemPrompt 构建系统提示词
|
||||
func (r *AiRouterService) buildSystemPrompt(prompt string) string {
|
||||
func (r *AiRouterBiz) buildSystemPrompt(prompt string) string {
|
||||
if len(prompt) == 0 {
|
||||
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
|
||||
}
|
||||
|
@ -397,7 +397,7 @@ func (r *AiRouterService) buildSystemPrompt(prompt string) string {
|
|||
return prompt
|
||||
}
|
||||
|
||||
func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
|
||||
func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
|
||||
for _, item := range his {
|
||||
if len(chatHis.SessionId) == 0 {
|
||||
chatHis.SessionId = item.SessionID
|
||||
|
@ -416,7 +416,7 @@ func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.
|
|||
}
|
||||
|
||||
// handleKnowledgeQA 处理知识问答意图
|
||||
func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
|
||||
func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
package biz
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
"ai_scheduler/internal/tools"
|
||||
"ai_scheduler/utils"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
func Test_task(t *testing.T) {
|
||||
var c entitys.TaskConfig
|
||||
config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
|
||||
err := json.Unmarshal([]byte(config), &c)
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
type configData struct {
|
||||
Param map[string]interface{} `json:"param"`
|
||||
Do map[string]interface{} `json:"do"`
|
||||
}
|
||||
|
||||
func Test_Order(t *testing.T) {
|
||||
routerBiz := in()
|
||||
err := routerBiz.handleTask(nil, &entitys.Match{Index: "order_diagnosis"}, &model.AiTask{Config: `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`})
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
func in() *AiRouterBiz {
|
||||
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
|
||||
flag.Parse()
|
||||
configConfig, err := config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
panic("加载配置失败")
|
||||
}
|
||||
allLogger := log.DefaultLogger()
|
||||
manager := tools.NewManager(configConfig)
|
||||
db, _ := utils.NewGormDb(configConfig)
|
||||
sessionImpl := impl.NewSessionImpl(db)
|
||||
sysImpl := impl.NewSysImpl(db)
|
||||
taskImpl := impl.NewTaskImpl(db)
|
||||
chatImpl := impl.NewChatImpl(db)
|
||||
utilOllama := utils_ollama.NewUtilOllama(configConfig, allLogger)
|
||||
routerBiz := NewAiRouterBiz(manager, sessionImpl, sysImpl, taskImpl, chatImpl, configConfig, utilOllama)
|
||||
|
||||
return routerBiz
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz"
|
||||
"ai_scheduler/internal/data/constant"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/gateway"
|
||||
|
@ -17,16 +18,16 @@ import (
|
|||
|
||||
// ChatHandler 聊天处理器
|
||||
type ChatService struct {
|
||||
routerService entitys.RouterService
|
||||
Gw *gateway.Gateway
|
||||
mu sync.Mutex
|
||||
routerBiz *biz.AiRouterBiz
|
||||
Gw *gateway.Gateway
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewChatHandler 创建聊天处理器
|
||||
func NewChatService(routerService entitys.RouterService, gw *gateway.Gateway) *ChatService {
|
||||
func NewChatService(routerService *biz.AiRouterBiz, gw *gateway.Gateway) *ChatService {
|
||||
return &ChatService{
|
||||
routerService: routerService,
|
||||
Gw: gw,
|
||||
routerBiz: routerService,
|
||||
Gw: gw,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -100,7 +101,7 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
|||
log.Println("JSON parse error:", err)
|
||||
continue
|
||||
}
|
||||
err = h.routerService.RouteWithSocket(c, &req)
|
||||
err = h.routerBiz.RouteWithSocket(c, &req)
|
||||
if err != nil {
|
||||
log.Println("处理失败:", err)
|
||||
continue
|
||||
|
|
|
@ -27,7 +27,7 @@ func Test_task2(t *testing.T) {
|
|||
config configData
|
||||
)
|
||||
|
||||
configJson := `{"param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}, "do": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
|
||||
configJson := `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`
|
||||
err := json.Unmarshal([]byte(configJson), &config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -35,3 +35,7 @@ func Test_task2(t *testing.T) {
|
|||
mapstructure.Decode(config.Do, &c)
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
func in() {
|
||||
|
||||
}
|
|
@ -6,8 +6,8 @@ import (
|
|||
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
)
|
||||
|
||||
|
@ -87,21 +87,20 @@ func (w *ZltxOrderDetailTool) Execute(c *websocket.Conn, args json.RawMessage) e
|
|||
|
||||
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
|
||||
func (w *ZltxOrderDetailTool) getZltxOrderDetail(c *websocket.Conn, number string) (err error) {
|
||||
url := fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number)
|
||||
authorization := fmt.Sprintf("Bearer %s", w.config.APIKey)
|
||||
//查询订单详情
|
||||
req := l_request.Request{
|
||||
Url: fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number),
|
||||
Headers: map[string]string{
|
||||
"Authorization": fmt.Sprintf("Bearer %s", w.config.APIKey),
|
||||
},
|
||||
Method: "GET",
|
||||
}
|
||||
res, err := req.Send()
|
||||
|
||||
// 发送http请求
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Authorization", authorization)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
c.WriteMessage(websocket.TextMessage, res.Content)
|
||||
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue