From ca671694f95cf20416743ee90ef01e7e351e5021 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Sat, 28 Feb 2026 14:16:53 +0800 Subject: [PATCH] =?UTF-8?q?fix:=201.=20=E8=B0=83=E6=95=B4vllm=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E5=8F=8A=E5=85=B6=E7=9B=B8=E5=85=B3=202.=20=E6=84=8F?= =?UTF-8?q?=E5=9B=BE=E8=AF=86=E5=88=AB=E6=A8=A1=E5=9E=8B=E5=88=87=E6=8D=A2?= =?UTF-8?q?=20ollama=20->=20vllm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.yaml | 14 ++- config/config_env.yaml | 14 ++- config/config_test.yaml | 14 ++- internal/biz/do/handle.go | 6 +- internal/biz/llm_service/vllm.go | 153 ++++++++++++++++++++++++++++++ internal/biz/provider_set.go | 1 + internal/config/config.go | 7 +- internal/pkg/utils_vllm/client.go | 52 ++++++++-- 8 files changed, 237 insertions(+), 24 deletions(-) create mode 100644 internal/biz/llm_service/vllm.go diff --git a/config/config.yaml b/config/config.yaml index 57ce0fc..0636216 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -17,10 +17,16 @@ ollama: format: "json" vllm: - base_url: "http://172.17.0.1:8001/v1" - vl_model: "qwen2.5-vl-3b-awq" - timeout: "120s" - level: "info" + vl_model: + base_url: "http://192.168.6.115:8001/v1" + model: "qwen2.5-vl-3b-awq" + timeout: "120s" + level: "info" + text_model: + base_url: "http://192.168.6.115:8002/v1" + model: "qwen3-8b-fp8" + timeout: "120s" + level: "info" coze: base_url: "https://api.coze.cn" diff --git a/config/config_env.yaml b/config/config_env.yaml index 532191a..00c5349 100644 --- a/config/config_env.yaml +++ b/config/config_env.yaml @@ -14,10 +14,16 @@ ollama: format: "json" vllm: - base_url: "http://117.175.169.61:16001/v1" - vl_model: "qwen2.5-vl-3b-awq" - timeout: "120s" - level: "info" + vl_model: + base_url: "http://192.168.6.115:8001/v1" + model: "qwen2.5-vl-3b-awq" + timeout: "120s" + level: "info" + text_model: + base_url: "http://192.168.6.115:8002/v1" + model: "qwen3-8b-fp8" + timeout: "120s" + level: "info" coze: base_url: "https://api.coze.cn" diff --git a/config/config_test.yaml b/config/config_test.yaml index 563cd0e..82e7df1 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -14,10 +14,16 @@ ollama: format: "json" vllm: - base_url: "http://host.docker.internal:8001/v1" - vl_model: "qwen2.5-vl-3b-awq" - timeout: "120s" - level: "info" + vl_model: + base_url: "http://192.168.6.115:8001/v1" + model: "qwen2.5-vl-3b-awq" + timeout: "120s" + level: "info" + text_model: + base_url: "http://192.168.6.115:8002/v1" + model: "qwen3-8b-fp8" + timeout: "120s" + level: "info" coze: base_url: "https://api.coze.cn" diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index e312d04..828c059 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -36,6 +36,7 @@ import ( type Handle struct { Ollama *llm_service.OllamaService + Vllm *llm_service.VllmService toolManager *tools.Manager conf *config.Config sessionImpl *impl.SessionImpl @@ -47,6 +48,7 @@ type Handle struct { func NewHandle( Ollama *llm_service.OllamaService, + Vllm *llm_service.VllmService, toolManager *tools.Manager, conf *config.Config, sessionImpl *impl.SessionImpl, @@ -57,6 +59,7 @@ func NewHandle( ) *Handle { return &Handle{ Ollama: Ollama, + Vllm: Vllm, toolManager: toolManager, conf: conf, sessionImpl: sessionImpl, @@ -72,7 +75,8 @@ func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptPr prompt, err := promptProcessor.CreatePrompt(ctx, rec) //意图识别 - recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{ + // recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{ + recognizeMsg, err := r.Vllm.IntentRecognize(ctx, &entitys.ToolSelect{ Prompt: prompt, Tools: rec.Tasks, }) diff --git a/internal/biz/llm_service/vllm.go b/internal/biz/llm_service/vllm.go new file mode 100644 index 0000000..c054f27 --- /dev/null +++ b/internal/biz/llm_service/vllm.go @@ -0,0 +1,153 @@ +package llm_service + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/utils_vllm" + "context" + "encoding/base64" + "errors" + "fmt" + "strings" + + "github.com/cloudwego/eino/schema" + "github.com/ollama/ollama/api" +) + +type VllmService struct { + client *utils_vllm.Client + config *config.Config +} + +func NewVllmService( + client *utils_vllm.Client, + config *config.Config, +) *VllmService { + return &VllmService{ + client: client, + config: config, + } +} + +func (s *VllmService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) { + msgs := s.convertMessages(req.Prompt) + tools := s.convertTools(req.Tools) + + resp, err := s.client.ToolSelect(ctx, msgs, tools) + if err != nil { + return + } + + if resp.Content == "" { + if len(resp.ToolCalls) > 0 { + call := resp.ToolCalls[0] + var matchFromTools = &entitys.Match{ + Confidence: 1, + Index: call.Function.Name, + Parameters: call.Function.Arguments, + IsMatch: true, + } + msg = pkg.JsonStringIgonErr(matchFromTools) + } else { + err = errors.New("不太明白你想表达的意思呢,可以在仔细描述一下您所需要的内容吗,感谢感谢") + return + } + } else { + msg = resp.Content + } + return +} + +func (s *VllmService) convertMessages(prompts []api.Message) []*schema.Message { + msgs := make([]*schema.Message, 0, len(prompts)) + for _, p := range prompts { + msg := &schema.Message{ + Role: schema.RoleType(p.Role), + Content: p.Content, + } + + // 这里实际应该不会走进来 + if len(p.Images) > 0 { + parts := []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: p.Content}, + } + for _, imgData := range p.Images { + b64 := base64.StdEncoding.EncodeToString(imgData) + mimeType := "image/jpeg" + parts = append(parts, schema.MessageInputPart{ + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + MIMEType: mimeType, + Base64Data: &b64, + }, + }, + }) + } + msg.UserInputMultiContent = parts + } + msgs = append(msgs, msg) + } + return msgs +} + +func (s *VllmService) convertTools(tasks []entitys.RegistrationTask) []*schema.ToolInfo { + tools := make([]*schema.ToolInfo, 0, len(tasks)) + for _, task := range tasks { + params := make(map[string]*schema.ParameterInfo) + for k, v := range task.TaskConfigDetail.Param.Properties { + dt := schema.String + + // Handle v.Type dynamically to support both string and []string (compiler suggests []string) + // Using fmt.Sprint handles both cases safely without knowing exact type structure + typeStr := fmt.Sprintf("%v", v.Type) + typeStr = strings.Trim(typeStr, "[]") // normalize "[string]" -> "string" + + switch typeStr { + case "string": + dt = schema.String + case "integer", "int": + dt = schema.Integer + case "number", "float": + dt = schema.Number + case "boolean", "bool": + dt = schema.Boolean + case "object": + dt = schema.Object + case "array": + dt = schema.Array + } + + required := false + for _, r := range task.TaskConfigDetail.Param.Required { + if r == k { + required = true + break + } + } + + desc := v.Description + if len(v.Enum) > 0 { + var enumStrs []string + for _, e := range v.Enum { + enumStrs = append(enumStrs, fmt.Sprintf("%v", e)) + } + desc += " Enum: " + strings.Join(enumStrs, ", ") + } + + params[k] = &schema.ParameterInfo{ + Type: dt, + Desc: desc, + Required: required, + } + } + + tools = append(tools, &schema.ToolInfo{ + Name: task.Name, + Desc: task.Desc, + ParamsOneOf: schema.NewParamsOneOfByParams(params), + }) + } + return tools +} diff --git a/internal/biz/provider_set.go b/internal/biz/provider_set.go index 6f95898..2054c54 100644 --- a/internal/biz/provider_set.go +++ b/internal/biz/provider_set.go @@ -13,6 +13,7 @@ var ProviderSetBiz = wire.NewSet( NewChatHistoryBiz, //llm_service.NewLangChainGenerate, llm_service.NewOllamaGenerate, + llm_service.NewVllmService, //handle.NewHandle, do.NewDo, do.NewHandle, diff --git a/internal/config/config.go b/internal/config/config.go index 6e9a2e6..52932bb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -122,8 +122,13 @@ type OllamaConfig struct { } type VllmConfig struct { + VLModel VllmModel `mapstructure:"vl_model"` + TextModel VllmModel `mapstructure:"text_model"` +} + +type VllmModel struct { BaseURL string `mapstructure:"base_url"` - VlModel string `mapstructure:"vl_model"` + Model string `mapstructure:"model"` Timeout time.Duration `mapstructure:"timeout"` Level string `mapstructure:"level"` } diff --git a/internal/pkg/utils_vllm/client.go b/internal/pkg/utils_vllm/client.go index c8c4aec..6926b21 100644 --- a/internal/pkg/utils_vllm/client.go +++ b/internal/pkg/utils_vllm/client.go @@ -7,33 +7,63 @@ import ( "encoding/base64" "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" ) type Client struct { - model *openai.ChatModel - config *config.Config + vlModel *openai.ChatModel + generateModel *openai.ChatModel + config *config.Config } func NewClient(config *config.Config) (*Client, func(), error) { - m, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{ - BaseURL: config.Vllm.BaseURL, - Model: config.Vllm.VlModel, - Timeout: config.Vllm.Timeout, + // 初始化视觉模型 + vl, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{ + BaseURL: config.Vllm.VLModel.BaseURL, + Model: config.Vllm.VLModel.Model, + Timeout: config.Vllm.VLModel.Timeout, }) if err != nil { return nil, nil, err } - c := &Client{model: m, config: config} + + // 初始化生成模型 + gen, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{ + BaseURL: config.Vllm.TextModel.BaseURL, + Model: config.Vllm.TextModel.Model, + Timeout: config.Vllm.TextModel.Timeout, + ExtraFields: map[string]any{ + "chat_template_kwargs": map[string]any{ + "enable_thinking": false, + }, + }, + }) + if err != nil { + return nil, nil, err + } + + c := &Client{ + vlModel: vl, + generateModel: gen, + config: config, + } cleanup := func() {} return c, cleanup, nil } func (c *Client) Chat(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) { - return c.model.Generate(ctx, msgs) + // 默认聊天使用生成模型 + return c.generateModel.Generate(ctx, msgs) +} + +func (c *Client) ToolSelect(ctx context.Context, msgs []*schema.Message, tools []*schema.ToolInfo) (*schema.Message, error) { + // 工具选择使用生成模型 + return c.generateModel.Generate(ctx, msgs, model.WithTools(tools)) } func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt string, imgURLs []string) (*schema.Message, error) { + // 图片识别使用视觉模型 in := []*schema.Message{ { Role: schema.System, @@ -58,11 +88,12 @@ func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt } in[1].UserInputMultiContent = parts - return c.model.Generate(ctx, in) + return c.vlModel.Generate(ctx, in) } // 识别图片by二进制文件 func (c *Client) RecognizeWithImgBytes(ctx context.Context, systemPrompt, userPrompt string, imgBytes []byte, imgType string) (*schema.Message, error) { + // 图片识别使用视觉模型 in := []*schema.Message{ { Role: schema.System, @@ -82,9 +113,10 @@ func (c *Client) RecognizeWithImgBytes(ctx context.Context, systemPrompt, userPr MIMEType: imgType, Base64Data: util.AnyToPoint(base64.StdEncoding.EncodeToString(imgBytes)), }, + Detail: schema.ImageURLDetailHigh, }, }) in[1].UserInputMultiContent = parts - return c.model.Generate(ctx, in) + return c.vlModel.Generate(ctx, in) }