diff --git a/cmd/server/wire.go b/cmd/server/wire.go index bcd1d2c..8a29212 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -4,15 +4,16 @@ package main import ( - "ai_scheduler/internal/biz" - "ai_scheduler/internal/config" - "ai_scheduler/internal/data/impl" - "ai_scheduler/internal/pkg" - "ai_scheduler/internal/server" - "ai_scheduler/internal/services" - "ai_scheduler/internal/tools" - "ai_scheduler/internal/tools_bot" - "ai_scheduler/utils" + "ai_scheduler/internal/biz" + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/server" + "ai_scheduler/internal/services" + "ai_scheduler/internal/domain/workflow" + "ai_scheduler/internal/tools" + "ai_scheduler/internal/tools_bot" + "ai_scheduler/utils" "github.com/gofiber/fiber/v2/log" "github.com/google/wire" @@ -20,16 +21,17 @@ import ( // InitializeApp 初始化应用程序 func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) { - panic(wire.Build( - server.ProviderSetServer, - llm.ProviderSet, - tools.ProviderSetTools, - pkg.ProviderSetClient, - services.ProviderSetServices, - biz.ProviderSetBiz, - impl.ProviderImpl, - utils.ProviderUtils, - tools_bot.ProviderSetBotTools, - )) + panic(wire.Build( + server.ProviderSetServer, + llm.ProviderSet, + workflow.ProviderSetWorkflow, + tools.ProviderSetTools, + pkg.ProviderSetClient, + services.ProviderSetServices, + biz.ProviderSetBiz, + impl.ProviderImpl, + utils.ProviderUtils, + tools_bot.ProviderSetBotTools, + )) } diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index fa7b3c5..2028f54 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -7,7 +7,7 @@ import ( errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" - "ai_scheduler/internal/domain/workflow/zltx" + "ai_scheduler/internal/domain/workflow/runtime" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg/l_request" @@ -24,11 +24,12 @@ import ( ) type Handle struct { - Ollama *llm_service.OllamaService - toolManager *tools.Manager - Bot *tools_bot.BotTool - conf *config.Config - sessionImpl *impl.SessionImpl + Ollama *llm_service.OllamaService + toolManager *tools.Manager + Bot *tools_bot.BotTool + conf *config.Config + sessionImpl *impl.SessionImpl + workflowManager *runtime.Registry } func NewHandle( @@ -37,13 +38,15 @@ func NewHandle( conf *config.Config, sessionImpl *impl.SessionImpl, dTalkBot *tools_bot.BotTool, + workflowManager *runtime.Registry, ) *Handle { return &Handle{ - Ollama: Ollama, - toolManager: toolManager, - conf: conf, - sessionImpl: sessionImpl, - Bot: dTalkBot, + Ollama: Ollama, + toolManager: toolManager, + conf: conf, + sessionImpl: sessionImpl, + Bot: dTalkBot, + workflowManager: workflowManager, } } @@ -285,17 +288,24 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require func (r *Handle) handleEinoWorkflow(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { // token 写入ctx ctx = util.SetTokenToContext(ctx, requireData.Auth) - - // 构建工作流 - todo 示例,后续抽象出来 - zltxWorkflow, err := zltx.BuildOrderAfterResellerBatchWorkflow(ctx, r.conf.Tools.ZltxOrderAfterSaleResellerBatch) - if err != nil { - return + // 解析入参:workflow_id 与 input + var params map[string]any + if len(requireData.Match.Parameters) > 0 { + _ = json.Unmarshal([]byte(requireData.Match.Parameters), ¶ms) } - - // 工作流执行 - _, err = zltxWorkflow.Invoke(ctx, zltx.OrderAfterSaleResellerBatchWorkflowInput{}) - - return + wfID, _ := params["workflow_id"].(string) + input, _ := params["input"].(map[string]any) + if wfID == "" { + return fmt.Errorf("workflow_id 不能为空") + } + entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在执行工作流") + res, err := r.workflowManager.Invoke(ctx, wfID, input) + if err != nil { + return err + } + b, _ := json.Marshal(res) + entitys.ResJson(requireData.Ch, requireData.Task.Index, string(b)) + return nil } // 权限验证 diff --git a/internal/biz/provider_set.go b/internal/biz/provider_set.go index aefa3ce..1bdc0f7 100644 --- a/internal/biz/provider_set.go +++ b/internal/biz/provider_set.go @@ -1,10 +1,10 @@ package biz import ( - "ai_scheduler/internal/biz/do" - "ai_scheduler/internal/biz/llm_service" - - "github.com/google/wire" + "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/biz/llm_service" + + "github.com/google/wire" ) var ProviderSetBiz = wire.NewSet( @@ -15,7 +15,7 @@ var ProviderSetBiz = wire.NewSet( llm_service.NewOllamaGenerate, //handle.NewHandle, do.NewDo, - do.NewHandle, + do.NewHandle, NewTaskBiz, NewDingTalkBotBiz, ) diff --git a/internal/domain/workflow/manager.go b/internal/domain/workflow/manager.go deleted file mode 100644 index 0e59ea2..0000000 --- a/internal/domain/workflow/manager.go +++ /dev/null @@ -1 +0,0 @@ -package workflow diff --git a/internal/domain/workflow/provider_set.go b/internal/domain/workflow/provider_set.go new file mode 100644 index 0000000..c728b44 --- /dev/null +++ b/internal/domain/workflow/provider_set.go @@ -0,0 +1,22 @@ +package workflow + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/pkg/utils_ollama" + + "github.com/google/wire" +) + +var ProviderSetWorkflow = wire.NewSet(NewRegistry) + +// NewRegistry 注入共享依赖并注册默认 Registry,确保自注册工作流可被发现 +func NewRegistry(conf *config.Config, llm *utils_ollama.Client) *runtime.Registry { + // 步骤1:设置运行时依赖(配置与LLM客户端),供工作流工厂在首次实例化时使用;必须在任何调用 Invoke 之前完成,否则会触发 "deps not set" + runtime.SetDeps(&runtime.Deps{Conf: conf, LLM: llm}) + // 步骤2:创建新的工作流注册表;注册表负责按工作流ID惰性实例化并缓存单例实例,保障并发访问下的安全 + r := runtime.NewRegistry() + // 步骤3:将该注册表设置为全局默认,便于通过 runtime.Default() 获取;自注册的工作流可通过默认注册表被发现并调用 + runtime.SetDefault(r) + return r +} diff --git a/internal/domain/workflow/register_all_gen.go b/internal/domain/workflow/register_all_gen.go new file mode 100644 index 0000000..d30ba64 --- /dev/null +++ b/internal/domain/workflow/register_all_gen.go @@ -0,0 +1,7 @@ +package workflow + +import ( + // 手工维护:在此空导入工作流包以触发其 init() 自注册 + // 新增工作流时,只需在这里添加一行 `_ ""` + _ "ai_scheduler/internal/domain/workflow/zltx" +) diff --git a/internal/domain/workflow/registry.go b/internal/domain/workflow/registry.go new file mode 100644 index 0000000..cbde3b6 --- /dev/null +++ b/internal/domain/workflow/registry.go @@ -0,0 +1,12 @@ +package workflow + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/utils_ollama" +) + +// 仅声明依赖结构,避免在 workflow 包内实现注册中心逻辑导致循环依赖 +type Deps struct { + Conf *config.Config + LLM *utils_ollama.Client +} diff --git a/internal/domain/workflow/runtime/registry.go b/internal/domain/workflow/runtime/registry.go new file mode 100644 index 0000000..1260173 --- /dev/null +++ b/internal/domain/workflow/runtime/registry.go @@ -0,0 +1,95 @@ +package runtime + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/utils_ollama" + "context" + "errors" + "sync" +) + +type Workflow interface { + ID() string + Schema() map[string]any + Invoke(ctx context.Context, input map[string]any) (map[string]any, error) +} + +type Deps struct { + Conf *config.Config + LLM *utils_ollama.Client +} + +type Factory func(deps *Deps) (Workflow, error) + +var ( + regMu sync.RWMutex + factories = map[string]Factory{} + deps *Deps + defaultReg *Registry +) + +func Register(id string, f Factory) { + regMu.Lock() + factories[id] = f + regMu.Unlock() +} + +func SetDeps(d *Deps) { + regMu.Lock() + deps = d + regMu.Unlock() +} + +type Registry struct { + mu sync.RWMutex + instances map[string]Workflow +} + +func NewRegistry() *Registry { + return &Registry{instances: make(map[string]Workflow)} +} + +func SetDefault(r *Registry) { + regMu.Lock() + defaultReg = r + regMu.Unlock() +} + +func Default() *Registry { + regMu.RLock() + r := defaultReg + regMu.RUnlock() + return r +} + +func (r *Registry) Invoke(ctx context.Context, id string, input map[string]any) (map[string]any, error) { + if input == nil { + input = map[string]any{} + } + regMu.RLock() + f, ok := factories[id] + regMu.RUnlock() + if !ok { + return nil, errors.New("workflow not found: " + id) + } + + r.mu.RLock() + w, exists := r.instances[id] + r.mu.RUnlock() + + if !exists { + if deps == nil { + return nil, errors.New("deps not set") + } + nw, err := f(deps) + if err != nil { + return nil, err + } + r.mu.Lock() + r.instances[id] = nw + w = nw + r.mu.Unlock() + } + + return w.Invoke(ctx, input) +} diff --git a/internal/domain/workflow/zltx/crontab_supplier.go b/internal/domain/workflow/zltx/crontab_supplier.go deleted file mode 100644 index 29ee896..0000000 --- a/internal/domain/workflow/zltx/crontab_supplier.go +++ /dev/null @@ -1 +0,0 @@ -package zltx diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go index 5c3081a..e00b726 100644 --- a/internal/domain/workflow/zltx/order_after_reseller_batch.go +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -3,19 +3,66 @@ package zltx import ( "ai_scheduler/internal/config" toolZoarb "ai_scheduler/internal/domain/tools/zltx/order_after_reseller_batch" + "ai_scheduler/internal/domain/workflow/runtime" "context" "errors" "github.com/cloudwego/eino/compose" ) +func init() { + runtime.Register("zltx.orderAfterSaleResellerBatch", func(d *runtime.Deps) (runtime.Workflow, error) { + return &orderAfterSaleResellerBatch{cfg: d.Conf.Tools.ZltxOrderAfterSaleResellerBatch}, nil + }) +} + +type orderAfterSaleResellerBatch struct { + cfg config.ToolConfig +} + +// ID 返回工作流唯一标识 +func (o *orderAfterSaleResellerBatch) ID() string { return "zltx.orderAfterSaleResellerBatch" } + +// Schema 返回入参约束(用于校验/表单生成) +func (o *orderAfterSaleResellerBatch) Schema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{"orderNumber": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}}, + "required": []string{"orderNumber"}, + } +} + +// Invoke 调用原有编排工作流并规范化输出 +func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { + // 构建工作流 + chain, err := o.buildWorkflow(ctx) + if err != nil { + return nil, err + } + + var in OrderAfterSaleResellerBatchWorkflowInput + if v, ok := input["orderNumber"].([]string); ok { + in.OrderNumber = v + } + _, err = chain.Invoke(ctx, in) + if err != nil { + return nil, err + } + + // 工作流 callback + + // 不关心输出,全部在途中输出 + return nil, nil +} + type OrderAfterSaleResellerBatchWorkflowInput struct { OrderNumber []string `json:"orderNumber"` } var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空") -func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolConfig) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) { +// buildWorkflow 构建工作流 +func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) { // 定义工作流、出入参 c := compose.NewChain[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData]() @@ -29,7 +76,7 @@ func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolCo // 2.调用工具 c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in OrderAfterSaleResellerBatchWorkflowInput) (toolZoarb.OrderAfterSaleResellerBatchData, error) { - return toolZoarb.Call(ctx, cfg, in.OrderNumber) + return toolZoarb.Call(ctx, o.cfg, in.OrderNumber) })) // 3.结果映射与整形 diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch_test.go b/internal/domain/workflow/zltx/order_after_reseller_batch_test.go deleted file mode 100644 index be56786..0000000 --- a/internal/domain/workflow/zltx/order_after_reseller_batch_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package zltx - -import ( - "ai_scheduler/internal/config" - "context" - "errors" - "strings" - "testing" -) - -func TestOrderAfterResellerBatch_InvalidOrderNumbers(t *testing.T) { - ctx := context.Background() - cfg := config.ToolConfig{} - - run, err := BuildOrderAfterResellerBatchWorkflow(ctx, cfg) - if err != nil { - t.Fatalf("build workflow error: %v", err) - } - - _, err = run.Invoke(ctx, OrderAfterSaleResellerBatchWorkflowInput{}) - if err == nil { - t.Fatalf("expected error, got nil") - } - if !errors.Is(err, ErrInvalidOrderNumbers) { - t.Fatalf("expected ErrInvalidOrderNumbers, got %v", err) - } -} - -func TestOrderAfterResellerBatch_ContextTokenRequired(t *testing.T) { - ctx := context.Background() - cfg := config.ToolConfig{} - - run, err := BuildOrderAfterResellerBatchWorkflow(ctx, cfg) - if err != nil { - t.Fatalf("build workflow error: %v", err) - } - - in := OrderAfterSaleResellerBatchWorkflowInput{OrderNumber: []string{"123"}} - _, err = run.Invoke(ctx, in) - if err == nil { - t.Fatalf("expected error, got nil") - } - if !strings.Contains(err.Error(), "token 未注入") { - t.Fatalf("expected contains 'token 未注入', got %v", err) - } -} diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 2ebbb69..406d11c 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -2,7 +2,6 @@ package tools import ( "ai_scheduler/internal/config" - "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" zltxtool "ai_scheduler/internal/tools/zltx" @@ -101,23 +100,23 @@ func (m *Manager) GetTool(name string) (entitys.Tool, bool) { } // GetAllTools 获取所有工具 -func (m *Manager) GetAllTools() []entitys.Tool { - tools := make([]entitys.Tool, 0, len(m.tools)) - for _, tool := range m.tools { - tools = append(tools, tool) - } - return tools -} +// func (m *Manager) GetAllTools() []entitys.Tool { +// tools := make([]entitys.Tool, 0, len(m.tools)) +// for _, tool := range m.tools { +// tools = append(tools, tool) +// } +// return tools +// } -// GetToolDefinitions 获取所有工具定义 -func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition { - definitions := make([]entitys.ToolDefinition, 0, len(m.tools)) - for _, tool := range m.tools { - definitions = append(definitions, tool.Definition()) - } +// // GetToolDefinitions 获取所有工具定义 +// func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition { +// definitions := make([]entitys.ToolDefinition, 0, len(m.tools)) +// for _, tool := range m.tools { +// definitions = append(definitions, tool.Definition()) +// } - return definitions -} +// return definitions +// } // ExecuteTool 执行工具 func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error {