39 lines
1.3 KiB
Go
39 lines
1.3 KiB
Go
package pipeline
|
|
|
|
import (
|
|
"context"
|
|
"github.com/cloudwego/eino/compose"
|
|
"github.com/cloudwego/eino/schema"
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/internal/domain/llm"
|
|
"ai_scheduler/internal/domain/llm/capability"
|
|
"ai_scheduler/internal/domain/llm/provider"
|
|
"ai_scheduler/internal/domain/llm/provider/ollama"
|
|
"ai_scheduler/internal/domain/llm/provider/vllm"
|
|
)
|
|
|
|
func init() {
|
|
provider.Register("ollama", func() provider.Adapter { return ollama.New() })
|
|
provider.Register("vllm", func() provider.Adapter { return vllm.New() })
|
|
}
|
|
|
|
func BuildChat(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) {
|
|
choice, opts, err := capability.Route(cfg, capability.Chat)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err = capability.Validate(capability.Chat, opts); err != nil {
|
|
return nil, err
|
|
}
|
|
f := provider.Get(choice.Provider)
|
|
if f == nil {
|
|
return nil, llm.ErrProviderNotFound
|
|
}
|
|
ad := f()
|
|
c := compose.NewChain[[]*schema.Message, *schema.Message]()
|
|
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) {
|
|
return ad.Generate(ctx, in, opts)
|
|
}))
|
|
return c.Compile(ctx)
|
|
}
|