From aa86882f7dc9f4db7f2364171b55e172b247e378 Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Tue, 23 Sep 2025 16:30:30 +0800 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20=E6=8E=A5=E5=85=A5=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.yaml | 4 +- go.mod | 9 +- go.sum | 25 +-- internal/biz/router.go | 92 +++++++++- internal/entitys/types.go | 5 +- internal/pkg/l_request/README.md | 25 +++ internal/pkg/l_request/pool.go | 44 +++++ internal/pkg/l_request/pool_test.go | 11 ++ internal/pkg/l_request/request.go | 169 ++++++++++++++++++ internal/pkg/util/http.go | 168 +++++++++++++++++ internal/tools/konwledge_base.go | 248 ++++++++++++++++++++++++++ internal/tools/konwledge_base_test.go | 34 ++++ internal/tools/manager.go | 6 + 13 files changed, 803 insertions(+), 37 deletions(-) create mode 100644 internal/pkg/l_request/README.md create mode 100644 internal/pkg/l_request/pool.go create mode 100644 internal/pkg/l_request/pool_test.go create mode 100644 internal/pkg/l_request/request.go create mode 100644 internal/pkg/util/http.go create mode 100644 internal/tools/konwledge_base.go create mode 100644 internal/tools/konwledge_base_test.go diff --git a/config/config.yaml b/config/config.yaml index 4baff3f..c85ac02 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -47,4 +47,6 @@ tools: base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/ai/search/" enabled: true api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ" - + knowledge: + base_url: "http://117.175.169.61:10000" + enabled: true diff --git a/go.mod b/go.mod index 7768c99..4b59108 100644 --- a/go.mod +++ b/go.mod @@ -6,13 +6,12 @@ toolchain go1.24.7 require ( gitea.cdlsxd.cn/self-tools/l_request v1.0.8 - github.com/bytedance/sonic v1.11.6 github.com/emirpasic/gods v1.18.1 github.com/go-kratos/kratos/v2 v2.9.1 github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/websocket/v2 v2.2.1 github.com/google/wire v0.7.0 - github.com/ollama/ollama v0.11.11 + github.com/ollama/ollama v0.12.0 github.com/redis/go-redis/v9 v9.14.0 github.com/spf13/viper v1.17.0 github.com/tmc/langchaingo v0.1.13 @@ -27,10 +26,7 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/andybalholm/brotli v1.1.0 // indirect - github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/cloudwego/base64x v0.1.4 // indirect - github.com/cloudwego/iasm v0.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.4 // indirect github.com/fasthttp/websocket v1.5.3 // indirect @@ -41,7 +37,6 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/compress v1.17.9 // indirect - github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -59,13 +54,11 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/testify v1.11.1 // indirect github.com/subosito/gotenv v1.6.0 // indirect - github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect - golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect golang.org/x/sys v0.31.0 // indirect diff --git a/go.sum b/go.sum index 916bf1c..fa9c4ef 100644 --- a/go.sum +++ b/go.sum @@ -50,10 +50,6 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= -github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= -github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= -github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= -github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -61,10 +57,6 @@ github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWR github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= -github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= -github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= -github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= @@ -176,10 +168,6 @@ github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/X github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= -github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= -github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= -github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -199,8 +187,8 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= -github.com/ollama/ollama v0.11.11 h1:mErMiUGclp47rCDbSUmBiY2L76EpT0uIYRZVBO6qg/k= -github.com/ollama/ollama v0.11.11/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms= +github.com/ollama/ollama v0.12.0 h1:BRry7G2Skz7Mu+E6rz40tzBXNbLTEhheGT8umc1zvxo= +github.com/ollama/ollama v0.12.0/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -244,7 +232,6 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -253,8 +240,6 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8 github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA= github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg= -github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= -github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= @@ -275,9 +260,6 @@ go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= -golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= -golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -414,7 +396,6 @@ golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= @@ -603,9 +584,7 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= -rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= diff --git a/internal/biz/router.go b/internal/biz/router.go index dd9c910..6e5fccb 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -10,7 +10,7 @@ import ( "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/pkg/utils_ollama" - "ai_scheduler/internal/tools" + tools "ai_scheduler/internal/tools" "ai_scheduler/tmpl/dataTemp" "context" "encoding/json" @@ -145,7 +145,7 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe } }() defer close(ch) - err = r.handleMatch(c, ch, &matchJson, task) + err = r.handleMatch(c, ch, &matchJson, task, sysInfo) if err != nil { return } @@ -169,7 +169,7 @@ func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Respons return } -func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask) (err error) { +func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) { if !matchJson.IsMatch { ch <- entitys.ResponseData{ @@ -195,6 +195,8 @@ func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan entitys.ResponseDat return r.handleApiTask(ch, c, matchJson, pointTask) case constants.TaskTypeFunc: return r.handleTask(ch, c, matchJson, pointTask) + case constants.TaskTypeKnowle: + return r.handleKnowle(ch, c, matchJson, pointTask, sysInfo) default: return r.handleOtherTask(c, ch, matchJson) } @@ -215,6 +217,90 @@ func (r *AiRouterBiz) handleTask(channel chan entitys.ResponseData, c *websocket return } +// 知识库 +func (r *AiRouterBiz) handleKnowle(channel chan entitys.ResponseData, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask, sysInfo model.AiSy) (err error) { + + var ( + configData entitys.ConfigDataTool + sessionIdKnowledge string + query string + host string + ) + err = json.Unmarshal([]byte(task.Config), &configData) + if err != nil { + return + } + + // 通过session 找到知识库session + session := c.Headers("X-Session", "") + if len(session) == 0 { + return errors.SessionNotFound + } + sessionInfo, has, err := r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(session)) + if err != nil { + return + } else if !has { + return errors.SessionNotFound + } + + // 找到知识库的host + { + tool, exists := r.toolManager.GetTool(configData.Tool) + if !exists { + return fmt.Errorf("tool not found: %s", configData.Tool) + } + + if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok { + return fmt.Errorf("未找到知识库Tool: %s", configData.Tool) + } else { + host = knowledgeTool.GetConfig().BaseURL + } + + } + + // 知识库的session为空,请求知识库获取, 并绑定 + if sessionInfo.KnowlegeSessionID == "" { + // 请求知识库 + if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, sysInfo.KnowlegeBaseID, sysInfo.KnowlegeTenantKey); err != nil { + return + } + + // 绑定知识库session,下次可以使用 + sessionInfo.KnowlegeSessionID = sessionIdKnowledge + if err = r.sessionImpl.Update(&sessionInfo, r.sessionImpl.WithSessionId(sessionInfo.SessionID)); err != nil { + return + } + } + + // 用户输入解析 + var ok bool + input := make(map[string]string) + if err = json.Unmarshal([]byte(matchJson.Parameters), &input); err != nil { + return + } + if query, ok = input["query"]; !ok { + return fmt.Errorf("query不能为空") + } + + knowledgeConfig := tools.KnowledgeBaseRequest{ + Session: sessionInfo.KnowlegeSessionID, + ApiKey: sysInfo.KnowlegeTenantKey, + Query: query, + } + b, err := json.Marshal(knowledgeConfig) + if err != nil { + return + } + + // 执行工具 + err = r.toolManager.ExecuteTool(channel, c, configData.Tool, b) + if err != nil { + return + } + + return +} + func (r *AiRouterBiz) handleApiTask(channels chan entitys.ResponseData, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var ( request l_request.Request diff --git a/internal/entitys/types.go b/internal/entitys/types.go index 189f426..1d10a01 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -76,8 +76,9 @@ type ConfigDataHttp struct { } type ConfigDataTool struct { - Param map[string]interface{} `json:"param"` - Tool string `json:"tool"` + Param map[string]interface{} `json:"param"` + Request map[string]interface{} `json:"request"` + Tool string `json:"tool"` } // Message 消息 diff --git a/internal/pkg/l_request/README.md b/internal/pkg/l_request/README.md new file mode 100644 index 0000000..f895ffe --- /dev/null +++ b/internal/pkg/l_request/README.md @@ -0,0 +1,25 @@ +## 安装 + +```bash +$ go get -u gitea.cdlsxd.cn/rzy_tools/request/tags +``` + + +## 正常使用 +```go + req := request.Request{ + Method: "POST", + Url: reqUrl, + Json: RequestBody, + Headers: header, + } + resp, _ := req.Send() +``` + +## 同时大量请求或者在协程中使用建议使用 + +```go + r := RequestPools.Get() + defer RequestPools.ClearAndPut(r) + ... +``` \ No newline at end of file diff --git a/internal/pkg/l_request/pool.go b/internal/pkg/l_request/pool.go new file mode 100644 index 0000000..6aed978 --- /dev/null +++ b/internal/pkg/l_request/pool.go @@ -0,0 +1,44 @@ +package l_request + +import ( + "sync" +) + +type RequestPool struct { + pool sync.Pool +} + +var RequestPools = &RequestPool{ + pool: sync.Pool{ + New: func() interface{} { + return new(Request) + }, + }, +} + +func (re *RequestPool) Get() *Request { + return re.pool.Get().(*Request) +} + +func (re *RequestPool) Put(r *Request) { + re.pool.Put(r) +} + +// 重置对象 +func (re *RequestPool) Reset(r *Request) { + r.Method = "" + r.Url = "" + r.Params = nil + r.Headers = nil + r.Cookies = nil + r.Data = nil + r.Json = nil + r.Files = nil + r.Raw = "" + r.JsonByte = nil +} + +func (re *RequestPool) ClearAndPut(r *Request) { + re.Reset(r) + re.Put(r) +} diff --git a/internal/pkg/l_request/pool_test.go b/internal/pkg/l_request/pool_test.go new file mode 100644 index 0000000..596177f --- /dev/null +++ b/internal/pkg/l_request/pool_test.go @@ -0,0 +1,11 @@ +package l_request + +import "testing" + +func TestPool(t *testing.T) { + r := RequestPools.Get() + r.Url = "http://www.baidu.com" + RequestPools.ClearAndPut(r) + a := RequestPools.Get() + t.Log(a.Url) +} diff --git a/internal/pkg/l_request/request.go b/internal/pkg/l_request/request.go new file mode 100644 index 0000000..d47730c --- /dev/null +++ b/internal/pkg/l_request/request.go @@ -0,0 +1,169 @@ +package l_request + +import ( + "crypto/tls" + "encoding/json" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// 请求结构体 +type Request struct { + Method string `json:"method"` // 请求方法 + Url string `json:"url"` // 请求url + Params map[string]string `json:"params"` // Query参数 + Headers map[string]string `json:"headers"` // 请求头 + Cookies map[string]string `json:"cookies"` // todo 处理 Cookies + Data map[string]string `json:"data"` // 表单格式请求数据 + Json map[string]interface{} `json:"json"` // JSON格式请求数据 todo 多层 嵌套 + Files map[string]string `json:"files"` // todo 处理 Files + Raw string `json:"raw"` // 原始请求数据 + JsonByte []byte `json:"json_raw"` // JSON格式请求数据 todo 多层 嵌套 + Xml []byte `json:"xml"` // xml +} + +// 响应结构体 +type Response struct { + StatusCode int `json:"status_code"` // 状态码 + Reason string `json:"reason"` // 状态码说明 + Elapsed float64 `json:"elapsed"` // 请求耗时(秒) + Content []byte `json:"content"` // 响应二进制内容 + Text string `json:"text"` // 响应文本 + Headers map[string]string `json:"headers"` // 响应头 + Cookies map[string]string `json:"cookies"` // todo 添加响应Cookies + Request *Request `json:"request"` // 原始请求 +} + +// 处理请求方法 +func (r *Request) getMethod() string { + return strings.ToUpper(r.Method) // 必须转为全部大写 +} + +// 组装URL +func (r *Request) getUrl() string { + if r.Params != nil { + urlValues := url.Values{} + Url, _ := url.Parse(r.Url) // todo 处理err + for key, value := range r.Params { + urlValues.Set(key, value) + } + Url.RawQuery = urlValues.Encode() + return Url.String() + } + return r.Url +} + +// 组装请求数据 +func (r *Request) getData() io.Reader { + var reqBody string + if r.Headers == nil { + r.Headers = make(map[string]string, 1) + } + if r.Raw != "" { + reqBody = r.Raw + } else if r.Data != nil { + urlValues := url.Values{} + for key, value := range r.Data { + urlValues.Add(key, value) + } + reqBody = urlValues.Encode() + if _, ex := r.Headers["Content-Type"]; !ex { + r.Headers["Content-Type"] = "application/x-www-form-urlencoded" + } + + } else if r.Json != nil { + bytesData, _ := json.Marshal(r.Json) + reqBody = string(bytesData) + if _, ex := r.Headers["Content-Type"]; !ex { + r.Headers["Content-Type"] = "application/json" + } + } else if r.JsonByte != nil { + reqBody = string(r.JsonByte) + if _, ex := r.Headers["Content-Type"]; !ex { + r.Headers["Content-Type"] = "application/json" + } + } else if r.Xml != nil { + reqBody = string(r.Xml) + if _, ex := r.Headers["Content-Type"]; !ex { + r.Headers["Content-Type"] = "application/xml" + } + } + return strings.NewReader(reqBody) +} + +// 添加请求头-需要在getData后使用 +func (r *Request) addHeaders(req *http.Request) { + if r.Headers != nil { + for key, value := range r.Headers { + req.Header.Add(key, value) + } + } +} + +// 准备请求 +func (r *Request) prepare() *http.Request { + Method := r.getMethod() + Url := r.getUrl() + Data := r.getData() + req, _ := http.NewRequest(Method, Url, Data) + r.addHeaders(req) + return req +} + +// 组装响应对象 +func (r *Request) packResponse(res *http.Response, elapsed float64) Response { + var resp Response + resBody, _ := io.ReadAll(res.Body) + resp.Content = resBody + resp.Text = string(resBody) + resp.StatusCode = res.StatusCode + resp.Reason = strings.Split(res.Status, " ")[1] + resp.Elapsed = elapsed + resp.Headers = map[string]string{} + for key, value := range res.Header { + resp.Headers[key] = strings.Join(value, ";") + } + return resp +} + +// 发送请求 +func (r *Request) Send() (Response, error) { + req := r.prepare() + client := &http.Client{} + start := time.Now() + res, err := client.Do(req) + if err != nil { + return Response{}, err + } + defer res.Body.Close() + elapsed := time.Since(start).Seconds() + return r.packResponse(res, elapsed), nil +} + +// 跳过证书发送请求 +func (r *Request) SendWithoutSsl() (Response, error) { + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + req := r.prepare() + client := &http.Client{Transport: tr} + start := time.Now() + res, err := client.Do(req) + if err != nil { + return Response{}, err + } + defer res.Body.Close() + elapsed := time.Since(start).Seconds() + return r.packResponse(res, elapsed), nil +} + +// 发送请求 +func (r *Request) SendNoParseResponse() (*http.Response, error) { + req := r.prepare() + client := &http.Client{} + res, err := client.Do(req) + return res, err +} diff --git a/internal/pkg/util/http.go b/internal/pkg/util/http.go new file mode 100644 index 0000000..5bddb26 --- /dev/null +++ b/internal/pkg/util/http.go @@ -0,0 +1,168 @@ +package util + +import ( + "ai_scheduler/internal/pkg/l_request" + "bufio" + "context" + "encoding/json" + "fmt" + + "net/http" + "strings" +) + +type KnowledgeBase struct { + session string + url string + apiKey string +} + +func NewKnowledgeBase(url, apiKey, session string) *KnowledgeBase { + return &KnowledgeBase{ + session: session, + url: url, + apiKey: apiKey, + } +} + +// 请求知识库聊天 +func (this *KnowledgeBase) Chat(ctx context.Context, query string) (text string, err error) { + + req := l_request.Request{ + Method: "post", + Url: this.url + "/api/v1/knowledge-chat/" + this.session, + Params: nil, + Headers: map[string]string{ + "Content-Type": "application/json", + "X-API-Key": this.apiKey, + }, + Cookies: nil, + Data: nil, + Json: map[string]interface{}{ + "query": query, + }, + Files: nil, + Raw: "", + JsonByte: nil, + Xml: nil, + } + + rsp, err := req.SendNoParseResponse() + if err != nil { + return + } + defer rsp.Body.Close() + + err = connectAndReadSSE(rsp) + if err != nil { + return + } + + return +} + +// Message 表示解析后的 SSE 消息 +type Message struct { + Event string // 事件类型(默认 "message") + Data string // 消息内容(可能多行) + ID string // 消息 ID(可选) +} + +// 连接 SSE 并读取数据 +func connectAndReadSSE(resp *http.Response) error { + + // 验证响应状态和格式 + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("非 200 状态码: %d", resp.StatusCode) + } + contentType := resp.Header.Get("Content-Type") + if !strings.Contains(contentType, "text/event-stream") { + return fmt.Errorf("不支持的 Content-Type: %s", contentType) + } + + // 逐行读取响应流 + scanner := bufio.NewScanner(resp.Body) + var currentMsg Message // 当前正在组装的消息 + + for scanner.Scan() { + line := scanner.Text() + if line == "" { + // 空行表示一条消息结束,处理当前消息 + if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" { + printMessage(currentMsg) + currentMsg = Message{} // 重置消息 + } + continue + } + + // 解析字段(格式:"field: value") + parts := strings.SplitN(line, ":", 2) + if len(parts) < 2 { + continue // 无效行(无冒号),跳过 + } + field := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + switch field { + case "event": + currentMsg.Event = value + case "data": + // data 可能多行,用换行符拼接(最后一条消息可能无结尾空行) + currentMsg.Data += value + "" + //case "id": + // currentMsg.ID = value + // 可选:处理 "retry" 字段(服务器建议的重连时间,单位秒) + } + } + + // 检查扫描错误(如连接断开) + if err := scanner.Err(); err != nil { + return fmt.Errorf("读取流失败: %w", err) + } + + // 处理最后一条未结束的消息(无结尾空行) + if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" { + printMessage(currentMsg) + } + + return nil +} + +type MegContent struct { + Id string `json:"id"` // 消息 ID + ResponseType string `json:"response_type"` // 响应类型,answer 或 references + Content string `json:"content"` // 消息内容 + Done bool `json:"done"` // 是否完成 + KnowledgeReferences interface{} `json:"knowledge_references"` +} + +// printMessage 打印解析后的 SSE 消息 +func printMessage(msg Message) { + + //fmt.Printf("--- 收到 SSE 消息 ---") + //fmt.Printf("事件类型: %s,", msg.Event) + //fmt.Printf("消息 ID: %s,", msg.ID) + //fmt.Printf("内容:%s,", strings.TrimSpace(msg.Data)) // 去除末尾多余换行 + + var content MegContent + _ = json.Unmarshal([]byte(msg.Data), &content) + fmt.Println(msg.Data) + + //if content.ResponseType == "answer" { + // //fmt.Printf("%s", content.Content) + // fmt.Println(content) + //} else { + // fmt.Printf("--- 收到 SSE 消息 ---") + // fmt.Printf("事件类型: %s,", msg.Event) + // fmt.Printf("消息 ID: %s,", msg.ID) + // fmt.Printf("内容:%s,", strings.TrimSpace(msg.Data)) // 去除末尾多余换行 + //} + +} + +// getRetryAfter 从响应头获取重连时间(示例,需根据实际响应头调整) +func getRetryAfter(url string) int { + // 实际需重新请求并获取响应头(此处简化为固定值) + // 正确做法:在 connectAndReadSSE 中记录响应头的 Retry-After 字段 + return 5 // 示例:等待 5 秒 +} diff --git a/internal/tools/konwledge_base.go b/internal/tools/konwledge_base.go new file mode 100644 index 0000000..bf82b28 --- /dev/null +++ b/internal/tools/konwledge_base.go @@ -0,0 +1,248 @@ +package tools + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/l_request" + "bufio" + "encoding/json" + "fmt" + "github.com/gofiber/websocket/v2" + "net/http" + "strings" +) + +// 知识库工具 +type KnowledgeBaseTool struct { + config config.ToolConfig +} + +// NewKnowledgeBaseTool 创建知识库工具 +func NewKnowledgeBaseTool(config config.ToolConfig) *KnowledgeBaseTool { + return &KnowledgeBaseTool{config: config} +} + +func (k *KnowledgeBaseTool) GetConfig() config.ToolConfig { + return k.config +} + +// Name 返回工具名称 +func (k *KnowledgeBaseTool) Name() string { + return "knowledgeBase" +} + +// Description 返回工具描述 +func (k *KnowledgeBaseTool) Description() string { + return "请求知识库" +} + +// Definition 返回工具定义 +func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ + Type: "function", + Function: entitys.FunctionDef{ + Name: k.Name(), + Description: k.Description(), + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "知识库查询条件", + }, + }, + "required": []string{"query"}, + }, + }, + } +} + +// Execute 执行知识库查询 +func (k *KnowledgeBaseTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error { + + var params KnowledgeBaseRequest + if err := json.Unmarshal(args, ¶ms); err != nil { + return fmt.Errorf("unmarshal args failed: %w", err) + } + + return k.chat(channel, c, params) + +} + +type KnowledgeBaseRequest struct { + Session string // 知识库会话id + ApiKey string // 知识库apiKey + Query string // 用户输入 +} + +// Message 表示解析后的 SSE 消息 +type Message struct { + Event string // 事件类型(默认 "message") + Data string // 消息内容(可能多行) + ID string // 消息 ID(可选) +} + +type MegContent struct { + Id string `json:"id"` + ResponseType string `json:"response_type"` + Content string `json:"content"` + Done bool `json:"done"` + KnowledgeReferences interface{} `json:"knowledge_references"` +} + +// 请求知识库聊天 +func (this *KnowledgeBaseTool) chat(channel chan entitys.ResponseData, c *websocket.Conn, param KnowledgeBaseRequest) (err error) { + + req := l_request.Request{ + Method: "post", + Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + param.Session, + Params: nil, + Headers: map[string]string{ + "Content-Type": "application/json", + "X-API-Key": param.ApiKey, + }, + Cookies: nil, + Data: nil, + Json: map[string]interface{}{ + "query": param.Query, + }, + Files: nil, + Raw: "", + JsonByte: nil, + Xml: nil, + } + + rsp, err := req.SendNoParseResponse() + if err != nil { + return + } + defer rsp.Body.Close() + + err = connectAndReadSSE(rsp, channel) + if err != nil { + return + } + + return +} + +// 连接 SSE 并读取数据 +func connectAndReadSSE(resp *http.Response, channel chan entitys.ResponseData) error { + + // 验证响应状态和格式 + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("非 200 状态码: %d", resp.StatusCode) + } + contentType := resp.Header.Get("Content-Type") + if !strings.Contains(contentType, "text/event-stream") { + return fmt.Errorf("不支持的 Content-Type: %s", contentType) + } + + // 逐行读取响应流 + scanner := bufio.NewScanner(resp.Body) + var currentMsg Message // 当前正在组装的消息 + + for scanner.Scan() { + line := scanner.Text() + if line == "" { + // 空行表示一条消息结束,处理当前消息 + if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" { + channel <- entitys.ResponseData{ + Done: false, + Content: currentMsg.Data, + Type: entitys.ResponseJson, + } + currentMsg = Message{} // 重置消息 + } + continue + } + + // 解析字段(格式:"field: value") + parts := strings.SplitN(line, ":", 2) + if len(parts) < 2 { + continue // 无效行(无冒号),跳过 + } + field := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + switch field { + case "event": + currentMsg.Event = value + case "data": + // data 可能多行,用换行符拼接(最后一条消息可能无结尾空行) + currentMsg.Data += value + "" + default: + // 忽略未知字段 + } + } + + // 检查扫描错误(如连接断开) + if err := scanner.Err(); err != nil { + return fmt.Errorf("读取流失败: %w", err) + } + + // 处理最后一条未结束的消息(无结尾空行) + if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" { + channel <- entitys.ResponseData{ + Done: false, + Content: currentMsg.Data, + Type: entitys.ResponseJson, + } + } + + return nil +} + +// 获取知识库 session +func GetKnowledgeBaseSession(host, baseId, apiKey string) (string, error) { + req := l_request.Request{ + Method: "post", + Url: host + "/api/v1/sessions", + Params: nil, + Headers: map[string]string{ + "Content-Type": "application/json", + "X-API-Key": apiKey, + }, + Cookies: nil, + Data: nil, + Json: map[string]interface{}{ + "knowledge_base_id": baseId, + }, + Files: nil, + Raw: "", + JsonByte: nil, + Xml: nil, + } + + rsp, err := req.Send() + if err != nil { + return "", err + } + + var result sessionRsp + err = json.Unmarshal(rsp.Content, &result) + + return result.Data.Id, err +} + +type sessionRsp struct { + Data struct { + Id string `json:"id"` + Title string `json:"title"` + Description string `json:"description"` + TenantId int `json:"tenant_id"` + KnowledgeBaseId string `json:"knowledge_base_id"` + MaxRounds int `json:"max_rounds"` + EnableRewrite bool `json:"enable_rewrite"` + FallbackStrategy string `json:"fallback_strategy"` + FallbackResponse string `json:"fallback_response"` + EmbeddingTopK int `json:"embedding_top_k"` + KeywordThreshold float64 `json:"keyword_threshold"` + VectorThreshold float64 `json:"vector_threshold"` + RerankModelId string `json:"rerank_model_id"` + RerankTopK int `json:"rerank_top_k"` + RerankThreshold float64 `json:"rerank_threshold"` + SummaryModelId string `json:"summary_model_id"` + } `json:"data"` + Success bool `json:"success"` +} diff --git a/internal/tools/konwledge_base_test.go b/internal/tools/konwledge_base_test.go new file mode 100644 index 0000000..4ea9bed --- /dev/null +++ b/internal/tools/konwledge_base_test.go @@ -0,0 +1,34 @@ +package tools + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "testing" +) + +func TestKnowledgeBaseTool_Execute(t *testing.T) { + + kb := NewKnowledgeBaseTool(config.ToolConfig{}) + channel := make(chan entitys.ResponseData) + err := kb.Execute(channel, nil, nil) + if err != nil { + t.Errorf("Execute() error = %v", err) + } + +} + +// session +func TestKnowledgeBaseTool_Submit(t *testing.T) { + + apiKey := "sk-EfnUANKMj3DUOiEPJZ5xS8SGMsbO6be_qYAg9uZ8T3zyoFM-" + baseId := "kb-00000001" + host := "http://117.175.169.61:10000" + + sessionId, err := GetKnowledgeBaseSession(host, baseId, apiKey) + if err != nil { + t.Errorf("GetKnowledgeBaseSession() error = %v", err) + } + + t.Log("sessionId:", sessionId) + +} diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 4b854ca..1aa0281 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -65,6 +65,12 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { zltxOrderStatisticsTool := NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics) m.tools[zltxOrderStatisticsTool.Name()] = zltxOrderStatisticsTool } + + // 注册知识库工具 + if config.Tools.Knowledge.Enabled { + knowledgeTool := NewKnowledgeBaseTool(config.Tools.Knowledge) + m.tools[knowledgeTool.Name()] = knowledgeTool + } return m } From 5523a8f78afc1713366c5023576dc90c03f8c15e Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Tue, 23 Sep 2025 21:17:44 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E7=BB=93=E6=9E=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.yaml | 3 +- internal/biz/router.go | 172 +++++++++++++++++++++------- internal/config/config.go | 1 + internal/pkg/channel_pool.go | 15 ++- internal/pkg/func.go | 17 ++- internal/pkg/utils_ollama/client.go | 2 +- internal/pkg/utils_ollama/ollama.go | 116 ++++++++++++++++--- 7 files changed, 258 insertions(+), 68 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index c85ac02..c6945d6 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -9,12 +9,13 @@ ollama: timeout: "120s" level: "info" format: "json" + sys: session_len: 3 channel_pool_len: 100 channel_pool_size: 32 - + llm_pool_len: 5 redis: host: 47.97.27.195:6379 type: node diff --git a/internal/biz/router.go b/internal/biz/router.go index 6e5fccb..1ed0ed8 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -16,10 +16,12 @@ import ( "encoding/json" "fmt" "strings" + "time" "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" + "github.com/ollama/ollama/api" "github.com/tmc/langchaingo/llms" "xorm.io/builder" ) @@ -34,6 +36,7 @@ type AiRouterBiz struct { hisImpl *impl.ChatImpl conf *config.Config utilAgent *utils_ollama.UtilOllama + ollama *utils_ollama.Client channelPool *pkg.SafeChannelPool } @@ -48,6 +51,7 @@ func NewAiRouterBiz( conf *config.Config, utilAgent *utils_ollama.UtilOllama, channelPool *pkg.SafeChannelPool, + ollama *utils_ollama.Client, ) *AiRouterBiz { return &AiRouterBiz{ //aiClient: aiClient, @@ -59,6 +63,7 @@ func NewAiRouterBiz( taskImpl: taskImpl, utilAgent: utilAgent, channelPool: channelPool, + ollama: ollama, } } @@ -68,30 +73,52 @@ func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*ent return nil, nil } -// Route 执行智能路由 func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { - ch, err := r.channelPool.Get() - if err != nil { - return err - } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + //ch := r.channelPool.Get() + ch := make(chan entitys.ResponseData) + done := make(chan struct{}) + + go func() { + defer close(done) + for { + select { + case v, ok := <-ch: + if !ok { + return + } + // 带超时的发送,避免阻塞 + if err := sendWithTimeout(c, v, 2*time.Second); err != nil { + log.Errorf("Send error: %v", err) + cancel() // 通知主流程退出 + return + } + case <-ctx.Done(): + return + } + } + }() + defer func() { + if err != nil { - entitys.MsgSend(c, entitys.ResponseData{ + _ = entitys.MsgSend(c, entitys.ResponseData{ Done: false, Content: err.Error(), Type: entitys.ResponseErr, }) } - entitys.MsgSend(c, entitys.ResponseData{ + _ = entitys.MsgSend(c, entitys.ResponseData{ Done: true, Content: "", Type: entitys.ResponseEnd, }) - err = r.channelPool.Put(ch) - if err != nil { - close(ch) - } + //r.channelPool.Put(ch) + close(ch) }() + session := c.Headers("X-Session", "") if len(session) == 0 { return errors.SessionNotFound @@ -119,46 +146,94 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe if err != nil { return errors.SystemError } - //意图预测 + prompt := r.getPromptLLM(sysInfo, history, req.Text, task) - match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt, - llms.WithJSONMode(), - ) - if err != nil { - return errors.SystemError - } - log.Info(match.Choices[0].Content) + + AgentClient := r.utilAgent.Get() + ch <- entitys.ResponseData{ Done: false, - Content: match.Choices[0].Content, + Content: "准备意图识别", Type: entitys.ResponseLog, } - var matchJson entitys.Match - err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson) + + match, err := AgentClient.Llm.GenerateContent( + ctx, // 使用可取消的上下文 + prompt, + llms.WithJSONMode(), + ) + resMsg := match.Choices[0].Content + + r.utilAgent.Put(AgentClient) + ch <- entitys.ResponseData{ + Done: false, + Content: resMsg, + Type: entitys.ResponseLog, + } + ch <- entitys.ResponseData{ + Done: false, + Content: "意图识别结束", + Type: entitys.ResponseLog, + } + //for i := 1; i < 10; i++ { + // ch <- entitys.ResponseData{ + // Done: false, + // Content: fmt.Sprintf("%d", i), + // Type: entitys.ResponseLog, + // } + // time.Sleep(1 * time.Second) + //} + //return + if err != nil { + log.Errorf("LLM error: %v", err) return errors.SystemError } - go func() { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("recovered from panic: %v", r) - } - }() - defer close(ch) - err = r.handleMatch(c, ch, &matchJson, task, sysInfo) - if err != nil { - return - } - }() - for v := range ch { - if err := entitys.MsgSend(c, v); err != nil { - return err - } + //msg, err := r.ollama.ToolSelect(ctx, r.getPromptOllama(sysInfo, history, req.Text), []api.Tool{}) + //if err != nil { + // return + //} + //resMsg := msg.Message.Content + select { + case ch <- entitys.ResponseData{ + Done: false, + Content: resMsg, + Type: entitys.ResponseLog, + }: + case <-ctx.Done(): + return ctx.Err() } - return + + var matchJson entitys.Match + if err := json.Unmarshal([]byte(resMsg), &matchJson); err != nil { + return errors.SystemError + } + + if err := r.handleMatch(ctx, c, ch, &matchJson, task, sysInfo); err != nil { + return err + } + + return nil } +// 辅助函数:带超时的 WebSocket 发送 +func sendWithTimeout(c *websocket.Conn, data entitys.ResponseData, timeout time.Duration) error { + sendCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- entitys.MsgSend(c, data) + }() + + select { + case err := <-done: + return err + case <-sendCtx.Done(): + return sendCtx.Err() + } +} func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match) (err error) { ch <- entitys.ResponseData{ Done: false, @@ -169,7 +244,7 @@ func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Respons return } -func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) { +func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) { if !matchJson.IsMatch { ch <- entitys.ResponseData{ @@ -404,6 +479,23 @@ func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, re return prompt } +func (r *AiRouterBiz) getPromptOllama(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []api.Message { + var ( + prompt = make([]api.Message, 0) + ) + prompt = append(prompt, api.Message{ + Role: "system", + Content: r.buildSystemPrompt(sysInfo.SysPrompt), + }, api.Message{ + Role: "assistant", + Content: pkg.JsonStringIgonErr(r.buildAssistant(history)), + }, api.Message{ + Role: "user", + Content: reqInput, + }) + return prompt +} + func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { var ( prompt = make([]llms.MessageContent, 0) diff --git a/internal/config/config.go b/internal/config/config.go index 6a0bcc4..2e80b18 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -28,6 +28,7 @@ type SysConfig struct { SessionLen int `mapstructure:"session_len"` ChannelPoolLen int `mapstructure:"channel_pool_len"` ChannelPoolSize int `mapstructure:"channel_pool_size"` + LlmPoolLen int `mapstructure:"llm_pool_len"` } // ServerConfig 服务器配置 diff --git a/internal/pkg/channel_pool.go b/internal/pkg/channel_pool.go index 5fa2137..eda85fa 100644 --- a/internal/pkg/channel_pool.go +++ b/internal/pkg/channel_pool.go @@ -3,7 +3,6 @@ package pkg import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" - "errors" "sync" ) @@ -25,29 +24,29 @@ func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) { } // 从池中获取 channel(若无空闲则创建新 channel) -func (p *SafeChannelPool) Get() (chan entitys.ResponseData, error) { +func (p *SafeChannelPool) Get() chan entitys.ResponseData { p.mu.Lock() defer p.mu.Unlock() if p.closed { - return nil, errors.New("pool is closed") + return make(chan entitys.ResponseData, p.bufSize) } select { case ch := <-p.pool: // 从池中取 - return ch, nil + return ch default: // 池为空,创建新 channel - return make(chan entitys.ResponseData, p.bufSize), nil + return make(chan entitys.ResponseData, p.bufSize) } } // 将 channel 放回池中(必须确保 channel 已清空!) -func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) error { +func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) { p.mu.Lock() defer p.mu.Unlock() if p.closed { - return errors.New("pool is closed") + return } // 清空 channel(防止复用时读取旧数据) @@ -62,7 +61,7 @@ func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) error { default: // 池已满,直接关闭 channel(避免泄漏) close(ch) } - return nil + return } // 关闭池(释放所有资源) diff --git a/internal/pkg/func.go b/internal/pkg/func.go index a31a818..20d9d7c 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -1,8 +1,23 @@ package pkg -import "encoding/json" +import ( + "ai_scheduler/internal/entitys" + "encoding/json" +) func JsonStringIgonErr(data interface{}) string { dataByte, _ := json.Marshal(data) return string(dataByte) } + +// IsChannelClosed 检查给定的 channel 是否已经关闭 +// 参数 ch: 要检查的 channel,类型为 chan entitys.ResponseData +// 返回值: bool 类型,true 表示 channel 已关闭,false 表示未关闭 +func IsChannelClosed(ch chan entitys.ResponseData) bool { + select { + case _, ok := <-ch: // 尝试从 channel 中读取数据 + return !ok // 如果 ok=false,说明 channel 已关闭 + default: // 如果 channel 暂时无数据可读(但不一定关闭) + return false // channel 未关闭(但可能有数据未读取) + } +} diff --git a/internal/pkg/utils_ollama/client.go b/internal/pkg/utils_ollama/client.go index 390caa8..9d9ac82 100644 --- a/internal/pkg/utils_ollama/client.go +++ b/internal/pkg/utils_ollama/client.go @@ -46,7 +46,7 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [ Messages: messages, Stream: new(bool), // 设置为false,不使用流式响应 Think: &api.ThinkValue{Value: true}, - Tools: tools, + //Tools: tools, } err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { diff --git a/internal/pkg/utils_ollama/ollama.go b/internal/pkg/utils_ollama/ollama.go index 65388a3..6121838 100644 --- a/internal/pkg/utils_ollama/ollama.go +++ b/internal/pkg/utils_ollama/ollama.go @@ -2,40 +2,112 @@ package utils_ollama import ( "ai_scheduler/internal/config" + "math/rand" + "net" "net/http" "os" + "sync" + "time" "github.com/gofiber/fiber/v2/log" "github.com/tmc/langchaingo/llms/ollama" ) type UtilOllama struct { - Llm *ollama.LLM + LlmClientPool *sync.Pool + poolSize int // 记录池大小,用于调试 + model string + serverURL string + c *config.Config +} + +type LlmObj struct { + Number string + Llm *ollama.LLM } func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama { - llm, err := ollama.New( - ollama.WithModel(c.Ollama.Model), - ollama.WithHTTPClient(http.DefaultClient), - ollama.WithServerURL(getUrl(c)), - ollama.WithKeepAlive("-1s"), - ) - if err != nil { - logger.Fatal(err) - panic(err) + poolSize := c.Sys.LlmPoolLen + if poolSize <= 0 { + poolSize = 10 // 默认值 + logger.Warnf("LlmPoolLen not set, using default: %d", poolSize) + } + + // 初始化 Pool + pool := &sync.Pool{ + New: func() interface{} { + llm, err := ollama.New( + ollama.WithModel(c.Ollama.Model), + ollama.WithHTTPClient(http.DefaultClient), + ollama.WithServerURL(getUrl(c)), + ollama.WithKeepAlive("-1s"), + ) + if err != nil { + logger.Fatalf("Failed to create Ollama client: %v", err) + panic(err) // 或者返回 nil + 错误处理 + } + number := randStr(5) + log.Info(number) + return &LlmObj{ + Number: number, + Llm: llm, + } + }, + } + + // 预填充 Pool + for i := 0; i < poolSize; i++ { + pool.Put(pool.New()) } return &UtilOllama{ - Llm: llm, + LlmClientPool: pool, + poolSize: poolSize, + model: c.Ollama.Model, + serverURL: getUrl(c), } + } -//func (o *UtilOllama) a() { -// var agent agents.Agent -// agent = agents.NewOneShotAgent(llm, tools, opts...) -// -// agents.NewExecutor() -//} +func (o *UtilOllama) NewClient() *ollama.LLM { + llm, _ := ollama.New( + ollama.WithModel(o.c.Ollama.Model), + ollama.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + MaxIdleConns: 100, // 最大空闲连接数(默认 2,太小) + MaxIdleConnsPerHost: 100, // 每个 Host 的最大空闲连接数(默认 2) + IdleConnTimeout: 90 * time.Second, // 空闲连接超时时间 + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, // 连接超时 + KeepAlive: 30 * time.Second, // TCP Keep-Alive + }).DialContext, + }, + Timeout: 60 * time.Second, // 整体请求超时(避免无限等待) + }), + ollama.WithServerURL(getUrl(o.c)), + ollama.WithKeepAlive("-1s"), + ) + return llm +} + +// Get 返回一个可用的 LLM 客户端 +func (o *UtilOllama) Get() *LlmObj { + client := o.LlmClientPool.Get().(*LlmObj) + return client +} + +// Put 归还客户端(可选:检查是否仍可用) +func (o *UtilOllama) Put(llm *LlmObj) { + if llm == nil { + return + } + o.LlmClientPool.Put(llm) +} + +// Stats 返回池的统计信息(用于监控) +func (o *UtilOllama) Stats() (current, max int) { + return o.poolSize, o.poolSize +} func getUrl(c *config.Config) string { baseURL := c.Ollama.BaseURL @@ -45,3 +117,13 @@ func getUrl(c *config.Config) string { } return baseURL } + +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func randStr(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} From 5deb04576786556328dc85ecd4f163932eff37e7 Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Wed, 24 Sep 2025 09:43:26 +0800 Subject: [PATCH 3/3] =?UTF-8?q?feat:=20=E6=8E=A5=E5=85=A5=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tools/konwledge_base.go | 34 +++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/internal/tools/konwledge_base.go b/internal/tools/konwledge_base.go index bf82b28..b614c11 100644 --- a/internal/tools/konwledge_base.go +++ b/internal/tools/konwledge_base.go @@ -7,6 +7,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" "net/http" "strings" @@ -64,6 +65,7 @@ func (k *KnowledgeBaseTool) Execute(channel chan entitys.ResponseData, c *websoc if err := json.Unmarshal(args, ¶ms); err != nil { return fmt.Errorf("unmarshal args failed: %w", err) } + log.Info("开始执行知识库 KnowledgeBaseTool Execute, params: %v", params) return k.chat(channel, c, params) @@ -82,7 +84,7 @@ type Message struct { ID string // 消息 ID(可选) } -type MegContent struct { +type MsgContent struct { Id string `json:"id"` ResponseType string `json:"response_type"` Content string `json:"content"` @@ -90,6 +92,22 @@ type MegContent struct { KnowledgeReferences interface{} `json:"knowledge_references"` } +// 解析知识库响应内容,并把通过channel结果返回 +func msgContentParse(input string, channel chan entitys.ResponseData) (msgContent MsgContent, err error) { + err = json.Unmarshal([]byte(input), &msgContent) + if err != nil { + err = fmt.Errorf("unmarshal input failed: %w", err) + } + + channel <- entitys.ResponseData{ + Done: msgContent.Done, + Content: msgContent.Content, + Type: entitys.ResponseStream, + } + + return +} + // 请求知识库聊天 func (this *KnowledgeBaseTool) chat(channel chan entitys.ResponseData, c *websocket.Conn, param KnowledgeBaseRequest) (err error) { @@ -147,10 +165,9 @@ func connectAndReadSSE(resp *http.Response, channel chan entitys.ResponseData) e if line == "" { // 空行表示一条消息结束,处理当前消息 if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" { - channel <- entitys.ResponseData{ - Done: false, - Content: currentMsg.Data, - Type: entitys.ResponseJson, + _, err := msgContentParse(currentMsg.Data, channel) + if err != nil { + return fmt.Errorf("msgContentParse failed: %w", err) } currentMsg = Message{} // 重置消息 } @@ -183,10 +200,9 @@ func connectAndReadSSE(resp *http.Response, channel chan entitys.ResponseData) e // 处理最后一条未结束的消息(无结尾空行) if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" { - channel <- entitys.ResponseData{ - Done: false, - Content: currentMsg.Data, - Type: entitys.ResponseJson, + _, err := msgContentParse(currentMsg.Data, channel) + if err != nil { + return fmt.Errorf("msgContentParse failed: %w", err) } }