结构修改

This commit is contained in:
renzhiyuan 2025-09-30 17:53:55 +08:00
parent 8b57afd572
commit 3058a1e6b9
15 changed files with 488 additions and 395 deletions

View File

@ -4,8 +4,8 @@ server:
host: "0.0.0.0" host: "0.0.0.0"
ollama: ollama:
base_url: "http://localhost:11434" base_url: "http://127.0.0.1:11434"
model: "qwen3:8b" model: "qwen3-coder:480b-cloud"
timeout: "120s" timeout: "120s"
level: "info" level: "info"
format: "json" format: "json"

View File

@ -0,0 +1,38 @@
package llm_service
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"context"
)
type LlmService interface {
IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (string, error)
}
// buildSystemPrompt 构建系统提示词
func buildSystemPrompt(prompt string) string {
if len(prompt) == 0 {
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式禁用markdown格式返回\n3.只返回json字符串不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
}
return prompt
}
func buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
for _, item := range his {
if len(chatHis.SessionId) == 0 {
chatHis.SessionId = item.SessionID
}
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
Role: item.Role,
Content: item.Content,
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
})
}
chatHis.Context = entitys.HisContext{
UserLanguage: "zh-CN",
SystemMode: "technical_support",
}
return
}

View File

@ -0,0 +1,86 @@
package llm_service
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_langchain"
"context"
"encoding/json"
"github.com/tmc/langchaingo/llms"
)
type LangChainService struct {
client *utils_langchain.UtilLangChain
}
func NewLangChainGenerate(
client *utils_langchain.UtilLangChain,
) *LangChainService {
return &LangChainService{
client: client,
}
}
func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
prompt := r.getPrompt(sysInfo, history, userInput, tasks)
AgentClient := r.client.Get()
defer r.client.Put(AgentClient)
match, err := AgentClient.Llm.GenerateContent(
ctx, // 使用可取消的上下文
prompt,
llms.WithJSONMode(),
)
msg = match.Choices[0].Content
return
}
func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
var (
prompt = make([]llms.MessageContent, 0)
)
prompt = append(prompt, llms.MessageContent{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{
llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{
llms.TextPart(reqInput),
},
})
return prompt
}
func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
taskPrompt := make([]llms.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfig
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, llms.Tool{
Type: "function",
Function: &llms.FunctionDefinition{
Name: task.Index,
Description: task.Desc,
Parameters: taskConfig.Param,
},
})
}
return taskPrompt
}

View File

@ -0,0 +1,79 @@
package llm_service
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
"context"
"encoding/json"
"fmt"
"github.com/ollama/ollama/api"
)
type OllamaService struct {
client *utils_ollama.Client
}
func NewOllamaGenerate(
client *utils_ollama.Client,
) *OllamaService {
return &OllamaService{
client: client,
}
}
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
prompt := r.getPrompt(requireData.Sys, requireData.Histories, requireData.UserInput, requireData.Tasks)
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
if err != nil {
return
}
msg = match.Message.Content
return
}
func (r *OllamaService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
var (
prompt = make([]api.Message, 0)
)
prompt = append(prompt, api.Message{
Role: "system",
Content: buildSystemPrompt(sysInfo.SysPrompt),
}, api.Message{
Role: "assistant",
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(history))),
}, api.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
taskPrompt := make([]api.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfigDetail
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: task.Index,
Description: task.Desc,
Parameters: api.ToolFunctionParameters{
Type: taskConfig.Param.Type,
Required: taskConfig.Param.Required,
Properties: taskConfig.Param.Properties,
},
},
})
}
return taskPrompt
}

View File

@ -1,5 +1,15 @@
package biz package biz
import "github.com/google/wire" import (
"ai_scheduler/internal/biz/llm_service"
var ProviderSetBiz = wire.NewSet(NewAiRouterBiz, NewSessionBiz, NewChatHistoryBiz) "github.com/google/wire"
)
var ProviderSetBiz = wire.NewSet(
NewAiRouterBiz,
NewSessionBiz,
NewChatHistoryBiz,
llm_service.NewLangChainGenerate,
llm_service.NewOllamaGenerate,
)

View File

@ -1,6 +1,7 @@
package biz package biz
import ( import (
"ai_scheduler/internal/biz/llm_service"
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/constants"
errors "ai_scheduler/internal/data/error" errors "ai_scheduler/internal/data/error"
@ -9,8 +10,8 @@ import (
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/pkg/utils_ollama"
tools "ai_scheduler/internal/tools" "ai_scheduler/internal/tools"
"ai_scheduler/tmpl/dataTemp" "ai_scheduler/tmpl/dataTemp"
"context" "context"
"encoding/json" "encoding/json"
@ -21,96 +22,108 @@ import (
"gitea.cdlsxd.cn/self-tools/l_request" "gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log" "github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
"github.com/ollama/ollama/api"
"github.com/tmc/langchaingo/llms"
"xorm.io/builder" "xorm.io/builder"
) )
// AiRouterBiz 智能路由服务 // AiRouterBiz 智能路由服务
type AiRouterBiz struct { type AiRouterBiz struct {
//aiClient entitys.AIClient
toolManager *tools.Manager toolManager *tools.Manager
sessionImpl *impl.SessionImpl sessionImpl *impl.SessionImpl
sysImpl *impl.SysImpl sysImpl *impl.SysImpl
taskImpl *impl.TaskImpl taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl hisImpl *impl.ChatImpl
conf *config.Config conf *config.Config
utilAgent *utils_ollama.UtilOllama
ollama *utils_ollama.Client
channelPool *pkg.SafeChannelPool
rds *pkg.Rdb rds *pkg.Rdb
langChain *llm_service.LangChainService
Ollama *llm_service.OllamaService
} }
// NewRouterService 创建路由服务 // NewRouterService 创建路由服务
func NewAiRouterBiz( func NewAiRouterBiz(
//aiClient entitys.AIClient,
toolManager *tools.Manager, toolManager *tools.Manager,
sessionImpl *impl.SessionImpl, sessionImpl *impl.SessionImpl,
sysImpl *impl.SysImpl, sysImpl *impl.SysImpl,
taskImpl *impl.TaskImpl, taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl, hisImpl *impl.ChatImpl,
conf *config.Config, conf *config.Config,
utilAgent *utils_ollama.UtilOllama, langChain *llm_service.LangChainService,
channelPool *pkg.SafeChannelPool, Ollama *llm_service.OllamaService,
ollama *utils_ollama.Client,
) *AiRouterBiz { ) *AiRouterBiz {
return &AiRouterBiz{ return &AiRouterBiz{
//aiClient: aiClient,
toolManager: toolManager, toolManager: toolManager,
sessionImpl: sessionImpl, sessionImpl: sessionImpl,
conf: conf, conf: conf,
sysImpl: sysImpl, sysImpl: sysImpl,
hisImpl: hisImpl, hisImpl: hisImpl,
taskImpl: taskImpl, taskImpl: taskImpl,
utilAgent: utilAgent, langChain: langChain,
channelPool: channelPool, Ollama: Ollama,
ollama: ollama,
} }
} }
// Route 执行智能路由
func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
return nil, nil
}
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
//必要数据验证和获取
session := c.Query("x-session", "") var requireData entitys.RequireData
if len(session) == 0 { err = r.dataAuth(c, &requireData)
return errors.SessionNotFound if err != nil {
} return
auth := c.Query("x-authorization", "")
if len(auth) == 0 {
return errors.AuthNotFound
}
key := c.Query("x-app-key", "")
if len(key) == 0 {
return errors.KeyNotFound
} }
var chat = make([]string, 0) //初始化通道/上下文
requireData.Ch = make(chan entitys.Response)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() // 启动独立的消息处理协程
//ch := r.channelPool.Get() done := r.startMessageHandler(ctx, c, &requireData)
ch := make(chan entitys.Response) defer func() {
close(requireData.Ch) //关闭主通道
<-done // 等待消息处理完成
cancel()
}()
//获取系统信息
err = r.getRequireData(req.Text, &requireData)
if err != nil {
log.Errorf("SQL error: %v", err)
return
}
//意图识别
err = r.recognize(ctx, &requireData)
if err != nil {
log.Errorf("LLM error: %v", err)
return
}
//向下传递
if err = r.handleMatch(ctx, &requireData); err != nil {
log.Errorf("Handle error: %v", err)
return
}
return
}
// startMessageHandler 启动独立的消息处理协程
func (r *AiRouterBiz) startMessageHandler(
ctx context.Context,
c *websocket.Conn,
requireData *entitys.RequireData,
) <-chan struct{} {
done := make(chan struct{}) done := make(chan struct{})
var chat []string
go func() { go func() {
defer func() { defer func() {
close(done) close(done)
// 保存历史记录
var his = []*model.AiChatHi{ var his = []*model.AiChatHi{
{ {
SessionID: session, SessionID: requireData.Session,
Role: "user", Role: "user",
Content: req.Text, Content: "", // 用户输入在外部处理
}, },
} }
if len(chat) > 0 { if len(chat) > 0 {
his = append(his, &model.AiChatHi{ his = append(his, &model.AiChatHi{
SessionID: session, SessionID: requireData.Session,
Role: "assistant", Role: "assistant",
Content: strings.Join(chat, ""), Content: strings.Join(chat, ""),
}) })
@ -119,92 +132,19 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe
r.hisImpl.Add(hi) r.hisImpl.Add(hi)
} }
}() }()
for {
select {
case v, ok := <-ch:
if !ok {
return
}
// 带超时的发送,避免阻塞
if err = sendWithTimeout(c, v, 2*time.Second); err != nil {
log.Errorf("Send error: %v", err)
cancel() // 通知主流程退出
return
}
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
chat = append(chat, v.Content)
}
case <-ctx.Done(): for v := range requireData.Ch { // 自动检测通道关闭
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
log.Errorf("Send error: %v", err)
return return
} }
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
chat = append(chat, v.Content)
}
} }
}() }()
defer func() { return done
close(ch)
}()
sysInfo, err := r.getSysInfo(key)
if err != nil {
return errors.SysErr("获取系统信息失败:%v", err.Error())
}
history, err := r.getSessionChatHis(session)
if err != nil {
return errors.SysErr("获取历史记录失败:%v", err.Error())
}
task, err := r.getTasks(sysInfo.SysID)
if err != nil {
return errors.SysErr("获取任务列表失败:%v", err.Error())
}
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
AgentClient := r.utilAgent.Get()
ch <- entitys.Response{
Index: "",
Content: "准备意图识别",
Type: entitys.ResponseLog,
}
match, err := AgentClient.Llm.GenerateContent(
ctx, // 使用可取消的上下文
prompt,
llms.WithJSONMode(),
)
resMsg := match.Choices[0].Content
r.utilAgent.Put(AgentClient)
ch <- entitys.Response{
Index: "",
Content: resMsg,
Type: entitys.ResponseLog,
}
ch <- entitys.Response{
Index: "",
Content: "意图识别结束",
Type: entitys.ResponseLog,
}
if err != nil {
log.Errorf("LLM error: %v", err)
return errors.SystemError
}
var matchJson entitys.Match
if err := json.Unmarshal([]byte(resMsg), &matchJson); err != nil {
log.Info(resMsg)
return errors.SysErr("数据结构错误:%v", err.Error())
}
matchJson.History = pkg.JsonByteIgonErr(history)
matchJson.UserInput = req.Text
if err := r.handleMatch(ctx, c, ch, &matchJson, task, sysInfo); err != nil {
return err
}
return nil
} }
// 辅助函数:带超时的 WebSocket 发送 // 辅助函数:带超时的 WebSocket 发送
@ -218,8 +158,11 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
if r := recover(); r != nil { if r := recover(); r != nil {
done <- fmt.Errorf("panic in MsgSend: %v", r) done <- fmt.Errorf("panic in MsgSend: %v", r)
} }
close(done)
}() }()
done <- entitys.MsgSend(c, data) // 如果 MsgSend 阻塞,这里会卡住
err := entitys.MsgSend(c, data)
done <- err
}() }()
select { select {
@ -229,58 +172,135 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
return sendCtx.Err() return sendCtx.Err()
} }
} }
func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match) (err error) {
ch <- entitys.Response{ func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
requireData.Ch <- entitys.Response{
Index: "", Index: "",
Content: matchJson.Reasoning, Content: "准备意图识别",
Type: entitys.ResponseLog,
}
//意图识别
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
if err != nil {
return
}
requireData.Ch <- entitys.Response{
Index: "",
Content: recognizeMsg,
Type: entitys.ResponseLog,
}
requireData.Ch <- entitys.Response{
Index: "",
Content: "意图识别结束",
Type: entitys.ResponseLog,
}
if err = json.Unmarshal([]byte(recognizeMsg), requireData.Match); err != nil {
err = errors.SysErr("数据结构错误:%v", err.Error())
return
}
return
}
func (r *AiRouterBiz) getRequireData(userInput string, requireData *entitys.RequireData) (err error) {
requireData.Sys, err = r.getSysInfo(requireData.Key)
if err != nil {
err = errors.SysErr("获取系统信息失败:%v", err.Error())
return
}
requireData.Histories, err = r.getSessionChatHis(requireData.Session)
if err != nil {
err = errors.SysErr("获取历史记录失败:%v", err.Error())
return
}
requireData.Tasks, err = r.getTasks(requireData.Sys.SysID)
if err != nil {
err = errors.SysErr("获取任务列表失败:%v", err.Error())
return
}
requireData.UserInput = userInput
if len(requireData.UserInput) == 0 {
err = errors.SysErr("获取用户输入失败")
return
}
if len(requireData.UserInput) == 0 {
err = errors.SysErr("获取用户输入失败")
return
}
return
}
func (r *AiRouterBiz) dataAuth(c *websocket.Conn, requireData *entitys.RequireData) (err error) {
requireData.Session = c.Query("x-session", "")
if len(requireData.Session) == 0 {
err = errors.SessionNotFound
return
}
requireData.Auth = c.Query("x-authorization", "")
if len(requireData.Auth) == 0 {
err = errors.AuthNotFound
return
}
requireData.Key = c.Query("x-app-key", "")
if len(requireData.Key) == 0 {
err = errors.KeyNotFound
return
}
return
}
func (r *AiRouterBiz) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
requireData.Ch <- entitys.Response{
Index: "",
Content: requireData.Match.Reasoning,
Type: entitys.ResponseText, Type: entitys.ResponseText,
} }
return return
} }
func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) { func (r *AiRouterBiz) handleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) {
if !matchJson.IsMatch { if !requireData.Match.IsMatch {
_ = entitys.MsgSend(c, entitys.Response{ requireData.Ch <- entitys.Response{
Index: "", Index: "",
Content: matchJson.Reasoning, Content: requireData.Match.Reasoning,
Type: entitys.ResponseText, Type: entitys.ResponseText,
}) }
return return
} }
var pointTask *model.AiTask var pointTask *model.AiTask
for _, task := range tasks { for _, task := range requireData.Tasks {
if task.Index == matchJson.Index { if task.Index == requireData.Match.Index {
pointTask = &task pointTask = &task
break break
} }
} }
if pointTask == nil || pointTask.Index == "other" { if pointTask == nil || pointTask.Index == "other" {
return r.handleOtherTask(c, ch, matchJson) return r.handleOtherTask(ctx, requireData)
} }
switch pointTask.Type { switch pointTask.Type {
case constants.TaskTypeApi: case constants.TaskTypeApi:
return r.handleApiTask(ch, c, matchJson, pointTask) return r.handleApiTask(ctx, requireData, pointTask)
case constants.TaskTypeFunc: case constants.TaskTypeFunc:
return r.handleTask(ch, c, matchJson, pointTask) return r.handleTask(ctx, requireData, pointTask)
case constants.TaskTypeKnowle: case constants.TaskTypeKnowle:
return r.handleKnowle(ch, c, matchJson, pointTask, sysInfo) return r.handleKnowle(ctx, requireData, pointTask)
default: default:
return r.handleOtherTask(c, ch, matchJson) return r.handleOtherTask(ctx, requireData)
} }
} }
func (r *AiRouterBiz) handleTask(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { func (r *AiRouterBiz) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData) err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil { if err != nil {
return return
} }
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, []byte(matchJson.Parameters), matchJson) err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
if err != nil { if err != nil {
return return
} }
@ -289,7 +309,7 @@ func (r *AiRouterBiz) handleTask(channel chan entitys.Response, c *websocket.Con
} }
// 知识库 // 知识库
func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask, sysInfo model.AiSy) (err error) { func (r *AiRouterBiz) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var ( var (
configData entitys.ConfigDataTool configData entitys.ConfigDataTool
@ -303,11 +323,11 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C
} }
// 通过session 找到知识库session // 通过session 找到知识库session
session := c.Query("x-session", "") var has bool
if len(session) == 0 { if len(requireData.Session) == 0 {
return errors.SessionNotFound return errors.SessionNotFound
} }
sessionInfo, has, err := r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(session)) requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session))
if err != nil { if err != nil {
return return
} else if !has { } else if !has {
@ -330,15 +350,15 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C
} }
// 知识库的session为空请求知识库获取, 并绑定 // 知识库的session为空请求知识库获取, 并绑定
if sessionInfo.KnowlegeSessionID == "" { if requireData.SessionInfo.KnowlegeSessionID == "" {
// 请求知识库 // 请求知识库
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, sysInfo.KnowlegeBaseID, sysInfo.KnowlegeTenantKey); err != nil { if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil {
return return
} }
// 绑定知识库session下次可以使用 // 绑定知识库session下次可以使用
sessionInfo.KnowlegeSessionID = sessionIdKnowledge requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
if err = r.sessionImpl.Update(&sessionInfo, r.sessionImpl.WithSessionId(sessionInfo.SessionID)); err != nil { if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil {
return return
} }
} }
@ -346,25 +366,21 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C
// 用户输入解析 // 用户输入解析
var ok bool var ok bool
input := make(map[string]string) input := make(map[string]string)
if err = json.Unmarshal([]byte(matchJson.Parameters), &input); err != nil { if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil {
return return
} }
if query, ok = input["query"]; !ok { if query, ok = input["query"]; !ok {
return fmt.Errorf("query不能为空") return fmt.Errorf("query不能为空")
} }
knowledgeConfig := tools.KnowledgeBaseRequest{ requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{
Session: sessionInfo.KnowlegeSessionID, Session: requireData.SessionInfo.KnowlegeSessionID,
ApiKey: sysInfo.KnowlegeTenantKey, ApiKey: requireData.Sys.KnowlegeTenantKey,
Query: query, Query: query,
} }
b, err := json.Marshal(knowledgeConfig)
if err != nil {
return
}
// 执行工具 // 执行工具
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, b, nil) err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
if err != nil { if err != nil {
return return
} }
@ -372,17 +388,16 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C
return return
} }
func (r *AiRouterBiz) handleApiTask(channels chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { func (r *AiRouterBiz) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var ( var (
request l_request.Request request l_request.Request
auth = c.Query("x-authorization", "")
requestParam map[string]interface{} requestParam map[string]interface{}
) )
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam) err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam)
if err != nil { if err != nil {
return return
} }
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth) request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
for k, v := range requestParam { for k, v := range requestParam {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v)) task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
} }
@ -403,7 +418,11 @@ func (r *AiRouterBiz) handleApiTask(channels chan entitys.Response, c *websocket
if err != nil { if err != nil {
return return
} }
c.WriteMessage(1, res.Content) requireData.Ch <- entitys.Response{
Index: "",
Content: pkg.JsonStringIgonErr(res.Text),
Type: entitys.ResponseJson,
}
return return
} }
@ -436,119 +455,3 @@ func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
return return
} }
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool {
taskPrompt := make([]llms.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfig
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, llms.Tool{
Type: "function",
Function: &llms.FunctionDefinition{
Name: task.Index,
Description: task.Desc,
Parameters: taskConfig.Param,
},
})
}
return taskPrompt
}
func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
var (
prompt = make([]entitys.Message, 0)
)
prompt = append(prompt, entitys.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, entitys.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
}, entitys.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
func (r *AiRouterBiz) getPromptOllama(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []api.Message {
var (
prompt = make([]api.Message, 0)
)
prompt = append(prompt, api.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, api.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
}, api.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
var (
prompt = make([]llms.MessageContent, 0)
)
prompt = append(prompt, llms.MessageContent{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{
llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{
llms.TextPart(reqInput),
},
})
return prompt
}
// buildSystemPrompt 构建系统提示词
func (r *AiRouterBiz) buildSystemPrompt(prompt string) string {
if len(prompt) == 0 {
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式禁用markdown格式返回\n3.只返回json字符串不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
}
return prompt
}
func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
for _, item := range his {
if len(chatHis.SessionId) == 0 {
chatHis.SessionId = item.SessionID
}
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
Role: item.Role,
Content: item.Content,
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
})
}
chatHis.Context = entitys.HisContext{
UserLanguage: "zh-CN",
SystemMode: "technical_support",
}
return
}
// handleKnowledgeQA 处理知识问答意图
func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
return nil, nil
}

