Compare commits

..

3 Commits

Author SHA1 Message Date
renzhiyuan 115aab5b83 结构修改 2025-09-30 14:38:41 +08:00
renzhiyuan 935712308e 结构修改 2025-09-30 11:03:32 +08:00
renzhiyuan ed7591975b 结构修改 2025-09-22 17:04:47 +08:00
47 changed files with 480 additions and 2716 deletions

View File

@ -1,21 +0,0 @@
# 创建最终镜像用于运行编译后的Go程序
FROM alpine
RUN echo 'http://mirrors.ustc.edu.cn/alpine/v3.5/main' > /etc/apk/repositories \
&& echo 'http://mirrors.ustc.edu.cn/alpine/v3.5/community' >>/etc/apk/repositories \
&& apk update && apk add tzdata \
&& ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& echo "Asia/Shanghai" > /etc/timezone
# 设置工作目录
WORKDIR /app
# 将编译好的二进制文件从构建阶段复制到运行阶段
COPY ./ /app
ENV TZ=Asia/Shanghai
# 设置容器启动时运行的命令
CMD ["./bin/server"]

View File

@ -21,10 +21,3 @@ endif
# generate wire
wire:
cd ./cmd/server && wire
.PHONY: build
# build
build:
# make config;
make wire;
mkdir -p bin/ && go build -ldflags "-X main.Version=$(VERSION)" -o ./bin/ ./...

View File

@ -10,21 +10,18 @@ import (
)
func main() {
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
flag.Parse()
bc, err := config.LoadConfig(*configPath)
if err != nil {
log.Fatalf("加载配置失败: %v", err)
}
app, cleanup, err := InitializeApp(bc, log.DefaultLogger())
if err != nil {
log.Fatalf("项目初始化失败: %v", err)
}
defer func() {
cleanup()
}()
defer cleanup()
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
}

35
config.yaml Normal file
View File

@ -0,0 +1,35 @@
server:
port: 8090
host: "0.0.0.0"
ollama:
base_url: "http://localhost:11434"
model: "qwen3:8b"
timeout: 30s
# 模型参数
modelParam:
temperature: 0.7
max_tokens: 2000
tools:
weather:
enabled: true
calculator:
enabled: true
zltxOrderDetail: # 直连天下订单详情
enabled: true
base_url: https://gateway.dev.cdlsxd.cn
biz_system: "zltx"
zltxOrderLog: # 直连天下订单日志
enabled: true
base_url: https://gateway.dev.cdlsxd.cn
biz_system: "zltx"
knowledge: # 知识库
enabled: true
base_url: http://117.175.169.61:8080
api_key: sk-EfnUANKMj3DUOiEPJZ5xS8SGMsbO6be_qYAg9uZ8T3zyoFM-
logging:
level: "info"
format: "json"

View File

@ -6,18 +6,13 @@ server:
ollama:
base_url: "http://127.0.0.1:11434"
model: "qwen3-coder:480b-cloud"
generate_model: "qwen3-coder:480b-cloud"
timeout: "120s"
level: "info"
format: "json"
sys:
session_len: 6
channel_pool_len: 100
channel_pool_size: 32
llm_pool_len: 5
session_len: 3
redis:
host: 47.97.27.195:6379
type: node
@ -35,18 +30,5 @@ db:
tools:
zltxOrderDetail:
enabled: true
base_url: "https://revcl.1688sup.com/api/admin/direct/ai/%s"
add_url: "https://revcl.1688sup.com/api/admin/direct/log/%s/%s"
api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU4MDkxOTU4LCJuYmYiOjE3NTgwOTAxNTgsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.Bjsx9f8yfcrV9EWxb0n6POwnXVOq9XPRD78JFZnnf1_VAVMN78W4W570SZL27PWuDnkD7E4oUg6RzeZwZgl7BZrNpNr-a-QpNC5qCptqrqXeNfVStmX7pxWA8GqnzI8ybkZgbhQ58Gje7DzdJtBq_8zte_LDaYhTYXdIc5EAG0AbCzAk22nPTl47nkMeHtmisXQVLEsdibl1hW3ViFJlXwfXvUrOENItmL1_mRYkggUB0MaTu2nHJOYM6PaOVGLHx-74eepnmK2rm6konFEb6ed-Ukc6gVR-nM9yWZaYLYNGNKJLwZoCX3tRuerq74n4kzQgWmUEJeaVI1yIGSw1zw"
zltxProduct:
enabled: true
base_url: "https://revcl.1688sup.com/api/admin/oursProduct"
add_url: "https://revcl.1688sup.com/api/admin/platformProduct/getProductsByOfficialProductId"
base_url: "https://gateway.dev.cdlsxd.cn/"
api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ"
zltxOrderStatistics:
base_url: "https://revcl.1688sup.com/api/admin/direct/ai/search/"
enabled: true
api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ"
knowledge:
base_url: "http://117.175.169.61:10000"
enabled: true

View File

@ -1,22 +0,0 @@
export GO111MODULE=on
export GOPROXY=https://goproxy.cn,direct
export GOPATH=/root/go
export GOCACHE=/root/.cache/go-build
export CONTAINER_NAME=ai_scheduler
export CGO_ENABLED='0'
git pull origin master
go mod tidy
make build
docker build -t ${CONTAINER_NAME} .
docker stop ${CONTAINER_NAME}
docker rm -f ${CONTAINER_NAME}
docker run -itd \
--name "${CONTAINER_NAME}" \
--restart=always \
-e "OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://host.docker.internal:11434}" \
-p 8090:8090 \
"${CONTAINER_NAME}"
docker logs -f ${CONTAINER_NAME}

4
go.sum
View File

@ -283,8 +283,8 @@ github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQ
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/ollama/ollama v0.12.0 h1:BRry7G2Skz7Mu+E6rz40tzBXNbLTEhheGT8umc1zvxo=
github.com/ollama/ollama v0.12.0/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms=
github.com/ollama/ollama v0.11.11 h1:mErMiUGclp47rCDbSUmBiY2L76EpT0uIYRZVBO6qg/k=
github.com/ollama/ollama v0.11.11/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms=
github.com/ollama/ollama v0.12.3 h1:dHni+/BYDig8u8r7++FLdj6ebZaG95B2ZMqVTqqqYvc=
github.com/ollama/ollama v0.12.3/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms=
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c h1:GwiUUjKefgvSNmv3NCvI/BL0kDebW6Xa+kcdpdc1mTY=

View File

@ -1,38 +0,0 @@
package llm_service
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"context"
)
type LlmService interface {
IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (string, error)
}
// buildSystemPrompt 构建系统提示词
func buildSystemPrompt(prompt string) string {
if len(prompt) == 0 {
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式禁用markdown格式返回\n3.只返回json字符串不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
}
return prompt
}
func buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
for _, item := range his {
if len(chatHis.SessionId) == 0 {
chatHis.SessionId = item.SessionID
}
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
Role: item.Role,
Content: item.Content,
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
})
}
chatHis.Context = entitys.HisContext{
UserLanguage: "zh-CN",
SystemMode: "technical_support",
}
return
}

View File

@ -1,87 +0,0 @@
package llm_service
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_langchain"
"context"
"encoding/json"
"github.com/tmc/langchaingo/llms"
)
type LangChainService struct {
client *utils_langchain.UtilLangChain
}
func NewLangChainGenerate(
client *utils_langchain.UtilLangChain,
) *LangChainService {
return &LangChainService{
client: client,
}
}
func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
prompt := r.getPrompt(sysInfo, history, userInput, tasks)
AgentClient := r.client.Get()
defer r.client.Put(AgentClient)
match, err := AgentClient.Llm.GenerateContent(
ctx, // 使用可取消的上下文
prompt,
llms.WithJSONMode(),
)
msg = match.Choices[0].Content
return
}
func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
var (
prompt = make([]llms.MessageContent, 0)
)
prompt = append(prompt, llms.MessageContent{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{
llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{
llms.TextPart(reqInput),
},
})
return prompt
}
func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
taskPrompt := make([]llms.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfig
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, llms.Tool{
Type: "function",
Function: &llms.FunctionDefinition{
Name: task.Index,
Description: task.Desc,
Parameters: taskConfig.Param,
},
})
}
return taskPrompt
}

View File

@ -1,132 +0,0 @@
package llm_service
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gofiber/fiber/v2/log"
"github.com/ollama/ollama/api"
)
type OllamaService struct {
client *utils_ollama.Client
config *config.Config
}
func NewOllamaGenerate(
client *utils_ollama.Client,
config *config.Config,
) *OllamaService {
return &OllamaService{
client: client,
config: config,
}
}
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
prompt, err := r.getPrompt(ctx, requireData)
if err != nil {
return
}
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
if err != nil {
return
}
log.Info("意图识别结果: %v", pkg.JsonStringIgonErr(match))
if len(match.Message.Content) == 0 {
if match.Message.ToolCalls != nil {
var matchFromTools = &entitys.Match{
Confidence: 1,
Index: match.Message.ToolCalls[0].Function.Name,
Parameters: pkg.JsonStringIgonErr(match.Message.ToolCalls[0].Function.Arguments),
IsMatch: true,
}
match.Message.Content = pkg.JsonStringIgonErr(matchFromTools)
} else {
err = errors.New("不太明白你想表达的意思呢,可以在仔细描述一下您所需要的内容吗,感谢感谢")
return
}
}
msg = match.Message.Content
return
}
func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
var (
prompt = make([]api.Message, 0)
)
prompt = append(prompt, api.Message{
Role: "system",
Content: buildSystemPrompt(requireData.Sys.SysPrompt),
}, api.Message{
Role: "assistant",
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(requireData.Histories))),
}, api.Message{
Role: "user",
Content: requireData.UserInput,
})
if len(requireData.ImgByte) > 0 {
_, err := r.RecognizeWithImg(ctx, requireData)
if err != nil {
return nil, err
}
}
return prompt, nil
}
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
if requireData.ImgByte == nil {
return
}
requireData.Ch <- entitys.Response{
Index: "",
Content: "图片识别中。。。",
Type: entitys.ResponseLoading,
}
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
Model: r.config.Ollama.GenerateModel,
Stream: new(bool),
System: "识别图片内容",
Prompt: requireData.UserInput,
Images: requireData.ImgByte,
})
return
}
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
taskPrompt := make([]api.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfigDetail
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: task.Index,
Description: task.Desc,
Parameters: api.ToolFunctionParameters{
Type: taskConfig.Param.Type,
Required: taskConfig.Param.Required,
Properties: taskConfig.Param.Properties,
},
},
})
}
return taskPrompt
}

View File

@ -1,15 +1,5 @@
package biz
import (
"ai_scheduler/internal/biz/llm_service"
import "github.com/google/wire"
"github.com/google/wire"
)
var ProviderSetBiz = wire.NewSet(
NewAiRouterBiz,
NewSessionBiz,
NewChatHistoryBiz,
llm_service.NewLangChainGenerate,
llm_service.NewOllamaGenerate,
)
var ProviderSetBiz = wire.NewSet(NewAiRouterBiz, NewSessionBiz, NewChatHistoryBiz)

View File

