123 lines
3.2 KiB
Go
123 lines
3.2 KiB
Go
package utils_vllm
|
|
|
|
import (
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/internal/pkg/util"
|
|
"context"
|
|
"encoding/base64"
|
|
|
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
|
"github.com/cloudwego/eino/components/model"
|
|
"github.com/cloudwego/eino/schema"
|
|
)
|
|
|
|
type Client struct {
|
|
vlModel *openai.ChatModel
|
|
generateModel *openai.ChatModel
|
|
config *config.Config
|
|
}
|
|
|
|
func NewClient(config *config.Config) (*Client, func(), error) {
|
|
// 初始化视觉模型
|
|
vl, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{
|
|
BaseURL: config.Vllm.VLModel.BaseURL,
|
|
Model: config.Vllm.VLModel.Model,
|
|
Timeout: config.Vllm.VLModel.Timeout,
|
|
})
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// 初始化生成模型
|
|
gen, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{
|
|
BaseURL: config.Vllm.TextModel.BaseURL,
|
|
Model: config.Vllm.TextModel.Model,
|
|
Timeout: config.Vllm.TextModel.Timeout,
|
|
ExtraFields: map[string]any{
|
|
"chat_template_kwargs": map[string]any{
|
|
"enable_thinking": false,
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
c := &Client{
|
|
vlModel: vl,
|
|
generateModel: gen,
|
|
config: config,
|
|
}
|
|
cleanup := func() {}
|
|
return c, cleanup, nil
|
|
}
|
|
|
|
func (c *Client) Chat(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) {
|
|
// 默认聊天使用生成模型
|
|
return c.generateModel.Generate(ctx, msgs)
|
|
}
|
|
|
|
func (c *Client) ToolSelect(ctx context.Context, msgs []*schema.Message, tools []*schema.ToolInfo) (*schema.Message, error) {
|
|
// 工具选择使用生成模型
|
|
return c.generateModel.Generate(ctx, msgs, model.WithTools(tools))
|
|
}
|
|
|
|
func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt string, imgURLs []string) (*schema.Message, error) {
|
|
// 图片识别使用视觉模型
|
|
in := []*schema.Message{
|
|
{
|
|
Role: schema.System,
|
|
Content: systemPrompt,
|
|
},
|
|
{
|
|
Role: schema.User,
|
|
},
|
|
}
|
|
parts := []schema.MessageInputPart{
|
|
{Type: schema.ChatMessagePartTypeText, Text: userPrompt},
|
|
}
|
|
for i := range imgURLs {
|
|
u := imgURLs[i]
|
|
parts = append(parts, schema.MessageInputPart{
|
|
Type: schema.ChatMessagePartTypeImageURL,
|
|
Image: &schema.MessageInputImage{
|
|
MessagePartCommon: schema.MessagePartCommon{URL: &u},
|
|
Detail: schema.ImageURLDetailHigh,
|
|
},
|
|
})
|
|
}
|
|
|
|
in[1].UserInputMultiContent = parts
|
|
return c.vlModel.Generate(ctx, in)
|
|
}
|
|
|
|
// 识别图片by二进制文件
|
|
func (c *Client) RecognizeWithImgBytes(ctx context.Context, systemPrompt, userPrompt string, imgBytes []byte, imgType string) (*schema.Message, error) {
|
|
// 图片识别使用视觉模型
|
|
in := []*schema.Message{
|
|
{
|
|
Role: schema.System,
|
|
Content: systemPrompt,
|
|
},
|
|
{
|
|
Role: schema.User,
|
|
},
|
|
}
|
|
parts := []schema.MessageInputPart{
|
|
{Type: schema.ChatMessagePartTypeText, Text: userPrompt},
|
|
}
|
|
parts = append(parts, schema.MessageInputPart{
|
|
Type: schema.ChatMessagePartTypeImageURL,
|
|
Image: &schema.MessageInputImage{
|
|
MessagePartCommon: schema.MessagePartCommon{
|
|
MIMEType: imgType,
|
|
Base64Data: util.AnyToPoint(base64.StdEncoding.EncodeToString(imgBytes)),
|
|
},
|
|
Detail: schema.ImageURLDetailHigh,
|
|
},
|
|
})
|
|
|
|
in[1].UserInputMultiContent = parts
|
|
return c.vlModel.Generate(ctx, in)
|
|
}
|