View File

@ -1,6 +1,8 @@
package entitys package entitys
import ( import (
"ai_scheduler/internal/data/model"
"context" "context"
"encoding/json" "encoding/json"
@ -73,7 +75,7 @@ type Tool interface {
Name() string Name() string
Description() string Description() string
Definition() ToolDefinition Definition() ToolDefinition
Execute(channel chan Response, c *websocket.Conn, args json.RawMessage, matchJson *Match) error Execute(ctx context.Context, requireData *RequireData) error
} }
type ConfigDataHttp struct { type ConfigDataHttp struct {
@ -118,6 +120,7 @@ type Match struct {
Reasoning string `json:"reasoning"` Reasoning string `json:"reasoning"`
History []byte `json:"history"` History []byte `json:"history"`
UserInput string `json:"user_input"` UserInput string `json:"user_input"`
Auth string `json:"auth"`
} }
type ChatHis struct { type ChatHis struct {
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
@ -135,8 +138,27 @@ type HisContext struct {
SystemMode string `json:"system_mode"` SystemMode string `json:"system_mode"`
} }
type RequireData struct {
Session string
Key string
Sys model.AiSy
Histories []model.AiChatHi
SessionInfo model.AiSession
Tasks []model.AiTask
Match *Match
UserInput string
Auth string
Ch chan Response
KnowledgeConf KnowledgeBaseRequest
}
type KnowledgeBaseRequest struct {
Session string // 知识库会话id
ApiKey string // 知识库apiKey
Query string // 用户输入
}
// RouterService 路由服务接口 // RouterService 路由服务接口
type RouterService interface { type RouterService interface {
Route(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
RouteWithSocket(c *websocket.Conn, req *ChatSockRequest) error RouteWithSocket(c *websocket.Conn, req *ChatSockRequest) error
} }

View File

@ -1,6 +1,7 @@
package pkg package pkg
import ( import (
"ai_scheduler/internal/pkg/utils_langchain"
"ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/pkg/utils_ollama"
"github.com/google/wire" "github.com/google/wire"
@ -9,7 +10,7 @@ import (
var ProviderSetClient = wire.NewSet( var ProviderSetClient = wire.NewSet(
NewRdb, NewRdb,
NewGormDb, NewGormDb,
utils_ollama.NewUtilOllama, utils_langchain.NewUtilLangChain,
utils_ollama.NewClient, utils_ollama.NewClient,
NewSafeChannelPool, NewSafeChannelPool,
) )

View File

@ -1,4 +1,4 @@
package utils_ollama package utils_langchain
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
@ -13,7 +13,7 @@ import (
"github.com/tmc/langchaingo/llms/ollama" "github.com/tmc/langchaingo/llms/ollama"
) )
type UtilOllama struct { type UtilLangChain struct {
LlmClientPool *sync.Pool LlmClientPool *sync.Pool
poolSize int // 记录池大小,用于调试 poolSize int // 记录池大小,用于调试
model string model string
@ -26,7 +26,7 @@ type LlmObj struct {
Llm *ollama.LLM Llm *ollama.LLM
} }
func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama { func NewUtilLangChain(c *config.Config, logger log.AllLogger) *UtilLangChain {
poolSize := c.Sys.LlmPoolLen poolSize := c.Sys.LlmPoolLen
if poolSize <= 0 { if poolSize <= 0 {
poolSize = 10 // 默认值 poolSize = 10 // 默认值
@ -60,7 +60,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
pool.Put(pool.New()) pool.Put(pool.New())
} }
return &UtilOllama{ return &UtilLangChain{
LlmClientPool: pool, LlmClientPool: pool,
poolSize: poolSize, poolSize: poolSize,
model: c.Ollama.Model, model: c.Ollama.Model,
@ -69,7 +69,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
} }
func (o *UtilOllama) NewClient() *ollama.LLM { func (o *UtilLangChain) NewClient() *ollama.LLM {
llm, _ := ollama.New( llm, _ := ollama.New(
ollama.WithModel(o.c.Ollama.Model), ollama.WithModel(o.c.Ollama.Model),
ollama.WithHTTPClient(&http.Client{ ollama.WithHTTPClient(&http.Client{
@ -91,13 +91,13 @@ func (o *UtilOllama) NewClient() *ollama.LLM {
} }
// Get 返回一个可用的 LLM 客户端 // Get 返回一个可用的 LLM 客户端
func (o *UtilOllama) Get() *LlmObj { func (o *UtilLangChain) Get() *LlmObj {
client := o.LlmClientPool.Get().(*LlmObj) client := o.LlmClientPool.Get().(*LlmObj)
return client return client
} }
// Put 归还客户端(可选:检查是否仍可用) // Put 归还客户端(可选:检查是否仍可用)
func (o *UtilOllama) Put(llm *LlmObj) { func (o *UtilLangChain) Put(llm *LlmObj) {
if llm == nil { if llm == nil {
return return
} }
@ -105,7 +105,7 @@ func (o *UtilOllama) Put(llm *LlmObj) {
} }
// Stats 返回池的统计信息(用于监控) // Stats 返回池的统计信息(用于监控)
func (o *UtilOllama) Stats() (current, max int) { func (o *UtilLangChain) Stats() (current, max int) {
return o.poolSize, o.poolSize return o.poolSize, o.poolSize
} }

View File

@ -5,13 +5,11 @@ import (
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/l_request"
"bufio" "bufio"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2"
) )
// 知识库工具 // 知识库工具
@ -60,22 +58,10 @@ func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition {
} }
// Execute 执行知识库查询 // Execute 执行知识库查询
func (k *KnowledgeBaseTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error { func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
var params KnowledgeBaseRequest return k.chat(requireData)
if err := json.Unmarshal(args, &params); err != nil {
return fmt.Errorf("unmarshal args failed: %w", err)
}
log.Info("开始执行知识库 KnowledgeBaseTool Execute, params: %v", params)
return k.chat(channel, c, params)
}
type KnowledgeBaseRequest struct {
Session string // 知识库会话id
ApiKey string // 知识库apiKey
Query string // 用户输入
} }
// Message 表示解析后的 SSE 消息 // Message 表示解析后的 SSE 消息
@ -110,20 +96,20 @@ func (this *KnowledgeBaseTool) msgContentParse(input string, channel chan entity
} }
// 请求知识库聊天 // 请求知识库聊天
func (this *KnowledgeBaseTool) chat(channel chan entitys.Response, c *websocket.Conn, param KnowledgeBaseRequest) (err error) { func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error) {
req := l_request.Request{ req := l_request.Request{
Method: "post", Method: "post",
Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + param.Session, Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + requireData.KnowledgeConf.Session,
Params: nil, Params: nil,
Headers: map[string]string{ Headers: map[string]string{
"Content-Type": "application/json", "Content-Type": "application/json",
"X-API-Key": param.ApiKey, "X-API-Key": requireData.KnowledgeConf.ApiKey,
}, },
Cookies: nil, Cookies: nil,
Data: nil, Data: nil,
Json: map[string]interface{}{ Json: map[string]interface{}{
"query": param.Query, "query": requireData.KnowledgeConf.Query,
}, },
Files: nil, Files: nil,
Raw: "", Raw: "",
@ -137,7 +123,7 @@ func (this *KnowledgeBaseTool) chat(channel chan entitys.Response, c *websocket.
} }
defer rsp.Body.Close() defer rsp.Body.Close()
err = this.connectAndReadSSE(rsp, channel) err = this.connectAndReadSSE(rsp, requireData.Ch)
if err != nil { if err != nil {
return return
} }

View File

@ -5,11 +5,9 @@ import (
"ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/constants"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/pkg/utils_ollama"
"context"
"encoding/json"
"fmt" "fmt"
"github.com/gofiber/websocket/v2"
) )
// Manager 工具管理器 // Manager 工具管理器
@ -100,13 +98,13 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi
} }
// ExecuteTool 执行工具 // ExecuteTool 执行工具
func (m *Manager) ExecuteTool(channel chan entitys.Response, c *websocket.Conn, name string, args json.RawMessage, matchJson *entitys.Match) error { func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error {
tool, exists := m.GetTool(name) tool, exists := m.GetTool(name)
if !exists { if !exists {
return fmt.Errorf("tool not found: %s", name) return fmt.Errorf("tool not found: %s", name)
} }
return tool.Execute(channel, c, args, matchJson) return tool.Execute(ctx, requireData)
} }
// ExecuteToolCalls 执行多个工具调用 // ExecuteToolCalls 执行多个工具调用

View File

@ -3,13 +3,13 @@ package tools
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/pkg/utils_ollama"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"gitea.cdlsxd.cn/self-tools/l_request" "gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/websocket/v2"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@ -81,34 +81,26 @@ type ZltxOrderDetailData struct {
} }
// Execute 执行直连天下订单详情查询 // Execute 执行直连天下订单详情查询
func (w *ZltxOrderDetailTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error { func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
var req ZltxOrderDetailRequest var req ZltxOrderDetailRequest
if err := json.Unmarshal(args, &req); err != nil { if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid zltxOrderDetail request: %w", err) return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
} }
if req.OrderNumber == "" { if req.OrderNumber == "" {
return fmt.Errorf("number is required") return fmt.Errorf("number is required")
} }
// 这里可以集成真实的直连天下订单详情API // 这里可以集成真实的直连天下订单详情API
return w.getZltxOrderDetail(channel, c, req.OrderNumber, matchJson) return w.getZltxOrderDetail(requireData, req.OrderNumber)
} }
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据 // getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *websocket.Conn, number string, matchJson *entitys.Match) (err error) { func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireData, number string) (err error) {
//查询订单详情 //查询订单详情
var auth string
if c != nil {
auth = c.Query("x-authorization", "")
}
if len(auth) == 0 {
auth = w.config.APIKey
}
req := l_request.Request{ req := l_request.Request{
Url: fmt.Sprintf("%szltx_api/admin/direct/ai/%s", w.config.BaseURL, number), Url: fmt.Sprintf("%szltx_api/admin/direct/ai/%s", w.config.BaseURL, number),
Headers: map[string]string{ Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", auth), "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
}, },
Method: "GET", Method: "GET",
} }
@ -129,13 +121,13 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
if err = json.Unmarshal(res.Content, &resData); err != nil { if err = json.Unmarshal(res.Content, &resData); err != nil {
return return
} }
ch <- entitys.Response{ requireData.Ch <- entitys.Response{
Index: w.Name(), Index: w.Name(),
Content: res.Text, Content: res.Text,
Type: entitys.ResponseJson, Type: entitys.ResponseJson,
} }
if resData.Data.Direct != nil && resData.Data.Direct["needAi"].(bool) { if resData.Data.Direct != nil && resData.Data.Direct["needAi"].(bool) {
ch <- entitys.Response{ requireData.Ch <- entitys.Response{
Index: w.Name(), Index: w.Name(),
Content: "正在分析订单日志", Content: "正在分析订单日志",
Type: entitys.ResponseLoading, Type: entitys.ResponseLoading,
@ -144,7 +136,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
req = l_request.Request{ req = l_request.Request{
Url: fmt.Sprintf("%szltx_api/admin/direct/log/%s/%s", w.config.BaseURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)), Url: fmt.Sprintf("%szltx_api/admin/direct/log/%s/%s", w.config.BaseURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)),
Headers: map[string]string{ Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", auth), "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
}, },
Method: "GET", Method: "GET",
} }
@ -164,14 +156,14 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
return fmt.Errorf("订单日志解析失败:%s", err) return fmt.Errorf("订单日志解析失败:%s", err)
} }
err = w.llm.ChatStream(context.TODO(), ch, []api.Message{ err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{
{ {
Role: "system", Role: "system",
Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,提取出订单失败的原因。", Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,提取出订单失败的原因。",
}, },
{ {
Role: "assistant", Role: "assistant",
Content: fmt.Sprintf("聊天记录:%s", string(matchJson.History)), Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)),
}, },
{ {
Role: "assistant", Role: "assistant",
@ -179,7 +171,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
}, },
{ {
Role: "user", Role: "user",
Content: matchJson.UserInput, Content: requireData.UserInput,
}, },
}, w.Name()) }, w.Name())
if err != nil { if err != nil {
@ -187,15 +179,11 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
} }
} }
if resData.Data.Direct == nil { if resData.Data.Direct == nil {
ch <- entitys.Response{ requireData.Ch <- entitys.Response{
Index: w.Name(), Index: w.Name(),
Content: "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘", Content: "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘",
Type: entitys.ResponseText, Type: entitys.ResponseText,
} }
} }
// else {
//}
return return
} }

View File

@ -3,11 +3,11 @@ package tools
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"gitea.cdlsxd.cn/self-tools/l_request" "gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/websocket/v2"
) )
type ZltxOrderLogTool struct { type ZltxOrderLogTool struct {
@ -67,31 +67,25 @@ type ZltxOrderDirectLogData struct {
Data map[string]interface{} `json:"data"` Data map[string]interface{} `json:"data"`
} }
func (t *ZltxOrderLogTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error { func (t *ZltxOrderLogTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
var req ZltxOrderLogRequest var req ZltxOrderLogRequest
if err := json.Unmarshal(args, &req); err != nil { if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid zltxOrderLog request: %w", err) return fmt.Errorf("invalid zltxOrderLog request: %w", err)
} }
if req.OrderNumber == "" || req.SerialNumber == "" { if req.OrderNumber == "" || req.SerialNumber == "" {
return fmt.Errorf("orderNumber and serialNumber is required") return fmt.Errorf("orderNumber and serialNumber is required")
} }
return t.getZltxOrderLog(channel, c, req.OrderNumber, req.SerialNumber) return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, requireData)
} }
func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.Response, c *websocket.Conn, orderNumber, serialNumber string) (err error) { func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, requireData *entitys.RequireData) (err error) {
//查询订单详情 //查询订单详情
var auth string
if c != nil {
auth = c.Query("x-authorization", "")
}
if len(auth) == 0 {
auth = t.config.APIKey
}
url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber) url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber)
req := l_request.Request{ req := l_request.Request{
Url: url, Url: url,
Headers: map[string]string{ Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", auth), "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
}, },
Method: "GET", Method: "GET",
} }
@ -106,7 +100,7 @@ func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.Response, c *web
if err = json.Unmarshal(res.Content, &resData); err != nil { if err = json.Unmarshal(res.Content, &resData); err != nil {
return return
} }
channel <- entitys.Response{ requireData.Ch <- entitys.Response{
Index: t.Name(), Index: t.Name(),
Content: res.Text, Content: res.Text,
Type: entitys.ResponseJson, Type: entitys.ResponseJson,

View File

@ -3,13 +3,13 @@ package tools
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"gitea.cdlsxd.cn/self-tools/l_request" "gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/websocket/v2"
) )
type ZltxProductTool struct { type ZltxProductTool struct {
@ -53,12 +53,12 @@ type ZltxProductRequest struct {
Name string `json:"name"` Name string `json:"name"`
} }
func (z ZltxProductTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error { func (z ZltxProductTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
var req ZltxProductRequest var req ZltxProductRequest
if err := json.Unmarshal(args, &req); err != nil { if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid zltxProduct request: %w", err) return fmt.Errorf("invalid zltxProduct request: %w", err)
} }
return z.getZltxProduct(channel, c, req.Id, req.Name) return z.getZltxProduct(&req, requireData)
} }
type ZltxProductResponse struct { type ZltxProductResponse struct {
@ -133,22 +133,16 @@ type ZltxProductData struct {
PlatformProductList interface{} `json:"platform_product_list"` PlatformProductList interface{} `json:"platform_product_list"`
} }
func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websocket.Conn, id string, name string) error { func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *entitys.RequireData) error {
var auth string
if c != nil {
auth = c.Query("x-authorization", "")
}
if len(auth) == 0 {
auth = z.config.APIKey
}
var Url string var Url string
var params map[string]string var params map[string]string
if id != "" { if body.Id != "" {
Url = fmt.Sprintf("%s/%s", z.config.BaseURL, id) Url = fmt.Sprintf("%s/%s", z.config.BaseURL, body.Id)
} else { } else {
Url = fmt.Sprintf("%s?keyword=%s&limit=10&page=1", z.config.BaseURL, name) Url = fmt.Sprintf("%s?keyword=%s&limit=10&page=1", z.config.BaseURL, body.Name)
params = map[string]string{ params = map[string]string{
"keyword": name, "keyword": body.Name,
"limit": "10", "limit": "10",
"page": "1", "page": "1",
} }
@ -159,7 +153,7 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websoc
//根据商品ID或名称走不同的接口查询 //根据商品ID或名称走不同的接口查询
Url: Url, Url: Url,
Headers: map[string]string{ Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", auth), "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
}, },
Params: params, Params: params,
Method: "GET", Method: "GET",
@ -191,7 +185,7 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websoc
for i := range resp.Data.List { for i := range resp.Data.List {
// 调用 平台商品列表 // 调用 平台商品列表
if resp.Data.List[i].AuthProductIds != "" { if resp.Data.List[i].AuthProductIds != "" {
platformProductList := z.ExecutePlatformProductList(auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID) platformProductList := z.ExecutePlatformProductList(requireData.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID)
resp.Data.List[i].PlatformProductList = platformProductList resp.Data.List[i].PlatformProductList = platformProductList
} }
} }
@ -200,7 +194,7 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websoc
if err != nil { if err != nil {
return err return err
} }
channel <- entitys.Response{ requireData.Ch <- entitys.Response{
Index: z.Name(), Index: z.Name(),
Content: string(marshal), Content: string(marshal),
Type: entitys.ResponseJson, Type: entitys.ResponseJson,

View File

@ -3,12 +3,12 @@ package tools
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sort" "sort"
"gitea.cdlsxd.cn/self-tools/l_request" "gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/websocket/v2"
) )
type ZltxOrderStatisticsTool struct { type ZltxOrderStatisticsTool struct {
@ -47,15 +47,15 @@ type ZltxOrderStatisticsRequest struct {
Number string `json:"number"` Number string `json:"number"`
} }
func (z ZltxOrderStatisticsTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error { func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
var req ZltxOrderStatisticsRequest var req ZltxOrderStatisticsRequest
if err := json.Unmarshal(args, &req); err != nil { if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
return err return err
} }
if req.Number == "" { if req.Number == "" {
return fmt.Errorf("number is required") return fmt.Errorf("number is required")
} }
return z.getZltxOrderStatistics(channel, c, req.Number) return z.getZltxOrderStatistics(req.Number, requireData)
} }
type ZltxOrderStatisticsResponse struct { type ZltxOrderStatisticsResponse struct {
@ -75,20 +75,14 @@ type ZltxOrderStatisticsData struct {
Total int `json:"total"` Total int `json:"total"`
} }
func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.Response, c *websocket.Conn, number string) error { func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireData *entitys.RequireData) error {
//查询订单详情 //查询订单详情
var auth string
if c != nil {
auth = c.Query("x-authorization", "")
}
if len(auth) == 0 {
auth = z.config.APIKey
}
url := fmt.Sprintf("%s%s", z.config.BaseURL, number) url := fmt.Sprintf("%s%s", z.config.BaseURL, number)
req := l_request.Request{ req := l_request.Request{
Url: url, Url: url,
Headers: map[string]string{ Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", auth), "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
}, },
Method: "GET", Method: "GET",
} }
@ -114,7 +108,7 @@ func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.Res
if err != nil { if err != nil {
return err return err
} }
channel <- entitys.Response{ requireData.Ch <- entitys.Response{
Index: z.Name(), Index: z.Name(),
Content: string(jsonByte), Content: string(jsonByte),
Type: entitys.ResponseJson, Type: entitys.ResponseJson,