@ -1,341 +1,252 @@
package biz
import (
"ai_scheduler/internal/biz/llm_service"
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/data/constant"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/internal/tools"
"ai_scheduler/tmpl/dataTemp"
"context"
"encoding/json"
"fmt"
"strings"
"time"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2"
"github.com/ollama/ollama/api"
"xorm.io/builder"
)
// AiRouterBiz 智能路由服务
type AiRouterBiz struct {
//aiClient entitys.AIClient
toolManager *tools.Manager
sessionImpl *impl.SessionImpl
sysImpl *impl.SysImpl
taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl
conf *config.Config
rds *pkg.Rdb
langChain *llm_service.LangChainService
Ollama *llm_service.OllamaService
ai *utils_ollama.Client
}
// NewRouterService 创建路由服务
func NewAiRouterBiz(
//aiClient entitys.AIClient,
toolManager *tools.Manager,
sessionImpl *impl.SessionImpl,
sysImpl *impl.SysImpl,
taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl,
conf *config.Config,
langChain *llm_service.LangChainService,
Ollama *llm_service.OllamaService,
ai *utils_ollama.Client,
) *AiRouterBiz {
return &AiRouterBiz{
//aiClient: aiClient,
toolManager: toolManager,
sessionImpl: sessionImpl,
conf: conf,
sysImpl: sysImpl,
hisImpl: hisImpl,
taskImpl: taskImpl,
langChain: langChain,
Ollama: Ollama,
ai: ai,
}
}
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
//必要数据验证和获取
var requireData entitys.RequireData
err = r.dataAuth(c, &requireData)
// Route 执行智能路由
func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
return nil, nil
}
// Route 执行智能路由
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error {
session := c.Query("x-session", "")
if len(session) == 0 {
return errors.SessionNotFound
}
auth := c.Query("x-authorization", "")
if len(auth) == 0 {
return errors.AuthNotFound
}
key := c.Query("x-app-key", "")
if len(key) == 0 {
return errors.KeyNotFound
}
sysInfo, err := r.getSysInfo(key)
if err != nil {
return
return errors.SysNotFound
}
//初始化通道/上下文
requireData.Ch = make(chan entitys.Response)
ctx, cancel := context.WithCancel(context.Background())
// 启动独立的消息处理协程
done := r.startMessageHandler(ctx, c, &requireData)
history, err := r.getSessionChatHis(session)
if err != nil {
return errors.SystemError
}
task, err := r.getTasks(sysInfo.SysID)
if err != nil {
return errors.SystemError
}
//toolDefinitions := r.registerTools(task)
//prompt := r.getPrompt(sysInfo, history, req.Text)
//意图预测
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
toolDefinitions := r.registerTools(task)
match, err := r.ai.ToolSelect(context.TODO(), prompt, toolDefinitions)
if err != nil {
return errors.SystemError
}
log.Info(match)
//var matchJson entitys.Match
//err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
//if err != nil {
// return errors.SystemError
//}
//return r.handleMatch(c, &matchJson, task)
c.WriteMessage(1, []byte(match.Message.Content))
// 构建消息
//messages := []entitys.Message{
// {
// Role: "user",
// Content: req.UserInput,
// },
//}
//
//// 第1次调用AI获取用户意图
//intentResponse, err := r.aiClient.Chat(ctx, messages, nil)
//if err != nil {
// return nil, fmt.Errorf("AI响应失败: %w", err)
//}
//
//// 从AI响应中提取意图
//intent := r.extractIntent(intentResponse)
//if intent == "" {
// return nil, fmt.Errorf("未识别到用户意图")
//}
//
//switch intent {
//case "order_diagnosis":
// // 订单诊断意图
// return r.handleOrderDiagnosis(ctx, req, messages)
//case "knowledge_qa":
// // 知识问答意图
// return r.handleKnowledgeQA(ctx, req, messages)
//default:
// // 未知意图
// return nil, fmt.Errorf("意图识别失败,请明确您的需求呢,我可以为您")
//}
//
//// 获取工具定义
//toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller))
//
//// 第2次调用AI获取是否需要使用工具
//response, err := r.aiClient.Chat(ctx, messages, toolDefinitions)
//if err != nil {
// return nil, fmt.Errorf("failed to chat with AI: %w", err)
//}
//
//// 如果没有工具调用,直接返回
//if len(response.ToolCalls) == 0 {
// return response, nil
//}
//
//// 执行工具调用
//toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls)
//if err != nil {
// return nil, fmt.Errorf("failed to execute tools: %w", err)
//}
//
//// 构建包含工具结果的消息
//messages = append(messages, entitys.Message{
// Role: "assistant",
// Content: response.Message,
//})
//
//// 添加工具调用结果
//for _, toolResult := range toolResults {
// toolResultStr, _ := json.Marshal(toolResult.Result)
// messages = append(messages, entitys.Message{
// Role: "tool",
// Content: fmt.Sprintf("Tool %s result: %s", toolResult.Function.Name, string(toolResultStr)),
// })
//}
//
//// 第二次调用AI生成最终回复
//finalResponse, err := r.aiClient.Chat(ctx, messages, nil)
//if err != nil {
// return nil, fmt.Errorf("failed to generate final response: %w", err)
//}
//
//// 合并工具调用信息到最终响应
//finalResponse.ToolCalls = toolResults
//
//log.Printf("Router processed request: %s, used %d tools", req.UserInput, len(toolResults))
//return finalResponse, nil
return nil
}
func (r *AiRouterBiz) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
var resChan = make(chan []byte, 10)
defer func() {
close(requireData.Ch) //关闭主通道
<-done // 等待消息处理完成
cancel()
close(resChan)
if err != nil {
c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
}
c.WriteMessage(websocket.TextMessage, []byte("EOF"))
}()
//获取图片信息
err = r.getImgData(req.Img, &requireData)
if err != nil {
log.Errorf("GetImgData error: %v", err)
return
}
//获取系统信息
err = r.getRequireData(req.Text, &requireData)
if err != nil {
log.Errorf("SQL error: %v", err)
return
}
//意图识别
err = r.recognize(ctx, &requireData)
if err != nil {
log.Errorf("LLM error: %v", err)
return
}
//向下传递
if err = r.handleMatch(ctx, &requireData); err != nil {
log.Errorf("Handle error: %v", err)
return
}
return
}
// startMessageHandler 启动独立的消息处理协程
func (r *AiRouterBiz) startMessageHandler(
ctx context.Context,
c *websocket.Conn,
requireData *entitys.RequireData,
) <-chan struct{} {
done := make(chan struct{})
var chat []string
go func() {
defer func() {
close(done)
// 保存历史记录
var his = []*model.AiChatHi{
{
SessionID: requireData.Session,
Role: "user",
Content: "", // 用户输入在外部处理
},
}
if len(chat) > 0 {
his = append(his, &model.AiChatHi{
SessionID: requireData.Session,
Role: "assistant",
Content: strings.Join(chat, ""),
})
}
for _, hi := range his {
r.hisImpl.Add(hi)
}
}()
for v := range requireData.Ch { // 自动检测通道关闭
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
log.Errorf("Send error: %v", err)
return
}
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
chat = append(chat, v.Content)
}
}
}()
return done
}
// 辅助函数:带超时的 WebSocket 发送
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
done := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
done <- fmt.Errorf("panic in MsgSend: %v", r)
}
close(done)
}()
// 如果 MsgSend 阻塞,这里会卡住
err := entitys.MsgSend(c, data)
done <- err
}()
select {
case err := <-done:
return err
case <-sendCtx.Done():
return sendCtx.Err()
}
}
func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) {
if len(imgUrl) == 0 {
return
}
if err = pkg.ValidateImageURL(imgUrl); err != nil {
return err
}
req := l_request.Request{
Method: "GET",
Url: imgUrl,
}
res, err := req.Send()
if err != nil {
return
}
if _, ex := res.Headers["Content-Type"]; !ex {
return errors.ParamErr("图片格式错误:Content-Type未获取")
}
if !strings.HasPrefix(res.Headers["Content-Type"], "image/") {
return errors.ParamErr("expected image content, got %s", res.Headers["Content-Type"])
}
requireData.ImgByte = append(requireData.ImgByte, res.Content)
return
}
func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
requireData.Ch <- entitys.Response{
Index: "",
Content: "准备意图识别",
Type: entitys.ResponseLog,
}
//意图识别
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
if err != nil {
return
}
requireData.Ch <- entitys.Response{
Index: "",
Content: recognizeMsg,
Type: entitys.ResponseLog,
}
requireData.Ch <- entitys.Response{
Index: "",
Content: "意图识别结束",
Type: entitys.ResponseLog,
}
var match entitys.Match
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
err = errors.SysErr("数据结构错误:%v", err.Error())
return
}
requireData.Match = &match
return
}
func (r *AiRouterBiz) getRequireData(userInput string, requireData *entitys.RequireData) (err error) {
requireData.Sys, err = r.getSysInfo(requireData.Key)
if err != nil {
err = errors.SysErr("获取系统信息失败:%v", err.Error())
return
}
requireData.Histories, err = r.getSessionChatHis(requireData.Session)
if err != nil {
err = errors.SysErr("获取历史记录失败:%v", err.Error())
return
}
requireData.Tasks, err = r.getTasks(requireData.Sys.SysID)
if err != nil {
err = errors.SysErr("获取任务列表失败:%v", err.Error())
return
}
requireData.UserInput = userInput
if len(requireData.UserInput) == 0 {
err = errors.SysErr("获取用户输入失败")
return
}
if len(requireData.UserInput) == 0 {
err = errors.SysErr("获取用户输入失败")
return
}
return
}
func (r *AiRouterBiz) dataAuth(c *websocket.Conn, requireData *entitys.RequireData) (err error) {
requireData.Session = c.Query("x-session", "")
if len(requireData.Session) == 0 {
err = errors.SessionNotFound
return
}
requireData.Auth = c.Query("x-authorization", "")
if len(requireData.Auth) == 0 {
err = errors.AuthNotFound
return
}
requireData.Key = c.Query("x-app-key", "")
if len(requireData.Key) == 0 {
err = errors.KeyNotFound
return
}
return
}
func (r *AiRouterBiz) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
requireData.Ch <- entitys.Response{
Index: "",
Content: requireData.Match.Reasoning,
Type: entitys.ResponseText,
}
return
}
func (r *AiRouterBiz) handleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) {
if !requireData.Match.IsMatch {
requireData.Ch <- entitys.Response{
Index: "",
Content: requireData.Match.Chat,
Type: entitys.ResponseText,
}
if !matchJson.IsMatch {
c.WriteMessage(websocket.TextMessage, []byte(matchJson.Reasoning))
return
}
var pointTask *model.AiTask
for _, task := range requireData.Tasks {
if task.Index == requireData.Match.Index {
for _, task := range tasks {
if task.Index == matchJson.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return r.handleOtherTask(ctx, requireData)
return r.handleOtherTask(resChan, c, matchJson)
}
switch pointTask.Type {
case constants.TaskTypeApi:
return r.handleApiTask(ctx, requireData, pointTask)
case constants.TaskTypeFunc:
return r.handleTask(ctx, requireData, pointTask)
case constants.TaskTypeKnowle:
return r.handleKnowle(ctx, requireData, pointTask)
case constant.TaskTypeApi:
err = r.handleApiTask(resChan, c, matchJson, pointTask)
case constant.TaskTypeFunc:
err = r.handleTask(resChan, c, matchJson, pointTask)
default:
return r.handleOtherTask(ctx, requireData)
return r.handleOtherTask(resChan, c, matchJson)
}
select {
case v := <-resChan: // 尝试接收
fmt.Println("接收到值:", v)
default:
fmt.Println("无数据可接收")
}
return
}
func (r *AiRouterBiz) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
func (r *AiRouterBiz) handleTask(channel chan []byte, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, []byte(matchJson.Parameters))
if err != nil {
return
}
@ -343,96 +254,22 @@ func (r *AiRouterBiz) handleTask(ctx context.Context, requireData *entitys.Requi
return
}
// 知识库
func (r *AiRouterBiz) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var (
configData entitys.ConfigDataTool
sessionIdKnowledge string
query string
host string
)
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
func (r *AiRouterBiz) handleOtherTask(channel chan []byte, c *websocket.Conn, matchJson *entitys.Match) (err error) {
channel <- []byte(matchJson.Reasoning)
return
}
// 通过session 找到知识库session
var has bool
if len(requireData.Session) == 0 {
return errors.SessionNotFound
}
requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session))
if err != nil {
return
} else if !has {
return errors.SessionNotFound
}
// 找到知识库的host
{
tool, exists := r.toolManager.GetTool(configData.Tool)
if !exists {
return fmt.Errorf("tool not found: %s", configData.Tool)
}
if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok {
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
} else {
host = knowledgeTool.GetConfig().BaseURL
}
}
// 知识库的session为空请求知识库获取, 并绑定
if requireData.SessionInfo.KnowlegeSessionID == "" {
// 请求知识库
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil {
return
}
// 绑定知识库session下次可以使用
requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil {
return
}
}
// 用户输入解析
var ok bool
input := make(map[string]string)
if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil {
return
}
if query, ok = input["query"]; !ok {
return fmt.Errorf("query不能为空")
}
requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{
Session: requireData.SessionInfo.KnowlegeSessionID,
ApiKey: requireData.Sys.KnowlegeTenantKey,
Query: query,
}
// 执行工具
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
if err != nil {
return
}
return
}
func (r *AiRouterBiz) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
func (r *AiRouterBiz) handleApiTask(channels chan []byte, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var (
request l_request.Request
auth = c.Headers("X-Authorization", "")
requestParam map[string]interface{}
)
err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam)
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
if err != nil {
return
}
request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
for k, v := range requestParam {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
}
@ -446,18 +283,14 @@ func (r *AiRouterBiz) handleApiTask(ctx context.Context, requireData *entitys.Re
return
}
if len(request.Url) == 0 {
err = errors.NewBusinessErr(422, "api地址获取失败")
err = errors.NewBusinessErr("00022", "api地址获取失败")
return
}
res, err := request.Send()
if err != nil {
return
}
requireData.Ch <- entitys.Response{
Index: "",
Content: pkg.JsonStringIgonErr(res.Text),
Type: entitys.ResponseJson,
}
c.WriteMessage(1, res.Content)
return
}
@ -466,7 +299,7 @@ func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi,
cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": sessionId})
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id desc")
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id asc")
return
}
@ -490,3 +323,99 @@ func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
return
}
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []api.Tool {
taskPrompt := make([]api.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfigDetail
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: task.Index,
Description: task.Desc,
Parameters: api.ToolFunctionParameters{
Type: taskConfig.Param.Type,
Required: taskConfig.Param.Required,
Properties: taskConfig.Param.Properties,
},
},
})
}
return taskPrompt
}
func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
var (
prompt = make([]entitys.Message, 0)
)
prompt = append(prompt, entitys.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, entitys.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
}, entitys.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
var (
prompt = make([]api.Message, 0)
)
prompt = append(prompt, api.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, api.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
}, api.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.registerTools(tasks)),
}, api.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
// buildSystemPrompt 构建系统提示词
func (r *AiRouterBiz) buildSystemPrompt(prompt string) string {
if len(prompt) == 0 {
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式禁用markdown格式返回\n3.只返回json字符串不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
}
return prompt
}
func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
for _, item := range his {
if len(chatHis.SessionId) == 0 {
chatHis.SessionId = item.SessionID
}
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
Role: item.Role,
Content: item.Content,
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
})
}
chatHis.Context = entitys.HisContext{
UserLanguage: "zh-CN",
SystemMode: "technical_support",
}
return
}
// handleKnowledgeQA 处理知识问答意图
func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
return nil, nil
}

View File

@ -5,7 +5,6 @@ import (
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/internal/tools"
"ai_scheduler/utils"
@ -33,7 +32,7 @@ type configData struct {
func Test_Order(t *testing.T) {
routerBiz := in()
ch := make(chan entitys.Response, 5)
ch := make(chan []byte, 5)
defer close(ch)
err := routerBiz.handleTask(ch, nil, &entitys.Match{Index: "order_diagnosis", Parameters: `{"order_number":"822895927188791297"}`}, &model.AiTask{Config: `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`})
select {
@ -45,55 +44,28 @@ func Test_Order(t *testing.T) {
t.Log(err)
}
func Test_OrderLog(t *testing.T) {
routerBiz := in()
ch := make(chan entitys.Response, 5)
defer close(ch)
err := routerBiz.handleTask(ch, nil, &entitys.Match{Index: "order_diagnosis", Parameters: `{"order_number":"822979421673758721","serial_number":"822979421979938817"}`}, &model.AiTask{Config: `{"tool": "zltxOrderDirectLog", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`})
t.Log(err)
}
func Test_ProductLog(t *testing.T) {
routerBiz := in()
ch := make(chan entitys.Response, 5)
defer close(ch)
err := routerBiz.handleTask(ch, nil, &entitys.Match{Index: "order_diagnosis", Parameters: `{"name":"利楚测试"}`}, &model.AiTask{Config: `{"tool": "zltxProduct", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`})
t.Log(err)
}
func Test_ZltxStatistics(t *testing.T) {
routerBiz := in()
ch := make(chan entitys.Response, 5)
defer close(ch)
err := routerBiz.handleTask(ch, nil, &entitys.Match{Index: "order_diagnosis", Parameters: `{"number":"13737882067"}`}, &model.AiTask{Config: `{"tool": "zltxOrderStatistics", "param": {"type": "object", "optional": [], "required": ["number"], "properties": {"number": {"type": "string", "description": "充值账号/分销商ID"}}}}`})
t.Log(err)
}
func in() *AiRouterBiz {
modDir, err := getModuleDir()
if err != nil {
panic("1")
}
configPath := flag.String("config", fmt.Sprintf("%s/config/config.yaml", modDir), "Path to configuration file")
flag.Parse()
configConfig, err := config.LoadConfig(*configPath)
if err != nil {
panic("加载配置失败")
}
client, _, err := utils_ollama.NewClient(configConfig)
allLogger := log.DefaultLogger()
utilOllama := utils_ollama.NewUtilOllama(configConfig, allLogger)
manager := tools.NewManager(configConfig, client)
manager := tools.NewManager(configConfig, utilOllama)
db, _ := utils.NewGormDb(configConfig)
sessionImpl := impl.NewSessionImpl(db)
sysImpl := impl.NewSysImpl(db)
taskImpl := impl.NewTaskImpl(db)
chatImpl := impl.NewChatImpl(db)
safeChannelPool, _ := pkg.NewSafeChannelPool(configConfig)
routerBiz := NewAiRouterBiz(manager, sessionImpl, sysImpl, taskImpl, chatImpl, configConfig, utilOllama, safeChannelPool, client)
routerBiz := NewAiRouterBiz(manager, sessionImpl, sysImpl, taskImpl, chatImpl, configConfig, utilOllama)
return routerBiz
}

View File

@ -2,14 +2,13 @@ package biz
import (
"ai_scheduler/internal/data/constants"
errorcode "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"context"
"time"
"fmt"
"github.com/gofiber/fiber/v2/utils"
"time"
"ai_scheduler/internal/config"
)
@ -22,11 +21,10 @@ type SessionBiz struct {
conf *config.Config
}
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatImpl) *SessionBiz {
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl) *SessionBiz {
return &SessionBiz{
sessionRepo: sessionImpl,
sysRepo: sysImpl,
chatRepo: chatImpl,
conf: conf,
}
}
@ -39,7 +37,7 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe
if err != nil {
return
} else if !has {
err = errorcode.SysNotFound
err = fmt.Errorf("sys not found")
return
}
@ -74,8 +72,6 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe
Content: sysConfig.Prologue,
}
result.Chat = append(result.Chat, chat)
result.SessionId = session.SessionID
result.Prologue = sysConfig.Prologue
// 开场白写入会话历史
s.chatRepo.AsyncCreate(ctx, model.AiChatHi{
@ -85,8 +81,6 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe
})
} else {
result.SessionId = session.SessionID
result.Prologue = sysConfig.Prologue
// 存在,返回会话历史
var chatList []model.AiChatHi
chatList, err = s.chatRepo.FindAll(
@ -104,7 +98,6 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe
SessionID: chat.SessionID,
Role: constants.Caller(chat.Role),
Content: chat.Content,
Prologue: sysConfig.Prologue,
})
}
}

View File

