结构修改
This commit is contained in:
parent
8b57afd572
commit
3058a1e6b9
|
|
@ -4,8 +4,8 @@ server:
|
||||||
host: "0.0.0.0"
|
host: "0.0.0.0"
|
||||||
|
|
||||||
ollama:
|
ollama:
|
||||||
base_url: "http://localhost:11434"
|
base_url: "http://127.0.0.1:11434"
|
||||||
model: "qwen3:8b"
|
model: "qwen3-coder:480b-cloud"
|
||||||
timeout: "120s"
|
timeout: "120s"
|
||||||
level: "info"
|
level: "info"
|
||||||
format: "json"
|
format: "json"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
package llm_service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LlmService interface {
|
||||||
|
IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSystemPrompt 构建系统提示词
|
||||||
|
func buildSystemPrompt(prompt string) string {
|
||||||
|
if len(prompt) == 0 {
|
||||||
|
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
|
||||||
|
}
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
|
||||||
|
for _, item := range his {
|
||||||
|
if len(chatHis.SessionId) == 0 {
|
||||||
|
chatHis.SessionId = item.SessionID
|
||||||
|
}
|
||||||
|
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
|
||||||
|
Role: item.Role,
|
||||||
|
Content: item.Content,
|
||||||
|
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
chatHis.Context = entitys.HisContext{
|
||||||
|
UserLanguage: "zh-CN",
|
||||||
|
SystemMode: "technical_support",
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,86 @@
|
||||||
|
package llm_service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg"
|
||||||
|
"ai_scheduler/internal/pkg/utils_langchain"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/tmc/langchaingo/llms"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LangChainService struct {
|
||||||
|
client *utils_langchain.UtilLangChain
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLangChainGenerate(
|
||||||
|
client *utils_langchain.UtilLangChain,
|
||||||
|
) *LangChainService {
|
||||||
|
return &LangChainService{
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
|
||||||
|
prompt := r.getPrompt(sysInfo, history, userInput, tasks)
|
||||||
|
AgentClient := r.client.Get()
|
||||||
|
defer r.client.Put(AgentClient)
|
||||||
|
match, err := AgentClient.Llm.GenerateContent(
|
||||||
|
ctx, // 使用可取消的上下文
|
||||||
|
prompt,
|
||||||
|
llms.WithJSONMode(),
|
||||||
|
)
|
||||||
|
msg = match.Choices[0].Content
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||||
|
var (
|
||||||
|
prompt = make([]llms.MessageContent, 0)
|
||||||
|
)
|
||||||
|
prompt = append(prompt, llms.MessageContent{
|
||||||
|
Role: llms.ChatMessageTypeSystem,
|
||||||
|
Parts: []llms.ContentPart{
|
||||||
|
llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
|
||||||
|
},
|
||||||
|
}, llms.MessageContent{
|
||||||
|
Role: llms.ChatMessageTypeTool,
|
||||||
|
Parts: []llms.ContentPart{
|
||||||
|
llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
|
||||||
|
},
|
||||||
|
}, llms.MessageContent{
|
||||||
|
Role: llms.ChatMessageTypeTool,
|
||||||
|
Parts: []llms.ContentPart{
|
||||||
|
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
||||||
|
},
|
||||||
|
}, llms.MessageContent{
|
||||||
|
Role: llms.ChatMessageTypeHuman,
|
||||||
|
Parts: []llms.ContentPart{
|
||||||
|
llms.TextPart(reqInput),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
|
||||||
|
taskPrompt := make([]llms.Tool, 0)
|
||||||
|
for _, task := range tasks {
|
||||||
|
var taskConfig entitys.TaskConfig
|
||||||
|
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
taskPrompt = append(taskPrompt, llms.Tool{
|
||||||
|
Type: "function",
|
||||||
|
Function: &llms.FunctionDefinition{
|
||||||
|
Name: task.Index,
|
||||||
|
Description: task.Desc,
|
||||||
|
Parameters: taskConfig.Param,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
return taskPrompt
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
package llm_service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg"
|
||||||
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OllamaService struct {
|
||||||
|
client *utils_ollama.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOllamaGenerate(
|
||||||
|
client *utils_ollama.Client,
|
||||||
|
) *OllamaService {
|
||||||
|
return &OllamaService{
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
|
||||||
|
prompt := r.getPrompt(requireData.Sys, requireData.Histories, requireData.UserInput, requireData.Tasks)
|
||||||
|
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
|
||||||
|
match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg = match.Message.Content
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OllamaService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
|
||||||
|
var (
|
||||||
|
prompt = make([]api.Message, 0)
|
||||||
|
)
|
||||||
|
prompt = append(prompt, api.Message{
|
||||||
|
Role: "system",
|
||||||
|
Content: buildSystemPrompt(sysInfo.SysPrompt),
|
||||||
|
}, api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(history))),
|
||||||
|
}, api.Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: reqInput,
|
||||||
|
})
|
||||||
|
return prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
|
||||||
|
taskPrompt := make([]api.Tool, 0)
|
||||||
|
for _, task := range tasks {
|
||||||
|
var taskConfig entitys.TaskConfigDetail
|
||||||
|
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
taskPrompt = append(taskPrompt, api.Tool{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: task.Index,
|
||||||
|
Description: task.Desc,
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: taskConfig.Param.Type,
|
||||||
|
Required: taskConfig.Param.Required,
|
||||||
|
Properties: taskConfig.Param.Properties,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
return taskPrompt
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,15 @@
|
||||||
package biz
|
package biz
|
||||||
|
|
||||||
import "github.com/google/wire"
|
import (
|
||||||
|
"ai_scheduler/internal/biz/llm_service"
|
||||||
|
|
||||||
var ProviderSetBiz = wire.NewSet(NewAiRouterBiz, NewSessionBiz, NewChatHistoryBiz)
|
"github.com/google/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ProviderSetBiz = wire.NewSet(
|
||||||
|
NewAiRouterBiz,
|
||||||
|
NewSessionBiz,
|
||||||
|
NewChatHistoryBiz,
|
||||||
|
llm_service.NewLangChainGenerate,
|
||||||
|
llm_service.NewOllamaGenerate,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package biz
|
package biz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"ai_scheduler/internal/biz/llm_service"
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/data/constants"
|
"ai_scheduler/internal/data/constants"
|
||||||
errors "ai_scheduler/internal/data/error"
|
errors "ai_scheduler/internal/data/error"
|
||||||
|
|
@ -9,8 +10,8 @@ import (
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/pkg"
|
"ai_scheduler/internal/pkg"
|
||||||
"ai_scheduler/internal/pkg/mapstructure"
|
"ai_scheduler/internal/pkg/mapstructure"
|
||||||
"ai_scheduler/internal/pkg/utils_ollama"
|
|
||||||
tools "ai_scheduler/internal/tools"
|
"ai_scheduler/internal/tools"
|
||||||
"ai_scheduler/tmpl/dataTemp"
|
"ai_scheduler/tmpl/dataTemp"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
@ -21,96 +22,108 @@ import (
|
||||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
"github.com/gofiber/fiber/v2/log"
|
"github.com/gofiber/fiber/v2/log"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/tmc/langchaingo/llms"
|
|
||||||
"xorm.io/builder"
|
"xorm.io/builder"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AiRouterBiz 智能路由服务
|
// AiRouterBiz 智能路由服务
|
||||||
type AiRouterBiz struct {
|
type AiRouterBiz struct {
|
||||||
//aiClient entitys.AIClient
|
|
||||||
toolManager *tools.Manager
|
toolManager *tools.Manager
|
||||||
sessionImpl *impl.SessionImpl
|
sessionImpl *impl.SessionImpl
|
||||||
sysImpl *impl.SysImpl
|
sysImpl *impl.SysImpl
|
||||||
taskImpl *impl.TaskImpl
|
taskImpl *impl.TaskImpl
|
||||||
hisImpl *impl.ChatImpl
|
hisImpl *impl.ChatImpl
|
||||||
conf *config.Config
|
conf *config.Config
|
||||||
utilAgent *utils_ollama.UtilOllama
|
|
||||||
ollama *utils_ollama.Client
|
|
||||||
channelPool *pkg.SafeChannelPool
|
|
||||||
rds *pkg.Rdb
|
rds *pkg.Rdb
|
||||||
|
langChain *llm_service.LangChainService
|
||||||
|
Ollama *llm_service.OllamaService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRouterService 创建路由服务
|
// NewRouterService 创建路由服务
|
||||||
func NewAiRouterBiz(
|
func NewAiRouterBiz(
|
||||||
//aiClient entitys.AIClient,
|
|
||||||
toolManager *tools.Manager,
|
toolManager *tools.Manager,
|
||||||
sessionImpl *impl.SessionImpl,
|
sessionImpl *impl.SessionImpl,
|
||||||
sysImpl *impl.SysImpl,
|
sysImpl *impl.SysImpl,
|
||||||
taskImpl *impl.TaskImpl,
|
taskImpl *impl.TaskImpl,
|
||||||
hisImpl *impl.ChatImpl,
|
hisImpl *impl.ChatImpl,
|
||||||
conf *config.Config,
|
conf *config.Config,
|
||||||
utilAgent *utils_ollama.UtilOllama,
|
langChain *llm_service.LangChainService,
|
||||||
channelPool *pkg.SafeChannelPool,
|
Ollama *llm_service.OllamaService,
|
||||||
ollama *utils_ollama.Client,
|
|
||||||
|
|
||||||
) *AiRouterBiz {
|
) *AiRouterBiz {
|
||||||
return &AiRouterBiz{
|
return &AiRouterBiz{
|
||||||
//aiClient: aiClient,
|
|
||||||
toolManager: toolManager,
|
toolManager: toolManager,
|
||||||
sessionImpl: sessionImpl,
|
sessionImpl: sessionImpl,
|
||||||
conf: conf,
|
conf: conf,
|
||||||
sysImpl: sysImpl,
|
sysImpl: sysImpl,
|
||||||
hisImpl: hisImpl,
|
hisImpl: hisImpl,
|
||||||
taskImpl: taskImpl,
|
taskImpl: taskImpl,
|
||||||
utilAgent: utilAgent,
|
langChain: langChain,
|
||||||
channelPool: channelPool,
|
Ollama: Ollama,
|
||||||
ollama: ollama,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route 执行智能路由
|
|
||||||
func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
|
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
|
||||||
|
//必要数据验证和获取
|
||||||
session := c.Query("x-session", "")
|
var requireData entitys.RequireData
|
||||||
if len(session) == 0 {
|
err = r.dataAuth(c, &requireData)
|
||||||
return errors.SessionNotFound
|
if err != nil {
|
||||||
}
|
return
|
||||||
auth := c.Query("x-authorization", "")
|
|
||||||
if len(auth) == 0 {
|
|
||||||
return errors.AuthNotFound
|
|
||||||
}
|
|
||||||
key := c.Query("x-app-key", "")
|
|
||||||
if len(key) == 0 {
|
|
||||||
return errors.KeyNotFound
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var chat = make([]string, 0)
|
//初始化通道/上下文
|
||||||
|
requireData.Ch = make(chan entitys.Response)
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
// 启动独立的消息处理协程
|
||||||
//ch := r.channelPool.Get()
|
done := r.startMessageHandler(ctx, c, &requireData)
|
||||||
ch := make(chan entitys.Response)
|
defer func() {
|
||||||
|
close(requireData.Ch) //关闭主通道
|
||||||
|
<-done // 等待消息处理完成
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
//获取系统信息
|
||||||
|
err = r.getRequireData(req.Text, &requireData)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("SQL error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
//意图识别
|
||||||
|
err = r.recognize(ctx, &requireData)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("LLM error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
//向下传递
|
||||||
|
if err = r.handleMatch(ctx, &requireData); err != nil {
|
||||||
|
log.Errorf("Handle error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// startMessageHandler 启动独立的消息处理协程
|
||||||
|
func (r *AiRouterBiz) startMessageHandler(
|
||||||
|
ctx context.Context,
|
||||||
|
c *websocket.Conn,
|
||||||
|
requireData *entitys.RequireData,
|
||||||
|
) <-chan struct{} {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
var chat []string
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
close(done)
|
close(done)
|
||||||
|
// 保存历史记录
|
||||||
var his = []*model.AiChatHi{
|
var his = []*model.AiChatHi{
|
||||||
{
|
{
|
||||||
SessionID: session,
|
SessionID: requireData.Session,
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: req.Text,
|
Content: "", // 用户输入在外部处理
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if len(chat) > 0 {
|
if len(chat) > 0 {
|
||||||
his = append(his, &model.AiChatHi{
|
his = append(his, &model.AiChatHi{
|
||||||
SessionID: session,
|
SessionID: requireData.Session,
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: strings.Join(chat, ""),
|
Content: strings.Join(chat, ""),
|
||||||
})
|
})
|
||||||
|
|
@ -119,92 +132,19 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe
|
||||||
r.hisImpl.Add(hi)
|
r.hisImpl.Add(hi)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
for {
|
|
||||||
select {
|
for v := range requireData.Ch { // 自动检测通道关闭
|
||||||
case v, ok := <-ch:
|
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 带超时的发送,避免阻塞
|
|
||||||
if err = sendWithTimeout(c, v, 2*time.Second); err != nil {
|
|
||||||
log.Errorf("Send error: %v", err)
|
log.Errorf("Send error: %v", err)
|
||||||
cancel() // 通知主流程退出
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
|
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
|
||||||
chat = append(chat, v.Content)
|
chat = append(chat, v.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
defer func() {
|
return done
|
||||||
close(ch)
|
|
||||||
}()
|
|
||||||
|
|
||||||
sysInfo, err := r.getSysInfo(key)
|
|
||||||
if err != nil {
|
|
||||||
return errors.SysErr("获取系统信息失败:%v", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
history, err := r.getSessionChatHis(session)
|
|
||||||
if err != nil {
|
|
||||||
return errors.SysErr("获取历史记录失败:%v", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
task, err := r.getTasks(sysInfo.SysID)
|
|
||||||
if err != nil {
|
|
||||||
return errors.SysErr("获取任务列表失败:%v", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
|
|
||||||
|
|
||||||
AgentClient := r.utilAgent.Get()
|
|
||||||
|
|
||||||
ch <- entitys.Response{
|
|
||||||
Index: "",
|
|
||||||
Content: "准备意图识别",
|
|
||||||
Type: entitys.ResponseLog,
|
|
||||||
}
|
|
||||||
|
|
||||||
match, err := AgentClient.Llm.GenerateContent(
|
|
||||||
ctx, // 使用可取消的上下文
|
|
||||||
prompt,
|
|
||||||
llms.WithJSONMode(),
|
|
||||||
)
|
|
||||||
resMsg := match.Choices[0].Content
|
|
||||||
|
|
||||||
r.utilAgent.Put(AgentClient)
|
|
||||||
ch <- entitys.Response{
|
|
||||||
Index: "",
|
|
||||||
Content: resMsg,
|
|
||||||
Type: entitys.ResponseLog,
|
|
||||||
}
|
|
||||||
ch <- entitys.Response{
|
|
||||||
Index: "",
|
|
||||||
Content: "意图识别结束",
|
|
||||||
Type: entitys.ResponseLog,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("LLM error: %v", err)
|
|
||||||
return errors.SystemError
|
|
||||||
}
|
|
||||||
var matchJson entitys.Match
|
|
||||||
if err := json.Unmarshal([]byte(resMsg), &matchJson); err != nil {
|
|
||||||
log.Info(resMsg)
|
|
||||||
return errors.SysErr("数据结构错误:%v", err.Error())
|
|
||||||
}
|
|
||||||
matchJson.History = pkg.JsonByteIgonErr(history)
|
|
||||||
matchJson.UserInput = req.Text
|
|
||||||
if err := r.handleMatch(ctx, c, ch, &matchJson, task, sysInfo); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 辅助函数:带超时的 WebSocket 发送
|
// 辅助函数:带超时的 WebSocket 发送
|
||||||
|
|
@ -218,8 +158,11 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
done <- fmt.Errorf("panic in MsgSend: %v", r)
|
done <- fmt.Errorf("panic in MsgSend: %v", r)
|
||||||
}
|
}
|
||||||
|
close(done)
|
||||||
}()
|
}()
|
||||||
done <- entitys.MsgSend(c, data)
|
// 如果 MsgSend 阻塞,这里会卡住
|
||||||
|
err := entitys.MsgSend(c, data)
|
||||||
|
done <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|
@ -229,58 +172,135 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
|
||||||
return sendCtx.Err()
|
return sendCtx.Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match) (err error) {
|
|
||||||
ch <- entitys.Response{
|
func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||||||
|
requireData.Ch <- entitys.Response{
|
||||||
Index: "",
|
Index: "",
|
||||||
Content: matchJson.Reasoning,
|
Content: "准备意图识别",
|
||||||
|
Type: entitys.ResponseLog,
|
||||||
|
}
|
||||||
|
//意图识别
|
||||||
|
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requireData.Ch <- entitys.Response{
|
||||||
|
Index: "",
|
||||||
|
Content: recognizeMsg,
|
||||||
|
Type: entitys.ResponseLog,
|
||||||
|
}
|
||||||
|
requireData.Ch <- entitys.Response{
|
||||||
|
Index: "",
|
||||||
|
Content: "意图识别结束",
|
||||||
|
Type: entitys.ResponseLog,
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal([]byte(recognizeMsg), requireData.Match); err != nil {
|
||||||
|
err = errors.SysErr("数据结构错误:%v", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) getRequireData(userInput string, requireData *entitys.RequireData) (err error) {
|
||||||
|
requireData.Sys, err = r.getSysInfo(requireData.Key)
|
||||||
|
if err != nil {
|
||||||
|
err = errors.SysErr("获取系统信息失败:%v", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requireData.Histories, err = r.getSessionChatHis(requireData.Session)
|
||||||
|
if err != nil {
|
||||||
|
err = errors.SysErr("获取历史记录失败:%v", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
requireData.Tasks, err = r.getTasks(requireData.Sys.SysID)
|
||||||
|
if err != nil {
|
||||||
|
err = errors.SysErr("获取任务列表失败:%v", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
requireData.UserInput = userInput
|
||||||
|
if len(requireData.UserInput) == 0 {
|
||||||
|
err = errors.SysErr("获取用户输入失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(requireData.UserInput) == 0 {
|
||||||
|
err = errors.SysErr("获取用户输入失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) dataAuth(c *websocket.Conn, requireData *entitys.RequireData) (err error) {
|
||||||
|
requireData.Session = c.Query("x-session", "")
|
||||||
|
if len(requireData.Session) == 0 {
|
||||||
|
err = errors.SessionNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requireData.Auth = c.Query("x-authorization", "")
|
||||||
|
if len(requireData.Auth) == 0 {
|
||||||
|
err = errors.AuthNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requireData.Key = c.Query("x-app-key", "")
|
||||||
|
if len(requireData.Key) == 0 {
|
||||||
|
err = errors.KeyNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||||||
|
requireData.Ch <- entitys.Response{
|
||||||
|
Index: "",
|
||||||
|
Content: requireData.Match.Reasoning,
|
||||||
Type: entitys.ResponseText,
|
Type: entitys.ResponseText,
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) {
|
func (r *AiRouterBiz) handleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||||||
|
|
||||||
if !matchJson.IsMatch {
|
if !requireData.Match.IsMatch {
|
||||||
_ = entitys.MsgSend(c, entitys.Response{
|
requireData.Ch <- entitys.Response{
|
||||||
Index: "",
|
Index: "",
|
||||||
Content: matchJson.Reasoning,
|
Content: requireData.Match.Reasoning,
|
||||||
Type: entitys.ResponseText,
|
Type: entitys.ResponseText,
|
||||||
})
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var pointTask *model.AiTask
|
var pointTask *model.AiTask
|
||||||
for _, task := range tasks {
|
for _, task := range requireData.Tasks {
|
||||||
if task.Index == matchJson.Index {
|
if task.Index == requireData.Match.Index {
|
||||||
pointTask = &task
|
pointTask = &task
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if pointTask == nil || pointTask.Index == "other" {
|
if pointTask == nil || pointTask.Index == "other" {
|
||||||
return r.handleOtherTask(c, ch, matchJson)
|
return r.handleOtherTask(ctx, requireData)
|
||||||
}
|
}
|
||||||
switch pointTask.Type {
|
switch pointTask.Type {
|
||||||
case constants.TaskTypeApi:
|
case constants.TaskTypeApi:
|
||||||
return r.handleApiTask(ch, c, matchJson, pointTask)
|
return r.handleApiTask(ctx, requireData, pointTask)
|
||||||
case constants.TaskTypeFunc:
|
case constants.TaskTypeFunc:
|
||||||
return r.handleTask(ch, c, matchJson, pointTask)
|
return r.handleTask(ctx, requireData, pointTask)
|
||||||
case constants.TaskTypeKnowle:
|
case constants.TaskTypeKnowle:
|
||||||
return r.handleKnowle(ch, c, matchJson, pointTask, sysInfo)
|
return r.handleKnowle(ctx, requireData, pointTask)
|
||||||
default:
|
default:
|
||||||
return r.handleOtherTask(c, ch, matchJson)
|
return r.handleOtherTask(ctx, requireData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterBiz) handleTask(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
|
func (r *AiRouterBiz) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||||||
var configData entitys.ConfigDataTool
|
var configData entitys.ConfigDataTool
|
||||||
err = json.Unmarshal([]byte(task.Config), &configData)
|
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, []byte(matchJson.Parameters), matchJson)
|
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -289,7 +309,7 @@ func (r *AiRouterBiz) handleTask(channel chan entitys.Response, c *websocket.Con
|
||||||
}
|
}
|
||||||
|
|
||||||
// 知识库
|
// 知识库
|
||||||
func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask, sysInfo model.AiSy) (err error) {
|
func (r *AiRouterBiz) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
configData entitys.ConfigDataTool
|
configData entitys.ConfigDataTool
|
||||||
|
|
@ -303,11 +323,11 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C
|
||||||
}
|
}
|
||||||
|
|
||||||
// 通过session 找到知识库session
|
// 通过session 找到知识库session
|
||||||
session := c.Query("x-session", "")
|
var has bool
|
||||||
if len(session) == 0 {
|
if len(requireData.Session) == 0 {
|
||||||
return errors.SessionNotFound
|
return errors.SessionNotFound
|
||||||
}
|
}
|
||||||
sessionInfo, has, err := r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(session))
|
requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
} else if !has {
|
} else if !has {
|
||||||
|
|
@ -330,15 +350,15 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C
|
||||||
}
|
}
|
||||||
|
|
||||||
// 知识库的session为空,请求知识库获取, 并绑定
|
// 知识库的session为空,请求知识库获取, 并绑定
|
||||||
if sessionInfo.KnowlegeSessionID == "" {
|
if requireData.SessionInfo.KnowlegeSessionID == "" {
|
||||||
// 请求知识库
|
// 请求知识库
|
||||||
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, sysInfo.KnowlegeBaseID, sysInfo.KnowlegeTenantKey); err != nil {
|
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 绑定知识库session,下次可以使用
|
// 绑定知识库session,下次可以使用
|
||||||
sessionInfo.KnowlegeSessionID = sessionIdKnowledge
|
requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
|
||||||
if err = r.sessionImpl.Update(&sessionInfo, r.sessionImpl.WithSessionId(sessionInfo.SessionID)); err != nil {
|
if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -346,25 +366,21 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C
|
||||||
// 用户输入解析
|
// 用户输入解析
|
||||||
var ok bool
|
var ok bool
|
||||||
input := make(map[string]string)
|
input := make(map[string]string)
|
||||||
if err = json.Unmarshal([]byte(matchJson.Parameters), &input); err != nil {
|
if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if query, ok = input["query"]; !ok {
|
if query, ok = input["query"]; !ok {
|
||||||
return fmt.Errorf("query不能为空")
|
return fmt.Errorf("query不能为空")
|
||||||
}
|
}
|
||||||
|
|
||||||
knowledgeConfig := tools.KnowledgeBaseRequest{
|
requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{
|
||||||
Session: sessionInfo.KnowlegeSessionID,
|
Session: requireData.SessionInfo.KnowlegeSessionID,
|
||||||
ApiKey: sysInfo.KnowlegeTenantKey,
|
ApiKey: requireData.Sys.KnowlegeTenantKey,
|
||||||
Query: query,
|
Query: query,
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(knowledgeConfig)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 执行工具
|
// 执行工具
|
||||||
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, b, nil)
|
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -372,17 +388,16 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterBiz) handleApiTask(channels chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
|
func (r *AiRouterBiz) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||||||
var (
|
var (
|
||||||
request l_request.Request
|
request l_request.Request
|
||||||
auth = c.Query("x-authorization", "")
|
|
||||||
requestParam map[string]interface{}
|
requestParam map[string]interface{}
|
||||||
)
|
)
|
||||||
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
|
err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
|
request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
|
||||||
for k, v := range requestParam {
|
for k, v := range requestParam {
|
||||||
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
|
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
|
||||||
}
|
}
|
||||||
|
|
@ -403,7 +418,11 @@ func (r *AiRouterBiz) handleApiTask(channels chan entitys.Response, c *websocket
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.WriteMessage(1, res.Content)
|
requireData.Ch <- entitys.Response{
|
||||||
|
Index: "",
|
||||||
|
Content: pkg.JsonStringIgonErr(res.Text),
|
||||||
|
Type: entitys.ResponseJson,
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -436,119 +455,3 @@ func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool {
|
|
||||||
taskPrompt := make([]llms.Tool, 0)
|
|
||||||
for _, task := range tasks {
|
|
||||||
var taskConfig entitys.TaskConfig
|
|
||||||
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
taskPrompt = append(taskPrompt, llms.Tool{
|
|
||||||
Type: "function",
|
|
||||||
Function: &llms.FunctionDefinition{
|
|
||||||
Name: task.Index,
|
|
||||||
Description: task.Desc,
|
|
||||||
Parameters: taskConfig.Param,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
return taskPrompt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
|
|
||||||
var (
|
|
||||||
prompt = make([]entitys.Message, 0)
|
|
||||||
)
|
|
||||||
prompt = append(prompt, entitys.Message{
|
|
||||||
Role: "system",
|
|
||||||
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
|
||||||
}, entitys.Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
|
|
||||||
}, entitys.Message{
|
|
||||||
Role: "user",
|
|
||||||
Content: reqInput,
|
|
||||||
})
|
|
||||||
return prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) getPromptOllama(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []api.Message {
|
|
||||||
var (
|
|
||||||
prompt = make([]api.Message, 0)
|
|
||||||
)
|
|
||||||
prompt = append(prompt, api.Message{
|
|
||||||
Role: "system",
|
|
||||||
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
|
||||||
}, api.Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
|
|
||||||
}, api.Message{
|
|
||||||
Role: "user",
|
|
||||||
Content: reqInput,
|
|
||||||
})
|
|
||||||
return prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
|
||||||
var (
|
|
||||||
prompt = make([]llms.MessageContent, 0)
|
|
||||||
)
|
|
||||||
prompt = append(prompt, llms.MessageContent{
|
|
||||||
Role: llms.ChatMessageTypeSystem,
|
|
||||||
Parts: []llms.ContentPart{
|
|
||||||
llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)),
|
|
||||||
},
|
|
||||||
}, llms.MessageContent{
|
|
||||||
Role: llms.ChatMessageTypeTool,
|
|
||||||
Parts: []llms.ContentPart{
|
|
||||||
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
|
|
||||||
},
|
|
||||||
}, llms.MessageContent{
|
|
||||||
Role: llms.ChatMessageTypeTool,
|
|
||||||
Parts: []llms.ContentPart{
|
|
||||||
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
|
||||||
},
|
|
||||||
}, llms.MessageContent{
|
|
||||||
Role: llms.ChatMessageTypeHuman,
|
|
||||||
Parts: []llms.ContentPart{
|
|
||||||
llms.TextPart(reqInput),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildSystemPrompt 构建系统提示词
|
|
||||||
func (r *AiRouterBiz) buildSystemPrompt(prompt string) string {
|
|
||||||
if len(prompt) == 0 {
|
|
||||||
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
|
|
||||||
}
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
|
|
||||||
for _, item := range his {
|
|
||||||
if len(chatHis.SessionId) == 0 {
|
|
||||||
chatHis.SessionId = item.SessionID
|
|
||||||
}
|
|
||||||
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
|
|
||||||
Role: item.Role,
|
|
||||||
Content: item.Content,
|
|
||||||
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
chatHis.Context = entitys.HisContext{
|
|
||||||
UserLanguage: "zh-CN",
|
|
||||||
SystemMode: "technical_support",
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleKnowledgeQA 处理知识问答意图
|
|
||||||
func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
package entitys
|
package entitys
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
|
@ -73,7 +75,7 @@ type Tool interface {
|
||||||
Name() string
|
Name() string
|
||||||
Description() string
|
Description() string
|
||||||
Definition() ToolDefinition
|
Definition() ToolDefinition
|
||||||
Execute(channel chan Response, c *websocket.Conn, args json.RawMessage, matchJson *Match) error
|
Execute(ctx context.Context, requireData *RequireData) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfigDataHttp struct {
|
type ConfigDataHttp struct {
|
||||||
|
|
@ -118,6 +120,7 @@ type Match struct {
|
||||||
Reasoning string `json:"reasoning"`
|
Reasoning string `json:"reasoning"`
|
||||||
History []byte `json:"history"`
|
History []byte `json:"history"`
|
||||||
UserInput string `json:"user_input"`
|
UserInput string `json:"user_input"`
|
||||||
|
Auth string `json:"auth"`
|
||||||
}
|
}
|
||||||
type ChatHis struct {
|
type ChatHis struct {
|
||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
|
|
@ -135,8 +138,27 @@ type HisContext struct {
|
||||||
SystemMode string `json:"system_mode"`
|
SystemMode string `json:"system_mode"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RequireData struct {
|
||||||
|
Session string
|
||||||
|
Key string
|
||||||
|
Sys model.AiSy
|
||||||
|
Histories []model.AiChatHi
|
||||||
|
SessionInfo model.AiSession
|
||||||
|
Tasks []model.AiTask
|
||||||
|
Match *Match
|
||||||
|
UserInput string
|
||||||
|
Auth string
|
||||||
|
Ch chan Response
|
||||||
|
KnowledgeConf KnowledgeBaseRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
type KnowledgeBaseRequest struct {
|
||||||
|
Session string // 知识库会话id
|
||||||
|
ApiKey string // 知识库apiKey
|
||||||
|
Query string // 用户输入
|
||||||
|
}
|
||||||
|
|
||||||
// RouterService 路由服务接口
|
// RouterService 路由服务接口
|
||||||
type RouterService interface {
|
type RouterService interface {
|
||||||
Route(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
|
|
||||||
RouteWithSocket(c *websocket.Conn, req *ChatSockRequest) error
|
RouteWithSocket(c *websocket.Conn, req *ChatSockRequest) error
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package pkg
|
package pkg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"ai_scheduler/internal/pkg/utils_langchain"
|
||||||
"ai_scheduler/internal/pkg/utils_ollama"
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
|
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
|
|
@ -9,7 +10,7 @@ import (
|
||||||
var ProviderSetClient = wire.NewSet(
|
var ProviderSetClient = wire.NewSet(
|
||||||
NewRdb,
|
NewRdb,
|
||||||
NewGormDb,
|
NewGormDb,
|
||||||
utils_ollama.NewUtilOllama,
|
utils_langchain.NewUtilLangChain,
|
||||||
utils_ollama.NewClient,
|
utils_ollama.NewClient,
|
||||||
NewSafeChannelPool,
|
NewSafeChannelPool,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package utils_ollama
|
package utils_langchain
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
|
|
@ -13,7 +13,7 @@ import (
|
||||||
"github.com/tmc/langchaingo/llms/ollama"
|
"github.com/tmc/langchaingo/llms/ollama"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UtilOllama struct {
|
type UtilLangChain struct {
|
||||||
LlmClientPool *sync.Pool
|
LlmClientPool *sync.Pool
|
||||||
poolSize int // 记录池大小,用于调试
|
poolSize int // 记录池大小,用于调试
|
||||||
model string
|
model string
|
||||||
|
|
@ -26,7 +26,7 @@ type LlmObj struct {
|
||||||
Llm *ollama.LLM
|
Llm *ollama.LLM
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
|
func NewUtilLangChain(c *config.Config, logger log.AllLogger) *UtilLangChain {
|
||||||
poolSize := c.Sys.LlmPoolLen
|
poolSize := c.Sys.LlmPoolLen
|
||||||
if poolSize <= 0 {
|
if poolSize <= 0 {
|
||||||
poolSize = 10 // 默认值
|
poolSize = 10 // 默认值
|
||||||
|
|
@ -60,7 +60,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
|
||||||
pool.Put(pool.New())
|
pool.Put(pool.New())
|
||||||
}
|
}
|
||||||
|
|
||||||
return &UtilOllama{
|
return &UtilLangChain{
|
||||||
LlmClientPool: pool,
|
LlmClientPool: pool,
|
||||||
poolSize: poolSize,
|
poolSize: poolSize,
|
||||||
model: c.Ollama.Model,
|
model: c.Ollama.Model,
|
||||||
|
|
@ -69,7 +69,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *UtilOllama) NewClient() *ollama.LLM {
|
func (o *UtilLangChain) NewClient() *ollama.LLM {
|
||||||
llm, _ := ollama.New(
|
llm, _ := ollama.New(
|
||||||
ollama.WithModel(o.c.Ollama.Model),
|
ollama.WithModel(o.c.Ollama.Model),
|
||||||
ollama.WithHTTPClient(&http.Client{
|
ollama.WithHTTPClient(&http.Client{
|
||||||
|
|
@ -91,13 +91,13 @@ func (o *UtilOllama) NewClient() *ollama.LLM {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get 返回一个可用的 LLM 客户端
|
// Get 返回一个可用的 LLM 客户端
|
||||||
func (o *UtilOllama) Get() *LlmObj {
|
func (o *UtilLangChain) Get() *LlmObj {
|
||||||
client := o.LlmClientPool.Get().(*LlmObj)
|
client := o.LlmClientPool.Get().(*LlmObj)
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put 归还客户端(可选:检查是否仍可用)
|
// Put 归还客户端(可选:检查是否仍可用)
|
||||||
func (o *UtilOllama) Put(llm *LlmObj) {
|
func (o *UtilLangChain) Put(llm *LlmObj) {
|
||||||
if llm == nil {
|
if llm == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -105,7 +105,7 @@ func (o *UtilOllama) Put(llm *LlmObj) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stats 返回池的统计信息(用于监控)
|
// Stats 返回池的统计信息(用于监控)
|
||||||
func (o *UtilOllama) Stats() (current, max int) {
|
func (o *UtilLangChain) Stats() (current, max int) {
|
||||||
return o.poolSize, o.poolSize
|
return o.poolSize, o.poolSize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -5,13 +5,11 @@ import (
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/pkg/l_request"
|
"ai_scheduler/internal/pkg/l_request"
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/log"
|
|
||||||
"github.com/gofiber/websocket/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// 知识库工具
|
// 知识库工具
|
||||||
|
|
@ -60,22 +58,10 @@ func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute 执行知识库查询
|
// Execute 执行知识库查询
|
||||||
func (k *KnowledgeBaseTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error {
|
func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
|
||||||
|
|
||||||
var params KnowledgeBaseRequest
|
return k.chat(requireData)
|
||||||
if err := json.Unmarshal(args, ¶ms); err != nil {
|
|
||||||
return fmt.Errorf("unmarshal args failed: %w", err)
|
|
||||||
}
|
|
||||||
log.Info("开始执行知识库 KnowledgeBaseTool Execute, params: %v", params)
|
|
||||||
|
|
||||||
return k.chat(channel, c, params)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
type KnowledgeBaseRequest struct {
|
|
||||||
Session string // 知识库会话id
|
|
||||||
ApiKey string // 知识库apiKey
|
|
||||||
Query string // 用户输入
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Message 表示解析后的 SSE 消息
|
// Message 表示解析后的 SSE 消息
|
||||||
|
|
@ -110,20 +96,20 @@ func (this *KnowledgeBaseTool) msgContentParse(input string, channel chan entity
|
||||||
}
|
}
|
||||||
|
|
||||||
// 请求知识库聊天
|
// 请求知识库聊天
|
||||||
func (this *KnowledgeBaseTool) chat(channel chan entitys.Response, c *websocket.Conn, param KnowledgeBaseRequest) (err error) {
|
func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error) {
|
||||||
|
|
||||||
req := l_request.Request{
|
req := l_request.Request{
|
||||||
Method: "post",
|
Method: "post",
|
||||||
Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + param.Session,
|
Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + requireData.KnowledgeConf.Session,
|
||||||
Params: nil,
|
Params: nil,
|
||||||
Headers: map[string]string{
|
Headers: map[string]string{
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"X-API-Key": param.ApiKey,
|
"X-API-Key": requireData.KnowledgeConf.ApiKey,
|
||||||
},
|
},
|
||||||
Cookies: nil,
|
Cookies: nil,
|
||||||
Data: nil,
|
Data: nil,
|
||||||
Json: map[string]interface{}{
|
Json: map[string]interface{}{
|
||||||
"query": param.Query,
|
"query": requireData.KnowledgeConf.Query,
|
||||||
},
|
},
|
||||||
Files: nil,
|
Files: nil,
|
||||||
Raw: "",
|
Raw: "",
|
||||||
|
|
@ -137,7 +123,7 @@ func (this *KnowledgeBaseTool) chat(channel chan entitys.Response, c *websocket.
|
||||||
}
|
}
|
||||||
defer rsp.Body.Close()
|
defer rsp.Body.Close()
|
||||||
|
|
||||||
err = this.connectAndReadSSE(rsp, channel)
|
err = this.connectAndReadSSE(rsp, requireData.Ch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,9 @@ import (
|
||||||
"ai_scheduler/internal/data/constants"
|
"ai_scheduler/internal/data/constants"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/pkg/utils_ollama"
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
|
"context"
|
||||||
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/gofiber/websocket/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager 工具管理器
|
// Manager 工具管理器
|
||||||
|
|
@ -100,13 +98,13 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteTool 执行工具
|
// ExecuteTool 执行工具
|
||||||
func (m *Manager) ExecuteTool(channel chan entitys.Response, c *websocket.Conn, name string, args json.RawMessage, matchJson *entitys.Match) error {
|
func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error {
|
||||||
tool, exists := m.GetTool(name)
|
tool, exists := m.GetTool(name)
|
||||||
if !exists {
|
if !exists {
|
||||||
return fmt.Errorf("tool not found: %s", name)
|
return fmt.Errorf("tool not found: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tool.Execute(channel, c, args, matchJson)
|
return tool.Execute(ctx, requireData)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteToolCalls 执行多个工具调用
|
// ExecuteToolCalls 执行多个工具调用
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,13 @@ package tools
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg"
|
||||||
"ai_scheduler/internal/pkg/utils_ollama"
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
"github.com/gofiber/websocket/v2"
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -81,34 +81,26 @@ type ZltxOrderDetailData struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute 执行直连天下订单详情查询
|
// Execute 执行直连天下订单详情查询
|
||||||
func (w *ZltxOrderDetailTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error {
|
func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
|
||||||
var req ZltxOrderDetailRequest
|
var req ZltxOrderDetailRequest
|
||||||
if err := json.Unmarshal(args, &req); err != nil {
|
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
|
||||||
return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
|
return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.OrderNumber == "" {
|
if req.OrderNumber == "" {
|
||||||
return fmt.Errorf("number is required")
|
return fmt.Errorf("number is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 这里可以集成真实的直连天下订单详情API
|
// 这里可以集成真实的直连天下订单详情API
|
||||||
return w.getZltxOrderDetail(channel, c, req.OrderNumber, matchJson)
|
return w.getZltxOrderDetail(requireData, req.OrderNumber)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
|
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
|
||||||
func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *websocket.Conn, number string, matchJson *entitys.Match) (err error) {
|
func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireData, number string) (err error) {
|
||||||
//查询订单详情
|
//查询订单详情
|
||||||
var auth string
|
|
||||||
if c != nil {
|
|
||||||
auth = c.Query("x-authorization", "")
|
|
||||||
}
|
|
||||||
if len(auth) == 0 {
|
|
||||||
auth = w.config.APIKey
|
|
||||||
}
|
|
||||||
req := l_request.Request{
|
req := l_request.Request{
|
||||||
Url: fmt.Sprintf("%szltx_api/admin/direct/ai/%s", w.config.BaseURL, number),
|
Url: fmt.Sprintf("%szltx_api/admin/direct/ai/%s", w.config.BaseURL, number),
|
||||||
Headers: map[string]string{
|
Headers: map[string]string{
|
||||||
"Authorization": fmt.Sprintf("Bearer %s", auth),
|
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
|
||||||
},
|
},
|
||||||
Method: "GET",
|
Method: "GET",
|
||||||
}
|
}
|
||||||
|
|
@ -129,13 +121,13 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
|
||||||
if err = json.Unmarshal(res.Content, &resData); err != nil {
|
if err = json.Unmarshal(res.Content, &resData); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ch <- entitys.Response{
|
requireData.Ch <- entitys.Response{
|
||||||
Index: w.Name(),
|
Index: w.Name(),
|
||||||
Content: res.Text,
|
Content: res.Text,
|
||||||
Type: entitys.ResponseJson,
|
Type: entitys.ResponseJson,
|
||||||
}
|
}
|
||||||
if resData.Data.Direct != nil && resData.Data.Direct["needAi"].(bool) {
|
if resData.Data.Direct != nil && resData.Data.Direct["needAi"].(bool) {
|
||||||
ch <- entitys.Response{
|
requireData.Ch <- entitys.Response{
|
||||||
Index: w.Name(),
|
Index: w.Name(),
|
||||||
Content: "正在分析订单日志",
|
Content: "正在分析订单日志",
|
||||||
Type: entitys.ResponseLoading,
|
Type: entitys.ResponseLoading,
|
||||||
|
|
@ -144,7 +136,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
|
||||||
req = l_request.Request{
|
req = l_request.Request{
|
||||||
Url: fmt.Sprintf("%szltx_api/admin/direct/log/%s/%s", w.config.BaseURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)),
|
Url: fmt.Sprintf("%szltx_api/admin/direct/log/%s/%s", w.config.BaseURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)),
|
||||||
Headers: map[string]string{
|
Headers: map[string]string{
|
||||||
"Authorization": fmt.Sprintf("Bearer %s", auth),
|
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
|
||||||
},
|
},
|
||||||
Method: "GET",
|
Method: "GET",
|
||||||
}
|
}
|
||||||
|
|
@ -164,14 +156,14 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
|
||||||
return fmt.Errorf("订单日志解析失败:%s", err)
|
return fmt.Errorf("订单日志解析失败:%s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = w.llm.ChatStream(context.TODO(), ch, []api.Message{
|
err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{
|
||||||
{
|
{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,提取出订单失败的原因。",
|
Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,提取出订单失败的原因。",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: fmt.Sprintf("聊天记录:%s", string(matchJson.History)),
|
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
|
|
@ -179,7 +171,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: matchJson.UserInput,
|
Content: requireData.UserInput,
|
||||||
},
|
},
|
||||||
}, w.Name())
|
}, w.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -187,15 +179,11 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if resData.Data.Direct == nil {
|
if resData.Data.Direct == nil {
|
||||||
ch <- entitys.Response{
|
requireData.Ch <- entitys.Response{
|
||||||
Index: w.Name(),
|
Index: w.Name(),
|
||||||
Content: "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘",
|
Content: "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘",
|
||||||
Type: entitys.ResponseText,
|
Type: entitys.ResponseText,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// else {
|
|
||||||
|
|
||||||
//}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,11 @@ package tools
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
"github.com/gofiber/websocket/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ZltxOrderLogTool struct {
|
type ZltxOrderLogTool struct {
|
||||||
|
|
@ -67,31 +67,25 @@ type ZltxOrderDirectLogData struct {
|
||||||
Data map[string]interface{} `json:"data"`
|
Data map[string]interface{} `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ZltxOrderLogTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error {
|
func (t *ZltxOrderLogTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
|
||||||
var req ZltxOrderLogRequest
|
var req ZltxOrderLogRequest
|
||||||
if err := json.Unmarshal(args, &req); err != nil {
|
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
|
||||||
return fmt.Errorf("invalid zltxOrderLog request: %w", err)
|
return fmt.Errorf("invalid zltxOrderLog request: %w", err)
|
||||||
}
|
}
|
||||||
if req.OrderNumber == "" || req.SerialNumber == "" {
|
if req.OrderNumber == "" || req.SerialNumber == "" {
|
||||||
return fmt.Errorf("orderNumber and serialNumber is required")
|
return fmt.Errorf("orderNumber and serialNumber is required")
|
||||||
}
|
}
|
||||||
return t.getZltxOrderLog(channel, c, req.OrderNumber, req.SerialNumber)
|
return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, requireData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.Response, c *websocket.Conn, orderNumber, serialNumber string) (err error) {
|
func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, requireData *entitys.RequireData) (err error) {
|
||||||
//查询订单详情
|
//查询订单详情
|
||||||
var auth string
|
|
||||||
if c != nil {
|
|
||||||
auth = c.Query("x-authorization", "")
|
|
||||||
}
|
|
||||||
if len(auth) == 0 {
|
|
||||||
auth = t.config.APIKey
|
|
||||||
}
|
|
||||||
url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber)
|
url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber)
|
||||||
req := l_request.Request{
|
req := l_request.Request{
|
||||||
Url: url,
|
Url: url,
|
||||||
Headers: map[string]string{
|
Headers: map[string]string{
|
||||||
"Authorization": fmt.Sprintf("Bearer %s", auth),
|
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
|
||||||
},
|
},
|
||||||
Method: "GET",
|
Method: "GET",
|
||||||
}
|
}
|
||||||
|
|
@ -106,7 +100,7 @@ func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.Response, c *web
|
||||||
if err = json.Unmarshal(res.Content, &resData); err != nil {
|
if err = json.Unmarshal(res.Content, &resData); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel <- entitys.Response{
|
requireData.Ch <- entitys.Response{
|
||||||
Index: t.Name(),
|
Index: t.Name(),
|
||||||
Content: res.Text,
|
Content: res.Text,
|
||||||
Type: entitys.ResponseJson,
|
Type: entitys.ResponseJson,
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,13 @@ package tools
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
"github.com/gofiber/websocket/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ZltxProductTool struct {
|
type ZltxProductTool struct {
|
||||||
|
|
@ -53,12 +53,12 @@ type ZltxProductRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (z ZltxProductTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error {
|
func (z ZltxProductTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
|
||||||
var req ZltxProductRequest
|
var req ZltxProductRequest
|
||||||
if err := json.Unmarshal(args, &req); err != nil {
|
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
|
||||||
return fmt.Errorf("invalid zltxProduct request: %w", err)
|
return fmt.Errorf("invalid zltxProduct request: %w", err)
|
||||||
}
|
}
|
||||||
return z.getZltxProduct(channel, c, req.Id, req.Name)
|
return z.getZltxProduct(&req, requireData)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ZltxProductResponse struct {
|
type ZltxProductResponse struct {
|
||||||
|
|
@ -133,22 +133,16 @@ type ZltxProductData struct {
|
||||||
PlatformProductList interface{} `json:"platform_product_list"`
|
PlatformProductList interface{} `json:"platform_product_list"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websocket.Conn, id string, name string) error {
|
func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *entitys.RequireData) error {
|
||||||
var auth string
|
|
||||||
if c != nil {
|
|
||||||
auth = c.Query("x-authorization", "")
|
|
||||||
}
|
|
||||||
if len(auth) == 0 {
|
|
||||||
auth = z.config.APIKey
|
|
||||||
}
|
|
||||||
var Url string
|
var Url string
|
||||||
var params map[string]string
|
var params map[string]string
|
||||||
if id != "" {
|
if body.Id != "" {
|
||||||
Url = fmt.Sprintf("%s/%s", z.config.BaseURL, id)
|
Url = fmt.Sprintf("%s/%s", z.config.BaseURL, body.Id)
|
||||||
} else {
|
} else {
|
||||||
Url = fmt.Sprintf("%s?keyword=%s&limit=10&page=1", z.config.BaseURL, name)
|
Url = fmt.Sprintf("%s?keyword=%s&limit=10&page=1", z.config.BaseURL, body.Name)
|
||||||
params = map[string]string{
|
params = map[string]string{
|
||||||
"keyword": name,
|
"keyword": body.Name,
|
||||||
"limit": "10",
|
"limit": "10",
|
||||||
"page": "1",
|
"page": "1",
|
||||||
}
|
}
|
||||||
|
|
@ -159,7 +153,7 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websoc
|
||||||
//根据商品ID或名称走不同的接口查询
|
//根据商品ID或名称走不同的接口查询
|
||||||
Url: Url,
|
Url: Url,
|
||||||
Headers: map[string]string{
|
Headers: map[string]string{
|
||||||
"Authorization": fmt.Sprintf("Bearer %s", auth),
|
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
|
||||||
},
|
},
|
||||||
Params: params,
|
Params: params,
|
||||||
Method: "GET",
|
Method: "GET",
|
||||||
|
|
@ -191,7 +185,7 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websoc
|
||||||
for i := range resp.Data.List {
|
for i := range resp.Data.List {
|
||||||
// 调用 平台商品列表
|
// 调用 平台商品列表
|
||||||
if resp.Data.List[i].AuthProductIds != "" {
|
if resp.Data.List[i].AuthProductIds != "" {
|
||||||
platformProductList := z.ExecutePlatformProductList(auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID)
|
platformProductList := z.ExecutePlatformProductList(requireData.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID)
|
||||||
resp.Data.List[i].PlatformProductList = platformProductList
|
resp.Data.List[i].PlatformProductList = platformProductList
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -200,7 +194,7 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websoc
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
channel <- entitys.Response{
|
requireData.Ch <- entitys.Response{
|
||||||
Index: z.Name(),
|
Index: z.Name(),
|
||||||
Content: string(marshal),
|
Content: string(marshal),
|
||||||
Type: entitys.ResponseJson,
|
Type: entitys.ResponseJson,
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,12 @@ package tools
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
"github.com/gofiber/websocket/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ZltxOrderStatisticsTool struct {
|
type ZltxOrderStatisticsTool struct {
|
||||||
|
|
@ -47,15 +47,15 @@ type ZltxOrderStatisticsRequest struct {
|
||||||
Number string `json:"number"`
|
Number string `json:"number"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (z ZltxOrderStatisticsTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error {
|
func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
|
||||||
var req ZltxOrderStatisticsRequest
|
var req ZltxOrderStatisticsRequest
|
||||||
if err := json.Unmarshal(args, &req); err != nil {
|
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if req.Number == "" {
|
if req.Number == "" {
|
||||||
return fmt.Errorf("number is required")
|
return fmt.Errorf("number is required")
|
||||||
}
|
}
|
||||||
return z.getZltxOrderStatistics(channel, c, req.Number)
|
return z.getZltxOrderStatistics(req.Number, requireData)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ZltxOrderStatisticsResponse struct {
|
type ZltxOrderStatisticsResponse struct {
|
||||||
|
|
@ -75,20 +75,14 @@ type ZltxOrderStatisticsData struct {
|
||||||
Total int `json:"total"`
|
Total int `json:"total"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.Response, c *websocket.Conn, number string) error {
|
func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireData *entitys.RequireData) error {
|
||||||
//查询订单详情
|
//查询订单详情
|
||||||
var auth string
|
|
||||||
if c != nil {
|
|
||||||
auth = c.Query("x-authorization", "")
|
|
||||||
}
|
|
||||||
if len(auth) == 0 {
|
|
||||||
auth = z.config.APIKey
|
|
||||||
}
|
|
||||||
url := fmt.Sprintf("%s%s", z.config.BaseURL, number)
|
url := fmt.Sprintf("%s%s", z.config.BaseURL, number)
|
||||||
req := l_request.Request{
|
req := l_request.Request{
|
||||||
Url: url,
|
Url: url,
|
||||||
Headers: map[string]string{
|
Headers: map[string]string{
|
||||||
"Authorization": fmt.Sprintf("Bearer %s", auth),
|
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
|
||||||
},
|
},
|
||||||
Method: "GET",
|
Method: "GET",
|
||||||
}
|
}
|
||||||
|
|
@ -114,7 +108,7 @@ func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.Res
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
channel <- entitys.Response{
|
requireData.Ch <- entitys.Response{
|
||||||
Index: z.Name(),
|
Index: z.Name(),
|
||||||
Content: string(jsonByte),
|
Content: string(jsonByte),
|
||||||
Type: entitys.ResponseJson,
|
Type: entitys.ResponseJson,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue