Merge branch 'v3' into feature/fzy/refine
This commit is contained in:
commit
5e1fab132a
32
Dockerfile
32
Dockerfile
|
|
@ -1,3 +1,28 @@
|
|||
## 使用官方Go镜像作为构建环境
|
||||
FROM golang:1.24.1-alpine AS builder
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 使用国内镜像源加速APK包下载
|
||||
RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories
|
||||
|
||||
# 使用国内镜像源加速依赖下载
|
||||
ENV GOPROXY=https://goproxy.cn,direct
|
||||
|
||||
# 复制项目源码
|
||||
COPY . .
|
||||
|
||||
# 复制go模块依赖文件
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod tidy
|
||||
|
||||
RUN go install github.com/google/wire/cmd/wire@latest
|
||||
RUN wire ./cmd/server
|
||||
|
||||
# 编译Go应用程序,生成静态链接的二进制文件
|
||||
RUN go build -ldflags="-s -w" -o server ./cmd/server
|
||||
|
||||
# 创建最终镜像,用于运行编译后的Go程序
|
||||
FROM alpine
|
||||
|
||||
|
|
@ -12,10 +37,13 @@ RUN echo 'http://mirrors.ustc.edu.cn/alpine/v3.5/main' > /etc/apk/repositories \
|
|||
WORKDIR /app
|
||||
|
||||
# 将编译好的二进制文件从构建阶段复制到运行阶段
|
||||
COPY ./ /app
|
||||
COPY --from=builder /app/server ./server
|
||||
# 复制配置文件夹
|
||||
COPY --from=builder /app/config ./config
|
||||
|
||||
|
||||
ENV TZ=Asia/Shanghai
|
||||
# 设置容器启动时运行的命令
|
||||
|
||||
|
||||
CMD ["./bin/server"]
|
||||
CMD ["./server"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
# 服务器配置
|
||||
server:
|
||||
port: 8090
|
||||
host: "0.0.0.0"
|
||||
|
||||
ollama:
|
||||
base_url: "http://192.168.6.109:11434"
|
||||
model: "qwen3-coder:480b-cloud"
|
||||
generate_model: "qwen3-coder:480b-cloud"
|
||||
vl_model: "qwen2.5vl:7b"
|
||||
timeout: "120s"
|
||||
level: "info"
|
||||
format: "json"
|
||||
|
||||
vllm:
|
||||
base_url: "http://117.175.169.61:16001/v1"
|
||||
vl_model: "models/Qwen2.5-VL-3B-Instruct-AWQ"
|
||||
timeout: "120s"
|
||||
level: "info"
|
||||
|
||||
|
||||
sys:
|
||||
session_len: 6
|
||||
channel_pool_len: 100
|
||||
channel_pool_size: 32
|
||||
llm_pool_len: 5
|
||||
heartbeat_interval: 30
|
||||
redis:
|
||||
host: 47.97.27.195:6379
|
||||
type: node
|
||||
pass: lansexiongdi@666
|
||||
key: report-api-test
|
||||
pollSize: 5 #连接池大小,不配置,或配置为0表示不启用连接池
|
||||
minIdleConns: 2 #最小空闲连接数
|
||||
maxIdleTime: 30 #每个连接最大空闲时间,如果超过了这个时间会被关闭
|
||||
tls: 30
|
||||
db:
|
||||
db:
|
||||
driver: mysql
|
||||
source: root:SD###sdf323r343@tcp(121.199.38.107:3306)/sys_ai_test?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai
|
||||
|
||||
tools:
|
||||
zltxOrderDetail:
|
||||
enabled: true
|
||||
base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/ai/%s"
|
||||
add_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/log/%s/%s"
|
||||
api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU4MDkxOTU4LCJuYmYiOjE3NTgwOTAxNTgsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.Bjsx9f8yfcrV9EWxb0n6POwnXVOq9XPRD78JFZnnf1_VAVMN78W4W570SZL27PWuDnkD7E4oUg6RzeZwZgl7BZrNpNr-a-QpNC5qCptqrqXeNfVStmX7pxWA8GqnzI8ybkZgbhQ58Gje7DzdJtBq_8zte_LDaYhTYXdIc5EAG0AbCzAk22nPTl47nkMeHtmisXQVLEsdibl1hW3ViFJlXwfXvUrOENItmL1_mRYkggUB0MaTu2nHJOYM6PaOVGLHx-74eepnmK2rm6konFEb6ed-Ukc6gVR-nM9yWZaYLYNGNKJLwZoCX3tRuerq74n4kzQgWmUEJeaVI1yIGSw1zw"
|
||||
zltxProduct:
|
||||
enabled: true
|
||||
base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/oursProduct"
|
||||
add_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/platformProduct/getProductsByOfficialProductId"
|
||||
api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ"
|
||||
zltxOrderStatistics:
|
||||
base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/ai/search/"
|
||||
enabled: true
|
||||
api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ"
|
||||
knowledge:
|
||||
base_url: "http://117.175.169.61:10000"
|
||||
enabled: true
|
||||
DingTalkBot:
|
||||
enabled: true
|
||||
api_key: "dingsbbntrkeiyazcfdg"
|
||||
api_secret: "ObqxwyR20r9rVNhju0sCPQyQA98_FZSc32W4vgxnGFH_b02HZr1BPCJsOAF816nu"
|
||||
zltxOrderAfterSaleSupplier:
|
||||
enabled: true
|
||||
base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/directs"
|
||||
zltxOrderAfterSaleReseller:
|
||||
enabled: true
|
||||
base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/reseller_pre_ai"
|
||||
zltxOrderAfterSaleResellerBatch:
|
||||
enabled: true
|
||||
base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/reseller_pre_ai"
|
||||
|
||||
|
||||
|
||||
default_prompt:
|
||||
img_recognize:
|
||||
system_prompt:
|
||||
'你是一个具备图像理解与用户意图分析能力的智能助手。当用户提供一张图片时,请完成以下任务:
|
||||
1. 关键信息提取:
|
||||
提取出图片中对用户可能有用的关键信息(例如金额、日期、标题、编号、联系信息、商品名称等)。
|
||||
若图片为文档类(如合同、发票、收据),请结构化输出关键字段(如客户名称、金额、开票日期等)。
|
||||
'
|
||||
user_prompt: '识别图片内容'
|
||||
# 权限配置
|
||||
permissionConfig:
|
||||
permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId="
|
||||
|
|
@ -3,15 +3,21 @@ server:
|
|||
port: 8090
|
||||
host: "0.0.0.0"
|
||||
|
||||
|
||||
ollama:
|
||||
base_url: "http://127.0.0.1:11434"
|
||||
model: "qwen3-coder:480b-cloud"
|
||||
generate_model: "qwen3-coder:480b-cloud"
|
||||
vl_model: "qwen2.5vl:7b"
|
||||
vl_model: "gemini-3-pro-preview"
|
||||
timeout: "120s"
|
||||
level: "info"
|
||||
format: "json"
|
||||
|
||||
vllm:
|
||||
base_url: "http://host.docker.internal:8001/v1"
|
||||
vl_model: "models/Qwen2.5-VL-3B-Instruct-AWQ"
|
||||
timeout: "120s"
|
||||
level: "info"
|
||||
|
||||
|
||||
sys:
|
||||
|
|
@ -19,6 +25,7 @@ sys:
|
|||
channel_pool_len: 100
|
||||
channel_pool_size: 32
|
||||
llm_pool_len: 5
|
||||
heartbeat_interval: 30
|
||||
redis:
|
||||
host: 47.97.27.195:6379
|
||||
type: node
|
||||
|
|
|
|||
16
deploy.sh
16
deploy.sh
|
|
@ -1,9 +1,9 @@
|
|||
export GO111MODULE=on
|
||||
export GOPROXY=https://goproxy.cn,direct
|
||||
export GOPATH=/root/go
|
||||
export GOCACHE=/root/.cache/go-build
|
||||
#export GO111MODULE=on
|
||||
#export GOPROXY=https://goproxy.cn,direct
|
||||
#export GOPATH=/root/go
|
||||
#export GOCACHE=/root/.cache/go-build
|
||||
export CONTAINER_NAME=ai_scheduler
|
||||
export CGO_ENABLED='0'
|
||||
#export CGO_ENABLED='0'
|
||||
|
||||
|
||||
MODE="$1"
|
||||
|
|
@ -22,8 +22,8 @@ fi
|
|||
git fetch origin
|
||||
git checkout "$BRANCH"
|
||||
git pull origin "$BRANCH"
|
||||
go mod tidy
|
||||
make build
|
||||
#go mod tidy
|
||||
#make build
|
||||
docker build -t ${CONTAINER_NAME} .
|
||||
docker stop ${CONTAINER_NAME}
|
||||
docker rm -f ${CONTAINER_NAME}
|
||||
|
|
@ -33,6 +33,6 @@ docker run -itd \
|
|||
-e "OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://host.docker.internal:11434}" \
|
||||
-e "MODE=${MODE}" \
|
||||
-p 8090:8090 \
|
||||
"${CONTAINER_NAME}" ./bin/server --config "./${CONFIG_FILE}"
|
||||
"${CONTAINER_NAME}" ./server --config "./${CONFIG_FILE}"
|
||||
|
||||
docker logs -f ${CONTAINER_NAME}
|
||||
13
go.mod
13
go.mod
|
|
@ -2,21 +2,21 @@ module ai_scheduler
|
|||
|
||||
go 1.24.0
|
||||
|
||||
toolchain go1.24.7
|
||||
|
||||
require (
|
||||
gitea.cdlsxd.cn/self-tools/l_request v1.0.8
|
||||
github.com/alibabacloud-go/darabonba-openapi/v2 v2.0.12
|
||||
github.com/alibabacloud-go/dingtalk v1.6.96
|
||||
github.com/alibabacloud-go/tea v1.2.2
|
||||
github.com/alibabacloud-go/tea-utils/v2 v2.0.6
|
||||
github.com/cloudwego/eino v0.7.6
|
||||
github.com/cloudwego/eino-ext/components/model/ollama v0.1.6
|
||||
github.com/cloudwego/eino v0.7.7
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.1.5
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/faabiosr/cachego v0.26.0
|
||||
github.com/fastwego/dingding v1.0.0-beta.4
|
||||
github.com/go-kratos/kratos/v2 v2.9.1
|
||||
github.com/go-playground/locales v0.14.1
|
||||
github.com/go-playground/universal-translator v0.18.1
|
||||
github.com/go-playground/validator/v10 v10.20.0
|
||||
github.com/gofiber/fiber/v2 v2.52.9
|
||||
github.com/gofiber/websocket/v2 v2.2.1
|
||||
github.com/google/uuid v1.6.0
|
||||
|
|
@ -53,10 +53,10 @@ require (
|
|||
github.com/dlclark/regexp2 v1.11.4 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/eino-contrib/jsonschema v1.0.3 // indirect
|
||||
github.com/eino-contrib/ollama v0.1.0 // indirect
|
||||
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/goph/emperror v0.17.2 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
|
|
@ -65,6 +65,7 @@ require (
|
|||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.17.9 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
|
|
@ -100,7 +101,7 @@ require (
|
|||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.11.0 // indirect
|
||||
golang.org/x/crypto v0.39.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
|
|
|
|||
20
go.sum
20
go.sum
|
|
@ -130,10 +130,8 @@ github.com/clbanning/mxj/v2 v2.5.5/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn
|
|||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/cloudwego/eino v0.7.6 h1:9KGY1IZ/5kCf2viMDrPF3ck3tqd92bOhVOoSKTFRwY0=
|
||||
github.com/cloudwego/eino v0.7.6/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ=
|
||||
github.com/cloudwego/eino-ext/components/model/ollama v0.1.6 h1:ZbrhV91uE0hGIOYXhb2i3G6tQJ/rK2SLYtoYrmocZXM=
|
||||
github.com/cloudwego/eino-ext/components/model/ollama v0.1.6/go.mod h1:GDXrvorGdRNV6g2mK5jdla2D8Xc/hh7XDrTeGDteLLo=
|
||||
github.com/cloudwego/eino v0.7.7 h1:WhP0SMWWPgLdOH03HrKUxtP9/Q96NhziMZNEQl9lxpU=
|
||||
github.com/cloudwego/eino v0.7.7/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ=
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.1.5 h1:+yvGbTPw93li9GSmdm6Rix88Yy8AXg5NNBcRbWx3CQU=
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.1.5/go.mod h1:IPVYMFoZcuHeVEsDTGN6SZjvue0xr1iZFhdpq1SBWdQ=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2 h1:r9Id2wzJ05PoHl+Km7jQgNMgciaZI93TVnUYso89esM=
|
||||
|
|
@ -153,8 +151,6 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
|
|||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0=
|
||||
github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4=
|
||||
github.com/eino-contrib/ollama v0.1.0 h1:z1NaMdKW6X1ftP8g5xGGR5zDRPUtuTKFq35vBQgxsN4=
|
||||
github.com/eino-contrib/ollama v0.1.0/go.mod h1:mYsQ7b3DeqY8bHPuD3MZJYTqkgyL6LoemxoP/B7ZNhA=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
|
|
@ -178,6 +174,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo
|
|||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
|
||||
github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
|
||||
|
|
@ -187,6 +185,14 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2
|
|||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-kratos/kratos/v2 v2.9.1 h1:EGif6/S/aK/RCR5clIbyhioTNyoSrii3FC118jG40Z0=
|
||||
github.com/go-kratos/kratos/v2 v2.9.1/go.mod h1:a1MQLjMhIh7R0kcJS9SzJYR43BRI7EPzzN0J1Ksu2bA=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw=
|
||||
|
|
@ -296,6 +302,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
|||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
|
|
|
|||
|
|
@ -1,41 +1,107 @@
|
|||
package biz
|
||||
|
||||
import (
|
||||
errors "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/util"
|
||||
"context"
|
||||
|
||||
"encoding/json"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type ChatHistoryBiz struct {
|
||||
chatRepo *impl.ChatImpl
|
||||
chatHiRepo *impl.ChatHisImpl
|
||||
taskRepo *impl.TaskImpl
|
||||
}
|
||||
|
||||
func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz {
|
||||
func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl, taskRepo *impl.TaskImpl) *ChatHistoryBiz {
|
||||
s := &ChatHistoryBiz{
|
||||
chatRepo: chatRepo,
|
||||
chatHiRepo: chatHiRepo,
|
||||
taskRepo: taskRepo,
|
||||
}
|
||||
//go s.AsyncProcess(context.Background())
|
||||
return s
|
||||
}
|
||||
|
||||
//func (s *ChatHistoryBiz) create(ctx context.Context, sessionID, role, content string) error {
|
||||
// chat := model.AiChatHi{
|
||||
// SessionID: sessionID,
|
||||
// Role: role,
|
||||
// Content: content,
|
||||
// }
|
||||
//
|
||||
// return s.chatRepo.Create(&chat)
|
||||
//}
|
||||
//
|
||||
// 查询会话历史
|
||||
func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]entitys.ChatHisQueryResponse, error) {
|
||||
chats, err := s.chatHiRepo.FindAll(
|
||||
s.chatHiRepo.WithSessionId(query.SessionID),
|
||||
s.chatHiRepo.PaginateScope(query.Page, query.PageSize),
|
||||
s.chatHiRepo.OrderByDesc("his_id"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
taskIds := make([]int32, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
// 去重任务ID
|
||||
if !util.Contains(taskIds, chat.TaskID) {
|
||||
taskIds = append(taskIds, chat.TaskID)
|
||||
}
|
||||
}
|
||||
|
||||
// 查询任务名称
|
||||
tasks, err := s.taskRepo.FindAll(s.taskRepo.In("task_id", taskIds))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
taskMap := make(map[int32]model.AiTask)
|
||||
for _, task := range tasks {
|
||||
taskMap[task.TaskID] = task
|
||||
}
|
||||
|
||||
// 构建结果
|
||||
result := make([]entitys.ChatHisQueryResponse, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
item := entitys.ChatHisQueryResponse{}
|
||||
item.FromModel(chat, taskMap[chat.TaskID])
|
||||
result = append(result, item)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
//// 添加会话历史
|
||||
//func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
|
||||
// return s.create(ctx, chat.SessionID, chat.Role.String(), chat.Content)
|
||||
// return s.chatHiRepo.Create(&model.AiChatHi{
|
||||
// SessionID: chat.SessionID,
|
||||
// Ques: chat.Role.String(),
|
||||
// Ans: chat.Content,
|
||||
// })
|
||||
//}
|
||||
|
||||
// 更新会话历史内容, 追加内容, 不覆盖原有内容, content 使用json格式存储
|
||||
func (c *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UpdateContentRequest) error {
|
||||
var contents []string
|
||||
chatHi, has, err := c.chatHiRepo.FindOne(c.chatHiRepo.WithHisId(chat.HisID))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !has {
|
||||
return errors.NewBusinessErr(errors.InvalidParamCode, "chat history not found")
|
||||
}
|
||||
|
||||
if "" != chatHi.Content {
|
||||
// 解析历史内容
|
||||
err = json.Unmarshal([]byte(chatHi.Content), &contents)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
contents = append(contents, chat.Content)
|
||||
|
||||
b, err := json.Marshal(contents)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chatHi.Content = string(b)
|
||||
return c.chatHiRepo.Update(&chatHi,
|
||||
c.chatHiRepo.Select("content"),
|
||||
c.chatHiRepo.WithHisId(chatHi.HisID))
|
||||
}
|
||||
|
||||
// 异步添加会话历史
|
||||
//func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) {
|
||||
// s.chatRepo.AsyncCreate(ctx, model.AiChatHi{
|
||||
|
|
@ -53,5 +119,5 @@ func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz {
|
|||
func (s *ChatHistoryBiz) Update(ctx context.Context, chat *entitys.UseFulRequest) error {
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"his_id": chat.HisId})
|
||||
return s.chatRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
|
||||
return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,8 +19,6 @@ import (
|
|||
|
||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
|
|
@ -29,14 +27,14 @@ type Do struct {
|
|||
sessionImpl *impl.SessionImpl
|
||||
sysImpl *impl.SysImpl
|
||||
taskImpl *impl.TaskImpl
|
||||
hisImpl *impl.ChatImpl
|
||||
hisImpl *impl.ChatHisImpl
|
||||
conf *config.Config
|
||||
}
|
||||
|
||||
func NewDo(
|
||||
sysImpl *impl.SysImpl,
|
||||
taskImpl *impl.TaskImpl,
|
||||
hisImpl *impl.ChatImpl,
|
||||
hisImpl *impl.ChatHisImpl,
|
||||
conf *config.Config,
|
||||
) *Do {
|
||||
return &Do{
|
||||
|
|
@ -141,10 +139,10 @@ func (d *Do) loadChatHistory(ctx context.Context, requireData *entitys.RequireDa
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *Do) MakeCh(c *websocket.Conn, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) {
|
||||
func (d *Do) MakeCh(client *gateway.Client, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) {
|
||||
requireData.Ch = make(chan entitys.Response)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := d.startMessageHandler(ctx, c, requireData)
|
||||
done := d.startMessageHandler(ctx, client, requireData)
|
||||
return ctx, func() {
|
||||
close(requireData.Ch) //关闭主通道
|
||||
<-done // 等待消息处理完成
|
||||
|
|
@ -235,7 +233,7 @@ func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
|||
// startMessageHandler 启动独立的消息处理协程
|
||||
func (d *Do) startMessageHandler(
|
||||
ctx context.Context,
|
||||
c *websocket.Conn,
|
||||
client *gateway.Client,
|
||||
requireData *entitys.RequireData,
|
||||
) <-chan struct{} {
|
||||
done := make(chan struct{})
|
||||
|
|
@ -254,12 +252,13 @@ func (d *Do) startMessageHandler(
|
|||
Ques: requireData.Req.Text,
|
||||
Ans: strings.Join(chat, ""),
|
||||
Files: requireData.Req.Img,
|
||||
TaskID: requireData.Task.TaskID,
|
||||
}
|
||||
d.hisImpl.AddWithData(AiRes)
|
||||
hisLog.HisId = AiRes.HisID
|
||||
}
|
||||
|
||||
_ = entitys.MsgSend(c, entitys.Response{
|
||||
_ = entitys.MsgSend(client, entitys.Response{
|
||||
Content: pkg.JsonStringIgonErr(hisLog),
|
||||
Type: entitys.ResponseEnd,
|
||||
})
|
||||
|
|
@ -267,7 +266,7 @@ func (d *Do) startMessageHandler(
|
|||
}()
|
||||
|
||||
for v := range requireData.Ch { // 自动检测通道关闭
|
||||
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
|
||||
if err := sendWithTimeout(client, v, 10*time.Second); err != nil {
|
||||
log.Errorf("Send error: %v", err)
|
||||
return
|
||||
}
|
||||
|
|
@ -281,7 +280,7 @@ func (d *Do) startMessageHandler(
|
|||
}
|
||||
|
||||
// 辅助函数:带超时的 WebSocket 发送
|
||||
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
|
||||
func sendWithTimeout(client *gateway.Client, data entitys.Response, timeout time.Duration) error {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
|
|
@ -294,7 +293,7 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
|
|||
close(done)
|
||||
}()
|
||||
// 如果 MsgSend 阻塞,这里会卡住
|
||||
err := entitys.MsgSend(c, data)
|
||||
err := entitys.MsgSend(client, data)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
|
|
|
|||
|
|
@ -17,8 +17,9 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gorm.io/gorm/utils"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
type Handle struct {
|
||||
|
|
@ -85,6 +86,7 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requir
|
|||
for _, task := range requireData.Tasks {
|
||||
if task.Index == requireData.Match.Index {
|
||||
pointTask = &task
|
||||
requireData.Task = task
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -236,9 +238,19 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
|
||||
// request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
|
||||
task.Config = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
|
||||
for k, v := range requestParam {
|
||||
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
|
||||
if vStr, ok := v.(string); ok {
|
||||
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", vStr)
|
||||
} else {
|
||||
var jsonStr []byte
|
||||
jsonStr, err = json.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.NewBusinessErr(422, "请求参数解析失败")
|
||||
}
|
||||
task.Config = strings.ReplaceAll(task.Config, "\"${"+k+"}\"", string(jsonStr))
|
||||
}
|
||||
}
|
||||
var configData entitys.ConfigDataHttp
|
||||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||
|
|
|
|||
|
|
@ -7,33 +7,35 @@ import (
|
|||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
"ai_scheduler/internal/pkg/utils_vllm"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/ollama/ollama/api"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type OllamaService struct {
|
||||
client *utils_ollama.Client
|
||||
config *config.Config
|
||||
chatHis *impl.ChatImpl
|
||||
client *utils_ollama.Client
|
||||
vllmClient *utils_vllm.Client
|
||||
config *config.Config
|
||||
chatHis *impl.ChatHisImpl
|
||||
}
|
||||
|
||||
func NewOllamaGenerate(
|
||||
client *utils_ollama.Client,
|
||||
vllmClient *utils_vllm.Client,
|
||||
config *config.Config,
|
||||
chatHis *impl.ChatImpl,
|
||||
chatHis *impl.ChatHisImpl,
|
||||
) *OllamaService {
|
||||
return &OllamaService{
|
||||
client: client,
|
||||
config: config,
|
||||
chatHis: chatHis,
|
||||
client: client,
|
||||
vllmClient: vllmClient,
|
||||
config: config,
|
||||
chatHis: chatHis,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -105,7 +107,8 @@ func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys
|
|||
}
|
||||
|
||||
if len(requireData.ImgByte) > 0 {
|
||||
desc, err := r.RecognizeWithImg(ctx, requireData)
|
||||
// desc, err := r.RecognizeWithImg(ctx, requireData)
|
||||
desc, err := r.RecognizeWithImgVllm(ctx, requireData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
@ -141,6 +144,7 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit
|
|||
Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
Images: requireData.ImgByte,
|
||||
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
||||
//Think: &api.ThinkValue{Value: false},
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -155,42 +159,11 @@ func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *e
|
|||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||
|
||||
chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{})
|
||||
if err != nil {
|
||||
return api.GenerateResponse{}, err
|
||||
}
|
||||
in := []*schema.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
UserInputMultiContent: []schema.MessageInputPart{
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, imgUrl := range requireData.ImgUrls {
|
||||
imgTmp := imgUrl
|
||||
|
||||
in[1].UserInputMultiContent = append(in[1].UserInputMultiContent, schema.MessageInputPart{
|
||||
Type: schema.ChatMessagePartTypeImageURL,
|
||||
Image: &schema.MessageInputImage{
|
||||
MessagePartCommon: schema.MessagePartCommon{
|
||||
URL: &imgTmp,
|
||||
},
|
||||
Detail: schema.ImageURLDetailHigh,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
outMsg, err := chatModel.Generate(ctx, in)
|
||||
outMsg, err := r.vllmClient.RecognizeWithImg(ctx,
|
||||
r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||
r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
requireData.ImgUrls,
|
||||
)
|
||||
if err != nil {
|
||||
return api.GenerateResponse{}, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,11 +39,9 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
|
|||
requireData := &entitys.RequireData{
|
||||
Req: req,
|
||||
}
|
||||
// 获取WebSocket连接
|
||||
conn := client.GetConn()
|
||||
|
||||
//初始化通道/上下文
|
||||
ctx, clearFunc := r.do.MakeCh(conn, requireData)
|
||||
ctx, clearFunc := r.do.MakeCh(client, requireData)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
entitys.ResError(requireData.Ch, "", err.Error())
|
||||
|
|
|
|||
|
|
@ -17,16 +17,16 @@ import (
|
|||
type SessionBiz struct {
|
||||
sessionRepo *impl.SessionImpl
|
||||
sysRepo *impl.SysImpl
|
||||
chatRepo *impl.ChatImpl
|
||||
chatHisRepo *impl.ChatHisImpl
|
||||
|
||||
conf *config.Config
|
||||
}
|
||||
|
||||
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatImpl) *SessionBiz {
|
||||
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatHisImpl) *SessionBiz {
|
||||
return &SessionBiz{
|
||||
sessionRepo: sessionImpl,
|
||||
sysRepo: sysImpl,
|
||||
chatRepo: chatImpl,
|
||||
chatHisRepo: chatImpl,
|
||||
conf: conf,
|
||||
}
|
||||
}
|
||||
|
|
@ -91,10 +91,10 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe
|
|||
result.Prologue = sysConfig.Prologue
|
||||
// 存在,返回会话历史
|
||||
var chatList []model.AiChatHi
|
||||
chatList, err = s.chatRepo.FindAll(
|
||||
s.chatRepo.WithSessionId(session.SessionID), // 条件:会话ID
|
||||
s.chatRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
|
||||
s.chatRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数
|
||||
chatList, err = s.chatHisRepo.FindAll(
|
||||
s.chatHisRepo.WithSessionId(session.SessionID), // 条件:会话ID
|
||||
s.chatHisRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
|
||||
s.chatHisRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Ollama OllamaConfig `mapstructure:"ollama"`
|
||||
Vllm VllmConfig `mapstructure:"vllm"`
|
||||
Sys SysConfig `mapstructure:"sys"`
|
||||
Tools ToolsConfig `mapstructure:"tools"`
|
||||
Logging LoggingConfig `mapstructure:"logging"`
|
||||
|
|
@ -59,10 +60,11 @@ type LLMCapabilityConfig struct {
|
|||
|
||||
// SysConfig 系统配置
|
||||
type SysConfig struct {
|
||||
SessionLen int `mapstructure:"session_len"`
|
||||
ChannelPoolLen int `mapstructure:"channel_pool_len"`
|
||||
ChannelPoolSize int `mapstructure:"channel_pool_size"`
|
||||
LlmPoolLen int `mapstructure:"llm_pool_len"`
|
||||
SessionLen int `mapstructure:"session_len"`
|
||||
ChannelPoolLen int `mapstructure:"channel_pool_len"`
|
||||
ChannelPoolSize int `mapstructure:"channel_pool_size"`
|
||||
LlmPoolLen int `mapstructure:"llm_pool_len"`
|
||||
HeartbeatInterval int `mapstructure:"heartbeat_interval"`
|
||||
}
|
||||
|
||||
// ServerConfig 服务器配置
|
||||
|
|
@ -80,6 +82,13 @@ type OllamaConfig struct {
|
|||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
}
|
||||
|
||||
type VllmConfig struct {
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
VlModel string `mapstructure:"vl_model"`
|
||||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
Level string `mapstructure:"level"`
|
||||
}
|
||||
|
||||
type Redis struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Type string `mapstructure:"type"`
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ type BaseRepository[P PO] interface {
|
|||
WithStatus(status int) CondFunc // 查询status
|
||||
GetDb() *gorm.DB // 获取数据库连接
|
||||
WithLimit(limit int) CondFunc // 限制返回条数
|
||||
In(field string, values interface{}) CondFunc // 查询字段是否在列表中
|
||||
Select(fields ...string) CondFunc // 选择字段
|
||||
}
|
||||
|
||||
// PaginationResult 分页查询结果
|
||||
|
|
@ -215,3 +217,17 @@ func (this *BaseModel[P]) WithLimit(limit int) CondFunc {
|
|||
return db.Limit(limit)
|
||||
}
|
||||
}
|
||||
|
||||
// 查询字段是否在列表中
|
||||
func (this *BaseModel[P]) In(field string, values interface{}) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where(fmt.Sprintf("%s IN ?", field), values)
|
||||
}
|
||||
}
|
||||
|
||||
// select 字段
|
||||
func (this *BaseModel[P]) Select(fields ...string) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Select(fields)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,14 +11,14 @@ import (
|
|||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ChatImpl struct {
|
||||
type ChatHisImpl struct {
|
||||
dataTemp.DataTemp
|
||||
BaseRepository[model.AiChatHi]
|
||||
chatChannel chan model.AiChatHi
|
||||
}
|
||||
|
||||
func NewChatImpl(db *utils.Db) *ChatImpl {
|
||||
return &ChatImpl{
|
||||
func NewChatHisImpl(db *utils.Db) *ChatHisImpl {
|
||||
return &ChatHisImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)),
|
||||
BaseRepository: NewBaseModel[model.AiChatHi](db.Client),
|
||||
chatChannel: make(chan model.AiChatHi, 100),
|
||||
|
|
@ -26,19 +26,19 @@ func NewChatImpl(db *utils.Db) *ChatImpl {
|
|||
}
|
||||
|
||||
// WithSessionId 条件:会话ID
|
||||
func (impl *ChatImpl) WithSessionId(sessionId interface{}) CondFunc {
|
||||
func (impl *ChatHisImpl) WithSessionId(sessionId interface{}) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("session_id = ?", sessionId)
|
||||
}
|
||||
}
|
||||
|
||||
// 异步添加会话历史
|
||||
func (impl *ChatImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) {
|
||||
func (impl *ChatHisImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) {
|
||||
impl.chatChannel <- chat
|
||||
}
|
||||
|
||||
// 异步处理会话历史
|
||||
func (impl *ChatImpl) AsyncProcess(ctx context.Context) {
|
||||
func (impl *ChatHisImpl) AsyncProcess(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case chat := <-impl.chatChannel:
|
||||
|
|
@ -55,3 +55,10 @@ func (impl *ChatImpl) AsyncProcess(ctx context.Context) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// his_id 条件:历史ID
|
||||
func (impl *ChatHisImpl) WithHisId(hisId interface{}) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("his_id = ?", hisId)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@ import (
|
|||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatImpl)
|
||||
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatHisImpl)
|
||||
|
|
|
|||
|
|
@ -8,8 +8,12 @@ import (
|
|||
|
||||
type TaskImpl struct {
|
||||
dataTemp.DataTemp
|
||||
BaseRepository[model.AiTask]
|
||||
}
|
||||
|
||||
func NewTaskImpl(db *utils.Db) *TaskImpl {
|
||||
return &TaskImpl{*dataTemp.NewDataTemp(db, new(model.AiTask))}
|
||||
return &TaskImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiTask)),
|
||||
BaseRepository: NewBaseModel[model.AiTask](db.Client),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ type AiChatHi struct {
|
|||
Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用,其他为无用" json:"useful"` // 0不评价,1有用,其他为无用
|
||||
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID
|
||||
Content string `gorm:"column:content" json:"content"` // 前端回传数据
|
||||
}
|
||||
|
||||
// TableName AiChatHi's table name
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ package entitys
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"encoding/json"
|
||||
"log"
|
||||
)
|
||||
|
||||
type ChatHistory struct {
|
||||
|
|
@ -14,3 +17,49 @@ type ChatHistory struct {
|
|||
type ChatHisLog struct {
|
||||
HisId int64 `json:"his_id"`
|
||||
}
|
||||
|
||||
type ChatHistQuery struct {
|
||||
SessionID string `json:"session_id"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
type ChatHisQueryResponse struct {
|
||||
HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"`
|
||||
SessionID string `gorm:"column:session_id;not null" json:"session_id"`
|
||||
Ques string `gorm:"column:ques;not null" json:"ques"`
|
||||
Ans string `gorm:"column:ans;not null" json:"ans"`
|
||||
Files string `gorm:"column:files;not null" json:"files"`
|
||||
Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用,其他为无用" json:"useful"` // 0不评价,1有用,其他为无用
|
||||
CreateAt string `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||
TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID
|
||||
TaskName string `gorm:"column:task_name;not null" json:"task_name"` // 任务名称
|
||||
Contents []string `gorm:"column:contents" json:"contents"` // 前端回传数据
|
||||
}
|
||||
|
||||
func (c *ChatHisQueryResponse) FromModel(chat model.AiChatHi, task model.AiTask) {
|
||||
c.HisID = chat.HisID
|
||||
c.SessionID = chat.SessionID
|
||||
c.Ques = chat.Ques
|
||||
c.Ans = chat.Ans
|
||||
c.Files = chat.Files
|
||||
c.Useful = chat.Useful
|
||||
c.CreateAt = chat.CreateAt.Format("2006-01-02 15:04:05")
|
||||
c.TaskID = chat.TaskID
|
||||
c.TaskName = task.Name
|
||||
c.Contents = make([]string, 0)
|
||||
|
||||
// 解析Content
|
||||
if "" != chat.Content {
|
||||
err := json.Unmarshal([]byte(chat.Content), &c.Contents)
|
||||
if err != nil {
|
||||
c.Contents = append(c.Contents, chat.Content)
|
||||
log.Println("解析Content失败 error: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type UpdateContentRequest struct {
|
||||
HisID int64 `json:"his_id" validate:"required"`
|
||||
Content string `json:"content" validate:"required"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package entitys
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/gateway"
|
||||
"encoding/json"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
)
|
||||
|
|
@ -100,13 +101,13 @@ func MsgSet(msgType ResponseType, msg string, done bool) []byte {
|
|||
return jsonByte
|
||||
}
|
||||
|
||||
func MsgSend(c *websocket.Conn, msg Response) error {
|
||||
func MsgSend(client *gateway.Client, msg Response) error {
|
||||
// 检查上下文是否已取消
|
||||
if msg.Type == ResponseText {
|
||||
|
||||
}
|
||||
jsonByte, _ := json.Marshal(msg)
|
||||
return c.WriteMessage(websocket.TextMessage, jsonByte)
|
||||
return client.SendFunc(jsonByte)
|
||||
}
|
||||
|
||||
func MsgSendByte(c *websocket.Conn, msg []byte) {
|
||||
|
|
|
|||
|
|
@ -150,6 +150,7 @@ type RequireData struct {
|
|||
Histories []model.AiChatHi
|
||||
SessionInfo model.AiSession
|
||||
Tasks []model.AiTask
|
||||
Task model.AiTask
|
||||
Match *Match
|
||||
Req *ChatSockRequest
|
||||
Auth string
|
||||
|
|
|
|||
|
|
@ -3,33 +3,46 @@ package gateway
|
|||
import (
|
||||
errors "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"context"
|
||||
"github.com/google/uuid"
|
||||
"log"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/websocket/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrConnClosed = errors.SysErr("连接不存在或已关闭")
|
||||
rng = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
idBuf = make([]byte, 20)
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
id string // 客户端唯一ID
|
||||
conn *websocket.Conn // WebSocket 连接
|
||||
session string // 会话ID
|
||||
key string // 应用密钥
|
||||
auth string // 用户凭证token
|
||||
codes []string // 用户权限code
|
||||
sysInfo *model.AiSy // 系统信息
|
||||
tasks []model.AiTask // 任务列表
|
||||
sysCode string // 系统编码
|
||||
id string // 客户端唯一ID
|
||||
conn *websocket.Conn // WebSocket 连接
|
||||
session string // 会话ID
|
||||
key string // 应用密钥
|
||||
auth string // 用户凭证token
|
||||
codes []string // 用户权限code
|
||||
sysInfo *model.AiSy // 系统信息
|
||||
tasks []model.AiTask // 任务列表
|
||||
sysCode string // 系统编码
|
||||
Ctx context.Context
|
||||
Cancel context.CancelFunc
|
||||
LastActive time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewClient(conn *websocket.Conn) *Client {
|
||||
func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client {
|
||||
|
||||
return &Client{
|
||||
id: generateClientID(),
|
||||
conn: conn,
|
||||
id: generateClientID(),
|
||||
conn: conn,
|
||||
Ctx: ctx,
|
||||
Cancel: cancel,
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -63,7 +76,7 @@ func (c *Client) GetCodes() []string {
|
|||
return c.codes
|
||||
}
|
||||
|
||||
// GetSysCode 获取系统编码
|
||||
// 获取系统编码
|
||||
func (c *Client) GetSysCode() string {
|
||||
return c.sysCode
|
||||
}
|
||||
|
|
@ -93,22 +106,51 @@ func (c *Client) SetCodes(codes []string) {
|
|||
c.codes = codes
|
||||
}
|
||||
|
||||
// Close 关闭客户端连接
|
||||
func (c *Client) Close() {
|
||||
//c.mu.Lock()
|
||||
//defer c.mu.Unlock()
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// SendFunc 发送消息到客户端
|
||||
func (c *Client) SendFunc(msg []byte) error {
|
||||
if c.conn != nil {
|
||||
return c.conn.WriteMessage(websocket.TextMessage, msg)
|
||||
return c.SendMessage(websocket.TextMessage, msg)
|
||||
}
|
||||
|
||||
// 在Client结构体中添加更详细的日志
|
||||
func (c *Client) SendMessage(msgType int, msg []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn == nil {
|
||||
return ErrConnClosed
|
||||
}
|
||||
return ErrConnClosed
|
||||
|
||||
err := c.conn.WriteMessage(msgType, msg)
|
||||
if err != nil {
|
||||
log.Printf("发送消息失败: %v, 客户端ID: %s, 消息类型: %d",
|
||||
err, c.id, msgType)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 生成唯一的客户端ID
|
||||
func generateClientID() string {
|
||||
// 使用时间戳+随机数确保唯一性
|
||||
timestamp := time.Now().UnixNano()
|
||||
randomBytes := make([]byte, 4)
|
||||
rand.Read(randomBytes)
|
||||
randomStr := hex.EncodeToString(randomBytes)
|
||||
return fmt.Sprintf("%d%s", timestamp, randomStr)
|
||||
return uuid.New().String()
|
||||
//// 1. 时间戳
|
||||
//timestamp := time.Now().UnixNano()
|
||||
//binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp))
|
||||
//
|
||||
//// 2. 随机数(4字节)
|
||||
//binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32())
|
||||
//
|
||||
//// 3. 十六进制编码
|
||||
//n := pkg.HexEncode(idBuf[:12], idBuf[12:])
|
||||
//return string(idBuf[12 : 12+n])
|
||||
}
|
||||
|
||||
// 连接数据验证和收集
|
||||
|
|
@ -136,3 +178,34 @@ func (c *Client) DataAuth() (err error) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 总结:目前绝大多数浏览器不支持直接发送WebSocket Ping帧,因此在实际开发中,应该实现应用层ping机制作为主要心跳检测方案,todo 同时保留对未来可能的原生支持的兼容检测。
|
||||
func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
|
||||
ticker := time.NewTicker(timeoutSecond * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
//*2是防止丢包,连续丢包两次,再加5s网络延迟容错
|
||||
if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错
|
||||
log.Println("心跳超时", "clientId", c.id)
|
||||
err := c.SendMessage(websocket.CloseMessage, []byte("Heartbeat timeout"))
|
||||
if err != nil {
|
||||
log.Println("发送心跳超时消息失败", err)
|
||||
}
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
case <-c.Ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 在Client结构体中添加ReadMessage方法
|
||||
func (c *Client) ReadMessage() (messageType int, message []byte, err error) {
|
||||
if c.conn == nil {
|
||||
return 0, nil, ErrConnClosed
|
||||
}
|
||||
return c.conn.ReadMessage()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ package gateway
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Gateway struct {
|
||||
|
|
@ -20,14 +22,29 @@ func NewGateway() *Gateway {
|
|||
|
||||
func (g *Gateway) AddClient(c *Client) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
defer func() {
|
||||
g.mu.Unlock()
|
||||
//心跳开始计时
|
||||
c.LastActive = time.Now()
|
||||
log.Println("client connected:", c.GetID())
|
||||
log.Println("客户端已连接")
|
||||
}()
|
||||
g.clients[c.GetID()] = c
|
||||
}
|
||||
|
||||
func (g *Gateway) RemoveClient(clientID string) {
|
||||
func (g *Gateway) Cleanup(clientID string) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
delete(g.clients, clientID)
|
||||
// 从网关管理中移除客户端
|
||||
defer func() {
|
||||
if c, ex := g.clients[clientID]; ex {
|
||||
delete(g.clients, clientID)
|
||||
c.Close()
|
||||
c.Cancel()
|
||||
}
|
||||
g.mu.Unlock()
|
||||
log.Println("client disconnected:", clientID)
|
||||
}()
|
||||
// 从所有绑定的UID列表中移除该客户端
|
||||
for uid, list := range g.uidMap {
|
||||
newList := []string{}
|
||||
for _, cid := range list {
|
||||
|
|
@ -37,6 +54,7 @@ func (g *Gateway) RemoveClient(clientID string) {
|
|||
}
|
||||
g.uidMap[uid] = newList
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (g *Gateway) SendToAll(msg []byte) {
|
||||
|
|
@ -63,6 +81,7 @@ func (g *Gateway) BindUid(clientID, uid string) error {
|
|||
return errors.New("client not found")
|
||||
}
|
||||
g.uidMap[uid] = append(g.uidMap[uid], clientID)
|
||||
log.Printf("绑定 clientId %s -> uid:%s\n", clientID, uid)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -55,3 +55,13 @@ func ValidateImageURL(rawURL string) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hexEncode 将 src 的二进制数据编码为十六进制字符串,写入 dst,返回写入长度
|
||||
func HexEncode(src, dst []byte) int {
|
||||
const hextable = "0123456789abcdef"
|
||||
for i := 0; i < len(src); i++ {
|
||||
dst[i*2] = hextable[src[i]>>4]
|
||||
dst[i*2+1] = hextable[src[i]&0xf]
|
||||
}
|
||||
return len(src) * 2
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"ai_scheduler/internal/pkg/dingtalk"
|
||||
"ai_scheduler/internal/pkg/utils_langchain"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
"ai_scheduler/internal/pkg/utils_vllm"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
|
@ -13,6 +14,7 @@ var ProviderSetClient = wire.NewSet(
|
|||
NewGormDb,
|
||||
utils_langchain.NewUtilLangChain,
|
||||
utils_ollama.NewClient,
|
||||
utils_vllm.NewClient,
|
||||
NewSafeChannelPool,
|
||||
dingtalk.NewOldClient,
|
||||
dingtalk.NewContactClient,
|
||||
|
|
|
|||
|
|
@ -32,3 +32,13 @@ func StringToFloat64(s string) float64 {
|
|||
i, _ := strconv.ParseFloat(s, 64)
|
||||
return i
|
||||
}
|
||||
|
||||
// 是否包含在数组中
|
||||
func Contains[T comparable](strings []T, str T) bool {
|
||||
for _, s := range strings {
|
||||
if s == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
package utils_vllm
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
model *openai.ChatModel
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewClient(config *config.Config) (*Client, func(), error) {
|
||||
m, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{
|
||||
BaseURL: config.Vllm.BaseURL,
|
||||
Model: config.Vllm.VlModel,
|
||||
Timeout: config.Vllm.Timeout,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
c := &Client{model: m, config: config}
|
||||
cleanup := func() {}
|
||||
return c, cleanup, nil
|
||||
}
|
||||
|
||||
func (c *Client) Chat(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) {
|
||||
return c.model.Generate(ctx, msgs)
|
||||
}
|
||||
|
||||
func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt string, imgURLs []string) (*schema.Message, error) {
|
||||
in := []*schema.Message{
|
||||
{
|
||||
Role: schema.System,
|
||||
Content: systemPrompt,
|
||||
},
|
||||
{
|
||||
Role: schema.User,
|
||||
},
|
||||
}
|
||||
parts := []schema.MessageInputPart{
|
||||
{Type: schema.ChatMessagePartTypeText, Text: userPrompt},
|
||||
}
|
||||
for i := range imgURLs {
|
||||
u := imgURLs[i]
|
||||
parts = append(parts, schema.MessageInputPart{
|
||||
Type: schema.ChatMessagePartTypeImageURL,
|
||||
Image: &schema.MessageInputImage{
|
||||
MessagePartCommon: schema.MessagePartCommon{URL: &u},
|
||||
Detail: schema.ImageURLDetailHigh,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
in[1].UserInputMultiContent = parts
|
||||
return c.model.Generate(ctx, in)
|
||||
}
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
package utils_vllm
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func newMockServer() *httptest.Server {
|
||||
h := http.NewServeMux()
|
||||
h.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"id":"cmpl-1","object":"chat.completion","created":173,"model":"x","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`))
|
||||
})
|
||||
return httptest.NewServer(h)
|
||||
}
|
||||
|
||||
func Test_Vllm_Chat_Generate(t *testing.T) {
|
||||
cfg, err := config.LoadConfig("../../../config/config_test.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
client, _, err := NewClient(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("new client: %v", err)
|
||||
}
|
||||
|
||||
msgs := []*schema.Message{{Role: schema.User, Content: "hi"}}
|
||||
out, err := client.Chat(ctx, msgs)
|
||||
if err != nil {
|
||||
t.Fatalf("chat generate: %v", err)
|
||||
}
|
||||
if out == nil || out.Content != "ok" {
|
||||
t.Fatalf("unexpected content: %v", out)
|
||||
}
|
||||
|
||||
t.Logf("结果: %v", out)
|
||||
}
|
||||
|
||||
func Test_Vllm_RecognizeWithImg(t *testing.T) {
|
||||
cfg, err := config.LoadConfig("../../../config/config_test.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
client, _, err := NewClient(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("new client: %v", err)
|
||||
}
|
||||
|
||||
out, err := client.RecognizeWithImg(ctx, "sys", "user", []string{"https://img0.baidu.com/it/u=910428455,194434251&fm=253&app=138&f=JPEG?w=1122&h=800"})
|
||||
if err != nil {
|
||||
t.Fatalf("recognize with img: %v", err)
|
||||
}
|
||||
if out == nil || out.Content != "ok" {
|
||||
t.Fatalf("unexpected content: %v", out)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
package validate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-playground/locales/zh"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
"github.com/go-playground/validator/v10"
|
||||
zh_translations "github.com/go-playground/validator/v10/translations/zh"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func Struct(s interface{}) (errMsg []string, err error) {
|
||||
// 创建验证器实例
|
||||
validate := validator.New()
|
||||
|
||||
// 创建中文翻译器
|
||||
zh_ch := zh.New()
|
||||
uni := ut.New(zh_ch, zh_ch)
|
||||
trans, _ := uni.GetTranslator("zh")
|
||||
|
||||
//注册一个函数,获取struct tag里自定义的label作为字段名
|
||||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||
name := fld.Tag.Get("label")
|
||||
return name
|
||||
})
|
||||
|
||||
// 注册中文翻译器到验证器
|
||||
_ = zh_translations.RegisterDefaultTranslations(validate, trans)
|
||||
|
||||
// 验证结构体
|
||||
err = validate.Struct(s)
|
||||
if err != nil {
|
||||
// 处理验证错误
|
||||
if _, ok := err.(*validator.InvalidValidationError); ok {
|
||||
fmt.Println("处理验证错误error:", err)
|
||||
errMsg = append(errMsg, err.Error())
|
||||
} else {
|
||||
for _, v := range err.(validator.ValidationErrors) {
|
||||
errMsg = append(errMsg, v.Translate(trans))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -11,24 +11,25 @@ import (
|
|||
)
|
||||
|
||||
type HTTPServer struct {
|
||||
app *fiber.App
|
||||
service *services.ChatService
|
||||
session *services.SessionService
|
||||
gateway *gateway.Gateway
|
||||
callback *services.CallbackService
|
||||
app *fiber.App
|
||||
service *services.ChatService
|
||||
session *services.SessionService
|
||||
gateway *gateway.Gateway
|
||||
callback *services.CallbackService
|
||||
}
|
||||
|
||||
func NewHTTPServer(
|
||||
service *services.ChatService,
|
||||
session *services.SessionService,
|
||||
task *services.TaskService,
|
||||
gateway *gateway.Gateway,
|
||||
callback *services.CallbackService,
|
||||
service *services.ChatService,
|
||||
session *services.SessionService,
|
||||
task *services.TaskService,
|
||||
gateway *gateway.Gateway,
|
||||
callback *services.CallbackService,
|
||||
chatHis *services.HistoryService,
|
||||
) *fiber.App {
|
||||
//构建 server
|
||||
app := initRoute()
|
||||
router.SetupRoutes(app, service, session, task, gateway, callback)
|
||||
return app
|
||||
//构建 server
|
||||
app := initRoute()
|
||||
router.SetupRoutes(app, service, session, task, gateway, callback, chatHis)
|
||||
return app
|
||||
}
|
||||
|
||||
func initRoute() *fiber.App {
|
||||
|
|
|
|||
|
|
@ -15,14 +15,17 @@ import (
|
|||
)
|
||||
|
||||
type RouterServer struct {
|
||||
app *fiber.App
|
||||
service *services.ChatService
|
||||
session *services.SessionService
|
||||
gateway *gateway.Gateway
|
||||
app *fiber.App
|
||||
service *services.ChatService
|
||||
session *services.SessionService
|
||||
gateway *gateway.Gateway
|
||||
chatHist *services.HistoryService
|
||||
}
|
||||
|
||||
// SetupRoutes 设置路由
|
||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway, callbackService *services.CallbackService) {
|
||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService,
|
||||
gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService,
|
||||
) {
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
// 设置 CORS 头
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
|
|
@ -77,15 +80,16 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
|||
return ctx.Status(400).SendString("unknown action")
|
||||
}
|
||||
})
|
||||
|
||||
// 会话历史
|
||||
r.Post("/chat/history/list", chatHist.List)
|
||||
r.Post("/chat/history/update/content", chatHist.UpdateContent)
|
||||
}
|
||||
|
||||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
||||
ws := app.Group("ws/v1/")
|
||||
// WebSocket 路由配置
|
||||
ws.Get("/chat", websocket.New(func(c *websocket.Conn) {
|
||||
// 可以在这里添加握手前的中间件逻辑(如头校验)
|
||||
chatService.Chat(c) // 调用实际的 Chat 处理函数
|
||||
}, websocket.Config{
|
||||
ws.Get("/chat", websocket.New(chatService.Chat, websocket.Config{
|
||||
// 可选配置:跨域检查、最大负载大小等
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
//Subprotocols: []string{"json", "msgpack"},
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@ import (
|
|||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/gateway"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
|
|
@ -38,20 +40,6 @@ func NewChatService(
|
|||
}
|
||||
}
|
||||
|
||||
// ToolCallResponse 工具调用响应
|
||||
type ToolCallResponse struct {
|
||||
ID string `json:"id" example:"call_1"`
|
||||
Type string `json:"type" example:"function"`
|
||||
Function FunctionCallResponse `json:"function"`
|
||||
Result interface{} `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
// FunctionCallResponse 函数调用响应
|
||||
type FunctionCallResponse struct {
|
||||
Name string `json:"name" example:"get_weather"`
|
||||
Arguments interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
||||
err := c.WriteMessage(websocket.TextMessage, []byte(content))
|
||||
if err != nil {
|
||||
|
|
@ -63,84 +51,113 @@ func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
|||
// Chat 处理WebSocket聊天连接
|
||||
// 这是WebSocket处理的主入口函数
|
||||
func (h *ChatService) Chat(c *websocket.Conn) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// 创建新的客户端实例
|
||||
h.mu.Lock()
|
||||
client := gateway.NewClient(c)
|
||||
h.mu.Unlock()
|
||||
|
||||
// 将客户端添加到网关管理
|
||||
h.Gw.AddClient(client)
|
||||
log.Println("client connected:", client.GetID())
|
||||
log.Println("客户端已连接")
|
||||
|
||||
// 绑定会话ID
|
||||
uid := c.Query("x-session")
|
||||
if uid != "" {
|
||||
if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
|
||||
log.Println("绑定UID错误:", err)
|
||||
}
|
||||
log.Printf("bind %s -> uid:%s\n", client.GetID(), uid)
|
||||
}
|
||||
client := gateway.NewClient(c, ctx, cancel)
|
||||
|
||||
// 验证并收集连接数据,后续对话中会使用
|
||||
if err := client.DataAuth(); err != nil {
|
||||
log.Println("数据验证错误:", err)
|
||||
h.ChatFail(c, err.Error())
|
||||
_ = client.SendFunc([]byte(err.Error()))
|
||||
client.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// 验证通过后,将客户端添加到网关管理
|
||||
h.Gw.AddClient(client)
|
||||
|
||||
// 使用信号量限制并发处理的消息数量
|
||||
semaphore := make(chan struct{}, 1) // 最多1个并发消息处理
|
||||
// 用于等待所有goroutine完成的wait group
|
||||
var wg sync.WaitGroup
|
||||
// 确保在函数返回时移除客户端并关闭连接
|
||||
defer func() {
|
||||
h.Gw.RemoveClient(client.GetID())
|
||||
_ = c.Close()
|
||||
log.Println("client disconnected:", client.GetID())
|
||||
wg.Wait() // 等待所有消息处理goroutine完成
|
||||
close(semaphore) // 关闭信号量通道
|
||||
h.Gw.Cleanup(client.GetID())
|
||||
}()
|
||||
|
||||
// 绑定会话ID, sessionId 为空时, 则不绑定
|
||||
if uid := client.GetSession(); uid != "" {
|
||||
if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
|
||||
log.Println("绑定UID错误:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 开启心跳检测
|
||||
go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval))
|
||||
|
||||
// 循环读取客户端消息
|
||||
for {
|
||||
// 读取消息
|
||||
messageType, message, err := c.ReadMessage()
|
||||
messageType, message, err := client.ReadMessage()
|
||||
if err != nil {
|
||||
log.Println("读取错误:", err)
|
||||
log.Printf("读取错误: %v, 客户端ID: %s", err, client.GetID())
|
||||
break
|
||||
}
|
||||
|
||||
// 处理消息
|
||||
msg, chatType := h.handleMessageToString(c, messageType, message)
|
||||
if chatType == constants.ConnStatusClosed {
|
||||
break
|
||||
}
|
||||
if chatType == constants.ConnStatusIgnore {
|
||||
// 处理心跳消息
|
||||
if messageType == websocket.PingMessage || string(message) == "PING" {
|
||||
client.LastActive = time.Now()
|
||||
msgType := websocket.TextMessage
|
||||
if messageType == websocket.PingMessage {
|
||||
msgType = websocket.PongMessage
|
||||
}
|
||||
if err = client.SendMessage(msgType, []byte(`PONG`)); err != nil {
|
||||
log.Printf("发送pong消息失败: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("收到消息: %s", string(msg))
|
||||
// 使用信号量限制并发
|
||||
semaphore <- struct{}{}
|
||||
wg.Add(1)
|
||||
go func(msgType int, msg []byte) {
|
||||
defer func() {
|
||||
<-semaphore
|
||||
wg.Done()
|
||||
// 恢复panic
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("消息处理goroutine发生panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// 解析请求
|
||||
var req entitys.ChatSockRequest
|
||||
if err = json.Unmarshal(msg, &req); err != nil {
|
||||
log.Println("JSON parse error:", err)
|
||||
continue
|
||||
}
|
||||
// 消息处理逻辑
|
||||
h.processMessage(client, msgType, msg)
|
||||
}(messageType, message)
|
||||
}
|
||||
}
|
||||
|
||||
// 路由处理请求
|
||||
err = h.routerBiz.RouteWithSocket(client, &req)
|
||||
if err != nil {
|
||||
log.Println("处理失败:", err)
|
||||
}
|
||||
// 将消息处理逻辑提取到单独的方法
|
||||
func (h *ChatService) processMessage(client *gateway.Client, msgType int, msg []byte) {
|
||||
// 处理消息
|
||||
processedMsg, _ := h.handleMessageToString(client, msgType, msg)
|
||||
log.Printf("收到消息:消息类型 %d, 内容 %s, 客户端ID: %s",
|
||||
msgType, string(processedMsg), client.GetID())
|
||||
|
||||
// 解析请求
|
||||
var req entitys.ChatSockRequest
|
||||
if err := json.Unmarshal(processedMsg, &req); err != nil {
|
||||
log.Printf("JSON解析错误: %v, 客户端ID: %s", err, client.GetID())
|
||||
return
|
||||
}
|
||||
|
||||
// 路由处理请求
|
||||
if err := h.routerBiz.RouteWithSocket(client, &req); err != nil {
|
||||
log.Printf("处理失败: %v, 客户端ID: %s", err, client.GetID())
|
||||
}
|
||||
}
|
||||
|
||||
// handleMessageToString 处理不同类型的WebSocket消息
|
||||
// 参数:
|
||||
// - c: WebSocket连接
|
||||
// - client: 客户端对象
|
||||
// - msgType: 消息类型
|
||||
// - msg: 消息内容
|
||||
//
|
||||
// 返回:
|
||||
// - text: 处理后的文本内容
|
||||
// - chatType: 连接状态
|
||||
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
|
||||
func (h *ChatService) handleMessageToString(client *gateway.Client, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
|
||||
switch msgType {
|
||||
case websocket.TextMessage:
|
||||
return msg.([]byte), constants.ConnStatusNormal
|
||||
|
|
@ -150,15 +167,13 @@ func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg
|
|||
|
||||
return nil, constants.ConnStatusClosed
|
||||
case websocket.PingMessage:
|
||||
// 可选:回复 Pong
|
||||
c.WriteMessage(websocket.PongMessage, nil)
|
||||
return nil, constants.ConnStatusIgnore
|
||||
case websocket.PongMessage:
|
||||
return nil, constants.ConnStatusIgnore
|
||||
default:
|
||||
log.Printf("未知的消息类型: %d", msgType)
|
||||
return nil, constants.ConnStatusIgnore
|
||||
}
|
||||
return msg.([]byte), constants.ConnStatusIgnore
|
||||
}
|
||||
|
||||
func (s *ChatService) Useful(c *fiber.Ctx) error {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,63 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz"
|
||||
errors "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/validate"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type HistoryService struct {
|
||||
chatRepo *biz.ChatHistoryBiz
|
||||
}
|
||||
|
||||
func NewHistoryService(chatRepo *biz.ChatHistoryBiz) *HistoryService {
|
||||
return &HistoryService{
|
||||
chatRepo: chatRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetHistoryService 获取会话历史
|
||||
func (h *HistoryService) List(c *fiber.Ctx) error {
|
||||
var query entitys.ChatHistQuery
|
||||
if err := c.BodyParser(&query); err != nil {
|
||||
return err
|
||||
}
|
||||
// 校验参数
|
||||
if query.SessionID == "" {
|
||||
return errors.SessionNotFound
|
||||
}
|
||||
if query.Page <= 0 {
|
||||
query.Page = 1
|
||||
}
|
||||
if query.PageSize <= 0 {
|
||||
query.PageSize = 10
|
||||
}
|
||||
|
||||
// 查询历史
|
||||
history, err := h.chatRepo.List(c.Context(), &query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(history)
|
||||
}
|
||||
|
||||
func (h *HistoryService) UpdateContent(c *fiber.Ctx) error {
|
||||
var req entitys.UpdateContentRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
// 校验参数
|
||||
msg, err := validate.Struct(req)
|
||||
if err != nil {
|
||||
log.Error(c.UserContext(), "参数错误 error: ", err)
|
||||
return errors.NewBusinessErr(errors.InvalidParamCode, strings.Join(msg, ";"))
|
||||
}
|
||||
|
||||
// 更新历史
|
||||
return h.chatRepo.UpdateContent(c.Context(), &req)
|
||||
}
|
||||
|
|
@ -6,4 +6,4 @@ import (
|
|||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService)
|
||||
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService, NewHistoryService)
|
||||
|
|
|
|||
Loading…
Reference in New Issue