feat: 1.调整通用文件识别 2.新增eino vllm 图片字节识别方法 3.实现直连AI业务的图片识别接入
This commit is contained in:
parent
1d604a4af9
commit
ae34efb989
|
|
@ -14,7 +14,7 @@ ollama:
|
||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
base_url: "http://117.175.169.61:16001/v1"
|
base_url: "http://117.175.169.61:16001/v1"
|
||||||
vl_model: "models/Qwen2.5-VL-3B-Instruct-AWQ"
|
vl_model: "Qwen2.5-VL-3B-Instruct-AWQ"
|
||||||
timeout: "120s"
|
timeout: "120s"
|
||||||
level: "info"
|
level: "info"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ ollama:
|
||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
base_url: "http://host.docker.internal:8001/v1"
|
base_url: "http://host.docker.internal:8001/v1"
|
||||||
vl_model: "models/Qwen2.5-VL-3B-Instruct-AWQ"
|
vl_model: "Qwen2.5-VL-3B-Instruct-AWQ"
|
||||||
timeout: "120s"
|
timeout: "120s"
|
||||||
level: "info"
|
level: "info"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,11 @@ package do
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/biz/handle"
|
"ai_scheduler/internal/biz/handle"
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/data/constants"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/pkg"
|
"ai_scheduler/internal/pkg"
|
||||||
|
"ai_scheduler/internal/pkg/utils_vllm"
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
|
@ -15,6 +18,7 @@ type PromptOption interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type WithSys struct {
|
type WithSys struct {
|
||||||
|
Config *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) {
|
func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) {
|
||||||
|
|
@ -43,7 +47,7 @@ func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes
|
||||||
|
|
||||||
func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) {
|
func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) {
|
||||||
var hasFile bool
|
var hasFile bool
|
||||||
if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 {
|
if len(rec.UserContent.File) > 0 {
|
||||||
hasFile = true
|
hasFile = true
|
||||||
}
|
}
|
||||||
content.WriteString(rec.UserContent.Text)
|
content.WriteString(rec.UserContent.Text)
|
||||||
|
|
@ -67,13 +71,51 @@ func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (c
|
||||||
content.WriteString("### 文件内容:\n")
|
content.WriteString("### 文件内容:\n")
|
||||||
for _, file := range rec.UserContent.File {
|
for _, file := range rec.UserContent.File {
|
||||||
handle.HandleRecognizeFile(file)
|
handle.HandleRecognizeFile(file)
|
||||||
}
|
// 文件识别
|
||||||
|
switch file.FileType {
|
||||||
|
case constants.FileTypeImage:
|
||||||
|
entitys.ResLog(rec.Ch, "recognize_img_start", "图片识别中...")
|
||||||
|
var imageContent string
|
||||||
|
imageContent, err = f.recognizeWithImgVllm(ctx, file)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entitys.ResLog(rec.Ch, "recognize_img_end", "图片识别完成,识别内容:"+imageContent)
|
||||||
|
|
||||||
//...do something with file
|
// 解析结果回写到file
|
||||||
|
file.FileRec = imageContent
|
||||||
|
default:
|
||||||
|
content.WriteString(file.FileRec)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *WithSys) recognizeWithImgVllm(ctx context.Context, file *entitys.RecognizeFile) (content string, err error) {
|
||||||
|
if file.FileData == nil || file.FileType != constants.FileTypeImage {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client, cleanup, err := utils_vllm.NewClient(f.Config)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
outMsg, err := client.RecognizeWithImgBytes(ctx,
|
||||||
|
f.Config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||||
|
f.Config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||||
|
file.FileData,
|
||||||
|
file.FileRealMime,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return outMsg.Content, nil
|
||||||
|
}
|
||||||
|
|
||||||
type WithDingTalkBot struct {
|
type WithDingTalkBot struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -65,17 +65,14 @@ func HandleRecognizeFile(files *entitys.RecognizeFile) {
|
||||||
// 分支3:仅有数据、无类型→内容检测并填充
|
// 分支3:仅有数据、无类型→内容检测并填充
|
||||||
if len(files.FileData) > 0 && len(strings.TrimSpace(files.FileType.String())) == 0 {
|
if len(files.FileData) > 0 && len(strings.TrimSpace(files.FileType.String())) == 0 {
|
||||||
if len(files.FileData) > maxSize {
|
if len(files.FileData) > maxSize {
|
||||||
files.FileType = constants.Caller(constants.FileTypeUnknown)
|
files.FileType = constants.FileTypeUnknown
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
reader := bytes.NewReader(files.FileData)
|
reader := bytes.NewReader(files.FileData)
|
||||||
detected := detectFileType(reader, "")
|
detected, fileRealMime := detectFileType(reader, "")
|
||||||
if detected == constants.FileTypeUnknown {
|
files.FileType = detected
|
||||||
files.FileType = constants.Caller(constants.FileTypeUnknown)
|
files.FileRealMime = fileRealMime
|
||||||
return
|
|
||||||
}
|
|
||||||
files.FileType = constants.Caller(detected)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -83,18 +80,19 @@ func HandleRecognizeFile(files *entitys.RecognizeFile) {
|
||||||
if len(files.FileUrl) > 0 {
|
if len(files.FileUrl) > 0 {
|
||||||
fileBytes, contentType, err := downloadFile(files.FileUrl)
|
fileBytes, contentType, err := downloadFile(files.FileUrl)
|
||||||
if err != nil || len(fileBytes) == 0 {
|
if err != nil || len(fileBytes) == 0 {
|
||||||
files.FileType = constants.Caller(constants.FileTypeUnknown)
|
files.FileType = constants.FileTypeUnknown
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(fileBytes) > maxSize {
|
if len(fileBytes) > maxSize {
|
||||||
// 超限:不写入数据,类型置 unknown
|
// 超限:不写入数据,类型置 unknown
|
||||||
files.FileType = constants.Caller(constants.FileTypeUnknown)
|
files.FileType = constants.FileTypeUnknown
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 优先使用响应头的 Content-Type 映射
|
// 优先使用响应头的 Content-Type 映射
|
||||||
detected := mapToFileType(contentType)
|
detected := mapToFileType(contentType)
|
||||||
|
fileRealMime := contentType
|
||||||
|
|
||||||
if detected == constants.FileTypeUnknown {
|
if detected == constants.FileTypeUnknown {
|
||||||
// 回退:内容检测 + URL 文件名扩展名辅助
|
// 回退:内容检测 + URL 文件名扩展名辅助
|
||||||
|
|
@ -103,17 +101,13 @@ func HandleRecognizeFile(files *entitys.RecognizeFile) {
|
||||||
fname = filepath.Base(u.Path)
|
fname = filepath.Base(u.Path)
|
||||||
}
|
}
|
||||||
reader := bytes.NewReader(fileBytes)
|
reader := bytes.NewReader(fileBytes)
|
||||||
detected = detectFileType(reader, fname)
|
detected, fileRealMime = detectFileType(reader, fname)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 写入数据
|
// 写入数据
|
||||||
files.FileData = fileBytes
|
files.FileData = fileBytes
|
||||||
|
files.FileType = detected
|
||||||
if detected == constants.FileTypeUnknown {
|
files.FileRealMime = fileRealMime
|
||||||
files.FileType = constants.Caller(constants.FileTypeUnknown)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
files.FileType = constants.Caller(detected)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -150,7 +144,7 @@ func downloadFile(fileUrl string) (fileBytes []byte, contentType string, err err
|
||||||
}
|
}
|
||||||
|
|
||||||
// detectFileType 判断文件类型
|
// detectFileType 判断文件类型
|
||||||
func detectFileType(file io.ReadSeeker, filename string) constants.FileType {
|
func detectFileType(file io.ReadSeeker, filename string) (constants.FileType, string) {
|
||||||
// 1. 读取文件头检测 MIME
|
// 1. 读取文件头检测 MIME
|
||||||
buffer := make([]byte, 512)
|
buffer := make([]byte, 512)
|
||||||
n, _ := file.Read(buffer)
|
n, _ := file.Read(buffer)
|
||||||
|
|
@ -160,7 +154,7 @@ func detectFileType(file io.ReadSeeker, filename string) constants.FileType {
|
||||||
for fileType, items := range constants.FileTypeMappings {
|
for fileType, items := range constants.FileTypeMappings {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
if !strings.HasPrefix(item, ".") && item == detectedMIME {
|
if !strings.HasPrefix(item, ".") && item == detectedMIME {
|
||||||
return fileType
|
return fileType, detectedMIME
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -170,10 +164,10 @@ func detectFileType(file io.ReadSeeker, filename string) constants.FileType {
|
||||||
for fileType, items := range constants.FileTypeMappings {
|
for fileType, items := range constants.FileTypeMappings {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
if strings.HasPrefix(item, ".") && item == ext {
|
if strings.HasPrefix(item, ".") && item == ext {
|
||||||
return fileType
|
return fileType, ext
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return constants.FileTypeUnknown
|
return constants.FileTypeUnknown, ""
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package biz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/biz/do"
|
"ai_scheduler/internal/biz/do"
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/data/constants"
|
"ai_scheduler/internal/data/constants"
|
||||||
errors "ai_scheduler/internal/data/error"
|
errors "ai_scheduler/internal/data/error"
|
||||||
"ai_scheduler/internal/gateway"
|
"ai_scheduler/internal/gateway"
|
||||||
|
|
@ -21,16 +22,19 @@ import (
|
||||||
type AiRouterBiz struct {
|
type AiRouterBiz struct {
|
||||||
do *do.Do
|
do *do.Do
|
||||||
handle *do.Handle
|
handle *do.Handle
|
||||||
|
config *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAiRouterBiz 创建路由服务
|
// NewAiRouterBiz 创建路由服务
|
||||||
func NewAiRouterBiz(
|
func NewAiRouterBiz(
|
||||||
do *do.Do,
|
do *do.Do,
|
||||||
handle *do.Handle,
|
handle *do.Handle,
|
||||||
|
config *config.Config,
|
||||||
) *AiRouterBiz {
|
) *AiRouterBiz {
|
||||||
return &AiRouterBiz{
|
return &AiRouterBiz{
|
||||||
do: do,
|
do: do,
|
||||||
handle: handle,
|
handle: handle,
|
||||||
|
config: config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -94,7 +98,7 @@ func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireDa
|
||||||
// 对应不同的appKey, 配置不同的系统提示词
|
// 对应不同的appKey, 配置不同的系统提示词
|
||||||
switch requireData.Sys.AppKey {
|
switch requireData.Sys.AppKey {
|
||||||
default:
|
default:
|
||||||
sys = &do.WithSys{}
|
sys = &do.WithSys{Config: r.config}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. 系统提示词
|
// 1. 系统提示词
|
||||||
|
|
|
||||||
|
|
@ -46,3 +46,7 @@ var FileTypeMappings = map[FileType][]string{
|
||||||
".csv",
|
".csv",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ft FileType) String() string {
|
||||||
|
return string(ft)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -41,8 +41,9 @@ type RecognizeUserContent struct {
|
||||||
type FileData []byte
|
type FileData []byte
|
||||||
|
|
||||||
type RecognizeFile struct {
|
type RecognizeFile struct {
|
||||||
FileRec string //文件识别内容
|
FileRec string //文件识别内容
|
||||||
FileData FileData // 文件数据(二进制格式)
|
FileData FileData // 文件数据(二进制格式)
|
||||||
FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断)
|
FileType constants.FileType // 文件类型(文件类型,能填最好填,可以跳过一层判断)
|
||||||
FileUrl string // 文件下载链接
|
FileRealMime string // 文件真实MIME类型
|
||||||
|
FileUrl string // 文件下载链接
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
// AnyToPoint converts any value to a pointer.
|
||||||
|
func AnyToPoint[T any](v T) *T {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
@ -2,7 +2,9 @@ package utils_vllm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/pkg/util"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
|
||||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
|
|
@ -58,3 +60,31 @@ func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt
|
||||||
in[1].UserInputMultiContent = parts
|
in[1].UserInputMultiContent = parts
|
||||||
return c.model.Generate(ctx, in)
|
return c.model.Generate(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 识别图片by二进制文件
|
||||||
|
func (c *Client) RecognizeWithImgBytes(ctx context.Context, systemPrompt, userPrompt string, imgBytes []byte, imgType string) (*schema.Message, error) {
|
||||||
|
in := []*schema.Message{
|
||||||
|
{
|
||||||
|
Role: schema.System,
|
||||||
|
Content: systemPrompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: schema.User,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
parts := []schema.MessageInputPart{
|
||||||
|
{Type: schema.ChatMessagePartTypeText, Text: userPrompt},
|
||||||
|
}
|
||||||
|
parts = append(parts, schema.MessageInputPart{
|
||||||
|
Type: schema.ChatMessagePartTypeImageURL,
|
||||||
|
Image: &schema.MessageInputImage{
|
||||||
|
MessagePartCommon: schema.MessagePartCommon{
|
||||||
|
MIMEType: imgType,
|
||||||
|
Base64Data: util.AnyToPoint(base64.StdEncoding.EncodeToString(imgBytes)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
in[1].UserInputMultiContent = parts
|
||||||
|
return c.model.Generate(ctx, in)
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue