diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 60d9c78..98100fa 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -13,6 +13,8 @@ import ( "strings" "time" + "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/schema" "github.com/ollama/ollama/api" "xorm.io/builder" ) @@ -147,6 +149,60 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit return } +func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { + if requireData.ImgByte == nil { + return + } + entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") + + chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{}) + if err != nil { + return api.GenerateResponse{}, err + } + in := []*schema.Message{ + { + Role: "system", + Content: r.config.DefaultPrompt.ImgRecognize.SystemPrompt, + }, + { + Role: "user", + Content: r.config.DefaultPrompt.ImgRecognize.UserPrompt, + UserInputMultiContent: []schema.MessageInputPart{ + { + Type: schema.ChatMessagePartTypeText, + Text: r.config.DefaultPrompt.ImgRecognize.UserPrompt, + }, + }, + }, + } + + for _, imgUrl := range requireData.ImgUrls { + imgTmp := imgUrl + + in[1].UserInputMultiContent = append(in[1].UserInputMultiContent, schema.MessageInputPart{ + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + URL: &imgTmp, + }, + Detail: schema.ImageURLDetailHigh, + }, + }) + } + + outMsg, err := chatModel.Generate(ctx, in) + if err != nil { + return api.GenerateResponse{}, err + } + + desc = api.GenerateResponse{ + Response: outMsg.Content, + } + + entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) + return +} + func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool { taskPrompt := make([]api.Tool, 0) for _, task := range tasks { diff --git a/internal/domain/common/vision_builder.go b/internal/domain/common/vision_builder.go new file mode 100644 index 0000000..66153bf --- /dev/null +++ b/internal/domain/common/vision_builder.go @@ -0,0 +1,37 @@ +package common + +import ( + "errors" + "strings" + + "github.com/cloudwego/eino/schema" +) + +type ImageInput struct { + URLs []string +} + +func BuildVisionMessages(systemPrompt string, userText string, images ImageInput) ([]*schema.Message, error) { + if len(images.URLs) == 0 { + return nil, errors.New("vision requires at least one image url") + } + parts := make([]schema.MessageInputPart, 0, 1+len(images.URLs)) + if strings.TrimSpace(userText) != "" { + parts = append(parts, schema.MessageInputPart{Type: schema.ChatMessagePartTypeText, Text: userText}) + } + for _, u := range images.URLs { + if u == "" { + continue + } + if !strings.HasPrefix(u, "http://") && !strings.HasPrefix(u, "https://") { + continue + } + parts = append(parts, schema.MessageInputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &u}, Detail: schema.ImageURLDetailHigh}}) + } + if len(parts) == 0 { + return nil, errors.New("vision inputs invalid: no text or valid image urls") + } + msgs := []*schema.Message{schema.SystemMessage(systemPrompt)} + msgs = append(msgs, &schema.Message{Role: schema.User, UserInputMultiContent: parts}) + return msgs, nil +} diff --git a/internal/domain/llm/errors.go b/internal/domain/llm/errors.go index 59d76e6..366cf33 100644 --- a/internal/domain/llm/errors.go +++ b/internal/domain/llm/errors.go @@ -2,8 +2,8 @@ package llm import "errors" -var ErrInvalidCapability = errors.New("invalid capability") -var ErrProviderNotFound = errors.New("provider not found") -var ErrModelNotFound = errors.New("model not found") -var ErrModalityMismatch = errors.New("modality mismatch") +var ErrInvalidCapability = errors.New("能力未配置或无效") +var ErrProviderNotFound = errors.New("提供者未找到或未注册") +var ErrModelNotFound = errors.New("模型未找到或未配置") +var ErrModalityMismatch = errors.New("模态不匹配:视觉能力需要包含 image") var ErrNotImplemented = errors.New("not implemented") diff --git a/internal/domain/llm/pipeline/vision.go b/internal/domain/llm/pipeline/vision.go index 3563048..12f7981 100644 --- a/internal/domain/llm/pipeline/vision.go +++ b/internal/domain/llm/pipeline/vision.go @@ -1,26 +1,78 @@ package pipeline import ( - "context" - "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/schema" - "ai_scheduler/internal/config" - "ai_scheduler/internal/domain/llm" - "ai_scheduler/internal/domain/llm/capability" - "ai_scheduler/internal/domain/llm/provider" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/common" + "ai_scheduler/internal/domain/llm" + "ai_scheduler/internal/domain/llm/capability" + "ai_scheduler/internal/domain/llm/provider" + "context" + + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" ) func BuildVision(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) { - choice, opts, err := capability.Route(cfg, capability.Vision) - if err != nil { return nil, err } - if err = capability.Validate(capability.Vision, opts); err != nil { return nil, err } - f := provider.Get(choice.Provider) - if f == nil { return nil, llm.ErrProviderNotFound } - ad := f() - c := compose.NewChain[[]*schema.Message, *schema.Message]() - c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) { - return ad.Generate(ctx, in, opts) - })) - return c.Compile(ctx) + choice, opts, err := capability.Route(cfg, capability.Vision) + if err != nil { + return nil, err + } + if err = capability.Validate(capability.Vision, opts); err != nil { + return nil, err + } + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + ad := f() + c := compose.NewChain[[]*schema.Message, *schema.Message]() + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) { + if len(in) == 0 { + msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: []string{}}) + if err != nil { + return nil, err + } + return ad.Generate(ctx, msgs, opts) + } + if len(in[0].MultiContent) == 0 { + urls := []string{} + for _, tok := range splitBySpace(in[0].Content) { + if hasHTTPPrefix(tok) { + urls = append(urls, tok) + } + } + msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: urls}) + if err != nil { + return nil, err + } + return ad.Generate(ctx, msgs, opts) + } + return ad.Generate(ctx, in, opts) + })) + return c.Compile(ctx) } +func splitBySpace(s string) []string { + res := []string{} + start := -1 + for i, r := range s { + if r == ' ' || r == '\n' || r == '\t' || r == '\r' { + if start >= 0 { + res = append(res, s[start:i]) + start = -1 + } + } else { + if start < 0 { + start = i + } + } + } + if start >= 0 { + res = append(res, s[start:]) + } + return res +} + +func hasHTTPPrefix(s string) bool { + return len(s) >= 7 && (s[:7] == "http://" || (len(s) >= 8 && s[:8] == "https://")) +} diff --git a/internal/domain/llm/prompt/templates.go b/internal/domain/llm/prompt/templates.go index 47da00a..47674a5 100644 --- a/internal/domain/llm/prompt/templates.go +++ b/internal/domain/llm/prompt/templates.go @@ -1,6 +1,11 @@ package prompt -func SystemForChat() string { return "You are a helpful assistant." } -func SystemForVision() string { return "You are a vision assistant." } -func SystemForIntent() string { return "You classify user intent." } - +func SystemForChat() string { + return "你是一名有用的助手,请用清晰、简洁的中文回答。" +} +func SystemForVision() string { + return "你是一名视觉助手,请根据图片与描述进行中文理解与回答。" +} +func SystemForIntent() string { + return "你负责意图识别,请用中文给出明确的意图类别与理由。" +}