@ -26,9 +26,6 @@ type LLM struct {
// SysConfig 系统配置
type SysConfig struct {
SessionLen int `mapstructure:"session_len"`
ChannelPoolLen int `mapstructure:"channel_pool_len"`
ChannelPoolSize int `mapstructure:"channel_pool_size"`
LlmPoolLen int `mapstructure:"llm_pool_len"`
}
// ServerConfig 服务器配置
@ -41,7 +38,6 @@ type ServerConfig struct {
type OllamaConfig struct {
BaseURL string `mapstructure:"base_url"`
Model string `mapstructure:"model"`
GenerateModel string `mapstructure:"generate_model"`
Timeout time.Duration `mapstructure:"timeout"`
}
@ -71,12 +67,8 @@ type ToolsConfig struct {
Weather ToolConfig `mapstructure:"weather"`
Calculator ToolConfig `mapstructure:"calculator"`
ZltxOrderDetail ToolConfig `mapstructure:"zltxOrderDetail"`
ZltxOrderDirectLog ToolConfig `mapstructure:"zltxOrderDirectLog"`
ZltxOrderLog ToolConfig `mapstructure:"zltxOrderLog"`
Knowledge ToolConfig `mapstructure:"knowledge"`
//通过ID获取我们的商品信息
ZltxProduct ToolConfig `mapstructure:"zltxProduct"`
//通过账号获取订单统计信息
ZltxOrderStatistics ToolConfig `mapstructure:"zltxOrderStatistics"`
}
// ToolConfig 单个工具配置
@ -84,8 +76,6 @@ type ToolConfig struct {
Enabled bool `mapstructure:"enabled"`
BaseURL string `mapstructure:"base_url"`
APIKey string `mapstructure:"api_key"`
//附加地址
AddURL string `mapstructure:"add_url"`
}
// LoggingConfig 日志配置

View File

@ -1,4 +1,4 @@
package constants
package constant
type ConnStatus int8

View File

@ -0,0 +1,3 @@
package constant
const ()

View File

@ -1,3 +0,0 @@
package constants
const ()

View File

@ -1,34 +1,32 @@
package errorcode
import "fmt"
var (
Success = &BusinessErr{code: 200, message: "成功"}
ParamError = &BusinessErr{code: 401, message: "参数错误"}
NotFoundError = &BusinessErr{code: 404, message: "请求地址未找到"}
SystemError = &BusinessErr{code: 405, message: "系统错误"}
Success = &BusinessErr{code: "0000", message: "成功"}
ParamError = &BusinessErr{code: "0001", message: "参数错误"}
NotFoundError = &BusinessErr{code: "0004", message: "请求地址未找到"}
SystemError = &BusinessErr{code: "0005", message: "系统错误"}
ClientNotFound = &BusinessErr{code: 406, message: "未找到client_id"}
SessionNotFound = &BusinessErr{code: 407, message: "未找到会话信息"}
AuthNotFound = &BusinessErr{code: 408, message: "身份验证失败"}
KeyNotFound = &BusinessErr{code: 409, message: "身份验证失败"}
SysNotFound = &BusinessErr{code: 410, message: "未找到系统信息"}
SupplierNotFound = &BusinessErr{code: "0006", message: "供应商不存在"}
SessionNotFound = &BusinessErr{code: "0007", message: "未找到会话信息"}
AuthNotFound = &BusinessErr{code: "0008", message: "身份验证失败"}
KeyNotFound = &BusinessErr{code: "0009", message: "身份验证失败"}
SysNotFound = &BusinessErr{code: "0010", message: "未找到系统信息"}
InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"}
)
const (
InvalidParamCode = 408
InvalidParamCode = "0008"
)
type BusinessErr struct {
code int
code string
message string
}
func (e *BusinessErr) Error() string {
return e.message
}
func (e *BusinessErr) Code() int {
func (e *BusinessErr) Code() string {
return e.code
}
@ -38,18 +36,10 @@ func (e *BusinessErr) Is(target error) bool {
}
// CustomErr 自定义错误
func NewBusinessErr(code int, message string) *BusinessErr {
func NewBusinessErr(code string, message string) *BusinessErr {
return &BusinessErr{code: code, message: message}
}
func SysErr(message string, arg ...any) *BusinessErr {
return &BusinessErr{code: SystemError.code, message: fmt.Sprintf(message, arg)}
}
func ParamErr(message string, arg ...any) *BusinessErr {
return &BusinessErr{code: ParamError.code, message: fmt.Sprintf(message, arg)}
}
func (e *BusinessErr) Wrap(err error) *BusinessErr {
return NewBusinessErr(e.code, err.Error())
}

View File

@ -9,13 +9,13 @@ import (
type SysImpl struct {
dataTemp.DataTemp
BaseRepository[model.AiSy]
BaseModel[model.AiSy]
}
func NewSysImpl(db *utils.Db) *SysImpl {
return &SysImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSy)),
BaseRepository: NewBaseModel[model.AiSy](db.Client),
BaseModel: BaseModel[model.AiSy]{},
}
}

View File

@ -8,5 +8,4 @@ type ChatHistory struct {
SessionID string `json:"session_id"`
Role constants.Caller `json:"role"`
Content string `json:"content"`
Prologue string `json:"prologue"`
}

View File

@ -1,61 +0,0 @@
package entitys
import (
"encoding/json"
"github.com/gofiber/websocket/v2"
)
type ResponseType string
const (
ResponseJson ResponseType = "json"
ResponseLoading ResponseType = "loading"
ResponseEnd ResponseType = "end"
ResponseStream ResponseType = "stream"
ResponseText ResponseType = "txt"
ResponseImg ResponseType = "img"
ResponseFile ResponseType = "file"
ResponseErr ResponseType = "error"
ResponseLog ResponseType = "log"
ResponseAuth ResponseType = "auth"
)
type ResponseData struct {
Done bool
Content string
Type ResponseType
}
type Response struct {
Content string
Type ResponseType
Index string
}
func MsgSet(msgType ResponseType, msg string, done bool) []byte {
jsonByte, err := json.Marshal(ResponseData{
Done: done,
Content: msg,
Type: msgType,
})
if err != nil {
return nil
}
return jsonByte
}
func MsgSend(c *websocket.Conn, msg Response) error {
// 检查上下文是否已取消
if msg.Type == ResponseText {
}
jsonByte, _ := json.Marshal(msg)
return c.WriteMessage(websocket.TextMessage, jsonByte)
}
func MsgSendByte(c *websocket.Conn, msg []byte) {
_ = c.WriteMessage(websocket.TextMessage, msg)
}

View File

@ -6,9 +6,7 @@ type SessionInitRequest struct {
}
type SessionInitResponse struct {
SessionId string `json:"session_id"`
Chat []ChatHistory `json:"chat"`
Prologue string `json:"prologue"`
}
type SessionListRequest struct {

View File

@ -1,8 +1,6 @@
package entitys
import (
"ai_scheduler/internal/data/model"
"context"
"encoding/json"
@ -22,12 +20,6 @@ type ChatRequestMeta struct {
Authorization string `json:"authorization"`
}
type FirstSockRequest struct {
Authorization string `json:"authorization"`
SessionID string `json:"session_id"`
AppKey string `json:"app_key"`
}
type ChatSockRequest struct {
Text string `json:"text" binding:"required"`
Img string `json:"img" binding:"required"`
@ -75,7 +67,7 @@ type Tool interface {
Name() string
Description() string
Definition() ToolDefinition
Execute(ctx context.Context, requireData *RequireData) error
Execute(channel chan []byte, c *websocket.Conn, args json.RawMessage) error
}
type ConfigDataHttp struct {
@ -85,7 +77,6 @@ type ConfigDataHttp struct {
type ConfigDataTool struct {
Param map[string]interface{} `json:"param"`
Request map[string]interface{} `json:"request"`
Tool string `json:"tool"`
}
@ -113,15 +104,11 @@ type ConfigParam struct {
Type string `json:"type"`
}
type Match struct {
Confidence any `json:"confidence"`
Confidence float64 `json:"confidence"`
Index string `json:"index"`
IsMatch bool `json:"is_match"`
Parameters string `json:"parameters"`
Reasoning string `json:"reasoning"`
History []byte `json:"history"`
UserInput string `json:"user_input"`
Auth string `json:"auth"`
Chat string `json:"chat"`
}
type ChatHis struct {
SessionId string `json:"session_id"`
@ -139,28 +126,8 @@ type HisContext struct {
SystemMode string `json:"system_mode"`
}
type RequireData struct {
Session string
Key string
Sys model.AiSy
Histories []model.AiChatHi
SessionInfo model.AiSession
Tasks []model.AiTask
Match *Match
UserInput string
Auth string
Ch chan Response
KnowledgeConf KnowledgeBaseRequest
ImgByte []api.ImageData
}
type KnowledgeBaseRequest struct {
Session string // 知识库会话id
ApiKey string // 知识库apiKey
Query string // 用户输入
}
// RouterService 路由服务接口
type RouterService interface {
Route(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
RouteWithSocket(c *websocket.Conn, req *ChatSockRequest) error
}

View File

@ -1,75 +0,0 @@
package pkg
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"sync"
)
type SafeChannelPool struct {
pool chan chan entitys.ResponseData // 存储空闲 channel 的队列
bufSize int // channel 缓冲大小
mu sync.Mutex
closed bool
}
func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) {
pool := &SafeChannelPool{
pool: make(chan chan entitys.ResponseData, c.Sys.ChannelPoolLen),
bufSize: c.Sys.ChannelPoolSize,
}
cleanup := pool.Close
return pool, cleanup
}
// 从池中获取 channel若无空闲则创建新 channel
func (p *SafeChannelPool) Get() chan entitys.ResponseData {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return make(chan entitys.ResponseData, p.bufSize)
}
select {
case ch := <-p.pool: // 从池中取
return ch
default: // 池为空,创建新 channel
return make(chan entitys.ResponseData, p.bufSize)
}
}
// 将 channel 放回池中(必须确保 channel 已清空!)
func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return
}
// 清空 channel防止复用时读取旧数据
go func() {
for range ch {
// 丢弃所有数据(或根据业务需求处理)
}
}()
select {
case p.pool <- ch: // 尝试放回池中
default: // 池已满,直接关闭 channel避免泄漏
close(ch)
}
return
}
// 关闭池(释放所有资源)
func (p *SafeChannelPool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
p.closed = true
close(p.pool) // 关闭池队列
// 需额外逻辑关闭所有内部 channel此处简化
}

View File

@ -1,57 +1,8 @@
package pkg
import (
"ai_scheduler/internal/entitys"
"encoding/json"
"errors"
"fmt"
"net/url"
"strings"
)
import "encoding/json"
func JsonStringIgonErr(data interface{}) string {
return string(JsonByteIgonErr(data))
}
func JsonByteIgonErr(data interface{}) []byte {
dataByte, _ := json.Marshal(data)
return dataByte
}
// IsChannelClosed 检查给定的 channel 是否已经关闭
// 参数 ch: 要检查的 channel类型为 chan entitys.ResponseData
// 返回值: bool 类型true 表示 channel 已关闭false 表示未关闭
func IsChannelClosed(ch chan entitys.ResponseData) bool {
select {
case _, ok := <-ch: // 尝试从 channel 中读取数据
return !ok // 如果 ok=false说明 channel 已关闭
default: // 如果 channel 暂时无数据可读(但不一定关闭)
return false // channel 未关闭(但可能有数据未读取)
}
}
// ValidateImageURL 验证图片 URL 是否有效
func ValidateImageURL(rawURL string) error {
// 1. 基础格式验证
parsed, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid URL format: %v", err)
}
// 2. 检查协议是否为 http/https
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return errors.New("URL must use http or https protocol")
}
// 3. 检查是否有空的主机名
if parsed.Host == "" {
return errors.New("URL missing host")
}
// 4. 检查路径是否为空(可选)
if strings.TrimSpace(parsed.Path) == "" {
return errors.New("URL path is empty")
}
return nil
return string(dataByte)
}

View File

@ -1,25 +0,0 @@
## 安装
```bash
$ go get -u gitea.cdlsxd.cn/rzy_tools/request/tags
```
## 正常使用
```go
req := request.Request{
Method: "POST",
Url: reqUrl,
Json: RequestBody,
Headers: header,
}
resp, _ := req.Send()
```
## 同时大量请求或者在协程中使用建议使用
```go
r := RequestPools.Get()
defer RequestPools.ClearAndPut(r)
...
```

View File

@ -1,44 +0,0 @@
package l_request
import (
"sync"
)
type RequestPool struct {
pool sync.Pool
}
var RequestPools = &RequestPool{
pool: sync.Pool{
New: func() interface{} {
return new(Request)
},
},
}
func (re *RequestPool) Get() *Request {
return re.pool.Get().(*Request)
}
func (re *RequestPool) Put(r *Request) {
re.pool.Put(r)
}
// 重置对象
func (re *RequestPool) Reset(r *Request) {
r.Method = ""
r.Url = ""
r.Params = nil
r.Headers = nil
r.Cookies = nil
r.Data = nil
r.Json = nil
r.Files = nil
r.Raw = ""
r.JsonByte = nil
}
func (re *RequestPool) ClearAndPut(r *Request) {
re.Reset(r)
re.Put(r)
}

View File

@ -1,11 +0,0 @@
package l_request
import "testing"
func TestPool(t *testing.T) {
r := RequestPools.Get()
r.Url = "http://www.baidu.com"
RequestPools.ClearAndPut(r)
a := RequestPools.Get()
t.Log(a.Url)
}

View File

@ -1,169 +0,0 @@
package l_request
import (
"crypto/tls"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// 请求结构体
type Request struct {
Method string `json:"method"` // 请求方法
Url string `json:"url"` // 请求url
Params map[string]string `json:"params"` // Query参数
Headers map[string]string `json:"headers"` // 请求头
Cookies map[string]string `json:"cookies"` // todo 处理 Cookies
Data map[string]string `json:"data"` // 表单格式请求数据
Json map[string]interface{} `json:"json"` // JSON格式请求数据 todo 多层 嵌套
Files map[string]string `json:"files"` // todo 处理 Files
Raw string `json:"raw"` // 原始请求数据
JsonByte []byte `json:"json_raw"` // JSON格式请求数据 todo 多层 嵌套
Xml []byte `json:"xml"` // xml
}
// 响应结构体
type Response struct {
StatusCode int `json:"status_code"` // 状态码
Reason string `json:"reason"` // 状态码说明
Elapsed float64 `json:"elapsed"` // 请求耗时(秒)
Content []byte `json:"content"` // 响应二进制内容
Text string `json:"text"` // 响应文本
Headers map[string]string `json:"headers"` // 响应头
Cookies map[string]string `json:"cookies"` // todo 添加响应Cookies
Request *Request `json:"request"` // 原始请求
}
// 处理请求方法
func (r *Request) getMethod() string {
return strings.ToUpper(r.Method) // 必须转为全部大写
}
// 组装URL
func (r *Request) getUrl() string {
if r.Params != nil {
urlValues := url.Values{}
Url, _ := url.Parse(r.Url) // todo 处理err
for key, value := range r.Params {
urlValues.Set(key, value)
}
Url.RawQuery = urlValues.Encode()
return Url.String()
}
return r.Url
}
// 组装请求数据
func (r *Request) getData() io.Reader {
var reqBody string
if r.Headers == nil {
r.Headers = make(map[string]string, 1)
}
if r.Raw != "" {
reqBody = r.Raw
} else if r.Data != nil {
urlValues := url.Values{}
for key, value := range r.Data {
urlValues.Add(key, value)
}
reqBody = urlValues.Encode()
if _, ex := r.Headers["Content-Type"]; !ex {
r.Headers["Content-Type"] = "application/x-www-form-urlencoded"
}
} else if r.Json != nil {
bytesData, _ := json.Marshal(r.Json)
reqBody = string(bytesData)
if _, ex := r.Headers["Content-Type"]; !ex {
r.Headers["Content-Type"] = "application/json"
}
} else if r.JsonByte != nil {
reqBody = string(r.JsonByte)
if _, ex := r.Headers["Content-Type"]; !ex {
r.Headers["Content-Type"] = "application/json"
}
} else if r.Xml != nil {
reqBody = string(r.Xml)
if _, ex := r.Headers["Content-Type"]; !ex {
r.Headers["Content-Type"] = "application/xml"
}
}
return strings.NewReader(reqBody)
}
// 添加请求头-需要在getData后使用
func (r *Request) addHeaders(req *http.Request) {
if r.Headers != nil {
for key, value := range r.Headers {
req.Header.Add(key, value)
}
}
}
// 准备请求
func (r *Request) prepare() *http.Request {
Method := r.getMethod()
Url := r.getUrl()
Data := r.getData()
req, _ := http.NewRequest(Method, Url, Data)
r.addHeaders(req)
return req
}
// 组装响应对象
func (r *Request) packResponse(res *http.Response, elapsed float64) Response {
var resp Response
resBody, _ := io.ReadAll(res.Body)
resp.Content = resBody
resp.Text = string(resBody)
resp.StatusCode = res.StatusCode
resp.Reason = strings.Split(res.Status, " ")[1]
resp.Elapsed = elapsed
resp.Headers = map[string]string{}
for key, value := range res.Header {
resp.Headers[key] = strings.Join(value, ";")
}
return resp
}
// 发送请求
func (r *Request) Send() (Response, error) {
req := r.prepare()
client := &http.Client{}
start := time.Now()
res, err := client.Do(req)
if err != nil {
return Response{}, err
}
defer res.Body.Close()
elapsed := time.Since(start).Seconds()
return r.packResponse(res, elapsed), nil
}
// 跳过证书发送请求
func (r *Request) SendWithoutSsl() (Response, error) {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
req := r.prepare()
client := &http.Client{Transport: tr}
start := time.Now()
res, err := client.Do(req)
if err != nil {
return Response{}, err
}
defer res.Body.Close()
elapsed := time.Since(start).Seconds()
return r.packResponse(res, elapsed), nil
}
// 发送请求
func (r *Request) SendNoParseResponse() (*http.Response, error) {
req := r.prepare()
client := &http.Client{}
res, err := client.Do(req)
return res, err
}

View File

@ -1,7 +1,6 @@
package pkg
import (
"ai_scheduler/internal/pkg/utils_langchain"
"ai_scheduler/internal/pkg/utils_ollama"
"github.com/google/wire"
@ -10,7 +9,6 @@ import (
var ProviderSetClient = wire.NewSet(
NewRdb,
NewGormDb,
utils_langchain.NewUtilLangChain,
utils_ollama.NewUtilOllama,
utils_ollama.NewClient,
NewSafeChannelPool,
)

View File

@ -1,168 +0,0 @@
package util
import (
"ai_scheduler/internal/pkg/l_request"
"bufio"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
)
type KnowledgeBase struct {
session string
url string
apiKey string
}
func NewKnowledgeBase(url, apiKey, session string) *KnowledgeBase {
return &KnowledgeBase{
session: session,
url: url,
apiKey: apiKey,
}
}
// 请求知识库聊天
func (this *KnowledgeBase) Chat(ctx context.Context, query string) (text string, err error) {
req := l_request.Request{
Method: "post",
Url: this.url + "/api/v1/knowledge-chat/" + this.session,
Params: nil,
Headers: map[string]string{
"Content-Type": "application/json",
"X-API-Key": this.apiKey,
},
Cookies: nil,
Data: nil,
Json: map[string]interface{}{
"query": query,
},
Files: nil,
Raw: "",
JsonByte: nil,
Xml: nil,
}
rsp, err := req.SendNoParseResponse()
if err != nil {
return
}
defer rsp.Body.Close()
err = connectAndReadSSE(rsp)
if err != nil {
return
}
return
}
// Message 表示解析后的 SSE 消息
type Message struct {
Event string // 事件类型(默认 "message"
Data string // 消息内容(可能多行)
ID string // 消息 ID可选
}
// 连接 SSE 并读取数据
func connectAndReadSSE(resp *http.Response) error {
// 验证响应状态和格式
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("非 200 状态码: %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if !strings.Contains(contentType, "text/event-stream") {
return fmt.Errorf("不支持的 Content-Type: %s", contentType)
}
// 逐行读取响应流
scanner := bufio.NewScanner(resp.Body)
var currentMsg Message // 当前正在组装的消息
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// 空行表示一条消息结束,处理当前消息
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
printMessage(currentMsg)
currentMsg = Message{} // 重置消息
}
continue
}
// 解析字段(格式:"field: value"
parts := strings.SplitN(line, ":", 2)
if len(parts) < 2 {
continue // 无效行(无冒号),跳过
}
field := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
switch field {
case "event":
currentMsg.Event = value
case "data":
// data 可能多行,用换行符拼接(最后一条消息可能无结尾空行)
currentMsg.Data += value + ""
//case "id":
// currentMsg.ID = value
// 可选:处理 "retry" 字段(服务器建议的重连时间,单位秒)
}
}
// 检查扫描错误(如连接断开)
if err := scanner.Err(); err != nil {
return fmt.Errorf("读取流失败: %w", err)
}
// 处理最后一条未结束的消息(无结尾空行)
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
printMessage(currentMsg)
}
return nil
}
type MegContent struct {
Id string `json:"id"` // 消息 ID
ResponseType string `json:"response_type"` // 响应类型answer 或 references
Content string `json:"content"` // 消息内容
Done bool `json:"done"` // 是否完成
KnowledgeReferences interface{} `json:"knowledge_references"`
}
// printMessage 打印解析后的 SSE 消息
func printMessage(msg Message) {
//fmt.Printf("--- 收到 SSE 消息 ---")
//fmt.Printf("事件类型: %s", msg.Event)
//fmt.Printf("消息 ID: %s", msg.ID)
//fmt.Printf("内容:%s", strings.TrimSpace(msg.Data)) // 去除末尾多余换行
var content MegContent
_ = json.Unmarshal([]byte(msg.Data), &content)
fmt.Println(msg.Data)
//if content.ResponseType == "answer" {
// //fmt.Printf("%s", content.Content)
// fmt.Println(content)
//} else {
// fmt.Printf("--- 收到 SSE 消息 ---")
// fmt.Printf("事件类型: %s", msg.Event)
// fmt.Printf("消息 ID: %s", msg.ID)
// fmt.Printf("内容:%s", strings.TrimSpace(msg.Data)) // 去除末尾多余换行
//}
}
// getRetryAfter 从响应头获取重连时间(示例,需根据实际响应头调整)
func getRetryAfter(url string) int {
// 实际需重新请求并获取响应头(此处简化为固定值)
// 正确做法:在 connectAndReadSSE 中记录响应头的 Retry-After 字段
return 5 // 示例:等待 5 秒
}

View File

@ -1,129 +0,0 @@
package utils_langchain
import (
"ai_scheduler/internal/config"
"math/rand"
"net"
"net/http"
"os"
"sync"
"time"
"github.com/gofiber/fiber/v2/log"
"github.com/tmc/langchaingo/llms/ollama"
)
type UtilLangChain struct {
LlmClientPool *sync.Pool
poolSize int // 记录池大小,用于调试
model string
serverURL string
c *config.Config
}
type LlmObj struct {
Number string
Llm *ollama.LLM
}
func NewUtilLangChain(c *config.Config, logger log.AllLogger) *UtilLangChain {
poolSize := c.Sys.LlmPoolLen
if poolSize <= 0 {
poolSize = 10 // 默认值
logger.Warnf("LlmPoolLen not set, using default: %d", poolSize)
}
// 初始化 Pool
pool := &sync.Pool{
New: func() interface{} {
llm, err := ollama.New(
ollama.WithModel(c.Ollama.Model),
ollama.WithHTTPClient(http.DefaultClient),
ollama.WithServerURL(getUrl(c)),
ollama.WithKeepAlive("-1s"),
)
if err != nil {
logger.Fatalf("Failed to create Ollama client: %v", err)
panic(err) // 或者返回 nil + 错误处理
}
number := randStr(5)
log.Info(number)
return &LlmObj{
Number: number,
Llm: llm,
}
},
}
// 预填充 Pool
for i := 0; i < poolSize; i++ {
pool.Put(pool.New())
}
return &UtilLangChain{
LlmClientPool: pool,
poolSize: poolSize,
model: c.Ollama.Model,
serverURL: getUrl(c),
}
}
func (o *UtilLangChain) NewClient() *ollama.LLM {
llm, _ := ollama.New(
ollama.WithModel(o.c.Ollama.Model),
ollama.WithHTTPClient(&http.Client{
Transport: &http.Transport{
MaxIdleConns: 100, // 最大空闲连接数(默认 2太小
MaxIdleConnsPerHost: 100, // 每个 Host 的最大空闲连接数(默认 2
IdleConnTimeout: 90 * time.Second, // 空闲连接超时时间
DialContext: (&net.Dialer{
Timeout: 30 * time.Second, // 连接超时
KeepAlive: 30 * time.Second, // TCP Keep-Alive
}).DialContext,
},
Timeout: 60 * time.Second, // 整体请求超时(避免无限等待)
}),
ollama.WithServerURL(getUrl(o.c)),
ollama.WithKeepAlive("-1s"),
)
return llm
}
// Get 返回一个可用的 LLM 客户端
func (o *UtilLangChain) Get() *LlmObj {
client := o.LlmClientPool.Get().(*LlmObj)
return client
}
// Put 归还客户端(可选:检查是否仍可用)
func (o *UtilLangChain) Put(llm *LlmObj) {
if llm == nil {
return
}
o.LlmClientPool.Put(llm)
}
// Stats 返回池的统计信息(用于监控)
func (o *UtilLangChain) Stats() (current, max int) {
return o.poolSize, o.poolSize
}
func getUrl(c *config.Config) string {
baseURL := c.Ollama.BaseURL
envURL := os.Getenv("OLLAMA_BASE_URL")
if envURL != "" {
baseURL = envURL
}
return baseURL
}
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func randStr(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}

View File

@ -7,7 +7,6 @@ import (
"net/http"
"net/url"
"os"
"sync"
"github.com/ollama/ollama/api"
)
@ -48,6 +47,7 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [
Think: &api.ThinkValue{Value: false},
Tools: tools,
}
err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
res = resp
return nil
@ -59,48 +59,6 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [
return
}
func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string, model string) (err error) {
if len(model) == 0 {
model = c.config.Model
}
// 构建聊天请求
req := &api.ChatRequest{
Model: model,
Messages: messages,
Stream: nil,
Think: &api.ThinkValue{Value: true},
}
var w sync.WaitGroup
w.Add(1)
go func() {
defer w.Done()
err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
if resp.Message.Content != "" {
ch <- entitys.Response{
Index: index,
Content: resp.Message.Content,
Type: entitys.ResponseStream,
}
}
return nil
})
if err != nil {
return
}
}()
w.Wait()
return
}
func (c *Client) Generation(ctx context.Context, generateRequest *api.GenerateRequest) (result api.GenerateResponse, err error) {
err = c.client.Generate(ctx, generateRequest, func(resp api.GenerateResponse) error {
result = resp
return nil
})
return
}
// convertResponse 转换响应格式
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
//result := &entitys.ChatResponse{

View File

@ -0,0 +1,47 @@
package utils_ollama
import (
"ai_scheduler/internal/config"
"net/http"
"os"
"github.com/gofiber/fiber/v2/log"
"github.com/tmc/langchaingo/llms/ollama"
)
type UtilOllama struct {
Llm *ollama.LLM
}
func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
llm, err := ollama.New(
ollama.WithModel(c.Ollama.Model),
ollama.WithHTTPClient(http.DefaultClient),
ollama.WithServerURL(getUrl(c)),
ollama.WithKeepAlive("-1s"),
)
if err != nil {
logger.Fatal(err)
panic(err)
}
return &UtilOllama{
Llm: llm,
}
}
//func (o *UtilOllama) a() {
// var agent agents.Agent
// agent = agents.NewOneShotAgent(llm, tools, opts...)
//
// agents.NewExecutor()
//}
func getUrl(c *config.Config) string {
baseURL := c.Ollama.BaseURL
envURL := os.Getenv("OLLAMA_BASE_URL")
if envURL != "" {
baseURL = envURL
}
return baseURL
}

View File

@ -2,8 +2,6 @@ package services
import (
errorcode "ai_scheduler/internal/data/error"
"net/http"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
)
@ -20,16 +18,16 @@ func handRes(c *fiber.Ctx, _err error, rsp interface{}) error {
err = e
} else {
log.Error(c.UserContext(), "系统错误 error: ", _err)
err = errorcode.NewBusinessErr(http.StatusInternalServerError, _err.Error())
err = errorcode.NewBusinessErr("500", _err.Error())
}
}
body := fiber.Map{
"code": err.Code(),
"code": err.Code,
"msg": err.Error(),
"data": rsp,
}
log.Info(c.UserContext(), c.Path(), "请求参数=", string(c.BodyRaw()), "响应=", body)
log.Info(c.UserContext(), c.Path(), "请求参数=", c.BodyRaw(), "响应=", body)
return c.JSON(body)
}

View File

@ -2,7 +2,7 @@ package services
import (
"ai_scheduler/internal/biz"
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/data/constant"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway"
"encoding/hex"
@ -74,7 +74,6 @@ func (h *ChatService) Chat(c *websocket.Conn) {
h.Gw.AddClient(client)
log.Println("client connected:", clientID)
log.Println("客户端已连接")
// 循环读取客户端消息
for {
messageType, message, err := c.ReadMessage()
@ -82,12 +81,17 @@ func (h *ChatService) Chat(c *websocket.Conn) {
log.Println("读取错误:", err)
break
}
//简单协议bind:<uid>
if c.Headers("Sec-Websocket-Protocol") == "bind" && c.Headers("X-Session") != "" {
uid := c.Headers("X-Session")
_ = h.Gw.BindUid(clientID, uid)
log.Printf("bind %s -> uid:%s\n", clientID, uid)
}
msg, chatType := h.handleMessageToString(c, messageType, message)
if chatType == constants.ConnStatusClosed {
if chatType == constant.ConnStatusClosed {
break
}
if chatType == constants.ConnStatusIgnore {
if chatType == constant.ConnStatusIgnore {
continue
}
@ -97,49 +101,34 @@ func (h *ChatService) Chat(c *websocket.Conn) {
log.Println("JSON parse error:", err)
continue
}
//简单协议bind:<uid>
if c.Headers("Sec-Websocket-Protocol") == "bind" && req.SessionID != "" {
uid := c.Query("x-session")
_ = h.Gw.BindUid(clientID, req.SessionID)
log.Printf("bind %s -> uid:%s\n", clientID, uid)
}
err = h.routerBiz.RouteWithSocket(c, &req)
if err != nil {
log.Println("处理失败:", err)
entitys.MsgSend(c, entitys.Response{
Content: err.Error(),
Type: entitys.ResponseText,
})
continue
}
_ = entitys.MsgSend(c, entitys.Response{
Content: "",
Type: entitys.ResponseEnd,
})
}
h.Gw.RemoveClient(clientID)
_ = c.Close()
log.Println("client disconnected:", clientID)
}
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) {
switch msgType {
case websocket.TextMessage:
return msg.([]byte), constants.ConnStatusNormal
return msg.([]byte), constant.ConnStatusNormal
case websocket.BinaryMessage:
return msg.([]byte), constants.ConnStatusNormal
return msg.([]byte), constant.ConnStatusNormal
case websocket.CloseMessage:
return nil, constants.ConnStatusClosed
return nil, constant.ConnStatusClosed
case websocket.PingMessage:
// 可选:回复 Pong
c.WriteMessage(websocket.PongMessage, nil)
return nil, constants.ConnStatusIgnore
return nil, constant.ConnStatusIgnore
case websocket.PongMessage:
return nil, constants.ConnStatusIgnore
return nil, constant.ConnStatusIgnore
default:
return nil, constants.ConnStatusIgnore
return nil, constant.ConnStatusIgnore
}
return msg.([]byte), constants.ConnStatusIgnore
return msg.([]byte), constant.ConnStatusIgnore
}

View File

@ -27,11 +27,7 @@ func (s *SessionService) SessionInit(c *fiber.Ctx) error {
result, err := s.sessionBiz.SessionInit(c.Context(), req)
if err != nil {
return err
}
return c.JSON(result)
return handRes(c, err, result)
}
// SessionList 获取会话列表
@ -43,11 +39,7 @@ func (s *SessionService) SessionList(c *fiber.Ctx) error {
sessionList, err := s.sessionBiz.SessionList(c.Context(), req)
if err != nil {
return err
}
return c.JSON(fiber.Map{
return handRes(c, err, fiber.Map{
"session_list": sessionList,
})
}

View File

@ -4,9 +4,7 @@ import (
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/mapstructure"
"encoding/json"
"fmt"
"testing"
"time"
"gitea.cdlsxd.cn/self-tools/l_request"
)
@ -38,24 +36,6 @@ func Test_task2(t *testing.T) {
t.Log(err)
}
func producer(ch chan<- int) {
for i := 0; i < 100; i++ {
ch <- i // 发送数据到通道
fmt.Printf("Sent: %d\n", i)
time.Sleep(500 * time.Millisecond) // 模拟生产延迟
}
close(ch) // 关闭通道,通知接收方数据发送完毕
}
func in() {
func consumer(ch <-chan int) {
for v := range ch { // 阻塞等待数据,有数据立即处理
fmt.Printf("Received: %d\n", v)
}
}
func Test_a(t *testing.T) {
ch := make(chan int, 3) // 有缓冲通道(可选)
go producer(ch)
consumer(ch) // 主线程阻塞,直到通道关闭
}

View File

@ -1,251 +0,0 @@
package tools
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/l_request"
"bufio"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
)
// 知识库工具
type KnowledgeBaseTool struct {
config config.ToolConfig
}
// NewKnowledgeBaseTool 创建知识库工具
func NewKnowledgeBaseTool(config config.ToolConfig) *KnowledgeBaseTool {
return &KnowledgeBaseTool{config: config}
}
func (k *KnowledgeBaseTool) GetConfig() config.ToolConfig {
return k.config
}
// Name 返回工具名称
func (k *KnowledgeBaseTool) Name() string {
return "knowledgeBase"
}
// Description 返回工具描述
func (k *KnowledgeBaseTool) Description() string {
return "请求知识库"
}
// Definition 返回工具定义
func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition {
return entitys.ToolDefinition{
Type: "function",
Function: entitys.FunctionDef{
Name: k.Name(),
Description: k.Description(),
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "知识库查询条件",
},
},
"required": []string{"query"},
},
},
}
}
// Execute 执行知识库查询
func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
return k.chat(requireData)
}
// Message 表示解析后的 SSE 消息
type Message struct {
Event string // 事件类型(默认 "message"
Data string // 消息内容(可能多行)
ID string // 消息 ID可选
}
type MsgContent struct {
Id string `json:"id"`
ResponseType string `json:"response_type"`
Content string `json:"content"`
Done bool `json:"done"`
KnowledgeReferences interface{} `json:"knowledge_references"`
}
// 解析知识库响应内容,并把通过channel结果返回
func (this *KnowledgeBaseTool) msgContentParse(input string, channel chan entitys.Response) (msgContent MsgContent, err error) {
err = json.Unmarshal([]byte(input), &msgContent)
if err != nil {
err = fmt.Errorf("unmarshal input failed: %w", err)
}
channel <- entitys.Response{
Index: this.Name(),
Content: msgContent.Content,
Type: entitys.ResponseStream,
}
return
}
// 请求知识库聊天
func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error) {
req := l_request.Request{
Method: "post",
Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + requireData.KnowledgeConf.Session,
Params: nil,
Headers: map[string]string{
"Content-Type": "application/json",
"X-API-Key": requireData.KnowledgeConf.ApiKey,
},
Cookies: nil,
Data: nil,
Json: map[string]interface{}{
"query": requireData.KnowledgeConf.Query,
},
Files: nil,
Raw: "",
JsonByte: nil,
Xml: nil,
}
rsp, err := req.SendNoParseResponse()
if err != nil {
return
}
defer rsp.Body.Close()
err = this.connectAndReadSSE(rsp, requireData.Ch)
if err != nil {
return
}
return
}
// 连接 SSE 并读取数据
func (this *KnowledgeBaseTool) connectAndReadSSE(resp *http.Response, channel chan entitys.Response) error {
// 验证响应状态和格式
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("非 200 状态码: %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if !strings.Contains(contentType, "text/event-stream") {
return fmt.Errorf("不支持的 Content-Type: %s", contentType)
}
// 逐行读取响应流
scanner := bufio.NewScanner(resp.Body)
var currentMsg Message // 当前正在组装的消息
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// 空行表示一条消息结束,处理当前消息
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
_, err := this.msgContentParse(currentMsg.Data, channel)
if err != nil {
return fmt.Errorf("msgContentParse failed: %w", err)
}
currentMsg = Message{} // 重置消息
}
continue
}
// 解析字段(格式:"field: value"
parts := strings.SplitN(line, ":", 2)
if len(parts) < 2 {
continue // 无效行(无冒号),跳过
}
field := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
switch field {
case "event":
currentMsg.Event = value
case "data":
// data 可能多行,用换行符拼接(最后一条消息可能无结尾空行)
currentMsg.Data += value + ""
default:
// 忽略未知字段
}
}
// 检查扫描错误(如连接断开)
if err := scanner.Err(); err != nil {
return fmt.Errorf("读取流失败: %w", err)
}
// 处理最后一条未结束的消息(无结尾空行)
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
_, err := this.msgContentParse(currentMsg.Data, channel)
if err != nil {
return fmt.Errorf("msgContentParse failed: %w", err)
}
}
return nil
}
// 获取知识库 session
func GetKnowledgeBaseSession(host, baseId, apiKey string) (string, error) {
req := l_request.Request{
Method: "post",
Url: host + "/api/v1/sessions",
Params: nil,
Headers: map[string]string{
"Content-Type": "application/json",
"X-API-Key": apiKey,
},
Cookies: nil,
Data: nil,
Json: map[string]interface{}{
"knowledge_base_id": baseId,
},
Files: nil,
Raw: "",
JsonByte: nil,
Xml: nil,
}
rsp, err := req.Send()
if err != nil {
return "", err
}
var result sessionRsp
err = json.Unmarshal(rsp.Content, &result)
return result.Data.Id, err
}
type sessionRsp struct {
Data struct {
Id string `json:"id"`
Title string `json:"title"`
Description string `json:"description"`
TenantId int `json:"tenant_id"`
KnowledgeBaseId string `json:"knowledge_base_id"`
MaxRounds int `json:"max_rounds"`
EnableRewrite bool `json:"enable_rewrite"`
FallbackStrategy string `json:"fallback_strategy"`
FallbackResponse string `json:"fallback_response"`
EmbeddingTopK int `json:"embedding_top_k"`
KeywordThreshold float64 `json:"keyword_threshold"`
VectorThreshold float64 `json:"vector_threshold"`
RerankModelId string `json:"rerank_model_id"`
RerankTopK int `json:"rerank_top_k"`
RerankThreshold float64 `json:"rerank_threshold"`
SummaryModelId string `json:"summary_model_id"`
} `json:"data"`
Success bool `json:"success"`
}

View File

