97 lines
2.6 KiB
Go
97 lines
2.6 KiB
Go
package service
|
|
|
|
import (
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/internal/domain/llm"
|
|
"ai_scheduler/internal/domain/llm/capability"
|
|
"ai_scheduler/internal/domain/llm/provider"
|
|
"context"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
)
|
|
|
|
type LLMService struct{ cfg *config.Config }
|
|
|
|
func NewLLMService(cfg *config.Config) *LLMService { return &LLMService{cfg: cfg} }
|
|
|
|
func (s *LLMService) Chat(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) {
|
|
choice, routeOpts, err := capability.Route(s.cfg, capability.Chat)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
mergeOptions(&routeOpts, opts)
|
|
f := provider.Get(choice.Provider)
|
|
if f == nil {
|
|
return nil, llm.ErrProviderNotFound
|
|
}
|
|
return f().Generate(ctx, input, routeOpts)
|
|
}
|
|
|
|
func (s *LLMService) ChatStream(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.StreamReader[*schema.Message], error) {
|
|
choice, routeOpts, err := capability.Route(s.cfg, capability.Chat)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
routeOpts.Stream = true
|
|
mergeOptions(&routeOpts, opts)
|
|
f := provider.Get(choice.Provider)
|
|
if f == nil {
|
|
return nil, llm.ErrProviderNotFound
|
|
}
|
|
return f().Stream(ctx, input, routeOpts)
|
|
}
|
|
|
|
func (s *LLMService) Vision(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) {
|
|
choice, routeOpts, err := capability.Route(s.cfg, capability.Vision)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
mergeOptions(&routeOpts, opts)
|
|
f := provider.Get(choice.Provider)
|
|
if f == nil {
|
|
return nil, llm.ErrProviderNotFound
|
|
}
|
|
return f().Generate(ctx, input, routeOpts)
|
|
}
|
|
|
|
func (s *LLMService) Intent(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) {
|
|
choice, routeOpts, err := capability.Route(s.cfg, capability.Intent)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
mergeOptions(&routeOpts, opts)
|
|
f := provider.Get(choice.Provider)
|
|
if f == nil {
|
|
return nil, llm.ErrProviderNotFound
|
|
}
|
|
return f().Generate(ctx, input, routeOpts)
|
|
}
|
|
|
|
func mergeOptions(base *llm.Options, override llm.Options) {
|
|
if override.Model != "" {
|
|
base.Model = override.Model
|
|
}
|
|
if override.MaxTokens > 0 {
|
|
base.MaxTokens = override.MaxTokens
|
|
}
|
|
if override.Temperature != 0 {
|
|
base.Temperature = override.Temperature
|
|
}
|
|
if override.Timeout > 0 {
|
|
base.Timeout = override.Timeout
|
|
}
|
|
if len(override.Modalities) > 0 {
|
|
base.Modalities = override.Modalities
|
|
}
|
|
if override.SystemPrompt != "" {
|
|
base.SystemPrompt = override.SystemPrompt
|
|
}
|
|
if override.TopP != 0 {
|
|
base.TopP = override.TopP
|
|
}
|
|
if len(override.Stop) > 0 {
|
|
base.Stop = override.Stop
|
|
}
|
|
base.Stream = base.Stream || override.Stream
|
|
}
|