ai_scheduler/internal/domain/llm/service/service.go

97 lines
2.6 KiB
Go

package service
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/domain/llm"
"ai_scheduler/internal/domain/llm/capability"
"ai_scheduler/internal/domain/llm/provider"
"context"
"github.com/cloudwego/eino/schema"
)
type LLMService struct{ cfg *config.Config }
func NewLLMService(cfg *config.Config) *LLMService { return &LLMService{cfg: cfg} }
func (s *LLMService) Chat(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) {
choice, routeOpts, err := capability.Route(s.cfg, capability.Chat)
if err != nil {
return nil, err
}
mergeOptions(&routeOpts, opts)
f := provider.Get(choice.Provider)
if f == nil {
return nil, llm.ErrProviderNotFound
}
return f().Generate(ctx, input, routeOpts)
}
func (s *LLMService) ChatStream(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.StreamReader[*schema.Message], error) {
choice, routeOpts, err := capability.Route(s.cfg, capability.Chat)
if err != nil {
return nil, err
}
routeOpts.Stream = true
mergeOptions(&routeOpts, opts)
f := provider.Get(choice.Provider)
if f == nil {
return nil, llm.ErrProviderNotFound
}
return f().Stream(ctx, input, routeOpts)
}
func (s *LLMService) Vision(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) {
choice, routeOpts, err := capability.Route(s.cfg, capability.Vision)
if err != nil {
return nil, err
}
mergeOptions(&routeOpts, opts)
f := provider.Get(choice.Provider)
if f == nil {
return nil, llm.ErrProviderNotFound
}
return f().Generate(ctx, input, routeOpts)
}
func (s *LLMService) Intent(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) {
choice, routeOpts, err := capability.Route(s.cfg, capability.Intent)
if err != nil {
return nil, err
}
mergeOptions(&routeOpts, opts)
f := provider.Get(choice.Provider)
if f == nil {
return nil, llm.ErrProviderNotFound
}
return f().Generate(ctx, input, routeOpts)
}
func mergeOptions(base *llm.Options, override llm.Options) {
if override.Model != "" {
base.Model = override.Model
}
if override.MaxTokens > 0 {
base.MaxTokens = override.MaxTokens
}
if override.Temperature != 0 {
base.Temperature = override.Temperature
}
if override.Timeout > 0 {
base.Timeout = override.Timeout
}
if len(override.Modalities) > 0 {
base.Modalities = override.Modalities
}
if override.SystemPrompt != "" {
base.SystemPrompt = override.SystemPrompt
}
if override.TopP != 0 {
base.TopP = override.TopP
}
if len(override.Stop) > 0 {
base.Stop = override.Stop
}
base.Stream = base.Stream || override.Stream
}