结构修改

This commit is contained in:
renzhiyuan 2025-10-09 11:38:37 +08:00
parent 220ad92c8b
commit f052000d59
5 changed files with 75 additions and 1 deletions

View File

@ -80,6 +80,13 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe
cancel() cancel()
}() }()
//获取图片信息
err = r.getImgData(req.Img, &requireData)
if err != nil {
log.Errorf("GetImgData error: %v", err)
return
}
//获取系统信息 //获取系统信息
err = r.getRequireData(req.Text, &requireData) err = r.getRequireData(req.Text, &requireData)
if err != nil { if err != nil {
@ -173,6 +180,32 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
} }
} }
func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) {
return
if len(imgUrl) == 0 {
return
}
if err = pkg.ValidateImageURL(imgUrl); err != nil {
return err
}
req := l_request.Request{
Method: "GET",
Url: imgUrl,
}
res, err := req.Send()
if err != nil {
return
}
if _, ex := res.Headers["Content-Type"]; !ex {
return errors.ParamErr("图片格式错误:Content-Type未获取")
}
if !strings.HasPrefix(res.Headers["Content-Type"], "image/") {
return errors.ParamErr("expected image content, got %s", res.Headers["Content-Type"])
}
requireData.ImgByte = append(requireData.ImgByte, res.Content)
return
}
func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) { func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
requireData.Ch <- entitys.Response{ requireData.Ch <- entitys.Response{
Index: "", Index: "",

View File

@ -46,6 +46,10 @@ func SysErr(message string, arg ...any) *BusinessErr {
return &BusinessErr{code: SystemError.code, message: fmt.Sprintf(message, arg)} return &BusinessErr{code: SystemError.code, message: fmt.Sprintf(message, arg)}
} }
func ParamErr(message string, arg ...any) *BusinessErr {
return &BusinessErr{code: ParamError.code, message: fmt.Sprintf(message, arg)}
}
func (e *BusinessErr) Wrap(err error) *BusinessErr { func (e *BusinessErr) Wrap(err error) *BusinessErr {
return NewBusinessErr(e.code, err.Error()) return NewBusinessErr(e.code, err.Error())
} }

View File

@ -151,6 +151,7 @@ type RequireData struct {
Auth string Auth string
Ch chan Response Ch chan Response
KnowledgeConf KnowledgeBaseRequest KnowledgeConf KnowledgeBaseRequest
ImgByte []api.ImageData
} }
type KnowledgeBaseRequest struct { type KnowledgeBaseRequest struct {

View File

@ -3,6 +3,10 @@ package pkg
import ( import (
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"encoding/json" "encoding/json"
"errors"
"fmt"
"net/url"
"strings"
) )
func JsonStringIgonErr(data interface{}) string { func JsonStringIgonErr(data interface{}) string {
@ -25,3 +29,29 @@ func IsChannelClosed(ch chan entitys.ResponseData) bool {
return false // channel 未关闭(但可能有数据未读取) return false // channel 未关闭(但可能有数据未读取)
} }
} }
// ValidateImageURL 验证图片 URL 是否有效
func ValidateImageURL(rawURL string) error {
// 1. 基础格式验证
parsed, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid URL format: %v", err)
}
// 2. 检查协议是否为 http/https
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return errors.New("URL must use http or https protocol")
}
// 3. 检查是否有空的主机名
if parsed.Host == "" {
return errors.New("URL missing host")
}
// 4. 检查路径是否为空(可选)
if strings.TrimSpace(parsed.Path) == "" {
return errors.New("URL path is empty")
}
return nil
}

View File

@ -56,6 +56,12 @@ func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.Requi
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据 // getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) { func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) {
requireData.Ch <- entitys.Response{
Index: w.Name(),
Content: "<think></think>",
Type: entitys.ResponseStream,
}
err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{
{ {
Role: "system", Role: "system",
@ -67,7 +73,7 @@ func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat
}, },
{ {
Role: "user", Role: "user",
Content: requireData.UserInput, Content: chat.ChatContent,
}, },
}, w.Name()) }, w.Name())
if err != nil { if err != nil {