feat: 增加多模态参数

This commit is contained in:
fuzhongyun 2025-12-05 09:11:38 +08:00
parent 9c19e42600
commit f240e1fac4
5 changed files with 176 additions and 26 deletions

View File

@ -13,6 +13,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/schema"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"xorm.io/builder" "xorm.io/builder"
) )
@ -147,6 +149,60 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit
return return
} }
func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
if requireData.ImgByte == nil {
return
}
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{})
if err != nil {
return api.GenerateResponse{}, err
}
in := []*schema.Message{
{
Role: "system",
Content: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
},
{
Role: "user",
Content: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
UserInputMultiContent: []schema.MessageInputPart{
{
Type: schema.ChatMessagePartTypeText,
Text: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
},
},
},
}
for _, imgUrl := range requireData.ImgUrls {
imgTmp := imgUrl
in[1].UserInputMultiContent = append(in[1].UserInputMultiContent, schema.MessageInputPart{
Type: schema.ChatMessagePartTypeImageURL,
Image: &schema.MessageInputImage{
MessagePartCommon: schema.MessagePartCommon{
URL: &imgTmp,
},
Detail: schema.ImageURLDetailHigh,
},
})
}
outMsg, err := chatModel.Generate(ctx, in)
if err != nil {
return api.GenerateResponse{}, err
}
desc = api.GenerateResponse{
Response: outMsg.Content,
}
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
return
}
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool { func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
taskPrompt := make([]api.Tool, 0) taskPrompt := make([]api.Tool, 0)
for _, task := range tasks { for _, task := range tasks {

View File

@ -0,0 +1,37 @@
package common
import (
"errors"
"strings"
"github.com/cloudwego/eino/schema"
)
type ImageInput struct {
URLs []string
}
func BuildVisionMessages(systemPrompt string, userText string, images ImageInput) ([]*schema.Message, error) {
if len(images.URLs) == 0 {
return nil, errors.New("vision requires at least one image url")
}
parts := make([]schema.MessageInputPart, 0, 1+len(images.URLs))
if strings.TrimSpace(userText) != "" {
parts = append(parts, schema.MessageInputPart{Type: schema.ChatMessagePartTypeText, Text: userText})
}
for _, u := range images.URLs {
if u == "" {
continue
}
if !strings.HasPrefix(u, "http://") && !strings.HasPrefix(u, "https://") {
continue
}
parts = append(parts, schema.MessageInputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &u}, Detail: schema.ImageURLDetailHigh}})
}
if len(parts) == 0 {
return nil, errors.New("vision inputs invalid: no text or valid image urls")
}
msgs := []*schema.Message{schema.SystemMessage(systemPrompt)}
msgs = append(msgs, &schema.Message{Role: schema.User, UserInputMultiContent: parts})
return msgs, nil
}

View File

@ -2,8 +2,8 @@ package llm
import "errors" import "errors"
var ErrInvalidCapability = errors.New("invalid capability") var ErrInvalidCapability = errors.New("能力未配置或无效")
var ErrProviderNotFound = errors.New("provider not found") var ErrProviderNotFound = errors.New("提供者未找到或未注册")
var ErrModelNotFound = errors.New("model not found") var ErrModelNotFound = errors.New("模型未找到或未配置")
var ErrModalityMismatch = errors.New("modality mismatch") var ErrModalityMismatch = errors.New("模态不匹配:视觉能力需要包含 image")
var ErrNotImplemented = errors.New("not implemented") var ErrNotImplemented = errors.New("not implemented")

View File

@ -1,26 +1,78 @@
package pipeline package pipeline
import ( import (
"context" "ai_scheduler/internal/config"
"github.com/cloudwego/eino/compose" "ai_scheduler/internal/domain/common"
"github.com/cloudwego/eino/schema" "ai_scheduler/internal/domain/llm"
"ai_scheduler/internal/config" "ai_scheduler/internal/domain/llm/capability"
"ai_scheduler/internal/domain/llm" "ai_scheduler/internal/domain/llm/provider"
"ai_scheduler/internal/domain/llm/capability" "context"
"ai_scheduler/internal/domain/llm/provider"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
) )
func BuildVision(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) { func BuildVision(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) {
choice, opts, err := capability.Route(cfg, capability.Vision) choice, opts, err := capability.Route(cfg, capability.Vision)
if err != nil { return nil, err } if err != nil {
if err = capability.Validate(capability.Vision, opts); err != nil { return nil, err } return nil, err
f := provider.Get(choice.Provider) }
if f == nil { return nil, llm.ErrProviderNotFound } if err = capability.Validate(capability.Vision, opts); err != nil {
ad := f() return nil, err
c := compose.NewChain[[]*schema.Message, *schema.Message]() }
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) { f := provider.Get(choice.Provider)
return ad.Generate(ctx, in, opts) if f == nil {
})) return nil, llm.ErrProviderNotFound
return c.Compile(ctx) }
ad := f()
c := compose.NewChain[[]*schema.Message, *schema.Message]()
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) {
if len(in) == 0 {
msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: []string{}})
if err != nil {
return nil, err
}
return ad.Generate(ctx, msgs, opts)
}
if len(in[0].MultiContent) == 0 {
urls := []string{}
for _, tok := range splitBySpace(in[0].Content) {
if hasHTTPPrefix(tok) {
urls = append(urls, tok)
}
}
msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: urls})
if err != nil {
return nil, err
}
return ad.Generate(ctx, msgs, opts)
}
return ad.Generate(ctx, in, opts)
}))
return c.Compile(ctx)
} }
func splitBySpace(s string) []string {
res := []string{}
start := -1
for i, r := range s {
if r == ' ' || r == '\n' || r == '\t' || r == '\r' {
if start >= 0 {
res = append(res, s[start:i])
start = -1
}
} else {
if start < 0 {
start = i
}
}
}
if start >= 0 {
res = append(res, s[start:])
}
return res
}
func hasHTTPPrefix(s string) bool {
return len(s) >= 7 && (s[:7] == "http://" || (len(s) >= 8 && s[:8] == "https://"))
}

View File

@ -1,6 +1,11 @@
package prompt package prompt
func SystemForChat() string { return "You are a helpful assistant." } func SystemForChat() string {
func SystemForVision() string { return "You are a vision assistant." } return "你是一名有用的助手,请用清晰、简洁的中文回答。"
func SystemForIntent() string { return "You classify user intent." } }
func SystemForVision() string {
return "你是一名视觉助手,请根据图片与描述进行中文理解与回答。"
}
func SystemForIntent() string {
return "你负责意图识别,请用中文给出明确的意图类别与理由。"
}