@ -1,34 +0,0 @@
package tools
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"testing"
)
func TestKnowledgeBaseTool_Execute(t *testing.T) {
kb := NewKnowledgeBaseTool(config.ToolConfig{})
channel := make(chan entitys.ResponseData)
err := kb.Execute(channel, nil, nil)
if err != nil {
t.Errorf("Execute() error = %v", err)
}
}
// session
func TestKnowledgeBaseTool_Submit(t *testing.T) {
apiKey := "sk-EfnUANKMj3DUOiEPJZ5xS8SGMsbO6be_qYAg9uZ8T3zyoFM-"
baseId := "kb-00000001"
host := "http://117.175.169.61:10000"
sessionId, err := GetKnowledgeBaseSession(host, baseId, apiKey)
if err != nil {
t.Errorf("GetKnowledgeBaseSession() error = %v", err)
}
t.Log("sessionId:", sessionId)
}

View File

@ -5,19 +5,21 @@ import (
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/utils_ollama"
"context"
"encoding/json"
"fmt"
"github.com/gofiber/websocket/v2"
)
// Manager 工具管理器
type Manager struct {
tools map[string]entitys.Tool
llm *utils_ollama.Client
llm *utils_ollama.UtilOllama
}
// NewManager 创建工具管理器
func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager {
func NewManager(config *config.Config, llm *utils_ollama.UtilOllama) *Manager {
m := &Manager{
tools: make(map[string]entitys.Tool),
llm: llm,
@ -48,31 +50,10 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager {
}
// 注册直连天下订单日志工具
if config.Tools.ZltxOrderDirectLog.Enabled {
zltxOrderLogTool := NewZltxOrderLogTool(config.Tools.ZltxOrderDirectLog)
m.tools[zltxOrderLogTool.Name()] = zltxOrderLogTool
}
//注册直连天下商品工具
if config.Tools.ZltxProduct.Enabled {
zltxProductTool := NewZltxProductTool(config.Tools.ZltxProduct)
m.tools[zltxProductTool.Name()] = zltxProductTool
}
//注册直连天下订单统计工具
if config.Tools.ZltxOrderStatistics.Enabled {
zltxOrderStatisticsTool := NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics)
m.tools[zltxOrderStatisticsTool.Name()] = zltxOrderStatisticsTool
}
// 注册知识库工具
if config.Tools.Knowledge.Enabled {
knowledgeTool := NewKnowledgeBaseTool(config.Tools.Knowledge)
m.tools[knowledgeTool.Name()] = knowledgeTool
}
// 普通对话
chat := NewNormalChatTool(m.llm, config)
m.tools[chat.Name()] = chat
// if config.ZltxOrderLog.Enabled {
// zltxOrderLogTool := NewZltxOrderLogTool(config.ZltxOrderLog)
// m.tools[zltxOrderLogTool.Name()] = zltxOrderLogTool
// }
return m
}
@ -103,13 +84,13 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi
}
// ExecuteTool 执行工具
func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error {
func (m *Manager) ExecuteTool(channel chan []byte, c *websocket.Conn, name string, args json.RawMessage) error {
tool, exists := m.GetTool(name)
if !exists {
return fmt.Errorf("tool not found: %s", name)
}
return tool.Execute(ctx, requireData)
return tool.Execute(channel, c, args)
}
// ExecuteToolCalls 执行多个工具调用

View File

@ -1,86 +0,0 @@
package tools
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
"context"
"encoding/json"
"fmt"
"github.com/ollama/ollama/api"
)
// NormalChatTool 普通对话
type NormalChatTool struct {
llm *utils_ollama.Client
config *config.Config
}
// NewNormalChatTool 实例普通对话
func NewNormalChatTool(llm *utils_ollama.Client, config *config.Config) *NormalChatTool {
return &NormalChatTool{llm: llm, config: config}
}
// Name 返回工具名称
func (w *NormalChatTool) Name() string {
return "normalChat"
}
// Description 返回工具描述
func (w *NormalChatTool) Description() string {
return "用户想进行一般性问答"
}
type NormalChat struct {
ChatContent string `json:"chat_content"`
}
// Definition 返回工具定义
func (w *NormalChatTool) Definition() entitys.ToolDefinition {
return entitys.ToolDefinition{}
}
// Execute 执行直连天下订单详情查询
func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
var req NormalChat
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
}
if req.ChatContent == "" {
req.ChatContent = "介绍一下你能做什么"
}
// 这里可以集成真实的直连天下订单详情API
return w.chat(requireData, &req)
}
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) {
//requireData.Ch <- entitys.Response{
// Index: w.Name(),
// Content: "<think></think>",
// Type: entitys.ResponseStream,
//}
err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{
{
Role: "system",
Content: "你是一个聊天助手",
},
{
Role: "assistant",
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)),
},
{
Role: "user",
Content: chat.ChatContent,
},
}, w.Name(), w.config.Ollama.GenerateModel)
if err != nil {
return fmt.Errorf("%s", err)
}
return
}

View File

@ -3,24 +3,22 @@ package tools
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
"context"
"encoding/json"
"fmt"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/ollama/ollama/api"
"github.com/gofiber/websocket/v2"
)
// ZltxOrderDetailTool 直连天下订单详情工具
type ZltxOrderDetailTool struct {
config config.ToolConfig
llm *utils_ollama.Client
llm *utils_ollama.UtilOllama
}
// NewZltxOrderDetailTool 创建直连天下订单详情工具
func NewZltxOrderDetailTool(config config.ToolConfig, llm *utils_ollama.Client) *ZltxOrderDetailTool {
func NewZltxOrderDetailTool(config config.ToolConfig, llm *utils_ollama.UtilOllama) *ZltxOrderDetailTool {
return &ZltxOrderDetailTool{config: config, llm: llm}
}
@ -65,7 +63,6 @@ type ZltxOrderDetailResponse struct {
Code int `json:"code"`
Error string `json:"error"`
Data ZltxOrderDetailData `json:"data"`
Mes string `json:"mes"`
}
type ZltxOrderLogResponse struct {
@ -81,62 +78,56 @@ type ZltxOrderDetailData struct {
}
// Execute 执行直连天下订单详情查询
func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
func (w *ZltxOrderDetailTool) Execute(channel chan []byte, c *websocket.Conn, args json.RawMessage) error {
var req ZltxOrderDetailRequest
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
if err := json.Unmarshal(args, &req); err != nil {
return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
}
if req.OrderNumber == "" {
return fmt.Errorf("number is required")
}
// 这里可以集成真实的直连天下订单详情API
return w.getZltxOrderDetail(requireData, req.OrderNumber)
return w.getZltxOrderDetail(channel, c, req.OrderNumber)
}
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireData, number string) (err error) {
func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan []byte, c *websocket.Conn, number string) (err error) {
//查询订单详情
var auth string
if c != nil {
auth = c.Headers("X-Authorization", "")
}
if len(auth) == 0 {
auth = w.config.APIKey
}
req := l_request.Request{
Url: fmt.Sprintf(w.config.BaseURL, number),
Url: fmt.Sprintf("%szltx_api/admin/direct/ai/%s", w.config.BaseURL, number),
Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
"Authorization": fmt.Sprintf("Bearer %s", auth),
},
Method: "GET",
}
res, err := req.Send()
if err != nil {
return fmt.Errorf("订单查询失败:网络请求错误:%s", err.Error())
}
var codeMap map[string]interface{}
if err = json.Unmarshal(res.Content, &codeMap); err != nil {
return
}
if codeMap["code"].(float64) != 200 {
return fmt.Errorf("订单查询失败:状态码错误:%s", string(res.Content))
}
var resData ZltxOrderDetailResponse
if err = json.Unmarshal(res.Content, &resData); err != nil {
return
}
requireData.Ch <- entitys.Response{
Index: w.Name(),
Content: res.Text,
Type: entitys.ResponseJson,
if resData.Code != 200 {
return fmt.Errorf("订单查询失败:%s", resData.Error)
}
ch <- res.Content
if resData.Data.Direct != nil && resData.Data.Direct["needAi"].(bool) {
requireData.Ch <- entitys.Response{
Index: w.Name(),
Content: "正在分析订单日志",
Type: entitys.ResponseLoading,
}
ch <- []byte("orderErrorChecking")
req = l_request.Request{
Url: fmt.Sprintf(w.config.AddURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)),
Url: fmt.Sprintf("%szltx_api/admin/direct/log/%s/%s", w.config.BaseURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)),
Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
"Authorization": fmt.Sprintf("Bearer %s", auth),
},
Method: "GET",
}
@ -151,39 +142,8 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat
if orderLog.Code != 200 {
return fmt.Errorf("订单日志查询失败:%s", orderLog.Error)
}
dataJson, err := json.Marshal(orderLog.Data)
if err != nil {
return fmt.Errorf("订单日志解析失败:%s", err)
}
err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{
{
Role: "system",
Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,提取出订单失败的原因。",
},
{
Role: "assistant",
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)),
},
{
Role: "assistant",
Content: fmt.Sprintf("需要分析的订单日志:%s", string(dataJson)),
},
{
Role: "user",
Content: requireData.UserInput,
},
}, w.Name(), "")
if err != nil {
return fmt.Errorf("订单日志解析失败:%s", err)
}
}
if resData.Data.Direct == nil {
requireData.Ch <- entitys.Response{
Index: w.Name(),
Content: "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘",
Type: entitys.ResponseText,
}
}
return
}

View File

@ -1,115 +0,0 @@
package tools
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"context"
"encoding/json"
"fmt"
"gitea.cdlsxd.cn/self-tools/l_request"
)
type ZltxOrderLogTool struct {
config config.ToolConfig
}
func (t *ZltxOrderLogTool) Name() string {
return "zltxOrderDirectLog"
}
func (t *ZltxOrderLogTool) Description() string {
return "查询订单日志"
}
func (t *ZltxOrderLogTool) Definition() entitys.ToolDefinition {
return entitys.ToolDefinition{
Type: "function",
Function: entitys.FunctionDef{
Name: t.Name(),
Description: t.Description(),
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"order_number": map[string]interface{}{
"type": "string",
"description": "订单编号",
},
"serial_number": map[string]interface{}{
"type": "string",
"description": "流水号",
},
},
"required": []string{"order_number", "serial_number"},
},
},
}
}
// ZltxOrderDetailRequest 直连天下订单详情请求参数
type ZltxOrderLogRequest struct {
OrderNumber string `json:"order_number"`
SerialNumber string `json:"serial_number"`
}
// ZltxOrderDetailResponse 直连天下订单详情响应
type ZltxOrderDirectLogResponse struct {
Code int `json:"code"`
Error string `json:"error"`
Data []ZltxOrderDirectLogData `json:"data"`
}
// ZltxOrderLogData 直连天下订单详情数据
type ZltxOrderDirectLogData struct {
Datetime string `json:"datetime"`
ServerID string `json:"serverId"`
Mes string `json:"mes"`
Data map[string]interface{} `json:"data"`
}
func (t *ZltxOrderLogTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
var req ZltxOrderLogRequest
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid zltxOrderLog request: %w", err)
}
if req.OrderNumber == "" || req.SerialNumber == "" {
return fmt.Errorf("orderNumber and serialNumber is required")
}
return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, requireData)
}
func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, requireData *entitys.RequireData) (err error) {
//查询订单详情
url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber)
req := l_request.Request{
Url: url,
Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
},
Method: "GET",
}
res, err := req.Send()
var resData ZltxOrderDirectLogResponse
if err != nil {
return
}
if resData.Code != 200 {
return fmt.Errorf("订单查询失败:日志查询失败:%s", resData.Error)
}
if err = json.Unmarshal(res.Content, &resData); err != nil {
return
}
requireData.Ch <- entitys.Response{
Index: t.Name(),
Content: res.Text,
Type: entitys.ResponseJson,
}
return
}
func NewZltxOrderLogTool(config config.ToolConfig) *ZltxOrderLogTool {
return &ZltxOrderLogTool{
config: config,
}
}

View File

@ -1,334 +0,0 @@
package tools
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"gitea.cdlsxd.cn/self-tools/l_request"
)
type ZltxProductTool struct {
config config.ToolConfig
}
func (z ZltxProductTool) Name() string {
return "zltxProduct"
}
func (z ZltxProductTool) Description() string {
return "获取直连天下商品信息"
}
func (z ZltxProductTool) Definition() entitys.ToolDefinition {
return entitys.ToolDefinition{
Type: "function",
Function: entitys.FunctionDef{
Name: z.Name(),
Description: z.Description(),
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"id": map[string]interface{}{
"type": "string",
"description": "商品ID",
},
"name": map[string]interface{}{
"type": "string",
"description": "商品名称",
},
},
"required": []string{"id", "name"},
},
},
}
}
type ZltxProductRequest struct {
Id string `json:"id"`
Name string `json:"name"`
}
func (z ZltxProductTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
var req ZltxProductRequest
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid zltxProduct request: %w", err)
}
return z.getZltxProduct(&req, requireData)
}
type ZltxProductResponse struct {
Code int `json:"code"`
Data struct {
DataCount int `json:"dataCount"`
List []ZltxProductData `json:"list"`
} `json:"data"`
Error string `json:"error"`
}
type ZltxProductDataById struct {
Code int `json:"code"`
Data ZltxProductData `json:"data"`
Error string `json:"error"`
}
type ZltxProductData struct {
ID int `json:"id"`
OursProductCategoryID int `json:"ours_product_category_id"`
OfficialProductID int `json:"official_product_id"`
Tag string `json:"tag"`
Name string `json:"name"`
Type int `json:"type"`
Discount string `json:"discount"`
Preview string `json:"preview"`
Describe string `json:"describe"`
Price string `json:"price"`
Status int `json:"status"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
Extend string `json:"extend"`
Wight int `json:"wight"`
Property int `json:"property"`
AuthProductInfo []any `json:"auth_product_info"`
AuthProductIds string `json:"auth_product_ids"`
Category struct {
ID int `json:"id"`
Name string `json:"name"`
Status int `json:"status"`
CreateTime string `json:"create_time"`
Pid int `json:"pid"`
} `json:"category"`
OfficialProduct struct {
ID int `json:"id"`
OfficialID int `json:"official_id"`
Name string `json:"name"`
Describe string `json:"describe"`
Preview string `json:"preview"`
Price float64 `json:"price"`
Status int `json:"status"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
Type int `json:"type"`
Daies int `json:"daies"`
PreviewURL string `json:"preview_url"`
Official struct {
ID int `json:"id"`
Name string `json:"name"`
Describe string `json:"describe"`
Num int `json:"num"`
Status int `json:"status"`
WebURL string `json:"web_url"`
RechargeURL string `json:"recharge_url"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
Type int `json:"type"`
Tag string `json:"tag"`
} `json:"official"`
} `json:"official_product"`
Statistics interface{} `json:"statistics"`
PlatformProductList interface{} `json:"platform_product_list"`
}
func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *entitys.RequireData) error {
var Url string
var params map[string]string
if body.Id != "" {
Url = fmt.Sprintf("%s/%s", z.config.BaseURL, body.Id)
} else {
Url = fmt.Sprintf("%s?keyword=%s&limit=10&page=1", z.config.BaseURL, body.Name)
params = map[string]string{
"keyword": body.Name,
"limit": "10",
"page": "1",
}
}
req := l_request.Request{
//get /admin/oursProduct/{product_id} 通过商品id获取我们的商品信息
//get /admin/oursProduct?keyword={name}&limit=10&page=1 通过商品name获取我们的商品列表
//根据商品ID或名称走不同的接口查询
Url: Url,
Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
},
Params: params,
Method: "GET",
}
res, err := req.Send()
if err != nil {
return err
}
var resp ZltxProductResponse
if err = json.Unmarshal(res.Content, &resp); err != nil {
return fmt.Errorf("解析商品数据失败:%w", err)
}
if resp.Code != 200 {
return fmt.Errorf("商品查询失败:%s", string(res.Content))
}
if resp.Data.List == nil || len(resp.Data.List) == 0 {
var respData ZltxProductDataById
if err := json.Unmarshal(res.Content, &respData); err != nil {
return fmt.Errorf("解析商品数据失败:%w", err)
}
if respData.Data.ID == 0 {
return fmt.Errorf("商品查询失败:暂无数据")
}
resp.Data.List = []ZltxProductData{respData.Data}
resp.Data.DataCount = 1
}
//调用 平台商品列表
if resp.Data.List != nil && len(resp.Data.List) > 0 {
for i := range resp.Data.List {
// 调用 平台商品列表
if resp.Data.List[i].AuthProductIds != "" {
platformProductList := z.ExecutePlatformProductList(requireData.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID)
resp.Data.List[i].PlatformProductList = platformProductList
}
}
}
marshal, err := json.Marshal(resp)
if err != nil {
return err
}
requireData.Ch <- entitys.Response{
Index: z.Name(),
Content: string(marshal),
Type: entitys.ResponseJson,
}
return nil
}
type PlatformProductResponse struct {
Code int `json:"code"`
Data []struct {
ID int `json:"id"`
PlatformID int `json:"platform_id"`
OfficialProductID int `json:"official_product_id"`
Weight int `json:"weight"`
Code string `json:"code"`
Name string `json:"name"`
Discount string `json:"discount"`
Price string `json:"price"`
Extra string `json:"extra"`
AccountType []string `json:"account_type"`
Status int `json:"status"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
RebatePrice string `json:"rebate_price"`
ExchangeType int `json:"exchange_type"`
IsReportInfo int `json:"is_report_info"`
Describe string `json:"describe"`
PlatformAliasID int `json:"platform_alias_id"`
Platform struct {
ID int `json:"id"`
Tag string `json:"tag"`
Name string `json:"name"`
Status int `json:"status"`
Weight int `json:"weight"`
UpdateTime string `json:"update_time"`
SearchTime string `json:"search_time"`
Balance string `json:"balance"`
BalanceWarning string `json:"balance_warning"`
StockType int `json:"stock_type"`
IsIndependent int `json:"is_independent"`
IndependentDeadline string `json:"independent_deadline"`
CreateTime string `json:"create_time"`
} `json:"platform"`
} `json:"data"`
Error string `json:"error"`
}
type PlatformData struct {
ID int `json:"id"`
PlatformID int `json:"platform_id"`
OfficialProductID int `json:"official_product_id"`
Weight int `json:"weight"`
Code string `json:"code"`
Name string `json:"name"`
Discount string `json:"discount"`
Price string `json:"price"`
Extra string `json:"extra"`
AccountType []string `json:"account_type"`
Status int `json:"status"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
RebatePrice string `json:"rebate_price"`
ExchangeType int `json:"exchange_type"`
IsReportInfo int `json:"is_report_info"`
Describe string `json:"describe"`
PlatformAliasID int `json:"platform_alias_id"`
Platform struct {
ID int `json:"id"`
Tag string `json:"tag"`
Name string `json:"name"`
Status int `json:"status"`
Weight int `json:"weight"`
UpdateTime string `json:"update_time"`
SearchTime string `json:"search_time"`
Balance string `json:"balance"`
BalanceWarning string `json:"balance_warning"`
StockType int `json:"stock_type"`
IsIndependent int `json:"is_independent"`
IndependentDeadline string `json:"independent_deadline"`
CreateTime string `json:"create_time"`
} `json:"platform"`
}
func (z ZltxProductTool) ExecutePlatformProductList(auth string, authProductIds string, officialProductID int) []PlatformData {
if authProductIds == "" {
return nil
}
//authProductIds 以逗号分割
authProductIdList := strings.Split(authProductIds, ",")
for _, authProductId := range authProductIdList {
var platformProductResponse PlatformProductResponse
req := l_request.Request{
//https://gateway.dev.cdlsxd.cn/zltx_api/admin/platformProduct/getProductsByOfficialProductId?official_product_id=282&is_all=true
Url: fmt.Sprintf("%s?official_product_id=%d&is_all=true", z.config.AddURL, officialProductID),
Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", auth),
},
Method: "GET",
}
res, err := req.Send()
if err != nil {
// 可以考虑记录日志而不是直接跳过
continue
}
if err := json.Unmarshal(res.Content, &platformProductResponse); err != nil {
// 可以考虑记录日志而不是直接跳过
continue
}
authProductIdInt, err := strconv.Atoi(authProductId)
if err != nil {
// 可以考虑记录日志而不是直接跳过
continue
}
for _, platformProduct := range platformProductResponse.Data {
if platformProduct.ID == authProductIdInt {
return []PlatformData{platformProduct}
}
}
}
return nil
}
type ZltxProductPlatformProductList struct {
Code int `json:"code"`
Data map[string]any `json:"data"`
Error string `json:"error"`
}
func NewZltxProductTool(config config.ToolConfig) *ZltxProductTool {
return &ZltxProductTool{
config: config,
}
}

View File

@ -1,123 +0,0 @@
package tools
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"context"
"encoding/json"
"fmt"
"sort"
"gitea.cdlsxd.cn/self-tools/l_request"
)
type ZltxOrderStatisticsTool struct {
config config.ToolConfig
}
func (z ZltxOrderStatisticsTool) Name() string {
return "zltxOrderStatistics"
}
func (z ZltxOrderStatisticsTool) Description() string {
return "通过账号获取订单统计信息"
}
func (z ZltxOrderStatisticsTool) Definition() entitys.ToolDefinition {
return entitys.ToolDefinition{
Type: "function",
Function: entitys.FunctionDef{
Name: z.Name(),
Description: z.Description(),
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"number": map[string]interface{}{
"type": "string",
"description": "账号或分销商号",
},
},
"required": []string{"number"},
},
},
}
}
type ZltxOrderStatisticsRequest struct {
Number string `json:"number"`
}
func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
var req ZltxOrderStatisticsRequest
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
return err
}
if req.Number == "" {
return fmt.Errorf("number is required")
}
return z.getZltxOrderStatistics(req.Number, requireData)
}
type ZltxOrderStatisticsResponse struct {
Code int `json:"code"`
Data struct {
RecentThreeDays []ZltxOrderStatisticsData `json:"recentThreeDays"`
RecentOneMonth []ZltxOrderStatisticsData `json:"recentOneMonth"`
} `json:"data"`
Error string `json:"error"`
}
type ZltxOrderStatisticsData struct {
Date string `json:"date"`
Number string `json:"number"`
Success int `json:"success"`
Fail int `json:"fail"`
Total int `json:"total"`
}
func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireData *entitys.RequireData) error {
//查询订单详情
url := fmt.Sprintf("%s%s", z.config.BaseURL, number)
req := l_request.Request{
Url: url,
Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth),
},
Method: "GET",
}
res, err := req.Send()
var resData ZltxOrderStatisticsResponse
if err != nil {
return err
}
if err := json.Unmarshal(res.Content, &resData); err != nil {
return err
}
if resData.Code != 200 {
return fmt.Errorf("为获取到数据,请检查权限: %s", string(res.Content))
}
//按照日期排序
sort.Slice(resData.Data.RecentThreeDays, func(i, j int) bool {
return resData.Data.RecentThreeDays[i].Date < resData.Data.RecentThreeDays[j].Date
})
sort.Slice(resData.Data.RecentOneMonth, func(i, j int) bool {
return resData.Data.RecentOneMonth[i].Date < resData.Data.RecentOneMonth[j].Date
})
jsonByte, err := json.Marshal(resData)
if err != nil {
return err
}
requireData.Ch <- entitys.Response{
Index: z.Name(),
Content: string(jsonByte),
Type: entitys.ResponseJson,
}
return nil
}
func NewZltxOrderStatisticsTool(config config.ToolConfig) *ZltxOrderStatisticsTool {
return &ZltxOrderStatisticsTool{
config: config,
}
}