ai_scheduler/internal/pkg/utils_vllm/client.go

123 lines
3.2 KiB
Go

package utils_vllm
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/pkg/util"
"context"
"encoding/base64"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
)
type Client struct {
vlModel *openai.ChatModel
generateModel *openai.ChatModel
config *config.Config
}
func NewClient(config *config.Config) (*Client, func(), error) {
// 初始化视觉模型
vl, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{
BaseURL: config.Vllm.VLModel.BaseURL,
Model: config.Vllm.VLModel.Model,
Timeout: config.Vllm.VLModel.Timeout,
})
if err != nil {
return nil, nil, err
}
// 初始化生成模型
gen, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{
BaseURL: config.Vllm.TextModel.BaseURL,
Model: config.Vllm.TextModel.Model,
Timeout: config.Vllm.TextModel.Timeout,
ExtraFields: map[string]any{
"chat_template_kwargs": map[string]any{
"enable_thinking": false,
},
},
})
if err != nil {
return nil, nil, err
}
c := &Client{
vlModel: vl,
generateModel: gen,
config: config,
}
cleanup := func() {}
return c, cleanup, nil
}
func (c *Client) Chat(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) {
// 默认聊天使用生成模型
return c.generateModel.Generate(ctx, msgs)
}
func (c *Client) ToolSelect(ctx context.Context, msgs []*schema.Message, tools []*schema.ToolInfo) (*schema.Message, error) {
// 工具选择使用生成模型
return c.generateModel.Generate(ctx, msgs, model.WithTools(tools))
}
func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt string, imgURLs []string) (*schema.Message, error) {
// 图片识别使用视觉模型
in := []*schema.Message{
{
Role: schema.System,
Content: systemPrompt,
},
{
Role: schema.User,
},
}
parts := []schema.MessageInputPart{
{Type: schema.ChatMessagePartTypeText, Text: userPrompt},
}
for i := range imgURLs {
u := imgURLs[i]
parts = append(parts, schema.MessageInputPart{
Type: schema.ChatMessagePartTypeImageURL,
Image: &schema.MessageInputImage{
MessagePartCommon: schema.MessagePartCommon{URL: &u},
Detail: schema.ImageURLDetailHigh,
},
})
}
in[1].UserInputMultiContent = parts
return c.vlModel.Generate(ctx, in)
}
// 识别图片by二进制文件
func (c *Client) RecognizeWithImgBytes(ctx context.Context, systemPrompt, userPrompt string, imgBytes []byte, imgType string) (*schema.Message, error) {
// 图片识别使用视觉模型
in := []*schema.Message{
{
Role: schema.System,
Content: systemPrompt,
},
{
Role: schema.User,
},
}
parts := []schema.MessageInputPart{
{Type: schema.ChatMessagePartTypeText, Text: userPrompt},
}
parts = append(parts, schema.MessageInputPart{
Type: schema.ChatMessagePartTypeImageURL,
Image: &schema.MessageInputImage{
MessagePartCommon: schema.MessagePartCommon{
MIMEType: imgType,
Base64Data: util.AnyToPoint(base64.StdEncoding.EncodeToString(imgBytes)),
},
Detail: schema.ImageURLDetailHigh,
},
})
in[1].UserInputMultiContent = parts
return c.vlModel.Generate(ctx, in)
}