diff --git a/config/config_env.yaml b/config/config_env.yaml index 188e5f9..23a123b 100644 --- a/config/config_env.yaml +++ b/config/config_env.yaml @@ -14,7 +14,7 @@ ollama: vllm: base_url: "http://117.175.169.61:16001/v1" - vl_model: "models/Qwen2.5-VL-3B-Instruct-AWQ" + vl_model: "Qwen2.5-VL-3B-Instruct-AWQ" timeout: "120s" level: "info" diff --git a/config/config_test.yaml b/config/config_test.yaml index 5fa0dbd..ab91824 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -15,7 +15,7 @@ ollama: vllm: base_url: "http://host.docker.internal:8001/v1" - vl_model: "models/Qwen2.5-VL-3B-Instruct-AWQ" + vl_model: "Qwen2.5-VL-3B-Instruct-AWQ" timeout: "120s" level: "info" diff --git a/internal/biz/do/prompt.go b/internal/biz/do/prompt.go index ea82f84..cada763 100644 --- a/internal/biz/do/prompt.go +++ b/internal/biz/do/prompt.go @@ -2,8 +2,11 @@ package do import ( "ai_scheduler/internal/biz/handle" + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/utils_vllm" "context" "strings" @@ -15,6 +18,7 @@ type PromptOption interface { } type WithSys struct { + Config *config.Config } func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) { @@ -43,7 +47,7 @@ func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { var hasFile bool - if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 { + if len(rec.UserContent.File) > 0 { hasFile = true } content.WriteString(rec.UserContent.Text) @@ -67,13 +71,51 @@ func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (c content.WriteString("### 文件内容:\n") for _, file := range rec.UserContent.File { handle.HandleRecognizeFile(file) - } + // 文件识别 + switch file.FileType { + case constants.FileTypeImage: + entitys.ResLog(rec.Ch, "recognize_img_start", "图片识别中...") + var imageContent string + imageContent, err = f.recognizeWithImgVllm(ctx, file) + if err != nil { + return + } + entitys.ResLog(rec.Ch, "recognize_img_end", "图片识别完成,识别内容:"+imageContent) - //...do something with file + // 解析结果回写到file + file.FileRec = imageContent + default: + content.WriteString(file.FileRec) + } + } } return } +func (f *WithSys) recognizeWithImgVllm(ctx context.Context, file *entitys.RecognizeFile) (content string, err error) { + if file.FileData == nil || file.FileType != constants.FileTypeImage { + return + } + + client, cleanup, err := utils_vllm.NewClient(f.Config) + if err != nil { + return "", err + } + defer cleanup() + + outMsg, err := client.RecognizeWithImgBytes(ctx, + f.Config.DefaultPrompt.ImgRecognize.SystemPrompt, + f.Config.DefaultPrompt.ImgRecognize.UserPrompt, + file.FileData, + file.FileRealMime, + ) + if err != nil { + return "", err + } + + return outMsg.Content, nil +} + type WithDingTalkBot struct { } diff --git a/internal/biz/handle/file.go b/internal/biz/handle/file.go index 882ab86..e93fbc3 100644 --- a/internal/biz/handle/file.go +++ b/internal/biz/handle/file.go @@ -65,17 +65,14 @@ func HandleRecognizeFile(files *entitys.RecognizeFile) { // 分支3:仅有数据、无类型→内容检测并填充 if len(files.FileData) > 0 && len(strings.TrimSpace(files.FileType.String())) == 0 { if len(files.FileData) > maxSize { - files.FileType = constants.Caller(constants.FileTypeUnknown) + files.FileType = constants.FileTypeUnknown return } reader := bytes.NewReader(files.FileData) - detected := detectFileType(reader, "") - if detected == constants.FileTypeUnknown { - files.FileType = constants.Caller(constants.FileTypeUnknown) - return - } - files.FileType = constants.Caller(detected) + detected, fileRealMime := detectFileType(reader, "") + files.FileType = detected + files.FileRealMime = fileRealMime return } @@ -83,18 +80,19 @@ func HandleRecognizeFile(files *entitys.RecognizeFile) { if len(files.FileUrl) > 0 { fileBytes, contentType, err := downloadFile(files.FileUrl) if err != nil || len(fileBytes) == 0 { - files.FileType = constants.Caller(constants.FileTypeUnknown) + files.FileType = constants.FileTypeUnknown return } if len(fileBytes) > maxSize { // 超限:不写入数据,类型置 unknown - files.FileType = constants.Caller(constants.FileTypeUnknown) + files.FileType = constants.FileTypeUnknown return } // 优先使用响应头的 Content-Type 映射 detected := mapToFileType(contentType) + fileRealMime := contentType if detected == constants.FileTypeUnknown { // 回退:内容检测 + URL 文件名扩展名辅助 @@ -103,17 +101,13 @@ func HandleRecognizeFile(files *entitys.RecognizeFile) { fname = filepath.Base(u.Path) } reader := bytes.NewReader(fileBytes) - detected = detectFileType(reader, fname) + detected, fileRealMime = detectFileType(reader, fname) } // 写入数据 files.FileData = fileBytes - - if detected == constants.FileTypeUnknown { - files.FileType = constants.Caller(constants.FileTypeUnknown) - return - } - files.FileType = constants.Caller(detected) + files.FileType = detected + files.FileRealMime = fileRealMime return } } @@ -150,7 +144,7 @@ func downloadFile(fileUrl string) (fileBytes []byte, contentType string, err err } // detectFileType 判断文件类型 -func detectFileType(file io.ReadSeeker, filename string) constants.FileType { +func detectFileType(file io.ReadSeeker, filename string) (constants.FileType, string) { // 1. 读取文件头检测 MIME buffer := make([]byte, 512) n, _ := file.Read(buffer) @@ -160,7 +154,7 @@ func detectFileType(file io.ReadSeeker, filename string) constants.FileType { for fileType, items := range constants.FileTypeMappings { for _, item := range items { if !strings.HasPrefix(item, ".") && item == detectedMIME { - return fileType + return fileType, detectedMIME } } } @@ -170,10 +164,10 @@ func detectFileType(file io.ReadSeeker, filename string) constants.FileType { for fileType, items := range constants.FileTypeMappings { for _, item := range items { if strings.HasPrefix(item, ".") && item == ext { - return fileType + return fileType, ext } } } - return constants.FileTypeUnknown + return constants.FileTypeUnknown, "" } diff --git a/internal/biz/router.go b/internal/biz/router.go index 28ce0b8..7d045f8 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -2,6 +2,7 @@ package biz import ( "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/gateway" @@ -21,16 +22,19 @@ import ( type AiRouterBiz struct { do *do.Do handle *do.Handle + config *config.Config } // NewAiRouterBiz 创建路由服务 func NewAiRouterBiz( do *do.Do, handle *do.Handle, + config *config.Config, ) *AiRouterBiz { return &AiRouterBiz{ do: do, handle: handle, + config: config, } } @@ -94,7 +98,7 @@ func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireDa // 对应不同的appKey, 配置不同的系统提示词 switch requireData.Sys.AppKey { default: - sys = &do.WithSys{} + sys = &do.WithSys{Config: r.config} } // 1. 系统提示词 diff --git a/internal/data/constants/file.go b/internal/data/constants/file.go index b825971..f234572 100644 --- a/internal/data/constants/file.go +++ b/internal/data/constants/file.go @@ -46,3 +46,7 @@ var FileTypeMappings = map[FileType][]string{ ".csv", }, } + +func (ft FileType) String() string { + return string(ft) +} diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index 831bef7..87f684c 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -41,8 +41,9 @@ type RecognizeUserContent struct { type FileData []byte type RecognizeFile struct { - FileRec string //文件识别内容 - FileData FileData // 文件数据(二进制格式) - FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断) - FileUrl string // 文件下载链接 + FileRec string //文件识别内容 + FileData FileData // 文件数据(二进制格式) + FileType constants.FileType // 文件类型(文件类型,能填最好填,可以跳过一层判断) + FileRealMime string // 文件真实MIME类型 + FileUrl string // 文件下载链接 } diff --git a/internal/pkg/util/point.go b/internal/pkg/util/point.go new file mode 100644 index 0000000..beceddf --- /dev/null +++ b/internal/pkg/util/point.go @@ -0,0 +1,6 @@ +package util + +// AnyToPoint converts any value to a pointer. +func AnyToPoint[T any](v T) *T { + return &v +} diff --git a/internal/pkg/utils_vllm/client.go b/internal/pkg/utils_vllm/client.go index c333350..c8c4aec 100644 --- a/internal/pkg/utils_vllm/client.go +++ b/internal/pkg/utils_vllm/client.go @@ -2,7 +2,9 @@ package utils_vllm import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/util" "context" + "encoding/base64" "github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino/schema" @@ -58,3 +60,31 @@ func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt in[1].UserInputMultiContent = parts return c.model.Generate(ctx, in) } + +// 识别图片by二进制文件 +func (c *Client) RecognizeWithImgBytes(ctx context.Context, systemPrompt, userPrompt string, imgBytes []byte, imgType string) (*schema.Message, error) { + in := []*schema.Message{ + { + Role: schema.System, + Content: systemPrompt, + }, + { + Role: schema.User, + }, + } + parts := []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: userPrompt}, + } + parts = append(parts, schema.MessageInputPart{ + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + MIMEType: imgType, + Base64Data: util.AnyToPoint(base64.StdEncoding.EncodeToString(imgBytes)), + }, + }, + }) + + in[1].UserInputMultiContent = parts + return c.model.Generate(ctx, in) +}