From f052000d591abc33a35c09cb725a1a8d2d5b0e63 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Thu, 9 Oct 2025 11:38:37 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=93=E6=9E=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/router.go | 33 +++++++++++++++++++++++++++++++ internal/data/error/error_code.go | 4 ++++ internal/entitys/types.go | 1 + internal/pkg/func.go | 30 ++++++++++++++++++++++++++++ internal/tools/normal_chat.go | 8 +++++++- 5 files changed, 75 insertions(+), 1 deletion(-) diff --git a/internal/biz/router.go b/internal/biz/router.go index df7e82c..c88fa62 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -80,6 +80,13 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe cancel() }() + //获取图片信息 + err = r.getImgData(req.Img, &requireData) + if err != nil { + log.Errorf("GetImgData error: %v", err) + return + } + //获取系统信息 err = r.getRequireData(req.Text, &requireData) if err != nil { @@ -173,6 +180,32 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura } } +func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) { + return + if len(imgUrl) == 0 { + return + } + if err = pkg.ValidateImageURL(imgUrl); err != nil { + return err + } + req := l_request.Request{ + Method: "GET", + Url: imgUrl, + } + res, err := req.Send() + if err != nil { + return + } + if _, ex := res.Headers["Content-Type"]; !ex { + return errors.ParamErr("图片格式错误:Content-Type未获取") + } + if !strings.HasPrefix(res.Headers["Content-Type"], "image/") { + return errors.ParamErr("expected image content, got %s", res.Headers["Content-Type"]) + } + requireData.ImgByte = append(requireData.ImgByte, res.Content) + return +} + func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) { requireData.Ch <- entitys.Response{ Index: "", diff --git a/internal/data/error/error_code.go b/internal/data/error/error_code.go index e6c7230..e9cc67d 100644 --- a/internal/data/error/error_code.go +++ b/internal/data/error/error_code.go @@ -46,6 +46,10 @@ func SysErr(message string, arg ...any) *BusinessErr { return &BusinessErr{code: SystemError.code, message: fmt.Sprintf(message, arg)} } +func ParamErr(message string, arg ...any) *BusinessErr { + return &BusinessErr{code: ParamError.code, message: fmt.Sprintf(message, arg)} +} + func (e *BusinessErr) Wrap(err error) *BusinessErr { return NewBusinessErr(e.code, err.Error()) } diff --git a/internal/entitys/types.go b/internal/entitys/types.go index ca4eb96..0395b29 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -151,6 +151,7 @@ type RequireData struct { Auth string Ch chan Response KnowledgeConf KnowledgeBaseRequest + ImgByte []api.ImageData } type KnowledgeBaseRequest struct { diff --git a/internal/pkg/func.go b/internal/pkg/func.go index 0583a46..27ce2a3 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -3,6 +3,10 @@ package pkg import ( "ai_scheduler/internal/entitys" "encoding/json" + "errors" + "fmt" + "net/url" + "strings" ) func JsonStringIgonErr(data interface{}) string { @@ -25,3 +29,29 @@ func IsChannelClosed(ch chan entitys.ResponseData) bool { return false // channel 未关闭(但可能有数据未读取) } } + +// ValidateImageURL 验证图片 URL 是否有效 +func ValidateImageURL(rawURL string) error { + // 1. 基础格式验证 + parsed, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL format: %v", err) + } + + // 2. 检查协议是否为 http/https + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return errors.New("URL must use http or https protocol") + } + + // 3. 检查是否有空的主机名 + if parsed.Host == "" { + return errors.New("URL missing host") + } + + // 4. 检查路径是否为空(可选) + if strings.TrimSpace(parsed.Path) == "" { + return errors.New("URL path is empty") + } + + return nil +} diff --git a/internal/tools/normal_chat.go b/internal/tools/normal_chat.go index 489eb2e..9ef8903 100644 --- a/internal/tools/normal_chat.go +++ b/internal/tools/normal_chat.go @@ -56,6 +56,12 @@ func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.Requi // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) { + requireData.Ch <- entitys.Response{ + Index: w.Name(), + Content: "", + Type: entitys.ResponseStream, + } + err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ { Role: "system", @@ -67,7 +73,7 @@ func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat }, { Role: "user", - Content: requireData.UserInput, + Content: chat.ChatContent, }, }, w.Name()) if err != nil {