feat: 增加多模态参数
This commit is contained in:
parent
9c19e42600
commit
f240e1fac4
|
|
@ -13,6 +13,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"xorm.io/builder"
|
"xorm.io/builder"
|
||||||
)
|
)
|
||||||
|
|
@ -147,6 +149,60 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
||||||
|
if requireData.ImgByte == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||||
|
|
||||||
|
chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{})
|
||||||
|
if err != nil {
|
||||||
|
return api.GenerateResponse{}, err
|
||||||
|
}
|
||||||
|
in := []*schema.Message{
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||||
|
UserInputMultiContent: []schema.MessageInputPart{
|
||||||
|
{
|
||||||
|
Type: schema.ChatMessagePartTypeText,
|
||||||
|
Text: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, imgUrl := range requireData.ImgUrls {
|
||||||
|
imgTmp := imgUrl
|
||||||
|
|
||||||
|
in[1].UserInputMultiContent = append(in[1].UserInputMultiContent, schema.MessageInputPart{
|
||||||
|
Type: schema.ChatMessagePartTypeImageURL,
|
||||||
|
Image: &schema.MessageInputImage{
|
||||||
|
MessagePartCommon: schema.MessagePartCommon{
|
||||||
|
URL: &imgTmp,
|
||||||
|
},
|
||||||
|
Detail: schema.ImageURLDetailHigh,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
outMsg, err := chatModel.Generate(ctx, in)
|
||||||
|
if err != nil {
|
||||||
|
return api.GenerateResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
desc = api.GenerateResponse{
|
||||||
|
Response: outMsg.Content,
|
||||||
|
}
|
||||||
|
|
||||||
|
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
|
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
|
||||||
taskPrompt := make([]api.Tool, 0)
|
taskPrompt := make([]api.Tool, 0)
|
||||||
for _, task := range tasks {
|
for _, task := range tasks {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ImageInput struct {
|
||||||
|
URLs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildVisionMessages(systemPrompt string, userText string, images ImageInput) ([]*schema.Message, error) {
|
||||||
|
if len(images.URLs) == 0 {
|
||||||
|
return nil, errors.New("vision requires at least one image url")
|
||||||
|
}
|
||||||
|
parts := make([]schema.MessageInputPart, 0, 1+len(images.URLs))
|
||||||
|
if strings.TrimSpace(userText) != "" {
|
||||||
|
parts = append(parts, schema.MessageInputPart{Type: schema.ChatMessagePartTypeText, Text: userText})
|
||||||
|
}
|
||||||
|
for _, u := range images.URLs {
|
||||||
|
if u == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(u, "http://") && !strings.HasPrefix(u, "https://") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts = append(parts, schema.MessageInputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &u}, Detail: schema.ImageURLDetailHigh}})
|
||||||
|
}
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return nil, errors.New("vision inputs invalid: no text or valid image urls")
|
||||||
|
}
|
||||||
|
msgs := []*schema.Message{schema.SystemMessage(systemPrompt)}
|
||||||
|
msgs = append(msgs, &schema.Message{Role: schema.User, UserInputMultiContent: parts})
|
||||||
|
return msgs, nil
|
||||||
|
}
|
||||||
|
|
@ -2,8 +2,8 @@ package llm
|
||||||
|
|
||||||
import "errors"
|
import "errors"
|
||||||
|
|
||||||
var ErrInvalidCapability = errors.New("invalid capability")
|
var ErrInvalidCapability = errors.New("能力未配置或无效")
|
||||||
var ErrProviderNotFound = errors.New("provider not found")
|
var ErrProviderNotFound = errors.New("提供者未找到或未注册")
|
||||||
var ErrModelNotFound = errors.New("model not found")
|
var ErrModelNotFound = errors.New("模型未找到或未配置")
|
||||||
var ErrModalityMismatch = errors.New("modality mismatch")
|
var ErrModalityMismatch = errors.New("模态不匹配:视觉能力需要包含 image")
|
||||||
var ErrNotImplemented = errors.New("not implemented")
|
var ErrNotImplemented = errors.New("not implemented")
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,78 @@
|
||||||
package pipeline
|
package pipeline
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"ai_scheduler/internal/config"
|
||||||
"github.com/cloudwego/eino/compose"
|
"ai_scheduler/internal/domain/common"
|
||||||
"github.com/cloudwego/eino/schema"
|
"ai_scheduler/internal/domain/llm"
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/domain/llm/capability"
|
||||||
"ai_scheduler/internal/domain/llm"
|
"ai_scheduler/internal/domain/llm/provider"
|
||||||
"ai_scheduler/internal/domain/llm/capability"
|
"context"
|
||||||
"ai_scheduler/internal/domain/llm/provider"
|
|
||||||
|
"github.com/cloudwego/eino/compose"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BuildVision(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) {
|
func BuildVision(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) {
|
||||||
choice, opts, err := capability.Route(cfg, capability.Vision)
|
choice, opts, err := capability.Route(cfg, capability.Vision)
|
||||||
if err != nil { return nil, err }
|
if err != nil {
|
||||||
if err = capability.Validate(capability.Vision, opts); err != nil { return nil, err }
|
return nil, err
|
||||||
f := provider.Get(choice.Provider)
|
}
|
||||||
if f == nil { return nil, llm.ErrProviderNotFound }
|
if err = capability.Validate(capability.Vision, opts); err != nil {
|
||||||
ad := f()
|
return nil, err
|
||||||
c := compose.NewChain[[]*schema.Message, *schema.Message]()
|
}
|
||||||
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) {
|
f := provider.Get(choice.Provider)
|
||||||
return ad.Generate(ctx, in, opts)
|
if f == nil {
|
||||||
}))
|
return nil, llm.ErrProviderNotFound
|
||||||
return c.Compile(ctx)
|
}
|
||||||
|
ad := f()
|
||||||
|
c := compose.NewChain[[]*schema.Message, *schema.Message]()
|
||||||
|
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) {
|
||||||
|
if len(in) == 0 {
|
||||||
|
msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: []string{}})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ad.Generate(ctx, msgs, opts)
|
||||||
|
}
|
||||||
|
if len(in[0].MultiContent) == 0 {
|
||||||
|
urls := []string{}
|
||||||
|
for _, tok := range splitBySpace(in[0].Content) {
|
||||||
|
if hasHTTPPrefix(tok) {
|
||||||
|
urls = append(urls, tok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: urls})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ad.Generate(ctx, msgs, opts)
|
||||||
|
}
|
||||||
|
return ad.Generate(ctx, in, opts)
|
||||||
|
}))
|
||||||
|
return c.Compile(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func splitBySpace(s string) []string {
|
||||||
|
res := []string{}
|
||||||
|
start := -1
|
||||||
|
for i, r := range s {
|
||||||
|
if r == ' ' || r == '\n' || r == '\t' || r == '\r' {
|
||||||
|
if start >= 0 {
|
||||||
|
res = append(res, s[start:i])
|
||||||
|
start = -1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if start < 0 {
|
||||||
|
start = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if start >= 0 {
|
||||||
|
res = append(res, s[start:])
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasHTTPPrefix(s string) bool {
|
||||||
|
return len(s) >= 7 && (s[:7] == "http://" || (len(s) >= 8 && s[:8] == "https://"))
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,11 @@
|
||||||
package prompt
|
package prompt
|
||||||
|
|
||||||
func SystemForChat() string { return "You are a helpful assistant." }
|
func SystemForChat() string {
|
||||||
func SystemForVision() string { return "You are a vision assistant." }
|
return "你是一名有用的助手,请用清晰、简洁的中文回答。"
|
||||||
func SystemForIntent() string { return "You classify user intent." }
|
}
|
||||||
|
func SystemForVision() string {
|
||||||
|
return "你是一名视觉助手,请根据图片与描述进行中文理解与回答。"
|
||||||
|
}
|
||||||
|
func SystemForIntent() string {
|
||||||
|
return "你负责意图识别,请用中文给出明确的意图类别与理由。"
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue