package utils_vllm import ( "ai_scheduler/internal/config" "ai_scheduler/internal/pkg/util" "context" "encoding/base64" "github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" ) type Client struct { vlModel *openai.ChatModel generateModel *openai.ChatModel config *config.Config } func NewClient(config *config.Config) (*Client, func(), error) { // 初始化视觉模型 vl, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{ BaseURL: config.Vllm.VLModel.BaseURL, Model: config.Vllm.VLModel.Model, Timeout: config.Vllm.VLModel.Timeout, }) if err != nil { return nil, nil, err } // 初始化生成模型 gen, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{ BaseURL: config.Vllm.TextModel.BaseURL, Model: config.Vllm.TextModel.Model, Timeout: config.Vllm.TextModel.Timeout, ExtraFields: map[string]any{ "chat_template_kwargs": map[string]any{ "enable_thinking": false, }, }, }) if err != nil { return nil, nil, err } c := &Client{ vlModel: vl, generateModel: gen, config: config, } cleanup := func() {} return c, cleanup, nil } func (c *Client) Chat(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) { // 默认聊天使用生成模型 return c.generateModel.Generate(ctx, msgs) } func (c *Client) ToolSelect(ctx context.Context, msgs []*schema.Message, tools []*schema.ToolInfo) (*schema.Message, error) { // 工具选择使用生成模型 return c.generateModel.Generate(ctx, msgs, model.WithTools(tools)) } func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt string, imgURLs []string) (*schema.Message, error) { // 图片识别使用视觉模型 in := []*schema.Message{ { Role: schema.System, Content: systemPrompt, }, { Role: schema.User, }, } parts := []schema.MessageInputPart{ {Type: schema.ChatMessagePartTypeText, Text: userPrompt}, } for i := range imgURLs { u := imgURLs[i] parts = append(parts, schema.MessageInputPart{ Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{ MessagePartCommon: schema.MessagePartCommon{URL: &u}, Detail: schema.ImageURLDetailHigh, }, }) } in[1].UserInputMultiContent = parts return c.vlModel.Generate(ctx, in) } // 识别图片by二进制文件 func (c *Client) RecognizeWithImgBytes(ctx context.Context, systemPrompt, userPrompt string, imgBytes []byte, imgType string) (*schema.Message, error) { // 图片识别使用视觉模型 in := []*schema.Message{ { Role: schema.System, Content: systemPrompt, }, { Role: schema.User, }, } parts := []schema.MessageInputPart{ {Type: schema.ChatMessagePartTypeText, Text: userPrompt}, } parts = append(parts, schema.MessageInputPart{ Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{ MessagePartCommon: schema.MessagePartCommon{ MIMEType: imgType, Base64Data: util.AnyToPoint(base64.StdEncoding.EncodeToString(imgBytes)), }, Detail: schema.ImageURLDetailHigh, }, }) in[1].UserInputMultiContent = parts return c.vlModel.Generate(ctx, in) }