79 lines
2.1 KiB
Go
79 lines
2.1 KiB
Go
package pipeline
|
|
|
|
import (
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/internal/domain/common"
|
|
"ai_scheduler/internal/domain/llm"
|
|
"ai_scheduler/internal/domain/llm/capability"
|
|
"ai_scheduler/internal/domain/llm/provider"
|
|
"context"
|
|
|
|
"github.com/cloudwego/eino/compose"
|
|
"github.com/cloudwego/eino/schema"
|
|
)
|
|
|
|
func BuildVision(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) {
|
|
choice, opts, err := capability.Route(cfg, capability.Vision)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err = capability.Validate(capability.Vision, opts); err != nil {
|
|
return nil, err
|
|
}
|
|
f := provider.Get(choice.Provider)
|
|
if f == nil {
|
|
return nil, llm.ErrProviderNotFound
|
|
}
|
|
ad := f()
|
|
c := compose.NewChain[[]*schema.Message, *schema.Message]()
|
|
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) {
|
|
if len(in) == 0 {
|
|
msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: []string{}})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ad.Generate(ctx, msgs, opts)
|
|
}
|
|
if len(in[0].MultiContent) == 0 {
|
|
urls := []string{}
|
|
for _, tok := range splitBySpace(in[0].Content) {
|
|
if hasHTTPPrefix(tok) {
|
|
urls = append(urls, tok)
|
|
}
|
|
}
|
|
msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: urls})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ad.Generate(ctx, msgs, opts)
|
|
}
|
|
return ad.Generate(ctx, in, opts)
|
|
}))
|
|
return c.Compile(ctx)
|
|
}
|
|
|
|
func splitBySpace(s string) []string {
|
|
res := []string{}
|
|
start := -1
|
|
for i, r := range s {
|
|
if r == ' ' || r == '\n' || r == '\t' || r == '\r' {
|
|
if start >= 0 {
|
|
res = append(res, s[start:i])
|
|
start = -1
|
|
}
|
|
} else {
|
|
if start < 0 {
|
|
start = i
|
|
}
|
|
}
|
|
}
|
|
if start >= 0 {
|
|
res = append(res, s[start:])
|
|
}
|
|
return res
|
|
}
|
|
|
|
func hasHTTPPrefix(s string) bool {
|
|
return len(s) >= 7 && (s[:7] == "http://" || (len(s) >= 8 && s[:8] == "https://"))
|
|
}
|