diff --git a/go.mod b/go.mod index 48c2db6..7768c99 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ toolchain go1.24.7 require ( gitea.cdlsxd.cn/self-tools/l_request v1.0.8 + github.com/bytedance/sonic v1.11.6 github.com/emirpasic/gods v1.18.1 github.com/go-kratos/kratos/v2 v2.9.1 github.com/gofiber/fiber/v2 v2.52.9 @@ -26,7 +27,10 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/andybalholm/brotli v1.1.0 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.4 // indirect github.com/fasthttp/websocket v1.5.3 // indirect @@ -37,6 +41,7 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/compress v1.17.9 // indirect + github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -54,11 +59,13 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/testify v1.11.1 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect + golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect golang.org/x/sys v0.31.0 // indirect diff --git a/go.sum b/go.sum index 2542880..916bf1c 100644 --- a/go.sum +++ b/go.sum @@ -50,6 +50,10 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= +github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= +github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -57,6 +61,10 @@ github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWR github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= @@ -168,6 +176,10 @@ github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/X github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -232,6 +244,7 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -240,6 +253,8 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8 github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA= github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= @@ -260,6 +275,9 @@ go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -396,6 +414,7 @@ golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= @@ -584,7 +603,9 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= diff --git a/internal/biz/router.go b/internal/biz/router.go index d754745..9376b4b 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -76,10 +76,21 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe } defer func() { if err != nil { - _ = c.WriteMessage(websocket.TextMessage, []byte(err.Error())) + entitys.MsgSend(c, entitys.ResponseData{ + Done: true, + Content: []byte(err.Error()), + Type: entitys.ResponseErr, + }) + } + entitys.MsgSend(c, entitys.ResponseData{ + Done: true, + Content: []byte(err.Error()), + Type: entitys.ResponseEnd, + }) + err = r.channelPool.Put(ch) + if err != nil { + close(ch) } - _ = c.WriteMessage(websocket.TextMessage, []byte("EOF")) - r.channelPool.Put(ch) }() session := c.Headers("X-Session", "") if len(session) == 0 { @@ -117,7 +128,11 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe return errors.SystemError } log.Info(match.Choices[0].Content) - _ = c.WriteMessage(websocket.TextMessage, []byte(match.Choices[0].Content)) + ch <- entitys.ResponseData{ + Done: false, + Content: []byte(match.Choices[0].Content), + Type: entitys.ResponseLog, + } var matchJson entitys.Match err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson) if err != nil { @@ -137,24 +152,31 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe }() for v := range ch { - if err := c.WriteMessage(websocket.TextMessage, v); err != nil { + if err := entitys.MsgSend(c, v); err != nil { return err } } - _ = c.WriteMessage(websocket.TextMessage, []byte("结束")) return } -func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan []byte, matchJson *entitys.Match) (err error) { - ch <- []byte(matchJson.Reasoning) +func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match) (err error) { + ch <- entitys.ResponseData{ + Done: false, + Content: []byte(matchJson.Reasoning), + Type: entitys.ResponseText, + } return } -func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan []byte, matchJson *entitys.Match, tasks []model.AiTask) (err error) { +func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask) (err error) { if !matchJson.IsMatch { - ch <- []byte(matchJson.Reasoning) + ch <- entitys.ResponseData{ + Done: false, + Content: []byte(matchJson.Reasoning), + Type: entitys.ResponseText, + } return } var pointTask *model.AiTask @@ -178,7 +200,7 @@ func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan []byte, matchJson * } } -func (r *AiRouterBiz) handleTask(channel chan []byte, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { +func (r *AiRouterBiz) handleTask(channel chan entitys.ResponseData, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) @@ -193,7 +215,7 @@ func (r *AiRouterBiz) handleTask(channel chan []byte, c *websocket.Conn, matchJs return } -func (r *AiRouterBiz) handleApiTask(channels chan []byte, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { +func (r *AiRouterBiz) handleApiTask(channels chan entitys.ResponseData, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var ( request l_request.Request auth = c.Headers("X-Authorization", "") diff --git a/internal/entitys/response.go b/internal/entitys/response.go index f8bcb00..8d98d83 100644 --- a/internal/entitys/response.go +++ b/internal/entitys/response.go @@ -1,5 +1,11 @@ package entitys +import ( + "encoding/json" + + "github.com/gofiber/websocket/v2" +) + type Response string const ( @@ -10,4 +16,35 @@ const ( ResponseText Response = "txt" ResponseImg Response = "img" ResponseFile Response = "file" + ResponseErr Response = "error" + ResponseLog Response = "log" ) + +type ResponseData struct { + Done bool + Content []byte + Type Response +} + +func MsgSet(msgType Response, msg []byte, done bool) []byte { + jsonByte, err := json.Marshal(ResponseData{ + Done: done, + Content: msg, + Type: msgType, + }) + if err != nil { + return nil + } + return jsonByte +} + +func MsgSend(c *websocket.Conn, msg ResponseData) error { + jsonByte, _ := json.Marshal(msg) + + return c.WriteMessage(websocket.TextMessage, jsonByte) +} + +func MsgSendByte(c *websocket.Conn, msg []byte) { + + _ = c.WriteMessage(websocket.TextMessage, msg) +} diff --git a/internal/entitys/types.go b/internal/entitys/types.go index 820bedc..189f426 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -67,7 +67,7 @@ type Tool interface { Name() string Description() string Definition() ToolDefinition - Execute(channel chan []byte, c *websocket.Conn, args json.RawMessage) error + Execute(channel chan ResponseData, c *websocket.Conn, args json.RawMessage) error } type ConfigDataHttp struct { diff --git a/internal/pkg/channel_pool.go b/internal/pkg/channel_pool.go index 70502b2..5fa2137 100644 --- a/internal/pkg/channel_pool.go +++ b/internal/pkg/channel_pool.go @@ -2,20 +2,21 @@ package pkg import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" "errors" "sync" ) type SafeChannelPool struct { - pool chan chan []byte // 存储空闲 channel 的队列 - bufSize int // channel 缓冲大小 + pool chan chan entitys.ResponseData // 存储空闲 channel 的队列 + bufSize int // channel 缓冲大小 mu sync.Mutex closed bool } func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) { pool := &SafeChannelPool{ - pool: make(chan chan []byte, c.Sys.ChannelPoolLen), + pool: make(chan chan entitys.ResponseData, c.Sys.ChannelPoolLen), bufSize: c.Sys.ChannelPoolSize, } @@ -24,7 +25,7 @@ func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) { } // 从池中获取 channel(若无空闲则创建新 channel) -func (p *SafeChannelPool) Get() (chan []byte, error) { +func (p *SafeChannelPool) Get() (chan entitys.ResponseData, error) { p.mu.Lock() defer p.mu.Unlock() @@ -36,12 +37,12 @@ func (p *SafeChannelPool) Get() (chan []byte, error) { case ch := <-p.pool: // 从池中取 return ch, nil default: // 池为空,创建新 channel - return make(chan []byte, p.bufSize), nil + return make(chan entitys.ResponseData, p.bufSize), nil } } // 将 channel 放回池中(必须确保 channel 已清空!) -func (p *SafeChannelPool) Put(ch chan []byte) error { +func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) error { p.mu.Lock() defer p.mu.Unlock() diff --git a/internal/pkg/utils_ollama/client.go b/internal/pkg/utils_ollama/client.go index bae73cc..730a033 100644 --- a/internal/pkg/utils_ollama/client.go +++ b/internal/pkg/utils_ollama/client.go @@ -60,7 +60,7 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [ return } -func (c *Client) ChatStream(ctx context.Context, ch chan []byte, messages []api.Message) (err error) { +func (c *Client) ChatStream(ctx context.Context, ch chan entitys.ResponseData, messages []api.Message) (err error) { // 构建聊天请求 req := &api.ChatRequest{ Model: c.config.Model, @@ -74,15 +74,15 @@ func (c *Client) ChatStream(ctx context.Context, ch chan []byte, messages []api. defer w.Done() err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { if resp.Message.Content != "" { - ch <- []byte(resp.Message.Content) - } - if resp.Done { - ch <- []byte("EOF") + ch <- entitys.ResponseData{ + Done: false, + Content: []byte(resp.Message.Content), + Type: entitys.ResponseStream, + } } return nil }) if err != nil { - ch <- []byte("EOF") return } }() diff --git a/internal/tools/manager.go b/internal/tools/manager.go index d4a86d9..4b854ca 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -94,7 +94,7 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi } // ExecuteTool 执行工具 -func (m *Manager) ExecuteTool(channel chan []byte, c *websocket.Conn, name string, args json.RawMessage) error { +func (m *Manager) ExecuteTool(channel chan entitys.ResponseData, c *websocket.Conn, name string, args json.RawMessage) error { tool, exists := m.GetTool(name) if !exists { return fmt.Errorf("tool not found: %s", name) diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx_order_detail.go index 1181e74..c37dcba 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx_order_detail.go @@ -81,7 +81,7 @@ type ZltxOrderDetailData struct { } // Execute 执行直连天下订单详情查询 -func (w *ZltxOrderDetailTool) Execute(channel chan []byte, c *websocket.Conn, args json.RawMessage) error { +func (w *ZltxOrderDetailTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error { var req ZltxOrderDetailRequest if err := json.Unmarshal(args, &req); err != nil { return fmt.Errorf("invalid zltxOrderDetail request: %w", err) @@ -96,7 +96,7 @@ func (w *ZltxOrderDetailTool) Execute(channel chan []byte, c *websocket.Conn, ar } // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 -func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan []byte, c *websocket.Conn, number string) (err error) { +func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.ResponseData, c *websocket.Conn, number string) (err error) { //查询订单详情 var auth string if c != nil { @@ -129,9 +129,18 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan []byte, c *websocket.Co if err = json.Unmarshal(res.Content, &resData); err != nil { return } - ch <- res.Content + ch <- entitys.ResponseData{ + Done: false, + Content: res.Content, + Type: entitys.ResponseJson, + } if resData.Data.Direct != nil && resData.Data.Direct["needAi"].(bool) { - ch <- []byte("orderErrorChecking") + ch <- entitys.ResponseData{ + Done: false, + Content: []byte("正在分析订单日志"), + Type: entitys.ResponseLoading, + } + req = l_request.Request{ Url: fmt.Sprintf("%szltx_api/admin/direct/log/%s/%s", w.config.BaseURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)), Headers: map[string]string{ @@ -154,6 +163,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan []byte, c *websocket.Co if err != nil { return fmt.Errorf("订单日志解析失败:%s", err) } + err = w.llm.ChatStream(context.TODO(), ch, []api.Message{ { Role: "system", diff --git a/internal/tools/zltx_order_direct_log.go b/internal/tools/zltx_order_direct_log.go index abd9ec5..95e2d05 100644 --- a/internal/tools/zltx_order_direct_log.go +++ b/internal/tools/zltx_order_direct_log.go @@ -5,6 +5,7 @@ import ( "ai_scheduler/internal/entitys" "encoding/json" "fmt" + "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/websocket/v2" ) @@ -66,7 +67,7 @@ type ZltxOrderDirectLogData struct { Data map[string]interface{} `json:"data"` } -func (t *ZltxOrderLogTool) Execute(channel chan []byte, c *websocket.Conn, args json.RawMessage) error { +func (t *ZltxOrderLogTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error { var req ZltxOrderLogRequest if err := json.Unmarshal(args, &req); err != nil { return fmt.Errorf("invalid zltxOrderLog request: %w", err) @@ -77,7 +78,7 @@ func (t *ZltxOrderLogTool) Execute(channel chan []byte, c *websocket.Conn, args return t.getZltxOrderLog(channel, c, req.OrderNumber, req.SerialNumber) } -func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan []byte, c *websocket.Conn, orderNumber, serialNumber string) (err error) { +func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.ResponseData, c *websocket.Conn, orderNumber, serialNumber string) (err error) { //查询订单详情 var auth string if c != nil { @@ -109,7 +110,11 @@ func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan []byte, c *websocket.Con _ = c.WriteMessage(websocket.TextMessage, res.Content) return } else { - channel <- res.Content + channel <- entitys.ResponseData{ + Done: false, + Content: res.Content, + Type: entitys.ResponseJson, + } } return } diff --git a/internal/tools/zltx_product.go b/internal/tools/zltx_product.go index 50f28e2..0e3728a 100644 --- a/internal/tools/zltx_product.go +++ b/internal/tools/zltx_product.go @@ -5,9 +5,10 @@ import ( "ai_scheduler/internal/entitys" "encoding/json" "fmt" + "strings" + "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/websocket/v2" - "strings" ) type ZltxProductTool struct { @@ -51,7 +52,7 @@ type ZltxProductRequest struct { Name string `json:"name"` } -func (z ZltxProductTool) Execute(channel chan []byte, c *websocket.Conn, args json.RawMessage) error { +func (z ZltxProductTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error { var req ZltxProductRequest if err := json.Unmarshal(args, &req); err != nil { return fmt.Errorf("invalid zltxProduct request: %w", err) @@ -125,7 +126,7 @@ type ZltxProductData struct { PlatformProductList interface{} `json:"platform_product_list"` } -func (z ZltxProductTool) getZltxProduct(channel chan []byte, c *websocket.Conn, id string, name string) error { +func (z ZltxProductTool) getZltxProduct(channel chan entitys.ResponseData, c *websocket.Conn, id string, name string) error { var auth string if c != nil { auth = c.Headers("X-Authorization", "") @@ -176,7 +177,11 @@ func (z ZltxProductTool) getZltxProduct(channel chan []byte, c *websocket.Conn, } } } - channel <- res.Content + channel <- entitys.ResponseData{ + Done: false, + Content: res.Content, + Type: entitys.ResponseJson, + } return nil } diff --git a/internal/tools/zltx_statistics.go b/internal/tools/zltx_statistics.go index 4a5716b..ae35c26 100644 --- a/internal/tools/zltx_statistics.go +++ b/internal/tools/zltx_statistics.go @@ -5,6 +5,7 @@ import ( "ai_scheduler/internal/entitys" "encoding/json" "fmt" + "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/websocket/v2" ) @@ -45,7 +46,7 @@ type ZltxOrderStatisticsRequest struct { Number string `json:"number"` } -func (z ZltxOrderStatisticsTool) Execute(channel chan []byte, c *websocket.Conn, args json.RawMessage) error { +func (z ZltxOrderStatisticsTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error { var req ZltxOrderStatisticsRequest if err := json.Unmarshal(args, &req); err != nil { return err @@ -70,7 +71,7 @@ type ZltxOrderStatisticsData struct { Total int `json:"total"` } -func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan []byte, c *websocket.Conn, number string) error { +func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.ResponseData, c *websocket.Conn, number string) error { //查询订单详情 var auth string if c != nil { @@ -102,7 +103,11 @@ func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan []byte, c * _ = c.WriteMessage(websocket.TextMessage, res.Content) return nil } else { - channel <- res.Content + channel <- entitys.ResponseData{ + Done: false, + Content: res.Content, + Type: entitys.ResponseJson, + } } return nil }