ai_scheduler/internal/domain/llm/provider/vllm/adapter.go

71 lines
1.9 KiB
Go

package vllm
import (
"ai_scheduler/internal/domain/llm"
"context"
eino_openai "github.com/cloudwego/eino-ext/components/model/openai"
eino_model "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
)
type Adapter struct{}
func New() *Adapter { return &Adapter{} }
func (a *Adapter) Generate(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) {
cm, err := eino_openai.NewChatModel(ctx, &eino_openai.ChatModelConfig{
BaseURL: opts.Endpoint,
Timeout: opts.Timeout,
Model: opts.Model,
})
if err != nil {
return nil, err
}
var mopts []eino_model.Option
if opts.Temperature != 0 {
mopts = append(mopts, eino_model.WithTemperature(opts.Temperature))
}
if opts.MaxTokens > 0 {
mopts = append(mopts, eino_model.WithMaxTokens(opts.MaxTokens))
}
if opts.Model != "" {
mopts = append(mopts, eino_model.WithModel(opts.Model))
}
if opts.TopP != 0 {
mopts = append(mopts, eino_model.WithTopP(opts.TopP))
}
if len(opts.Stop) > 0 {
mopts = append(mopts, eino_model.WithStop(opts.Stop))
}
return cm.Generate(ctx, input, mopts...)
}
func (a *Adapter) Stream(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.StreamReader[*schema.Message], error) {
cm, err := eino_openai.NewChatModel(ctx, &eino_openai.ChatModelConfig{
BaseURL: opts.Endpoint,
Timeout: opts.Timeout,
Model: opts.Model,
})
if err != nil {
return nil, err
}
var mopts []eino_model.Option
if opts.Temperature != 0 {
mopts = append(mopts, eino_model.WithTemperature(opts.Temperature))
}
if opts.MaxTokens > 0 {
mopts = append(mopts, eino_model.WithMaxTokens(opts.MaxTokens))
}
if opts.Model != "" {
mopts = append(mopts, eino_model.WithModel(opts.Model))
}
if opts.TopP != 0 {
mopts = append(mopts, eino_model.WithTopP(opts.TopP))
}
if len(opts.Stop) > 0 {
mopts = append(mopts, eino_model.WithStop(opts.Stop))
}
return cm.Stream(ctx, input, mopts...)
}