diff --git a/Dockerfile b/Dockerfile index 4b5b1ec..8b02af6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,3 +1,28 @@ +## 使用官方Go镜像作为构建环境 +FROM golang:1.24.7-alpine AS builder + +# 设置工作目录 +WORKDIR /app + +# 使用国内镜像源加速APK包下载 +RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories + +# 使用国内镜像源加速依赖下载 +ENV GOPROXY=https://goproxy.cn,direct + +# 复制项目源码 +COPY . . + +# 复制go模块依赖文件 +COPY go.mod go.sum ./ +RUN go mod tidy + +RUN go install github.com/google/wire/cmd/wire@latest +RUN wire ./cmd/server + +# 编译Go应用程序,生成静态链接的二进制文件 +RUN go build -ldflags="-s -w" -o server ./cmd/server + # 创建最终镜像,用于运行编译后的Go程序 FROM alpine @@ -12,10 +37,13 @@ RUN echo 'http://mirrors.ustc.edu.cn/alpine/v3.5/main' > /etc/apk/repositories \ WORKDIR /app # 将编译好的二进制文件从构建阶段复制到运行阶段 -COPY ./ /app +COPY --from=builder /app/server ./server +# 复制配置文件夹 +COPY --from=builder /app/config ./config + ENV TZ=Asia/Shanghai # 设置容器启动时运行的命令 -CMD ["./bin/server"] +CMD ["./server"] diff --git a/README-test.md b/README-test.md new file mode 100644 index 0000000..161ea9a --- /dev/null +++ b/README-test.md @@ -0,0 +1,4 @@ +[https://p6-img.searchpstatp.com/tos-cn-i-vvloioitz3/6e5e76d274df2efabde9194a06f97e89~tplv-vvloioitz3-6:190:124.jpeg] + + +![图片](https://p6-img.searchpstatp.com/tos-cn-i-vvloioitz3/ab5ae998d8162b431f44fb2a0ed9ae33~tplv-vvloioitz3-6:190:124.jpeg) \ No newline at end of file diff --git a/cmd/server/main.go b/cmd/server/main.go index f0b0214..d765735 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,7 +2,7 @@ package main import ( "ai_scheduler/internal/config" - + "context" "flag" "fmt" @@ -11,6 +11,7 @@ import ( func main() { configPath := flag.String("config", "./config/config_test.yaml", "Path to configuration file") + onBot := flag.String("bot", "", "bot start") flag.Parse() bc, err := config.LoadConfig(*configPath) if err != nil { @@ -23,8 +24,8 @@ func main() { } defer func() { cleanup() - }() + app.DingBotServer.Run(context.Background(), *onBot) log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port))) } diff --git a/cmd/server/wire.go b/cmd/server/wire.go index f134ef9..8357aae 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -5,13 +5,19 @@ package main import ( "ai_scheduler/internal/biz" + "ai_scheduler/internal/biz/handle/dingtalk" + "ai_scheduler/internal/biz/tools_regis" "ai_scheduler/internal/config" "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/domain/component" + "ai_scheduler/internal/domain/repo" + "ai_scheduler/internal/domain/workflow" "ai_scheduler/internal/pkg" "ai_scheduler/internal/server" "ai_scheduler/internal/services" + + // "ai_scheduler/internal/tool_callback" "ai_scheduler/internal/tools" - "ai_scheduler/internal/tools_bot" "ai_scheduler/utils" "github.com/gofiber/fiber/v2/log" @@ -22,13 +28,18 @@ import ( func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) { panic(wire.Build( server.ProviderSetServer, + workflow.ProviderSetWorkflow, tools.ProviderSetTools, pkg.ProviderSetClient, services.ProviderSetServices, biz.ProviderSetBiz, impl.ProviderImpl, utils.ProviderUtils, - tools_bot.ProviderSetBotTools, + dingtalk.ProviderSetDingTalk, + tools_regis.ProviderToolsRegis, + // tool_callback.ProviderSetCallBackTools, + component.ProviderSet, + repo.ProviderSet, )) } diff --git a/config/config.yaml b/config/config.yaml index c1e8c08..a1497f4 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -23,7 +23,7 @@ redis: host: 47.97.27.195:6379 type: node pass: lansexiongdi@666 - key: report-api-test + key: report-api pollSize: 5 #连接池大小,不配置,或配置为0表示不启用连接池 minIdleConns: 2 #最小空闲连接数 maxIdleTime: 30 #每个连接最大空闲时间,如果超过了这个时间会被关闭 @@ -65,6 +65,13 @@ tools: enabled: true base_url: "https://revcl.1688sup.com/api/admin/afterSales/reseller_pre_ai" +dingtalk: + api_key: "dingsbbntrkeiyazcfdg" + api_secret: "ObqxwyR20r9rVNhju0sCPQyQA98_FZSc32W4vgxnGFH_b02HZr1BPCJsOAF816nu" + table_demand: + url: "https://alidocs.dingtalk.com/i/nodes/YQBnd5ExVE6qAbnOiANQg2KKJyeZqMmz" + base_id: "2Amq4vjg89RnYx9DTp66m2orW3kdP0wQ" + sheet_id_or_name: "数据表" default_prompt: img_recognize: diff --git a/config/config_env.yaml b/config/config_env.yaml new file mode 100644 index 0000000..e180a4b --- /dev/null +++ b/config/config_env.yaml @@ -0,0 +1,128 @@ +# 服务器配置 +server: + port: 8090 + host: "0.0.0.0" + +ollama: + base_url: "http://192.168.6.109:11434" + model: "qwen3-coder:480b-cloud" + generate_model: "qwen3-coder:480b-cloud" + mapping_model: "deepseek-v3.2:cloud" + vl_model: "qwen2.5vl:7b" + timeout: "120s" + level: "info" + format: "json" + +vllm: + base_url: "http://117.175.169.61:16001/v1" + vl_model: "Qwen2.5-VL-3B-Instruct-AWQ" + timeout: "120s" + level: "info" + +coze: + base_url: "https://api.coze.cn" + api_secret: "pat_guUSPk8KZFvIIbVReuaMlOBVAaIISSdkTEV8MaRgVPNv6UEYPHKTBUXznFcxl04H" + + +sys: + session_len: 6 + channel_pool_len: 100 + channel_pool_size: 32 + llm_pool_len: 5 + heartbeat_interval: 30 +redis: + host: 47.97.27.195:6379 + type: node + pass: lansexiongdi@666 + key: report-api-test + pollSize: 5 #连接池大小,不配置,或配置为0表示不启用连接池 + minIdleConns: 2 #最小空闲连接数 + maxIdleTime: 30 #每个连接最大空闲时间,如果超过了这个时间会被关闭 + tls: 30 + db: +db: + driver: mysql + source: root:SD###sdf323r343@tcp(121.199.38.107:3306)/sys_ai_test?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai + +tools: + zltxOrderDetail: + enabled: true + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/ai/%s" + add_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/log/%s/%s" + api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU4MDkxOTU4LCJuYmYiOjE3NTgwOTAxNTgsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.Bjsx9f8yfcrV9EWxb0n6POwnXVOq9XPRD78JFZnnf1_VAVMN78W4W570SZL27PWuDnkD7E4oUg6RzeZwZgl7BZrNpNr-a-QpNC5qCptqrqXeNfVStmX7pxWA8GqnzI8ybkZgbhQ58Gje7DzdJtBq_8zte_LDaYhTYXdIc5EAG0AbCzAk22nPTl47nkMeHtmisXQVLEsdibl1hW3ViFJlXwfXvUrOENItmL1_mRYkggUB0MaTu2nHJOYM6PaOVGLHx-74eepnmK2rm6konFEb6ed-Ukc6gVR-nM9yWZaYLYNGNKJLwZoCX3tRuerq74n4kzQgWmUEJeaVI1yIGSw1zw" + zltxProduct: + enabled: true + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/oursProduct" + add_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/platformProduct/getProductsByOfficialProductId" + api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ" + zltxOrderStatistics: + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/ai/search/" + enabled: true + api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ" + knowledge: + base_url: "http://117.175.169.61:10000" + enabled: true + DingTalkBot: + enabled: true + api_key: "dingsbbntrkeiyazcfdg" + api_secret: "ObqxwyR20r9rVNhju0sCPQyQA98_FZSc32W4vgxnGFH_b02HZr1BPCJsOAF816nu" + zltxOrderAfterSaleSupplier: + enabled: true + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/directs" + zltxOrderAfterSaleReseller: + enabled: true + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/reseller_pre_ai" + zltxOrderAfterSaleResellerBatch: + enabled: true + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/reseller_pre_ai" + +# eino tool 配置 +eino_tools: + # 货易通商品上传 + hytProductUpload: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/supplier/batch/add/complete" + add_url: "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage" + # 货易通供应商查询 + hytSupplierSearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/supplier/list" + # 货易通仓库查询 + hytWarehouseSearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/warehouse/list" + # 货易通商品添加 + hytGoodsAdd: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/add" + add_url: "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage" + # 货易通商品图片添加 + hytGoodsMediaAdd: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/media/add/batch" + # 货易通商品分类添加 + hytGoodsCategoryAdd: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/good/category/relation/add" + # 货易通商品分类查询 + hytGoodsCategorySearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/category/list" + # 货易通商品品牌查询 + hytGoodsBrandSearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/brand/list" + +dingtalk: + api_key: "dingsbbntrkeiyazcfdg" + api_secret: "ObqxwyR20r9rVNhju0sCPQyQA98_FZSc32W4vgxnGFH_b02HZr1BPCJsOAF816nu" + table_demand: + url: "https://alidocs.dingtalk.com/i/nodes/YQBnd5ExVE6qAbnOiANQg2KKJyeZqMmz" + base_id: "YQBnd5ExVE6qAbnOiANQg2KKJyeZqMmz" + sheet_id_or_name: "数据表" + + +default_prompt: + img_recognize: + system_prompt: + '你是一个具备图像理解与用户意图分析能力的智能助手。当用户提供一张图片时,请完成以下任务: + 1. 关键信息提取: + 提取出图片中对用户可能有用的关键信息(例如金额、日期、标题、编号、联系信息、商品名称等)。 + 若图片为文档类(如合同、发票、收据),请结构化输出关键字段(如客户名称、金额、开票日期等)。 + ' + user_prompt: '识别图片内容' +# 权限配置 +permissionConfig: + permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId=" diff --git a/config/config_test.yaml b/config/config_test.yaml index 8275102..2180e18 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -4,14 +4,24 @@ server: host: "0.0.0.0" ollama: - base_url: "http://127.0.0.1:11434" + base_url: "http://host.docker.internal:11434" model: "qwen3-coder:480b-cloud" generate_model: "qwen3-coder:480b-cloud" - vl_model: "qwen2.5vl:7b" + mapping_model: "deepseek-v3.2:cloud" + vl_model: "gemini-3-pro-preview" timeout: "120s" level: "info" format: "json" +vllm: + base_url: "http://host.docker.internal:8001/v1" + vl_model: "Qwen2.5-VL-3B-Instruct-AWQ" + timeout: "120s" + level: "info" + +coze: + base_url: "https://api.coze.cn" + api_secret: "sat_AqvFcdNgesP8megy1ItTscWFXRcsHRzmM4NJ1KNavfcdT0EPwYuCPkDqGhItpx13" sys: @@ -19,6 +29,7 @@ sys: channel_pool_len: 100 channel_pool_size: 32 llm_pool_len: 5 + heartbeat_interval: 30 redis: host: 47.97.27.195:6379 type: node @@ -64,8 +75,57 @@ tools: zltxOrderAfterSaleResellerBatch: enabled: true base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/reseller_pre_ai" + weather: + enabled: true + base_url: "https://restapi.amap.com/v3/weather/weatherInfo" + api_key: "12afbde5ab78cb7e575ff76bd0bdef2b" + cozeExpress: + enabled: true + base_url: "https://api.coze.cn" + api_key: "7582477438102552616" + api_secret: "pat_eEN0BdLNDughEtABjJJRYTW71olvDU0qUbfQUeaPc2NnYWO8HeyNoui5aR9z0sSZ" + cozeCompany: + enabled: true + base_url: "https://api.coze.cn" + api_key: "7583905168607100978" + api_secret: "pat_eEN0BdLNDughEtABjJJRYTW71olvDU0qUbfQUeaPc2NnYWO8HeyNoui5aR9z0sSZ" +# eino tool 配置 +eino_tools: + # 货易通商品上传 + hytProductUpload: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/supplier/batch/add/complete" + add_url: "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage" + # 货易通供应商查询 + hytSupplierSearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/supplier/list" + # 货易通仓库查询 + hytWarehouseSearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/warehouse/list" + # 货易通商品添加 + hytGoodsAdd: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/add" + add_url: "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage" + # 货易通商品图片添加 + hytGoodsMediaAdd: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/media/add/batch" + # 货易通商品分类添加 + hytGoodsCategoryAdd: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/good/category/relation/add" + # 货易通商品分类查询 + hytGoodsCategorySearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/category/list" + # 货易通商品品牌查询 + hytGoodsBrandSearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/brand/list" +dingtalk: + api_key: "dingsbbntrkeiyazcfdg" + api_secret: "ObqxwyR20r9rVNhju0sCPQyQA98_FZSc32W4vgxnGFH_b02HZr1BPCJsOAF816nu" + table_demand: + url: "https://alidocs.dingtalk.com/i/nodes/YQBnd5ExVE6qAbnOiANQg2KKJyeZqMmz" + base_id: "YQBnd5ExVE6qAbnOiANQg2KKJyeZqMmz" + sheet_id_or_name: "数据表" default_prompt: img_recognize: @@ -79,3 +139,56 @@ default_prompt: # 权限配置 permissionConfig: permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId=" + +# llm 服务配置 +llm: + providers: + ollama: + endpoint: http://host.docker.internal:11434 + timeout: 60s + models: + - id: qwen3-coder:480b-cloud + name: qwen3-coder:480b-cloud + streaming: true + modalities: [text] + max_tokens: 4096 + + vllm: + endpoint: http://117.175.169.61:16001 + timeout: 60s + models: + - id: models/Qwen2.5-VL-3B-Instruct-AWQ + name: qwen2.5-vl-3b + streaming: true + modalities: [text, image] + max_tokens: 4096 + + # 每个能力只绑定一个 provider+model,不做自动回退 + capabilities: + intent: + provider: vllm + model: qwen2.5-vl-3b + parameters: + temperature: 0.2 + max_tokens: 4096 + stream: false + + vision: + provider: vllm + model: qwen2.5-vl-3b + parameters: + temperature: 0.5 + max_tokens: 4096 + stream: true + + chat: + provider: ollama + model: qwen3-coder:480b-cloud + parameters: + temperature: 0.7 + max_tokens: 4096 + stream: true +#ding_talk_bots: +# public: +# client_id: "dingchg59zwwvmuuvldx", +# client_secret: "ZwetAnRiTQobNFVlNrshRagSMAJIFpBAepWkWI7on7Tt_o617KHtTjBLp8fQfplz", diff --git a/deploy.sh b/deploy.sh index 53d6253..5c0061a 100644 --- a/deploy.sh +++ b/deploy.sh @@ -1,9 +1,9 @@ -export GO111MODULE=on -export GOPROXY=https://goproxy.cn,direct -export GOPATH=/root/go -export GOCACHE=/root/.cache/go-build +#export GO111MODULE=on +#export GOPROXY=https://goproxy.cn,direct +#export GOPATH=/root/go +#export GOCACHE=/root/.cache/go-build export CONTAINER_NAME=ai_scheduler -export CGO_ENABLED='0' +#export CGO_ENABLED='0' MODE="$1" @@ -14,16 +14,18 @@ fi CONFIG_FILE="config/config.yaml" BRANCH="master" +BOT="ALL" if [ "$MODE" = "dev" ]; then CONFIG_FILE="config/config_test.yaml" + BOT="zltx" BRANCH="test" fi git fetch origin git checkout "$BRANCH" git pull origin "$BRANCH" -go mod tidy -make build +#go mod tidy +#make build docker build -t ${CONTAINER_NAME} . docker stop ${CONTAINER_NAME} docker rm -f ${CONTAINER_NAME} @@ -34,6 +36,6 @@ docker run -itd \ -e "OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://host.docker.internal:11434}" \ -e "MODE=${MODE}" \ -p 8090:8090 \ - "${CONTAINER_NAME}" ./bin/server --config "./${CONFIG_FILE}" + "${CONTAINER_NAME}" ./server --config "./${CONFIG_FILE}" --bot "./${BOT}" docker logs -f ${CONTAINER_NAME} \ No newline at end of file diff --git a/go.mod b/go.mod index 6b08e4a..76bed20 100644 --- a/go.mod +++ b/go.mod @@ -1,19 +1,26 @@ module ai_scheduler -go 1.24.0 - -toolchain go1.24.7 +go 1.24.7 require ( + gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go v0.9.3 gitea.cdlsxd.cn/self-tools/l_request v1.0.8 github.com/alibabacloud-go/darabonba-openapi/v2 v2.0.12 github.com/alibabacloud-go/dingtalk v1.6.96 github.com/alibabacloud-go/tea v1.2.2 github.com/alibabacloud-go/tea-utils/v2 v2.0.6 + github.com/cloudwego/eino v0.7.7 + github.com/cloudwego/eino-ext/components/model/ollama v0.1.6 + github.com/cloudwego/eino-ext/components/model/openai v0.1.5 + github.com/coze-dev/coze-go v0.0.0-20251029161603-312b7fd62d20 github.com/emirpasic/gods v1.18.1 github.com/faabiosr/cachego v0.26.0 github.com/fastwego/dingding v1.0.0-beta.4 + github.com/gabriel-vasile/mimetype v1.4.11 github.com/go-kratos/kratos/v2 v2.9.1 + github.com/go-playground/locales v0.14.1 + github.com/go-playground/universal-translator v0.18.1 + github.com/go-playground/validator/v10 v10.20.0 github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/websocket/v2 v2.2.1 github.com/google/uuid v1.6.0 @@ -22,9 +29,8 @@ require ( github.com/redis/go-redis/v9 v9.16.0 github.com/spf13/viper v1.17.0 github.com/tmc/langchaingo v0.1.13 + golang.org/x/sync v0.15.0 google.golang.org/grpc v1.64.0 - google.golang.org/protobuf v1.34.1 - gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.6.0 gorm.io/gorm v1.31.0 xorm.io/builder v0.3.13 @@ -39,31 +45,53 @@ require ( github.com/alibabacloud-go/tea-xml v1.1.3 // indirect github.com/aliyun/credentials-go v1.4.6 // indirect github.com/andybalholm/brotli v1.1.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.14.1 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/clbanning/mxj/v2 v2.5.5 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/eino-contrib/jsonschema v1.0.3 // indirect + github.com/eino-contrib/ollama v0.1.0 // indirect + github.com/evanphx/json-patch v0.5.2 // indirect github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect + github.com/goph/emperror v0.17.2 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.9 // indirect + github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/magiconair/properties v1.8.7 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/meguminnnnnnnnn/go-openai v0.1.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/nikolalohinski/gonja v1.5.3 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pkoukk/tiktoken-go v0.1.6 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/sagikazarmark/locafero v0.3.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.10.0 // indirect github.com/spf13/cast v1.5.1 // indirect @@ -71,16 +99,22 @@ require ( github.com/stretchr/testify v1.11.1 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yargevad/filepathx v1.0.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect - golang.org/x/crypto v0.36.0 // indirect + golang.org/x/arch v0.11.0 // indirect + golang.org/x/crypto v0.39.0 // indirect golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect golang.org/x/net v0.38.0 // indirect - golang.org/x/sys v0.31.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.26.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect + google.golang.org/protobuf v1.34.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 29b4ace..a6e9a9c 100644 --- a/go.sum +++ b/go.sum @@ -38,12 +38,15 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go v0.9.3 h1:qaSPxVz5kHCs2AWvShnOG8mUgrUP9Gc3uUB4ZX1BF5A= +gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go v0.9.3/go.mod h1:5mCPTjBxOk69LRJPHWJRNTkfxcffqlQSOBMD4M5JVnE= gitea.cdlsxd.cn/self-tools/l_request v1.0.8 h1:FaKRql9mCVcSoaGqPeBOAruZ52slzRngQ6VRTYKNSsA= gitea.cdlsxd.cn/self-tools/l_request v1.0.8/go.mod h1:Qf4hVXm2Eu5vOvwXk8D7U0q/aekMCkZ4Fg9wnRKlasQ= gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s= gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= github.com/alibabacloud-go/alibabacloud-gateway-pop v0.0.6 h1:eIf+iGJxdU4U9ypaUfbtOWCsZSbTb8AUHvyPrxu6mAA= github.com/alibabacloud-go/alibabacloud-gateway-pop v0.0.6/go.mod h1:4EUIoxs/do24zMOGGqYVWgw0s9NtiylnJglOeEB5UJo= github.com/alibabacloud-go/alibabacloud-gateway-spi v0.0.4/go.mod h1:sCavSAvdzOjul4cEqeVtvlSaSScfNsTQ+46HwlTL1hc= @@ -96,12 +99,29 @@ github.com/aliyun/credentials-go v1.4.6 h1:CG8rc/nxCNKfXbZWpWDzI9GjF4Tuu3Es14qT8 github.com/aliyun/credentials-go v1.4.6/go.mod h1:Jm6d+xIgwJVLVWT561vy67ZRP4lPTQxMbEYRuT2Ti1U= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= +github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/mockey v1.2.14 h1:KZaFgPdiUwW+jOWFieo3Lr7INM1P+6adO3hxZhDswY8= +github.com/bytedance/mockey v1.2.14/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= +github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= +github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= @@ -110,9 +130,21 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn github.com/clbanning/mxj/v2 v2.5.5 h1:oT81vUeEiQQ/DcHbzSytRngP6Ky9O+L+0Bw0zSJag9E= github.com/clbanning/mxj/v2 v2.5.5/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/cloudwego/eino v0.7.7 h1:WhP0SMWWPgLdOH03HrKUxtP9/Q96NhziMZNEQl9lxpU= +github.com/cloudwego/eino v0.7.7/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ= +github.com/cloudwego/eino-ext/components/model/ollama v0.1.6 h1:ZbrhV91uE0hGIOYXhb2i3G6tQJ/rK2SLYtoYrmocZXM= +github.com/cloudwego/eino-ext/components/model/ollama v0.1.6/go.mod h1:GDXrvorGdRNV6g2mK5jdla2D8Xc/hh7XDrTeGDteLLo= +github.com/cloudwego/eino-ext/components/model/openai v0.1.5 h1:+yvGbTPw93li9GSmdm6Rix88Yy8AXg5NNBcRbWx3CQU= +github.com/cloudwego/eino-ext/components/model/openai v0.1.5/go.mod h1:IPVYMFoZcuHeVEsDTGN6SZjvue0xr1iZFhdpq1SBWdQ= +github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2 h1:r9Id2wzJ05PoHl+Km7jQgNMgciaZI93TVnUYso89esM= +github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2/go.mod h1:S4OkvglPY9hsm9tXeShODrf/WN1Cgu4bqu4nn/CnIic= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/coze-dev/coze-go v0.0.0-20251029161603-312b7fd62d20 h1:m6P88V9lLrxZsE7uj9otq7l7nqDuCSAJ86KhzRlWf0M= +github.com/coze-dev/coze-go v0.0.0-20251029161603-312b7fd62d20/go.mod h1:wdT5CFt/sFsWz9hna2Z7DWzUra9spx0SoX1PUZyoSB0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -121,6 +153,12 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0= +github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= +github.com/eino-contrib/ollama v0.1.0 h1:z1NaMdKW6X1ftP8g5xGGR5zDRPUtuTKFq35vBQgxsN4= +github.com/eino-contrib/ollama v0.1.0/go.mod h1:mYsQ7b3DeqY8bHPuD3MZJYTqkgyL6LoemxoP/B7ZNhA= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -129,6 +167,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k= +github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= github.com/faabiosr/cachego v0.15.0/go.mod h1:L2EomlU3/rUWjzFavY9Fwm8B4zZmX2X6u8kTMkETrwI= github.com/faabiosr/cachego v0.26.0 h1:EDDv2y9T0XJ4Cx3tUhbKSUayGWxCGkkZUivNLceHRWY= github.com/faabiosr/cachego v0.26.0/go.mod h1:p54WXVzeB1CctH1ix/rjqv1EotNzD0Xoxk2IsR1PQX8= @@ -142,18 +182,34 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik= +github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-kratos/kratos/v2 v2.9.1 h1:EGif6/S/aK/RCR5clIbyhioTNyoSrii3FC118jG40Z0= github.com/go-kratos/kratos/v2 v2.9.1/go.mod h1:a1MQLjMhIh7R0kcJS9SzJYR43BRI7EPzzN0J1Ksu2bA= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw= github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/gofiber/websocket/v2 v2.2.1 h1:C9cjxvloojayOp9AovmpQrk8VqvVnT8Oao3+IUygH7w= github.com/gofiber/websocket/v2 v2.2.1/go.mod h1:Ao/+nyNnX5u/hIFPuHl28a+NIkrqK7PRimyKaj4JxVU= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -215,8 +271,14 @@ github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3 github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= +github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= +github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= @@ -224,19 +286,26 @@ github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= +github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -245,8 +314,12 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -255,6 +328,10 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.6.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/meguminnnnnnnnn/go-openai v0.1.0 h1:BGzB1PlS2Epq0mBB2TGLwzMihbR7BANrlMH3w4ZnY88= +github.com/meguminnnnnnnnn/go-openai v0.1.0/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -265,16 +342,22 @@ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3Rllmb github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= +github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/ollama/ollama v0.12.7 h1:dxokli1UyO/a0Aun5sE4+0Gg+A9oMUAPiFQhxrXOIXA= github.com/ollama/ollama v0.12.7/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= @@ -290,6 +373,7 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= github.com/sagikazarmark/locafero v0.3.0 h1:zT7VEGWC2DTflmccN/5T1etyKvxSxpHsjb9cJvm4SvQ= github.com/sagikazarmark/locafero v0.3.0/go.mod h1:w+v7UsPNFwzF1cHuOajOOzoq4U7v/ig1mpRjqV+Bu1U= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= @@ -297,9 +381,18 @@ github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWR github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.10.0 h1:EaGW2JJh15aKOejeuJ+wpFSHnbd7GE6Wvp3TsNhb6LY= @@ -311,16 +404,19 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/spf13/viper v1.17.0 h1:I5txKw7MJasPL/BrfkbA0Jyo/oELqVmux4pR/UxOMfI= github.com/spf13/viper v1.17.0/go.mod h1:BmMMMLQXSbcHK6KAOiFLz0l5JHrU89OdIRHvsk0+yVI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -332,12 +428,20 @@ github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA= github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= +github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= +github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= +github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.30/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -353,8 +457,13 @@ go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= +golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -370,8 +479,8 @@ golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -475,9 +584,10 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -519,6 +629,7 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -529,8 +640,8 @@ golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -539,8 +650,8 @@ golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= +golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -553,8 +664,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -721,6 +832,7 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= gopkg.in/redis.v4 v4.2.4/go.mod h1:8KREHdypkCEojGKQcjMqAODMICIVwZAONWq8RowTITA= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/biz/chat_history.go b/internal/biz/chat_history.go index 48dcd0d..c3912a4 100644 --- a/internal/biz/chat_history.go +++ b/internal/biz/chat_history.go @@ -1,41 +1,115 @@ package biz import ( + errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/util" "context" - + "encoding/json" "xorm.io/builder" ) type ChatHistoryBiz struct { - chatRepo *impl.ChatImpl + chatHiRepo *impl.ChatHisImpl + taskRepo *impl.TaskImpl } -func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz { +func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl, taskRepo *impl.TaskImpl) *ChatHistoryBiz { s := &ChatHistoryBiz{ - chatRepo: chatRepo, + chatHiRepo: chatHiRepo, + taskRepo: taskRepo, } - //go s.AsyncProcess(context.Background()) return s } -//func (s *ChatHistoryBiz) create(ctx context.Context, sessionID, role, content string) error { -// chat := model.AiChatHi{ -// SessionID: sessionID, -// Role: role, -// Content: content, -// } -// -// return s.chatRepo.Create(&chat) -//} -// +// 查询会话历史 +func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]entitys.ChatHisQueryResponse, error) { + + con := []impl.CondFunc{ + s.chatHiRepo.WithSessionId(query.SessionID), + s.chatHiRepo.PaginateScope(query.Page, query.PageSize), + s.chatHiRepo.OrderByDesc("his_id"), + } + if query.HisID > 0 { + con = append(con, s.chatHiRepo.WithHisId(query.HisID)) + } + + chats, err := s.chatHiRepo.FindAll( + con..., + ) + if err != nil { + return nil, err + } + + taskIds := make([]int32, 0, len(chats)) + for _, chat := range chats { + // 去重任务ID + if !util.Contains(taskIds, chat.TaskID) { + taskIds = append(taskIds, chat.TaskID) + } + } + + // 查询任务名称 + tasks, err := s.taskRepo.FindAll(s.taskRepo.In("task_id", taskIds)) + if err != nil { + return nil, err + } + taskMap := make(map[int32]model.AiTask) + for _, task := range tasks { + taskMap[task.TaskID] = task + } + + // 构建结果 + result := make([]entitys.ChatHisQueryResponse, 0, len(chats)) + for _, chat := range chats { + item := entitys.ChatHisQueryResponse{} + item.FromModel(chat, taskMap[chat.TaskID]) + result = append(result, item) + } + + return result, nil +} + //// 添加会话历史 //func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error { -// return s.create(ctx, chat.SessionID, chat.Role.String(), chat.Content) +// return s.chatHiRepo.Create(&model.AiChatHi{ +// SessionID: chat.SessionID, +// Ques: chat.Role.String(), +// Ans: chat.Content, +// }) //} +// 更新会话历史内容, 追加内容, 不覆盖原有内容, content 使用json格式存储 +func (c *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UpdateContentRequest) error { + var contents []string + chatHi, has, err := c.chatHiRepo.FindOne(c.chatHiRepo.WithHisId(chat.HisID)) + if err != nil { + return err + } else if !has { + return errors.NewBusinessErr(errors.InvalidParamCode, "chat history not found") + } + + if "" != chatHi.Content { + // 解析历史内容 + err = json.Unmarshal([]byte(chatHi.Content), &contents) + if err != nil { + return err + } + } + contents = append(contents, chat.Content) + + b, err := json.Marshal(contents) + if err != nil { + return err + } + chatHi.Content = string(b) + return c.chatHiRepo.Update(&chatHi, + c.chatHiRepo.Select("content"), + c.chatHiRepo.WithHisId(chatHi.HisID)) +} + // 异步添加会话历史 //func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) { // s.chatRepo.AsyncCreate(ctx, model.AiChatHi{ @@ -53,5 +127,5 @@ func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz { func (s *ChatHistoryBiz) Update(ctx context.Context, chat *entitys.UseFulRequest) error { cond := builder.NewCond() cond = cond.And(builder.Eq{"his_id": chat.HisId}) - return s.chatRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful}) + return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful}) } diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go new file mode 100644 index 0000000..b8976df --- /dev/null +++ b/internal/biz/ding_talk_bot.go @@ -0,0 +1,564 @@ +package biz + +import ( + "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/biz/handle/dingtalk" + "ai_scheduler/internal/biz/tools_regis" + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/tools" + "ai_scheduler/tmpl/dataTemp" + "io" + "net/http" + "strconv" + "time" + + "ai_scheduler/internal/config" + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" + "github.com/coze-dev/coze-go" + "github.com/gofiber/fiber/v2/log" + "xorm.io/builder" +) + +// AiRouterBiz 智能路由服务 +type DingTalkBotBiz struct { + do *do.Do + handle *do.Handle + botConfigImpl *impl.BotConfigImpl + replier *chatbot.ChatbotReplier + log log.Logger + dingTalkUser *dingtalk.User + botTools []model.AiBotTool + botGroupImpl *impl.BotGroupImpl + toolManager *tools.Manager + chatHis *impl.BotChatHisImpl + conf *config.Config + cardSend *dingtalk.SendCardClient +} + +// NewDingTalkBotBiz +func NewDingTalkBotBiz( + do *do.Do, + handle *do.Handle, + botConfigImpl *impl.BotConfigImpl, + botGroupImpl *impl.BotGroupImpl, + dingTalkUser *dingtalk.User, + tools *tools_regis.ToolRegis, + chatHis *impl.BotChatHisImpl, + toolManager *tools.Manager, + conf *config.Config, + cardSend *dingtalk.SendCardClient, +) *DingTalkBotBiz { + return &DingTalkBotBiz{ + do: do, + handle: handle, + botConfigImpl: botConfigImpl, + replier: chatbot.NewChatbotReplier(), + dingTalkUser: dingTalkUser, + botTools: tools.BootTools, + botGroupImpl: botGroupImpl, + toolManager: toolManager, + chatHis: chatHis, + conf: conf, + cardSend: cardSend, + } +} + +func (d *DingTalkBotBiz) GetDingTalkBotCfgList() (dingBotList []entitys.DingTalkBot, err error) { + botConfig := make([]model.AiBotConfig, 0) + + cond := builder.NewCond() + cond = cond.And(builder.Eq{"status": constants.Enable}) + cond = cond.And(builder.Eq{"bot_type": constants.BotTypeDingTalk}) + err = d.botConfigImpl.GetRangeToMapStruct(&cond, &botConfig) + for _, v := range botConfig { + var config entitys.DingTalkBot + err = json.Unmarshal([]byte(v.BotConfig), &config) + if err != nil { + d.log.Info("初始化“%s”失败:%s", v.BotName, err.Error()) + } + config.BotIndex = v.RobotCode + dingBotList = append(dingBotList, config) + } + return + +} + +func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallbackDataModel) (requireData *entitys.RequireDataDingTalkBot, err error) { + requireData = &entitys.RequireDataDingTalkBot{ + Req: data, + Ch: make(chan entitys.Response, 2), + } + + return +} + +func (d *DingTalkBotBiz) Do(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { + //entitys.ResLoading(requireData.Ch, "", "收到消息,正在处理中,请稍等") + //defer close(requireData.Ch) + switch constants.ConversationType(requireData.Req.ConversationType) { + case constants.ConversationTypeSingle: + err = d.handleSingleChat(ctx, requireData) + case constants.ConversationTypeGroup: + err = d.handleGroupChat(ctx, requireData) + default: + err = errors.New("未知的聊天类型:" + requireData.Req.ConversationType) + } + if err != nil { + entitys.ResText(requireData.Ch, "", err.Error()) + } + return +} + +func (d *DingTalkBotBiz) handleSingleChat(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { + entitys.ResLog(requireData.Ch, "", "个人聊天暂未开启,请期待后续更新") + return + //requireData.UserInfo, err = d.dingTalkUser.GetUserInfoFromBot(ctx, requireData.Req.SenderStaffId, dingtalk.WithId(1)) + //if err != nil { + // return + //} + //requireData.ID=requireData.UserInfo.UserID + ////如果不是管理或者不是老板,则进行权限判断 + //if requireData.UserInfo.IsSenior == constants.IsSeniorFalse && requireData.UserInfo.IsBoss == constants.IsBossFalse { + // + //} + //return +} + +func (d *DingTalkBotBiz) handleGroupChat(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { + group, err := d.initGroup(ctx, requireData.Req.ConversationId, requireData.Req.ConversationTitle, requireData.Req.RobotCode) + + if err != nil { + return + } + requireData.ID = group.GroupID + groupTools, err := d.getGroupTools(ctx, group) + if err != nil { + return + } + rec, err := d.recognize(ctx, requireData, groupTools) + if err != nil { + return + } + + return d.handleMatch(ctx, rec) +} + +func (d *DingTalkBotBiz) initGroup(ctx context.Context, conversationId string, conversationTitle string, robotCode string) (group *model.AiBotGroup, err error) { + group, err = d.botGroupImpl.GetByConversationIdAndRobotCode(conversationId, robotCode) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + + return + } + } + + if group.GroupID == 0 { + group = &model.AiBotGroup{ + ConversationID: conversationId, + Title: conversationTitle, + RobotCode: robotCode, + ToolList: "", + } + //如果不存在则创建 + _, err = d.botGroupImpl.Add(group) + } + return +} + +func (d *DingTalkBotBiz) getGroupTools(ctx context.Context, group *model.AiBotGroup) (tools []model.AiBotTool, err error) { + if len(d.botTools) == 0 { + return + } + var ( + groupRegisTools = make(map[int]struct{}) + ) + if group.ToolList != "" { + groupToolList := strings.Split(group.ToolList, ",") + for _, tool := range groupToolList { + if tool == "" { + continue + } + num, _err := strconv.Atoi(tool) + if _err != nil { + continue + } + groupRegisTools[num] = struct{}{} + } + } + + for _, v := range d.botTools { + if v.PermissionType == constants.PermissionTypeNone { + tools = append(tools, v) + continue + } + if _, ex := groupRegisTools[int(v.ToolID)]; ex { + tools = append(tools, v) + } + } + return +} +func (d *DingTalkBotBiz) recognize(ctx context.Context, requireData *entitys.RequireDataDingTalkBot, tools []model.AiBotTool) (rec *entitys.Recognize, err error) { + + userContent, err := d.getUserContent(requireData.Req.Msgtype, requireData.Req.Text.Content) + if err != nil { + return + } + rec = &entitys.Recognize{ + Ch: requireData.Ch, + SystemPrompt: d.defaultPrompt(), + UserContent: userContent, + } + //历史记录 + rec.ChatHis, err = d.getHis(ctx, constants.ConversationType(requireData.Req.ConversationType), requireData.ID) + if err != nil { + return + } + //工具注册 + if len(tools) > 0 { + rec.Tasks = make([]entitys.RegistrationTask, 0, len(tools)) + for _, task := range tools { + taskConfig := entitys.TaskConfigDetail{} + if err = json.Unmarshal([]byte(task.Config), &taskConfig); err != nil { + log.Errorf("解析任务配置失败: %s, 任务ID: %s", err.Error(), task.Index) + continue // 解析失败时跳过该任务,而不是直接返回错误 + } + + rec.Tasks = append(rec.Tasks, entitys.RegistrationTask{ + Name: task.Index, + Desc: task.Desc, + TaskConfigDetail: taskConfig, // 直接使用解析后的配置,避免重复构建 + }) + } + } + err = d.handle.Recognize(ctx, rec, &do.WithDingTalkBot{}) + return +} + +func (d *DingTalkBotBiz) getHis(ctx context.Context, conversationType constants.ConversationType, Id int32) (content entitys.ChatHis, err error) { + + var ( + his []model.AiBotChatHi + ) + cond := builder.NewCond() + cond = cond.And(builder.Eq{"his_type": conversationType}) + cond = cond.And(builder.Eq{"id": Id}) + _, err = d.chatHis.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: d.conf.Sys.SessionLen}, &his, "his_id desc") + if err != nil { + return + } + messages := make([]entitys.HisMessage, 0, len(his)) + for _, v := range his { + messages = append(messages, entitys.HisMessage{ + Role: constants.Caller(v.Role), // 用户角色 + Content: v.Content, // 用户输入内容 + Timestamp: v.CreateAt.Format(time.DateTime), + }) + } + return entitys.ChatHis{ + SessionId: fmt.Sprintf("%s_%d", conversationType, Id), + Messages: messages, + Context: entitys.HisContext{ + UserLanguage: constants.LangZhCN, // 默认中文 + SystemMode: constants.SystemModeTechnicalSupport, // 默认技术支持模式 + }, + }, nil +} + +func (d *DingTalkBotBiz) getUserContent(msgType string, msgContent interface{}) (content *entitys.RecognizeUserContent, err error) { + switch constants.BotMsgType(msgType) { + case constants.BotMsgTypeText: + content = &entitys.RecognizeUserContent{ + Text: msgContent.(string), + } + default: + return nil, errors.New("未知的消息类型:" + msgType) + } + return +} + +func (d *DingTalkBotBiz) handleMatch(ctx context.Context, rec *entitys.Recognize) (err error) { + + if !rec.Match.IsMatch { + if len(rec.Match.Chat) != 0 { + entitys.ResText(rec.Ch, "", rec.Match.Chat) + } else { + entitys.ResText(rec.Ch, "", rec.Match.Reasoning) + } + return + } + var pointTask *model.AiBotTool + for _, task := range d.botTools { + if task.Index == rec.Match.Index { + pointTask = &task + + break + } + } + + if pointTask == nil || pointTask.Index == "other" { + return d.otherTask(ctx, rec) + } + switch constants.TaskType(pointTask.Type) { + case constants.TaskTypeFunc: + return d.handleTask(ctx, rec, pointTask) + case constants.TaskTypeCozeWorkflow: + return d.handleCozeWorkflow(ctx, rec, pointTask) + default: + return d.otherTask(ctx, rec) + } + return +} + +func (d *DingTalkBotBiz) handleCozeWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiBotTool) (err error) { + entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流(coze)\n") + + customClient := &http.Client{ + Timeout: time.Minute * 30, + } + + authCli := coze.NewTokenAuth(d.conf.Coze.ApiSecret) + cozeCli := coze.NewCozeAPI( + authCli, + coze.WithBaseURL(d.conf.Coze.BaseURL), + coze.WithHttpClient(customClient), + ) + + // 从参数中获取workflowID + type requestParams struct { + Request l_request.Request `json:"request"` + } + var config requestParams + err = json.Unmarshal([]byte(task.Config), &config) + if err != nil { + return err + } + workflowId, ok := config.Request.Json["workflow_id"].(string) + if !ok { + return fmt.Errorf("workflow_id不能为空") + } + // 提取参数 + var data map[string]interface{} + err = json.Unmarshal([]byte(rec.Match.Parameters), &data) + + req := &coze.RunWorkflowsReq{ + WorkflowID: workflowId, + Parameters: data, + // IsAsync: true, + } + + stream := config.Request.Json["stream"].(bool) + + entitys.ResLog(rec.Ch, task.Index, "工作流执行中...") + + if stream { + streamResp, err := cozeCli.Workflows.Runs.Stream(ctx, req) + if err != nil { + return err + } + + handleCozeWorkflowEvents(ctx, streamResp, cozeCli, workflowId, rec.Ch, task.Index) + } else { + resp, err := cozeCli.Workflows.Runs.Create(ctx, req) + if err != nil { + return err + } + + entitys.ResJson(rec.Ch, task.Index, resp.Data) + } + + return +} + +// handleCozeWorkflowEvents 处理 coze 工作流事件 +func handleCozeWorkflowEvents(ctx context.Context, resp coze.Stream[coze.WorkflowEvent], cozeCli coze.CozeAPI, workflowID string, ch chan entitys.Response, index string) { + defer resp.Close() + for { + event, err := resp.Recv() + if errors.Is(err, io.EOF) { + fmt.Println("Stream finished") + break + } + if err != nil { + fmt.Println("Error receiving event:", err) + break + } + + switch event.Event { + case coze.WorkflowEventTypeMessage: + entitys.ResStream(ch, index, event.Message.Content) + case coze.WorkflowEventTypeError: + entitys.ResError(ch, index, fmt.Sprintf("工作流执行错误: %s", event.Error)) + case coze.WorkflowEventTypeDone: + entitys.ResEnd(ch, index, "工作流执行完成") + case coze.WorkflowEventTypeInterrupt: + resumeReq := &coze.ResumeRunWorkflowsReq{ + WorkflowID: workflowID, + EventID: event.Interrupt.InterruptData.EventID, + ResumeData: "your data", + InterruptType: event.Interrupt.InterruptData.Type, + } + newResp, err := cozeCli.Workflows.Runs.Resume(ctx, resumeReq) + if err != nil { + entitys.ResError(ch, index, fmt.Sprintf("工作流恢复执行错误: %s", err.Error())) + return + } + entitys.ResLog(ch, index, "工作流恢复执行中...") + handleCozeWorkflowEvents(ctx, newResp, cozeCli, workflowID, ch, index) + } + } + fmt.Printf("done, log:%s\n", resp.Response().LogID()) +} + +func (d *DingTalkBotBiz) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiBotTool) (err error) { + var configData entitys.ConfigDataTool + err = json.Unmarshal([]byte(task.Config), &configData) + if err != nil { + return + } + + err = d.toolManager.ExecuteTool(ctx, configData.Tool, rec) + if err != nil { + return + } + + return +} + +func (d *DingTalkBotBiz) otherTask(ctx context.Context, rec *entitys.Recognize) (err error) { + entitys.ResText(rec.Ch, "", rec.Match.Reasoning) + return +} + +//func (d *DingTalkBotBiz) HandleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response, ch chan string) error { +// switch resp.Type { +// case entitys.ResponseText: +// return d.replyText(ctx, data.SessionWebhook, resp.Content) +// case entitys.ResponseStream: +// +// return d.replySteam(ctx, data, ch) +// case entitys.ResponseImg: +// return d.replyImg(ctx, data.SessionWebhook, resp.Content) +// case entitys.ResponseFile: +// return d.replyFile(ctx, data.SessionWebhook, resp.Content) +// case entitys.ResponseMarkdown: +// return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content) +// case entitys.ResponseActionCard: +// return d.replyActionCard(ctx, data.SessionWebhook, resp.Content) +// default: +// return nil +// } +//} + +func (d *DingTalkBotBiz) HandleStreamRes(ctx context.Context, data *chatbot.BotCallbackDataModel, content chan string) (err error) { + err = d.cardSend.NewCard(ctx, &dingtalk.CardSend{ + RobotCode: data.RobotCode, + ConversationType: constants.ConversationType(data.ConversationType), + Template: constants.CardTempDefault, + ContentChannel: content, // 指定内容通道 + ConversationId: data.ConversationId, + SenderStaffId: data.SenderStaffId, + Title: data.Text.Content, + }) + + return +} + +func (d *DingTalkBotBiz) ReplyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyImg(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyFile(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyMarkdown(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyActionCard(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) SaveHis(ctx context.Context, requireData *entitys.RequireDataDingTalkBot, chat []string) (err error) { + if len(chat) == 0 { + return + } + his := []*model.AiBotChatHi{ + { + HisType: requireData.Req.ConversationType, + ID: requireData.ID, + Role: "user", + Content: requireData.Req.Text.Content, + }, + { + HisType: requireData.Req.ConversationType, + ID: requireData.ID, + Role: "system", + Content: strings.Join(chat, "\n"), + }, + } + _, err = d.chatHis.Add(his) + return err +} + +func (d *DingTalkBotBiz) defaultPrompt() string { + + return `[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**。请严格遵循以下规则: +[rule] +1. **返回格式**: +仅输出以下 **严格格式化的 JSON 字符串**(禁用 Markdown): +{ "index": "工具索引index", "confidence": 0.0-1.0,"reasoning": "判断理由","parameters":"jsonstring |提取参数","is_match":true||false,"chat": "追问内容"} +关键规则(按优先级排序): + +2. **工具匹配**: + +- 若匹配到工具,使用工具的 parameters 作为模板做参数匹配 +- 注意区分 parameters 中的 必须参数(required) 和 可选参数(optional),按下述参数提取规则处理。 +- 若**完全无法匹配**,立即设置 is_match: false,并在 chat 中已第一人称的角度提醒用户需要适用何种工具(例:"请问您是要查询订单还是商品呢")。 + +1. **参数提取**: + +- 根据 parameters 字段列出的参数名,从用户输入中提取对应值。 +- **仅提取**明确提及的参数,忽略未列出的内容。 +- 必须参数仅使用用户直接提及的参数,不允许从上下文推断。 +- 若必须参数缺失,立即设置 is_match: false,并在 chat 中已第一人称的角度提醒用户提供缺少的参数追问(例:"需要您补充XX信息")。 + +4. 格式强制要求: +-所有字段值必须是**字符串**(包括 confidence)。 +-parameters 必须是 **转义后的 JSON 字符串**(如 "{\"product_name\": \"京东月卡\"}")。` +} diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index ed8749d..5d9ba72 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -19,8 +19,6 @@ import ( "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/fiber/v2/log" - "github.com/gofiber/websocket/v2" - "xorm.io/builder" ) @@ -29,14 +27,14 @@ type Do struct { sessionImpl *impl.SessionImpl sysImpl *impl.SysImpl taskImpl *impl.TaskImpl - hisImpl *impl.ChatImpl + hisImpl *impl.ChatHisImpl conf *config.Config } func NewDo( sysImpl *impl.SysImpl, taskImpl *impl.TaskImpl, - hisImpl *impl.ChatImpl, + hisImpl *impl.ChatHisImpl, conf *config.Config, ) *Do { return &Do{ @@ -81,6 +79,20 @@ func (d *Do) DataAuth(ctx context.Context, client *gateway.Client, requireData * return nil } +func (d *Do) DataAuthForBot(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) { + + // 2. 加载系统信息 + if err = d.loadSystemInfo(ctx, client, requireData); err != nil { + return fmt.Errorf("获取系统信息失败: %w", err) + } + + // 3. 加载任务列表 + if err = d.loadTaskList(ctx, client, requireData); err != nil { + return fmt.Errorf("获取任务列表失败: %w", err) + } + return nil +} + // 提取数据验证为单独函数 func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.RequireData) error { requireData.Session = client.GetSession() @@ -104,7 +116,7 @@ func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.Req // 获取系统信息的辅助函数 func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error { if sysInfo := client.GetSysInfo(); sysInfo == nil { - sys, err := d.getSysInfo(requireData) + sys, err := d.GetSysInfo(requireData) if err != nil { return err } @@ -119,7 +131,8 @@ func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, require // 获取任务列表的辅助函数 func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error { if taskInfo := client.GetTasks(); len(taskInfo) == 0 { - tasks, err := d.getTasks(requireData.Sys.SysID) + // 从数据库获取任务列表, 0 表示获取公共的任务 + tasks, err := d.GetTasks(requireData.Sys.SysID, 0) if err != nil { return err } @@ -128,6 +141,7 @@ func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireDa } else { requireData.Tasks = taskInfo } + return nil } @@ -141,10 +155,10 @@ func (d *Do) loadChatHistory(ctx context.Context, requireData *entitys.RequireDa return nil } -func (d *Do) MakeCh(c *websocket.Conn, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) { +func (d *Do) MakeCh(client *gateway.Client, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) { requireData.Ch = make(chan entitys.Response) ctx, cancel := context.WithCancel(context.Background()) - done := d.startMessageHandler(ctx, c, requireData) + done := d.startMessageHandler(ctx, client, requireData) return ctx, func() { close(requireData.Ch) //关闭主通道 <-done // 等待消息处理完成 @@ -202,7 +216,7 @@ func (d *Do) getRequireData() (err error) { return } -func (d *Do) getSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) { +func (d *Do) GetSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"app_key": requireData.Key}) cond = cond.And(builder.IsNull{"delete_at"}) @@ -221,12 +235,12 @@ func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.Ai return } -func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) { +func (d *Do) GetTasks(sysId ...int32) (tasks []model.AiTask, err error) { cond := builder.NewCond() - cond = cond.And(builder.Eq{"sys_id": sysId}) + //cond = cond.And(builder.Eq{"sys_id": sysId}) cond = cond.And(builder.IsNull{"delete_at"}) - cond = cond.And(builder.Eq{"status": 1}) + cond = cond.And(builder.Eq{"status": 1}.And(builder.In("sys_id", sysId))) _, err = d.taskImpl.GetListToStruct(&cond, nil, &tasks, "") return @@ -235,7 +249,7 @@ func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) { // startMessageHandler 启动独立的消息处理协程 func (d *Do) startMessageHandler( ctx context.Context, - c *websocket.Conn, + client *gateway.Client, requireData *entitys.RequireData, ) <-chan struct{} { done := make(chan struct{}) @@ -254,12 +268,13 @@ func (d *Do) startMessageHandler( Ques: requireData.Req.Text, Ans: strings.Join(chat, ""), Files: requireData.Req.Img, + TaskID: requireData.Task.TaskID, } d.hisImpl.AddWithData(AiRes) hisLog.HisId = AiRes.HisID } - _ = entitys.MsgSend(c, entitys.Response{ + _ = entitys.MsgSend(client, entitys.Response{ Content: pkg.JsonStringIgonErr(hisLog), Type: entitys.ResponseEnd, }) @@ -267,7 +282,7 @@ func (d *Do) startMessageHandler( }() for v := range requireData.Ch { // 自动检测通道关闭 - if err := sendWithTimeout(c, v, 2*time.Second); err != nil { + if err := sendWithTimeout(client, v, 10*time.Second); err != nil { log.Errorf("Send error: %v", err) return } @@ -281,7 +296,7 @@ func (d *Do) startMessageHandler( } // 辅助函数:带超时的 WebSocket 发送 -func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error { +func sendWithTimeout(client *gateway.Client, data entitys.Response, timeout time.Duration) error { sendCtx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -294,7 +309,7 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura close(done) }() // 如果 MsgSend 阻塞,这里会卡住 - err := entitys.MsgSend(c, data) + err := entitys.MsgSend(client, data) done <- err }() @@ -334,7 +349,7 @@ func (d *Do) LoadUserPermission(client *gateway.Client, requireData *entitys.Req // 检查响应状态码 if res.StatusCode != http.StatusOK { - err = errors.SysErr("获取用户权限失败") + err = errors.SysErrf("获取用户权限失败") return } diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index e014d25..ca2392c 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -4,29 +4,45 @@ import ( "ai_scheduler/internal/biz/llm_service" "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" + errorcode "ai_scheduler/internal/data/error" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" + "ai_scheduler/internal/domain/workflow/runtime" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/dingtalk" "ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/mapstructure" + "ai_scheduler/internal/pkg/rec_extra" + "ai_scheduler/internal/pkg/util" "ai_scheduler/internal/tools" - "ai_scheduler/internal/tools_bot" + "ai_scheduler/internal/tools/public" + errorsSpecial "errors" + "io" + "net/http" + "time" + "context" "encoding/json" "fmt" - "gorm.io/gorm/utils" "strings" + + "github.com/coze-dev/coze-go" + "github.com/gofiber/fiber/v2/log" + "gorm.io/gorm/utils" ) type Handle struct { - Ollama *llm_service.OllamaService - toolManager *tools.Manager - Bot *tools_bot.BotTool - conf *config.Config - sessionImpl *impl.SessionImpl + Ollama *llm_service.OllamaService + toolManager *tools.Manager + conf *config.Config + sessionImpl *impl.SessionImpl + workflowManager *runtime.Registry + dingtalkOldClient *dingtalk.OldClient + dingtalkContactClient *dingtalk.ContactClient + dingtalkNotableClient *dingtalk.NotableClient } func NewHandle( @@ -34,34 +50,44 @@ func NewHandle( toolManager *tools.Manager, conf *config.Config, sessionImpl *impl.SessionImpl, - dTalkBot *tools_bot.BotTool, + workflowManager *runtime.Registry, + dingtalkOldClient *dingtalk.OldClient, + dingtalkContactClient *dingtalk.ContactClient, + dingtalkNotableClient *dingtalk.NotableClient, ) *Handle { return &Handle{ - Ollama: Ollama, - toolManager: toolManager, - conf: conf, - sessionImpl: sessionImpl, - Bot: dTalkBot, + Ollama: Ollama, + toolManager: toolManager, + conf: conf, + sessionImpl: sessionImpl, + workflowManager: workflowManager, + dingtalkOldClient: dingtalkOldClient, + dingtalkContactClient: dingtalkContactClient, + dingtalkNotableClient: dingtalkNotableClient, } } -func (r *Handle) Recognize(ctx context.Context, requireData *entitys.RequireData) (err error) { - entitys.ResLog(requireData.Ch, "recognize_start", "准备意图识别") +func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (err error) { + entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别") + prompt, err := promptProcessor.CreatePrompt(ctx, rec) //意图识别 - recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData) + recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{ + Prompt: prompt, + Tools: rec.Tasks, + }) if err != nil { return } - entitys.ResLog(requireData.Ch, "recognize", recognizeMsg) - entitys.ResLog(requireData.Ch, "recognize_end", "意图识别结束") - + entitys.ResLog(rec.Ch, "recognize", recognizeMsg) + entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束") var match entitys.Match if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil { - err = errors.SysErr("数据结构错误:%v", err.Error()) + err = errors.SysErrf("数据结构错误:%v", err.Error()) return } - requireData.Match = &match + rec.Match = &match + return } @@ -70,27 +96,27 @@ func (r *Handle) handleOtherTask(ctx context.Context, requireData *entitys.Requi return } -func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) { +func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, rec *entitys.Recognize, requireData *entitys.RequireData) (err error) { - if !requireData.Match.IsMatch { - if len(requireData.Match.Chat) != 0 { - entitys.ResText(requireData.Ch, "", requireData.Match.Chat) + if !rec.Match.IsMatch { + if len(rec.Match.Chat) != 0 { + entitys.ResText(rec.Ch, "", rec.Match.Chat) } else { - entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning) + entitys.ResText(rec.Ch, "", rec.Match.Reasoning) } return } var pointTask *model.AiTask for _, task := range requireData.Tasks { - if task.Index == requireData.Match.Index { + if task.Index == rec.Match.Index { pointTask = &task break } } if pointTask == nil || pointTask.Index == "other" { - return r.OtherTask(ctx, requireData) + return r.OtherTask(ctx, rec) } // 校验用户权限 @@ -101,45 +127,35 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requir switch constants.TaskType(pointTask.Type) { case constants.TaskTypeApi: - return r.handleApiTask(ctx, requireData, pointTask) - case constants.TaskTypeFunc: - return r.handleTask(ctx, requireData, pointTask) + return r.handleApiTask(ctx, rec, pointTask) case constants.TaskTypeKnowle: - return r.handleKnowle(ctx, requireData, pointTask) + return r.handleKnowle(ctx, rec, pointTask) + case constants.TaskTypeFunc: + return r.handleTask(ctx, rec, pointTask) case constants.TaskTypeBot: - return r.handleBot(ctx, requireData, pointTask) + return r.handleBot(ctx, rec, pointTask) + case constants.TaskTypeEinoWorkflow: + return r.handleEinoWorkflow(ctx, rec, pointTask) + case constants.TaskTypeCozeWorkflow: + return r.handleCozeWorkflow(ctx, rec, pointTask) default: return r.handleOtherTask(ctx, requireData) } } -func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) { +func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.Recognize) (err error) { entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning) return } -func (r *Handle) handleBot(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { - var configData entitys.ConfigDataTool - err = json.Unmarshal([]byte(task.Config), &configData) - if err != nil { - return - } - err = r.Bot.Execute(ctx, configData.Tool, requireData) - if err != nil { - return - } - - return -} - -func (r *Handle) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { +func (r *Handle) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } - err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) + err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec) if err != nil { return } @@ -148,7 +164,7 @@ func (r *Handle) handleTask(ctx context.Context, requireData *entitys.RequireDat } // 知识库 -func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { +func (r *Handle) handleKnowle(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { var ( configData entitys.ConfigDataTool @@ -160,13 +176,16 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD if err != nil { return } - + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } // 通过session 找到知识库session var has bool - if len(requireData.Session) == 0 { + if len(ext.Session) == 0 { return errors.SessionNotFound } - requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session)) + ext.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(ext.Session)) if err != nil { return } else if !has { @@ -180,7 +199,7 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD return fmt.Errorf("tool not found: %s", configData.Tool) } - if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok { + if knowledgeTool, ok := tool.(*public.KnowledgeBaseTool); !ok { return fmt.Errorf("未找到知识库Tool: %s", configData.Tool) } else { host = knowledgeTool.GetConfig().BaseURL @@ -189,15 +208,15 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD } // 知识库的session为空,请求知识库获取, 并绑定 - if requireData.SessionInfo.KnowlegeSessionID == "" { + if ext.SessionInfo.KnowlegeSessionID == "" { // 请求知识库 - if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil { + if sessionIdKnowledge, err = public.GetKnowledgeBaseSession(host, ext.Sys.KnowlegeBaseID, ext.Sys.KnowlegeTenantKey); err != nil { return } // 绑定知识库session,下次可以使用 - requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge - if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil { + ext.SessionInfo.KnowlegeSessionID = sessionIdKnowledge + if err = r.sessionImpl.Update(&ext.SessionInfo, r.sessionImpl.WithSessionId(ext.SessionInfo.SessionID)); err != nil { return } } @@ -205,21 +224,21 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD // 用户输入解析 var ok bool input := make(map[string]string) - if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil { + if err = json.Unmarshal([]byte(rec.Match.Parameters), &input); err != nil { return } if query, ok = input["query"]; !ok { return fmt.Errorf("query不能为空") } - requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{ - Session: requireData.SessionInfo.KnowlegeSessionID, - ApiKey: requireData.Sys.KnowlegeTenantKey, + ext.KnowledgeConf = entitys.KnowledgeBaseRequest{ + Session: ext.SessionInfo.KnowlegeSessionID, + ApiKey: ext.Sys.KnowlegeTenantKey, Query: query, } - + rec.Ext = pkg.JsonByteIgonErr(ext) // 执行工具 - err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) + err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec) if err != nil { return } @@ -227,18 +246,113 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD return } -func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { +// bot 临时实现,后续转到 eino 工作流 +func (r *Handle) handleBot(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { + if task.Index == "bug_optimization_submit" { + // Ext 中获取 sessionId + sessionID := rec.GetSession() + // 获取dingtalk accessToken + accessToken, _ := r.dingtalkOldClient.GetAccessToken() + // 获取创建者 dingtalk unionId + unionId := r.getUserDingtalkUnionId(ctx, accessToken, sessionID) + // 附件url + var attachmentUrl string + for _, file := range rec.UserContent.File { + attachmentUrl = file.FileUrl + break + } + recordId, err := r.dingtalkNotableClient.InsertRecord(accessToken, &dingtalk.InsertRecordReq{ + BaseId: r.conf.Dingtalk.TableDemand.BaseId, + SheetIdOrName: r.conf.Dingtalk.TableDemand.SheetIdOrName, + // OperatorId: tool_callback.BotBugOptimizationSubmitAdminUnionId, + OperatorId: unionId, + CreatorUnionId: unionId, + Content: rec.UserContent.Text, + AttachmentUrl: attachmentUrl, + }) + if err != nil { + errCode := r.dingtalkNotableClient.GetHTTPStatus(err) + // 权限不足 + if errCode == 403 { + return errorcode.ForbiddenErr("您当前没有AI需求表编辑权限,请联系管理员添加权限") + } + return err + } + + if recordId == "" { + return errors.NewBusinessErr(422, "创建记录失败,请联系管理员") + } + + // 构建跳转链接 + detailPage := util.BuildJumpLink(r.conf.Dingtalk.TableDemand.Url, "去查看") + + entitys.ResText(rec.Ch, "", fmt.Sprintf("问题已记录,正在分配相关人员处理,请您耐心等待处理结果。点击查看工单进度:%s", detailPage)) + + return nil + } + + return errors.NewBusinessErr(422, "bot 任务未实现") +} + +// getUserDingtalkUnionId 获取用户的 dingtalk unionId +func (r *Handle) getUserDingtalkUnionId(ctx context.Context, accessToken, sessionID string) (unionId string) { + // 查询用户名 + session, has, err := r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(sessionID)) + if err != nil || !has { + log.Warnf("session not found: %s", sessionID) + return + } + creatorName := session.UserName + + // 获取创建者uid 用户名 -> dingtalk uid + creatorId, err := r.dingtalkContactClient.SearchUserOne(accessToken, creatorName) + if err != nil { + log.Warnf("search dingtalk user one failed: %v", err) + return + } + + // 获取用户详情 dingtalk uid -> dingtalk unionId + userDetails, err := r.dingtalkOldClient.QueryUserDetails(ctx, creatorId) + if err != nil { + log.Warnf("query user dingtalk details failed: %v", err) + return + } + if userDetails == nil { + log.Warnf("user details not found: %s", creatorId) + return + } + + unionId = userDetails.UnionID + + return +} + +func (r *Handle) handleApiTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { var ( request l_request.Request requestParam map[string]interface{} ) - err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam) + ext, err := rec_extra.GetTaskRecExt(rec) if err != nil { return } - request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) + err = json.Unmarshal([]byte(rec.Match.Parameters), &requestParam) + if err != nil { + return + } + // request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) + task.Config = strings.ReplaceAll(task.Config, "${authorization}", ext.Auth) for k, v := range requestParam { - task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v)) + if vStr, ok := v.(string); ok { + task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", vStr) + } else { + var jsonStr []byte + jsonStr, err = json.Marshal(v) + if err != nil { + return errors.NewBusinessErr(422, "请求参数解析失败") + } + task.Config = strings.ReplaceAll(task.Config, "\"${"+k+"}\"", string(jsonStr)) + } } var configData entitys.ConfigDataHttp err = json.Unmarshal([]byte(task.Config), &configData) @@ -253,15 +367,139 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require err = errors.NewBusinessErr(422, "api地址获取失败") return } + + entitys.ResLoading(rec.Ch, task.Index, "正在请求数据") + res, err := request.Send() if err != nil { return } - entitys.ResJson(requireData.Ch, "", pkg.JsonStringIgonErr(res.Text)) + entitys.ResJson(rec.Ch, task.Index, res.Text) return } +// eino 工作流 +func (r *Handle) handleEinoWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { + // token 写入ctx + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } + ctx = util.SetTokenToContext(ctx, ext.Auth) + + entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流") + + // 工作流内部输出 + workflowId := task.Index + _, err = r.workflowManager.Invoke(ctx, workflowId, rec) + if err != nil { + return err + } + + return nil +} + +func (r *Handle) handleCozeWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { + entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流(coze)") + + customClient := &http.Client{ + Timeout: time.Minute * 30, + } + + authCli := coze.NewTokenAuth(r.conf.Coze.ApiSecret) + cozeCli := coze.NewCozeAPI( + authCli, + coze.WithBaseURL(r.conf.Coze.BaseURL), + coze.WithHttpClient(customClient), + ) + + // 从参数中获取workflowID + type requestParams struct { + Request l_request.Request `json:"request"` + } + var config requestParams + err = json.Unmarshal([]byte(task.Config), &config) + if err != nil { + return err + } + workflowId, ok := config.Request.Json["workflow_id"].(string) + if !ok { + return fmt.Errorf("workflow_id不能为空") + } + // 提取参数 + var data map[string]interface{} + err = json.Unmarshal([]byte(rec.Match.Parameters), &data) + + req := &coze.RunWorkflowsReq{ + WorkflowID: workflowId, + Parameters: data, + // IsAsync: true, + } + + stream := config.Request.Json["stream"].(bool) + + entitys.ResLog(rec.Ch, task.Index, "工作流执行中...") + + if stream { + streamResp, err := cozeCli.Workflows.Runs.Stream(ctx, req) + if err != nil { + return err + } + + handleCozeWorkflowEvents(ctx, streamResp, cozeCli, workflowId, rec.Ch, task.Index) + } else { + resp, err := cozeCli.Workflows.Runs.Create(ctx, req) + if err != nil { + return err + } + + entitys.ResJson(rec.Ch, task.Index, resp.Data) + } + + return +} + +// handleCozeWorkflowEvents 处理 coze 工作流事件 +func handleCozeWorkflowEvents(ctx context.Context, resp coze.Stream[coze.WorkflowEvent], cozeCli coze.CozeAPI, workflowID string, ch chan entitys.Response, index string) { + defer resp.Close() + for { + event, err := resp.Recv() + if errorsSpecial.Is(err, io.EOF) { + fmt.Println("Stream finished") + break + } + if err != nil { + fmt.Println("Error receiving event:", err) + break + } + + switch event.Event { + case coze.WorkflowEventTypeMessage: + entitys.ResStream(ch, index, event.Message.Content) + case coze.WorkflowEventTypeError: + entitys.ResError(ch, index, fmt.Sprintf("工作流执行错误: %s", event.Error)) + case coze.WorkflowEventTypeDone: + entitys.ResEnd(ch, index, "工作流执行完成") + case coze.WorkflowEventTypeInterrupt: + resumeReq := &coze.ResumeRunWorkflowsReq{ + WorkflowID: workflowID, + EventID: event.Interrupt.InterruptData.EventID, + ResumeData: "your data", + InterruptType: event.Interrupt.InterruptData.Type, + } + newResp, err := cozeCli.Workflows.Runs.Resume(ctx, resumeReq) + if err != nil { + entitys.ResError(ch, index, fmt.Sprintf("工作流恢复执行错误: %s", err.Error())) + return + } + entitys.ResLog(ch, index, "工作流恢复执行中...") + handleCozeWorkflowEvents(ctx, newResp, cozeCli, workflowID, ch, index) + } + } + fmt.Printf("done, log:%s\n", resp.Response().LogID()) +} + // 权限验证 func (r *Handle) PermissionAuth(client *gateway.Client, pointTask *model.AiTask) (err error) { // 授权检查权限 diff --git a/internal/biz/do/prompt.go b/internal/biz/do/prompt.go new file mode 100644 index 0000000..cada763 --- /dev/null +++ b/internal/biz/do/prompt.go @@ -0,0 +1,168 @@ +package do + +import ( + "ai_scheduler/internal/biz/handle" + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/utils_vllm" + "context" + "strings" + + "github.com/ollama/ollama/api" +) + +type PromptOption interface { + CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) +} + +type WithSys struct { + Config *config.Config +} + +func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) { + var ( + prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片 + ) + // 获取用户内容,如果出错则直接返回错误 + content, err := f.getUserContent(ctx, rec) + if err != nil { + return nil, err + } + // 构建提示消息列表,包含系统提示、助手回复和用户内容 + mes = append(prompt, api.Message{ + Role: "system", // 系统角色 + Content: rec.SystemPrompt, // 系统提示内容 + }, api.Message{ + Role: "assistant", // 助手角色 + Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容 + }, api.Message{ + Role: "user", // 用户角色 + Content: content.String(), // 用户输入内容 + }) + + return +} + +func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { + var hasFile bool + if len(rec.UserContent.File) > 0 { + hasFile = true + } + content.WriteString(rec.UserContent.Text) + if hasFile { + content.WriteString("\n") + } + + if len(rec.UserContent.Tag) > 0 { + content.WriteString("\n") + content.WriteString("### 工具必须使用:") + content.WriteString(rec.UserContent.Tag) + } + + if len(rec.ChatHis.Messages) > 0 { + content.WriteString("### 引用历史聊天记录:\n") + content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis)) + } + + if hasFile { + content.WriteString("\n") + content.WriteString("### 文件内容:\n") + for _, file := range rec.UserContent.File { + handle.HandleRecognizeFile(file) + // 文件识别 + switch file.FileType { + case constants.FileTypeImage: + entitys.ResLog(rec.Ch, "recognize_img_start", "图片识别中...") + var imageContent string + imageContent, err = f.recognizeWithImgVllm(ctx, file) + if err != nil { + return + } + entitys.ResLog(rec.Ch, "recognize_img_end", "图片识别完成,识别内容:"+imageContent) + + // 解析结果回写到file + file.FileRec = imageContent + default: + content.WriteString(file.FileRec) + } + } + } + return +} + +func (f *WithSys) recognizeWithImgVllm(ctx context.Context, file *entitys.RecognizeFile) (content string, err error) { + if file.FileData == nil || file.FileType != constants.FileTypeImage { + return + } + + client, cleanup, err := utils_vllm.NewClient(f.Config) + if err != nil { + return "", err + } + defer cleanup() + + outMsg, err := client.RecognizeWithImgBytes(ctx, + f.Config.DefaultPrompt.ImgRecognize.SystemPrompt, + f.Config.DefaultPrompt.ImgRecognize.UserPrompt, + file.FileData, + file.FileRealMime, + ) + if err != nil { + return "", err + } + + return outMsg.Content, nil +} + +type WithDingTalkBot struct { +} + +func (f *WithDingTalkBot) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) { + var ( + prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片 + ) + // 获取用户内容,如果出错则直接返回错误 + content, err := f.getUserContent(ctx, rec) + if err != nil { + return nil, err + } + // 构建提示消息列表,包含系统提示、助手回复和用户内容 + mes = append(prompt, api.Message{ + Role: "system", // 系统角色 + Content: rec.SystemPrompt, // 系统提示内容 + }, api.Message{ + Role: "assistant", // 助手角色 + Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容 + }, api.Message{ + Role: "user", // 用户角色 + Content: content.String(), // 用户输入内容 + }) + + return +} + +func (f *WithDingTalkBot) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { + var hasFile bool + if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 { + hasFile = true + } + content.WriteString(rec.UserContent.Text) + if hasFile { + content.WriteString("\n") + } + + if len(rec.UserContent.Tag) > 0 { + content.WriteString("\n") + content.WriteString("### 工具必须使用:") + content.WriteString(rec.UserContent.Tag) + } + + if len(rec.ChatHis.Messages) > 0 { + content.WriteString("### 引用历史聊天记录:\n") + content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis)) + } + + return +} diff --git a/internal/biz/handle/dingtalk/auth.go b/internal/biz/handle/dingtalk/auth.go new file mode 100644 index 0000000..143ab45 --- /dev/null +++ b/internal/biz/handle/dingtalk/auth.go @@ -0,0 +1,159 @@ +package dingtalk + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/utils" + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/gofiber/fiber/v2/log" + "github.com/redis/go-redis/v9" + "xorm.io/builder" +) + +type Auth struct { + redis *redis.Client + cfg *config.Config + botConfigImpl *impl.BotConfigImpl +} + +func NewAuth(cfg *config.Config, redis *utils.Rdb, botConfigImpl *impl.BotConfigImpl) *Auth { + return &Auth{ + redis: redis.Rdb, + cfg: cfg, + botConfigImpl: botConfigImpl, + } +} + +func (a *Auth) GetAccessToken(ctx context.Context, clientId string, clientSecret string) (authInfo *AuthInfo, err error) { + if clientId == "" { + return nil, errors.New("clientId is empty") + } + accessToken := a.redis.Get(ctx, a.getKey(clientId)).Val() + var expire time.Duration + if accessToken == "" { + dingTalkAuthRes, _err := a.getNewAccessToken(ctx, clientId, clientSecret) + if _err != nil { + return nil, _err + } + expire = time.Duration(dingTalkAuthRes.ExpireIn-3600) * time.Second + err = a.redis.SetEx(ctx, a.getKey(clientId), dingTalkAuthRes.AccessToken, expire).Err() + if err != nil { + return + } + accessToken = dingTalkAuthRes.AccessToken + } else { + expire, _ = a.redis.TTL(ctx, a.getKey(clientId)).Result() + } + return &AuthInfo{ + ClientId: clientId, + ClientSecret: clientSecret, + AccessToken: accessToken, + Expire: expire, + }, nil +} + +func (a *Auth) getKey(clientId string) string { + return a.cfg.Redis.Key + ":" + constants.DingTalkAuthBaseKeyPrefix + ":" + clientId +} + +func (a *Auth) getKeyBot(botCode string) string { + return a.cfg.Redis.Key + ":" + constants.DingTalkAuthBaseKeyBotPrefix + ":" + botCode +} + +func (a *Auth) getNewAccessToken(ctx context.Context, clientId string, clientSecret string) (auth DingTalkAuthIRes, err error) { + if clientId == "" || clientSecret == "" { + err = errors.New("clientId or clientSecret is empty") + return + } + + req := l_request.Request{ + Method: http.MethodPost, + Url: "https://api.dingtalk.com/v1.0/oauth2/accessToken", + Json: map[string]interface{}{ + "appKey": clientId, + "appSecret": clientSecret, + }, + } + res, err := req.Send() + if err != nil { + return + } + err = json.Unmarshal(res.Content, &auth) + + return +} + +func (a *Auth) GetTokenFromBotOption(ctx context.Context, botOption ...BotOption) (token *AuthInfo, err error) { + botInfo := &Bot{} + for _, option := range botOption { + option(botInfo) + } + + if botInfo.Id == 0 && botInfo.BotConfig == nil && botInfo.BotCode == "" { + err = errors.New("botInfo is nil") + return + } + + if botInfo.BotConfig == nil { + err = a.GetBotConfigFromModel(botInfo) + if err != nil { + return + } + } + + authInfo := a.redis.Get(ctx, a.getKeyBot(botInfo.BotConfig.RobotCode)).Val() + if authInfo == "" { + var botConfig entitys.DingTalkBot + err = json.Unmarshal([]byte(botInfo.BotConfig.BotConfig), &botConfig) + if err != nil { + log.Infof("初始化“%s”失败:%s", botInfo.BotConfig.BotName, err.Error()) + return + } + token, err = a.GetAccessToken(ctx, botConfig.ClientId, botConfig.ClientSecret) + if err != nil { + return + } + err = a.redis.SetEx(ctx, a.getKeyBot(botInfo.BotConfig.RobotCode), pkg.JsonStringIgonErr(token), token.Expire).Err() + if err != nil { + return + } + } else { + var tokenData AuthInfo + err = json.Unmarshal([]byte(authInfo), &tokenData) + token = &tokenData + } + return +} + +func (a *Auth) GetBotConfigFromModel(botInfo *Bot) (err error) { + var ( + botConfigDo model.AiBotConfig + ) + cond := builder.NewCond() + if botInfo.Id > 0 { + cond = cond.And(builder.Eq{"bot_id": botInfo.Id}) + } + if botInfo.BotCode != "" { + cond = cond.And(builder.Eq{"robot_code": botInfo.BotCode}) + } + err = a.botConfigImpl.GetOneBySearchToStrut(&cond, &botConfigDo) + if err != nil { + return + } + if botConfigDo.BotID == 0 { + err = errors.New("未找到机器人服务配置") + return + } + botInfo.BotConfig = &botConfigDo + return nil +} diff --git a/internal/biz/handle/dingtalk/dept.go b/internal/biz/handle/dingtalk/dept.go new file mode 100644 index 0000000..df8a607 --- /dev/null +++ b/internal/biz/handle/dingtalk/dept.go @@ -0,0 +1,108 @@ +package dingtalk + +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/l_request" + "context" + "encoding/json" + "fmt" + "net/http" + + "xorm.io/builder" +) + +type Dept struct { + dingDeptImpl *impl.BotDeptImpl + auth *Auth +} + +func NewDept(dingDeptImpl *impl.BotDeptImpl, auth *Auth) *Dept { + return &Dept{ + dingDeptImpl: dingDeptImpl, + auth: auth, + } +} + +func (d *Dept) GetDeptInfoByDeptIds(ctx context.Context, deptIds []int, authInfo *AuthInfo) (depts []*entitys.Dept, err error) { + if len(deptIds) == 0 || authInfo == nil { + return + } + var deptsInfo []model.AiBotDept + cond := builder.NewCond() + cond = cond.And(builder.Eq{"dingtalk_dept_id": deptIds}) + err = d.dingDeptImpl.GetRangeToMapStruct(&cond, &deptsInfo) + if err != nil { + return + } + var existDept = make([]int, len(deptsInfo), 0) + for _, dept := range deptsInfo { + depts = append(depts, &entitys.Dept{ + DeptId: int(dept.DeptID), + Name: dept.Name, + ToolList: dept.ToolList, + }) + existDept = append(existDept, int(dept.DeptID)) + } + diff := pkg.Difference(deptIds, existDept) + if len(diff) > 0 { + deptDo := make([]model.AiBotDept, 0) + for _, deptId := range diff { + deptInfo, _err := d.GetDeptInfoFromDingTalk(ctx, deptId, authInfo.AccessToken) + if _err != nil { + return nil, _err + } + depts = append(depts, &entitys.Dept{ + DeptId: deptInfo.DeptId, + Name: deptInfo.Name, + }) + deptDo = append(deptDo, model.AiBotDept{ + DingtalkDeptID: int32(deptInfo.DeptId), + Name: deptInfo.Name, + }) + } + if len(deptDo) > 0 { + _, err = d.dingDeptImpl.Add(deptDo) + if err != nil { + return nil, err + } + } + } + + return +} + +func (d *Dept) GetDeptInfoFromDingTalk(ctx context.Context, deptId int, token string) (depts DeptResResult, err error) { + if deptId == 0 || len(token) == 0 { + return + } + + req := l_request.Request{ + Method: http.MethodPost, + Url: constants.GetDingTalkRequestUrl(constants.RequestUrlGetDeptGet, map[string]string{ + "access_token": token, + }), + Json: map[string]interface{}{ + "dept_id": deptId, + }, + } + res, _err := req.Send() + if _err != nil { + err = _err + return + } + var deptInfo DeptRes + + err = json.Unmarshal(res.Content, &deptInfo) + if err != nil { + return + } + if deptInfo.Errcode != 0 { + fmt.Errorf("钉钉请求报错:%s", deptInfo.Errmsg) + } + return deptInfo.DeptResResult, err + +} diff --git a/internal/biz/handle/dingtalk/option.go b/internal/biz/handle/dingtalk/option.go new file mode 100644 index 0000000..3b72795 --- /dev/null +++ b/internal/biz/handle/dingtalk/option.go @@ -0,0 +1,36 @@ +package dingtalk + +import "ai_scheduler/internal/data/model" + +type Bot struct { + Id int + BotCode string + BotConfig *model.AiBotConfig +} +type BotOption func(*Bot) + +func WithId(id int) BotOption { + return func(b *Bot) { + b.Id = id + } +} + +func WithBotConfig(BotConfig *model.AiBotConfig) BotOption { + return func(bot *Bot) { + bot.BotConfig = BotConfig + } +} + +func WithBotCode(BotCode string) BotOption { + return func(bot *Bot) { + bot.BotCode = BotCode + } +} + +func WithBot(botSelf *Bot) BotOption { + return func(bot *Bot) { + bot.BotCode = botSelf.BotCode + bot.Id = botSelf.Id + bot.BotConfig = botSelf.BotConfig + } +} diff --git a/internal/biz/handle/dingtalk/provider_set.go b/internal/biz/handle/dingtalk/provider_set.go new file mode 100644 index 0000000..70f31ff --- /dev/null +++ b/internal/biz/handle/dingtalk/provider_set.go @@ -0,0 +1,12 @@ +package dingtalk + +import ( + "github.com/google/wire" +) + +var ProviderSetDingTalk = wire.NewSet( + NewUser, + NewAuth, + NewDept, + NewSendCardClient, +) diff --git a/internal/biz/handle/dingtalk/send_card.go b/internal/biz/handle/dingtalk/send_card.go new file mode 100644 index 0000000..c2063da --- /dev/null +++ b/internal/biz/handle/dingtalk/send_card.go @@ -0,0 +1,287 @@ +package dingtalk + +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/pkg" + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "time" + + openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client" + dingtalkim_1_0 "github.com/alibabacloud-go/dingtalk/im_1_0" + util "github.com/alibabacloud-go/tea-utils/v2/service" + "github.com/alibabacloud-go/tea/tea" + "github.com/gofiber/fiber/v2/log" + "github.com/google/uuid" +) + +const DefaultInterval = 100 * time.Millisecond +const HeardBeatX = 50 + +type SendCardClient struct { + Auth *Auth + CardClient *sync.Map + mu sync.RWMutex // 保护 CardClient 的并发访问 + logger log.AllLogger // 日志记录 + botOption *Bot +} + +func NewSendCardClient(auth *Auth, logger log.AllLogger) *SendCardClient { + return &SendCardClient{ + Auth: auth, + CardClient: &sync.Map{}, + logger: logger, + botOption: &Bot{}, + } +} + +// initClient 初始化或复用 DingTalk 客户端 +func (s *SendCardClient) initClient(robotCode string) (*dingtalkim_1_0.Client, error) { + if client, ok := s.CardClient.Load(robotCode); ok { + return client.(*dingtalkim_1_0.Client), nil + } + s.botOption.BotCode = robotCode + config := &openapi.Config{ + Protocol: tea.String("https"), + RegionId: tea.String("central"), + } + client, err := dingtalkim_1_0.NewClient(config) + if err != nil { + s.logger.Error("failed to init DingTalk client") + return nil, fmt.Errorf("init client failed: %w", err) + } + + s.CardClient.Store(robotCode, client) + return client, nil +} + +func (s *SendCardClient) NewCard(ctx context.Context, cardSend *CardSend) error { + // 参数校验 + if (len(cardSend.ContentSlice) == 0 || cardSend.ContentSlice == nil) && cardSend.ContentChannel == nil { + return errors.New("卡片内容不能为空") + } + if cardSend.UpdateInterval == 0 { + cardSend.UpdateInterval = DefaultInterval // 默认更新间隔 + } + if cardSend.Title == "" { + cardSend.Title = "钉钉卡片" + } + //替换标题 + replace, err := pkg.SafeReplace(string(cardSend.Template), "${title}", cardSend.Title) + if err != nil { + return err + } + cardSend.Template = constants.CardTemp(replace) + // 初始化客户端 + client, err := s.initClient(cardSend.RobotCode) + if err != nil { + return fmt.Errorf("初始化client失败: %w", err) + } + + // 生成卡片实例ID + cardInstanceId, err := uuid.NewUUID() + if err != nil { + return fmt.Errorf("创建uuid失败: %w", err) + } + + // 构建初始请求 + request, err := s.buildBaseRequest(cardSend, cardInstanceId.String()) + if err != nil { + return fmt.Errorf("请求失败: %w", err) + } + + // 发送初始卡片 + if _, err := s.SendInteractiveCard(ctx, request, cardSend.RobotCode, client); err != nil { + return fmt.Errorf("发送初始卡片失败: %w", err) + } + + // 处理切片内容(同步) + if len(cardSend.ContentSlice) > 0 { + if err := s.processContentSlice(ctx, cardSend, cardInstanceId.String(), client); err != nil { + return fmt.Errorf("内容同步失败: %w", err) + } + } + + // 处理通道内容(异步) + if cardSend.ContentChannel != nil { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + s.processContentChannel(ctx, cardSend, cardInstanceId.String(), client) + }() + wg.Wait() + } + + return nil +} + +// buildBaseRequest 构建基础请求 +func (s *SendCardClient) buildBaseRequest(cardSend *CardSend, cardInstanceId string) (*dingtalkim_1_0.SendRobotInteractiveCardRequest, error) { + cardData := fmt.Sprintf(string(cardSend.Template), "") // 初始空内容 + request := &dingtalkim_1_0.SendRobotInteractiveCardRequest{ + CardTemplateId: tea.String("StandardCard"), + CardBizId: tea.String(cardInstanceId), + CardData: tea.String(cardData), + RobotCode: tea.String(cardSend.RobotCode), + SendOptions: &dingtalkim_1_0.SendRobotInteractiveCardRequestSendOptions{}, + PullStrategy: tea.Bool(false), + } + + switch cardSend.ConversationType { + case constants.ConversationTypeGroup: + request.SetOpenConversationId(cardSend.ConversationId) + case constants.ConversationTypeSingle: + receiver, err := json.Marshal(map[string]string{"userId": cardSend.SenderStaffId}) + if err != nil { + return nil, fmt.Errorf("数据整理失败: %w", err) + } + request.SetSingleChatReceiver(string(receiver)) + default: + return nil, errors.New("未知的聊天场景") + } + + return request, nil +} + +// processContentChannel 处理通道内容(异步更新) +func (s *SendCardClient) processContentChannel(ctx context.Context, cardSend *CardSend, cardInstanceId string, client *dingtalkim_1_0.Client) { + defer func() { + if r := recover(); r != nil { + s.logger.Error("panic in processContentChannel") + } + }() + + ticker := time.NewTicker(cardSend.UpdateInterval) + defer ticker.Stop() + heartbeatTicker := time.NewTicker(time.Duration(HeardBeatX) * DefaultInterval) + defer heartbeatTicker.Stop() + + var ( + contentBuilder strings.Builder + lastUpdate time.Time + ) + for { + + select { + case content, ok := <-cardSend.ContentChannel: + if !ok { + // 通道关闭,发送最终内容 + if contentBuilder.Len() > 0 { + if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil { + s.logger.Errorf("更新卡片失败1:%s", err.Error()) + } + } + return + } + contentBuilder.WriteString(content) + if contentBuilder.Len() > 0 { + if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil { + s.logger.Errorf("更新卡片失败2:%s", err.Error()) + } + } + lastUpdate = time.Now() + + case <-heartbeatTicker.C: + if time.Now().Unix()-lastUpdate.Unix() >= HeardBeatX { + return + } + + case <-ctx.Done(): + s.logger.Info("context canceled, stop channel processing") + return + } + } + +} + +// processContentSlice 处理切片内容(同步更新) +func (s *SendCardClient) processContentSlice(ctx context.Context, cardSend *CardSend, cardInstanceId string, client *dingtalkim_1_0.Client) error { + var contentBuilder strings.Builder + for _, content := range cardSend.ContentSlice { + + contentBuilder.WriteString(content) + err := s.updateCardRequest(ctx, &UpdateCardRequest{ + Template: string(cardSend.Template), + Content: contentBuilder.String(), + Client: client, + RobotCode: cardSend.RobotCode, + CardInstanceId: cardInstanceId, + }) + if err != nil { + return fmt.Errorf("更新卡片失败: %w", err) + } + time.Sleep(cardSend.UpdateInterval) // 控制更新频率 + } + return nil +} + +// updateCardContent 封装卡片更新逻辑 +func (s *SendCardClient) updateCardContent(ctx context.Context, cardSend *CardSend, cardInstanceId, content string, client *dingtalkim_1_0.Client) error { + err := s.updateCardRequest(ctx, &UpdateCardRequest{ + Template: string(cardSend.Template), + Content: content, + Client: client, + RobotCode: cardSend.RobotCode, + CardInstanceId: cardInstanceId, + }) + + return err +} + +func (s *SendCardClient) updateCardRequest(ctx context.Context, updateCardRequest *UpdateCardRequest) error { + content, err := pkg.SafeReplace(updateCardRequest.Template, "%s", updateCardRequest.Content) + if err != nil { + return err + } + updateRequest := &dingtalkim_1_0.UpdateRobotInteractiveCardRequest{ + CardBizId: tea.String(updateCardRequest.CardInstanceId), + CardData: tea.String(content), + } + _, err = s.UpdateInteractiveCard(ctx, updateRequest, updateCardRequest.RobotCode, updateCardRequest.Client) + return err +} + +// UpdateInteractiveCard 更新交互卡片(封装错误处理) +func (s *SendCardClient) UpdateInteractiveCard(ctx context.Context, request *dingtalkim_1_0.UpdateRobotInteractiveCardRequest, robotCode string, client *dingtalkim_1_0.Client) (*dingtalkim_1_0.UpdateRobotInteractiveCardResponse, error) { + authInfo, err := s.Auth.GetTokenFromBotOption(ctx, WithBot(s.botOption)) + if err != nil { + return nil, fmt.Errorf("get token failed: %w", err) + } + + headers := &dingtalkim_1_0.UpdateRobotInteractiveCardHeaders{ + XAcsDingtalkAccessToken: tea.String(authInfo.AccessToken), + } + + response, err := client.UpdateRobotInteractiveCardWithOptions(request, headers, &util.RuntimeOptions{}) + if err != nil { + return nil, fmt.Errorf("API call failed: %w,request:%v", err, request.String()) + } + return response, nil +} + +// SendInteractiveCard 发送交互卡片(封装错误处理) +func (s *SendCardClient) SendInteractiveCard(ctx context.Context, request *dingtalkim_1_0.SendRobotInteractiveCardRequest, robotCode string, client *dingtalkim_1_0.Client) (res *dingtalkim_1_0.SendRobotInteractiveCardResponse, err error) { + err = s.Auth.GetBotConfigFromModel(s.botOption) + if err != nil { + return nil, fmt.Errorf("初始化bot失败: %w", err) + } + authInfo, err := s.Auth.GetTokenFromBotOption(ctx, WithBot(s.botOption)) + if err != nil { + return nil, fmt.Errorf("get token failed: %w", err) + } + + headers := &dingtalkim_1_0.SendRobotInteractiveCardHeaders{ + XAcsDingtalkAccessToken: tea.String(authInfo.AccessToken), + } + + response, err := client.SendRobotInteractiveCardWithOptions(request, headers, &util.RuntimeOptions{}) + if err != nil { + return nil, fmt.Errorf("API call failed: %w", err) + } + return response, nil +} diff --git a/internal/biz/handle/dingtalk/send_card.go.bak1 b/internal/biz/handle/dingtalk/send_card.go.bak1 new file mode 100644 index 0000000..9fb1e8d --- /dev/null +++ b/internal/biz/handle/dingtalk/send_card.go.bak1 @@ -0,0 +1,280 @@ +package dingtalk + +import ( + "ai_scheduler/internal/data/constants" + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "time" + + openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client" + dingtalkcard_1_0 "github.com/alibabacloud-go/dingtalk/card_1_0" + dingtalkim_1_0 "github.com/alibabacloud-go/dingtalk/im_1_0" + util "github.com/alibabacloud-go/tea-utils/v2/service" + "github.com/alibabacloud-go/tea/tea" + "github.com/gofiber/fiber/v2/log" + "github.com/google/uuid" +) + +const DefaultInterval = 100 * time.Millisecond +const HeardBeatX = 50 + +type SendCardClient struct { + Auth *Auth + CardClient *sync.Map + mu sync.RWMutex // 保护 CardClient 的并发访问 + logger log.AllLogger // 日志记录 + botOption *Bot +} + +func NewSendCardClient(auth *Auth, logger log.AllLogger) *SendCardClient { + return &SendCardClient{ + Auth: auth, + CardClient: &sync.Map{}, + logger: logger, + botOption: &Bot{}, + } +} + +// initClient 初始化或复用 DingTalk 客户端 +func (s *SendCardClient) initClient(robotCode string) (*dingtalkcard_1_0.Client, error) { + if client, ok := s.CardClient.Load(robotCode); ok { + return client.(*dingtalkcard_1_0.Client), nil + } + s.botOption.BotCode = robotCode + config := &openapi.Config{ + Protocol: tea.String("https"), + RegionId: tea.String("central"), + } + client, err := dingtalkcard_1_0.NewClient(config) + if err != nil { + s.logger.Error("failed to init DingTalk client") + return nil, fmt.Errorf("init client failed: %w", err) + } + + s.CardClient.Store(robotCode, client) + return client, nil +} + +func (s *SendCardClient) NewCard(ctx context.Context, cardSend *CardSend) error { + // 参数校验 + if (len(cardSend.ContentSlice) == 0 || cardSend.ContentSlice == nil) && cardSend.ContentChannel == nil { + return errors.New("卡片内容不能为空") + } + if cardSend.UpdateInterval == 0 { + cardSend.UpdateInterval = DefaultInterval // 默认更新间隔 + } + if cardSend.Title == "" { + cardSend.Title = "钉钉卡片" + } + //替换标题 + cardSend.Template = constants.CardTemp(strings.Replace(string(cardSend.Template), "${title}", cardSend.Title, 1)) + // 初始化客户端 + client, err := s.initClient(cardSend.RobotCode) + if err != nil { + return fmt.Errorf("初始化client失败: %w", err) + } + + // 生成卡片实例ID + cardInstanceId, err := uuid.NewUUID() + if err != nil { + return fmt.Errorf("创建uuid失败: %w", err) + } + + // 构建初始请求 + request, err := s.buildBaseRequest(cardSend, cardInstanceId.String()) + if err != nil { + return fmt.Errorf("请求失败: %w", err) + } + + // 发送初始卡片 + if _, err := s.SendInteractiveCard(ctx, request, cardSend.RobotCode, client); err != nil { + return fmt.Errorf("发送初始卡片失败: %w", err) + } + + // 处理切片内容(同步) + if len(cardSend.ContentSlice) > 0 { + if err := s.processContentSlice(ctx, cardSend, cardInstanceId.String(), client); err != nil { + return fmt.Errorf("内容同步失败: %w", err) + } + } + + // 处理通道内容(异步) + if cardSend.ContentChannel != nil { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + s.processContentChannel(ctx, cardSend, cardInstanceId.String(), client) + }() + wg.Wait() + } + + return nil +} + +// buildBaseRequest 构建基础请求 +func (s *SendCardClient) buildBaseRequest(cardSend *CardSend, cardInstanceId string) (*dingtalkcard_1_0.StreamingUpdateRequest, error) { + cardData := fmt.Sprintf(string(cardSend.Template), "") // 初始空内容 + request := &dingtalkcard_1_0.StreamingUpdateRequest{ + OutTrackId: tea.String("your-out-track-id"), + Guid: tea.String("0F714542-0AFC-2B0E-CF14-E2D39F5BFFE8"), + Key: tea.String("your-ai-param"), + Content: tea.String("test"), + IsFull: tea.Bool(false), + IsFinalize: tea.Bool(false), + IsError: tea.Bool(false), + } + + switch cardSend.ConversationType { + case constants.ConversationTypeGroup: + request.SetOpenConversationId(cardSend.ConversationId) + case constants.ConversationTypeSingle: + receiver, err := json.Marshal(map[string]string{"userId": cardSend.SenderStaffId}) + if err != nil { + return nil, fmt.Errorf("数据整理失败: %w", err) + } + request.SetSingleChatReceiver(string(receiver)) + default: + return nil, errors.New("未知的聊天场景") + } + + return request, nil +} + +// processContentChannel 处理通道内容(异步更新) +func (s *SendCardClient) processContentChannel(ctx context.Context, cardSend *CardSend, cardInstanceId string, client *dingtalkim_1_0.Client) { + defer func() { + if r := recover(); r != nil { + s.logger.Error("panic in processContentChannel") + } + }() + + ticker := time.NewTicker(cardSend.UpdateInterval) + defer ticker.Stop() + heartbeatTicker := time.NewTicker(time.Duration(HeardBeatX) * DefaultInterval) + defer heartbeatTicker.Stop() + + var ( + contentBuilder strings.Builder + lastUpdate time.Time + ) + for { + + select { + case content, ok := <-cardSend.ContentChannel: + if !ok { + // 通道关闭,发送最终内容 + if contentBuilder.Len() > 0 { + if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil { + s.logger.Errorf("更新卡片失败1:%s", err.Error()) + } + } + return + } + contentBuilder.WriteString(content) + if contentBuilder.Len() > 0 { + if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil { + s.logger.Errorf("更新卡片失败2:%s", err.Error()) + } + } + lastUpdate = time.Now() + + case <-heartbeatTicker.C: + if time.Now().Unix()-lastUpdate.Unix() >= HeardBeatX { + return + } + + case <-ctx.Done(): + s.logger.Info("context canceled, stop channel processing") + return + } + } + +} + +// processContentSlice 处理切片内容(同步更新) +func (s *SendCardClient) processContentSlice(ctx context.Context, cardSend *CardSend, cardInstanceId string, client *dingtalkim_1_0.Client) error { + var contentBuilder strings.Builder + for _, content := range cardSend.ContentSlice { + contentBuilder.WriteString(content) + err := s.updateCardRequest(ctx, &UpdateCardRequest{ + Template: string(cardSend.Template), + Content: contentBuilder.String(), + Client: client, + RobotCode: cardSend.RobotCode, + CardInstanceId: cardInstanceId, + }) + if err != nil { + return fmt.Errorf("更新卡片失败: %w", err) + } + time.Sleep(cardSend.UpdateInterval) // 控制更新频率 + } + return nil +} + +// updateCardContent 封装卡片更新逻辑 +func (s *SendCardClient) updateCardContent(ctx context.Context, cardSend *CardSend, cardInstanceId, content string, client *dingtalkim_1_0.Client) error { + err := s.updateCardRequest(ctx, &UpdateCardRequest{ + Template: string(cardSend.Template), + Content: content, + Client: client, + RobotCode: cardSend.RobotCode, + CardInstanceId: cardInstanceId, + }) + + return err +} + +func (s *SendCardClient) updateCardRequest(ctx context.Context, updateCardRequest *UpdateCardRequest) error { + + updateRequest := &dingtalkim_1_0.UpdateRobotInteractiveCardRequest{ + CardBizId: tea.String(updateCardRequest.CardInstanceId), + CardData: tea.String(fmt.Sprintf(updateCardRequest.Template, updateCardRequest.Content)), + } + _, err := s.UpdateInteractiveCard(ctx, updateRequest, updateCardRequest.RobotCode, updateCardRequest.Client) + return err +} + +// UpdateInteractiveCard 更新交互卡片(封装错误处理) +func (s *SendCardClient) UpdateInteractiveCard(ctx context.Context, request *dingtalkim_1_0.UpdateRobotInteractiveCardRequest, robotCode string, client *dingtalkim_1_0.Client) (*dingtalkim_1_0.UpdateRobotInteractiveCardResponse, error) { + authInfo, err := s.Auth.GetTokenFromBotOption(ctx, WithBot(s.botOption)) + if err != nil { + return nil, fmt.Errorf("get token failed: %w", err) + } + + headers := &dingtalkim_1_0.UpdateRobotInteractiveCardHeaders{ + XAcsDingtalkAccessToken: tea.String(authInfo.AccessToken), + } + + response, err := client.UpdateRobotInteractiveCardWithOptions(request, headers, &util.RuntimeOptions{}) + if err != nil { + return nil, fmt.Errorf("API call failed: %w,request:%v", err, request.String()) + } + return response, nil +} + +// SendInteractiveCard 发送交互卡片(封装错误处理) +func (s *SendCardClient) SendInteractiveCard(ctx context.Context, request *dingtalkim_1_0.SendRobotInteractiveCardRequest, robotCode string, client *dingtalkim_1_0.Client) (res *dingtalkim_1_0.SendRobotInteractiveCardResponse, err error) { + err = s.Auth.GetBotConfigFromModel(s.botOption) + if err != nil { + return nil, fmt.Errorf("初始化bot失败: %w", err) + } + authInfo, err := s.Auth.GetTokenFromBotOption(ctx, WithBot(s.botOption)) + if err != nil { + return nil, fmt.Errorf("get token failed: %w", err) + } + + headers := &dingtalkim_1_0.SendRobotInteractiveCardHeaders{ + XAcsDingtalkAccessToken: tea.String(authInfo.AccessToken), + } + + response, err := client.SendRobotInteractiveCardWithOptions(request, headers, &util.RuntimeOptions{}) + if err != nil { + return nil, fmt.Errorf("API call failed: %w", err) + } + return response, nil +} diff --git a/internal/biz/handle/dingtalk/types.go b/internal/biz/handle/dingtalk/types.go new file mode 100644 index 0000000..5baa770 --- /dev/null +++ b/internal/biz/handle/dingtalk/types.go @@ -0,0 +1,110 @@ +package dingtalk + +import ( + "ai_scheduler/internal/data/constants" + "time" + + dingtalkim_1_0 "github.com/alibabacloud-go/dingtalk/im_1_0" +) + +type DingTalkAuthIRes struct { + AccessToken string `json:"accessToken"` + ExpireIn int64 `json:"expireIn"` +} + +type UserInfoRes struct { + Errcode int `json:"errcode"` + Errmsg string `json:"errmsg"` + Result UserInfoResResult `json:"result"` + RequestId string `json:"request_id"` +} + +type UserInfoResResult struct { + Active bool `json:"active"` + Admin bool `json:"admin"` + Avatar string `json:"avatar"` + Boss bool `json:"boss"` + CreateTime time.Time `json:"create_time"` + DeptIdList []int `json:"dept_id_list"` + DeptOrderList []struct { + DeptId int `json:"dept_id"` + Order int64 `json:"order"` + } `json:"dept_order_list"` + ExclusiveAccount bool `json:"exclusive_account"` + HideMobile bool `json:"hide_mobile"` + HiredDate int64 `json:"hired_date"` + JobNumber string `json:"job_number"` + LeaderInDept []struct { + DeptId int `json:"dept_id"` + Leader bool `json:"leader"` + } `json:"leader_in_dept"` + ManagerUserid string `json:"manager_userid"` + Name string `json:"name"` + RealAuthed bool `json:"real_authed"` + RoleList []struct { + GroupName string `json:"group_name"` + Id int `json:"id"` + Name string `json:"name"` + } `json:"role_list"` + Senior bool `json:"senior"` + Title string `json:"title"` + Unionid string `json:"unionid"` + Userid string `json:"userid"` +} + +type DeptRes struct { + Errcode int `json:"errcode"` + Errmsg string `json:"errmsg"` + DeptResResult DeptResResult `json:"result"` + RequestId string `json:"request_id"` +} + +type DeptResResult struct { + DeptPermits []int `json:"dept_permits"` + OuterPermitUsers []string `json:"outer_permit_users"` + DeptManagerUseridList []string `json:"dept_manager_userid_list"` + OrgDeptOwner string `json:"org_dept_owner"` + OuterDept bool `json:"outer_dept"` + DeptGroupChatId string `json:"dept_group_chat_id"` + GroupContainSubDept bool `json:"group_contain_sub_dept"` + AutoAddUser bool `json:"auto_add_user"` + HideDept bool `json:"hide_dept"` + Name string `json:"name"` + OuterPermitDepts []int `json:"outer_permit_depts"` + UserPermits []interface{} `json:"user_permits"` + DeptId int `json:"dept_id"` + CreateDeptGroup bool `json:"create_dept_group"` + Order int `json:"order"` + Code string `json:"code"` + UnionDeptExt struct { + CorpId string `json:"corp_id"` + DeptId int `json:"dept_id"` + } `json:"union_dept_ext"` +} + +type AuthInfo struct { + ClientId string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + AccessToken string `json:"accessToken"` + Expire time.Duration `json:"expireIn"` +} + +type CardSend struct { + RobotCode string + ConversationType constants.ConversationType + ConversationId string + Template constants.CardTemp + SenderStaffId string + Title string + ContentSlice []string + ContentChannel chan string + UpdateInterval time.Duration // 控制通道更新的频率 +} + +type UpdateCardRequest struct { + Template string + Content string + Client *dingtalkim_1_0.Client + RobotCode string + CardInstanceId string +} diff --git a/internal/biz/handle/dingtalk/user.go b/internal/biz/handle/dingtalk/user.go new file mode 100644 index 0000000..2e4e615 --- /dev/null +++ b/internal/biz/handle/dingtalk/user.go @@ -0,0 +1,128 @@ +package dingtalk + +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/l_request" + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" +) + +type User struct { + dingUserImpl *impl.BotUserImpl + botConfigImpl *impl.BotConfigImpl + auth *Auth + dept *Dept +} + +func NewUser( + dingUserImpl *impl.BotUserImpl, + auth *Auth, + dept *Dept, +) *User { + return &User{ + dingUserImpl: dingUserImpl, + auth: auth, + dept: dept, + } +} + +func (u *User) GetUserInfoFromBot(ctx context.Context, staffId string, botOption ...BotOption) (userInfo *entitys.DingTalkUserInfo, err error) { + if len(staffId) == 0 { + return + } + user, err := u.dingUserImpl.GetByStaffId(staffId) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return + } + } + //待优化 + authInfo, err := u.auth.GetTokenFromBotOption(ctx, botOption...) + if err != nil || authInfo == nil { + return + } + //如果没有找到,则新增 + if user == nil { + + DingUserInfo, _err := u.getUserInfoFromDingTalk(ctx, authInfo.AccessToken, staffId) + if _err != nil { + return nil, _err + } + user = &model.AiBotUser{ + StaffID: DingUserInfo.Userid, + Name: DingUserInfo.Name, + Title: DingUserInfo.Title, + //Extension: DingUserInfo.Extension, + DeptIDList: strings.Join(pkg.SliceIntToString(DingUserInfo.DeptIdList), ","), + IsBoss: int32(pkg.Ter(DingUserInfo.Boss, constants.IsBossTrue, constants.IsBossFalse)), + IsSenior: int32(pkg.Ter(DingUserInfo.Senior, constants.IsSeniorTrue, constants.IsSeniorFalse)), + HiredDate: time.UnixMilli(DingUserInfo.HiredDate), + } + + _, err = u.dingUserImpl.Add(user) + if err != nil { + return + } + } + userInfo = &entitys.DingTalkUserInfo{ + UserId: int(user.UserID), + StaffId: user.StaffID, + Name: user.Name, + IsBoss: constants.IsBoss(user.IsBoss), + IsSenior: constants.IsSenior(user.IsSenior), + HiredDate: user.HiredDate, + Extension: user.Extension, + } + if len(user.DeptIDList) > 0 { + deptIdList := pkg.SliceStringToInt(strings.Split(user.DeptIDList, ",")) + depts, _err := u.dept.GetDeptInfoByDeptIds(ctx, deptIdList, authInfo) + if _err != nil { + return nil, err + } + for _, dept := range depts { + userInfo.Dept = append(userInfo.Dept, dept) + } + } + + return userInfo, nil +} + +func (u *User) getUserInfoFromDingTalk(ctx context.Context, token string, staffId string) (user UserInfoResResult, err error) { + if token == "" && staffId == "" { + err = errors.New("获取钉钉用户信息的必要参数不足") + return + } + + req := l_request.Request{ + Method: http.MethodPost, + Url: constants.GetDingTalkRequestUrl(constants.RequestUrlGetUserGet, map[string]string{ + "access_token": token, + }), + Data: map[string]string{ + "userid": staffId, + }, + } + res, err := req.Send() + if err != nil { + return + } + var userInfoRes UserInfoRes + err = json.Unmarshal(res.Content, &userInfoRes) + if err != nil { + return + } + if userInfoRes.Errcode != 0 { + fmt.Errorf("钉钉请求报错:%s", userInfoRes.Errmsg) + } + return userInfoRes.Result, err +} diff --git a/internal/biz/handle/file.go b/internal/biz/handle/file.go new file mode 100644 index 0000000..e93fbc3 --- /dev/null +++ b/internal/biz/handle/file.go @@ -0,0 +1,173 @@ +package handle + +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/l_request" + "bytes" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "path/filepath" + "strings" + + "github.com/gabriel-vasile/mimetype" +) + +// HandleRecognizeFile 这里的目的是无论将什么类型的file都转为二进制格式 +// 最终输出:1.将 files.FileData 填充为文件的二进制数据 2.将 files.FileType 填充为文件的类型(当前为 constants.Caller,兼容写入其字符串值) +// 判断文件大小(统一限制为10MB);判断文件类型;判断文件是否合法(类型在白名单映射中);无法识别/非法/超限→填充unknown并兼容返回 +// 若 FileData 不存在 且 FileUrl 不存在, 则直接退出 +// 若 FileData 存在 FileType 存在, 则直接退出 +// 若 FileData 存在 FileType 不存在, 则根据 FileData 推断文件类型并填充 FileType +// 若 FileUrl 存在, 则下载文件并填充 FileData 和 FileType +func HandleRecognizeFile(files *entitys.RecognizeFile) { + if files == nil { + return + } + + const maxSize = 10 * 1024 * 1024 // 10MB 上限 + + // 工具:根据 MIME 或扩展名映射到 FileType + mapToFileType := func(s string) constants.FileType { + if len(s) == 0 { + return constants.FileTypeUnknown + } + s = strings.ToLower(strings.TrimSpace(s)) + for ft, items := range constants.FileTypeMappings { + for _, item := range items { + if !strings.HasPrefix(item, ".") { // MIME + if s == item { + return ft + } + } else { // 扩展名 + if s == item { + return ft + } + } + } + } + return constants.FileTypeUnknown + } + + // 分支1:无数据、无URL→直接返回 + if len(files.FileData) == 0 && len(files.FileUrl) == 0 { + return + } + + // 分支2:已有数据且已有类型→直接返回 + if len(files.FileData) > 0 && len(strings.TrimSpace(files.FileType.String())) > 0 { + return + } + + // 分支3:仅有数据、无类型→内容检测并填充 + if len(files.FileData) > 0 && len(strings.TrimSpace(files.FileType.String())) == 0 { + if len(files.FileData) > maxSize { + files.FileType = constants.FileTypeUnknown + return + } + + reader := bytes.NewReader(files.FileData) + detected, fileRealMime := detectFileType(reader, "") + files.FileType = detected + files.FileRealMime = fileRealMime + return + } + + // 分支4:存在URL→下载并填充数据与类型 + if len(files.FileUrl) > 0 { + fileBytes, contentType, err := downloadFile(files.FileUrl) + if err != nil || len(fileBytes) == 0 { + files.FileType = constants.FileTypeUnknown + return + } + + if len(fileBytes) > maxSize { + // 超限:不写入数据,类型置 unknown + files.FileType = constants.FileTypeUnknown + return + } + + // 优先使用响应头的 Content-Type 映射 + detected := mapToFileType(contentType) + fileRealMime := contentType + + if detected == constants.FileTypeUnknown { + // 回退:内容检测 + URL 文件名扩展名辅助 + var fname string + if u, perr := url.Parse(files.FileUrl); perr == nil { + fname = filepath.Base(u.Path) + } + reader := bytes.NewReader(fileBytes) + detected, fileRealMime = detectFileType(reader, fname) + } + + // 写入数据 + files.FileData = fileBytes + files.FileType = detected + files.FileRealMime = fileRealMime + return + } +} + +// 下载文件并返回二进制数据、MIME 类型 +func downloadFile(fileUrl string) (fileBytes []byte, contentType string, err error) { + if len(fileUrl) == 0 { + return + } + req := l_request.Request{ + Method: "GET", + Url: fileUrl, + Headers: map[string]string{ + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + "Accept": "image/webp,image/apng,image/*,*/*;q=0.8", + }, + } + res, err := req.Send() + if err != nil { + return + } + var ex bool + if contentType, ex = res.Headers["Content-Type"]; !ex { + err = errors.New("Content-Type不存在") + return + } + + if res.StatusCode != http.StatusOK { + err = fmt.Errorf("server returned non-200 status: %d", res.StatusCode) + } + fileBytes = res.Content + + return fileBytes, contentType, nil +} + +// detectFileType 判断文件类型 +func detectFileType(file io.ReadSeeker, filename string) (constants.FileType, string) { + // 1. 读取文件头检测 MIME + buffer := make([]byte, 512) + n, _ := file.Read(buffer) + file.Seek(0, io.SeekStart) // 重置读取位置 + + detectedMIME := mimetype.Detect(buffer[:n]).String() + for fileType, items := range constants.FileTypeMappings { + for _, item := range items { + if !strings.HasPrefix(item, ".") && item == detectedMIME { + return fileType, detectedMIME + } + } + } + + // 2. 备用:通过扩展名检测 + ext := strings.ToLower(filepath.Ext(filename)) + for fileType, items := range constants.FileTypeMappings { + for _, item := range items { + if strings.HasPrefix(item, ".") && item == ext { + return fileType, ext + } + } + } + + return constants.FileTypeUnknown, "" +} diff --git a/internal/biz/handle/handle.go b/internal/biz/handle/handle.go deleted file mode 100644 index 391b6eb..0000000 --- a/internal/biz/handle/handle.go +++ /dev/null @@ -1,22 +0,0 @@ -package handle - -import ( - "ai_scheduler/internal/config" - "ai_scheduler/internal/tools" -) - -type Handle struct { - toolManager *tools.Manager - conf *config.Config -} - -func NewHandle( - toolManager *tools.Manager, - conf *config.Config, - -) *Handle { - return &Handle{ - toolManager: toolManager, - conf: conf, - } -} diff --git a/internal/biz/llm_service/common.go b/internal/biz/llm_service/common.go index 1d62ed7..c3ffe46 100644 --- a/internal/biz/llm_service/common.go +++ b/internal/biz/llm_service/common.go @@ -1,9 +1,7 @@ package llm_service import ( - "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/model" - "ai_scheduler/internal/entitys" "context" "time" ) @@ -20,48 +18,3 @@ func buildSystemPrompt(prompt string) string { return prompt } - -func buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) { - for _, item := range his { - if len(chatHis.SessionId) == 0 { - chatHis.SessionId = item.SessionID - } - chatHis.Messages = append(chatHis.Messages, []entitys.HisMessage{ - { - Role: constants.RoleUser, - Content: item.Ques, - Timestamp: item.CreateAt.Format(time.DateTime), - }, - { - Role: constants.RoleAssistant, - Content: item.Ans, - Timestamp: item.CreateAt.Format(time.DateTime), - }, - }...) - } - chatHis.Context = entitys.HisContext{ - UserLanguage: "zh-CN", - SystemMode: "technical_support", - } - return -} - -func BuildChatHisMessage(his []model.AiChatHi) (chatHis []entitys.HisMessage) { - for _, item := range his { - - chatHis = append(chatHis, []entitys.HisMessage{ - { - Role: constants.RoleUser, - Content: item.Ques, - Timestamp: item.CreateAt.Format(time.DateTime), - }, - { - Role: constants.RoleAssistant, - Content: item.Ans, - Timestamp: item.CreateAt.Format(time.DateTime), - }, - }...) - } - - return -} diff --git a/internal/biz/llm_service/langchain.go b/internal/biz/llm_service/langchain.go index 63f8815..a3216e8 100644 --- a/internal/biz/llm_service/langchain.go +++ b/internal/biz/llm_service/langchain.go @@ -1,87 +1,76 @@ package llm_service -import ( - "ai_scheduler/internal/data/model" - "ai_scheduler/internal/entitys" - "ai_scheduler/internal/pkg" - "ai_scheduler/internal/pkg/utils_langchain" - "context" - "encoding/json" - - "github.com/tmc/langchaingo/llms" -) - -type LangChainService struct { - client *utils_langchain.UtilLangChain -} - -func NewLangChainGenerate( - client *utils_langchain.UtilLangChain, -) *LangChainService { - - return &LangChainService{ - client: client, - } -} - -func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) { - prompt := r.getPrompt(sysInfo, history, userInput, tasks) - AgentClient := r.client.Get() - defer r.client.Put(AgentClient) - match, err := AgentClient.Llm.GenerateContent( - ctx, // 使用可取消的上下文 - prompt, - llms.WithJSONMode(), - ) - msg = match.Choices[0].Content - return -} - -func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { - var ( - prompt = make([]llms.MessageContent, 0) - ) - prompt = append(prompt, llms.MessageContent{ - Role: llms.ChatMessageTypeSystem, - Parts: []llms.ContentPart{ - llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeTool, - Parts: []llms.ContentPart{ - llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeTool, - Parts: []llms.ContentPart{ - llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{ - llms.TextPart(reqInput), - }, - }) - return prompt -} - -func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool { - taskPrompt := make([]llms.Tool, 0) - for _, task := range tasks { - var taskConfig entitys.TaskConfig - err := json.Unmarshal([]byte(task.Config), &taskConfig) - if err != nil { - continue - } - taskPrompt = append(taskPrompt, llms.Tool{ - Type: "function", - Function: &llms.FunctionDefinition{ - Name: task.Index, - Description: task.Desc, - Parameters: taskConfig.Param, - }, - }) - - } - return taskPrompt -} +//type LangChainService struct { +// client *utils_langchain.UtilLangChain +//} +// +//func NewLangChainGenerate( +// client *utils_langchain.UtilLangChain, +//) *LangChainService { +// +// return &LangChainService{ +// client: client, +// } +//} +// +//func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) { +// prompt := r.getPrompt(sysInfo, history, userInput, tasks) +// AgentClient := r.client.Get() +// defer r.client.Put(AgentClient) +// match, err := AgentClient.Llm.GenerateContent( +// ctx, // 使用可取消的上下文 +// prompt, +// llms.WithJSONMode(), +// ) +// msg = match.Choices[0].Content +// return +//} +// +//func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { +// var ( +// prompt = make([]llms.MessageContent, 0) +// ) +// prompt = append(prompt, llms.MessageContent{ +// Role: llms.ChatMessageTypeSystem, +// Parts: []llms.ContentPart{ +// llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)), +// }, +// }, llms.MessageContent{ +// Role: llms.ChatMessageTypeTool, +// Parts: []llms.ContentPart{ +// llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))), +// }, +// }, llms.MessageContent{ +// Role: llms.ChatMessageTypeTool, +// Parts: []llms.ContentPart{ +// llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), +// }, +// }, llms.MessageContent{ +// Role: llms.ChatMessageTypeHuman, +// Parts: []llms.ContentPart{ +// llms.TextPart(reqInput), +// }, +// }) +// return prompt +//} +// +//func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool { +// taskPrompt := make([]llms.Tool, 0) +// for _, task := range tasks { +// var taskConfig entitys.TaskConfig +// err := json.Unmarshal([]byte(task.Config), &taskConfig) +// if err != nil { +// continue +// } +// taskPrompt = append(taskPrompt, llms.Tool{ +// Type: "function", +// Function: &llms.FunctionDefinition{ +// Name: task.Index, +// Description: task.Desc, +// Parameters: taskConfig.Param, +// }, +// }) +// +// } +// return taskPrompt +//} diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 60d9c78..7333de5 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -3,46 +3,42 @@ package llm_service import ( "ai_scheduler/internal/config" "ai_scheduler/internal/data/impl" - "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/pkg/utils_vllm" "context" - "encoding/json" + "errors" - "strings" - "time" "github.com/ollama/ollama/api" - "xorm.io/builder" ) type OllamaService struct { - client *utils_ollama.Client - config *config.Config - chatHis *impl.ChatImpl + client *utils_ollama.Client + vllmClient *utils_vllm.Client + config *config.Config + chatHis *impl.ChatHisImpl } func NewOllamaGenerate( client *utils_ollama.Client, + vllmClient *utils_vllm.Client, config *config.Config, - chatHis *impl.ChatImpl, + chatHis *impl.ChatHisImpl, ) *OllamaService { return &OllamaService{ - client: client, - config: config, - chatHis: chatHis, + client: client, + vllmClient: vllmClient, + config: config, + chatHis: chatHis, } } -func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) { - prompt, err := r.getPrompt(ctx, requireData) - if err != nil { - return - } - toolDefinitions := r.registerToolsOllama(requireData.Tasks) +func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) { - match, err := r.client.ToolSelect(ctx, prompt, toolDefinitions) + toolDefinitions := r.registerToolsOllama(req.Tools) + match, err := r.client.ToolSelect(ctx, req.Prompt, toolDefinitions) if err != nil { return } @@ -66,105 +62,63 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entity return } -func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) { +//func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) { +// if imgByte == nil { +// return +// } +// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") +// +// desc, err = r.client.Generation(ctx, &api.GenerateRequest{ +// Model: r.config.Ollama.VlModel, +// Stream: new(bool), +// System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt, +// Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt, +// Images: requireData.ImgByte, +// KeepAlive: &api.Duration{Duration: 3600 * time.Second}, +// //Think: &api.ThinkValue{Value: false}, +// }) +// if err != nil { +// return +// } +// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) +// return +//} - var ( - prompt = make([]api.Message, 0) - ) - content, err := r.getUserContent(ctx, requireData) - if err != nil { - return nil, err - } - prompt = append(prompt, api.Message{ - Role: "system", - Content: buildSystemPrompt(requireData.Sys.SysPrompt), - }, api.Message{ - Role: "assistant", - Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)), - }, api.Message{ - Role: "user", - Content: content, - }) +//func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { +// if requireData.ImgByte == nil { +// return +// } +// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") +// +// outMsg, err := r.vllmClient.RecognizeWithImg(ctx, +// r.config.DefaultPrompt.ImgRecognize.SystemPrompt, +// r.config.DefaultPrompt.ImgRecognize.UserPrompt, +// requireData.ImgUrls, +// ) +// if err != nil { +// return api.GenerateResponse{}, err +// } +// +// desc = api.GenerateResponse{ +// Response: outMsg.Content, +// } +// +// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) +// return +//} - return prompt, nil -} - -func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys.RequireData) (string, error) { - var content strings.Builder - content.WriteString(requireData.Req.Text) - if len(requireData.ImgByte) > 0 { - content.WriteString("\n") - } - - if len(requireData.Req.Tags) > 0 { - content.WriteString("\n") - content.WriteString("### 工具必须使用:") - content.WriteString(requireData.Req.Tags) - } - - if len(requireData.ImgByte) > 0 { - desc, err := r.RecognizeWithImg(ctx, requireData) - if err != nil { - return "", err - } - content.WriteString("### 上传图片解析内容:\n") - content.WriteString(requireData.Req.Tags) - content.WriteString(desc.Response) - } - - if requireData.Req.MarkHis > 0 { - var his model.AiChatHi - cond := builder.NewCond() - cond = cond.And(builder.Eq{"his_id": requireData.Req.MarkHis}) - err := r.chatHis.GetOneBySearchToStrut(&cond, &his) - if err != nil { - return "", err - } - content.WriteString("### 引用历史聊天记录:\n") - content.WriteString(pkg.JsonStringIgonErr(BuildChatHisMessage([]model.AiChatHi{his}))) - } - return content.String(), nil -} - -func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { - if requireData.ImgByte == nil { - return - } - entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") - - desc, err = r.client.Generation(ctx, &api.GenerateRequest{ - Model: r.config.Ollama.VlModel, - Stream: new(bool), - System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt, - Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt, - Images: requireData.ImgByte, - KeepAlive: &api.Duration{Duration: 3600 * time.Second}, - }) - if err != nil { - return - } - entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) - return -} - -func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool { +func (r *OllamaService) registerToolsOllama(tasks []entitys.RegistrationTask) []api.Tool { taskPrompt := make([]api.Tool, 0) for _, task := range tasks { - var taskConfig entitys.TaskConfigDetail - err := json.Unmarshal([]byte(task.Config), &taskConfig) - if err != nil { - continue - } - taskPrompt = append(taskPrompt, api.Tool{ Type: "function", Function: api.ToolFunction{ - Name: task.Index, + Name: task.Name, Description: task.Desc, Parameters: api.ToolFunctionParameters{ - Type: taskConfig.Param.Type, - Required: taskConfig.Param.Required, - Properties: taskConfig.Param.Properties, + Type: task.TaskConfigDetail.Param.Type, + Required: task.TaskConfigDetail.Param.Required, + Properties: task.TaskConfigDetail.Param.Properties, }, }, }) diff --git a/internal/biz/provider_set.go b/internal/biz/provider_set.go index cc3de0a..1bdc0f7 100644 --- a/internal/biz/provider_set.go +++ b/internal/biz/provider_set.go @@ -1,21 +1,21 @@ package biz import ( - "ai_scheduler/internal/biz/do" - "ai_scheduler/internal/biz/handle" - "ai_scheduler/internal/biz/llm_service" - - "github.com/google/wire" + "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/biz/llm_service" + + "github.com/google/wire" ) var ProviderSetBiz = wire.NewSet( NewAiRouterBiz, NewSessionBiz, NewChatHistoryBiz, - llm_service.NewLangChainGenerate, + //llm_service.NewLangChainGenerate, llm_service.NewOllamaGenerate, - handle.NewHandle, + //handle.NewHandle, do.NewDo, - do.NewHandle, + do.NewHandle, NewTaskBiz, + NewDingTalkBotBiz, ) diff --git a/internal/biz/router.go b/internal/biz/router.go index 6dcc233..7d045f8 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -2,8 +2,17 @@ package biz import ( "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" + errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/gateway" + "ai_scheduler/internal/pkg/rec_extra" + "context" + "encoding/json" + "strings" + "time" + "ai_scheduler/internal/entitys" "github.com/gofiber/fiber/v2/log" @@ -13,16 +22,19 @@ import ( type AiRouterBiz struct { do *do.Do handle *do.Handle + config *config.Config } // NewAiRouterBiz 创建路由服务 func NewAiRouterBiz( do *do.Do, handle *do.Handle, + config *config.Config, ) *AiRouterBiz { return &AiRouterBiz{ do: do, handle: handle, + config: config, } } @@ -39,11 +51,9 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS requireData := &entitys.RequireData{ Req: req, } - // 获取WebSocket连接 - conn := client.GetConn() //初始化通道/上下文 - ctx, clearFunc := r.do.MakeCh(conn, requireData) + ctx, clearFunc := r.do.MakeCh(client, requireData) defer func() { if err != nil { entitys.ResError(requireData.Ch, "", err.Error()) @@ -56,17 +66,133 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS log.Errorf("数据验证和收集失败: %s", err.Error()) return } - + //组装意图识别 + rec, sys, err := r.SetRec(ctx, requireData) + if err != nil { + log.Errorf("组装意图识别失败: %s", err.Error()) + return + } //意图识别 - if err = r.handle.Recognize(ctx, requireData); err != nil { + + err = r.handle.Recognize(ctx, &rec, sys) + if err != nil { log.Errorf("意图识别失败: %s", err.Error()) return } - + //任务处理 + rec_extra.SetTaskRecExt(requireData, &rec) //向下传递 - if err = r.handle.HandleMatch(ctx, client, requireData); err != nil { + if err = r.handle.HandleMatch(ctx, client, &rec, requireData); err != nil { log.Errorf("任务处理失败: %s", err.Error()) return } return } + +func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireData) (match entitys.Recognize, sys do.PromptOption, err error) { + // 参数空值检查 + if requireData == nil || requireData.Req == nil { + return match, sys, errors.NewBusinessErr(500, "请求参数为空") + } + + // 对应不同的appKey, 配置不同的系统提示词 + switch requireData.Sys.AppKey { + default: + sys = &do.WithSys{Config: r.config} + } + + // 1. 系统提示词 + match.SystemPrompt = requireData.Sys.SysPrompt + + // 2. 用户输入和文件处理 + match.UserContent, err = r.buildUserContent(requireData) + if err != nil { + log.Errorf("构建用户内容失败: %s", err.Error()) + return + } + + // 3. 聊天记录 - 只有在有历史记录时才构建 + if len(requireData.Histories) > 0 { + match.ChatHis = r.buildChatHistory(requireData) + } + + // 4. 任务列表 - 预分配切片容量 + if len(requireData.Tasks) > 0 { + match.Tasks = make([]entitys.RegistrationTask, 0, len(requireData.Tasks)) + for _, task := range requireData.Tasks { + taskConfig := entitys.TaskConfigDetail{} + if err = json.Unmarshal([]byte(task.Config), &taskConfig); err != nil { + log.Errorf("解析任务配置失败: %s, 任务ID: %s", err.Error(), task.Index) + continue // 解析失败时跳过该任务,而不是直接返回错误 + } + + match.Tasks = append(match.Tasks, entitys.RegistrationTask{ + Name: task.Index, + Desc: task.Desc, + TaskConfigDetail: taskConfig, // 直接使用解析后的配置,避免重复构建 + }) + } + } + + match.Ch = requireData.Ch + return +} + +// buildUserContent 构建用户内容 +func (r *AiRouterBiz) buildUserContent(requireData *entitys.RequireData) (*entitys.RecognizeUserContent, error) { + // 预分配文件切片容量(最多2个文件:File和Img) + files := make([]*entitys.RecognizeFile, 0, 2) + + // 处理文件和图片 + fileUrls := []string{requireData.Req.File, requireData.Req.Img} + for _, item := range fileUrls { + // 处理逗号分隔的多个URL + urlList := strings.Split(item, ",") + for _, url := range urlList { + if url != "" { + files = append(files, &entitys.RecognizeFile{FileUrl: url}) + } + } + } + + // 构建并返回用户内容 + return &entitys.RecognizeUserContent{ + Text: requireData.Req.Text, + File: files, + ActionCardUrl: "", // TODO: 后续实现操作卡片功能 + Tag: requireData.Req.Tags, + }, nil +} + +// buildChatHistory 构建聊天历史 +func (r *AiRouterBiz) buildChatHistory(requireData *entitys.RequireData) entitys.ChatHis { + // 预分配消息切片容量(每个历史记录生成2条消息) + messages := make([]entitys.HisMessage, 0, len(requireData.Histories)*2) + + // 构建聊天记录 + for _, h := range requireData.Histories { + // 用户消息 + messages = append(messages, entitys.HisMessage{ + Role: constants.RoleUser, // 用户角色 + Content: h.Ans, // 用户输入内容 + Timestamp: h.CreateAt.Format(time.DateTime), + }) + + // 助手消息 + messages = append(messages, entitys.HisMessage{ + Role: constants.RoleAssistant, // 助手角色 + Content: h.Ques, // 助手回复内容 + Timestamp: h.CreateAt.Format(time.DateTime), + }) + } + + // 构建聊天历史上下文 + return entitys.ChatHis{ + SessionId: requireData.Session, + Messages: messages, + Context: entitys.HisContext{ + UserLanguage: constants.LangZhCN, // 默认中文 + SystemMode: constants.SystemModeTechnicalSupport, // 默认技术支持模式 + }, + } +} diff --git a/internal/biz/session.go b/internal/biz/session.go index d014148..a42e9c0 100644 --- a/internal/biz/session.go +++ b/internal/biz/session.go @@ -17,16 +17,16 @@ import ( type SessionBiz struct { sessionRepo *impl.SessionImpl sysRepo *impl.SysImpl - chatRepo *impl.ChatImpl + chatHisRepo *impl.ChatHisImpl conf *config.Config } -func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatImpl) *SessionBiz { +func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatHisImpl) *SessionBiz { return &SessionBiz{ sessionRepo: sessionImpl, sysRepo: sysImpl, - chatRepo: chatImpl, + chatHisRepo: chatImpl, conf: conf, } } @@ -91,10 +91,10 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe result.Prologue = sysConfig.Prologue // 存在,返回会话历史 var chatList []model.AiChatHi - chatList, err = s.chatRepo.FindAll( - s.chatRepo.WithSessionId(session.SessionID), // 条件:会话ID - s.chatRepo.OrderByDesc("create_at"), // 排序:按创建时间降序 - s.chatRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数 + chatList, err = s.chatHisRepo.FindAll( + s.chatHisRepo.WithSessionId(session.SessionID), // 条件:会话ID + s.chatHisRepo.OrderByDesc("create_at"), // 排序:按创建时间降序 + s.chatHisRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数 ) if err != nil { return diff --git a/internal/biz/task.go b/internal/biz/task.go index 277a97e..f9c03f0 100644 --- a/internal/biz/task.go +++ b/internal/biz/task.go @@ -70,13 +70,13 @@ func (t *TaskBiz) GetUserPermission(req *entitys.TaskRequest, auth string) (code // 发送请求 res, err := request.Send() if err != nil { - err = errors.SysErr("请求用户权限失败") + err = errors.SysErrf("请求用户权限失败") return } // 检查响应状态码 if res.StatusCode != http.StatusOK { - err = errors.SysErr("获取用户权限失败") + err = errors.SysErrf("获取用户权限失败") return } diff --git a/internal/biz/tools_regis/provider_set.go b/internal/biz/tools_regis/provider_set.go new file mode 100644 index 0000000..8294cf0 --- /dev/null +++ b/internal/biz/tools_regis/provider_set.go @@ -0,0 +1,9 @@ +package tools_regis + +import ( + "github.com/google/wire" +) + +var ProviderToolsRegis = wire.NewSet( + NewToolsRegis, +) diff --git a/internal/biz/tools_regis/tools_regis.go b/internal/biz/tools_regis/tools_regis.go new file mode 100644 index 0000000..0109849 --- /dev/null +++ b/internal/biz/tools_regis/tools_regis.go @@ -0,0 +1,30 @@ +package tools_regis + +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + + "xorm.io/builder" +) + +type ToolRegis struct { + //待优化 + BootTools []model.AiBotTool +} + +func NewToolsRegis(botToolsImpl *impl.BotToolsImpl) *ToolRegis { + botTools := &ToolRegis{} + err := botTools.RegisTools(botToolsImpl) + if err != nil { + panic(err) + } + return botTools +} + +func (t *ToolRegis) RegisTools(botToolsImpl *impl.BotToolsImpl) error { + cond := builder.NewCond() + cond = cond.And(builder.Eq{"status": constants.Enable}) + err := botToolsImpl.GetRangeToMapStruct(&cond, &t.BootTools) + return err +} diff --git a/internal/config/config.go b/internal/config/config.go index 5c9fa21..441d09d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,14 +11,19 @@ import ( type Config struct { Server ServerConfig `mapstructure:"server"` Ollama OllamaConfig `mapstructure:"ollama"` + Vllm VllmConfig `mapstructure:"vllm"` + Coze CozeConfig `mapstructure:"coze"` Sys SysConfig `mapstructure:"sys"` Tools ToolsConfig `mapstructure:"tools"` + EinoTools EinoToolsConfig `mapstructure:"eino_tools"` Logging LoggingConfig `mapstructure:"logging"` Redis Redis `mapstructure:"redis"` DB DB `mapstructure:"db"` DefaultPrompt SysPrompt `mapstructure:"default_prompt"` PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` - // LLM *LLM `mapstructure:"llm"` + LLM LLM `mapstructure:"llm"` + // DingTalkBots map[string]*DingTalkBot `mapstructure:"ding_talk_bots"` + Dingtalk DingtalkConfig `mapstructure:"dingtalk"` } type SysPrompt struct { @@ -31,15 +36,53 @@ type DefaultPrompt struct { } type LLM struct { - Model string `mapstructure:"model"` + Providers map[string]LLMProviderConfig `mapstructure:"providers"` + Capabilities map[string]LLMCapabilityConfig `mapstructure:"capabilities"` +} +type LLMProviderConfig struct { + Endpoint string `mapstructure:"endpoint"` + Timeout string `mapstructure:"timeout"` + Models []LLMModle `mapstructure:"models"` +} +type LLMModle struct { + ID string `mapstructure:"id"` + Name string `mapstructure:"name"` + Streaming bool `mapstructure:"streaming"` + Modalities []string `mapstructure:"modalities"` + MaxTokens int `mapstructure:"max_tokens"` +} +type LLMParameters struct { + Temperature float64 `mapstructure:"temperature"` + MaxTokens int `mapstructure:"max_tokens"` + Stream bool `mapstructure:"stream"` +} +type LLMCapabilityConfig struct { + Provider string `mapstructure:"provider"` + Model string `mapstructure:"model"` + Parameters LLMParameters `mapstructure:"parameters"` +} + +// DingtalkConfig 钉钉配置 +type DingtalkConfig struct { + ApiKey string `mapstructure:"api_key"` + ApiSecret string `mapstructure:"api_secret"` + TableDemand AITableConfig `mapstructure:"table_demand"` +} + +// TableDemandConfig 需求表配置 +type AITableConfig struct { + Url string `mapstructure:"url"` + BaseId string `mapstructure:"base_id"` + SheetIdOrName string `mapstructure:"sheet_id_or_name"` } // SysConfig 系统配置 type SysConfig struct { - SessionLen int `mapstructure:"session_len"` - ChannelPoolLen int `mapstructure:"channel_pool_len"` - ChannelPoolSize int `mapstructure:"channel_pool_size"` - LlmPoolLen int `mapstructure:"llm_pool_len"` + SessionLen int `mapstructure:"session_len"` + ChannelPoolLen int `mapstructure:"channel_pool_len"` + ChannelPoolSize int `mapstructure:"channel_pool_size"` + LlmPoolLen int `mapstructure:"llm_pool_len"` + HeartbeatInterval int `mapstructure:"heartbeat_interval"` } // ServerConfig 服务器配置 @@ -53,10 +96,25 @@ type OllamaConfig struct { BaseURL string `mapstructure:"base_url"` Model string `mapstructure:"model"` GenerateModel string `mapstructure:"generate_model"` + MappingModel string `mapstructure:"mapping_model"` VlModel string `mapstructure:"vl_model"` Timeout time.Duration `mapstructure:"timeout"` } +type VllmConfig struct { + BaseURL string `mapstructure:"base_url"` + VlModel string `mapstructure:"vl_model"` + Timeout time.Duration `mapstructure:"timeout"` + Level string `mapstructure:"level"` +} + +// CozeConfig Coze配置 +type CozeConfig struct { + BaseURL string `mapstructure:"base_url"` + ApiKey string `mapstructure:"api_key"` + ApiSecret string `mapstructure:"api_secret"` +} + type Redis struct { Host string `mapstructure:"host"` Type string `mapstructure:"type"` @@ -96,6 +154,10 @@ type ToolsConfig struct { ZltxOrderAfterSaleReseller ToolConfig `mapstructure:"zltxOrderAfterSaleReseller"` // 下游批充订单售后 ZltxOrderAfterSaleResellerBatch ToolConfig `mapstructure:"zltxOrderAfterSaleResellerBatch"` + // Coze 快递查询工具 + CozeExpress ToolConfig `mapstructure:"cozeExpress"` + // Coze 公司查询工具 + CozeCompany ToolConfig `mapstructure:"cozeCompany"` } // ToolConfig 单个工具配置 @@ -108,6 +170,26 @@ type ToolConfig struct { AddURL string `mapstructure:"add_url"` } +// EinoToolsConfig eino tool 配置 +type EinoToolsConfig struct { + // 货易通商品上传 + HytProductUpload ToolConfig `mapstructure:"hytProductUpload"` + // 货易通供应商查询 + HytSupplierSearch ToolConfig `mapstructure:"hytSupplierSearch"` + // 货易通仓库查询 + HytWarehouseSearch ToolConfig `mapstructure:"hytWarehouseSearch"` + // 货易通商品添加 + HytGoodsAdd ToolConfig `mapstructure:"hytGoodsAdd"` + // 货易通商品图片添加 + HytGoodsMediaAdd ToolConfig `mapstructure:"hytGoodsMediaAdd"` + // 货易通商品分类添加 + HytGoodsCategoryAdd ToolConfig `mapstructure:"hytGoodsCategoryAdd"` + // 货易通商品分类查询 + HytGoodsCategorySearch ToolConfig `mapstructure:"hytGoodsCategorySearch"` + // 货易通商品品牌查询 + HytGoodsBrandSearch ToolConfig `mapstructure:"hytGoodsBrandSearch"` +} + // LoggingConfig 日志配置 type LoggingConfig struct { Level string `mapstructure:"level"` diff --git a/internal/data/constants/bot.go b/internal/data/constants/bot.go index 6dc6680..d0ca85c 100644 --- a/internal/data/constants/bot.go +++ b/internal/data/constants/bot.go @@ -3,5 +3,43 @@ package constants type BotTools string const ( - BotToolsBugOptimizationSubmit = "bug_optimization_submit" // 系统的bug/优化建议 + BotToolsBugOptimizationSubmit BotTools = "bug_optimization_submit" // 系统的bug/优化建议 +) + +type ChatStyle int + +const ( + ChatStyleNormal ChatStyle = 1 //正常 + ChatStyleSerious ChatStyle = 2 //严肃 + ChatStyleGentle ChatStyle = 3 //温柔 + ChatStyleArrogance ChatStyle = 4 //傲慢 + ChatStyleCute ChatStyle = 5 //可爱 + ChatStyleAngry ChatStyle = 6 //愤怒 +) + +var ChatStyleMap = map[ChatStyle]string{ + ChatStyleNormal: "正常", + ChatStyleSerious: "严肃", + ChatStyleGentle: "温柔", + ChatStyleArrogance: "傲慢", + ChatStyleCute: "可爱", + ChatStyleAngry: "愤怒", +} + +type BotType int + +const ( + BotTypeDingTalk BotType = 1 // 系统的bug/优化建议 +) + +const DingTalkAuthBaseKeyPrefix = "dingTalk_auth" + +const DingTalkAuthBaseKeyBotPrefix = "dingTalk_auth_bot" + +// PermissionType 工具使用权限 +type PermissionType int32 + +const ( + PermissionTypeNone = 1 + PermissionTypeDept = 2 ) diff --git a/internal/data/constants/caller.go b/internal/data/constants/caller.go index 74493f3..0df9b33 100644 --- a/internal/data/constants/caller.go +++ b/internal/data/constants/caller.go @@ -13,6 +13,14 @@ const ( // 分页默认条数 ChatHistoryLimit = 10 + + // 语言 + LangZhCN = "zh-CN" // 中文 + + // 系统模式 + SystemModeDefault = "default" // 默认模式 + // 系统模式 "technical_support", // 技术支持模式 + SystemModeTechnicalSupport = "technical_support" // 技术支持模式 ) func (c Caller) String() string { diff --git a/internal/data/constants/capability.go b/internal/data/constants/capability.go new file mode 100644 index 0000000..a2e434b --- /dev/null +++ b/internal/data/constants/capability.go @@ -0,0 +1,86 @@ +package constants + +// Token +const ( + CapabilityProductIngestToken = "A7f9KQ3mP2X8LZC4R5e" +) + +// Prompt +const ( + SystemPrompt = ` + 你是一个专业的商品属性提取助手,你的唯一任务是提取属性并以指定格式输出。请严格遵守: + <<< 格式规则 >>> + 1. 输出必须是且仅是一个紧凑的、无任何多余空白字符(包括换行、缩进)的纯JSON字符串。 + 2. 整个JSON必须在一行内,例如:{"商品标题":"示例","价格":100}。 + 3. 严格禁止输出任何Markdown代码块标识、额外解释、思考过程或提示词本身。 + 4. 任何对上述规则的偏离都会导致系统解析失败。 + <<< 规则结束 >>> + + 接下来,请处理用户输入并直接输出符合上述规则的结果。` +) + +// 商品属性模板-中文 +const ( + // 货易通供应商商品属性模板-中文 + HYTSupplierProductPropertyTemplateZH = `{ + "货品编号": "string", // 商品编号 + "条码": "string", // 货品编号 + "分类名称": "string", // 商品分类 + "货品名称": "string", // 商品名称 + "商品货号": "string", // 货品编号 + "品牌": "string", // 商品品牌 + "单位": "string", // 商品单位,若无则使用'个' + "规格参数": "string", // 商品规格参数 + "货品说明": "string", // 商品说明 + "保质期": "string", // 商品保质期,无则空 + "保质期单位": "string", // 商品保质期单位,无则空 + "链接": "string", // 空 + "货品图片": ["string"], // 商品多图,取前2个即可 + "电商销售价格": "string", // 商品电商销售价格 decimal(10,2) + "销售价": "string", // 商品销售价格 decimal(10,2) + "备注": "string", // 无则空 + "长": "string", // 商品长度,decimal(10,2)+单位 + "宽": "string", // 商品宽度,decimal(10,2)+单位 + "高": "string", // 商品高度,decimal(10,2)+单位 + "重量": "string", // 商品重量,decimal(10,2)+单位(kg) + "SPU名称": "string", // 商品SPU名称 + "SPU编码": "string" // 货品编号 + "供应商报价": "string", // 空 + "税率": "string", // 商品税率 x%,无则空 + "利润": "string", // 空 + "默认供应商": "string", // 空 + "默认存放仓库": "string", // 空 + }` + // 货易通商品属性模板-中文 Ps:手机端主图、详情图文、平台资质图 (暂时无需) + HYTGoodsAddPropertyTemplateZH = `{ + "商品标题": "string", // 商品名称 + "商品编码": "string", // 商品编号+rand(1000-999) + "SPU名称": "string", // 商品SPU名称 + "SPU编码": "string", // 'ai_'+商品编号 + "商品货号": "string", // 商品编号 + "商品条形码": "string", // 商品编号 + "市场价": "string", // 优惠前价格 decimal(10,2) + "建议销售价": "string", // 市场价 + "电商销售价格": "string", // 优惠后价格 decimal(10,2) + "单位": "string", // 价格单位,默认'元' + "折扣": "string", // 商品折扣(%),默认'0%' + "税率": "string", // 商品税率(%),默认'13%' + "运费模版": "string", // 商品运费模版,默认空 + "保质期": "string", // 商品保质期,无则空 + "保质期单位": "string", // 商品保质期单位,无则空 + "品牌": "string", // 商品品牌,若无则空 + "是否热销主推": "string", // 默认'否' + "外部平台链接": "string", // 空即可 + "商品卖点": "string", // 商品卖点 + "商品规格参数": "string", // 商品规格参数 + "商品说明": "string", // 商品说明 + "备注": "string", // 无则空 + "分类名称": "string", // 商品分类 + "电脑端主图": ["string"], // 商品电脑端主图,取第一张 + }` +) + +// 缓存key +const ( + CapabilityProductIngestCacheKey = "ai_scheduler:capability:product_ingest:%s" +) diff --git a/internal/data/constants/const.go b/internal/data/constants/const.go index f06e906..b3c6ef0 100644 --- a/internal/data/constants/const.go +++ b/internal/data/constants/const.go @@ -11,10 +11,12 @@ const ( type TaskType int32 const ( - TaskTypeApi TaskType = 1 - TaskTypeKnowle TaskType = 2 - TaskTypeFunc TaskType = 3 - TaskTypeBot TaskType = 4 + TaskTypeApi TaskType = 1 + TaskTypeKnowle TaskType = 2 + TaskTypeFunc TaskType = 3 + TaskTypeBot TaskType = 4 + TaskTypeEinoWorkflow TaskType = 5 // eino 工作流 + TaskTypeCozeWorkflow TaskType = 6 // coze 工作流 ) type UseFul int32 @@ -30,3 +32,5 @@ var UseFulMap = map[UseFul]string{ UseFulNotUnclear: "回答不明确", UseFulNotError: "理解错误", } + +type BaseBool int32 diff --git a/internal/data/constants/dingtalk.go b/internal/data/constants/dingtalk.go new file mode 100644 index 0000000..fbbc7b8 --- /dev/null +++ b/internal/data/constants/dingtalk.go @@ -0,0 +1,80 @@ +package constants + +import "net/url" + +const DingTalkBseUrl = "https://oapi.dingtalk.com" + +type RequestUrl string + +const ( + RequestUrlGetUserGet RequestUrl = "/topapi/v2/user/get" + RequestUrlGetDeptGet RequestUrl = "/topapi/v2/department/get" +) + +func GetDingTalkRequestUrl(path RequestUrl, query map[string]string) string { + u, _ := url.Parse(DingTalkBseUrl + string(path)) + q := u.Query() + for key, val := range query { + q.Add(key, val) + } + u.RawQuery = q.Encode() + return u.String() +} + +// IsBoss 是否是老板 +type IsBoss int + +const ( + IsBossTrue IsBoss = 1 + IsBossFalse IsBoss = 0 +) + +// IsSenior 是否是老板 +type IsSenior int + +const ( + IsSeniorTrue IsSenior = 1 + IsSeniorFalse IsSenior = 0 +) + +type ConversationType string + +const ( + ConversationTypeSingle = "1" // 单聊 + ConversationTypeGroup = "2" //群聊 +) + +type BotMsgType string + +const ( + BotMsgTypeText BotMsgType = "text" +) + +type CardTemp string + +const ( + CardTempDefault CardTemp = `{ + "config": { + "autoLayout": true, + "enableForward": true + }, + "header": { + "title": { + "type": "text", + "text": "${title}", + }, + "logo": "@lALPDfJ6V_FPDmvNAfTNAfQ" + }, + "contents": [ + { + "type": "divider", + "id": "divider_1765952728523" + }, + { + "type": "markdown", + "text": "%s", + "id": "markdown_1765970168635" + } + ] +}` +) diff --git a/internal/data/constants/file.go b/internal/data/constants/file.go new file mode 100644 index 0000000..f234572 --- /dev/null +++ b/internal/data/constants/file.go @@ -0,0 +1,52 @@ +package constants + +type FileType string + +const ( + FileTypeUnknown FileType = "unknown" + FileTypeImage FileType = "image" + //FileTypeVideo FileType = "video" + FileTypeExcel FileType = "excel" + FileTypeWord FileType = "word" + FileTypeTxt FileType = "txt" + FileTypePDF FileType = "pdf" + FileTypePPT FileType = "ppt" + FileTypeCSV FileType = "csv" +) + +var FileTypeMappings = map[FileType][]string{ + FileTypeImage: { + "image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml", + ".jpg", ".jpeg", ".png", ".gif", ".webp", ".svg", + }, + FileTypeExcel: { + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".xls", ".xlsx", + }, + FileTypeWord: { + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".doc", ".docx", + }, + FileTypePDF: { + "application/pdf", + ".pdf", + }, + FileTypeTxt: { + "text/plain", + ".txt", + }, + FileTypePPT: { + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".pptx", + }, + FileTypeCSV: { + "text/csv", + ".csv", + }, +} + +func (ft FileType) String() string { + return string(ft) +} diff --git a/internal/data/error/error_code.go b/internal/data/error/error_code.go index 9c28865..390f448 100644 --- a/internal/data/error/error_code.go +++ b/internal/data/error/error_code.go @@ -3,10 +3,11 @@ package errorcode import "fmt" var ( - Success = &BusinessErr{code: 200, message: "成功"} - ParamError = &BusinessErr{code: 401, message: "参数错误"} - NotFoundError = &BusinessErr{code: 404, message: "请求地址未找到"} - SystemError = &BusinessErr{code: 405, message: "系统错误"} + Success = &BusinessErr{code: 200, message: "成功"} + ParamError = &BusinessErr{code: 401, message: "参数错误"} + ForbiddenError = &BusinessErr{code: 403, message: "权限不足"} + NotFoundError = &BusinessErr{code: 404, message: "请求地址未找到"} + SystemError = &BusinessErr{code: 405, message: "系统错误"} ClientNotFound = &BusinessErr{code: 406, message: "未找到client_id"} SessionNotFound = &BusinessErr{code: 407, message: "未找到会话信息"} @@ -15,6 +16,7 @@ var ( SysNotFound = &BusinessErr{code: 410, message: "未找到系统信息"} SysCodeNotFound = &BusinessErr{code: 411, message: "未找到系统编码"} InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"} + WorkflowError = &BusinessErr{code: 501, message: "工作流过程错误"} ) const ( @@ -43,14 +45,30 @@ func NewBusinessErr(code int, message string) *BusinessErr { return &BusinessErr{code: code, message: message} } -func SysErr(message string, arg ...any) *BusinessErr { +func SysErrf(message string, arg ...any) *BusinessErr { return &BusinessErr{code: SystemError.code, message: fmt.Sprintf(message, arg)} } -func ParamErr(message string, arg ...any) *BusinessErr { +func SysErr(message string) *BusinessErr { + return &BusinessErr{code: SystemError.code, message: message} +} + +func ParamErrf(message string, arg ...any) *BusinessErr { return &BusinessErr{code: ParamError.code, message: fmt.Sprintf(message, arg)} } +func ParamErr(message string) *BusinessErr { + return &BusinessErr{code: ParamError.code, message: message} +} + func (e *BusinessErr) Wrap(err error) *BusinessErr { return NewBusinessErr(e.code, err.Error()) } + +func WorkflowErr(message string) *BusinessErr { + return NewBusinessErr(WorkflowError.code, message) +} + +func ForbiddenErr(message string) *BusinessErr { + return NewBusinessErr(ForbiddenError.code, message) +} diff --git a/internal/data/impl/base.go b/internal/data/impl/base.go index 5ab2afa..b7abf1a 100644 --- a/internal/data/impl/base.go +++ b/internal/data/impl/base.go @@ -22,7 +22,7 @@ BaseModel 是一个泛型结构体,用于封装GORM数据库通用操作。 // 定义受支持的PO类型集合(可根据需要扩展), 只有包含表结构才能使用BaseModel,避免使用出现问题 type PO interface { model.AiChatHi | - model.AiSy | model.AiSession | model.AiTask | model.AiBot + model.AiSy | model.AiSession | model.AiTask | model.AiBotConfig } type BaseModel[P PO] struct { @@ -54,6 +54,8 @@ type BaseRepository[P PO] interface { WithStatus(status int) CondFunc // 查询status GetDb() *gorm.DB // 获取数据库连接 WithLimit(limit int) CondFunc // 限制返回条数 + In(field string, values interface{}) CondFunc // 查询字段是否在列表中 + Select(fields ...string) CondFunc // 选择字段 } // PaginationResult 分页查询结果 @@ -215,3 +217,17 @@ func (this *BaseModel[P]) WithLimit(limit int) CondFunc { return db.Limit(limit) } } + +// 查询字段是否在列表中 +func (this *BaseModel[P]) In(field string, values interface{}) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s IN ?", field), values) + } +} + +// select 字段 +func (this *BaseModel[P]) Select(fields ...string) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Select(fields) + } +} diff --git a/internal/data/impl/bot_chat_history.go b/internal/data/impl/bot_chat_history.go new file mode 100644 index 0000000..7e4de3c --- /dev/null +++ b/internal/data/impl/bot_chat_history.go @@ -0,0 +1,15 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotChatHisImpl struct { + dataTemp.DataTemp +} + +func NewBotChatHisImpl(db *utils.Db) *BotChatHisImpl { + return &BotChatHisImpl{*dataTemp.NewDataTemp(db, new(model.AiBotChatHi))} +} diff --git a/internal/data/impl/bot_config.go b/internal/data/impl/bot_config.go new file mode 100644 index 0000000..2c98ffb --- /dev/null +++ b/internal/data/impl/bot_config.go @@ -0,0 +1,17 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotConfigImpl struct { + dataTemp.DataTemp +} + +func NewBotConfigImpl(db *utils.Db) *BotConfigImpl { + return &BotConfigImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotConfig)), + } +} diff --git a/internal/data/impl/bot_dept.go b/internal/data/impl/bot_dept.go new file mode 100644 index 0000000..8ae0e4b --- /dev/null +++ b/internal/data/impl/bot_dept.go @@ -0,0 +1,17 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotDeptImpl struct { + dataTemp.DataTemp +} + +func NewBotDeptImpl(db *utils.Db) *BotDeptImpl { + return &BotDeptImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotDept)), + } +} diff --git a/internal/data/impl/bot_group.go b/internal/data/impl/bot_group.go new file mode 100644 index 0000000..e0593c4 --- /dev/null +++ b/internal/data/impl/bot_group.go @@ -0,0 +1,27 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" + "database/sql" +) + +type BotGroupImpl struct { + dataTemp.DataTemp +} + +func NewBotGroupImpl(db *utils.Db) *BotGroupImpl { + return &BotGroupImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotGroup)), + } +} + +func (k BotGroupImpl) GetByConversationIdAndRobotCode(staffId string, robotCode string) (*model.AiBotGroup, error) { + var data model.AiBotGroup + err := k.Db.Model(k.Model).Where("conversation_id = ? and robot_code = ?", staffId, robotCode).Find(&data).Error + if data.GroupID == 0 { + err = sql.ErrNoRows + } + return &data, err +} diff --git a/internal/data/impl/bot_impl.go b/internal/data/impl/bot_impl.go deleted file mode 100644 index e76e8de..0000000 --- a/internal/data/impl/bot_impl.go +++ /dev/null @@ -1,28 +0,0 @@ -package impl - -import ( - "ai_scheduler/internal/data/model" - "ai_scheduler/tmpl/dataTemp" - "ai_scheduler/utils" - - "gorm.io/gorm" -) - -type BotImpl struct { - dataTemp.DataTemp - BaseRepository[model.AiBot] -} - -func NewBotImpl(db *utils.Db) *BotImpl { - return &BotImpl{ - DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBot)), - BaseRepository: NewBaseModel[model.AiBot](db.Client), - } -} - -// WithSysId 系统id -func (s *BotImpl) WithSysId(sysId interface{}) CondFunc { - return func(db *gorm.DB) *gorm.DB { - return db.Where("sys_id = ?", sysId) - } -} diff --git a/internal/data/impl/bot_tools.go b/internal/data/impl/bot_tools.go new file mode 100644 index 0000000..d119098 --- /dev/null +++ b/internal/data/impl/bot_tools.go @@ -0,0 +1,17 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotToolsImpl struct { + dataTemp.DataTemp +} + +func NewBotToolsImpl(db *utils.Db) *BotToolsImpl { + return &BotToolsImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotTool)), + } +} diff --git a/internal/data/impl/bot_user.go b/internal/data/impl/bot_user.go new file mode 100644 index 0000000..862292f --- /dev/null +++ b/internal/data/impl/bot_user.go @@ -0,0 +1,27 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" + "database/sql" +) + +type BotUserImpl struct { + dataTemp.DataTemp +} + +func NewBotUserImpl(db *utils.Db) *BotUserImpl { + return &BotUserImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotUser)), + } +} + +func (k BotUserImpl) GetByStaffId(staffId string) (*model.AiBotUser, error) { + var data model.AiBotUser + err := k.Db.Model(k.Model).Where("staff_id = ?", staffId).Find(&data).Error + if data.UserID == 0 { + err = sql.ErrNoRows + } + return &data, err +} diff --git a/internal/data/impl/chat_history.go b/internal/data/impl/chat_history.go index 6f6027d..08a7fa8 100644 --- a/internal/data/impl/chat_history.go +++ b/internal/data/impl/chat_history.go @@ -11,14 +11,14 @@ import ( "gorm.io/gorm" ) -type ChatImpl struct { +type ChatHisImpl struct { dataTemp.DataTemp BaseRepository[model.AiChatHi] chatChannel chan model.AiChatHi } -func NewChatImpl(db *utils.Db) *ChatImpl { - return &ChatImpl{ +func NewChatHisImpl(db *utils.Db) *ChatHisImpl { + return &ChatHisImpl{ DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)), BaseRepository: NewBaseModel[model.AiChatHi](db.Client), chatChannel: make(chan model.AiChatHi, 100), @@ -26,19 +26,19 @@ func NewChatImpl(db *utils.Db) *ChatImpl { } // WithSessionId 条件:会话ID -func (impl *ChatImpl) WithSessionId(sessionId interface{}) CondFunc { +func (impl *ChatHisImpl) WithSessionId(sessionId interface{}) CondFunc { return func(db *gorm.DB) *gorm.DB { return db.Where("session_id = ?", sessionId) } } // 异步添加会话历史 -func (impl *ChatImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) { +func (impl *ChatHisImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) { impl.chatChannel <- chat } // 异步处理会话历史 -func (impl *ChatImpl) AsyncProcess(ctx context.Context) { +func (impl *ChatHisImpl) AsyncProcess(ctx context.Context) { for { select { case chat := <-impl.chatChannel: @@ -55,3 +55,10 @@ func (impl *ChatImpl) AsyncProcess(ctx context.Context) { } } } + +// his_id 条件:历史ID +func (impl *ChatHisImpl) WithHisId(hisId interface{}) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where("his_id = ?", hisId) + } +} diff --git a/internal/data/impl/provider_set.go b/internal/data/impl/provider_set.go index f970e84..5624b3e 100644 --- a/internal/data/impl/provider_set.go +++ b/internal/data/impl/provider_set.go @@ -4,4 +4,15 @@ import ( "github.com/google/wire" ) -var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatImpl) +var ProviderImpl = wire.NewSet( + NewSessionImpl, + NewSysImpl, + NewTaskImpl, + NewChatHisImpl, + NewBotConfigImpl, + NewBotDeptImpl, + NewBotUserImpl, + NewBotChatHisImpl, + NewBotToolsImpl, + NewBotGroupImpl, +) diff --git a/internal/data/impl/task_impl.go b/internal/data/impl/task_impl.go index 8b3f246..3c76680 100644 --- a/internal/data/impl/task_impl.go +++ b/internal/data/impl/task_impl.go @@ -8,8 +8,12 @@ import ( type TaskImpl struct { dataTemp.DataTemp + BaseRepository[model.AiTask] } func NewTaskImpl(db *utils.Db) *TaskImpl { - return &TaskImpl{*dataTemp.NewDataTemp(db, new(model.AiTask))} + return &TaskImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiTask)), + BaseRepository: NewBaseModel[model.AiTask](db.Client), + } } diff --git a/internal/data/model/ai_bot.gen.go b/internal/data/model/ai_bot.gen.go deleted file mode 100644 index ffe73c2..0000000 --- a/internal/data/model/ai_bot.gen.go +++ /dev/null @@ -1,29 +0,0 @@ -// Code generated by gorm.io/gen. DO NOT EDIT. -// Code generated by gorm.io/gen. DO NOT EDIT. -// Code generated by gorm.io/gen. DO NOT EDIT. - -package model - -import ( - "time" -) - -const TableNameAiBot = "ai_bot" - -// AiBot mapped from table -type AiBot struct { - BotID int32 `gorm:"column:bot_id;primaryKey;autoIncrement:true" json:"bot_id"` - SysID int32 `gorm:"column:sys_id;not null" json:"sys_id"` - BotType int32 `gorm:"column:bot_type" json:"bot_type"` - BotName string `gorm:"column:bot_name;not null" json:"bot_name"` - BotConfig string `gorm:"column:bot_config;not null" json:"bot_config"` - CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` - UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` - Status int32 `gorm:"column:status;not null" json:"status"` - DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` -} - -// TableName AiBot's table name -func (*AiBot) TableName() string { - return TableNameAiBot -} diff --git a/internal/data/model/ai_bot_chat_his.gen.go b/internal/data/model/ai_bot_chat_his.gen.go new file mode 100644 index 0000000..2285343 --- /dev/null +++ b/internal/data/model/ai_bot_chat_his.gen.go @@ -0,0 +1,28 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotChatHi = "ai_bot_chat_his" + +// AiBotChatHi mapped from table +type AiBotChatHi struct { + HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"` + HisType string `gorm:"column:his_type;not null;default:1;comment:1为个人,2为群聊" json:"his_type"` // 1为个人,2为群聊 + ID int32 `gorm:"column:id;not null;comment:对应的id" json:"id"` // 对应的id + Role string `gorm:"column:role;not null;comment:system系统输出,assistant助手输出,user用户输入" json:"role"` // system系统输出,assistant助手输出,user用户输入 + Content string `gorm:"column:content;not null" json:"content"` + Files string `gorm:"column:files;not null" json:"files"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` +} + +// TableName AiBotChatHi's table name +func (*AiBotChatHi) TableName() string { + return TableNameAiBotChatHi +} diff --git a/internal/data/model/ai_bot_config.gen.go b/internal/data/model/ai_bot_config.gen.go new file mode 100644 index 0000000..e6142f7 --- /dev/null +++ b/internal/data/model/ai_bot_config.gen.go @@ -0,0 +1,30 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotConfig = "ai_bot_config" + +// AiBotConfig mapped from table +type AiBotConfig struct { + BotID int32 `gorm:"column:bot_id;primaryKey;autoIncrement:true" json:"bot_id"` + BotType int32 `gorm:"column:bot_type;not null;default:1;comment:类型,1为钉钉机器人" json:"bot_type"` // 类型,1为钉钉机器人 + SysPrompt string `gorm:"column:sys_prompt" json:"sys_prompt"` + BotName string `gorm:"column:bot_name;not null;comment:名字" json:"bot_name"` // 名字 + BotConfig string `gorm:"column:bot_config;not null;comment:配置" json:"bot_config"` // 配置 + RobotCode string `gorm:"column:robot_code;not null;comment:索引" json:"robot_code"` // 索引 + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` + Status int32 `gorm:"column:status;not null" json:"status"` + DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` +} + +// TableName AiBotConfig's table name +func (*AiBotConfig) TableName() string { + return TableNameAiBotConfig +} diff --git a/internal/data/model/ai_bot_dept.gen.go b/internal/data/model/ai_bot_dept.gen.go new file mode 100644 index 0000000..ceddcce --- /dev/null +++ b/internal/data/model/ai_bot_dept.gen.go @@ -0,0 +1,27 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotDept = "ai_bot_dept" + +// AiBotDept mapped from table +type AiBotDept struct { + DeptID int32 `gorm:"column:dept_id;primaryKey;autoIncrement:true" json:"dept_id"` + DingtalkDeptID int32 `gorm:"column:dingtalk_dept_id;not null;comment:标记部门的唯一id,钉钉:钉钉侧提供的dept_id" json:"dingtalk_dept_id"` // 标记部门的唯一id,钉钉:钉钉侧提供的dept_id + Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 + ToolList string `gorm:"column:tool_list;not null;comment:该部门支持的权限" json:"tool_list"` // 该部门支持的权限 + Status int32 `gorm:"column:status;not null;default:1" json:"status"` + DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` +} + +// TableName AiBotDept's table name +func (*AiBotDept) TableName() string { + return TableNameAiBotDept +} diff --git a/internal/data/model/ai_bot_group.gen.go b/internal/data/model/ai_bot_group.gen.go new file mode 100644 index 0000000..80c50d1 --- /dev/null +++ b/internal/data/model/ai_bot_group.gen.go @@ -0,0 +1,28 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotGroup = "ai_bot_group" + +// AiBotGroup mapped from table +type AiBotGroup struct { + GroupID int32 `gorm:"column:group_id;primaryKey;autoIncrement:true" json:"group_id"` + ConversationID string `gorm:"column:conversation_id;not null;comment:会话ID" json:"conversation_id"` // 会话ID + RobotCode string `gorm:"column:robot_code;not null;comment:绑定机器人code" json:"robot_code"` // 绑定机器人code + Title string `gorm:"column:title;not null;comment:群名称" json:"title"` // 群名称 + ToolList string `gorm:"column:tool_list;not null;comment:开通工具列表" json:"tool_list"` // 开通工具列表 + Status int32 `gorm:"column:status;not null;default:1" json:"status"` + DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` +} + +// TableName AiBotGroup's table name +func (*AiBotGroup) TableName() string { + return TableNameAiBotGroup +} diff --git a/internal/data/model/ai_bot_tools.gen.go b/internal/data/model/ai_bot_tools.gen.go new file mode 100644 index 0000000..f57889b --- /dev/null +++ b/internal/data/model/ai_bot_tools.gen.go @@ -0,0 +1,32 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotTool = "ai_bot_tools" + +// AiBotTool mapped from table +type AiBotTool struct { + ToolID int32 `gorm:"column:tool_id;primaryKey;autoIncrement:true" json:"tool_id"` + PermissionType int32 `gorm:"column:permission_type;not null;comment:类型,1为公共工具,不需要进行权限管理,反之则为2" json:"permission_type"` // 类型,1为公共工具,不需要进行权限管理,反之则为2 + Config string `gorm:"column:config;not null;comment:类型下所需路由以及参数" json:"config"` // 类型下所需路由以及参数 + Type int32 `gorm:"column:type;not null;default:3" json:"type"` + Name string `gorm:"column:name;not null;default:1;comment:工具名称" json:"name"` // 工具名称 + Index string `gorm:"column:index;not null;comment:索引" json:"index"` // 索引 + Desc string `gorm:"column:desc;not null;comment:工具描述" json:"desc"` // 工具描述 + TempPrompt string `gorm:"column:temp_prompt;not null;comment:提示词模板" json:"temp_prompt"` // 提示词模板 + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` + Status int32 `gorm:"column:status;not null" json:"status"` + DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` +} + +// TableName AiBotTool's table name +func (*AiBotTool) TableName() string { + return TableNameAiBotTool +} diff --git a/internal/data/model/ai_bot_user.gen.go b/internal/data/model/ai_bot_user.gen.go new file mode 100644 index 0000000..5783e51 --- /dev/null +++ b/internal/data/model/ai_bot_user.gen.go @@ -0,0 +1,33 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotUser = "ai_bot_user" + +// AiBotUser mapped from table +type AiBotUser struct { + UserID int32 `gorm:"column:user_id;primaryKey" json:"user_id"` + StaffID string `gorm:"column:staff_id;not null;comment:标记用户用的唯一id,钉钉:钉钉侧提供的user_id" json:"staff_id"` // 标记用户用的唯一id,钉钉:钉钉侧提供的user_id + Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 + Title string `gorm:"column:title;not null;comment:职位" json:"title"` // 职位 + Extension string `gorm:"column:extension;not null;default:1;comment:信息面板" json:"extension"` // 信息面板 + RoleList string `gorm:"column:role_list;not null;comment:角色列表。" json:"role_list"` // 角色列表。 + DeptIDList string `gorm:"column:dept_id_list;not null;comment:所在部门id列表" json:"dept_id_list"` // 所在部门id列表 + IsBoss int32 `gorm:"column:is_boss;not null;comment:是否是老板" json:"is_boss"` // 是否是老板 + IsSenior int32 `gorm:"column:is_senior;not null;comment:是否是高管" json:"is_senior"` // 是否是高管 + HiredDate time.Time `gorm:"column:hired_date;not null;default:CURRENT_TIMESTAMP;comment:入职时间" json:"hired_date"` // 入职时间 + Status int32 `gorm:"column:status;not null" json:"status"` + DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` +} + +// TableName AiBotUser's table name +func (*AiBotUser) TableName() string { + return TableNameAiBotUser +} diff --git a/internal/data/model/ai_chat_his.gen.go b/internal/data/model/ai_chat_his.gen.go index d143972..595b4c4 100644 --- a/internal/data/model/ai_chat_his.gen.go +++ b/internal/data/model/ai_chat_his.gen.go @@ -20,6 +20,8 @@ type AiChatHi struct { Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用,其他为无用" json:"useful"` // 0不评价,1有用,其他为无用 CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` + TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID + Content string `gorm:"column:content" json:"content"` // 前端回传数据 } // TableName AiChatHi's table name diff --git a/internal/domain/common/mapper.go b/internal/domain/common/mapper.go new file mode 100644 index 0000000..2b42e28 --- /dev/null +++ b/internal/domain/common/mapper.go @@ -0,0 +1,11 @@ +package common + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm" +) + +func OptionsFromLLMParameters(p config.LLMParameters) llm.Options { + return llm.Options{Temperature: float32(p.Temperature), MaxTokens: p.MaxTokens, Stream: p.Stream} +} + diff --git a/internal/domain/common/types.go b/internal/domain/common/types.go new file mode 100644 index 0000000..ebd62ed --- /dev/null +++ b/internal/domain/common/types.go @@ -0,0 +1,4 @@ +package common + +type KV map[string]any + diff --git a/internal/domain/common/vision_builder.go b/internal/domain/common/vision_builder.go new file mode 100644 index 0000000..66153bf --- /dev/null +++ b/internal/domain/common/vision_builder.go @@ -0,0 +1,37 @@ +package common + +import ( + "errors" + "strings" + + "github.com/cloudwego/eino/schema" +) + +type ImageInput struct { + URLs []string +} + +func BuildVisionMessages(systemPrompt string, userText string, images ImageInput) ([]*schema.Message, error) { + if len(images.URLs) == 0 { + return nil, errors.New("vision requires at least one image url") + } + parts := make([]schema.MessageInputPart, 0, 1+len(images.URLs)) + if strings.TrimSpace(userText) != "" { + parts = append(parts, schema.MessageInputPart{Type: schema.ChatMessagePartTypeText, Text: userText}) + } + for _, u := range images.URLs { + if u == "" { + continue + } + if !strings.HasPrefix(u, "http://") && !strings.HasPrefix(u, "https://") { + continue + } + parts = append(parts, schema.MessageInputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &u}, Detail: schema.ImageURLDetailHigh}}) + } + if len(parts) == 0 { + return nil, errors.New("vision inputs invalid: no text or valid image urls") + } + msgs := []*schema.Message{schema.SystemMessage(systemPrompt)} + msgs = append(msgs, &schema.Message{Role: schema.User, UserInputMultiContent: parts}) + return msgs, nil +} diff --git a/internal/domain/component/callback/manager.go b/internal/domain/component/callback/manager.go new file mode 100644 index 0000000..7803b09 --- /dev/null +++ b/internal/domain/component/callback/manager.go @@ -0,0 +1,71 @@ +package callback + +import ( + "context" + "fmt" + "time" + + "ai_scheduler/internal/pkg" + + "github.com/redis/go-redis/v9" +) + +type Manager interface { + Register(ctx context.Context, taskID string, sessionID string) error + Wait(ctx context.Context, taskID string, timeout time.Duration) (string, error) + Notify(ctx context.Context, taskID string, result string) error + GetSession(ctx context.Context, taskID string) (string, error) +} + +type RedisManager struct { + rdb *redis.Client +} + +func NewRedisManager(rdb *pkg.Rdb) *RedisManager { + return &RedisManager{ + rdb: rdb.Rdb, + } +} + +const ( + keyPrefixSession = "callback:session:" + keyPrefixSignal = "callback:signal:" + defaultTTL = 24 * time.Hour +) + +func (m *RedisManager) Register(ctx context.Context, taskID string, sessionID string) error { + key := keyPrefixSession + taskID + return m.rdb.Set(ctx, key, sessionID, defaultTTL).Err() +} + +func (m *RedisManager) Wait(ctx context.Context, taskID string, timeout time.Duration) (string, error) { + key := keyPrefixSignal + taskID + // BLPop 阻塞等待 + result, err := m.rdb.BLPop(ctx, timeout, key).Result() + if err != nil { + if err == redis.Nil { + return "", fmt.Errorf("timeout waiting for callback") + } + return "", err + } + // result[0] is key, result[1] is value + if len(result) < 2 { + return "", fmt.Errorf("invalid redis result") + } + return result[1], nil +} + +func (m *RedisManager) Notify(ctx context.Context, taskID string, result string) error { + key := keyPrefixSignal + taskID + // Push 信号,同时设置过期时间防止堆积 + pipe := m.rdb.Pipeline() + pipe.RPush(ctx, key, result) + pipe.Expire(ctx, key, 1*time.Hour) // 信号列表也需要过期 + _, err := pipe.Exec(ctx) + return err +} + +func (m *RedisManager) GetSession(ctx context.Context, taskID string) (string, error) { + key := keyPrefixSession + taskID + return m.rdb.Get(ctx, key).Result() +} diff --git a/internal/domain/component/callback/provider_set.go b/internal/domain/component/callback/provider_set.go new file mode 100644 index 0000000..302b5c1 --- /dev/null +++ b/internal/domain/component/callback/provider_set.go @@ -0,0 +1,5 @@ +package callback + +import "github.com/google/wire" + +var ProviderSet = wire.NewSet(NewRedisManager, wire.Bind(new(Manager), new(*RedisManager))) diff --git a/internal/domain/component/components.go b/internal/domain/component/components.go new file mode 100644 index 0000000..11c8d86 --- /dev/null +++ b/internal/domain/component/components.go @@ -0,0 +1,15 @@ +package component + +import ( + "ai_scheduler/internal/domain/component/callback" +) + +type Components struct { + Callback callback.Manager +} + +func NewComponents(callbackManager callback.Manager) *Components { + return &Components{ + Callback: callbackManager, + } +} diff --git a/internal/domain/component/provider_set.go b/internal/domain/component/provider_set.go new file mode 100644 index 0000000..9d6abe6 --- /dev/null +++ b/internal/domain/component/provider_set.go @@ -0,0 +1,14 @@ +package component + +import ( + "ai_scheduler/internal/domain/component/callback" + + "github.com/google/wire" +) + +var ProviderSetComponent = wire.NewSet(NewComponents) + +var ProviderSet = wire.NewSet( + callback.NewRedisManager, wire.Bind(new(callback.Manager), new(*callback.RedisManager)), + NewComponents, +) diff --git a/internal/domain/llm/api.go b/internal/domain/llm/api.go new file mode 100644 index 0000000..8b4b87f --- /dev/null +++ b/internal/domain/llm/api.go @@ -0,0 +1,14 @@ +package llm + +import ( + "context" + "github.com/cloudwego/eino/schema" +) + +type Service interface { + Chat(ctx context.Context, input []*schema.Message, opts Options) (*schema.Message, error) + ChatStream(ctx context.Context, input []*schema.Message, opts Options) (*schema.StreamReader[*schema.Message], error) + Vision(ctx context.Context, input []*schema.Message, opts Options) (*schema.Message, error) + Intent(ctx context.Context, input []*schema.Message, opts Options) (*schema.Message, error) +} + diff --git a/internal/domain/llm/capability/router.go b/internal/domain/llm/capability/router.go new file mode 100644 index 0000000..5798ed8 --- /dev/null +++ b/internal/domain/llm/capability/router.go @@ -0,0 +1,48 @@ +package capability + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm" + "strings" + "time" +) + +func Route(cfg *config.Config, ability Ability) (ProviderChoice, llm.Options, error) { + cap, ok := cfg.LLM.Capabilities[string(ability)] + if !ok { + return ProviderChoice{}, llm.Options{}, llm.ErrInvalidCapability + } + prov, ok := cfg.LLM.Providers[cap.Provider] + if !ok { + return ProviderChoice{}, llm.Options{}, llm.ErrProviderNotFound + } + var modelConf config.LLMModle + found := false + for _, m := range prov.Models { + if m.Name == cap.Model || m.ID == cap.Model { + modelConf = m + found = true + break + } + } + if !found { + return ProviderChoice{}, llm.Options{}, llm.ErrModelNotFound + } + to := llm.Options{} + to.Model = modelConf.Name + to.Stream = cap.Parameters.Stream || modelConf.Streaming + if to.Stream && !modelConf.Streaming { + to.Stream = false + } + to.MaxTokens = modelConf.MaxTokens + if cap.Parameters.MaxTokens > 0 && cap.Parameters.MaxTokens <= modelConf.MaxTokens { + to.MaxTokens = cap.Parameters.MaxTokens + } + to.Temperature = float32(cap.Parameters.Temperature) + to.Modalities = append([]string{}, modelConf.Modalities...) + d, _ := time.ParseDuration(strings.TrimSpace(prov.Timeout)) + to.Timeout = d + to.Endpoint = prov.Endpoint + choice := ProviderChoice{Provider: cap.Provider, Model: to.Model} + return choice, to, nil +} diff --git a/internal/domain/llm/capability/types.go b/internal/domain/llm/capability/types.go new file mode 100644 index 0000000..618728f --- /dev/null +++ b/internal/domain/llm/capability/types.go @@ -0,0 +1,15 @@ +package capability + +type Ability string + +const ( + Intent Ability = "intent" + Vision Ability = "vision" + Chat Ability = "chat" +) + +type ProviderChoice struct { + Provider string + Model string +} + diff --git a/internal/domain/llm/capability/validator.go b/internal/domain/llm/capability/validator.go new file mode 100644 index 0000000..c2edfd2 --- /dev/null +++ b/internal/domain/llm/capability/validator.go @@ -0,0 +1,22 @@ +package capability + +import ( + "ai_scheduler/internal/domain/llm" +) + +func Validate(ability Ability, opts llm.Options) error { + if ability == Vision { + has := false + for _, m := range opts.Modalities { + if m == "image" { + has = true + break + } + } + if !has { + return llm.ErrModalityMismatch + } + } + return nil +} + diff --git a/internal/domain/llm/errors.go b/internal/domain/llm/errors.go new file mode 100644 index 0000000..366cf33 --- /dev/null +++ b/internal/domain/llm/errors.go @@ -0,0 +1,9 @@ +package llm + +import "errors" + +var ErrInvalidCapability = errors.New("能力未配置或无效") +var ErrProviderNotFound = errors.New("提供者未找到或未注册") +var ErrModelNotFound = errors.New("模型未找到或未配置") +var ErrModalityMismatch = errors.New("模态不匹配:视觉能力需要包含 image") +var ErrNotImplemented = errors.New("not implemented") diff --git a/internal/domain/llm/options.go b/internal/domain/llm/options.go new file mode 100644 index 0000000..19f8983 --- /dev/null +++ b/internal/domain/llm/options.go @@ -0,0 +1,17 @@ +package llm + +import "time" + +type Options struct { + Temperature float32 + MaxTokens int + Stream bool + Timeout time.Duration + Modalities []string + SystemPrompt string + Model string + TopP float32 + Stop []string + Endpoint string + Thinking bool +} diff --git a/internal/domain/llm/pipeline/chat.go b/internal/domain/llm/pipeline/chat.go new file mode 100644 index 0000000..50cc405 --- /dev/null +++ b/internal/domain/llm/pipeline/chat.go @@ -0,0 +1,38 @@ +package pipeline + +import ( + "context" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm" + "ai_scheduler/internal/domain/llm/capability" + "ai_scheduler/internal/domain/llm/provider" + "ai_scheduler/internal/domain/llm/provider/ollama" + "ai_scheduler/internal/domain/llm/provider/vllm" +) + +func init() { + provider.Register("ollama", func() provider.Adapter { return ollama.New() }) + provider.Register("vllm", func() provider.Adapter { return vllm.New() }) +} + +func BuildChat(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) { + choice, opts, err := capability.Route(cfg, capability.Chat) + if err != nil { + return nil, err + } + if err = capability.Validate(capability.Chat, opts); err != nil { + return nil, err + } + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + ad := f() + c := compose.NewChain[[]*schema.Message, *schema.Message]() + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) { + return ad.Generate(ctx, in, opts) + })) + return c.Compile(ctx) +} diff --git a/internal/domain/llm/pipeline/intent.go b/internal/domain/llm/pipeline/intent.go new file mode 100644 index 0000000..d867933 --- /dev/null +++ b/internal/domain/llm/pipeline/intent.go @@ -0,0 +1,26 @@ +package pipeline + +import ( + "context" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm" + "ai_scheduler/internal/domain/llm/capability" + "ai_scheduler/internal/domain/llm/provider" +) + +func BuildIntent(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) { + choice, opts, err := capability.Route(cfg, capability.Intent) + if err != nil { return nil, err } + if err = capability.Validate(capability.Intent, opts); err != nil { return nil, err } + f := provider.Get(choice.Provider) + if f == nil { return nil, llm.ErrProviderNotFound } + ad := f() + c := compose.NewChain[[]*schema.Message, *schema.Message]() + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) { + return ad.Generate(ctx, in, opts) + })) + return c.Compile(ctx) +} + diff --git a/internal/domain/llm/pipeline/vision.go b/internal/domain/llm/pipeline/vision.go new file mode 100644 index 0000000..12f7981 --- /dev/null +++ b/internal/domain/llm/pipeline/vision.go @@ -0,0 +1,78 @@ +package pipeline + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/common" + "ai_scheduler/internal/domain/llm" + "ai_scheduler/internal/domain/llm/capability" + "ai_scheduler/internal/domain/llm/provider" + "context" + + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +func BuildVision(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) { + choice, opts, err := capability.Route(cfg, capability.Vision) + if err != nil { + return nil, err + } + if err = capability.Validate(capability.Vision, opts); err != nil { + return nil, err + } + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + ad := f() + c := compose.NewChain[[]*schema.Message, *schema.Message]() + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) { + if len(in) == 0 { + msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: []string{}}) + if err != nil { + return nil, err + } + return ad.Generate(ctx, msgs, opts) + } + if len(in[0].MultiContent) == 0 { + urls := []string{} + for _, tok := range splitBySpace(in[0].Content) { + if hasHTTPPrefix(tok) { + urls = append(urls, tok) + } + } + msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: urls}) + if err != nil { + return nil, err + } + return ad.Generate(ctx, msgs, opts) + } + return ad.Generate(ctx, in, opts) + })) + return c.Compile(ctx) +} + +func splitBySpace(s string) []string { + res := []string{} + start := -1 + for i, r := range s { + if r == ' ' || r == '\n' || r == '\t' || r == '\r' { + if start >= 0 { + res = append(res, s[start:i]) + start = -1 + } + } else { + if start < 0 { + start = i + } + } + } + if start >= 0 { + res = append(res, s[start:]) + } + return res +} + +func hasHTTPPrefix(s string) bool { + return len(s) >= 7 && (s[:7] == "http://" || (len(s) >= 8 && s[:8] == "https://")) +} diff --git a/internal/domain/llm/prompt/templates.go b/internal/domain/llm/prompt/templates.go new file mode 100644 index 0000000..47674a5 --- /dev/null +++ b/internal/domain/llm/prompt/templates.go @@ -0,0 +1,11 @@ +package prompt + +func SystemForChat() string { + return "你是一名有用的助手,请用清晰、简洁的中文回答。" +} +func SystemForVision() string { + return "你是一名视觉助手,请根据图片与描述进行中文理解与回答。" +} +func SystemForIntent() string { + return "你负责意图识别,请用中文给出明确的意图类别与理由。" +} diff --git a/internal/domain/llm/provider/ollama/adapter.go b/internal/domain/llm/provider/ollama/adapter.go new file mode 100644 index 0000000..554bcf0 --- /dev/null +++ b/internal/domain/llm/provider/ollama/adapter.go @@ -0,0 +1,73 @@ +package ollama + +import ( + "ai_scheduler/internal/domain/llm" + "context" + + eino_ollama "github.com/cloudwego/eino-ext/components/model/ollama" + eino_model "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type Adapter struct{} + +func New() *Adapter { return &Adapter{} } + +func (a *Adapter) Generate(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) { + cm, err := eino_ollama.NewChatModel(ctx, &eino_ollama.ChatModelConfig{ + BaseURL: opts.Endpoint, + Timeout: opts.Timeout, + Model: opts.Model, + Options: &eino_ollama.Options{Temperature: opts.Temperature, NumPredict: opts.MaxTokens}, + Thinking: &eino_ollama.ThinkValue{Value: opts.Thinking}, + }) + if err != nil { + return nil, err + } + var mopts []eino_model.Option + if opts.Temperature != 0 { + mopts = append(mopts, eino_model.WithTemperature(opts.Temperature)) + } + if opts.MaxTokens > 0 { + mopts = append(mopts, eino_model.WithMaxTokens(opts.MaxTokens)) + } + if opts.Model != "" { + mopts = append(mopts, eino_model.WithModel(opts.Model)) + } + if opts.TopP != 0 { + mopts = append(mopts, eino_model.WithTopP(opts.TopP)) + } + if len(opts.Stop) > 0 { + mopts = append(mopts, eino_model.WithStop(opts.Stop)) + } + return cm.Generate(ctx, input, mopts...) +} + +func (a *Adapter) Stream(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.StreamReader[*schema.Message], error) { + cm, err := eino_ollama.NewChatModel(ctx, &eino_ollama.ChatModelConfig{ + BaseURL: opts.Endpoint, + Timeout: opts.Timeout, + Model: opts.Model, + Options: &eino_ollama.Options{Temperature: opts.Temperature, NumPredict: opts.MaxTokens}, + }) + if err != nil { + return nil, err + } + var mopts []eino_model.Option + if opts.Temperature != 0 { + mopts = append(mopts, eino_model.WithTemperature(opts.Temperature)) + } + if opts.MaxTokens > 0 { + mopts = append(mopts, eino_model.WithMaxTokens(opts.MaxTokens)) + } + if opts.Model != "" { + mopts = append(mopts, eino_model.WithModel(opts.Model)) + } + if opts.TopP != 0 { + mopts = append(mopts, eino_model.WithTopP(opts.TopP)) + } + if len(opts.Stop) > 0 { + mopts = append(mopts, eino_model.WithStop(opts.Stop)) + } + return cm.Stream(ctx, input, mopts...) +} diff --git a/internal/domain/llm/provider/registry.go b/internal/domain/llm/provider/registry.go new file mode 100644 index 0000000..9dc7ff5 --- /dev/null +++ b/internal/domain/llm/provider/registry.go @@ -0,0 +1,25 @@ +package provider + +import ( + "context" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/domain/llm" +) + +type Adapter interface { + Generate(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) + Stream(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.StreamReader[*schema.Message], error) +} + +type Factory func() Adapter + +var registry = map[string]Factory{} + +func Register(name string, f Factory) { + registry[name] = f +} + +func Get(name string) Factory { + return registry[name] +} + diff --git a/internal/domain/llm/provider/vllm/adapter.go b/internal/domain/llm/provider/vllm/adapter.go new file mode 100644 index 0000000..b3828db --- /dev/null +++ b/internal/domain/llm/provider/vllm/adapter.go @@ -0,0 +1,70 @@ +package vllm + +import ( + "ai_scheduler/internal/domain/llm" + "context" + + eino_openai "github.com/cloudwego/eino-ext/components/model/openai" + eino_model "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type Adapter struct{} + +func New() *Adapter { return &Adapter{} } + +func (a *Adapter) Generate(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) { + cm, err := eino_openai.NewChatModel(ctx, &eino_openai.ChatModelConfig{ + BaseURL: opts.Endpoint, + Timeout: opts.Timeout, + Model: opts.Model, + }) + if err != nil { + return nil, err + } + var mopts []eino_model.Option + if opts.Temperature != 0 { + mopts = append(mopts, eino_model.WithTemperature(opts.Temperature)) + } + if opts.MaxTokens > 0 { + mopts = append(mopts, eino_model.WithMaxTokens(opts.MaxTokens)) + } + if opts.Model != "" { + mopts = append(mopts, eino_model.WithModel(opts.Model)) + } + if opts.TopP != 0 { + mopts = append(mopts, eino_model.WithTopP(opts.TopP)) + } + if len(opts.Stop) > 0 { + mopts = append(mopts, eino_model.WithStop(opts.Stop)) + } + return cm.Generate(ctx, input, mopts...) +} + +func (a *Adapter) Stream(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.StreamReader[*schema.Message], error) { + cm, err := eino_openai.NewChatModel(ctx, &eino_openai.ChatModelConfig{ + BaseURL: opts.Endpoint, + Timeout: opts.Timeout, + Model: opts.Model, + }) + if err != nil { + return nil, err + } + var mopts []eino_model.Option + if opts.Temperature != 0 { + mopts = append(mopts, eino_model.WithTemperature(opts.Temperature)) + } + if opts.MaxTokens > 0 { + mopts = append(mopts, eino_model.WithMaxTokens(opts.MaxTokens)) + } + if opts.Model != "" { + mopts = append(mopts, eino_model.WithModel(opts.Model)) + } + if opts.TopP != 0 { + mopts = append(mopts, eino_model.WithTopP(opts.TopP)) + } + if len(opts.Stop) > 0 { + mopts = append(mopts, eino_model.WithStop(opts.Stop)) + } + return cm.Stream(ctx, input, mopts...) +} diff --git a/internal/domain/llm/service/chat_service.go b/internal/domain/llm/service/chat_service.go new file mode 100644 index 0000000..f642cf4 --- /dev/null +++ b/internal/domain/llm/service/chat_service.go @@ -0,0 +1,22 @@ +package service + +import ( + "context" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm/pipeline" +) + +type ChatService struct{ run compose.Runnable[[]*schema.Message, *schema.Message] } + +func NewChatService(ctx context.Context, cfg *config.Config) (*ChatService, error) { + r, err := pipeline.BuildChat(ctx, cfg) + if err != nil { return nil, err } + return &ChatService{run: r}, nil +} + +func (s *ChatService) Invoke(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) { + return s.run.Invoke(ctx, msgs) +} + diff --git a/internal/domain/llm/service/intent_service.go b/internal/domain/llm/service/intent_service.go new file mode 100644 index 0000000..25929ae --- /dev/null +++ b/internal/domain/llm/service/intent_service.go @@ -0,0 +1,22 @@ +package service + +import ( + "context" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm/pipeline" +) + +type IntentService struct{ run compose.Runnable[[]*schema.Message, *schema.Message] } + +func NewIntentService(ctx context.Context, cfg *config.Config) (*IntentService, error) { + r, err := pipeline.BuildIntent(ctx, cfg) + if err != nil { return nil, err } + return &IntentService{run: r}, nil +} + +func (s *IntentService) Invoke(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) { + return s.run.Invoke(ctx, msgs) +} + diff --git a/internal/domain/llm/service/service.go b/internal/domain/llm/service/service.go new file mode 100644 index 0000000..a8e2e21 --- /dev/null +++ b/internal/domain/llm/service/service.go @@ -0,0 +1,96 @@ +package service + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm" + "ai_scheduler/internal/domain/llm/capability" + "ai_scheduler/internal/domain/llm/provider" + "context" + + "github.com/cloudwego/eino/schema" +) + +type LLMService struct{ cfg *config.Config } + +func NewLLMService(cfg *config.Config) *LLMService { return &LLMService{cfg: cfg} } + +func (s *LLMService) Chat(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) { + choice, routeOpts, err := capability.Route(s.cfg, capability.Chat) + if err != nil { + return nil, err + } + mergeOptions(&routeOpts, opts) + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + return f().Generate(ctx, input, routeOpts) +} + +func (s *LLMService) ChatStream(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.StreamReader[*schema.Message], error) { + choice, routeOpts, err := capability.Route(s.cfg, capability.Chat) + if err != nil { + return nil, err + } + routeOpts.Stream = true + mergeOptions(&routeOpts, opts) + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + return f().Stream(ctx, input, routeOpts) +} + +func (s *LLMService) Vision(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) { + choice, routeOpts, err := capability.Route(s.cfg, capability.Vision) + if err != nil { + return nil, err + } + mergeOptions(&routeOpts, opts) + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + return f().Generate(ctx, input, routeOpts) +} + +func (s *LLMService) Intent(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) { + choice, routeOpts, err := capability.Route(s.cfg, capability.Intent) + if err != nil { + return nil, err + } + mergeOptions(&routeOpts, opts) + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + return f().Generate(ctx, input, routeOpts) +} + +func mergeOptions(base *llm.Options, override llm.Options) { + if override.Model != "" { + base.Model = override.Model + } + if override.MaxTokens > 0 { + base.MaxTokens = override.MaxTokens + } + if override.Temperature != 0 { + base.Temperature = override.Temperature + } + if override.Timeout > 0 { + base.Timeout = override.Timeout + } + if len(override.Modalities) > 0 { + base.Modalities = override.Modalities + } + if override.SystemPrompt != "" { + base.SystemPrompt = override.SystemPrompt + } + if override.TopP != 0 { + base.TopP = override.TopP + } + if len(override.Stop) > 0 { + base.Stop = override.Stop + } + base.Stream = base.Stream || override.Stream +} diff --git a/internal/domain/llm/service/vision_service.go b/internal/domain/llm/service/vision_service.go new file mode 100644 index 0000000..be37be2 --- /dev/null +++ b/internal/domain/llm/service/vision_service.go @@ -0,0 +1,22 @@ +package service + +import ( + "context" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm/pipeline" +) + +type VisionService struct{ run compose.Runnable[[]*schema.Message, *schema.Message] } + +func NewVisionService(ctx context.Context, cfg *config.Config) (*VisionService, error) { + r, err := pipeline.BuildVision(ctx, cfg) + if err != nil { return nil, err } + return &VisionService{run: r}, nil +} + +func (s *VisionService) Invoke(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) { + return s.run.Invoke(ctx, msgs) +} + diff --git a/internal/domain/repo/adapter.go b/internal/domain/repo/adapter.go new file mode 100644 index 0000000..c9a1357 --- /dev/null +++ b/internal/domain/repo/adapter.go @@ -0,0 +1,29 @@ +package repo + +import ( + "ai_scheduler/internal/data/impl" + "context" + "errors" +) + +// SessionAdapter 适配 impl.SessionImpl 到 SessionRepo 接口 +type SessionAdapter struct { + impl *impl.SessionImpl +} + +func NewSessionAdapter(impl *impl.SessionImpl) *SessionAdapter { + return &SessionAdapter{impl: impl} +} + +func (s *SessionAdapter) GetUserName(ctx context.Context, sessionID string) (string, error) { + // 复用 SessionImpl 的查询能力 + // 这里假设 sessionID 是唯一的,直接用 FindOne + session, has, err := s.impl.FindOne(s.impl.WithSessionId(sessionID)) + if err != nil { + return "", err + } + if !has { + return "", errors.New("session not found") + } + return session.UserName, nil +} diff --git a/internal/domain/repo/provider_set.go b/internal/domain/repo/provider_set.go new file mode 100644 index 0000000..c5b2437 --- /dev/null +++ b/internal/domain/repo/provider_set.go @@ -0,0 +1,5 @@ +package repo + +import "github.com/google/wire" + +var ProviderSet = wire.NewSet(NewRepos) diff --git a/internal/domain/repo/repos.go b/internal/domain/repo/repos.go new file mode 100644 index 0000000..40ba3de --- /dev/null +++ b/internal/domain/repo/repos.go @@ -0,0 +1,17 @@ +package repo + +import ( + "ai_scheduler/internal/data/impl" + "ai_scheduler/utils" +) + +// Repos 聚合所有 Repository +type Repos struct { + Session SessionRepo +} + +func NewRepos(sessionImpl *impl.SessionImpl, rdb *utils.Rdb) *Repos { + return &Repos{ + Session: NewSessionAdapter(sessionImpl), + } +} diff --git a/internal/domain/repo/session.go b/internal/domain/repo/session.go new file mode 100644 index 0000000..5ccc66c --- /dev/null +++ b/internal/domain/repo/session.go @@ -0,0 +1,11 @@ +package repo + +import ( + "context" +) + +// SessionRepo 定义会话相关的查询接口 +// 这里只暴露 workflow 真正需要的方法,避免直接依赖 impl 层 +type SessionRepo interface { + GetUserName(ctx context.Context, sessionID string) (string, error) +} diff --git a/internal/domain/tools/hyt/goods_add/client.go b/internal/domain/tools/hyt/goods_add/client.go new file mode 100644 index 0000000..a758b55 --- /dev/null +++ b/internal/domain/tools/hyt/goods_add/client.go @@ -0,0 +1,63 @@ +package goods_add + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, req *GoodsAddRequest) (*GoodsAddResponse, error) { + apiReq, _ := util.StructToMap(req) + + r := l_request.Request{ + Method: "Post", + Url: c.cfg.BaseURL, + Json: apiReq, + Headers: map[string]string{ + "Content-Type": "application/json", + }, + } + + res, err := r.Send() + if err != nil { + return nil, fmt.Errorf("请求失败,err: %v", err) + } + + type resType struct { + Code int `json:"code"` + Msg string `json:"message"` + Data struct { + Id int `json:"id"` // 商品 ID + } `json:"data"` + } + + var resData resType + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return nil, fmt.Errorf("解析响应失败,err: %v", err) + } + + if resData.Code != 200 { + return nil, fmt.Errorf("业务错误,%s", resData.Msg) + } + + toolResp := &GoodsAddResponse{ + PreviewUrl: c.cfg.AddURL, + SpuCode: req.SpuCode, + Id: resData.Data.Id, + } + + return toolResp, nil +} diff --git a/internal/domain/tools/hyt/goods_add/client_test.go b/internal/domain/tools/hyt/goods_add/client_test.go new file mode 100644 index 0000000..aba715a --- /dev/null +++ b/internal/domain/tools/hyt/goods_add/client_test.go @@ -0,0 +1,51 @@ +package goods_add + +import ( + "ai_scheduler/internal/config" + "context" + "fmt" + "testing" +) + +// Test_Call +func Test_Call(t *testing.T) { + req := &GoodsAddRequest{ + Unit: "元", + IsComposeGoods: 2, + GoodsAttributes: "

商品规格参数

", + Introduction: "

商品卖点

", + GoodsIllustration: "

商品说明

", + IsHot: 2, + Title: "fu测试001", + GoodsNum: "futest001sku", + SpuCode: "futest001spu", + SpuName: "fu测试001", + Price: 100, + SalesPrice: 80, + Discount: 15, + TaxRate: 13, + FreightId: 3, + Remark: "备注说明", + SellByDate: 180, + ExternalPrice: 120, + GoodsBarCode: "futest001code2", + GoodsCode: "futest001code1", + SellByDateUnit: "天", + BrandId: 3, + ExternalUrl: "https://www.baidu.com", + } + + cfg := config.ToolConfig{ + BaseURL: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/add", + } + + client := New(cfg) + toolResp, err := client.Call(context.Background(), req) + + if err != nil { + t.Errorf("Call() error = %v", err) + return + } + + fmt.Printf("toolResp: %v\n", toolResp) +} diff --git a/internal/domain/tools/hyt/goods_add/types.go b/internal/domain/tools/hyt/goods_add/types.go new file mode 100644 index 0000000..d17b500 --- /dev/null +++ b/internal/domain/tools/hyt/goods_add/types.go @@ -0,0 +1,33 @@ +package goods_add + +type GoodsAddRequest struct { + Title string `json:"title"` // 商品标题 + GoodsCode string `json:"goods_code"` // 商品编码 + SpuName string `json:"spu_name"` // SPU 名称 + SpuCode string `json:"spu_code"` // SPU 编码 + GoodsNum string `json:"goods_num"` // 商品货号 + GoodsBarCode string `json:"goods_bar_code"` // 商品条形码 + Price float64 `json:"price"` // 市场价 + SalesPrice float64 `json:"sales_price"` // 建议销售价 + ExternalPrice float64 `json:"external_price"` // 电商销售价格 + Unit string `json:"unit"` // 价格单位 + Discount int `json:"discount"` // 折扣 + TaxRate int `json:"tax_rate"` // 税率 + FreightId int `json:"freight_id"` // 运费模板 ID + SellByDate int `json:"sell_by_date"` // 保质期 + SellByDateUnit string `json:"sell_by_date_unit"` // 保质期单位 + BrandId int `json:"brand_id"` // 品牌 ID + IsHot int `json:"is_hot"` // 是否热销主推 1.是 2.否(默认) + ExternalUrl string `json:"external_url"` // 外部平台链接 + Introduction string `json:"introduction"` // 商品卖点 + GoodsAttributes string `json:"goods_attributes"` // 商品规格参数 + GoodsIllustration string `json:"goods_illustration"` // 商品说明 + Remark string `json:"remark"` // 备注说明 + IsComposeGoods int `json:"is_compose_goods"` // 是否组合商品 1.是 2.否(默认) +} + +type GoodsAddResponse struct { + PreviewUrl string `json:"preview_url"` // 预览URL + SpuCode string `json:"spu_code"` // SPU编码 + Id int `json:"id"` // 商品ID +} diff --git a/internal/domain/tools/hyt/goods_brand_search/client.go b/internal/domain/tools/hyt/goods_brand_search/client.go new file mode 100644 index 0000000..e7b3d58 --- /dev/null +++ b/internal/domain/tools/hyt/goods_brand_search/client.go @@ -0,0 +1,67 @@ +package goods_brand_search + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, name string) (int, error) { + if name == "" { + return 0, nil + } + + reqBody := GoodsBrandSearchRequest{ + Page: 1, + Limit: 1, + Search: SearchCondition{ + Name: name, + }, + } + + apiReq, _ := util.StructToMap(reqBody) + + req := l_request.Request{ + Method: "Post", + Url: c.cfg.BaseURL, + Json: apiReq, + Headers: map[string]string{ + "User-Agent": "Apifox/1.0.0 (https://apifox.com)", + "Content-Type": "application/json", + }, + } + + res, err := req.Send() + if err != nil { + return 0, fmt.Errorf("请求失败,err: %v", err) + } + + var resData GoodsBrandSearchResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return 0, fmt.Errorf("解析响应失败,err: %v", err) + } + + if resData.Code != 200 { + return 0, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + } + + if len(resData.Data.List) == 0 { + return 0, fmt.Errorf("品牌不存在") + } + + // 返回第一个匹配的品牌ID + return resData.Data.List[0].ID, nil +} diff --git a/internal/domain/tools/hyt/goods_brand_search/client_test.go b/internal/domain/tools/hyt/goods_brand_search/client_test.go new file mode 100644 index 0000000..41009f1 --- /dev/null +++ b/internal/domain/tools/hyt/goods_brand_search/client_test.go @@ -0,0 +1,28 @@ +package goods_brand_search + +import ( + "ai_scheduler/internal/config" + "context" + "fmt" + "testing" +) + +// Test_Call +func Test_Call(t *testing.T) { + // 使用示例中的查询条件 + name := "vivo" + + cfg := config.ToolConfig{ + BaseURL: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/brand/list", + } + + client := New(cfg) + toolResp, err := client.Call(context.Background(), name) + + if err != nil { + t.Errorf("Call() error = %v", err) + return + } + + fmt.Printf("toolResp (BrandID): %v\n", toolResp) +} diff --git a/internal/domain/tools/hyt/goods_brand_search/types.go b/internal/domain/tools/hyt/goods_brand_search/types.go new file mode 100644 index 0000000..467a214 --- /dev/null +++ b/internal/domain/tools/hyt/goods_brand_search/types.go @@ -0,0 +1,25 @@ +package goods_brand_search + +type GoodsBrandSearchRequest struct { + Page int `json:"page"` + Limit int `json:"limit"` + Search SearchCondition `json:"search"` +} + +type SearchCondition struct { + Name string `json:"name"` +} + +type GoodsBrandSearchResponse struct { + Code int `json:"code"` + Msg string `json:"message"` + Data struct { + List []BrandInfo `json:"list"` + } `json:"data"` +} + +type BrandInfo struct { + ID int `json:"id"` + Name string `json:"name"` + Logo string `json:"logo"` +} diff --git a/internal/domain/tools/hyt/goods_category_add/client.go b/internal/domain/tools/hyt/goods_category_add/client.go new file mode 100644 index 0000000..8fa0e8b --- /dev/null +++ b/internal/domain/tools/hyt/goods_category_add/client.go @@ -0,0 +1,49 @@ +package goods_category_add + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, req *GoodsCategoryAddRequest) (bool, error) { + apiReq, _ := util.StructToMap(req) + + r := l_request.Request{ + Method: "Post", + Url: c.cfg.BaseURL, + Json: apiReq, + Headers: map[string]string{ + "Content-Type": "application/json", + }, + } + + res, err := r.Send() + if err != nil { + return false, fmt.Errorf("请求失败,err: %v", err) + } + + var resData GoodsCategoryAddResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return false, fmt.Errorf("解析响应失败,err: %v", err) + } + + if resData.Code != 200 { + return false, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + } + + return resData.Data.IsSuccess, nil +} diff --git a/internal/domain/tools/hyt/goods_category_add/client_test.go b/internal/domain/tools/hyt/goods_category_add/client_test.go new file mode 100644 index 0000000..fed3a94 --- /dev/null +++ b/internal/domain/tools/hyt/goods_category_add/client_test.go @@ -0,0 +1,31 @@ +package goods_category_add + +import ( + "ai_scheduler/internal/config" + "context" + "fmt" + "testing" +) + +// Test_Call +func Test_Call(t *testing.T) { + req := &GoodsCategoryAddRequest{ + GoodsId: 8496, + CategoryIds: []int{1667}, + IsCover: false, + } + + cfg := config.ToolConfig{ + BaseURL: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/good/category/relation/add", + } + + client := New(cfg) + toolResp, err := client.Call(context.Background(), req) + + if err != nil { + t.Errorf("Call() error = %v", err) + return + } + + fmt.Printf("toolResp: %v\n", toolResp) +} diff --git a/internal/domain/tools/hyt/goods_category_add/types.go b/internal/domain/tools/hyt/goods_category_add/types.go new file mode 100644 index 0000000..b3ecf68 --- /dev/null +++ b/internal/domain/tools/hyt/goods_category_add/types.go @@ -0,0 +1,15 @@ +package goods_category_add + +type GoodsCategoryAddRequest struct { + GoodsId int `json:"goods_id"` + CategoryIds []int `json:"category_ids"` + IsCover bool `json:"is_cover"` +} + +type GoodsCategoryAddResponse struct { + Code int `json:"code"` + Msg string `json:"message"` + Data struct { + IsSuccess bool `json:"is_success"` // 是否成功 + } `json:"data"` +} diff --git a/internal/domain/tools/hyt/goods_category_search/client.go b/internal/domain/tools/hyt/goods_category_search/client.go new file mode 100644 index 0000000..3af5e14 --- /dev/null +++ b/internal/domain/tools/hyt/goods_category_search/client.go @@ -0,0 +1,67 @@ +package goods_category_search + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, name string) (int, error) { + if name == "" { + return 0, nil + } + + reqBody := GoodsCategorySearchRequest{ + Page: 1, + Limit: 1, + Search: SearchCondition{ + Name: name, + Level: 3, // 仅需三级分类 + }, + } + + apiReq, _ := util.StructToMap(reqBody) + + req := l_request.Request{ + Method: "Post", + Url: c.cfg.BaseURL, + Json: apiReq, + Headers: map[string]string{ + "Content-Type": "application/json", + }, + } + + res, err := req.Send() + if err != nil { + return 0, fmt.Errorf("请求失败,err: %v", err) + } + + var resData GoodsCategorySearchResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return 0, fmt.Errorf("解析响应失败,err: %v", err) + } + + if resData.Code != 200 { + return 0, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + } + + if len(resData.Data.List) == 0 { + return 0, fmt.Errorf("商品分类不存在") + } + + // 返回第一个匹配的分类ID + return resData.Data.List[0].ID, nil +} diff --git a/internal/domain/tools/hyt/goods_category_search/types.go b/internal/domain/tools/hyt/goods_category_search/types.go new file mode 100644 index 0000000..2b9fb0d --- /dev/null +++ b/internal/domain/tools/hyt/goods_category_search/types.go @@ -0,0 +1,25 @@ +package goods_category_search + +type GoodsCategorySearchRequest struct { + Page int `json:"page"` + Limit int `json:"limit"` + Search SearchCondition `json:"search"` +} + +type SearchCondition struct { + Name string `json:"full_name"` + Level int `json:"level"` +} + +type GoodsCategorySearchResponse struct { + Code int `json:"code"` + Msg string `json:"message"` + Data struct { + List []CategoryInfo `json:"list"` + } `json:"data"` +} + +type CategoryInfo struct { + ID int `json:"id"` + Name string `json:"name"` +} diff --git a/internal/domain/tools/hyt/goods_media_add/client.go b/internal/domain/tools/hyt/goods_media_add/client.go new file mode 100644 index 0000000..6632168 --- /dev/null +++ b/internal/domain/tools/hyt/goods_media_add/client.go @@ -0,0 +1,49 @@ +package goods_media_add + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, req *GoodsMediaAddRequest) (bool, error) { + apiReq, _ := util.StructToMap(req) + + r := l_request.Request{ + Method: "Post", + Url: c.cfg.BaseURL, + Json: apiReq, + Headers: map[string]string{ + "Content-Type": "application/json", + }, + } + + res, err := r.Send() + if err != nil { + return false, fmt.Errorf("请求失败,err: %v", err) + } + + var resData GoodsMediaAddResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return false, fmt.Errorf("解析响应失败,err: %v", err) + } + + if resData.Code != 200 { + return false, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + } + + return resData.Data.IsSuccess, nil +} diff --git a/internal/domain/tools/hyt/goods_media_add/client_test.go b/internal/domain/tools/hyt/goods_media_add/client_test.go new file mode 100644 index 0000000..f6f16ca --- /dev/null +++ b/internal/domain/tools/hyt/goods_media_add/client_test.go @@ -0,0 +1,37 @@ +package goods_media_add + +import ( + "ai_scheduler/internal/config" + "context" + "fmt" + "testing" +) + +// Test_Call +func Test_Call(t *testing.T) { + req := &GoodsMediaAddRequest{ + GoodsId: 8496, + Data: []MediaItem{ + { + Type: 1, + Url: "https://lsxd-hz-store.oss-cn-hangzhou.aliyuncs.com/physicalGoodsSystems/images/goodsimages/goods/22f03d91-3cb7-45b4-ab92-07aad78a1633-screenshot_2025-12-17_17-46-00.png", + Sort: 1, + }, + }, + IsCover: true, + } + + cfg := config.ToolConfig{ + BaseURL: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/media/add/batch", + } + + client := New(cfg) + toolResp, err := client.Call(context.Background(), req) + + if err != nil { + t.Errorf("Call() error = %v", err) + return + } + + fmt.Printf("toolResp: %v\n", toolResp) +} diff --git a/internal/domain/tools/hyt/goods_media_add/types.go b/internal/domain/tools/hyt/goods_media_add/types.go new file mode 100644 index 0000000..bde4826 --- /dev/null +++ b/internal/domain/tools/hyt/goods_media_add/types.go @@ -0,0 +1,21 @@ +package goods_media_add + +type GoodsMediaAddRequest struct { + GoodsId int `json:"goods_id"` + Data []MediaItem `json:"data"` + IsCover bool `json:"is_cover"` +} + +type MediaItem struct { + Type int `json:"type"` + Url string `json:"url"` + Sort int `json:"sort"` +} + +type GoodsMediaAddResponse struct { + Code int `json:"code"` + Msg string `json:"message"` + Data struct { + IsSuccess bool `json:"is_success"` + } `json:"data"` +} diff --git a/internal/domain/tools/hyt/product_upload/client.go b/internal/domain/tools/hyt/product_upload/client.go new file mode 100644 index 0000000..096965c --- /dev/null +++ b/internal/domain/tools/hyt/product_upload/client.go @@ -0,0 +1,68 @@ +package product_upload + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "errors" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, toolReq *ProductUploadRequest) (toolResp *ProductUploadResponse, err error) { + // 商品有且只能有一个 + if len(toolReq.GoodsList) != 1 { + return nil, errors.New("商品只能有一个") + } + + apiReq, _ := util.StructToMap(toolReq) + + req := l_request.Request{ + Method: "Post", + Url: c.cfg.BaseURL, + Json: apiReq, + } + res, err := req.Send() + + if err != nil { + return nil, fmt.Errorf("请求失败,err: %v", err) + } + + type resType struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + Ids []int `json:"ids"` // 商品 IDs + } `json:"data"` + } + var resMap resType + err = json.Unmarshal([]byte(res.Text), &resMap) + if err != nil { + return nil, fmt.Errorf("解析响应失败,err: %v", err) + } + if resMap.Code != 200 { + return nil, fmt.Errorf("业务错误,code: %d, msg: %s", resMap.Code, resMap.Msg) + } + if len(resMap.Data.Ids) == 0 { + return nil, fmt.Errorf("ids为空") + } + + toolResp = &ProductUploadResponse{ + PreviewUrl: c.cfg.AddURL, + SpuNum: toolReq.GoodsList[0].GoodsInfo.SpuNum, + Id: resMap.Data.Ids[0], + } + + return toolResp, nil +} diff --git a/internal/domain/tools/hyt/product_upload/types.go b/internal/domain/tools/hyt/product_upload/types.go new file mode 100644 index 0000000..947dbe5 --- /dev/null +++ b/internal/domain/tools/hyt/product_upload/types.go @@ -0,0 +1,54 @@ +package product_upload + +type ProductUploadRequest struct { + SupplierId int `json:"supplier_id"` // 供应商ID + WarehouseId int `json:"warehouse_id"` // 仓库ID + IsDefaultWarehouse int `json:"is_default_warehouse"` // 是否默认仓库 + Sort int `json:"sort"` // 排序 + Profit float64 `json:"profit"` // 利润 + TaxRate int `json:"tax_rate"` // 税率 + GoodsList []Goods `json:"goods_list"` // 商品列表 +} + +type Goods struct { + GoodsInfo GoodsInfo `json:"goods_info"` + GoodsMediaList []GoodsMedia `json:"goods_media_list"` +} + +type GoodsInfo struct { + Title string `json:"title"` // 商品名称 + Brand string `json:"brand"` // 品牌 + Category string `json:"category"` // 分类 + Discount int `json:"discount"` // 折扣 + GoodsAttributes string `json:"goods_attributes"` // 商品属性 + GoodsBarCode string `json:"goods_bar_code"` // 商品条码 + GoodsNum string `json:"goods_num"` // 商品编号 + Introduction string `json:"introduction"` // 商品介绍 + SpuName string `json:"spu_name"` // SPU名称 + SpuNum string `json:"spu_num"` // SPU编号 + Stock int `json:"stock"` // 库存 + TaxRate int `json:"tax_rate"` // 税率 + Unit string `json:"unit"` // 单位 + Weight string `json:"weight"` // 重量 + Price float64 `json:"price"` // 市场价 + SalesPrice float64 `json:"sales_price"` // 建议销售价格 + GoodsIllustration string `json:"goods_illustration"` // 商品插图 - 暂不提供 + Id int `json:"id"` // 商品ID - 无需 + CostPrice float64 `json:"cost_price"` // 成本价格 - 无需 + IsBind int `json:"is_bind"` // 是否绑定 - 默认0 + IsComposeGoods int32 `json:"is_compose_goods"` // 是否组合商品 - 默认2 + IsHot int `json:"is_hot"` // 是否热门商品 - 默认2 +} + +type GoodsMedia struct { + Remark string `json:"remark"` // 备注 + Sort int `json:"sort"` // 排序 + Type int `json:"type"` // 类型 + Url string `json:"url"` // URL +} + +type ProductUploadResponse struct { + PreviewUrl string `json:"preview_url"` // 预览URL + SpuNum string `json:"spu_code"` // SPU编码 + Id int `json:"id"` // 商品ID +} diff --git a/internal/domain/tools/hyt/supplier_search/client.go b/internal/domain/tools/hyt/supplier_search/client.go new file mode 100644 index 0000000..1f47ee8 --- /dev/null +++ b/internal/domain/tools/hyt/supplier_search/client.go @@ -0,0 +1,67 @@ +package supplier_search + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "context" + "encoding/json" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, name string) (int, error) { + if name == "" { + // 如果没有供应商名,返回0,不报错,由上层业务决定是否允许 + return 0, nil + } + + reqBody := SearchRequest{ + Page: 1, + Limit: 1, + Search: SearchCondition{ + Name: name, + }, + } + + apiReq := make(map[string]interface{}) + bytes, _ := json.Marshal(reqBody) + _ = json.Unmarshal(bytes, &apiReq) + + req := l_request.Request{ + Method: "Post", + Url: c.cfg.BaseURL, + Json: apiReq, + Headers: map[string]string{ + "Content-Type": "application/json", + }, + } + + res, err := req.Send() + if err != nil { + return 0, fmt.Errorf("请求失败,err: %v", err) + } + + var resData SearchResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return 0, fmt.Errorf("解析响应失败,err: %v", err) + } + + if resData.Code != 200 { + return 0, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + } + + if len(resData.Data.List) == 0 { + return 0, fmt.Errorf("供应商不存在") + } + + return resData.Data.List[0].ID, nil +} diff --git a/internal/domain/tools/hyt/supplier_search/types.go b/internal/domain/tools/hyt/supplier_search/types.go new file mode 100644 index 0000000..46a452c --- /dev/null +++ b/internal/domain/tools/hyt/supplier_search/types.go @@ -0,0 +1,24 @@ +package supplier_search + +type SearchRequest struct { + Page int `json:"page"` + Limit int `json:"limit"` + Search SearchCondition `json:"search"` +} + +type SearchCondition struct { + Name string `json:"name"` +} + +type SearchResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + List []SupplierInfo `json:"list"` + } `json:"data"` +} + +type SupplierInfo struct { + ID int `json:"id"` + Name string `json:"name"` +} diff --git a/internal/domain/tools/hyt/warehouse_search/client.go b/internal/domain/tools/hyt/warehouse_search/client.go new file mode 100644 index 0000000..cf420b2 --- /dev/null +++ b/internal/domain/tools/hyt/warehouse_search/client.go @@ -0,0 +1,62 @@ +package warehouse_search + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "context" + "encoding/json" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, name string) (int, error) { + if name == "" { + // 如果没有仓库名,返回0,不报错,由上层业务决定是否允许 + return 0, nil + } + + // GET 请求参数 + params := map[string]string{ + "name": name, + "page": "1", + "limit": "1", + } + + req := l_request.Request{ + Method: "Get", + Url: c.cfg.BaseURL, + Params: params, + Headers: map[string]string{ + "Content-Type": "application/json", + }, + } + + res, err := req.Send() + if err != nil { + return 0, fmt.Errorf("请求失败,err: %v", err) + } + + var resData SearchResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return 0, fmt.Errorf("解析响应失败,err: %v", err) + } + + if resData.Code != 200 { + return 0, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + } + + if len(resData.Data.List) == 0 { + return 0, fmt.Errorf("仓库不存在: %s", name) + } + + return resData.Data.List[0].ID, nil +} diff --git a/internal/domain/tools/hyt/warehouse_search/types.go b/internal/domain/tools/hyt/warehouse_search/types.go new file mode 100644 index 0000000..a5ae237 --- /dev/null +++ b/internal/domain/tools/hyt/warehouse_search/types.go @@ -0,0 +1,14 @@ +package warehouse_search + +type SearchResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + List []WarehouseInfo `json:"list"` + } `json:"data"` +} + +type WarehouseInfo struct { + ID int `json:"id"` + Name string `json:"name"` +} diff --git a/internal/domain/tools/registry.go b/internal/domain/tools/registry.go new file mode 100644 index 0000000..ad9439d --- /dev/null +++ b/internal/domain/tools/registry.go @@ -0,0 +1,44 @@ +package tools + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/tools/hyt/goods_add" + "ai_scheduler/internal/domain/tools/hyt/goods_brand_search" + "ai_scheduler/internal/domain/tools/hyt/goods_category_add" + "ai_scheduler/internal/domain/tools/hyt/goods_category_search" + "ai_scheduler/internal/domain/tools/hyt/goods_media_add" + "ai_scheduler/internal/domain/tools/hyt/product_upload" + "ai_scheduler/internal/domain/tools/hyt/supplier_search" + "ai_scheduler/internal/domain/tools/hyt/warehouse_search" +) + +type Manager struct { + Hyt *HytTools + // Zltx *ZltxTools +} + +type HytTools struct { + ProductUpload *product_upload.Client + SupplierSearch *supplier_search.Client + WarehouseSearch *warehouse_search.Client + GoodsAdd *goods_add.Client + GoodsMediaAdd *goods_media_add.Client + GoodsCategoryAdd *goods_category_add.Client + GoodsCategorySearch *goods_category_search.Client + GoodsBrandSearch *goods_brand_search.Client +} + +func NewManager(cfg *config.Config) *Manager { + return &Manager{ + Hyt: &HytTools{ + ProductUpload: product_upload.New(cfg.EinoTools.HytProductUpload), + SupplierSearch: supplier_search.New(cfg.EinoTools.HytSupplierSearch), + WarehouseSearch: warehouse_search.New(cfg.EinoTools.HytWarehouseSearch), + GoodsAdd: goods_add.New(cfg.EinoTools.HytGoodsAdd), + GoodsMediaAdd: goods_media_add.New(cfg.EinoTools.HytGoodsMediaAdd), + GoodsCategoryAdd: goods_category_add.New(cfg.EinoTools.HytGoodsCategoryAdd), + GoodsCategorySearch: goods_category_search.New(cfg.EinoTools.HytGoodsCategorySearch), + GoodsBrandSearch: goods_brand_search.New(cfg.EinoTools.HytGoodsBrandSearch), + }, + } +} diff --git a/internal/domain/tools/zltx/order_after_reseller_batch/client.go b/internal/domain/tools/zltx/order_after_reseller_batch/client.go new file mode 100644 index 0000000..efcf29c --- /dev/null +++ b/internal/domain/tools/zltx/order_after_reseller_batch/client.go @@ -0,0 +1,48 @@ +package order_after_reseller_batch + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "errors" + "fmt" +) + +func Call(ctx context.Context, cfg config.ToolConfig, orderNumbers []string) (*OrderAfterSaleResellerBatchResponse, error) { + if len(orderNumbers) == 0 { + return nil, errors.New("批充订单号不能为空") + } + token := util.GetTokenFromContext(ctx) + if token == "" { + return nil, errors.New("token 未注入") + } + r := l_request.Request{ + Url: cfg.BaseURL, + Headers: map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", token), + }, + Method: "POST", + Json: map[string]any{ + "order_numbers": orderNumbers, + "order_type": 2, + }, + } + res, err := r.Send() + if err != nil { + return nil, err + } + + response := &OrderAfterSaleResellerBatchResponse{} + if err = json.Unmarshal(res.Content, &response); err != nil { + return nil, err + } + if response.Code != 200 { + return nil, fmt.Errorf("售后订单查询异常: %s", response.Error) + } + if len(response.Data.Data) == 0 { + return nil, errors.New("未查询到相应售后订单,请核实订单号是否正确") + } + return response, nil +} diff --git a/internal/domain/tools/zltx/order_after_reseller_batch/invokable.go b/internal/domain/tools/zltx/order_after_reseller_batch/invokable.go new file mode 100644 index 0000000..e7e99b5 --- /dev/null +++ b/internal/domain/tools/zltx/order_after_reseller_batch/invokable.go @@ -0,0 +1,24 @@ +package order_after_reseller_batch + +import ( + "ai_scheduler/internal/config" + "context" + + "github.com/cloudwego/eino/components/tool" + toolutils "github.com/cloudwego/eino/components/tool/utils" +) + +type Args struct { + OrderNumber []string `json:"orderNumber"` +} + +func NewInvokable(cfg config.ToolConfig) tool.InvokableTool { + run := func(ctx context.Context, in Args) (*OrderAfterSaleResellerBatchResponse, error) { + return Call(ctx, cfg, in.OrderNumber) + } + t, err := toolutils.InferTool("zltxOrderAfterSaleResellerBatch", "直连天下下游分销商批充订单售后工具", run) + if err != nil { + panic(err) + } + return t +} diff --git a/internal/domain/tools/zltx/order_after_reseller_batch/types.go b/internal/domain/tools/zltx/order_after_reseller_batch/types.go new file mode 100644 index 0000000..3e35115 --- /dev/null +++ b/internal/domain/tools/zltx/order_after_reseller_batch/types.go @@ -0,0 +1,32 @@ +package order_after_reseller_batch + +type OrderAfterSaleResellerBatchResponse struct { + Code int `json:"code"` + Error string `json:"error"` + Data *OrderAfterSaleResellerBatchData `json:"data"` +} + +type OrderAfterSaleResellerBatchData struct { + Data []*OrderAfterSaleResellerBatchBase `json:"data"` + ExtData map[string]*OrderAfterSaleResellerBatchExtItem `json:"extraData"` +} + +type OrderAfterSaleResellerBatchBase struct { + OrderType int `json:"orderType"` + OrderNumber string `json:"orderNumber"` + OrderAmount float64 `json:"orderAmount"` + OrderPrice float64 `json:"orderPrice"` + SignCompany int `json:"signCompany"` + OrderQuantity int `json:"orderQuantity"` + ResellerID int `json:"resellerId"` + ResellerName string `json:"resellerName"` + OurProductID int `json:"ourProductId"` + OurProductTitle string `json:"ourProductTitle"` + Account []string `json:"account"` + Platforms map[int]string `json:"platforms"` +} + +type OrderAfterSaleResellerBatchExtItem struct { + IsExistsAfterSale bool `json:"isExistsAfterSale"` + SerialCreateTime int `json:"createTime"` +} diff --git a/internal/domain/workflow/hyt/goods_add.go b/internal/domain/workflow/hyt/goods_add.go new file mode 100644 index 0000000..91db063 --- /dev/null +++ b/internal/domain/workflow/hyt/goods_add.go @@ -0,0 +1,373 @@ +package hyt + +import ( + "ai_scheduler/internal/config" + errorcode "ai_scheduler/internal/data/error" + toolManager "ai_scheduler/internal/domain/tools" + "ai_scheduler/internal/domain/tools/hyt/goods_add" + "ai_scheduler/internal/domain/tools/hyt/goods_category_add" + "ai_scheduler/internal/domain/tools/hyt/goods_media_add" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" + "context" + "encoding/json" + "errors" + "fmt" + "log" + "strconv" + "strings" + "sync" + + "github.com/cloudwego/eino/compose" + "golang.org/x/sync/errgroup" +) + +const WorkflowIDGoodsAdd = "hyt.goodsAdd" + +func init() { + runtime.Register(WorkflowIDGoodsAdd, func(d *runtime.Deps) (runtime.Workflow, error) { + return &goodsAdd{cfg: d.Conf, toolManager: d.ToolManager}, nil + }) +} + +type goodsAdd struct { + cfg *config.Config + toolManager *toolManager.Manager + data *GoodsAddWorkflowInput +} + +type GoodsAddWorkflowInput struct { + Text string `mapstructure:"text"` +} + +func (o *goodsAdd) ID() string { return WorkflowIDGoodsAdd } + +func (o *goodsAdd) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { + // 构建工作流 + runnable, err := o.buildWorkflow(ctx) + if err != nil { + return nil, err + } + + o.data = &GoodsAddWorkflowInput{ + Text: rec.UserContent.Text, + } + // 工作流过程调用 + output, err := runnable.Invoke(ctx, o.data) + if err != nil { + fmt.Println("Invoke err:", err) + errStr := err.Error() + if u := errors.Unwrap(err); u != nil { + errStr = u.Error() + } + return nil, errorcode.WorkflowErr(errStr) + } + + return output, nil +} + +// ProductIngestData 对应 HYTGoodsAddPropertyTemplateZH 的结构 +type GoodsAddProductIngestData struct { + Title string `json:"商品标题"` + GoodsCode string `json:"商品编码"` + SpuName string `json:"SPU名称"` + SpuCode string `json:"SPU编码"` + GoodsNum string `json:"商品货号"` + GoodsBarCode string `json:"商品条形码"` + Price string `json:"市场价"` + SalesPrice string `json:"建议销售价"` + ExternalPrice string `json:"电商销售价格"` + Unit string `json:"单位"` + Discount string `json:"折扣"` + TaxRate string `json:"税率"` + FreightTemplate string `json:"运费模版"` + SellByDate string `json:"保质期"` + SellByDateUnit string `json:"保质期单位"` + Brand string `json:"品牌"` + IsHot string `json:"是否热销主推"` + ExternalUrl string `json:"外部平台链接"` + Introduction string `json:"商品卖点"` + GoodsAttributes string `json:"商品规格参数"` + GoodsIllustration string `json:"商品说明"` + Remark string `json:"备注"` + CategoryName string `json:"分类名称"` + Images []string `json:"电脑端主图"` +} + +// GoodsAddContext Graph 执行上下文状态 +type GoodsAddContext struct { + mu *sync.Mutex + InputText string + IngestData *GoodsAddProductIngestData + + // 核心请求体 + AddGoodsReq *goods_add.GoodsAddRequest + + // 中间态数据 + BrandId int + CategoryId int + BrandName string + CategoryName string + + // 运行结果 + GoodsAddResp *goods_add.GoodsAddResponse + GoodsCategoryAddResp bool + GoodsMediaAddResp bool +} + +// buildWorkflow 构建基于 Graph 的并行工作流 +func (o *goodsAdd) buildWorkflow(ctx context.Context) (compose.Runnable[*GoodsAddWorkflowInput, map[string]any], error) { + g := compose.NewGraph[*GoodsAddWorkflowInput, map[string]any]() + + // 1. DataMapping 节点: 解析 JSON -> 填充基础 Request + g.AddLambdaNode("data_mapping", compose.InvokableLambda(func(ctx context.Context, in *GoodsAddWorkflowInput) (*GoodsAddContext, error) { + state := &GoodsAddContext{ + mu: &sync.Mutex{}, // 初始化锁 + InputText: in.Text, + AddGoodsReq: &goods_add.GoodsAddRequest{}, + } + + // 解析用户输入的中文 JSON + var ingestData GoodsAddProductIngestData + if err := json.Unmarshal([]byte(in.Text), &ingestData); err != nil { + return nil, fmt.Errorf("解析商品数据失败") + } + + // 必填校验 + if ingestData.Title == "" { + return nil, errors.New("商品标题不能为空") + } + if ingestData.GoodsCode == "" { + return nil, errors.New("商品编码不能为空") + } + if ingestData.SpuName == "" { + return nil, errors.New("SPU名称不能为空") + } + if ingestData.SpuCode == "" { + return nil, errors.New("SPU编码不能为空") + } + if ingestData.Price == "" { + return nil, errors.New("市场价不能为空") + } + if ingestData.SalesPrice == "" { + return nil, errors.New("建议销售价不能为空") + } + if ingestData.Unit == "" { + return nil, errors.New("价格单位不能为空") + } + if ingestData.Discount == "" { + return nil, errors.New("折扣不能为空") + } + if ingestData.TaxRate == "" { + return nil, errors.New("税率不能为空") + } + + state.IngestData = &ingestData + state.BrandName = ingestData.Brand + state.CategoryName = ingestData.CategoryName + + // 映射字段到 AddGoodsReq + state.AddGoodsReq.Title = ingestData.Title + state.AddGoodsReq.GoodsCode = ingestData.GoodsCode + state.AddGoodsReq.SpuName = ingestData.SpuName + state.AddGoodsReq.SpuCode = ingestData.SpuCode + state.AddGoodsReq.GoodsNum = ingestData.GoodsNum + state.AddGoodsReq.GoodsBarCode = ingestData.GoodsBarCode + + // 价格处理 + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.Price, "元"), 64); err == nil { + state.AddGoodsReq.Price = val + } + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.SalesPrice, "元"), 64); err == nil { + state.AddGoodsReq.SalesPrice = val + } + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.ExternalPrice, "元"), 64); err == nil { + state.AddGoodsReq.ExternalPrice = val + } + + state.AddGoodsReq.Unit = ingestData.Unit + + // 折扣处理 "80%" -> 80 + discountStr := strings.TrimSuffix(ingestData.Discount, "%") + if val, err := strconv.Atoi(discountStr); err == nil { + state.AddGoodsReq.Discount = val + } + // 税率处理 "13%" -> 13 + taxStr := strings.TrimSuffix(strings.TrimSuffix(ingestData.TaxRate, "%"), " ") + if val, err := strconv.Atoi(taxStr); err == nil { + state.AddGoodsReq.TaxRate = val + } + + // 运费模板先不给 state.AddGoodsReq.FreightId = 3 + + // 保质期处理 "180天" -> 180 + sellByDateStr := strings.TrimSuffix(ingestData.SellByDate, "天") + if val, err := strconv.Atoi(sellByDateStr); err == nil { + state.AddGoodsReq.SellByDate = val + } + state.AddGoodsReq.SellByDateUnit = ingestData.SellByDateUnit + + // state.AddGoodsReq.BrandId 品牌ID后续赋值 + + state.AddGoodsReq.IsHot = 2 + if ingestData.IsHot == "是" { + state.AddGoodsReq.IsHot = 1 + } + + state.AddGoodsReq.ExternalUrl = ingestData.ExternalUrl + state.AddGoodsReq.Introduction = ingestData.Introduction + state.AddGoodsReq.GoodsAttributes = ingestData.GoodsAttributes + state.AddGoodsReq.GoodsIllustration = ingestData.GoodsIllustration + state.AddGoodsReq.Remark = ingestData.Remark + state.AddGoodsReq.IsComposeGoods = 2 // 非组合商品 + + return state, nil + })) + + // 2. 预处理节点: 并行获取 品牌ID 和 分类ID + g.AddLambdaNode("prepare_info", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { + eg, ctx := errgroup.WithContext(ctx) + + // 任务1: 获取品牌ID + eg.Go(func() error { + if state.BrandName == "" { + return nil + } + brandId, err := o.toolManager.Hyt.GoodsBrandSearch.Call(ctx, state.BrandName) + if err != nil { + log.Printf("warning: 品牌ID获取失败,%s: %v\n", state.BrandName, err) + return nil + } + state.mu.Lock() + state.BrandId = brandId + state.AddGoodsReq.BrandId = brandId + state.mu.Unlock() + return nil + }) + + // 任务2: 获取分类ID + eg.Go(func() error { + if state.CategoryName == "" { + return nil + } + categoryId, err := o.toolManager.Hyt.GoodsCategorySearch.Call(ctx, state.CategoryName) + if err != nil { + log.Printf("warning: 分类ID获取失败,%s: %v\n", state.CategoryName, err) + return nil + } + state.mu.Lock() + state.CategoryId = categoryId + state.mu.Unlock() + return nil + }) + + // 等待所有任务完成 + _ = eg.Wait() + + return state, nil + })) + + // 3. 新增商品 节点 (依赖 prepare_info) + g.AddLambdaNode("goods_add", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { + // 调用 goods_add 工具 + respData, err := o.toolManager.Hyt.GoodsAdd.Call(ctx, state.AddGoodsReq) + if err != nil || respData == nil { + log.Printf("warning: 新增商品失败: %v", err) + return nil, fmt.Errorf("新增商品失败: %s", err.Error()) + } + + state.GoodsAddResp = respData + + return state, nil + })) + + // 4. 后置处理节点: 并行执行 关联分类 和 添加图片 + g.AddLambdaNode("post_process", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { + if state.GoodsAddResp.Id == 0 { + return nil, errors.New("商品不存在") + } + + eg, ctx := errgroup.WithContext(ctx) + + // 任务1: 关联分类 + eg.Go(func() error { + if state.CategoryId == 0 { + return nil + } + req := &goods_category_add.GoodsCategoryAddRequest{ + GoodsId: state.GoodsAddResp.Id, + CategoryIds: []int{state.CategoryId}, + IsCover: false, + } + isSuccess, err := o.toolManager.Hyt.GoodsCategoryAdd.Call(ctx, req) + if err != nil { + log.Printf("warning: 关联分类失败: %v", err) + return nil + } + + state.mu.Lock() + state.GoodsCategoryAddResp = isSuccess + state.mu.Unlock() + + return nil + }) + + // 任务2: 添加图片 + eg.Go(func() error { + if len(state.IngestData.Images) == 0 { + return nil + } + req := &goods_media_add.GoodsMediaAddRequest{ + GoodsId: state.GoodsAddResp.Id, + IsCover: true, + Data: make([]goods_media_add.MediaItem, 0), + } + for i, url := range state.IngestData.Images { + req.Data = append(req.Data, goods_media_add.MediaItem{ + Type: 1, // 图片 + Url: url, + Sort: i, + }) + } + isSuccess, err := o.toolManager.Hyt.GoodsMediaAdd.Call(ctx, req) + if err != nil { + log.Printf("warning: 添加图片失败: %v", err) + return nil + } + + state.mu.Lock() + state.GoodsMediaAddResp = isSuccess + state.mu.Unlock() + + return nil + }) + + // 等待所有任务完成 + _ = eg.Wait() + + return state, nil + })) + + // 5. 结果格式化节点 + g.AddLambdaNode("format_output", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (map[string]any, error) { + if state.GoodsAddResp == nil { + return nil, fmt.Errorf("goods add response is nil") + } + + return map[string]any{ + "预览URL(货易通商品列表)": state.GoodsAddResp.PreviewUrl, + "SPU编码": state.GoodsAddResp.SpuCode, + "商品ID": state.GoodsAddResp.Id, + }, nil + })) + + // 构建边 (线性拓扑) + g.AddEdge(compose.START, "data_mapping") + g.AddEdge("data_mapping", "prepare_info") + g.AddEdge("prepare_info", "goods_add") + g.AddEdge("goods_add", "post_process") + g.AddEdge("post_process", "format_output") + g.AddEdge("format_output", compose.END) + + return g.Compile(ctx) +} diff --git a/internal/domain/workflow/hyt/product_upload.go b/internal/domain/workflow/hyt/product_upload.go new file mode 100644 index 0000000..35114ed --- /dev/null +++ b/internal/domain/workflow/hyt/product_upload.go @@ -0,0 +1,303 @@ +package hyt + +import ( + "ai_scheduler/internal/config" + errorcode "ai_scheduler/internal/data/error" + toolManager "ai_scheduler/internal/domain/tools" + toolPu "ai_scheduler/internal/domain/tools/hyt/product_upload" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" + "context" + "encoding/json" + "errors" + "fmt" + "log" + "strconv" + "strings" + "sync" + + "github.com/cloudwego/eino/compose" +) + +const WorkflowIDProductUpload = "hyt.productUpload" + +func init() { + runtime.Register(WorkflowIDProductUpload, func(d *runtime.Deps) (runtime.Workflow, error) { + return &productUpload{cfg: d.Conf, toolManager: d.ToolManager}, nil + }) +} + +type productUpload struct { + cfg *config.Config + toolManager *toolManager.Manager + data *ProductUploadWorkflowInput +} + +type ProductUploadWorkflowInput struct { + Text string `mapstructure:"text"` +} + +func (o *productUpload) ID() string { return WorkflowIDProductUpload } + +func (o *productUpload) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { + // 构建工作流 + runnable, err := o.buildWorkflow(ctx) + if err != nil { + return nil, err + } + + o.data = &ProductUploadWorkflowInput{ + Text: rec.UserContent.Text, + } + // 工作流过程调用 + output, err := runnable.Invoke(ctx, o.data) + if err != nil { + errStr := err.Error() + if u := errors.Unwrap(err); u != nil { + errStr = u.Error() + } + return nil, errorcode.WorkflowErr(errStr) + } + + fmt.Printf("workflow output: %v\n", output) + + return output, nil +} + +// ProductIngestData 对应 HYTSupplierProductPropertyTemplateZH 的结构 +type SupplierProductIngestData struct { + BarCode string `json:"条码"` + CategoryName string `json:"分类名称"` + GoodsName string `json:"货品名称"` + GoodsNum string `json:"货品编号"` + GoodsArticleNum string `json:"商品货号"` + Brand string `json:"品牌"` + Unit string `json:"单位"` + Specs string `json:"规格参数"` + Description string `json:"货品说明"` + ShelfLife string `json:"保质期"` + ShelfLifeUnit string `json:"保质期单位"` + Link string `json:"链接"` + Images []string `json:"货品图片"` + EPrice string `json:"电商销售价格"` + SalesPrice string `json:"销售价"` + SupplierPrice string `json:"供应商报价"` + TaxRate string `json:"税率"` + SupplierName string `json:"默认供应商"` + WarehouseName string `json:"默认存放仓库"` + Remark string `json:"备注"` + Length string `json:"长"` + Width string `json:"宽"` + Height string `json:"高"` + Weight string `json:"重量"` + SpuName string `json:"SPU名称"` + SpuCode string `json:"SPU编码"` + Profit string `json:"利润"` +} + +// ProductUploadContext Graph 执行上下文状态 +type ProductUploadContext struct { + mu *sync.Mutex + InputText string + IngestData *SupplierProductIngestData + UploadReq *toolPu.ProductUploadRequest + SupplierName string + WarehouseName string + UploadResp *toolPu.ProductUploadResponse +} + +// buildWorkflow 构建基于 Graph 的并行工作流 +func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*ProductUploadWorkflowInput, map[string]any], error) { + g := compose.NewGraph[*ProductUploadWorkflowInput, map[string]any]() + + // 1. DataMapping 节点: 解析 JSON -> 填充基础 Request + g.AddLambdaNode("data_mapping", compose.InvokableLambda(func(ctx context.Context, in *ProductUploadWorkflowInput) (*ProductUploadContext, error) { + state := &ProductUploadContext{ + mu: &sync.Mutex{}, // 初始化锁 + InputText: in.Text, + UploadReq: &toolPu.ProductUploadRequest{ + GoodsList: make([]toolPu.Goods, 1), // 初始化一个商品 + }, + } + + // 解析用户输入的中文 JSON + var ingestData SupplierProductIngestData + if err := json.Unmarshal([]byte(in.Text), &ingestData); err != nil { + return nil, fmt.Errorf("解析商品数据失败: %w", err) + } + + // 必填校验 + if ingestData.SupplierName == "" { + return nil, errors.New("供应商名称不能为空") + } + if ingestData.WarehouseName == "" { + return nil, errors.New("仓库名称不能为空") + } + if ingestData.Profit == "" { + return nil, errors.New("利润不能为空") + } + if ingestData.TaxRate == "" { + return nil, errors.New("税率不能为空") + } + if ingestData.SupplierPrice == "" { + return nil, errors.New("供应商报价不能为空") + } + + state.IngestData = &ingestData + state.SupplierName = ingestData.SupplierName + state.WarehouseName = ingestData.WarehouseName + + // 映射字段到 UploadReq + goodsInfo := &state.UploadReq.GoodsList[0].GoodsInfo + goodsInfo.Title = ingestData.GoodsName + goodsInfo.Brand = ingestData.Brand + goodsInfo.Category = ingestData.CategoryName + goodsInfo.GoodsBarCode = ingestData.BarCode + goodsInfo.GoodsNum = ingestData.GoodsNum + if goodsInfo.GoodsNum == "" { + goodsInfo.GoodsNum = ingestData.GoodsArticleNum + } + goodsInfo.Unit = ingestData.Unit + goodsInfo.GoodsAttributes = ingestData.Specs + goodsInfo.Introduction = ingestData.Description + goodsInfo.SpuName = ingestData.SpuName + goodsInfo.SpuNum = ingestData.SpuCode + goodsInfo.Weight = ingestData.Weight + + // 数值处理 + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.SalesPrice, "元"), 64); err == nil { + goodsInfo.SalesPrice = val + } + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.EPrice, "元"), 64); err == nil { + goodsInfo.Price = val // 假设电商价为市场价 + } + // 价格兼容 + if goodsInfo.CostPrice == 0 { + goodsInfo.CostPrice = goodsInfo.Price + } + // 税率处理 "13%" -> 13 + taxStr := strings.TrimSuffix(strings.TrimSuffix(ingestData.TaxRate, "%"), " ") + if val, err := strconv.Atoi(taxStr); err == nil { + goodsInfo.TaxRate = val + state.UploadReq.TaxRate = val + } + // 利润处理 + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.Profit, "元"), 64); err == nil { + state.UploadReq.Profit = val + } + + // 图片处理 + for i, imgUrl := range ingestData.Images { + state.UploadReq.GoodsList[0].GoodsMediaList = append(state.UploadReq.GoodsList[0].GoodsMediaList, toolPu.GoodsMedia{ + Url: imgUrl, + Type: 1, // 图片 + Sort: i, + }) + } + + // 默认值字段 + goodsInfo.IsComposeGoods = 2 + goodsInfo.IsBind = 0 + goodsInfo.IsHot = 2 + state.UploadReq.IsDefaultWarehouse = 1 + state.UploadReq.Sort = 1 + + return state, nil + })) + + // 2. 获取供应商ID 节点 + g.AddLambdaNode("get_supplier_id", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { + if state.SupplierName == "" { + return state, errors.New("供应商名称不能为空") + } + + supplierId, err := o.toolManager.Hyt.SupplierSearch.Call(ctx, state.SupplierName) + if err != nil { + // 记录日志,但不阻断流程,可能允许 ID 为 0 + log.Printf("warning: 供应商ID获取失败,%s: %v\n", state.SupplierName, err) + } else { + state.mu.Lock() + defer state.mu.Unlock() + state.UploadReq.SupplierId = supplierId + } + + return state, nil + })) + + // 3. 获取仓库ID 节点 + g.AddLambdaNode("get_warehouse_id", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { + if state.WarehouseName == "" { + return state, errors.New("仓库名称不能为空") + } + + warehouseId, err := o.toolManager.Hyt.WarehouseSearch.Call(ctx, state.WarehouseName) + if err != nil { + log.Printf("warning: 仓库ID获取失败,%s: %v\n", state.WarehouseName, err) + } else { + state.mu.Lock() + defer state.mu.Unlock() + state.UploadReq.WarehouseId = warehouseId + } + + return state, nil + })) + + // 4. 合并/同步节点 + g.AddLambdaNode("merge_node", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { + // 最终校验 + if state.UploadReq.SupplierId == 0 { + return nil, fmt.Errorf("供应商获取失败") + } + if state.UploadReq.WarehouseId == 0 { + return nil, fmt.Errorf("仓库获取失败") + } + return state, nil + })) + + // 5. 上传节点 + g.AddLambdaNode("upload_product", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { + toolRes, err := o.toolManager.Hyt.ProductUpload.Call(ctx, state.UploadReq) + if err != nil { + return nil, fmt.Errorf("商品上传失败") + } + state.UploadResp = toolRes + return state, nil + })) + + // 6. 结果格式化节点 + g.AddLambdaNode("format_output", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (map[string]any, error) { + if state.UploadResp == nil { + return nil, fmt.Errorf("upload response is nil") + } + return map[string]any{ + "预览URL(货易通商品列表)": state.UploadResp.PreviewUrl, + "SPU编码": state.UploadResp.SpuNum, + "商品ID": state.UploadResp.Id, + }, nil + })) + + // 构建边 + // Start -> Mapping + g.AddEdge(compose.START, "data_mapping") + + // 串行化执行以规避 Eino 指针合并问题 + // Mapping -> Supplier + g.AddEdge("data_mapping", "get_supplier_id") + + // Supplier -> Warehouse + g.AddEdge("get_supplier_id", "get_warehouse_id") + + // Warehouse -> Merge (虽然串行了,保留 Merge 节点做校验) + g.AddEdge("get_warehouse_id", "merge_node") + + // Merge -> Upload + g.AddEdge("merge_node", "upload_product") + + // Upload -> Format + g.AddEdge("upload_product", "format_output") + + // Format -> END + g.AddEdge("format_output", compose.END) + + return g.Compile(ctx) +} diff --git a/internal/domain/workflow/provider_set.go b/internal/domain/workflow/provider_set.go new file mode 100644 index 0000000..b9a2815 --- /dev/null +++ b/internal/domain/workflow/provider_set.go @@ -0,0 +1,32 @@ +package workflow + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/component" + "ai_scheduler/internal/domain/repo" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/pkg/utils_ollama" + + toolManager "ai_scheduler/internal/domain/tools" + + "github.com/google/wire" +) + +var ProviderSetWorkflow = wire.NewSet(NewRegistry) + +// NewRegistry 注入共享依赖并注册默认 Registry,确保自注册工作流可被发现 +func NewRegistry(conf *config.Config, llm *utils_ollama.Client, repos *repo.Repos, components *component.Components) *runtime.Registry { + // 步骤1:设置运行时依赖(配置与LLM客户端),供工作流工厂在首次实例化时使用;必须在任何调用 Invoke 之前完成,否则会触发 "deps not set" + runtime.SetDeps(&runtime.Deps{ + Conf: conf, + LLM: llm, + ToolManager: toolManager.NewManager(conf), + Repos: repos, + Component: components, + }) + // 步骤2:创建新的工作流注册表;注册表负责按工作流ID惰性实例化并缓存单例实例,保障并发访问下的安全 + r := runtime.NewRegistry() + // 步骤3:将该注册表设置为全局默认,便于通过 runtime.Default() 获取;自注册的工作流可通过默认注册表被发现并调用 + runtime.SetDefault(r) + return r +} diff --git a/internal/domain/workflow/register_all_gen.go b/internal/domain/workflow/register_all_gen.go new file mode 100644 index 0000000..d30ba64 --- /dev/null +++ b/internal/domain/workflow/register_all_gen.go @@ -0,0 +1,7 @@ +package workflow + +import ( + // 手工维护:在此空导入工作流包以触发其 init() 自注册 + // 新增工作流时,只需在这里添加一行 `_ ""` + _ "ai_scheduler/internal/domain/workflow/zltx" +) diff --git a/internal/domain/workflow/registry.go b/internal/domain/workflow/registry.go new file mode 100644 index 0000000..10b24ef --- /dev/null +++ b/internal/domain/workflow/registry.go @@ -0,0 +1,16 @@ +package workflow + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/component" + toolManager "ai_scheduler/internal/domain/tools" + "ai_scheduler/internal/pkg/utils_ollama" +) + +// 仅声明依赖结构,避免在 workflow 包内实现注册中心逻辑导致循环依赖 +type Deps struct { + Conf *config.Config + LLM *utils_ollama.Client + ToolManager *toolManager.Manager + Component *component.Components +} diff --git a/internal/domain/workflow/runtime/registry.go b/internal/domain/workflow/runtime/registry.go new file mode 100644 index 0000000..f804e1d --- /dev/null +++ b/internal/domain/workflow/runtime/registry.go @@ -0,0 +1,99 @@ +package runtime + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/component" + "ai_scheduler/internal/domain/repo" + toolManager "ai_scheduler/internal/domain/tools" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/utils_ollama" + "context" + "errors" + "sync" +) + +type Workflow interface { + ID() string + // Schema() map[string]any + Invoke(ctx context.Context, requireData *entitys.Recognize) (map[string]any, error) +} + +type Deps struct { + Conf *config.Config + LLM *utils_ollama.Client + ToolManager *toolManager.Manager + Component *component.Components // 基础设施能力 + Repos *repo.Repos // 数据访问 +} + +type Factory func(deps *Deps) (Workflow, error) + +var ( + regMu sync.RWMutex + factories = map[string]Factory{} + deps *Deps + defaultReg *Registry +) + +func Register(id string, f Factory) { + regMu.Lock() + factories[id] = f + regMu.Unlock() +} + +func SetDeps(d *Deps) { + regMu.Lock() + deps = d + regMu.Unlock() +} + +type Registry struct { + mu sync.RWMutex + instances map[string]Workflow +} + +func NewRegistry() *Registry { + return &Registry{instances: make(map[string]Workflow)} +} + +func SetDefault(r *Registry) { + regMu.Lock() + defaultReg = r + regMu.Unlock() +} + +func Default() *Registry { + regMu.RLock() + r := defaultReg + regMu.RUnlock() + return r +} + +func (r *Registry) Invoke(ctx context.Context, id string, rec *entitys.Recognize) (map[string]any, error) { + regMu.RLock() + f, ok := factories[id] + regMu.RUnlock() + if !ok { + return nil, errors.New("workflow not found: " + id) + } + + r.mu.RLock() + w, exists := r.instances[id] + r.mu.RUnlock() + + if !exists { + if deps == nil { + return nil, errors.New("deps not set") + } + nw, err := f(deps) + if err != nil { + return nil, err + } + r.mu.Lock() + r.instances[id] = nw + w = nw + r.mu.Unlock() + } + + return w.Invoke(ctx, rec) +} diff --git a/internal/domain/workflow/zltx/bug_optimization_submit.bak.go b/internal/domain/workflow/zltx/bug_optimization_submit.bak.go new file mode 100644 index 0000000..6ed6bb4 --- /dev/null +++ b/internal/domain/workflow/zltx/bug_optimization_submit.bak.go @@ -0,0 +1,170 @@ +package zltx + +import ( + "context" + "encoding/json" + "errors" + "time" + + "ai_scheduler/internal/domain/component/callback" + "ai_scheduler/internal/domain/repo" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/l_request" + + "github.com/cloudwego/eino/compose" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" +) + +const WorkflowIDBugOptimizationSubmitBak = "bug_optimization_submit_bak" + +func init() { + runtime.Register(WorkflowIDBugOptimizationSubmitBak, func(d *runtime.Deps) (runtime.Workflow, error) { + // 从 Deps.Repos 获取 SessionRepo + return &bugOptimizationSubmitBak{ + manager: d.Component.Callback, + sessionRepo: d.Repos.Session, + }, nil + }) +} + +type bugOptimizationSubmitBak struct { + manager callback.Manager + sessionRepo repo.SessionRepo + redisCli *redis.Client +} + +func (w *bugOptimizationSubmitBak) ID() string { + return WorkflowIDBugOptimizationSubmitBak +} + +type BugOptimizationSubmitBakInput struct { + Ch chan entitys.Response + RequireData *entitys.Recognize +} + +type BugOptimizationSubmitBakOutput struct { + Msg string +} + +type contextWithTaskBak struct { + Input *BugOptimizationSubmitBakInput + TaskID string +} + +func (w *bugOptimizationSubmitBak) Invoke(ctx context.Context, recognize *entitys.Recognize) (map[string]any, error) { + chain, err := w.buildWorkflow(ctx) + if err != nil { + return nil, err + } + + input := &BugOptimizationSubmitBakInput{ + Ch: recognize.Ch, + RequireData: recognize, + } + + out, err := chain.Invoke(ctx, input) + if err != nil { + return nil, err + } + + return map[string]any{"msg": out.Msg}, nil +} + +func (w *bugOptimizationSubmitBak) buildWorkflow(ctx context.Context) (compose.Runnable[*BugOptimizationSubmitBakInput, *BugOptimizationSubmitBakOutput], error) { + c := compose.NewChain[*BugOptimizationSubmitBakInput, *BugOptimizationSubmitBakOutput]() + + // Node 1: Prepare and Call + c.AppendLambda(compose.InvokableLambda(w.prepareAndCall)) + + // Node 2: Wait + c.AppendLambda(compose.InvokableLambda(w.waitCallback)) + + return c.Compile(ctx) +} + +func (w *bugOptimizationSubmitBak) prepareAndCall(ctx context.Context, in *BugOptimizationSubmitBakInput) (*contextWithTaskBak, error) { + // 生成 TaskID + taskID := uuid.New().String() + + // Ext 中获取 sessionId + sessionID := in.RequireData.GetSession() + + // 注册回调映射 + if err := w.manager.Register(ctx, taskID, sessionID); err != nil { + return nil, err + } + + // 查询用户名 + userName := "unknown" + if w.sessionRepo != nil { + name, err := w.sessionRepo.GetUserName(ctx, sessionID) + if err == nil && name != "" { + userName = name + } + } + + // 构建请求参数 + var fileUrls, fileContent string + if len(in.RequireData.UserContent.File) > 0 { + for _, file := range in.RequireData.UserContent.File { + fileUrls += file.FileUrl + "," + fileContent += file.FileRec + "," + } + fileUrls = fileUrls[:len(fileUrls)-1] + fileContent = fileContent[:len(fileContent)-1] + } + + body := map[string]string{ + "mark": in.RequireData.Match.Index, + "text": in.RequireData.UserContent.Text, + "img": fileUrls, + "img_content": fileContent, + "creator": userName, + "task_id": taskID, + } + + request := l_request.Request{ + Url: "https://connector.dingtalk.com/webhook/flow/10352c521dd02104cee9000c", + Method: "POST", + Headers: map[string]string{ + "Content-Type": "application/json", + }, + JsonByte: pkg.JsonByteIgonErr(body), + } + + res, err := request.Send() + if err != nil { + return nil, err + } + + var data map[string]any + if err := json.Unmarshal(res.Content, &data); err != nil { + return nil, err + } + + if success, ok := data["success"].(bool); !ok || !success { + return nil, errors.New("dingtalk flow failed") + } + + entitys.ResLog(in.Ch, in.RequireData.Match.Index, "问题记录中") + entitys.ResLoading(in.Ch, in.RequireData.Match.Index, "问题记录中...") + + return &contextWithTaskBak{Input: in, TaskID: taskID}, nil +} + +func (w *bugOptimizationSubmitBak) waitCallback(ctx context.Context, in *contextWithTask) (*BugOptimizationSubmitBakOutput, error) { + // 阻塞等待回调信号 + // 设置 5 分钟超时 + waitCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + res, err := w.manager.Wait(waitCtx, in.TaskID, 5*time.Minute) + if err != nil { + return nil, err + } + + return &BugOptimizationSubmitBakOutput{Msg: res}, nil +} diff --git a/internal/domain/workflow/zltx/bug_optimization_submit.go b/internal/domain/workflow/zltx/bug_optimization_submit.go new file mode 100644 index 0000000..30ad0bc --- /dev/null +++ b/internal/domain/workflow/zltx/bug_optimization_submit.go @@ -0,0 +1,170 @@ +package zltx + +import ( + "context" + "encoding/json" + "errors" + "time" + + "ai_scheduler/internal/domain/component/callback" + "ai_scheduler/internal/domain/repo" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/l_request" + + "github.com/cloudwego/eino/compose" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" +) + +const WorkflowIDBugOptimizationSubmit = "bug_optimization_submit" + +func init() { + runtime.Register(WorkflowIDBugOptimizationSubmit, func(d *runtime.Deps) (runtime.Workflow, error) { + // 从 Deps.Repos 获取 SessionRepo + return &bugOptimizationSubmit{ + manager: d.Component.Callback, + sessionRepo: d.Repos.Session, + }, nil + }) +} + +type bugOptimizationSubmit struct { + manager callback.Manager + sessionRepo repo.SessionRepo + redisCli *redis.Client +} + +func (w *bugOptimizationSubmit) ID() string { + return WorkflowIDBugOptimizationSubmit +} + +type BugOptimizationSubmitInput struct { + Ch chan entitys.Response + RequireData *entitys.Recognize +} + +type BugOptimizationSubmitOutput struct { + Msg string +} + +type contextWithTask struct { + Input *BugOptimizationSubmitInput + TaskID string +} + +func (w *bugOptimizationSubmit) Invoke(ctx context.Context, recognize *entitys.Recognize) (map[string]any, error) { + chain, err := w.buildWorkflow(ctx) + if err != nil { + return nil, err + } + + input := &BugOptimizationSubmitInput{ + Ch: recognize.Ch, + RequireData: recognize, + } + + out, err := chain.Invoke(ctx, input) + if err != nil { + return nil, err + } + + return map[string]any{"msg": out.Msg}, nil +} + +func (w *bugOptimizationSubmit) buildWorkflow(ctx context.Context) (compose.Runnable[*BugOptimizationSubmitInput, *BugOptimizationSubmitOutput], error) { + c := compose.NewChain[*BugOptimizationSubmitInput, *BugOptimizationSubmitOutput]() + + // Node 1: Prepare and Call + c.AppendLambda(compose.InvokableLambda(w.prepareAndCall)) + + // Node 2: Wait + c.AppendLambda(compose.InvokableLambda(w.waitCallback)) + + return c.Compile(ctx) +} + +func (w *bugOptimizationSubmit) prepareAndCall(ctx context.Context, in *BugOptimizationSubmitInput) (*contextWithTask, error) { + // 生成 TaskID + taskID := uuid.New().String() + + // Ext 中获取 sessionId + sessionID := in.RequireData.GetSession() + + // 注册回调映射 + if err := w.manager.Register(ctx, taskID, sessionID); err != nil { + return nil, err + } + + // 查询用户名 + userName := "unknown" + if w.sessionRepo != nil { + name, err := w.sessionRepo.GetUserName(ctx, sessionID) + if err == nil && name != "" { + userName = name + } + } + + // 构建请求参数 + var fileUrls, fileContent string + if len(in.RequireData.UserContent.File) > 0 { + for _, file := range in.RequireData.UserContent.File { + fileUrls += file.FileUrl + "," + fileContent += file.FileRec + "," + } + fileUrls = fileUrls[:len(fileUrls)-1] + fileContent = fileContent[:len(fileContent)-1] + } + + body := map[string]string{ + "mark": in.RequireData.Match.Index, + "text": in.RequireData.UserContent.Text, + "img": fileUrls, + "img_content": fileContent, + "creator": userName, + "task_id": taskID, + } + + request := l_request.Request{ + Url: "https://connector.dingtalk.com/webhook/flow/10352c521dd02104cee9000c", + Method: "POST", + Headers: map[string]string{ + "Content-Type": "application/json", + }, + JsonByte: pkg.JsonByteIgonErr(body), + } + + res, err := request.Send() + if err != nil { + return nil, err + } + + var data map[string]any + if err := json.Unmarshal(res.Content, &data); err != nil { + return nil, err + } + + if success, ok := data["success"].(bool); !ok || !success { + return nil, errors.New("dingtalk flow failed") + } + + entitys.ResLog(in.Ch, in.RequireData.Match.Index, "问题记录中") + entitys.ResLoading(in.Ch, in.RequireData.Match.Index, "问题记录中...") + + return &contextWithTask{Input: in, TaskID: taskID}, nil +} + +func (w *bugOptimizationSubmit) waitCallback(ctx context.Context, in *contextWithTask) (*BugOptimizationSubmitOutput, error) { + // 阻塞等待回调信号 + // 设置 5 分钟超时 + waitCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + res, err := w.manager.Wait(waitCtx, in.TaskID, 5*time.Minute) + if err != nil { + return nil, err + } + + return &BugOptimizationSubmitOutput{Msg: res}, nil +} diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go new file mode 100644 index 0000000..eee022a --- /dev/null +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -0,0 +1,281 @@ +package zltx + +import ( + "ai_scheduler/internal/config" + + toolZoarb "ai_scheduler/internal/domain/tools/zltx/order_after_reseller_batch" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "errors" + + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +func init() { + runtime.Register("zltx.orderAfterSaleResellerBatch", func(d *runtime.Deps) (runtime.Workflow, error) { + return &orderAfterSaleResellerBatch{cfg: d.Conf.Tools.ZltxOrderAfterSaleResellerBatch}, nil + }) +} + +type orderAfterSaleResellerBatch struct { + cfg config.ToolConfig +} + +// 工作流入参 +type OrderAfterSaleResellerBatchWorkflowInput struct { + Ch chan entitys.Response // 响应通道 + UserInput string // 用户输入文本 + FileContent string // 文件解析结果 + UserHistory entitys.ChatHis // 用户对话历史 + ParameterResult string // 参数解析结果 + Data *OrderAfterSaleResellerBatchNodeData // 节点所需参数 +} + +// 节点所需参数 +type OrderAfterSaleResellerBatchNodeData struct { + OrderNumber []string `json:"orderNumber"` // 订单号 + AfterType string `json:"afterType"` // 处理方式 1.退款 2.扣款 + AfterSalesPrice string `json:"afterSalesPrice"` // 售后金额 + AfterSalesReason string `json:"afterSalesReason"` // 售后原因 + ResponsibleType string `json:"responsibleType"` // 费用承担者 1.供应商 2.商务 3.公司 4.无 + ResponsiblePerson string `json:"responsiblePerson"` // 费用承担供应商 +} + +// 工作流出参 +type OrderAfterSaleResellerBatchWorkflowOutput struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data []*OrderAfterSaleResellerBatchData `json:"data"` +} + +type OrderAfterSaleResellerBatchData struct { + OrderType int `json:"orderType"` + OrderNumber string `json:"orderNumber"` + OrderAmount float64 `json:"orderAmount"` + OrderPrice float64 `json:"orderPrice"` + SignCompany int `json:"signCompany"` + OrderQuantity int `json:"orderQuantity"` + ResellerID int `json:"resellerId"` + ResellerName string `json:"resellerName"` + OurProductID int `json:"ourProductId"` + OurProductTitle string `json:"ourProductTitle"` + Account []string `json:"account"` + Platforms map[int]string `json:"platforms"` + AfterType int `json:"afterType"` // 处理方式 1.退款 2.扣款 + Remark string `json:"remark"` // 售后原因 + AfterAmount float64 `json:"afterAmount"` // 售后金额 + ResponsibleType int `json:"responsibleType"` // 费用承担者 1.供应商 2.商务 3.公司 4.无 + ResponsiblePerson string `json:"responsiblePerson"` // 费用承担供应商 + IsExistsAfterSale bool `json:"isExistsAfterSale"` // 是否已存在售后 + CreateTime int `json:"createTime"` // 创建时间 +} + +// ID 返回工作流唯一标识 +func (o *orderAfterSaleResellerBatch) ID() string { return "zltx.orderAfterSaleResellerBatch" } + +// Invoke 调用原有编排工作流并规范化输出 +func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { + // 构建工作流 + chain, err := o.buildWorkflow(ctx) + if err != nil { + return nil, err + } + + input := &OrderAfterSaleResellerBatchWorkflowInput{ + Ch: rec.Ch, + UserInput: rec.UserContent.Text, + FileContent: "", + UserHistory: rec.ChatHis, + ParameterResult: rec.Match.Parameters, + } + + // 将 Input 注入 Context + ctx = context.WithValue(ctx, workflowInputContextKey{}, input) + + // 工作流过程输出,不关注最终输出 + _, err = chain.Invoke(ctx, input) + if err != nil { + return nil, err + } + + // 工作流 callback + + // 不关心输出,全部在途中输出 + return nil, nil +} + +var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空") + +// contextKey 用于在 Context 中传递 WorkflowInput +type workflowInputContextKey struct{} + +// buildWorkflow 构建工作流 +func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compose.Runnable[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput], error) { + // 定义工作流、出入参 + c := compose.NewChain[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput]() + + // 1.llm 推断参数 (若需要) + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *OrderAfterSaleResellerBatchWorkflowInput) (*schema.Message, error) { + // 已推断完,直接使用 + parameters := in.ParameterResult + return &schema.Message{Content: parameters}, nil + })) + + // 2.参数解析为结构体 + c.AppendLambda(compose.MessageParser( + schema.NewMessageJSONParser[*OrderAfterSaleResellerBatchNodeData](&schema.MessageJSONParseConfig{ + ParseFrom: schema.MessageParseFromContent, + ParseKeyPath: "", // 如果仅需要 parse 子字段,可用 "key.sub.grandsub" + }), + )) + + // 3.参数校验 & 传递 Input + // 注意:为了在后续节点访问 WorkflowInput,这里使用闭包或 Context 传递。 + // Eino Chain 节点间传递的是返回值。这里我们修改节点签名,将 input 一路传下去,或者使用 context。 + // 由于 Eino Chain 是强类型的,这里选择让 Parser 返回的数据结构包含原始 input,或者我们在 Parser 后重新组合。 + // 但最简单的方法是使用 Context 存储 Input (如果 Eino 支持 Context 传递)。Eino 的 Invoke 接受 ctx。 + // 但 Eino Chain 的设计是数据流驱动。 + // 修正方案:修改中间节点的数据结构,或者使用闭包捕获(但闭包捕获的是 build 时的变量,无法捕获运行时 input)。 + // 正确做法:Chain 的节点入参必须是上一个节点的出参。 + // 我们可以把 Parser 的输入改为 Input,输出改为一个包含 Input 和 ParsedData 的结构。 + // 但这里为了最小改动,我们利用 Context 来传递 Input 引用(这在 Eino 中是可行的,因为 ctx 会贯穿整个 Invoke)。 + // 更好的做法是重构 Chain 的数据流,但在保持逻辑不变的前提下,Context 是最快解法。 + + // 为了线程安全,我们在第一个节点把 Input 放入 Context?不行,Chain.Invoke(ctx, input) 的 ctx 是外部传入的。 + // Eino 允许 Lambda 修改 Context 吗?通常不允许。 + + // 让我们重新审视数据流: + // Input -> Lambda1 -> Message -> Parser -> NodeData -> Lambda4 -> ToolResp -> Lambda5 -> Output + // Lambda4 需要 Input.Ch 来发 Loading。 + // Lambda5 需要 Input.Ch 来发 Log/Json,还需要 NodeData。 + + // 根本问题是:中间节点丢失了 Input 信息。 + // 解决方案:使用一个聚合结构体在 Chain 中传递。 + + // 由于要大改数据流比较复杂,这里使用一种技巧: + // 在 Invoke 时,构造一个带有 Input 信息的 Context 传入。 + // 这样每个节点都能从 Context 拿到 Input。 + + // 重新实现 buildWorkflow 以支持 Context 传递 + return o.buildWorkflowWithContext(ctx) +} + +func (o *orderAfterSaleResellerBatch) buildWorkflowWithContext(ctx context.Context) (compose.Runnable[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput], error) { + c := compose.NewChain[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput]() + + // 0. Context 注入节点 (Trick: 利用第一个节点将 Input 注入 Context,但 Eino Chain 无法修改 Context 传递给下游) + // 实际上,我们可以在 Invoke 调用前,在外部包装 Context。 + // 所以这里不需要额外的节点,只需要在 Invoke 时处理。 + // 但 Invoke 是由 Chain 提供的,我们只能控制传入的 ctx。 + // 见下文 Invoke 方法的修改。 + + // 1.llm 推断参数 + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *OrderAfterSaleResellerBatchWorkflowInput) (*schema.Message, error) { + return &schema.Message{Content: in.ParameterResult}, nil + })) + + // 2.参数解析 + c.AppendLambda(compose.MessageParser( + schema.NewMessageJSONParser[*OrderAfterSaleResellerBatchNodeData](&schema.MessageJSONParseConfig{ + ParseFrom: schema.MessageParseFromContent, + }), + )) + + // 3.参数校验 + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *OrderAfterSaleResellerBatchNodeData) (*OrderAfterSaleResellerBatchNodeData, error) { + if len(in.OrderNumber) == 0 { + return nil, ErrInvalidOrderNumbers + } + // 将解析后的 Data 存入 Input (通过 Context 获取 Input) + input := ctx.Value(workflowInputContextKey{}).(*OrderAfterSaleResellerBatchWorkflowInput) + input.Data = in // 这里修改 Input 是安全的,因为 Input 是请求维度的引用 + return in, nil + })) + + // 4.工具调用 + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *OrderAfterSaleResellerBatchNodeData) (*toolZoarb.OrderAfterSaleResellerBatchResponse, error) { + input := ctx.Value(workflowInputContextKey{}).(*OrderAfterSaleResellerBatchWorkflowInput) + + entitys.ResLoading(input.Ch, o.ID(), "数据拉取中") + toolRes, err := toolZoarb.Call(ctx, o.cfg, in.OrderNumber) + entitys.ResLog(input.Ch, o.ID(), "数据拉取完成") + + return toolRes, err + })) + + // 5.结果数据映射 + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *toolZoarb.OrderAfterSaleResellerBatchResponse) (*OrderAfterSaleResellerBatchWorkflowOutput, error) { + return o.dataMapping(ctx, in) + })) + + return c.Compile(ctx) +} + +// 结果数据映射 +func (o *orderAfterSaleResellerBatch) dataMapping(ctx context.Context, in *toolZoarb.OrderAfterSaleResellerBatchResponse) (*OrderAfterSaleResellerBatchWorkflowOutput, error) { + input := ctx.Value(workflowInputContextKey{}).(*OrderAfterSaleResellerBatchWorkflowInput) + + entitys.ResLog(input.Ch, o.ID(), "数据整理中") + + toolResp := &OrderAfterSaleResellerBatchWorkflowOutput{ + Code: in.Code, + Msg: in.Error, + Data: make([]*OrderAfterSaleResellerBatchData, 0, len(in.Data.Data)), + } + + // 转换数据 + for _, item := range in.Data.Data { + // 处理方式 + afterType := util.StringToInt(input.Data.AfterType) + if afterType == 0 { + afterType = 1 // 默认退款 + } + // 费用承担者 + responsibleType := util.StringToInt(input.Data.ResponsibleType) + if responsibleType == 0 { + responsibleType = 4 // 默认无 + } + // 售后金额 + afterSalesPrice := util.StringToFloat64(input.Data.AfterSalesPrice) + if afterSalesPrice == 0 { + afterSalesPrice = item.OrderPrice + } + + toolResp.Data = append(toolResp.Data, &OrderAfterSaleResellerBatchData{ + OrderType: item.OrderType, + OrderNumber: item.OrderNumber, + OrderAmount: item.OrderAmount, + OrderPrice: item.OrderPrice, + SignCompany: item.SignCompany, + OrderQuantity: item.OrderQuantity, + ResellerID: item.ResellerID, + ResellerName: item.ResellerName, + OurProductID: item.OurProductID, + OurProductTitle: item.OurProductTitle, + Account: item.Account, + Platforms: item.Platforms, + AfterType: afterType, + Remark: input.Data.AfterSalesReason, + AfterAmount: afterSalesPrice, + ResponsibleType: responsibleType, + ResponsiblePerson: input.Data.ResponsiblePerson, + }) + } + + // 追加扩展数据 + for _, item := range toolResp.Data { + if extItem, ok := in.Data.ExtData[item.OrderNumber]; ok { + item.IsExistsAfterSale = item.OrderType > 100 // 102 批充&已售后 + item.CreateTime = extItem.SerialCreateTime + } + } + + toolRespJson, _ := json.Marshal(toolResp) + entitys.ResJson(input.Ch, o.ID(), string(toolRespJson)) + + return toolResp, nil +} diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index 226fbc8..ed0902c 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -1,7 +1,23 @@ package entitys -type BotType int +import ( + "ai_scheduler/internal/data/model" -const ( - BugAndQuesDingTalk BotType = iota + 1 + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" ) + +type RequireDataDingTalkBot struct { + Histories []model.AiChatHi + UserInfo *DingTalkUserInfo + Tools []model.AiBotTool + Match *Match + Req *chatbot.BotCallbackDataModel + Ch chan Response + ID int32 +} + +type DingTalkBot struct { + BotIndex string `json:"bot_index"` + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` +} diff --git a/internal/entitys/chat_history.go b/internal/entitys/chat_history.go index a26fa46..b50148e 100644 --- a/internal/entitys/chat_history.go +++ b/internal/entitys/chat_history.go @@ -2,6 +2,9 @@ package entitys import ( "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/model" + "encoding/json" + "log" ) type ChatHistory struct { @@ -14,3 +17,50 @@ type ChatHistory struct { type ChatHisLog struct { HisId int64 `json:"his_id"` } + +type ChatHistQuery struct { + HisID int64 `json:"his_id"` + SessionID string `json:"session_id"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} + +type ChatHisQueryResponse struct { + HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"` + SessionID string `gorm:"column:session_id;not null" json:"session_id"` + Ques string `gorm:"column:ques;not null" json:"ques"` + Ans string `gorm:"column:ans;not null" json:"ans"` + Files string `gorm:"column:files;not null" json:"files"` + Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用,其他为无用" json:"useful"` // 0不评价,1有用,其他为无用 + CreateAt string `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID + TaskName string `gorm:"column:task_name;not null" json:"task_name"` // 任务名称 + Contents []string `gorm:"column:contents" json:"contents"` // 前端回传数据 +} + +func (c *ChatHisQueryResponse) FromModel(chat model.AiChatHi, task model.AiTask) { + c.HisID = chat.HisID + c.SessionID = chat.SessionID + c.Ques = chat.Ques + c.Ans = chat.Ans + c.Files = chat.Files + c.Useful = chat.Useful + c.CreateAt = chat.CreateAt.Format("2006-01-02 15:04:05") + c.TaskID = chat.TaskID + c.TaskName = task.Name + c.Contents = make([]string, 0) + + // 解析Content + if "" != chat.Content { + err := json.Unmarshal([]byte(chat.Content), &c.Contents) + if err != nil { + c.Contents = append(c.Contents, chat.Content) + log.Println("解析Content失败 error: ", err) + } + } +} + +type UpdateContentRequest struct { + HisID int64 `json:"his_id" validate:"required"` + Content string `json:"content" validate:"required"` +} diff --git a/internal/entitys/dingtalk.go b/internal/entitys/dingtalk.go new file mode 100644 index 0000000..93cc825 --- /dev/null +++ b/internal/entitys/dingtalk.go @@ -0,0 +1,23 @@ +package entitys + +import ( + "ai_scheduler/internal/data/constants" + "time" +) + +type DingTalkUserInfo struct { + UserId int `json:"user_id"` + StaffId string `json:"staff_id"` + Name string `json:"name"` + Dept []*Dept `json:"dept"` + IsBoss constants.IsBoss `json:"is_boss"` + IsSenior constants.IsSenior `json:"is_senior"` + HiredDate time.Time `json:"hired_date"` + Extension string `json:"extension"` +} + +type Dept struct { + Name string `json:"name"` + DeptId int `json:"dept_id"` + ToolList string `json:"tool_list"` +} diff --git a/internal/entitys/ollama.go b/internal/entitys/ollama.go new file mode 100644 index 0000000..7297dab --- /dev/null +++ b/internal/entitys/ollama.go @@ -0,0 +1,10 @@ +package entitys + +import ( + "github.com/ollama/ollama/api" +) + +type ToolSelect struct { + Prompt []api.Message + Tools []RegistrationTask +} diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go new file mode 100644 index 0000000..fcd2fe5 --- /dev/null +++ b/internal/entitys/recognize.go @@ -0,0 +1,67 @@ +package entitys + +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/model" + "encoding/json" +) + +type Recognize struct { + SystemPrompt string // 系统提示内容 + UserContent *RecognizeUserContent // 用户输入内容 + ChatHis ChatHis // 会话历史记录 + Tasks []RegistrationTask + Ch chan Response + Match *Match + Ext []byte +} + +type TaskExt struct { + Auth string `json:"auth"` + Session string `json:"session"` + Key string `json:"key"` + SessionInfo model.AiSession + Sys model.AiSy + KnowledgeConf KnowledgeBaseRequest +} + +type RegistrationTask struct { + Name string + Desc string + Index string + TaskConfigDetail TaskConfigDetail +} + +type RecognizeUserContent struct { + Text string // 用户输入的文本内容 + File []*RecognizeFile // 文件内容 + ActionCardUrl string // 操作卡片链接 + Tag string // 工具标签 +} + +type FileData []byte + +type RecognizeFile struct { + FileRec string //文件识别内容 + FileData FileData // 文件数据(二进制格式) + FileType constants.FileType // 文件类型(文件类型,能填最好填,可以跳过一层判断) + FileRealMime string // 文件真实MIME类型 + FileUrl string // 文件下载链接 +} + +func (r *Recognize) GetTaskExt() *TaskExt { + var ext TaskExt + if err := json.Unmarshal(r.Ext, &ext); err != nil { + return nil + } + + return &ext +} + +func (r *Recognize) GetSession() string { + ext := r.GetTaskExt() + if ext == nil { + return "" + } + return ext.Session +} diff --git a/internal/entitys/response.go b/internal/entitys/response.go index cdadc98..39e81ad 100644 --- a/internal/entitys/response.go +++ b/internal/entitys/response.go @@ -1,26 +1,33 @@ package entitys import ( + "ai_scheduler/internal/gateway" "encoding/json" + "github.com/gofiber/websocket/v2" ) type ResponseType string const ( - ResponseJson ResponseType = "json" - ResponseLoading ResponseType = "loading" - ResponseEnd ResponseType = "end" - ResponseStream ResponseType = "stream" - ResponseText ResponseType = "txt" - ResponseImg ResponseType = "img" - ResponseFile ResponseType = "file" - ResponseErr ResponseType = "error" - ResponseLog ResponseType = "log" - ResponseAuth ResponseType = "auth" + ResponseJson ResponseType = "json" + ResponseLoading ResponseType = "loading" + ResponseEnd ResponseType = "end" + ResponseStream ResponseType = "stream" + ResponseText ResponseType = "txt" + ResponseImg ResponseType = "img" + ResponseFile ResponseType = "file" + ResponseErr ResponseType = "error" + ResponseLog ResponseType = "log" + ResponseAuth ResponseType = "auth" + ResponseMarkdown ResponseType = "markdown" + ResponseActionCard ResponseType = "actionCard" ) func ResLog(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -29,6 +36,9 @@ func ResLog(ch chan Response, index string, content string) { } func ResStream(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -37,6 +47,9 @@ func ResStream(ch chan Response, index string, content string) { } func ResJson(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -45,6 +58,9 @@ func ResJson(ch chan Response, index string, content string) { } func ResEnd(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -53,6 +69,9 @@ func ResEnd(ch chan Response, index string, content string) { } func ResText(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -61,6 +80,9 @@ func ResText(ch chan Response, index string, content string) { } func ResLoading(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -68,6 +90,9 @@ func ResLoading(ch chan Response, index string, content string) { } } func ResError(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -100,13 +125,13 @@ func MsgSet(msgType ResponseType, msg string, done bool) []byte { return jsonByte } -func MsgSend(c *websocket.Conn, msg Response) error { +func MsgSend(client *gateway.Client, msg Response) error { // 检查上下文是否已取消 if msg.Type == ResponseText { } jsonByte, _ := json.Marshal(msg) - return c.WriteMessage(websocket.TextMessage, jsonByte) + return client.SendFunc(jsonByte) } func MsgSendByte(c *websocket.Conn, msg []byte) { diff --git a/internal/entitys/types.go b/internal/entitys/types.go index ece5ceb..601f50e 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -3,7 +3,6 @@ package entitys import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/model" - "context" "encoding/json" @@ -31,7 +30,7 @@ type FirstSockRequest struct { type ChatSockRequest struct { Text string `json:"text" binding:"required"` - Img string `json:"img" binding:"required"` + Img string `json:"img" binding:"required"` // 多图片使用 英文, 分割 File string `json:"file" binding:"required"` Tags string `json:"tags" binding:"required"` MarkHis int64 `json:"mark_his" ` @@ -79,7 +78,7 @@ type Tool interface { Name() string Description() string Definition() ToolDefinition - Execute(ctx context.Context, requireData *RequireData) error + Execute(ctx context.Context, requireData *Recognize) error } type ConfigDataHttp struct { @@ -139,8 +138,8 @@ type HisMessage struct { } type HisContext struct { - UserLanguage string `json:"user_language"` - SystemMode string `json:"system_mode"` + UserLanguage string `json:"user_language"` // 用户语言 + SystemMode string `json:"system_mode"` // 系统模式, } type RequireData struct { @@ -150,7 +149,8 @@ type RequireData struct { Histories []model.AiChatHi SessionInfo model.AiSession Tasks []model.AiTask - Match *Match + Task model.AiTask + Match Match Req *ChatSockRequest Auth string Ch chan Response diff --git a/internal/gateway/client.go b/internal/gateway/client.go index b49bf45..1da0dd8 100644 --- a/internal/gateway/client.go +++ b/internal/gateway/client.go @@ -3,33 +3,47 @@ package gateway import ( errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/model" - "encoding/hex" - "fmt" - "github.com/gofiber/websocket/v2" + "context" + "log" "math/rand" + "sync" "time" + + "github.com/google/uuid" + + "github.com/gofiber/websocket/v2" ) var ( - ErrConnClosed = errors.SysErr("连接不存在或已关闭") + ErrConnClosed = errors.SysErrf("连接不存在或已关闭") + rng = rand.New(rand.NewSource(time.Now().UnixNano())) + idBuf = make([]byte, 20) ) type Client struct { - id string // 客户端唯一ID - conn *websocket.Conn // WebSocket 连接 - session string // 会话ID - key string // 应用密钥 - auth string // 用户凭证token - codes []string // 用户权限code - sysInfo *model.AiSy // 系统信息 - tasks []model.AiTask // 任务列表 - sysCode string // 系统编码 + id string // 客户端唯一ID + conn *websocket.Conn // WebSocket 连接 + session string // 会话ID + key string // 应用密钥 + auth string // 用户凭证token + codes []string // 用户权限code + sysInfo *model.AiSy // 系统信息 + tasks []model.AiTask // 任务列表 + sysCode string // 系统编码 + Ctx context.Context + Cancel context.CancelFunc + LastActive time.Time + mu sync.Mutex } -func NewClient(conn *websocket.Conn) *Client { +func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client { + return &Client{ - id: generateClientID(), - conn: conn, + id: generateClientID(), + conn: conn, + Ctx: ctx, + Cancel: cancel, + mu: sync.Mutex{}, } } @@ -63,7 +77,7 @@ func (c *Client) GetCodes() []string { return c.codes } -// GetSysCode 获取系统编码 +// 获取系统编码 func (c *Client) GetSysCode() string { return c.sysCode } @@ -93,22 +107,51 @@ func (c *Client) SetCodes(codes []string) { c.codes = codes } +// Close 关闭客户端连接 +func (c *Client) Close() { + //c.mu.Lock() + //defer c.mu.Unlock() + if c.conn != nil { + _ = c.conn.Close() + c.conn = nil + } +} + // SendFunc 发送消息到客户端 func (c *Client) SendFunc(msg []byte) error { - if c.conn != nil { - return c.conn.WriteMessage(websocket.TextMessage, msg) + return c.SendMessage(websocket.TextMessage, msg) +} + +// 在Client结构体中添加更详细的日志 +func (c *Client) SendMessage(msgType int, msg []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn == nil { + return ErrConnClosed } - return ErrConnClosed + + err := c.conn.WriteMessage(msgType, msg) + if err != nil { + log.Printf("发送消息失败: %v, 客户端ID: %s, 消息类型: %d", + err, c.id, msgType) + } + return err } // 生成唯一的客户端ID func generateClientID() string { - // 使用时间戳+随机数确保唯一性 - timestamp := time.Now().UnixNano() - randomBytes := make([]byte, 4) - rand.Read(randomBytes) - randomStr := hex.EncodeToString(randomBytes) - return fmt.Sprintf("%d%s", timestamp, randomStr) + return uuid.New().String() + //// 1. 时间戳 + //timestamp := time.Now().UnixNano() + //binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp)) + // + //// 2. 随机数(4字节) + //binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32()) + // + //// 3. 十六进制编码 + //n := pkg.HexEncode(idBuf[:12], idBuf[12:]) + //return string(idBuf[12 : 12+n]) } // 连接数据验证和收集 @@ -136,3 +179,34 @@ func (c *Client) DataAuth() (err error) { } return } + +// 总结:目前绝大多数浏览器不支持直接发送WebSocket Ping帧,因此在实际开发中,应该实现应用层ping机制作为主要心跳检测方案,todo 同时保留对未来可能的原生支持的兼容检测。 +func (c *Client) InitHeartbeat(timeoutSecond time.Duration) { + ticker := time.NewTicker(timeoutSecond * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + //*2是防止丢包,连续丢包两次,再加5s网络延迟容错 + if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错 + log.Println("心跳超时", "clientId", c.id) + err := c.SendMessage(websocket.CloseMessage, []byte("Heartbeat timeout")) + if err != nil { + log.Println("发送心跳超时消息失败", err) + } + c.Close() + return + } + case <-c.Ctx.Done(): + return + } + } +} + +// 在Client结构体中添加ReadMessage方法 +func (c *Client) ReadMessage() (messageType int, message []byte, err error) { + if c.conn == nil { + return 0, nil, ErrConnClosed + } + return c.conn.ReadMessage() +} diff --git a/internal/gateway/gateway.go b/internal/gateway/gateway.go index 0f0e6f3..6e09bdc 100644 --- a/internal/gateway/gateway.go +++ b/internal/gateway/gateway.go @@ -2,7 +2,9 @@ package gateway import ( "errors" + "log" "sync" + "time" ) type Gateway struct { @@ -20,14 +22,29 @@ func NewGateway() *Gateway { func (g *Gateway) AddClient(c *Client) { g.mu.Lock() - defer g.mu.Unlock() + defer func() { + g.mu.Unlock() + //心跳开始计时 + c.LastActive = time.Now() + log.Println("client connected:", c.GetID()) + log.Println("客户端已连接") + }() g.clients[c.GetID()] = c } -func (g *Gateway) RemoveClient(clientID string) { +func (g *Gateway) Cleanup(clientID string) { g.mu.Lock() - defer g.mu.Unlock() - delete(g.clients, clientID) + // 从网关管理中移除客户端 + defer func() { + if c, ex := g.clients[clientID]; ex { + delete(g.clients, clientID) + c.Close() + c.Cancel() + } + g.mu.Unlock() + log.Println("client disconnected:", clientID) + }() + // 从所有绑定的UID列表中移除该客户端 for uid, list := range g.uidMap { newList := []string{} for _, cid := range list { @@ -37,6 +54,7 @@ func (g *Gateway) RemoveClient(clientID string) { } g.uidMap[uid] = newList } + } func (g *Gateway) SendToAll(msg []byte) { @@ -63,6 +81,7 @@ func (g *Gateway) BindUid(clientID, uid string) error { return errors.New("client not found") } g.uidMap[uid] = append(g.uidMap[uid], clientID) + log.Printf("绑定 clientId %s -> uid:%s\n", clientID, uid) return nil } diff --git a/internal/pkg/dingtalk/contact_client.go b/internal/pkg/dingtalk/contact_client.go index 4461995..ed3e3dd 100644 --- a/internal/pkg/dingtalk/contact_client.go +++ b/internal/pkg/dingtalk/contact_client.go @@ -54,10 +54,10 @@ func (c *ContactClient) SearchUserOne(accessToken string, name string) (string, } if resp.Body == nil { - return "", errorcode.ParamErr("empty response body") + return "", errorcode.ParamErrf("empty response body") } if len(resp.Body.List) == 0 { - return "", errorcode.ParamErr("empty user list") + return "", errorcode.ParamErrf("empty user list") } userId := resp.Body.List[0] diff --git a/internal/pkg/dingtalk/notable_client.go b/internal/pkg/dingtalk/notable_client.go index 7cfe04f..885e111 100644 --- a/internal/pkg/dingtalk/notable_client.go +++ b/internal/pkg/dingtalk/notable_client.go @@ -3,6 +3,8 @@ package dingtalk import ( "ai_scheduler/internal/config" errorcode "ai_scheduler/internal/data/error" + "encoding/json" + "time" openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client" notable "github.com/alibabacloud-go/dingtalk/notable_1_0" @@ -67,8 +69,75 @@ func (c *NotableClient) UpdateRecord(accessToken string, req *UpdateRecordReq) ( } if resp.Body == nil { - return false, errorcode.ParamErr("empty response body") + return false, errorcode.ParamErrf("empty response body") } return true, nil } + +type InsertRecordReq struct { + BaseId string + SheetIdOrName string + OperatorId string + CreatorUnionId string + Content string + AttachmentUrl string +} + +func (c *NotableClient) InsertRecord(accessToken string, req *InsertRecordReq) (string, error) { + // 默认使用“数据表” + if req.SheetIdOrName == "" { + req.SheetIdOrName = "数据表" + } + + headers := ¬able.InsertRecordsHeaders{} + headers.XAcsDingtalkAccessToken = tea.String(accessToken) + resp, err := c.cli.InsertRecordsWithOptions( + tea.String(req.BaseId), + tea.String(req.SheetIdOrName), + ¬able.InsertRecordsRequest{ + OperatorId: tea.String(req.OperatorId), + Records: []*notable.InsertRecordsRequestRecords{ + { + Fields: map[string]any{ + "创建日期": time.Now().Format(time.DateTime), + "需求内容": req.Content, + "提交人": []map[string]any{ + { + "unionId": req.CreatorUnionId, + }, + }, + "附件": map[string]any{ + "link": req.AttachmentUrl, + }, + }, + }, + }, + }, headers, &util.RuntimeOptions{}) + if err != nil { + return "", err + } + + if resp.Body == nil || resp.Body.Value == nil || len(resp.Body.Value) == 0 { + return "", errorcode.ParamErrf("empty response body") + } + + return *resp.Body.Value[0].Id, nil +} + +func (c *NotableClient) GetHTTPStatus(err error) int { + if sdkErr, ok := err.(*tea.SDKError); ok { + if sdkErr.StatusCode != nil { + return *sdkErr.StatusCode + } + if sdkErr.Data != nil { + var m struct { + StatusCode int `json:"statusCode"` + } + if json.Unmarshal([]byte(*sdkErr.Data), &m) == nil { + return m.StatusCode + } + } + } + return 0 // 0 = 非 HTTP 错误 +} diff --git a/internal/pkg/func.go b/internal/pkg/func.go index 4e6481a..32c404b 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/url" + "strconv" "strings" ) @@ -55,3 +56,112 @@ func ValidateImageURL(rawURL string) error { return nil } + +// hexEncode 将 src 的二进制数据编码为十六进制字符串,写入 dst,返回写入长度 +func HexEncode(src, dst []byte) int { + const hextable = "0123456789abcdef" + for i := 0; i < len(src); i++ { + dst[i*2] = hextable[src[i]>>4] + dst[i*2+1] = hextable[src[i]&0xf] + } + return len(src) * 2 +} + +// Ter 三目运算 Ter(true, 1, 2) +func Ter[T any](cond bool, a, b T) T { + if cond { + return a + } + return b +} + +// StringToSlice [num,num]转slice +func StringToSlice(s string) ([]int, error) { + // 1. 去掉两端的方括号 + trimmed := strings.Trim(s, "[]") + + // 2. 按逗号分割 + parts := strings.Split(trimmed, ",") + + // 3. 转换为 []int + result := make([]int, 0, len(parts)) + for _, part := range parts { + num, err := strconv.Atoi(strings.TrimSpace(part)) + if err != nil { + return nil, err + } + result = append(result, num) + } + return result, nil +} + +// Difference 差集 +func Difference[T comparable](a, b []T) []T { + // 创建 b 的映射(T 必须是可比较的类型) + bMap := make(map[T]struct{}, len(b)) + for _, item := range b { + bMap[item] = struct{}{} + } + + var diff []T // 修正为 []T 而非 []int + for _, item := range a { + if _, found := bMap[item]; !found { + diff = append(diff, item) + } + } + return diff +} + +// SliceStringToInt []string=>[]int +func SliceStringToInt(strSlice []string) []int { + numSlice := make([]int, len(strSlice)) + for i, str := range strSlice { + num, err := strconv.Atoi(str) + if err != nil { + return nil + } + numSlice[i] = num + } + return numSlice +} + +// SliceIntToString []int=>[]string +func SliceIntToString(slice []int) []string { + strSlice := make([]string, len(slice)) // len=cap=len(slice) + for i, num := range slice { + strSlice[i] = strconv.Itoa(num) // 直接赋值,无 append + } + return strSlice +} + +// SafeReplace 替换字符串中的 %s,并自动转义特殊字符(如 ") +/** + * SafeReplace 函数用于安全地替换模板字符串中的占位符 + * @param template 原始模板字符串 + * @param replaceTag 要被替换的占位符(如 "%s") + * @param replacements 可变参数,用于替换占位符的字符串 + * @return 返回替换后的字符串和可能的错误 + */ +func SafeReplace(template string, replaceTag string, replacements ...string) (string, error) { + // 如果没有提供替换参数,直接返回原始模板 + if len(replacements) == 0 { + return template, nil + } + + // 检查模板中 %s 的数量是否匹配替换参数 + expectedReplacements := strings.Count(template, replaceTag) + if expectedReplacements != len(replacements) { + return "", fmt.Errorf("模板需要 %d 个替换参数,但提供了 %d 个", expectedReplacements, len(replacements)) + } + + // 逐个替换 %s,并转义特殊字符 + for _, rep := range replacements { + // 转义特殊字符(如 ", \, \n 等) + escaped := strconv.Quote(rep) + // 去掉 strconv.Quote 添加的额外引号 + escaped = escaped[1 : len(escaped)-1] + template = strings.Replace(template, replaceTag, escaped, 1) + } + + return template, nil +} diff --git a/internal/pkg/provider_set.go b/internal/pkg/provider_set.go index 93d9180..f8fadac 100644 --- a/internal/pkg/provider_set.go +++ b/internal/pkg/provider_set.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/pkg/dingtalk" "ai_scheduler/internal/pkg/utils_langchain" "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/pkg/utils_vllm" "github.com/google/wire" ) @@ -13,6 +14,7 @@ var ProviderSetClient = wire.NewSet( NewGormDb, utils_langchain.NewUtilLangChain, utils_ollama.NewClient, + utils_vllm.NewClient, NewSafeChannelPool, dingtalk.NewOldClient, dingtalk.NewContactClient, diff --git a/internal/pkg/rec_extra/ext.go b/internal/pkg/rec_extra/ext.go new file mode 100644 index 0000000..b8b3c86 --- /dev/null +++ b/internal/pkg/rec_extra/ext.go @@ -0,0 +1,22 @@ +package rec_extra + +import ( + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "encoding/json" +) + +func SetTaskRecExt(requireData *entitys.RequireData, rec *entitys.Recognize) { + TaskExt := entitys.TaskExt{ + Auth: requireData.Auth, + Session: requireData.Session, + Key: requireData.Key, + Sys: requireData.Sys, + } + rec.Ext = pkg.JsonByteIgonErr(TaskExt) +} + +func GetTaskRecExt(rec *entitys.Recognize) (ext entitys.TaskExt, err error) { + err = json.Unmarshal(rec.Ext, &ext) + return ext, err +} diff --git a/internal/pkg/util/ctx.go b/internal/pkg/util/ctx.go new file mode 100644 index 0000000..1998805 --- /dev/null +++ b/internal/pkg/util/ctx.go @@ -0,0 +1,25 @@ +package util + +import ( + "context" +) + +type ContextKey string + +const ( + ContextKeyToken ContextKey = "token" +) + +// token 写入上下文 +func SetTokenToContext(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, ContextKeyToken, token) +} + +// 从上下文获取token +func GetTokenFromContext(ctx context.Context) string { + token, ok := ctx.Value(ContextKeyToken).(string) + if !ok { + return "" + } + return token +} diff --git a/internal/pkg/util/map.go b/internal/pkg/util/map.go new file mode 100644 index 0000000..5ca80c8 --- /dev/null +++ b/internal/pkg/util/map.go @@ -0,0 +1,14 @@ +package util + +import "encoding/json" + +// StructToMap 将结构体转换为 map[string]any +func StructToMap(v any) (map[string]any, error) { + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + var m map[string]any + err = json.Unmarshal(b, &m) + return m, err +} diff --git a/internal/pkg/util/point.go b/internal/pkg/util/point.go new file mode 100644 index 0000000..beceddf --- /dev/null +++ b/internal/pkg/util/point.go @@ -0,0 +1,6 @@ +package util + +// AnyToPoint converts any value to a pointer. +func AnyToPoint[T any](v T) *T { + return &v +} diff --git a/internal/pkg/util/string.go b/internal/pkg/util/string.go index 1fc898a..9dd4056 100644 --- a/internal/pkg/util/string.go +++ b/internal/pkg/util/string.go @@ -32,3 +32,13 @@ func StringToFloat64(s string) float64 { i, _ := strconv.ParseFloat(s, 64) return i } + +// 是否包含在数组中 +func Contains[T comparable](strings []T, str T) bool { + for _, s := range strings { + if s == str { + return true + } + } + return false +} diff --git a/internal/pkg/util/time.go b/internal/pkg/util/time.go new file mode 100644 index 0000000..1c1bf3e --- /dev/null +++ b/internal/pkg/util/time.go @@ -0,0 +1,41 @@ +package util + +import "time" + +// 判断当前时间是否在时间窗口内 +// ts 时间戳字符串,支持秒级或毫秒级 +// window 时间窗口,例如 10 * time.Minute +func IsInTimeWindow(ts string, window time.Duration) bool { + // 期望毫秒时间戳或秒级,简单容错 + // 尝试解析为整数 + var n int64 + for _, base := range []int64{1, 1000} { // 秒或毫秒 + if v, ok := parseInt64(ts); ok { + n = v + // 归一为毫秒 + if base == 1 && len(ts) <= 10 { + n = n * 1000 + } + now := time.Now().UnixMilli() + diff := now - n + if diff < 0 { + diff = -diff + } + if diff <= window.Milliseconds() { + return true + } + } + } + return false +} + +func parseInt64(s string) (int64, bool) { + var n int64 + for _, ch := range s { + if ch < '0' || ch > '9' { + return 0, false + } + n = n*10 + int64(ch-'0') + } + return n, true +} diff --git a/internal/pkg/utils_ollama/client.go b/internal/pkg/utils_ollama/client.go index 91640f0..fa88afa 100644 --- a/internal/pkg/utils_ollama/client.go +++ b/internal/pkg/utils_ollama/client.go @@ -90,6 +90,25 @@ func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messa return } +func (c *Client) Chat(ctx context.Context, model string, messages []api.Message) (res api.ChatResponse, err error) { + // 构建聊天请求 + req := &api.ChatRequest{ + Model: model, + Messages: messages, + Stream: new(bool), // 设置为false,不使用流式响应 + Think: &api.ThinkValue{Value: false}, + } + err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { + res = resp + return nil + }) + if err != nil { + return + } + + return +} + func (c *Client) Generation(ctx context.Context, generateRequest *api.GenerateRequest) (result api.GenerateResponse, err error) { err = c.client.Generate(ctx, generateRequest, func(resp api.GenerateResponse) error { result = resp diff --git a/internal/pkg/utils_vllm/client.go b/internal/pkg/utils_vllm/client.go new file mode 100644 index 0000000..c8c4aec --- /dev/null +++ b/internal/pkg/utils_vllm/client.go @@ -0,0 +1,90 @@ +package utils_vllm + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/base64" + + "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/schema" +) + +type Client struct { + model *openai.ChatModel + config *config.Config +} + +func NewClient(config *config.Config) (*Client, func(), error) { + m, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{ + BaseURL: config.Vllm.BaseURL, + Model: config.Vllm.VlModel, + Timeout: config.Vllm.Timeout, + }) + if err != nil { + return nil, nil, err + } + c := &Client{model: m, config: config} + cleanup := func() {} + return c, cleanup, nil +} + +func (c *Client) Chat(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) { + return c.model.Generate(ctx, msgs) +} + +func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt string, imgURLs []string) (*schema.Message, error) { + in := []*schema.Message{ + { + Role: schema.System, + Content: systemPrompt, + }, + { + Role: schema.User, + }, + } + parts := []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: userPrompt}, + } + for i := range imgURLs { + u := imgURLs[i] + parts = append(parts, schema.MessageInputPart{ + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{URL: &u}, + Detail: schema.ImageURLDetailHigh, + }, + }) + } + + in[1].UserInputMultiContent = parts + return c.model.Generate(ctx, in) +} + +// 识别图片by二进制文件 +func (c *Client) RecognizeWithImgBytes(ctx context.Context, systemPrompt, userPrompt string, imgBytes []byte, imgType string) (*schema.Message, error) { + in := []*schema.Message{ + { + Role: schema.System, + Content: systemPrompt, + }, + { + Role: schema.User, + }, + } + parts := []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: userPrompt}, + } + parts = append(parts, schema.MessageInputPart{ + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + MIMEType: imgType, + Base64Data: util.AnyToPoint(base64.StdEncoding.EncodeToString(imgBytes)), + }, + }, + }) + + in[1].UserInputMultiContent = parts + return c.model.Generate(ctx, in) +} diff --git a/internal/pkg/utils_vllm/client_test.go b/internal/pkg/utils_vllm/client_test.go new file mode 100644 index 0000000..e71390d --- /dev/null +++ b/internal/pkg/utils_vllm/client_test.go @@ -0,0 +1,66 @@ +package utils_vllm + +import ( + "ai_scheduler/internal/config" + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/cloudwego/eino/schema" +) + +func newMockServer() *httptest.Server { + h := http.NewServeMux() + h.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"cmpl-1","object":"chat.completion","created":173,"model":"x","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`)) + }) + return httptest.NewServer(h) +} + +func Test_Vllm_Chat_Generate(t *testing.T) { + cfg, err := config.LoadConfig("../../../config/config_test.yaml") + if err != nil { + t.Fatalf("load config: %v", err) + } + + ctx := context.Background() + client, _, err := NewClient(cfg) + if err != nil { + t.Fatalf("new client: %v", err) + } + + msgs := []*schema.Message{{Role: schema.User, Content: "hi"}} + out, err := client.Chat(ctx, msgs) + if err != nil { + t.Fatalf("chat generate: %v", err) + } + if out == nil || out.Content != "ok" { + t.Fatalf("unexpected content: %v", out) + } + + t.Logf("结果: %v", out) +} + +func Test_Vllm_RecognizeWithImg(t *testing.T) { + cfg, err := config.LoadConfig("../../../config/config_test.yaml") + if err != nil { + t.Fatalf("load config: %v", err) + } + + ctx := context.Background() + client, _, err := NewClient(cfg) + if err != nil { + t.Fatalf("new client: %v", err) + } + + out, err := client.RecognizeWithImg(ctx, "sys", "user", []string{"https://img0.baidu.com/it/u=910428455,194434251&fm=253&app=138&f=JPEG?w=1122&h=800"}) + if err != nil { + t.Fatalf("recognize with img: %v", err) + } + if out == nil || out.Content != "ok" { + t.Fatalf("unexpected content: %v", out) + } +} diff --git a/internal/pkg/validate/validate.go b/internal/pkg/validate/validate.go new file mode 100644 index 0000000..de58134 --- /dev/null +++ b/internal/pkg/validate/validate.go @@ -0,0 +1,45 @@ +package validate + +import ( + "fmt" + "github.com/go-playground/locales/zh" + ut "github.com/go-playground/universal-translator" + "github.com/go-playground/validator/v10" + zh_translations "github.com/go-playground/validator/v10/translations/zh" + "reflect" +) + +func Struct(s interface{}) (errMsg []string, err error) { + // 创建验证器实例 + validate := validator.New() + + // 创建中文翻译器 + zh_ch := zh.New() + uni := ut.New(zh_ch, zh_ch) + trans, _ := uni.GetTranslator("zh") + + //注册一个函数,获取struct tag里自定义的label作为字段名 + validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := fld.Tag.Get("label") + return name + }) + + // 注册中文翻译器到验证器 + _ = zh_translations.RegisterDefaultTranslations(validate, trans) + + // 验证结构体 + err = validate.Struct(s) + if err != nil { + // 处理验证错误 + if _, ok := err.(*validator.InvalidValidationError); ok { + fmt.Println("处理验证错误error:", err) + errMsg = append(errMsg, err.Error()) + } else { + for _, v := range err.(validator.ValidationErrors) { + errMsg = append(errMsg, v.Translate(trans)) + } + } + } + + return +} diff --git a/internal/server/ding_talk_bot.go b/internal/server/ding_talk_bot.go new file mode 100644 index 0000000..2eb31c6 --- /dev/null +++ b/internal/server/ding_talk_bot.go @@ -0,0 +1,107 @@ +package server + +import ( + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/services" + "context" + "fmt" + "sync" + + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/client" + "github.com/go-kratos/kratos/v2/log" +) + +type DingBotServiceInterface interface { + GetServiceCfg() ([]entitys.DingTalkBot, error) + OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) +} + +type DingTalkBotServer struct { + Clients map[string]*client.StreamClient +} + +// NewDingTalkBotServer 批量注册钉钉客户端cli +// 这里支持两种方式,一种是完全独立service,一种是直接用现成的service +// 独立的service,在本页的ProvideAllDingBotServices方法进行注册 +// 现成的service参考services->dtalk_bot.go +// 具体使用请根据实际业务需求 +func NewDingTalkBotServer( + services []DingBotServiceInterface, +) *DingTalkBotServer { + clients := make(map[string]*client.StreamClient) + for _, service := range services { + serviceConfigs, err := service.GetServiceCfg() + for _, serviceConf := range serviceConfigs { + if serviceConf.ClientId == "" || serviceConf.ClientSecret == "" { + continue + } + cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service) + if cli == nil { + log.Info("%s客户端初始失败:%s", serviceConf.BotIndex, err.Error()) + continue + } + clients[serviceConf.BotIndex] = cli + } + } + return &DingTalkBotServer{ + Clients: clients, + } +} + +func ProvideAllDingBotServices( + dingBotSvc *services.DingBotService, +) []DingBotServiceInterface { + return []DingBotServiceInterface{dingBotSvc} +} + +func (d *DingTalkBotServer) Run(ctx context.Context, botIndex string) { + if botIndex == "" { + log.Info("未指定机器人索引,跳过启动") + return + } + + var targets []string + switch { + case botIndex == "All": + targets = make([]string, 0, len(d.Clients)) + for name := range d.Clients { + targets = append(targets, name) + } + default: + if _, exists := d.Clients[botIndex]; exists { + targets = []string{botIndex} + } else { + log.Infof("未找到索引为 %s 的机器人", botIndex) + return + } + } + + var wg sync.WaitGroup + errors := make([]error, 0, len(targets)) + + for _, name := range targets { + wg.Add(1) + go func(name string) { + defer wg.Done() + err := d.Clients[name].Start(ctx) + if err != nil { + log.Errorf("%s 启动失败: %v", name, err) + errors = append(errors, fmt.Errorf("%s: %w", name, err)) + } else { + log.Infof("%s 启动成功", name) + } + }(name) + } + + wg.Wait() + if len(errors) > 0 { + log.Errorf("部分机器人启动失败,总数: %d, 成功: %d, 失败: %d", + len(targets), len(targets)-len(errors), len(errors)) + } +} +func DingBotServerInit(clientId string, clientSecret string, service DingBotServiceInterface) (cli *client.StreamClient) { + cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret))) + cli.RegisterChatBotCallbackRouter(service.OnChatBotMessageReceived) + return +} diff --git a/internal/server/http.go b/internal/server/http.go index 4cdc393..53446c8 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -11,24 +11,28 @@ import ( ) type HTTPServer struct { - app *fiber.App - service *services.ChatService - session *services.SessionService - gateway *gateway.Gateway - callback *services.CallbackService + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway + callback *services.CallbackService + chatHis *services.HistoryService + capabilityService *services.CapabilityService } func NewHTTPServer( - service *services.ChatService, - session *services.SessionService, - task *services.TaskService, - gateway *gateway.Gateway, - callback *services.CallbackService, + service *services.ChatService, + session *services.SessionService, + task *services.TaskService, + gateway *gateway.Gateway, + callback *services.CallbackService, + chatHis *services.HistoryService, + capabilityService *services.CapabilityService, ) *fiber.App { - //构建 server - app := initRoute() - router.SetupRoutes(app, service, session, task, gateway, callback) - return app + //构建 server + app := initRoute() + router.SetupRoutes(app, service, session, task, gateway, callback, chatHis, capabilityService) + return app } func initRoute() *fiber.App { diff --git a/internal/server/provider_set.go b/internal/server/provider_set.go index dd6b4b4..d5cef3d 100644 --- a/internal/server/provider_set.go +++ b/internal/server/provider_set.go @@ -1,5 +1,12 @@ package server -import "github.com/google/wire" +import ( + "github.com/google/wire" +) -var ProviderSetServer = wire.NewSet(NewServers, NewHTTPServer) +var ProviderSetServer = wire.NewSet( + NewServers, + NewHTTPServer, + ProvideAllDingBotServices, + NewDingTalkBotServer, +) diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 5c935d7..091c85f 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -15,20 +15,30 @@ import ( ) type RouterServer struct { - app *fiber.App - service *services.ChatService - session *services.SessionService - gateway *gateway.Gateway + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway + chatHist *services.HistoryService + capabilityService *services.CapabilityService } // SetupRoutes 设置路由 -func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway, callbackService *services.CallbackService) { +func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, + gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService, + capabilityService *services.CapabilityService, +) { app.Use(func(c *fiber.Ctx) error { // 设置 CORS 头 c.Set("Access-Control-Allow-Origin", "*") c.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + // AI能力调用路由,设置不同的 CORS 头 + if strings.HasPrefix(c.Path(), "/api/v1/capability") { + c.Set("Access-Control-Allow-Headers", "Content-Type, X-Source-Key, X-Timestamp") + } + // 如果是预检请求(OPTIONS),直接返回 204 if c.Method() == "OPTIONS" { return c.SendStatus(fiber.StatusNoContent) // 204 @@ -42,7 +52,6 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi r := app.Group("api/v1/") registerResponse(r) - // 注册 CORS 中间件 r.Get("/health", func(c *fiber.Ctx) error { c.Response().SetBody([]byte("1")) return nil @@ -77,15 +86,20 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi return ctx.Status(400).SendString("unknown action") } }) + + // 会话历史 + r.Post("/chat/history/list", chatHist.List) + r.Post("/chat/history/update/content", chatHist.UpdateContent) + + // 能力 + r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取 + r.Post("/capability/product/ingest/:thread_id/confirm", capabilityService.ProductIngestConfirm) // 商品数据提取确认 } func routerSocket(app *fiber.App, chatService *services.ChatService) { ws := app.Group("ws/v1/") // WebSocket 路由配置 - ws.Get("/chat", websocket.New(func(c *websocket.Conn) { - // 可以在这里添加握手前的中间件逻辑(如头校验) - chatService.Chat(c) // 调用实际的 Chat 处理函数 - }, websocket.Config{ + ws.Get("/chat", websocket.New(chatService.Chat, websocket.Config{ // 可选配置:跨域检查、最大负载大小等 HandshakeTimeout: 10 * time.Second, //Subprotocols: []string{"json", "msgpack"}, diff --git a/internal/server/server.go b/internal/server/server.go index 488ef38..02c8f84 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,13 +1,27 @@ package server -import "github.com/gofiber/fiber/v2" +import ( + "ai_scheduler/internal/config" + + "github.com/gofiber/fiber/v2" +) type Servers struct { - HttpServer *fiber.App + cfg *config.Config + HttpServer *fiber.App + DingBotServer *DingTalkBotServer } -func NewServers(fiber *fiber.App) *Servers { +func NewServers(cfg *config.Config, fiber *fiber.App, DingBotServer *DingTalkBotServer) *Servers { return &Servers{ - HttpServer: fiber, + HttpServer: fiber, + cfg: cfg, + DingBotServer: DingBotServer, } } + +//func DingBotServerInit(clientId string, clientSecret string, cfg *config.Config, handler *do.Handle, do *do.Do) (cli *client.StreamClient) { +// cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret))) +// cli.RegisterChatBotCallbackRouter(services.NewDingBotService(cfg, handler, do).OnChatBotMessageReceived) +// return +//} diff --git a/internal/services/callback.go b/internal/services/callback.go index 415f7cb..c68e1c3 100644 --- a/internal/services/callback.go +++ b/internal/services/callback.go @@ -4,12 +4,13 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" errorcode "ai_scheduler/internal/data/error" + "ai_scheduler/internal/domain/component/callback" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/dingtalk" "ai_scheduler/internal/pkg/util" - "ai_scheduler/internal/tools_bot" + "ai_scheduler/internal/tool_callback" "context" "encoding/json" "strings" @@ -25,17 +26,17 @@ type CallbackService struct { dingtalkOldClient *dingtalk.OldClient dingtalkContactClient *dingtalk.ContactClient dingtalkNotableClient *dingtalk.NotableClient - botTool *tools_bot.BotTool + callbackManager callback.Manager } -func NewCallbackService(cfg *config.Config, gateway *gateway.Gateway, dingtalkOldClient *dingtalk.OldClient, dingtalkContactClient *dingtalk.ContactClient, dingtalkNotableClient *dingtalk.NotableClient, botTool *tools_bot.BotTool) *CallbackService { +func NewCallbackService(cfg *config.Config, gateway *gateway.Gateway, dingtalkOldClient *dingtalk.OldClient, dingtalkContactClient *dingtalk.ContactClient, dingtalkNotableClient *dingtalk.NotableClient, callbackManager callback.Manager) *CallbackService { return &CallbackService{ cfg: cfg, gateway: gateway, dingtalkOldClient: dingtalkOldClient, dingtalkContactClient: dingtalkContactClient, dingtalkNotableClient: dingtalkNotableClient, - botTool: botTool, + callbackManager: callbackManager, } } @@ -77,20 +78,21 @@ func (s *CallbackService) Callback(c *fiber.Ctx) error { ts := strings.TrimSpace(c.Get("X-Timestamp")) // 时间窗口(如果提供了 ts 则校验,否则跳过),窗口 5 分钟 - if ts != "" && !validateTimestamp(ts, 5*time.Minute) { + // if ts != "" && !validateTimestamp(ts, 5*time.Minute) { + if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { return errorcode.AuthNotFound } // 解析 Envelope var env Envelope if err := json.Unmarshal(c.Body(), &env); err != nil { - return errorcode.ParamErr("invalid json: %v", err) + return errorcode.ParamErrf("invalid json: %v", err) } if env.Action == "" || env.TaskID == "" { - return errorcode.ParamErr("missing action/task_id") + return errorcode.ParamErrf("missing action/task_id") } if env.Data == nil { - return errorcode.ParamErr("missing data") + return errorcode.ParamErrf("missing data") } switch sourceKey { @@ -101,48 +103,51 @@ func (s *CallbackService) Callback(c *fiber.Ctx) error { } } -func validateTimestamp(ts string, window time.Duration) bool { - // 期望毫秒时间戳或秒级,简单容错 - // 尝试解析为整数 - var n int64 - for _, base := range []int64{1, 1000} { // 秒或毫秒 - if v, ok := parseInt64(ts); ok { - n = v - // 归一为毫秒 - if base == 1 && len(ts) <= 10 { - n = n * 1000 - } - now := time.Now().UnixMilli() - diff := now - n - if diff < 0 { - diff = -diff - } - if diff <= window.Milliseconds() { - return true - } - } - } - return false -} +// func validateTimestamp(ts string, window time.Duration) bool { +// // 期望毫秒时间戳或秒级,简单容错 +// // 尝试解析为整数 +// var n int64 +// for _, base := range []int64{1, 1000} { // 秒或毫秒 +// if v, ok := parseInt64(ts); ok { +// n = v +// // 归一为毫秒 +// if base == 1 && len(ts) <= 10 { +// n = n * 1000 +// } +// now := time.Now().UnixMilli() +// diff := now - n +// if diff < 0 { +// diff = -diff +// } +// if diff <= window.Milliseconds() { +// return true +// } +// } +// } +// return false +// } -func parseInt64(s string) (int64, bool) { - var n int64 - for _, ch := range s { - if ch < '0' || ch > '9' { - return 0, false - } - n = n*10 + int64(ch-'0') - } - return n, true -} +// func parseInt64(s string) (int64, bool) { +// var n int64 +// for _, ch := range s { +// if ch < '0' || ch > '9' { +// return 0, false +// } +// n = n*10 + int64(ch-'0') +// } +// return n, true +// } func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) error { // 校验taskId - sessionID, ok := s.botTool.GetSessionByTaskID(env.TaskID) - if !ok { - return errorcode.ParamErr("missing session_id for task_id: %s", env.TaskID) - } ctx := c.Context() + sessionID, err := s.callbackManager.GetSession(ctx, env.TaskID) + if err != nil { + return errorcode.ParamErrf("failed to get session for task_id: %s, err: %v", env.TaskID, err) + } + if sessionID == "" { + return errorcode.ParamErrf("missing session_id for task_id: %s", env.TaskID) + } switch env.Action { case ActionBugOptimizationSubmitUpdate: @@ -165,8 +170,10 @@ func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) err // 发送日志 s.sendStreamTxt(sessionID, msg) - // 删除映射 - s.botTool.DelTaskMapping(env.TaskID) + // 通知等待者 + if err := s.callbackManager.Notify(ctx, env.TaskID, msg); err != nil { + // 记录错误但继续 + } return c.JSON(fiber.Map{"code": 0, "message": "ok"}) case ActionBugOptimizationSubmitProcess: @@ -175,14 +182,14 @@ func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) err } var data processData if err := json.Unmarshal(env.Data, &data); err != nil { - return errorcode.ParamErr("invalid json: %v", err) + return errorcode.ParamErrf("invalid json: %v", err) } s.sendStreamLoading(sessionID, data.Process) return c.JSON(fiber.Map{"code": 0, "message": "ok"}) default: - return errorcode.ParamErr("unknown action: %s", env.Action) + return errorcode.ParamErrf("unknown action: %s", env.Action) } } @@ -212,7 +219,7 @@ func (s *CallbackService) sendStreamLog(sessionID string, content string) { } streamLog := entitys.Response{ - Index: constants.BotToolsBugOptimizationSubmit, + Index: string(constants.BotToolsBugOptimizationSubmit), Content: content, Type: entitys.ResponseLog, } @@ -227,7 +234,7 @@ func (s *CallbackService) sendStreamTxt(sessionID string, content string) { } streamLog := entitys.Response{ - Index: constants.BotToolsBugOptimizationSubmit, + Index: string(constants.BotToolsBugOptimizationSubmit), Content: content, Type: entitys.ResponseText, } @@ -242,7 +249,7 @@ func (s *CallbackService) sendStreamLoading(sessionID string, content string) { } streamLog := entitys.Response{ - Index: constants.BotToolsBugOptimizationSubmit, + Index: string(constants.BotToolsBugOptimizationSubmit), Content: content, Type: entitys.ResponseLoading, } @@ -254,27 +261,27 @@ func (s *CallbackService) sendStreamLoading(sessionID string, content string) { func (s *CallbackService) handleBugOptimizationSubmitUpdate(ctx context.Context, taskData json.RawMessage) (string, *errorcode.BusinessErr) { var data BugOptimizationSubmitUpdateData if err := json.Unmarshal(taskData, &data); err != nil { - return "", errorcode.ParamErr("invalid data type: %v", err) + return "", errorcode.ParamErrf("invalid data type: %v", err) } if data.Creator == "" { - return "", errorcode.ParamErr("empty creator") + return "", errorcode.ParamErrf("empty creator") } // 获取创建者uid accessToken, _ := s.dingtalkOldClient.GetAccessToken() creatorId, err := s.dingtalkContactClient.SearchUserOne(accessToken, data.Creator) if err != nil { - return "", errorcode.ParamErr("invalid data type: %v", err) + return "", errorcode.ParamErrf("invalid data type: %v", err) } // 获取用户详情 userDetails, err := s.dingtalkOldClient.QueryUserDetails(ctx, creatorId) if err != nil { - return "", errorcode.ParamErr("invalid data type: %v", err) + return "", errorcode.ParamErrf("invalid data type: %v", err) } if userDetails == nil { - return "", errorcode.ParamErr("user details not found") + return "", errorcode.ParamErrf("user details not found") } unionId := userDetails.UnionID @@ -283,14 +290,14 @@ func (s *CallbackService) handleBugOptimizationSubmitUpdate(ctx context.Context, BaseId: data.BaseId, SheetId: data.SheetId, RecordId: data.RecordId, - OperatorId: tools_bot.BotBugOptimizationSubmitAdminUnionId, + OperatorId: tool_callback.BotBugOptimizationSubmitAdminUnionId, CreatorUnionId: unionId, }) if err != nil { - return "", errorcode.ParamErr("invalid data type: %v", err) + return "", errorcode.ParamErrf("invalid data type: %v", err) } if !ok { - return "", errorcode.ParamErr("update record failed") + return "", errorcode.ParamErrf("update record failed") } return "问题记录即将完成", nil @@ -300,16 +307,16 @@ func (s *CallbackService) handleBugOptimizationSubmitUpdate(ctx context.Context, func (s *CallbackService) handleBugOptimizationSubmitDone(ctx context.Context, taskData json.RawMessage) (string, *errorcode.BusinessErr) { var data BugOptimizationSubmitDoneData if err := json.Unmarshal(taskData, &data); err != nil { - return "", errorcode.ParamErr("invalid data type: %v", err) + return "", errorcode.ParamErrf("invalid data type: %v", err) } if len(data.Receivers) == 0 { - return "", errorcode.ParamErr("empty receivers") + return "", errorcode.ParamErrf("empty receivers") } // 构建接收者 receivers := s.getDingtalkReceivers(ctx, data.Receivers) if receivers == "" { - return "", errorcode.ParamErr("invalid receivers") + return "", errorcode.ParamErrf("invalid receivers") } // 构建跳转链接 diff --git a/internal/services/capability.go b/internal/services/capability.go new file mode 100644 index 0000000..759433c --- /dev/null +++ b/internal/services/capability.go @@ -0,0 +1,207 @@ +package services + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" + errorcode "ai_scheduler/internal/data/error" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/util" + "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/utils" + "context" + "encoding/json" + "fmt" + "strings" + "time" + + hytWorkflow "ai_scheduler/internal/domain/workflow/hyt" + + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "github.com/ollama/ollama/api" + "github.com/redis/go-redis/v9" +) + +// CapabilityService 统一回调入口 +type CapabilityService struct { + cfg *config.Config + workflowManager *runtime.Registry + rdsCli *redis.Client +} + +func NewCapabilityService(cfg *config.Config, workflowManager *runtime.Registry, rdb *utils.Rdb) *CapabilityService { + return &CapabilityService{ + cfg: cfg, + workflowManager: workflowManager, + rdsCli: rdb.Rdb, + } +} + +// 产品数据提取入参 +type ProductIngestReq struct { + SysId string `json:"sys_id"` // 业务系统ID - 当前仅支持货易通(hyt) + Url string `json:"url"` // 商品详情页URL + Title string `json:"title"` // 商品标题 + Text string `json:"text"` // 商品描述 + Images []string `json:"images"` // 商品图片URL列表 +} + +type ProductIngestResp struct { + ThreadId string `json:"thread_id"` // 线程ID,后续确认调用时需要 + SysId string `json:"sys_id"` // 业务系统ID + MetaData any `json:"meta"` // 元数据 + Draft string `json:"draft"` // 草稿数据,后续确认调用时需要 +} + +// ProductIngest 产品数据提取 +func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { + ctx := context.Background() + // 请求头校验 + if err := s.checkRequestHeader(c); err != nil { + return err + } + + // 解析请求参数 + req := ProductIngestReq{} + if err := c.BodyParser(&req); err != nil { + return errorcode.ParamErrf("invalid request body: %v", err) + } + // 必要参数校验 + if req.Text == "" || req.SysId == "" { + return errorcode.ParamErrf("missing required fields") + } + + // 映射目标系统商品属性中文模板 + var sysProductPropertyTemplateZH string + switch req.SysId { + case "hyt": // 货易通 + sysProductPropertyTemplateZH = constants.HYTGoodsAddPropertyTemplateZH + default: + return errorcode.ParamErrf("invalid sys_id") + } + + // 模型调用 + client, cleanup, err := utils_ollama.NewClient(s.cfg) + if err != nil { + return err + } + defer cleanup() + res, err := client.Chat(ctx, s.cfg.Ollama.MappingModel, []api.Message{ + { + Role: "system", + Content: constants.SystemPrompt, + }, + { + Role: "assistant", + Content: fmt.Sprintf("目标属性模板:%s。", sysProductPropertyTemplateZH), + }, + { + Role: "user", + Content: req.Text, + }, + { + Role: "user", + Content: "商品图片URL列表:" + strings.Join(req.Images, ","), + }, + }) + if err != nil { + return err + } + + // 生成thread_id + threadId := uuid.NewString() + resp := &ProductIngestResp{ + ThreadId: threadId, + SysId: req.SysId, + MetaData: req, + Draft: res.Message.Content, // Go中map会无序,交给前端解析 + } + respJson, _ := json.Marshal(resp) + + // 存redis缓存 + if err = s.rdsCli.Set(ctx, fmt.Sprintf(constants.CapabilityProductIngestCacheKey, threadId), respJson, 30*time.Minute).Err(); err != nil { + return err + } + + // 解析模型输出 + c.JSON(resp) + + return nil +} + +// checkRequestHeader 校验请求头 +func (s *CapabilityService) checkRequestHeader(c *fiber.Ctx) error { + // 读取头 + token := strings.TrimSpace(c.Get("X-Source-Key")) + ts := strings.TrimSpace(c.Get("X-Timestamp")) + + // 时间窗口校验 + if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { + return errorcode.AuthNotFound + } + // token校验 + if token == "" || token != constants.CapabilityProductIngestToken { + return errorcode.KeyNotFound + } + + return nil +} + +type ProductIngestConfirmReq struct { + ThreadId string `json:"thread_id"` // 线程ID + Confirmed string `json:"confirmed"` // 已确认数据json字符串 +} + +// ProductIngestConfirm 商品数据提取确认 +func (s *CapabilityService) ProductIngestConfirm(c *fiber.Ctx) error { + ctx := context.Background() + + // 请求头校验 + if err := s.checkRequestHeader(c); err != nil { + return err + } + // 获取路径参数中的 thread_id + threadId := c.Params("thread_id") + if threadId == "" { + return errorcode.ParamErrf("missing required fields") + } + // 解析请求参数 body + req := ProductIngestConfirmReq{} + if err := c.BodyParser(&req); err != nil { + return errorcode.ParamErrf("invalid request body: %v", err) + } + // 必要参数校验 + if req.Confirmed == "" || threadId == "" { + return errorcode.ParamErr("missing required fields") + } + + // 校验线程ID是否存在 + resp, err := s.rdsCli.Get(ctx, fmt.Sprintf(constants.CapabilityProductIngestCacheKey, threadId)).Result() + if err != nil { + return errorcode.ParamErr("invalid thread_id") + } + var respData ProductIngestResp + if err = json.Unmarshal([]byte(resp), &respData); err != nil { + return errorcode.ParamErr("invalid thread_id data") + } + + // 映射目标系统工作流ID + var workflowId string + switch respData.SysId { + // 货易通 + case "hyt": + workflowId = hytWorkflow.WorkflowIDGoodsAdd + default: + return errorcode.ParamErr("invalid sys_id") + } + + // 调用eino工作流,实现商品上传到目标系统 + rec := &entitys.Recognize{UserContent: &entitys.RecognizeUserContent{Text: req.Confirmed}} + res, err := s.workflowManager.Invoke(ctx, workflowId, rec) + if err != nil { + return err + } + + return c.JSON(res) +} diff --git a/internal/services/chat.go b/internal/services/chat.go index ff75bee..eb47809 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -6,9 +6,11 @@ import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" + "context" "encoding/json" "log" "sync" + "time" "github.com/gofiber/fiber/v2" "github.com/gofiber/websocket/v2" @@ -38,20 +40,6 @@ func NewChatService( } } -// ToolCallResponse 工具调用响应 -type ToolCallResponse struct { - ID string `json:"id" example:"call_1"` - Type string `json:"type" example:"function"` - Function FunctionCallResponse `json:"function"` - Result interface{} `json:"result,omitempty"` -} - -// FunctionCallResponse 函数调用响应 -type FunctionCallResponse struct { - Name string `json:"name" example:"get_weather"` - Arguments interface{} `json:"arguments"` -} - func (h *ChatService) ChatFail(c *websocket.Conn, content string) { err := c.WriteMessage(websocket.TextMessage, []byte(content)) if err != nil { @@ -63,84 +51,113 @@ func (h *ChatService) ChatFail(c *websocket.Conn, content string) { // Chat 处理WebSocket聊天连接 // 这是WebSocket处理的主入口函数 func (h *ChatService) Chat(c *websocket.Conn) { + ctx, cancel := context.WithCancel(context.Background()) + // 创建新的客户端实例 - h.mu.Lock() - client := gateway.NewClient(c) - h.mu.Unlock() - - // 将客户端添加到网关管理 - h.Gw.AddClient(client) - log.Println("client connected:", client.GetID()) - log.Println("客户端已连接") - - // 绑定会话ID - uid := c.Query("x-session") - if uid != "" { - if err := h.Gw.BindUid(client.GetID(), uid); err != nil { - log.Println("绑定UID错误:", err) - } - log.Printf("bind %s -> uid:%s\n", client.GetID(), uid) - } + client := gateway.NewClient(c, ctx, cancel) // 验证并收集连接数据,后续对话中会使用 if err := client.DataAuth(); err != nil { log.Println("数据验证错误:", err) - h.ChatFail(c, err.Error()) + _ = client.SendFunc([]byte(err.Error())) + client.Close() return } + // 验证通过后,将客户端添加到网关管理 + h.Gw.AddClient(client) + + // 使用信号量限制并发处理的消息数量 + semaphore := make(chan struct{}, 1) // 最多1个并发消息处理 + // 用于等待所有goroutine完成的wait group + var wg sync.WaitGroup // 确保在函数返回时移除客户端并关闭连接 defer func() { - h.Gw.RemoveClient(client.GetID()) - _ = c.Close() - log.Println("client disconnected:", client.GetID()) + wg.Wait() // 等待所有消息处理goroutine完成 + close(semaphore) // 关闭信号量通道 + h.Gw.Cleanup(client.GetID()) }() + // 绑定会话ID, sessionId 为空时, 则不绑定 + if uid := client.GetSession(); uid != "" { + if err := h.Gw.BindUid(client.GetID(), uid); err != nil { + log.Println("绑定UID错误:", err) + } + } + + // 开启心跳检测 + go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval)) + // 循环读取客户端消息 for { - // 读取消息 - messageType, message, err := c.ReadMessage() + messageType, message, err := client.ReadMessage() if err != nil { - log.Println("读取错误:", err) + log.Printf("读取错误: %v, 客户端ID: %s", err, client.GetID()) break } - // 处理消息 - msg, chatType := h.handleMessageToString(c, messageType, message) - if chatType == constants.ConnStatusClosed { - break - } - if chatType == constants.ConnStatusIgnore { + // 处理心跳消息 + if messageType == websocket.PingMessage || string(message) == "PING" { + client.LastActive = time.Now() + msgType := websocket.TextMessage + if messageType == websocket.PingMessage { + msgType = websocket.PongMessage + } + if err = client.SendMessage(msgType, []byte(`PONG`)); err != nil { + log.Printf("发送pong消息失败: %v", err) + } continue } - log.Printf("收到消息: %s", string(msg)) + // 使用信号量限制并发 + semaphore <- struct{}{} + wg.Add(1) + go func(msgType int, msg []byte) { + defer func() { + <-semaphore + wg.Done() + // 恢复panic + if r := recover(); r != nil { + log.Printf("消息处理goroutine发生panic: %v", r) + } + }() - // 解析请求 - var req entitys.ChatSockRequest - if err = json.Unmarshal(msg, &req); err != nil { - log.Println("JSON parse error:", err) - continue - } + // 消息处理逻辑 + h.processMessage(client, msgType, msg) + }(messageType, message) + } +} - // 路由处理请求 - err = h.routerBiz.RouteWithSocket(client, &req) - if err != nil { - log.Println("处理失败:", err) - } +// 将消息处理逻辑提取到单独的方法 +func (h *ChatService) processMessage(client *gateway.Client, msgType int, msg []byte) { + // 处理消息 + processedMsg, _ := h.handleMessageToString(client, msgType, msg) + log.Printf("收到消息:消息类型 %d, 内容 %s, 客户端ID: %s", + msgType, string(processedMsg), client.GetID()) + + // 解析请求 + var req entitys.ChatSockRequest + if err := json.Unmarshal(processedMsg, &req); err != nil { + log.Printf("JSON解析错误: %v, 客户端ID: %s", err, client.GetID()) + return + } + + // 路由处理请求 + if err := h.routerBiz.RouteWithSocket(client, &req); err != nil { + log.Printf("处理失败: %v, 客户端ID: %s", err, client.GetID()) } } // handleMessageToString 处理不同类型的WebSocket消息 // 参数: -// - c: WebSocket连接 +// - client: 客户端对象 // - msgType: 消息类型 // - msg: 消息内容 // // 返回: // - text: 处理后的文本内容 // - chatType: 连接状态 -func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) { +func (h *ChatService) handleMessageToString(client *gateway.Client, msgType int, msg any) (text []byte, chatType constants.ConnStatus) { switch msgType { case websocket.TextMessage: return msg.([]byte), constants.ConnStatusNormal @@ -150,15 +167,13 @@ func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg return nil, constants.ConnStatusClosed case websocket.PingMessage: - // 可选:回复 Pong - c.WriteMessage(websocket.PongMessage, nil) return nil, constants.ConnStatusIgnore case websocket.PongMessage: return nil, constants.ConnStatusIgnore default: + log.Printf("未知的消息类型: %d", msgType) return nil, constants.ConnStatusIgnore } - return msg.([]byte), constants.ConnStatusIgnore } func (s *ChatService) Useful(c *fiber.Ctx) error { diff --git a/internal/services/chat_history.go b/internal/services/chat_history.go new file mode 100644 index 0000000..7bd4d75 --- /dev/null +++ b/internal/services/chat_history.go @@ -0,0 +1,63 @@ +package services + +import ( + "ai_scheduler/internal/biz" + errors "ai_scheduler/internal/data/error" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/validate" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/log" + "strings" +) + +type HistoryService struct { + chatRepo *biz.ChatHistoryBiz +} + +func NewHistoryService(chatRepo *biz.ChatHistoryBiz) *HistoryService { + return &HistoryService{ + chatRepo: chatRepo, + } +} + +// GetHistoryService 获取会话历史 +func (h *HistoryService) List(c *fiber.Ctx) error { + var query entitys.ChatHistQuery + if err := c.BodyParser(&query); err != nil { + return err + } + // 校验参数 + if query.SessionID == "" { + return errors.SessionNotFound + } + if query.Page <= 0 { + query.Page = 1 + } + if query.PageSize <= 0 { + query.PageSize = 10 + } + + // 查询历史 + history, err := h.chatRepo.List(c.Context(), &query) + if err != nil { + return err + } + + return c.JSON(history) +} + +func (h *HistoryService) UpdateContent(c *fiber.Ctx) error { + var req entitys.UpdateContentRequest + if err := c.BodyParser(&req); err != nil { + return err + } + // 校验参数 + msg, err := validate.Struct(req) + if err != nil { + log.Error(c.UserContext(), "参数错误 error: ", err) + return errors.NewBusinessErr(errors.InvalidParamCode, strings.Join(msg, ";")) + } + + // 更新历史 + return h.chatRepo.UpdateContent(c.Context(), &req) +} diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go new file mode 100644 index 0000000..b71e40b --- /dev/null +++ b/internal/services/dtalk_bot.go @@ -0,0 +1,137 @@ +package services + +import ( + "ai_scheduler/internal/biz" + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "context" + "log" + "sync" + "time" + + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" + "golang.org/x/sync/errgroup" +) + +type DingBotService struct { + config *config.Config + dingTalkBotBiz *biz.DingTalkBotBiz +} + +func NewDingBotService(config *config.Config, dingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService { + return &DingBotService{ + config: config, + dingTalkBotBiz: dingTalkBotBiz, + } +} + +func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) { + return d.dingTalkBotBiz.GetDingTalkBotCfgList() +} + +func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) { + requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data) + if err != nil { + return nil, err + } + + // 启动后台任务(独立生命周期,带超时控制) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + if err := d.runBackgroundTasks(ctx, data, requireData); err != nil { + log.Printf("后台任务执行失败: %v", err) + } + }() + + return []byte("success"), nil +} + +func (d *DingBotService) runBackgroundTasks(ctx context.Context, data *chatbot.BotCallbackDataModel, requireData *entitys.RequireDataDingTalkBot) error { + g, ctx := errgroup.WithContext(ctx) + var ( + chat []string + chatMu sync.Mutex + resChan = make(chan string, 10) + ) + + // 1. 流式处理协程 + g.Go(func() error { + defer func() { + // 确保通道最终关闭 + close(resChan) + }() + return d.dingTalkBotBiz.HandleStreamRes(ctx, data, resChan) + }) + + // 2. 业务处理协程(负责关闭requireData.Ch) + g.Go(func() error { + // 在完成时关闭通道 + defer close(requireData.Ch) + + //entitys.ResLoading(requireData.Ch, "", "![图片](") + //entitys.ResLoading(requireData.Ch, "", "https://p6-img.") + //entitys.ResLoading(requireData.Ch, "", "searchpstatp.com/") + //entitys.ResLoading(requireData.Ch, "", "tos-cn-i-vvloioitz3/ab5ae998d8162b431f44fb2a0ed9ae33~tplv-vvloioitz3-6:190:124.jpeg)") + + return d.dingTalkBotBiz.Do(ctx, requireData) + }) + + // 3. 结果收集协程(修改后的版本) + resultDone := make(chan struct{}) + g.Go(func() error { + // 使用defer确保通道关闭 + defer close(resultDone) + // 处理通道中的数据 + for { + select { + case resp, ok := <-requireData.Ch: + if !ok { + return nil // 通道已关闭,正常退出 + } + if resp.Type != entitys.ResponseLog { + chatMu.Lock() + chat = append(chat, resp.Content) + chatMu.Unlock() + + select { + case resChan <- resp.Content: + case <-ctx.Done(): + return ctx.Err() + } + } + case <-ctx.Done(): + return ctx.Err() // 上下文取消,提前退出 + } + } + }) + + // 4. 统一关闭通道的协程(只关闭resChan) + g.Go(func() error { + <-resultDone + // resChan已在流式处理协程关闭 + return nil + }) + + // 5. 历史记录保存协程 + g.Go(func() error { + <-resultDone + chatMu.Lock() + savedChat := make([]string, len(chat)) + copy(savedChat, chat) + chatMu.Unlock() + + if err := d.dingTalkBotBiz.SaveHis(ctx, requireData, savedChat); err != nil { + log.Printf("保存历史记录失败: %v", err) + return err + } + return nil + }) + + // 阻塞直到所有协程完成或出错 + if err := g.Wait(); err != nil { + return err + } + + return nil +} diff --git a/internal/services/dtalk_bot.go.bak b/internal/services/dtalk_bot.go.bak new file mode 100644 index 0000000..75c2c7f --- /dev/null +++ b/internal/services/dtalk_bot.go.bak @@ -0,0 +1,130 @@ +package services + +import ( + "ai_scheduler/internal/biz" + "log" + "sync" + "time" + + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "context" + + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" +) + +type DingBotService struct { + config *config.Config + dingTalkBotBiz *biz.DingTalkBotBiz +} + +func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService { + return &DingBotService{config: config, dingTalkBotBiz: DingTalkBotBiz} +} + +func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) { + return d.dingTalkBotBiz.GetDingTalkBotCfgList() +} + +func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) { + var ( + lastErr error + chat []string + streamWG sync.WaitGroup + resChan = make(chan string, 100) // 缓冲通道防止阻塞 + ) + + // 初始化请求 + requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data) + if err != nil { + return nil, err + } + + // 创建子上下文用于控制goroutine生命周期 + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // 启动流式处理goroutine + streamWG.Add(1) + go func() { + defer streamWG.Done() + err = d.dingTalkBotBiz.HandleStreamRes(subCtx, data, resChan) + if err != nil { + return + } + }() + + // 启动业务处理goroutine + done := make(chan error, 1) + go func() { + done <- d.dingTalkBotBiz.Do(subCtx, requireData) + }() + + // 主处理循环 + for { + select { + case <-ctx.Done(): + lastErr = ctx.Err() + goto cleanup + + case resp, ok := <-requireData.Ch: + if !ok { + goto cleanup + } + + // 处理不同类型响应 + switch resp.Type { + case entitys.ResponseLog: + // 忽略日志类型 + continue + + //case entitys.ResponseText, entitys.ResponseJson: + // chat = append(chat, resp.Content) + // if err := d.dingTalkBotBiz.ReplyText(ctx, data.SessionWebhook, resp.Content); err != nil { + // log.Printf("处理非流响应失败: %v", err) + // lastErr = err + // } + + default: + chat = append(chat, resp.Content) + select { + case resChan <- resp.Content: + case <-ctx.Done(): + lastErr = ctx.Err() + goto cleanup + } + } + } + } + +cleanup: + streamWG.Wait() + // 关闭流式通道 + close(resChan) + + // 保存历史记录 + if saveErr := d.dingTalkBotBiz.SaveHis(ctx, requireData, chat); saveErr != nil { + log.Printf("保存历史记录失败: %v", saveErr) + if lastErr == nil { + lastErr = saveErr + } + } + + // 等待业务处理完成(带超时) + select { + case err := <-done: + if err != nil { + log.Printf("业务处理失败: %v", err) + if lastErr == nil { + lastErr = err + } + } + case <-time.After(3 * time.Second): // 增加超时时间 + log.Println("警告:等待业务处理超时,可能发生goroutine泄漏") + } + + if lastErr != nil { + return nil, lastErr + } + return []byte("success"), nil +} diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index 0e1284b..375a886 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -6,4 +6,12 @@ import ( "github.com/google/wire" ) -var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService) +var ProviderSetServices = wire.NewSet( + NewChatService, + NewSessionService, gateway.NewGateway, + NewTaskService, + NewCallbackService, + NewDingBotService, + NewHistoryService, + NewCapabilityService, +) diff --git a/internal/tools_bot/bug_optimization_submit.go b/internal/tool_callback/bug_optimization_submit.go similarity index 72% rename from internal/tools_bot/bug_optimization_submit.go rename to internal/tool_callback/bug_optimization_submit.go index 1bb2d31..57a5e73 100644 --- a/internal/tools_bot/bug_optimization_submit.go +++ b/internal/tool_callback/bug_optimization_submit.go @@ -1,10 +1,13 @@ -package tools_bot +package tool_callback import ( + "ai_scheduler/internal/config" errors "ai_scheduler/internal/data/error" + "ai_scheduler/internal/data/impl" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/utils_ollama" "context" "encoding/json" "fmt" @@ -15,6 +18,23 @@ import ( "xorm.io/builder" ) +type CallBackTool struct { + config *config.Config + llm *utils_ollama.Client + sessionImpl *impl.SessionImpl + taskMap map[string]string +} + +// NewBotTool 创建直连天下订单详情工具 +func NewCallBackTool(config *config.Config, llm *utils_ollama.Client, sessionImpl *impl.SessionImpl) *CallBackTool { + return &CallBackTool{config: config, llm: llm, sessionImpl: sessionImpl, taskMap: make(map[string]string)} +} + +// Execute 执行直连天下订单详情查询 +func (w *CallBackTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) { + return +} + // BugOptimizationSubmitForm 工单提交表单参数 type BugOptimizationSubmitForm struct { Mark string `json:"mark"` // 工单标识 @@ -34,13 +54,13 @@ const ( ) // BugOptimizationSubmit 工单提交 -func (w *BotTool) BugOptimizationSubmit(ctx context.Context, requireData *entitys.RequireData) (err error) { +func (w *CallBackTool) BugOptimizationSubmit(ctx context.Context, requireData *entitys.RequireData) (err error) { // 获取用户信息 cond := builder.NewCond() cond = cond.And(builder.Eq{"session_id": requireData.Session}) sessionInfo, err := w.sessionImpl.GetOneBySearch(&cond) if err != nil { - err = errors.SysErr("获取会话信息失败:%v", err.Error()) + err = errors.SysErrf("获取会话信息失败:%v", err.Error()) return } userName := sessionInfo["user_name"].(string) @@ -98,7 +118,7 @@ func (w *BotTool) BugOptimizationSubmit(ctx context.Context, requireData *entity // SetTaskMapping 设置 task_id 到 session_id 的映射(内存版)。 // 后续考虑使用 Redis,确保幂等与过期清理。 -func (w *BotTool) SetTaskMapping(taskID, sessionID string) { +func (w *CallBackTool) SetTaskMapping(taskID, sessionID string) { if taskID == "" || sessionID == "" { return } @@ -106,12 +126,12 @@ func (w *BotTool) SetTaskMapping(taskID, sessionID string) { } // GetSessionByTaskID 读取映射 -func (w *BotTool) GetSessionByTaskID(taskID string) (string, bool) { +func (w *CallBackTool) GetSessionByTaskID(taskID string) (string, bool) { v, ok := w.taskMap[taskID] return v, ok } // DelTaskMapping 删除 task_id 到 session_id 的映射(内存版)。 -func (w *BotTool) DelTaskMapping(taskID string) { +func (w *CallBackTool) DelTaskMapping(taskID string) { delete(w.taskMap, taskID) } diff --git a/internal/tool_callback/provider_set.go b/internal/tool_callback/provider_set.go new file mode 100644 index 0000000..c2671d4 --- /dev/null +++ b/internal/tool_callback/provider_set.go @@ -0,0 +1,9 @@ +package tool_callback + +import ( + "github.com/google/wire" +) + +var ProviderSetCallBackTools = wire.NewSet( + NewCallBackTool, +) diff --git a/internal/tools/calculator.go b/internal/tools/calculator.go deleted file mode 100644 index fecde64..0000000 --- a/internal/tools/calculator.go +++ /dev/null @@ -1,121 +0,0 @@ -package tools - -import ( - "ai_scheduler/internal/entitys" - - "context" - "encoding/json" - "fmt" - "math" -) - -// CalculatorTool 计算器工具 -type CalculatorTool struct{} - -// NewCalculatorTool 创建计算器工具 -func NewCalculatorTool() *CalculatorTool { - return &CalculatorTool{} -} - -// Name 返回工具名称 -func (c *CalculatorTool) Name() string { - return "calculate" -} - -// Description 返回工具描述 -func (c *CalculatorTool) Description() string { - return "执行基本的数学运算,支持加减乘除和幂运算" -} - -// Definition 返回工具定义 -func (c *CalculatorTool) Definition() entitys.ToolDefinition { - return entitys.ToolDefinition{ - Type: "function", - Function: entitys.FunctionDef{ - Name: c.Name(), - Description: c.Description(), - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "operation": map[string]interface{}{ - "type": "string", - "description": "运算类型", - "enum": []string{"add", "subtract", "multiply", "divide", "power"}, - }, - "a": map[string]interface{}{ - "type": "number", - "description": "第一个数字", - }, - "b": map[string]interface{}{ - "type": "number", - "description": "第二个数字", - }, - }, - "required": []string{"operation", "a", "b"}, - }, - }, - } -} - -// CalculateRequest 计算请求参数 -type CalculateRequest struct { - Operation string `json:"operation"` - A float64 `json:"a"` - B float64 `json:"b"` -} - -// CalculateResponse 计算响应 -type CalculateResponse struct { - Operation string `json:"operation"` - A float64 `json:"a"` - B float64 `json:"b"` - Result float64 `json:"result"` - Expression string `json:"expression"` -} - -// Execute 执行计算 -func (c *CalculatorTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) { - var req CalculateRequest - if err := json.Unmarshal(args, &req); err != nil { - return nil, fmt.Errorf("invalid calculate request: %w", err) - } - - var result float64 - var expression string - - switch req.Operation { - case "add": - result = req.A + req.B - expression = fmt.Sprintf("%.2f + %.2f = %.2f", req.A, req.B, result) - case "subtract": - result = req.A - req.B - expression = fmt.Sprintf("%.2f - %.2f = %.2f", req.A, req.B, result) - case "multiply": - result = req.A * req.B - expression = fmt.Sprintf("%.2f × %.2f = %.2f", req.A, req.B, result) - case "divide": - if req.B == 0 { - return nil, fmt.Errorf("division by zero is not allowed") - } - result = req.A / req.B - expression = fmt.Sprintf("%.2f ÷ %.2f = %.2f", req.A, req.B, result) - case "power": - result = math.Pow(req.A, req.B) - expression = fmt.Sprintf("%.2f ^ %.2f = %.2f", req.A, req.B, result) - default: - return nil, fmt.Errorf("unsupported operation: %s", req.Operation) - } - - // 检查结果是否有效 - if math.IsNaN(result) || math.IsInf(result, 0) { - return nil, fmt.Errorf("calculation resulted in invalid number") - } - - return &CalculateResponse{ - Operation: req.Operation, - A: req.A, - B: req.B, - Result: result, - Expression: expression, - }, nil -} diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 2ebbb69..eedb4ee 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -2,9 +2,9 @@ package tools import ( "ai_scheduler/internal/config" - "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/tools/public" zltxtool "ai_scheduler/internal/tools/zltx" "context" @@ -24,50 +24,32 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { llm: llm, } - // 注册天气工具 - //if config.Tools.Weather.Enabled { - // weatherTool := NewWeatherTool() - // m.tools[weatherTool.Name()] = weatherTool - //} - // - //// 注册计算器工具 - //if config.Tools.Calculator.Enabled { - // calcTool := NewCalculatorTool() - // m.tools[calcTool.Name()] = calcTool - //} - - // 注册知识库工具 - // if config.Knowledge.Enabled { - // knowledgeTool := NewKnowledgeTool() - // m.tools[knowledgeTool.Name()] = knowledgeTool - // } - // 注册直连天下订单详情工具 if config.Tools.ZltxOrderDetail.Enabled { - zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm) + zltxOrderDetailTool := zltxtool.NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm) m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool } //注册直连天下订单日志工具 if config.Tools.ZltxOrderDirectLog.Enabled { - zltxOrderLogTool := NewZltxOrderLogTool(config.Tools.ZltxOrderDirectLog) + zltxOrderLogTool := zltxtool.NewZltxOrderLogTool(config.Tools.ZltxOrderDirectLog) m.tools[zltxOrderLogTool.Name()] = zltxOrderLogTool } //注册直连天下商品工具 if config.Tools.ZltxProduct.Enabled { - zltxProductTool := NewZltxProductTool(config.Tools.ZltxProduct) + zltxProductTool := zltxtool.NewZltxProductTool(config.Tools.ZltxProduct) m.tools[zltxProductTool.Name()] = zltxProductTool } //注册直连天下订单统计工具 if config.Tools.ZltxOrderStatistics.Enabled { - zltxOrderStatisticsTool := NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics) + zltxOrderStatisticsTool := zltxtool.NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics) m.tools[zltxOrderStatisticsTool.Name()] = zltxOrderStatisticsTool } // 注册知识库工具 if config.Tools.Knowledge.Enabled { - knowledgeTool := NewKnowledgeBaseTool(config.Tools.Knowledge) + knowledgeTool := public.NewKnowledgeBaseTool(config.Tools.Knowledge) m.tools[knowledgeTool.Name()] = knowledgeTool } @@ -86,9 +68,24 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { zltxOrderAfterSaleResellerBatchTool := zltxtool.NewOrderAfterSaleResellerBatchTool(config.Tools.ZltxOrderAfterSaleResellerBatch) m.tools[zltxOrderAfterSaleResellerBatchTool.Name()] = zltxOrderAfterSaleResellerBatchTool } + // 注册天气工具 + if config.Tools.Weather.Enabled { + weatherTool := public.NewWeatherTool(config.Tools.Weather) + m.tools[weatherTool.Name()] = weatherTool + } + // 注册 Coze 快递查询工具 + if config.Tools.CozeExpress.Enabled { + cozeTool := public.NewCozeExpress(config.Tools.CozeExpress, m.llm) + m.tools[cozeTool.Name()] = cozeTool + } + // 注册 Coze 公司查询工具 + if config.Tools.CozeCompany.Enabled { + cozeTool := public.NewCozeCompany(config.Tools.CozeCompany, m.llm) + m.tools[cozeTool.Name()] = cozeTool + } // 普通对话 - chat := NewNormalChatTool(m.llm, config) + chat := public.NewNormalChatTool(m.llm, config) m.tools[chat.Name()] = chat return m @@ -100,63 +97,12 @@ func (m *Manager) GetTool(name string) (entitys.Tool, bool) { return tool, exists } -// GetAllTools 获取所有工具 -func (m *Manager) GetAllTools() []entitys.Tool { - tools := make([]entitys.Tool, 0, len(m.tools)) - for _, tool := range m.tools { - tools = append(tools, tool) - } - return tools -} - -// GetToolDefinitions 获取所有工具定义 -func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition { - definitions := make([]entitys.ToolDefinition, 0, len(m.tools)) - for _, tool := range m.tools { - definitions = append(definitions, tool.Definition()) - } - - return definitions -} - // ExecuteTool 执行工具 -func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error { +func (m *Manager) ExecuteTool(ctx context.Context, name string, rec *entitys.Recognize) error { tool, exists := m.GetTool(name) if !exists { return fmt.Errorf("tool not found: %s", name) } - return tool.Execute(ctx, requireData) + return tool.Execute(ctx, rec) } - -// ExecuteToolCalls 执行多个工具调用 -//func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) { -// results := make([]entitys.ToolCall, len(toolCalls)) -// -// for i, toolCall := range toolCalls { -// results[i] = toolCall -// -// // 执行工具 -// err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments) -// if err != nil { -// // 将错误信息作为结果返回 -// errorResult := map[string]interface{}{ -// "error": err.Error(), -// } -// resultBytes, _ := json.Marshal(errorResult) -// results[i].Result = resultBytes -// } else { -// // 将成功结果序列化 -// resultBytes, err := json.Marshal(result) -// if err != nil { -// errorResult := map[string]interface{}{ -// "error": fmt.Sprintf("failed to serialize result: %v", err), -// } -// resultBytes, _ = json.Marshal(errorResult) -// } -// results[i].Result = resultBytes -// } -// } -// -// return results, nil -//} diff --git a/internal/tools/public/coze_company.go b/internal/tools/public/coze_company.go new file mode 100644 index 0000000..e061965 --- /dev/null +++ b/internal/tools/public/coze_company.go @@ -0,0 +1,265 @@ +package public + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/utils_ollama" + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/ollama/ollama/api" + + "github.com/coze-dev/coze-go" +) + +type CozeCompany struct { + cozeApi coze.CozeAPI + config config.ToolConfig + llm *utils_ollama.Client +} + +// NewCoze 创建 Coze 实例 +func NewCozeCompany(config config.ToolConfig, llm *utils_ollama.Client) *CozeCompany { + return &CozeCompany{ + cozeApi: newCozeApi(config), + config: config, + llm: llm, + } +} + +// newCozeClient 创建 Coze 客户端 +func newCozeApi(config config.ToolConfig) coze.CozeAPI { + authCli := coze.NewTokenAuth(config.APISecret) + cozeApi := coze.NewCozeAPI(authCli, coze.WithBaseURL(config.BaseURL), coze.WithHttpClient(&http.Client{ + Timeout: time.Second * 120, + })) + return cozeApi +} + +// Name 返回工具名称 +func (c *CozeCompany) Name() string { + return "coze_company" +} + +// Description 返回工具描述 +func (c *CozeCompany) Description() string { + return "查询企业信息" +} + +// Definition 返回工具定义 +func (c *CozeCompany) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ + Type: "function", + Function: entitys.FunctionDef{ + Name: c.Name(), + Description: c.Description(), + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "company_name": map[string]interface{}{ + "type": "string", + "description": "企业名称", + }, + }, + "required": []string{"company_name"}, + }, + }, + } +} + +// Execute 执行查询 +func (c *CozeCompany) Execute(ctx context.Context, requireData *entitys.Recognize) error { + var req map[string]interface{} + if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + return fmt.Errorf("invalid express request: %w", err) + } + + if req["company_name"] == "" { + return fmt.Errorf("company_name is required") + } + + // 调用 Coze 工作流 + rsp, err := c.callWorkflow(ctx, req) + if err != nil { + return fmt.Errorf("failed to get real weather: %w", err) + } + + companyInfo := CompanyInfo{} + err = json.Unmarshal([]byte(rsp.Data), &companyInfo) + if err != nil { + return fmt.Errorf("failed to unmarshal company info: %w", err) + } + + // 调用 LLM 模型 + err = c.llm.ChatStream(ctx, requireData.Ch, []api.Message{ + { + Role: "system", + Content: `# Role: 企业信息分析与经营诊断专家: +请基于以下12项指定数据字段(无需补充未提供的信息),完成目标企业的全维度分析总结,要求每部分结论必须100%锚定对应数据,拒绝主观推测,突出“风险可见性”与“关键信息关联性”: +一、输入数据清单(需逐一对应分析) +行政处罚:公司是否有行政处罚(含处罚事由、处罚机关、处罚日期、处罚文号) +清算信息:公司的清算信息(含清算原因、清算组构成、清算进展状态) +变更记录:公司的变更记录(含变更事项:注册资本/股东/法定代表人/经营范围、变更时间、变更前后内容) +主要成员:公司名搜索公司的主要成员(含姓名、职位、任职时间、核心履历关键词) +企业详情:公司名称,搜索公司详细信息(含成立时间、注册资本、经营范围、行业分类、注册地址) +经营异常:公司是否有经营异常(含列入原因、列入日期、移出状态) +破产重组:公司破产重组的信息(含申请法院、受理时间、重组方案核心内容) +严重违法:是否有严重违法信息(含违法事由、认定机关、公示期限) +司法信息:司法信息(含案件类型、原被告身份、判决结果/进展) +被执行信息:公司的被执行信息(含执行法院、执行标的、未履行金额、失信状态) +企业手机号:公司名称查询企业手机号(需标注是否为公开备案号) +股东信息:公司对应的股东(含股东名称、出资额、出资比例、股东类型:自然人/企业/机构) +二、分析总结框架(严格对应数据) +1. 企业基础画像(锚定公司名称,搜索公司详细信息) +核心属性:成立时间、注册资本(实缴/认缴)、行业分类(如“批发和零售业”“软件和信息技术服务业”)、主营业务(从经营范围中提炼1-2个核心赛道,如“专注于智能仓储设备研发与销售”); +注册地址:是否与主要经营地一致(若公司名称,搜索公司详细信息有披露)。 +2. 股权结构与股东特征(锚定公司对应的股东) +股权集中度:前三大股东出资占比之和(如“第一大股东持股60%,为绝对控股股东”); +股东类型分布:自然人股东、企业股东、机构股东的占比(如“70%为企业股东,含1家行业头部企业”); +关键股东亮点:若有知名企业/机构股东,需明确标注(如“股东含XX产业投资基金,具备产业链资源协同潜力”)。 +3. 经营稳定性与合规风险(锚定公司是否有行政处罚/公司是否有经营异常/是否有严重违法信息/公司的清算信息/公司破产重组的信息) +行政处罚风险:是否存在行政处罚? +若有:列清「处罚事由(如“虚假宣传”“税务逾期申报”)、处罚机关、处罚日期」,并判断“是否属于高频违规”(如1年内≥2次同类处罚→高风险); +若无:标注“无公开行政处罚记录”。 +经营异常风险:是否存在经营异常? +若有:说明「列入原因(如“通过登记的住所无法联系”“未公示年度报告”)、是否已移出」,并评估对业务的影响(如“地址失联可能导致客户信任下降”); +若无:标注“无经营异常记录”。 +严重违法红线:是否存在严重违法? +若有:明确「违法事由(如“拒不履行生效法律文书”“欺诈消费者”)、认定机关、公示期限」,并标注“触及监管红线,需重点核查整改情况”; +若无:标注“无严重违法记录”。 +极端状态预警:是否存在清算或破产重组? +若有清算:说明「清算原因(如“股东会决议解散”“营业期限届满”)、清算进展(如“清算组已完成资产清查”)」; +若有破产重组:说明「申请法院、受理时间、重组方案核心内容(如“拟引入战略投资者注资5000万”)」; +若无:标注“无清算或破产重组记录”。 +4. 司法与执行风险(锚定司法信息/公司的被执行信息) +涉诉情况:司法案件的核心特征(如“80%为买卖合同纠纷,20%为劳动争议仲裁”;“作为被告的案件占比75%”;“判决胜诉率约60%”); +被执行压力:是否存在被执行信息? +若有:列清「执行法院、执行标的金额、未履行金额、是否纳入失信被执行人名单」,并计算“未履行金额占注册资本的比例”(如“未履行800万,占注册资本的20%”); +若无:标注“无公开被执行记录”。 +5. 管理团队与联系信息(锚定公司名搜索公司的主要成员/公司名称查询企业手机号) +核心团队稳定性:主要成员的任职时间分布(如“CEO任职4年,CFO任职2年,核心团队平均任职3年”);是否有频繁变动(如“1年内2位高管离职”→需提示“管理稳定性风险”); +关键岗位资质:核心成员(如法定代表人、CEO、CFO)的过往履历亮点(如“CEO曾任职XX上市公司,主导过亿元级项目落地”); +联系信息可信度:企业手机号是否为公开备案(如“与工商登记预留电话一致→可信度高”;“为非备案号→需提示‘联系信息真实性存疑’”)。 +6. 综合结论与行动建议(整合所有数据) +整体风险评级:基于数据密度,给出“低风险/中风险/高风险”定性(示例:“无行政处罚、无被执行、仅1条普通合同纠纷→低风险”;“有失信被执行+严重违法→高风险”); +Top3核心风险:按“严重违法>被执行>破产重组>行政处罚>经营异常>司法纠纷”排序,列出最需关注的3个问题(如“1. 未履行金额占注册资本20%,存在债务违约风险;2. 1年内2次地址异常,经营稳定性弱;3. 自然人股东占比过高,决策易受个人因素影响”); + actionable 建议:针对每个核心风险,给出可执行的核查/应对动作(如“核查未履行金额对应的案件进展,评估企业偿债能力;要求企业提供近1年的地址证明,确认经营场所稳定性;穿透核查自然人股东的资产状况,降低决策风险”)。 +三、输出规则(强制遵守) +数据溯源:每句结论必须标注对应数据字段(如“根据公司是否有行政处罚,企业2023年因‘税务逾期申报’被区税务局处罚”); +量化优先:拒绝模糊表述(如不说“很多案件”,要说“近1年涉及5起买卖合同纠纷”); +风险分级:用“★”标注风险等级(★越多越严重,如“严重违法★★★”“被执行★★”“经营异常★”); +语言风格:专业简洁,避免冗余,适合风控/投资/合作前的快速决策阅读。 + +【示例输出片段】 +3. 经营稳定性与合规风险 + +行政处罚:根据公司是否有行政处罚,企业2022年11月因“发布虚假广告”被区市场监管局处以3万元罚款(文号:X市监罚字〔2022〕456号),无后续同类处罚→风险等级★; + +经营异常:根据公司是否有经营异常,企业2023年6月因“通过登记的住所无法联系”被列入异常名录,2023年12月已移出→风险等级★; + +严重违法:根据是否有严重违法信息,无公开严重违法记录→风险等级无; + +极端状态:根据公司的清算信息/公司破产重组的信息,无清算或破产重组记录→风险等级无。 + +6. 综合结论与行动建议 + +整体评级:低风险(仅1次轻微行政处罚,无重大合规瑕疵); + +Top3核心风险:1. 根据司法信息,近1年作为被告的合同纠纷占比75%,需警惕应收账款回收风险★★;2. 根据公司名搜索公司的主要成员,1年内2位销售总监离职,管理稳定性弱★;3. 根据公司名称查询企业手机号,企业手机号为非备案号,联系信息真实性存疑★; + +建议:核查合同纠纷案件的原告身份及回款情况,评估坏账概率;要求企业提供离职人员的交接说明,确认业务连续性;索要企业备案的联系方式,验证沟通有效性。 +`, + }, + + { + Role: "assistant", + Content: fmt.Sprintf(`请分析企业:%s, +公司是否有行政处罚:%v, +公司的清算信息:%v, +公司的变更记录:%v, +公司名搜索公司的主要成员:%v, +公司名称,搜索公司详细信息:%v, +公司是否有经营异常:%v, +公司破产重组的信息:%v, +是否有严重违法信息:%v, +司法信息:%v, +公司的被执行信息:%v, +公司名称查询企业手机号:%v, +公司对应的股东:%v`, req["company_name"], companyInfo.Xzcf, companyInfo.Clears, companyInfo.Changes, companyInfo.Employees, companyInfo.Searchdata, companyInfo.Operations, companyInfo.BankruptcyPublicList, companyInfo.Illegals, companyInfo.JudicialList, companyInfo.Executes, companyInfo.Phone, companyInfo.Partners), + }, + { + Role: "user", + Content: requireData.UserContent.Text, + }, + }, + c.Name(), "") + if err != nil { + return fmt.Errorf("failed to get express info: %w", err) + } + + //entitys.ResText(requireData.Ch, "", rsp.Data) + + return nil +} + +// CallWorkflow 调用 Coze 工作流 +// 参数: +// - ctx: 上下文,用于控制超时和取消 +// - workflowId: 工作流 ID +// - params: 工作流参数 +// 返回: +// - interface{}: 工作流执行结果 +// - error: 错误信息 +func (c *CozeCompany) callWorkflow(ctx context.Context, params map[string]interface{}) (*coze.RunWorkflowsResp, error) { + // 准备工作流请求参数 + workflowReq := &coze.RunWorkflowsReq{ + WorkflowID: c.config.APIKey, + Parameters: params, + } + + // 调用工作流 + resp, err := c.cozeApi.Workflows.Runs.Create(ctx, workflowReq) + if err != nil { + return nil, fmt.Errorf("工作流调用失败: %w", err) + } + + // 处理工作流响应 + if resp == nil { + return nil, fmt.Errorf("工作流响应为空") + } + + // 返回工作流执行结果 + return resp, nil +} + +type CompanyInfo struct { + BankruptcyPublicList interface{} `json:"bankruptcy_public_list"` // 破产公示列表 + Changes interface{} `json:"changes"` // 变更记录 + Clears interface{} `json:"clears"` // 清算记录 + Employees interface{} `json:"employees"` // 员工列表 + Executes interface{} `json:"executes"` // 执行记录 + Illegals interface{} `json:"illegals"` // 违法记录 + JudicialList interface{} `json:"judicial_list"` // 司法记录 + Operations interface{} `json:"operations"` // 经营记录 + Partners interface{} `json:"partners"` // 合伙人列表 + Phone string `json:"phone"` // 联系电话 + // 搜索数据 + Searchdata struct { + Authority interface{} `json:"authority"` + BusinessScope interface{} `json:"business_scope"` + Capital interface{} `json:"capital"` + CompanyAddress interface{} `json:"company_address"` + CompanyName string `json:"company_name"` + CompanyStatus interface{} `json:"company_status"` + CompanyType interface{} `json:"company_type"` + CreditNo interface{} `json:"credit_no"` + EstablishDate interface{} `json:"establish_date"` + Industry interface{} `json:"industry"` + LegalPerson interface{} `json:"legal_person"` + Province interface{} `json:"province"` + } `json:"searchdata"` + Xzcf interface{} `json:"xzcf"` // 行政处罚 +} diff --git a/internal/tools/public/coze_express.go b/internal/tools/public/coze_express.go new file mode 100644 index 0000000..58e6172 --- /dev/null +++ b/internal/tools/public/coze_express.go @@ -0,0 +1,141 @@ +package public + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/utils_ollama" + "context" + "encoding/json" + "fmt" + + "github.com/ollama/ollama/api" + + "github.com/coze-dev/coze-go" +) + +type CozeExpress struct { + cozeApi coze.CozeAPI + config config.ToolConfig + llm *utils_ollama.Client +} + +// NewCozeExpress 创建 CozeExpress 实例 +func NewCozeExpress(config config.ToolConfig, llm *utils_ollama.Client) *CozeExpress { + return &CozeExpress{ + cozeApi: newCozeApi(config), + config: config, + llm: llm, + } +} + +// newCozeExpressClient 创建 CozeExpress 客户端 +func newCozeExpressApi(config config.ToolConfig) coze.CozeAPI { + authCli := coze.NewTokenAuth(config.APISecret) + cozeApi := coze.NewCozeAPI(authCli, coze.WithBaseURL(config.BaseURL)) + return cozeApi +} + +// Name 返回工具名称 +func (c *CozeExpress) Name() string { + return "coze_express" +} + +// Description 返回工具描述 +func (c *CozeExpress) Description() string { + return "查询快递物流信息" +} + +// Definition 返回工具定义 +func (c *CozeExpress) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ + Type: "function", + Function: entitys.FunctionDef{ + Name: c.Name(), + Description: c.Description(), + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "express_id": map[string]interface{}{ + "type": "string", + "description": "快递单号", + }, + }, + "required": []string{"express_id"}, + }, + }, + } +} + +// Execute 执行查询 +func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.Recognize) error { + var req map[string]interface{} + if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + return fmt.Errorf("invalid express request: %w", err) + } + + if req["express_id"] == "" { + return fmt.Errorf("express_id is required") + } + + // 调用 Coze 工作流查询快递物流信息 + rsp, err := c.callWorkflow(ctx, req) + if err != nil { + return fmt.Errorf("failed to get real weather: %w", err) + } + err = c.llm.ChatStream(ctx, requireData.Ch, []api.Message{ + { + Role: "system", + Content: "你是一个快递查询助手。用户可能会提供快递单号,你需要分析快递单号,根据快递单号查询物流信息并反馈给我", + }, + { + Role: "assistant", + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.ChatHis)), + }, + { + Role: "assistant", + Content: fmt.Sprintf("需要分析的快递单号:%s", rsp.Data), + }, + { + Role: "user", + Content: requireData.UserContent.Text, + }, + }, c.Name(), "") + if err != nil { + return fmt.Errorf("failed to get express info: %w", err) + } + + //entitys.ResText(requireData.Ch, "", rsp.Data) + + return nil +} + +// CallWorkflow 调用 Coze 工作流 +// 参数: +// - ctx: 上下文,用于控制超时和取消 +// - workflowId: 工作流 ID +// - params: 工作流参数 +// 返回: +// - interface{}: 工作流执行结果 +// - error: 错误信息 +func (c *CozeExpress) callWorkflow(ctx context.Context, params map[string]interface{}) (*coze.RunWorkflowsResp, error) { + // 准备工作流请求参数 + workflowReq := &coze.RunWorkflowsReq{ + WorkflowID: c.config.APIKey, + Parameters: params, + } + + // 调用工作流 + resp, err := c.cozeApi.Workflows.Runs.Create(ctx, workflowReq) + if err != nil { + return nil, fmt.Errorf("工作流调用失败: %w", err) + } + + // 处理工作流响应 + if resp == nil { + return nil, fmt.Errorf("工作流响应为空") + } + + // 返回工作流执行结果 + return resp, nil +} diff --git a/internal/tools/konwledge_base.go b/internal/tools/public/konwledge_base.go similarity index 92% rename from internal/tools/konwledge_base.go rename to internal/tools/public/konwledge_base.go index fb5a7f1..48ddc1e 100644 --- a/internal/tools/konwledge_base.go +++ b/internal/tools/public/konwledge_base.go @@ -1,9 +1,10 @@ -package tools +package public import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/rec_extra" "bufio" "context" "encoding/json" @@ -58,7 +59,8 @@ func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition { } // Execute 执行知识库查询 -func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.Recognize) error { + entitys.ResLoading(requireData.Ch, k.Name(), "正在为您搜索相关信息") return k.chat(requireData) @@ -90,20 +92,23 @@ func (this *KnowledgeBaseTool) msgContentParse(input string, channel chan entity } // 请求知识库聊天 -func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error) { - +func (this *KnowledgeBaseTool) chat(rec *entitys.Recognize) (err error) { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } req := l_request.Request{ Method: "post", - Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + requireData.KnowledgeConf.Session, + Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + ext.KnowledgeConf.Session, Params: nil, Headers: map[string]string{ "Content-Type": "application/json", - "X-API-Key": requireData.KnowledgeConf.ApiKey, + "X-API-Key": ext.KnowledgeConf.ApiKey, }, Cookies: nil, Data: nil, Json: map[string]interface{}{ - "query": requireData.KnowledgeConf.Query, + "query": ext.KnowledgeConf.Query, }, Files: nil, Raw: "", @@ -117,7 +122,7 @@ func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error } defer rsp.Body.Close() - err = this.connectAndReadSSE(rsp, requireData.Ch) + err = this.connectAndReadSSE(rsp, rec.Ch) if err != nil { return } diff --git a/internal/tools/konwledge_base_test.go b/internal/tools/public/konwledge_base_test.go similarity index 97% rename from internal/tools/konwledge_base_test.go rename to internal/tools/public/konwledge_base_test.go index de1ab3f..0838587 100644 --- a/internal/tools/konwledge_base_test.go +++ b/internal/tools/public/konwledge_base_test.go @@ -1,4 +1,4 @@ -package tools +package public import ( "testing" diff --git a/internal/tools/normal_chat.go b/internal/tools/public/normal_chat.go similarity index 80% rename from internal/tools/normal_chat.go rename to internal/tools/public/normal_chat.go index 56010fc..cde667c 100644 --- a/internal/tools/normal_chat.go +++ b/internal/tools/public/normal_chat.go @@ -1,4 +1,4 @@ -package tools +package public import ( "ai_scheduler/internal/config" @@ -43,9 +43,9 @@ func (w *NormalChatTool) Definition() entitys.ToolDefinition { } // Execute 执行直连天下订单详情查询 -func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (w *NormalChatTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req NormalChat - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxOrderDetail request: %w", err) } if req.ChatContent == "" { @@ -53,25 +53,25 @@ func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.Requi } // 这里可以集成真实的直连天下订单详情API - return w.chat(requireData, &req) + return w.chat(rec, &req) } // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 -func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) { +func (w *NormalChatTool) chat(rec *entitys.Recognize, chat *NormalChat) (err error) { //requireData.Ch <- entitys.Response{ // Index: w.Name(), // Content: "", // Type: entitys.ResponseStream, //} - err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ + err = w.llm.ChatStream(context.TODO(), rec.Ch, []api.Message{ { Role: "system", Content: "你是一个聊天助手", }, { Role: "assistant", - Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(rec.ChatHis)), }, { Role: "user", diff --git a/internal/tools/public/weather.go b/internal/tools/public/weather.go new file mode 100644 index 0000000..ca36907 --- /dev/null +++ b/internal/tools/public/weather.go @@ -0,0 +1,345 @@ +package public + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/l_request" + "log" + "strconv" + + "context" + "encoding/json" + "fmt" + "time" +) + +// WeatherTool 天气查询工具 +type WeatherTool struct { + mockData bool + config config.ToolConfig +} + +// NewWeatherTool 创建天气工具 +func NewWeatherTool(config config.ToolConfig) *WeatherTool { + return &WeatherTool{config: config} +} + +// Name 返回工具名称 +func (w *WeatherTool) Name() string { + return "get_weather" +} + +// Description 返回工具描述 +func (w *WeatherTool) Description() string { + return "获取指定城市的天气信息" +} + +// Definition 返回工具定义 +func (w *WeatherTool) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ + Type: "function", + Function: entitys.FunctionDef{ + Name: w.Name(), + Description: w.Description(), + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{ + "type": "string", + "description": "城市名称,如:北京、上海、广州", + }, + "unit": map[string]interface{}{ + "type": "string", + "description": "温度单位,celsius(摄氏度)或fahrenheit(华氏度)", + "enum": []string{"celsius", "fahrenheit"}, + "default": "celsius", + }, + "extensions": map[string]interface{}{ + "type": "string", + "description": "扩展参数,base/all base:返回实况天气 all:返回预报天气", + "enum": []string{"base", "all"}, + "default": "base", + }, + }, + "required": []string{"city"}, + }, + }, + } +} + +// WeatherRequest 天气请求参数 +type WeatherRequest struct { + City string `json:"city"` + Extensions string `json:"extensions"` // 扩展参数,base/all base:返回实况天气 all:返回预报天气 + Unit string `json:"unit,omitempty"` +} + +// WeatherResponse 天气响应 +type WeatherResponse struct { + City string `json:"city"` + Unit string `json:"unit"` + Timestamp string `json:"timestamp"` + LiveWeather *LiveWeather `json:"live_weather,omitempty"` // 实时天气 + Forecasts []ForecastWeather `json:"forecasts,omitempty"` // 预报天气 +} + +// ForecastWeather 预报天气 +type ForecastWeather struct { + Date string `json:"date"` + Week string `json:"week"` + DayWeather string `json:"day_weather"` + NightWeather string `json:"night_weather"` + DayTemp float64 `json:"day_temp"` + NightTemp float64 `json:"night_temp"` + DayWind string `json:"day_wind"` + NightWind string `json:"night_wind"` + DayWindPower string `json:"day_wind_power"` + NightWindPower string `json:"night_wind_power"` +} + +// LiveWeather 实时天气 +type LiveWeather struct { + Temperature float64 `json:"temperature"` + Condition string `json:"condition"` + Humidity int `json:"humidity"` + WindSpeed float64 `json:"wind_speed"` + WindDirection string `json:"wind_direction"` +} + +// Execute 执行天气查询 +func (w *WeatherTool) Execute(ctx context.Context, rec *entitys.Recognize) error { + var req WeatherRequest + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { + return fmt.Errorf("invalid weather request: %w", err) + } + + if req.City == "" { + return fmt.Errorf("city is required") + } + + if req.Unit == "" { + req.Unit = "celsius" + } + + // 设置默认获取实时天气信息 + if req.Extensions == "" { + req.Extensions = "base" + } + + // 这里可以集成真实的天气API + responseMsg, err := w.getRealWeather(req) + if err != nil { + return fmt.Errorf("failed to get real weather: %w", err) + } + + // 根据 extensions 参数返回不同的天气信息 + if req.Extensions == "base" { + entitys.ResText(rec.Ch, "", fmt.Sprintf("%s实时天气:%s,温度:%.1f℃,湿度:%d%%,风速:%.1fkm/h,风向:%s", + req.City, + responseMsg.LiveWeather.Condition, + responseMsg.LiveWeather.Temperature, + responseMsg.LiveWeather.Humidity, + responseMsg.LiveWeather.WindSpeed, + responseMsg.LiveWeather.WindDirection)) + } else { + + rspStr := fmt.Sprintf("%s天气预报:\n", req.City) + for _, forecast := range responseMsg.Forecasts { + rspStr += fmt.Sprintf("%s 温度:%.1f℃/%.1f℃ 风力:%s %s\n", + forecast.Date, forecast.DayTemp, forecast.NightTemp, forecast.DayWind, forecast.NightWind) + } + + entitys.ResText(rec.Ch, "", rspStr) + } + return nil +} + +//// getMockWeather 获取模拟天气数据 +//func (w *WeatherTool) getMockWeather(city, unit string) *WeatherResponse { +// rand.Seed(time.Now().UnixNano()) +// +// // 模拟不同城市的基础温度 +// baseTemp := map[string]float64{ +// "北京": 15.0, +// "上海": 18.0, +// "广州": 25.0, +// "深圳": 26.0, +// "杭州": 17.0, +// "成都": 16.0, +// } +// +// temp := baseTemp[city] +// if temp == 0 { +// temp = 20.0 // 默认温度 +// } +// +// // 添加随机变化 +// temp += (rand.Float64() - 0.5) * 10 +// +// // 转换温度单位 +// if unit == "fahrenheit" { +// temp = temp*9/5 + 32 +// } +// +// conditions := []string{"晴朗", "多云", "阴天", "小雨", "中雨"} +// condition := conditions[rand.Intn(len(conditions))] +// +// return &WeatherResponse{ +// City: city, +// Temperature: float64(int(temp*10)) / 10, // 保留一位小数 +// Unit: unit, +// Condition: condition, +// Humidity: rand.Intn(40) + 40, // 40-80% +// WindSpeed: float64(rand.Intn(20)) + 1.0, +// Timestamp: time.Now().Format("2006-01-02 15:04:05"), +// } +//} + +// getRealWeather 调用高德天气API +func (w *WeatherTool) getRealWeather(request WeatherRequest) (*WeatherResponse, error) { + // 构建请求URL + req := l_request.Request{ + Url: w.config.BaseURL, + Headers: map[string]string{}, + Params: map[string]string{ + "city": request.City, // 城市名称 + "key": w.config.APIKey, // API密钥 + // extensions: 基础天气数据 可选值:base/all base:返回实况天气 all:返回预报天气 + "extensions": request.Extensions, // 基础天气数据 + "output": "JSON", // JSON格式返回 + }, + Method: "GET", + } + res, err := req.Send() + if err != nil { + return nil, err + } + + // 解析API响应 + var apiResp struct { + Status string `json:"status"` + Count string `json:"count"` + Info string `json:"info"` + Infocode string `json:"infocode"` + + // 预报天气信息数据 + Forecasts []struct { + City string `json:"city"` + Adcode string `json:"adcode"` + Province string `json:"province"` + Reporttime string `json:"reporttime"` + Casts []struct { + Date string `json:"date"` + Week string `json:"week"` + Dayweather string `json:"dayweather"` + Nightweather string `json:"nightweather"` + Daytemp string `json:"daytemp"` + Nighttemp string `json:"nighttemp"` + Daywind string `json:"daywind"` + Nightwind string `json:"nightwind"` + Daypower string `json:"daypower"` + Nightpower string `json:"nightpower"` + DaytempFloat string `json:"daytemp_float"` + NighttempFloat string `json:"nighttemp_float"` + } `json:"casts"` + } `json:"forecasts"` + // 实况天气信息数据 + Lives []struct { + Province string `json:"province"` + City string `json:"city"` + Adcode string `json:"adcode"` + Weather string `json:"weather"` + Temperature string `json:"temperature"` + Winddirection string `json:"winddirection"` + Windpower string `json:"windpower"` + Humidity string `json:"humidity"` + Reporttime string `json:"reporttime"` + TemperatureFloat string `json:"temperature_float"` + HumidityFloat string `json:"humidity_float"` + } `json:"lives"` + } + + log.Printf("weather API response: %s", string(res.Content)) + + if err = json.Unmarshal(res.Content, &apiResp); err != nil { + return nil, fmt.Errorf("parse weather API response failed: %w", err) + } + + // 检查API返回状态 + if apiResp.Status != "1" { + return nil, fmt.Errorf("weather API returned error: %s, info: %s", apiResp.Status, apiResp.Info) + } + + // 获取城市名称 + cityName := "" + if len(apiResp.Lives) > 0 { + cityName = apiResp.Lives[0].City + } else if len(apiResp.Forecasts) > 0 { + cityName = apiResp.Forecasts[0].City + } else { + return nil, fmt.Errorf("no weather data found") + } + + // 构建响应 + response := &WeatherResponse{ + City: cityName, + Unit: request.Unit, + Timestamp: time.Now().Format("2006-01-02 15:04:05"), + } + // 处理实时天气 + if len(apiResp.Lives) > 0 { + liveData := apiResp.Lives[0] + + // 转换温度 + temp, _ := strconv.ParseFloat(liveData.Temperature, 64) + if request.Unit == "fahrenheit" { + temp = temp*9/5 + 32 + } + + // 转换湿度和风速 + humidity, _ := strconv.Atoi(liveData.Humidity) + windSpeed, _ := strconv.ParseFloat(liveData.Windpower, 64) + + response.LiveWeather = &LiveWeather{ + Temperature: temp, + Condition: liveData.Weather, + Humidity: humidity, + WindSpeed: windSpeed, + WindDirection: liveData.Winddirection, + } + } + + // 处理预报天气 + if len(apiResp.Forecasts) > 0 && len(apiResp.Forecasts[0].Casts) > 0 { + response.Forecasts = make([]ForecastWeather, 0, len(apiResp.Forecasts[0].Casts)) + + for _, cast := range apiResp.Forecasts[0].Casts { + // 转换温度 + dayTemp, _ := strconv.ParseFloat(cast.Daytemp, 64) + nightTemp, _ := strconv.ParseFloat(cast.Nighttemp, 64) + + if request.Unit == "fahrenheit" { + dayTemp = dayTemp*9/5 + 32 + nightTemp = nightTemp*9/5 + 32 + } + + forecast := ForecastWeather{ + Date: cast.Date, + Week: cast.Week, + DayWeather: cast.Dayweather, + NightWeather: cast.Nightweather, + DayTemp: dayTemp, + NightTemp: nightTemp, + DayWind: cast.Daywind, + NightWind: cast.Nightwind, + DayWindPower: cast.Daypower, + NightWindPower: cast.Nightpower, + } + + response.Forecasts = append(response.Forecasts, forecast) + } + } + + return response, nil + +} diff --git a/internal/tools/weather.go b/internal/tools/weather.go deleted file mode 100644 index 744216f..0000000 --- a/internal/tools/weather.go +++ /dev/null @@ -1,139 +0,0 @@ -package tools - -import ( - "ai_scheduler/internal/entitys" - - "context" - "encoding/json" - "fmt" - "math/rand" - "time" -) - -// WeatherTool 天气查询工具 -type WeatherTool struct { - mockData bool -} - -// NewWeatherTool 创建天气工具 -func NewWeatherTool() *WeatherTool { - return &WeatherTool{} -} - -// Name 返回工具名称 -func (w *WeatherTool) Name() string { - return "get_weather" -} - -// Description 返回工具描述 -func (w *WeatherTool) Description() string { - return "获取指定城市的天气信息" -} - -// Definition 返回工具定义 -func (w *WeatherTool) Definition() entitys.ToolDefinition { - return entitys.ToolDefinition{ - Type: "function", - Function: entitys.FunctionDef{ - Name: w.Name(), - Description: w.Description(), - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "city": map[string]interface{}{ - "type": "string", - "description": "城市名称,如:北京、上海、广州", - }, - "unit": map[string]interface{}{ - "type": "string", - "description": "温度单位,celsius(摄氏度)或fahrenheit(华氏度)", - "enum": []string{"celsius", "fahrenheit"}, - "default": "celsius", - }, - }, - "required": []string{"city"}, - }, - }, - } -} - -// WeatherRequest 天气请求参数 -type WeatherRequest struct { - City string `json:"city"` - Unit string `json:"unit,omitempty"` -} - -// WeatherResponse 天气响应 -type WeatherResponse struct { - City string `json:"city"` - Temperature float64 `json:"temperature"` - Unit string `json:"unit"` - Condition string `json:"condition"` - Humidity int `json:"humidity"` - WindSpeed float64 `json:"wind_speed"` - Timestamp string `json:"timestamp"` -} - -// Execute 执行天气查询 -func (w *WeatherTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) { - var req WeatherRequest - if err := json.Unmarshal(args, &req); err != nil { - return nil, fmt.Errorf("invalid weather request: %w", err) - } - - if req.City == "" { - return nil, fmt.Errorf("city is required") - } - - if req.Unit == "" { - req.Unit = "celsius" - } - - if w.mockData { - return w.getMockWeather(req.City, req.Unit), nil - } - - // 这里可以集成真实的天气API - return w.getMockWeather(req.City, req.Unit), nil -} - -// getMockWeather 获取模拟天气数据 -func (w *WeatherTool) getMockWeather(city, unit string) *WeatherResponse { - rand.Seed(time.Now().UnixNano()) - - // 模拟不同城市的基础温度 - baseTemp := map[string]float64{ - "北京": 15.0, - "上海": 18.0, - "广州": 25.0, - "深圳": 26.0, - "杭州": 17.0, - "成都": 16.0, - } - - temp := baseTemp[city] - if temp == 0 { - temp = 20.0 // 默认温度 - } - - // 添加随机变化 - temp += (rand.Float64() - 0.5) * 10 - - // 转换温度单位 - if unit == "fahrenheit" { - temp = temp*9/5 + 32 - } - - conditions := []string{"晴朗", "多云", "阴天", "小雨", "中雨"} - condition := conditions[rand.Intn(len(conditions))] - - return &WeatherResponse{ - City: city, - Temperature: float64(int(temp*10)) / 10, // 保留一位小数 - Unit: unit, - Condition: condition, - Humidity: rand.Intn(40) + 40, // 40-80% - WindSpeed: float64(rand.Intn(20)) + 1.0, - Timestamp: time.Now().Format("2006-01-02 15:04:05"), - } -} diff --git a/internal/tools/zltx/order_after_reseller.go b/internal/tools/zltx/order_after_reseller.go index 71e0629..fcb71e0 100644 --- a/internal/tools/zltx/order_after_reseller.go +++ b/internal/tools/zltx/order_after_reseller.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" "context" "encoding/json" @@ -104,9 +105,9 @@ type OrderAfterSaleResellerApiExtItem struct { SerialCreateTime int `json:"createTime"` // 流水创建时间 } -func (t *OrderAfterSaleResellerTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleResellerTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req OrderAfterSaleResellerRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("解析参数失败,请重试或联系管理员") } if len(req.OrderNumber) == 0 && len(req.Account) == 0 { @@ -116,18 +117,22 @@ func (t *OrderAfterSaleResellerTool) Execute(ctx context.Context, requireData *e if req.SerialCreateTime != "" { _, err := time.ParseInLocation(time.DateTime, req.SerialCreateTime, time.Local) if err != nil { - entitys.ResLog(requireData.Ch, t.Name(), "时间格式不匹配,已置为空") + entitys.ResLog(rec.Ch, t.Name(), "时间格式不匹配,已置为空") req.SerialCreateTime = "" } } - entitys.ResLog(requireData.Ch, t.Name(), "正在拉取售后订单信息") + entitys.ResLog(rec.Ch, t.Name(), "正在拉取售后订单信息") - return t.checkOrderAfterSaleReseller(req, requireData) + return t.checkOrderAfterSaleReseller(req, rec) } -func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAfterSaleResellerRequest, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAfterSaleResellerRequest, rec *entitys.Recognize) error { var serialStartTime, serialEndTime int64 + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } if toolReq.SerialCreateTime != "" { // 流水创建时间上下浮动10min serialCreateTime, err := time.ParseInLocation(time.DateTime, toolReq.SerialCreateTime, time.Local) @@ -144,17 +149,16 @@ func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAf // 账号数量超过10直接截断 if len(toolReq.Account) > 10 { - entitys.ResLog(requireData.Ch, t.Name(), "账号数量超过10已被截断") + entitys.ResLog(rec.Ch, t.Name(), "账号数量超过10已被截断") toolReq.Account = toolReq.Account[:10] } headers := map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), } // 最终输出 var orderList []*OrderAfterSaleResellerData - var err error // 多订单号 if len(toolReq.OrderNumber) > 0 { @@ -217,8 +221,8 @@ func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAf return err } - entitys.ResLog(requireData.Ch, t.Name(), "售后订单信息拉取完成") - entitys.ResJson(requireData.Ch, t.Name(), string(jsonByte)) + entitys.ResLog(rec.Ch, t.Name(), "售后订单信息拉取完成") + entitys.ResJson(rec.Ch, t.Name(), string(jsonByte)) return nil } diff --git a/internal/tools/zltx/order_after_reseller_batch.go b/internal/tools/zltx/order_after_reseller_batch.go index e12e602..664d3d5 100644 --- a/internal/tools/zltx/order_after_reseller_batch.go +++ b/internal/tools/zltx/order_after_reseller_batch.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" "context" "encoding/json" @@ -100,25 +101,29 @@ type OrderAfterSaleResellerBatchApiExtItem struct { SerialCreateTime int `json:"createTime"` // 流水创建时间 } -func (t *OrderAfterSaleResellerBatchTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleResellerBatchTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req OrderAfterSaleResellerBatchRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("解析参数失败,请重试或联系管理员") } if len(req.OrderNumber) == 0 { return fmt.Errorf("批充订单号不能为空") } - entitys.ResLog(requireData.Ch, t.Name(), "正在拉取售后订单信息") + entitys.ResLog(rec.Ch, t.Name(), "正在拉取售后订单信息") - return t.checkOrderAfterSaleResellerBatch(req, requireData) + return t.checkOrderAfterSaleResellerBatch(req, rec) } -func (t *OrderAfterSaleResellerBatchTool) checkOrderAfterSaleResellerBatch(toolReq OrderAfterSaleResellerBatchRequest, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleResellerBatchTool) checkOrderAfterSaleResellerBatch(toolReq OrderAfterSaleResellerBatchRequest, rec *entitys.Recognize) error { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } req := l_request.Request{ Url: t.config.BaseURL, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "POST", Json: map[string]any{ @@ -200,7 +205,7 @@ func (t *OrderAfterSaleResellerBatchTool) checkOrderAfterSaleResellerBatch(toolR return err } - entitys.ResLog(requireData.Ch, t.Name(), "售后订单信息拉取完成") - entitys.ResJson(requireData.Ch, t.Name(), string(jsonByte)) + entitys.ResLog(rec.Ch, t.Name(), "售后订单信息拉取完成") + entitys.ResJson(rec.Ch, t.Name(), string(jsonByte)) return nil } diff --git a/internal/tools/zltx/order_after_supplier.go b/internal/tools/zltx/order_after_supplier.go index 4baf7f9..9e9ad82 100644 --- a/internal/tools/zltx/order_after_supplier.go +++ b/internal/tools/zltx/order_after_supplier.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" "context" "encoding/json" @@ -98,9 +99,9 @@ type OrderAfterSaleSupplierApiExtItem struct { SerialCreateTime int `json:"createTime"` // 流水创建时间 } -func (t *OrderAfterSaleSupplierTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleSupplierTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req OrderAfterSaleSupplierRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("解析参数失败,请重试或联系管理员") } if len(req.SerialNumber) == 0 && len(req.Account) == 0 { @@ -110,18 +111,24 @@ func (t *OrderAfterSaleSupplierTool) Execute(ctx context.Context, requireData *e if req.SerialCreateTime != "" { _, err := time.ParseInLocation(time.DateTime, req.SerialCreateTime, time.Local) if err != nil { - entitys.ResLog(requireData.Ch, t.Name(), "时间格式不匹配,已置为空") + entitys.ResLog(rec.Ch, t.Name(), "时间格式不匹配,已置为空") req.SerialCreateTime = "" } } - entitys.ResLog(requireData.Ch, t.Name(), "正在拉取售后订单信息") + entitys.ResLog(rec.Ch, t.Name(), "正在拉取售后订单信息") - return t.checkOrderAfterSaleSupplier(req, requireData) + return t.checkOrderAfterSaleSupplier(req, rec) } -func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAfterSaleSupplierRequest, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAfterSaleSupplierRequest, rec *entitys.Recognize) error { + var serialStartTime, serialEndTime int64 + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } + if toolReq.SerialCreateTime != "" { // 流水创建时间上下浮动10min serialCreateTime, err := time.ParseInLocation(time.DateTime, toolReq.SerialCreateTime, time.Local) @@ -138,17 +145,16 @@ func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAf // 账号数量超过10直接截断 if len(toolReq.Account) > 10 { - entitys.ResLog(requireData.Ch, t.Name(), "账号数量超过10已被截断") + entitys.ResLog(rec.Ch, t.Name(), "账号数量超过10已被截断") toolReq.Account = toolReq.Account[:10] } headers := map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), } // 最终输出 var orderList []*OrderAfterSaleSupplierData - var err error // 多流水号 if len(toolReq.SerialNumber) > 0 { @@ -210,8 +216,8 @@ func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAf return err } - entitys.ResLog(requireData.Ch, t.Name(), "售后订单信息拉取完成") - entitys.ResJson(requireData.Ch, t.Name(), string(jsonByte)) + entitys.ResLog(rec.Ch, t.Name(), "售后订单信息拉取完成") + entitys.ResJson(rec.Ch, t.Name(), string(jsonByte)) return nil } diff --git a/internal/tools/zltx_after_direct.go b/internal/tools/zltx/zltx_after_direct.go similarity index 99% rename from internal/tools/zltx_after_direct.go rename to internal/tools/zltx/zltx_after_direct.go index 5f1e20e..01a6282 100644 --- a/internal/tools/zltx_after_direct.go +++ b/internal/tools/zltx/zltx_after_direct.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/internal/tools/zltx_after_pre.go b/internal/tools/zltx/zltx_after_pre.go similarity index 99% rename from internal/tools/zltx_after_pre.go rename to internal/tools/zltx/zltx_after_pre.go index b44812b..5bbf809 100644 --- a/internal/tools/zltx_after_pre.go +++ b/internal/tools/zltx/zltx_after_pre.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx/zltx_order_detail.go similarity index 82% rename from internal/tools/zltx_order_detail.go rename to internal/tools/zltx/zltx_order_detail.go index f253f59..52d6bbb 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx/zltx_order_detail.go @@ -1,9 +1,10 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/utils_ollama" "context" "encoding/json" @@ -81,9 +82,9 @@ type ZltxOrderDetailData struct { } // Execute 执行直连天下订单详情查询 -func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (w *ZltxOrderDetailTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req ZltxOrderDetailRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxOrderDetail request: %w", err) } if req.OrderNumber == "" { @@ -91,16 +92,20 @@ func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys. } // 这里可以集成真实的直连天下订单详情API - return w.getZltxOrderDetail(requireData, req.OrderNumber) + return w.getZltxOrderDetail(rec, req.OrderNumber) } // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 -func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireData, number string) (err error) { +func (w *ZltxOrderDetailTool) getZltxOrderDetail(rec *entitys.Recognize, number string) (err error) { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } //查询订单详情 req := l_request.Request{ Url: fmt.Sprintf(w.config.BaseURL, number), Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "GET", } @@ -121,15 +126,15 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat if err = json.Unmarshal(res.Content, &resData); err != nil { return } - entitys.ResJson(requireData.Ch, w.Name(), res.Text) + entitys.ResJson(rec.Ch, w.Name(), res.Text) if resData.Data.Direct != nil { - entitys.ResLoading(requireData.Ch, w.Name(), "正在分析订单日志") + entitys.ResLoading(rec.Ch, w.Name(), "正在分析订单日志") req = l_request.Request{ Url: fmt.Sprintf(w.config.AddURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)), Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "GET", } @@ -149,14 +154,14 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat return fmt.Errorf("订单日志解析失败:%s", err) } - err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ + err = w.llm.ChatStream(context.TODO(), rec.Ch, []api.Message{ { Role: "system", Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,失败订单->分析失败原因,成功订单->找出整个日志的 Base64 编码的 JSON 数据的内容进行转换并反馈给我", }, { Role: "assistant", - Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(rec.ChatHis)), }, { Role: "assistant", @@ -164,7 +169,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat }, { Role: "user", - Content: requireData.Req.Text, + Content: rec.UserContent.Text, }, }, w.Name(), "") if err != nil { @@ -172,7 +177,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat } } if resData.Data.Direct == nil { - entitys.ResText(requireData.Ch, w.Name(), "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘") + entitys.ResText(rec.Ch, w.Name(), "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘") } return } diff --git a/internal/tools/zltx_order_direct_log.go b/internal/tools/zltx/zltx_order_direct_log.go similarity index 83% rename from internal/tools/zltx_order_direct_log.go rename to internal/tools/zltx/zltx_order_direct_log.go index bf82408..188a954 100644 --- a/internal/tools/zltx_order_direct_log.go +++ b/internal/tools/zltx/zltx_order_direct_log.go @@ -1,8 +1,9 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/rec_extra" "context" "encoding/json" "fmt" @@ -67,25 +68,28 @@ type ZltxOrderDirectLogData struct { Data map[string]interface{} `json:"data"` } -func (t *ZltxOrderLogTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (t *ZltxOrderLogTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req ZltxOrderLogRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxOrderLog request: %w", err) } if req.OrderNumber == "" || req.SerialNumber == "" { return fmt.Errorf("orderNumber and serialNumber is required") } - return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, requireData) + return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, rec) } -func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, requireData *entitys.RequireData) (err error) { +func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, rec *entitys.Recognize) (err error) { //查询订单详情 - + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber) req := l_request.Request{ Url: url, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "GET", } @@ -100,7 +104,7 @@ func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, req if err = json.Unmarshal(res.Content, &resData); err != nil { return } - entitys.ResJson(requireData.Ch, t.Name(), res.Text) + entitys.ResJson(rec.Ch, t.Name(), res.Text) return } diff --git a/internal/tools/zltx_product.go b/internal/tools/zltx/zltx_product.go similarity index 94% rename from internal/tools/zltx_product.go rename to internal/tools/zltx/zltx_product.go index 61b8bb1..0c63d84 100644 --- a/internal/tools/zltx_product.go +++ b/internal/tools/zltx/zltx_product.go @@ -1,8 +1,9 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/rec_extra" "context" "encoding/json" "fmt" @@ -53,12 +54,12 @@ type ZltxProductRequest struct { Name string `json:"name"` } -func (z ZltxProductTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (z ZltxProductTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req ZltxProductRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxProduct request: %w", err) } - return z.getZltxProduct(&req, requireData) + return z.getZltxProduct(&req, rec) } type ZltxProductResponse struct { @@ -133,8 +134,11 @@ type ZltxProductData struct { PlatformProductList interface{} `json:"platform_product_list"` } -func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *entitys.RequireData) error { - +func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, rec *entitys.Recognize) error { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } var Url string var params map[string]string if body.Id != "" { @@ -153,7 +157,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e //根据商品ID或名称走不同的接口查询 Url: Url, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Params: params, Method: "GET", @@ -185,7 +189,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e for i := range resp.Data.List { // 调用 平台商品列表 if resp.Data.List[i].AuthProductIds != "" { - platformProductList := z.ExecutePlatformProductList(requireData.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID) + platformProductList := z.ExecutePlatformProductList(ext.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID) resp.Data.List[i].PlatformProductList = platformProductList } } @@ -194,7 +198,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e if err != nil { return err } - entitys.ResJson(requireData.Ch, z.Name(), string(marshal)) + entitys.ResJson(rec.Ch, z.Name(), string(marshal)) return nil } diff --git a/internal/tools/zltx_statistics.go b/internal/tools/zltx/zltx_statistics.go similarity index 86% rename from internal/tools/zltx_statistics.go rename to internal/tools/zltx/zltx_statistics.go index f53d4b1..5d71a9b 100644 --- a/internal/tools/zltx_statistics.go +++ b/internal/tools/zltx/zltx_statistics.go @@ -1,8 +1,9 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/rec_extra" "context" "encoding/json" "fmt" @@ -47,15 +48,15 @@ type ZltxOrderStatisticsRequest struct { Number string `json:"number"` } -func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req ZltxOrderStatisticsRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return err } if req.Number == "" { return fmt.Errorf("number is required") } - return z.getZltxOrderStatistics(req.Number, requireData) + return z.getZltxOrderStatistics(req.Number, rec) } type ZltxOrderStatisticsResponse struct { @@ -75,14 +76,18 @@ type ZltxOrderStatisticsData struct { Total int `json:"total"` } -func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireData *entitys.RequireData) error { +func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, rec *entitys.Recognize) error { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } //查询订单详情 url := fmt.Sprintf("%s%s", z.config.BaseURL, number) req := l_request.Request{ Url: url, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "GET", } @@ -108,7 +113,7 @@ func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireDa if err != nil { return err } - entitys.ResJson(requireData.Ch, z.Name(), string(jsonByte)) + entitys.ResJson(rec.Ch, z.Name(), string(jsonByte)) return nil } diff --git a/internal/tools_bot/dtalk_bot.go b/internal/tools_bot/dtalk_bot.go deleted file mode 100644 index 0e9525a..0000000 --- a/internal/tools_bot/dtalk_bot.go +++ /dev/null @@ -1,38 +0,0 @@ -package tools_bot - -import ( - "ai_scheduler/internal/config" - "ai_scheduler/internal/data/constants" - errors "ai_scheduler/internal/data/error" - "ai_scheduler/internal/data/impl" - "ai_scheduler/internal/entitys" - "ai_scheduler/internal/pkg/utils_ollama" - "context" - - "github.com/gofiber/fiber/v2/log" -) - -type BotTool struct { - config *config.Config - llm *utils_ollama.Client - sessionImpl *impl.SessionImpl - taskMap map[string]string // task_id -> session_id - // zltxOrderAfterSaleTool tools.ZltxOrderAfterSaleTool -} - -// NewBotTool 创建直连天下订单详情工具 -func NewBotTool(config *config.Config, llm *utils_ollama.Client, sessionImpl *impl.SessionImpl) *BotTool { - return &BotTool{config: config, llm: llm, sessionImpl: sessionImpl, taskMap: make(map[string]string)} -} - -// Execute 执行直连天下订单详情查询 -func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) { - switch toolName { - case constants.BotToolsBugOptimizationSubmit: - err = w.BugOptimizationSubmit(ctx, requireData) - default: - log.Errorf("未知的工具类型:%s", toolName) - err = errors.ParamErr("未知的工具类型:%s", toolName) - } - return -} diff --git a/internal/tools_bot/provider_set.go b/internal/tools_bot/provider_set.go deleted file mode 100644 index 7bcff12..0000000 --- a/internal/tools_bot/provider_set.go +++ /dev/null @@ -1,9 +0,0 @@ -package tools_bot - -import ( - "github.com/google/wire" -) - -var ProviderSetBotTools = wire.NewSet( - NewBotTool, -) diff --git a/utils/rds.go b/utils/rds.go index 57b89f1..874eb3e 100644 --- a/utils/rds.go +++ b/utils/rds.go @@ -13,10 +13,10 @@ type Rdb struct { var rdb *Rdb -func NewRdb(c *config.Redis) *Rdb { +func NewRdb(c *config.Config) *Rdb { if rdb == nil { //构建 redis - rdbBuild := buildRdb(c) + rdbBuild := buildRdb(&c.Redis) //退出时清理资源 rdb = &Rdb{Rdb: rdbBuild} }