feat: 增加多模态参数
This commit is contained in:
parent
9c19e42600
commit
f240e1fac4
|
|
@ -13,6 +13,8 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/ollama/ollama/api"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
|
@ -147,6 +149,60 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit
|
|||
return
|
||||
}
|
||||
|
||||
func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
||||
if requireData.ImgByte == nil {
|
||||
return
|
||||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||
|
||||
chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{})
|
||||
if err != nil {
|
||||
return api.GenerateResponse{}, err
|
||||
}
|
||||
in := []*schema.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
UserInputMultiContent: []schema.MessageInputPart{
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, imgUrl := range requireData.ImgUrls {
|
||||
imgTmp := imgUrl
|
||||
|
||||
in[1].UserInputMultiContent = append(in[1].UserInputMultiContent, schema.MessageInputPart{
|
||||
Type: schema.ChatMessagePartTypeImageURL,
|
||||
Image: &schema.MessageInputImage{
|
||||
MessagePartCommon: schema.MessagePartCommon{
|
||||
URL: &imgTmp,
|
||||
},
|
||||
Detail: schema.ImageURLDetailHigh,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
outMsg, err := chatModel.Generate(ctx, in)
|
||||
if err != nil {
|
||||
return api.GenerateResponse{}, err
|
||||
}
|
||||
|
||||
desc = api.GenerateResponse{
|
||||
Response: outMsg.Content,
|
||||
}
|
||||
|
||||
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
|
||||
taskPrompt := make([]api.Tool, 0)
|
||||
for _, task := range tasks {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
type ImageInput struct {
|
||||
URLs []string
|
||||
}
|
||||
|
||||
func BuildVisionMessages(systemPrompt string, userText string, images ImageInput) ([]*schema.Message, error) {
|
||||
if len(images.URLs) == 0 {
|
||||
return nil, errors.New("vision requires at least one image url")
|
||||
}
|
||||
parts := make([]schema.MessageInputPart, 0, 1+len(images.URLs))
|
||||
if strings.TrimSpace(userText) != "" {
|
||||
parts = append(parts, schema.MessageInputPart{Type: schema.ChatMessagePartTypeText, Text: userText})
|
||||
}
|
||||
for _, u := range images.URLs {
|
||||
if u == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(u, "http://") && !strings.HasPrefix(u, "https://") {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, schema.MessageInputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &u}, Detail: schema.ImageURLDetailHigh}})
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return nil, errors.New("vision inputs invalid: no text or valid image urls")
|
||||
}
|
||||
msgs := []*schema.Message{schema.SystemMessage(systemPrompt)}
|
||||
msgs = append(msgs, &schema.Message{Role: schema.User, UserInputMultiContent: parts})
|
||||
return msgs, nil
|
||||
}
|
||||
|
|
@ -2,8 +2,8 @@ package llm
|
|||
|
||||
import "errors"
|
||||
|
||||
var ErrInvalidCapability = errors.New("invalid capability")
|
||||
var ErrProviderNotFound = errors.New("provider not found")
|
||||
var ErrModelNotFound = errors.New("model not found")
|
||||
var ErrModalityMismatch = errors.New("modality mismatch")
|
||||
var ErrInvalidCapability = errors.New("能力未配置或无效")
|
||||
var ErrProviderNotFound = errors.New("提供者未找到或未注册")
|
||||
var ErrModelNotFound = errors.New("模型未找到或未配置")
|
||||
var ErrModalityMismatch = errors.New("模态不匹配:视觉能力需要包含 image")
|
||||
var ErrNotImplemented = errors.New("not implemented")
|
||||
|
|
|
|||
|
|
@ -1,26 +1,78 @@
|
|||
package pipeline
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/domain/llm"
|
||||
"ai_scheduler/internal/domain/llm/capability"
|
||||
"ai_scheduler/internal/domain/llm/provider"
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/domain/common"
|
||||
"ai_scheduler/internal/domain/llm"
|
||||
"ai_scheduler/internal/domain/llm/capability"
|
||||
"ai_scheduler/internal/domain/llm/provider"
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func BuildVision(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) {
|
||||
choice, opts, err := capability.Route(cfg, capability.Vision)
|
||||
if err != nil { return nil, err }
|
||||
if err = capability.Validate(capability.Vision, opts); err != nil { return nil, err }
|
||||
f := provider.Get(choice.Provider)
|
||||
if f == nil { return nil, llm.ErrProviderNotFound }
|
||||
ad := f()
|
||||
c := compose.NewChain[[]*schema.Message, *schema.Message]()
|
||||
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) {
|
||||
return ad.Generate(ctx, in, opts)
|
||||
}))
|
||||
return c.Compile(ctx)
|
||||
choice, opts, err := capability.Route(cfg, capability.Vision)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = capability.Validate(capability.Vision, opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f := provider.Get(choice.Provider)
|
||||
if f == nil {
|
||||
return nil, llm.ErrProviderNotFound
|
||||
}
|
||||
ad := f()
|
||||
c := compose.NewChain[[]*schema.Message, *schema.Message]()
|
||||
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) {
|
||||
if len(in) == 0 {
|
||||
msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: []string{}})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ad.Generate(ctx, msgs, opts)
|
||||
}
|
||||
if len(in[0].MultiContent) == 0 {
|
||||
urls := []string{}
|
||||
for _, tok := range splitBySpace(in[0].Content) {
|
||||
if hasHTTPPrefix(tok) {
|
||||
urls = append(urls, tok)
|
||||
}
|
||||
}
|
||||
msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: urls})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ad.Generate(ctx, msgs, opts)
|
||||
}
|
||||
return ad.Generate(ctx, in, opts)
|
||||
}))
|
||||
return c.Compile(ctx)
|
||||
}
|
||||
|
||||
func splitBySpace(s string) []string {
|
||||
res := []string{}
|
||||
start := -1
|
||||
for i, r := range s {
|
||||
if r == ' ' || r == '\n' || r == '\t' || r == '\r' {
|
||||
if start >= 0 {
|
||||
res = append(res, s[start:i])
|
||||
start = -1
|
||||
}
|
||||
} else {
|
||||
if start < 0 {
|
||||
start = i
|
||||
}
|
||||
}
|
||||
}
|
||||
if start >= 0 {
|
||||
res = append(res, s[start:])
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func hasHTTPPrefix(s string) bool {
|
||||
return len(s) >= 7 && (s[:7] == "http://" || (len(s) >= 8 && s[:8] == "https://"))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
package prompt
|
||||
|
||||
func SystemForChat() string { return "You are a helpful assistant." }
|
||||
func SystemForVision() string { return "You are a vision assistant." }
|
||||
func SystemForIntent() string { return "You classify user intent." }
|
||||
|
||||
func SystemForChat() string {
|
||||
return "你是一名有用的助手,请用清晰、简洁的中文回答。"
|
||||
}
|
||||
func SystemForVision() string {
|
||||
return "你是一名视觉助手,请根据图片与描述进行中文理解与回答。"
|
||||
}
|
||||
func SystemForIntent() string {
|
||||
return "你负责意图识别,请用中文给出明确的意图类别与理由。"
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue