fix: 1.eino 工作流方案调整 2.下游批量订单售后接入工作流

This commit is contained in:
fuzhongyun 2025-12-11 18:35:48 +08:00
parent 15c3febe76
commit d310bf8104
12 changed files with 258 additions and 112 deletions

View File

@ -4,15 +4,16 @@
package main package main
import ( import (
"ai_scheduler/internal/biz" "ai_scheduler/internal/biz"
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/impl"
"ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg"
"ai_scheduler/internal/server" "ai_scheduler/internal/server"
"ai_scheduler/internal/services" "ai_scheduler/internal/services"
"ai_scheduler/internal/tools" "ai_scheduler/internal/domain/workflow"
"ai_scheduler/internal/tools_bot" "ai_scheduler/internal/tools"
"ai_scheduler/utils" "ai_scheduler/internal/tools_bot"
"ai_scheduler/utils"
"github.com/gofiber/fiber/v2/log" "github.com/gofiber/fiber/v2/log"
"github.com/google/wire" "github.com/google/wire"
@ -20,16 +21,17 @@ import (
// InitializeApp 初始化应用程序 // InitializeApp 初始化应用程序
func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) { func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) {
panic(wire.Build( panic(wire.Build(
server.ProviderSetServer, server.ProviderSetServer,
llm.ProviderSet, llm.ProviderSet,
tools.ProviderSetTools, workflow.ProviderSetWorkflow,
pkg.ProviderSetClient, tools.ProviderSetTools,
services.ProviderSetServices, pkg.ProviderSetClient,
biz.ProviderSetBiz, services.ProviderSetServices,
impl.ProviderImpl, biz.ProviderSetBiz,
utils.ProviderUtils, impl.ProviderImpl,
tools_bot.ProviderSetBotTools, utils.ProviderUtils,
)) tools_bot.ProviderSetBotTools,
))
} }

View File

@ -7,7 +7,7 @@ import (
errors "ai_scheduler/internal/data/error" errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model" "ai_scheduler/internal/data/model"
"ai_scheduler/internal/domain/workflow/zltx" "ai_scheduler/internal/domain/workflow/runtime"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway" "ai_scheduler/internal/gateway"
"ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/l_request"
@ -24,11 +24,12 @@ import (
) )
type Handle struct { type Handle struct {
Ollama *llm_service.OllamaService Ollama *llm_service.OllamaService
toolManager *tools.Manager toolManager *tools.Manager
Bot *tools_bot.BotTool Bot *tools_bot.BotTool
conf *config.Config conf *config.Config
sessionImpl *impl.SessionImpl sessionImpl *impl.SessionImpl
workflowManager *runtime.Registry
} }
func NewHandle( func NewHandle(
@ -37,13 +38,15 @@ func NewHandle(
conf *config.Config, conf *config.Config,
sessionImpl *impl.SessionImpl, sessionImpl *impl.SessionImpl,
dTalkBot *tools_bot.BotTool, dTalkBot *tools_bot.BotTool,
workflowManager *runtime.Registry,
) *Handle { ) *Handle {
return &Handle{ return &Handle{
Ollama: Ollama, Ollama: Ollama,
toolManager: toolManager, toolManager: toolManager,
conf: conf, conf: conf,
sessionImpl: sessionImpl, sessionImpl: sessionImpl,
Bot: dTalkBot, Bot: dTalkBot,
workflowManager: workflowManager,
} }
} }
@ -285,17 +288,24 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require
func (r *Handle) handleEinoWorkflow(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { func (r *Handle) handleEinoWorkflow(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
// token 写入ctx // token 写入ctx
ctx = util.SetTokenToContext(ctx, requireData.Auth) ctx = util.SetTokenToContext(ctx, requireData.Auth)
// 解析入参workflow_id 与 input
// 构建工作流 - todo 示例,后续抽象出来 var params map[string]any
zltxWorkflow, err := zltx.BuildOrderAfterResellerBatchWorkflow(ctx, r.conf.Tools.ZltxOrderAfterSaleResellerBatch) if len(requireData.Match.Parameters) > 0 {
if err != nil { _ = json.Unmarshal([]byte(requireData.Match.Parameters), &params)
return
} }
wfID, _ := params["workflow_id"].(string)
// 工作流执行 input, _ := params["input"].(map[string]any)
_, err = zltxWorkflow.Invoke(ctx, zltx.OrderAfterSaleResellerBatchWorkflowInput{}) if wfID == "" {
return fmt.Errorf("workflow_id 不能为空")
return }
entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在执行工作流")
res, err := r.workflowManager.Invoke(ctx, wfID, input)
if err != nil {
return err
}
b, _ := json.Marshal(res)
entitys.ResJson(requireData.Ch, requireData.Task.Index, string(b))
return nil
} }
// 权限验证 // 权限验证

View File

@ -1,10 +1,10 @@
package biz package biz
import ( import (
"ai_scheduler/internal/biz/do" "ai_scheduler/internal/biz/do"
"ai_scheduler/internal/biz/llm_service" "ai_scheduler/internal/biz/llm_service"
"github.com/google/wire" "github.com/google/wire"
) )
var ProviderSetBiz = wire.NewSet( var ProviderSetBiz = wire.NewSet(
@ -15,7 +15,7 @@ var ProviderSetBiz = wire.NewSet(
llm_service.NewOllamaGenerate, llm_service.NewOllamaGenerate,
//handle.NewHandle, //handle.NewHandle,
do.NewDo, do.NewDo,
do.NewHandle, do.NewHandle,
NewTaskBiz, NewTaskBiz,
NewDingTalkBotBiz, NewDingTalkBotBiz,
) )

View File

@ -1 +0,0 @@
package workflow

View File

@ -0,0 +1,22 @@
package workflow
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/domain/workflow/runtime"
"ai_scheduler/internal/pkg/utils_ollama"
"github.com/google/wire"
)
var ProviderSetWorkflow = wire.NewSet(NewRegistry)
// NewRegistry 注入共享依赖并注册默认 Registry确保自注册工作流可被发现
func NewRegistry(conf *config.Config, llm *utils_ollama.Client) *runtime.Registry {
// 步骤1设置运行时依赖配置与LLM客户端供工作流工厂在首次实例化时使用必须在任何调用 Invoke 之前完成,否则会触发 "deps not set"
runtime.SetDeps(&runtime.Deps{Conf: conf, LLM: llm})
// 步骤2创建新的工作流注册表注册表负责按工作流ID惰性实例化并缓存单例实例保障并发访问下的安全
r := runtime.NewRegistry()
// 步骤3将该注册表设置为全局默认便于通过 runtime.Default() 获取;自注册的工作流可通过默认注册表被发现并调用
runtime.SetDefault(r)
return r
}

View File

@ -0,0 +1,7 @@
package workflow
import (
// 手工维护:在此空导入工作流包以触发其 init() 自注册
// 新增工作流时,只需在这里添加一行 `_ "<import_path>"`
_ "ai_scheduler/internal/domain/workflow/zltx"
)

View File

@ -0,0 +1,12 @@
package workflow
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/pkg/utils_ollama"
)
// 仅声明依赖结构,避免在 workflow 包内实现注册中心逻辑导致循环依赖
type Deps struct {
Conf *config.Config
LLM *utils_ollama.Client
}

View File

@ -0,0 +1,95 @@
package runtime
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/pkg/utils_ollama"
"context"
"errors"
"sync"
)
type Workflow interface {
ID() string
Schema() map[string]any
Invoke(ctx context.Context, input map[string]any) (map[string]any, error)
}
type Deps struct {
Conf *config.Config
LLM *utils_ollama.Client
}
type Factory func(deps *Deps) (Workflow, error)
var (
regMu sync.RWMutex
factories = map[string]Factory{}
deps *Deps
defaultReg *Registry
)
func Register(id string, f Factory) {
regMu.Lock()
factories[id] = f
regMu.Unlock()
}
func SetDeps(d *Deps) {
regMu.Lock()
deps = d
regMu.Unlock()
}
type Registry struct {
mu sync.RWMutex
instances map[string]Workflow
}
func NewRegistry() *Registry {
return &Registry{instances: make(map[string]Workflow)}
}
func SetDefault(r *Registry) {
regMu.Lock()
defaultReg = r
regMu.Unlock()
}
func Default() *Registry {
regMu.RLock()
r := defaultReg
regMu.RUnlock()
return r
}
func (r *Registry) Invoke(ctx context.Context, id string, input map[string]any) (map[string]any, error) {
if input == nil {
input = map[string]any{}
}
regMu.RLock()
f, ok := factories[id]
regMu.RUnlock()
if !ok {
return nil, errors.New("workflow not found: " + id)
}
r.mu.RLock()
w, exists := r.instances[id]
r.mu.RUnlock()
if !exists {
if deps == nil {
return nil, errors.New("deps not set")
}
nw, err := f(deps)
if err != nil {
return nil, err
}
r.mu.Lock()
r.instances[id] = nw
w = nw
r.mu.Unlock()
}
return w.Invoke(ctx, input)
}

View File

@ -1 +0,0 @@
package zltx

View File

@ -3,19 +3,66 @@ package zltx
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
toolZoarb "ai_scheduler/internal/domain/tools/zltx/order_after_reseller_batch" toolZoarb "ai_scheduler/internal/domain/tools/zltx/order_after_reseller_batch"
"ai_scheduler/internal/domain/workflow/runtime"
"context" "context"
"errors" "errors"
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
) )
func init() {
runtime.Register("zltx.orderAfterSaleResellerBatch", func(d *runtime.Deps) (runtime.Workflow, error) {
return &orderAfterSaleResellerBatch{cfg: d.Conf.Tools.ZltxOrderAfterSaleResellerBatch}, nil
})
}
type orderAfterSaleResellerBatch struct {
cfg config.ToolConfig
}
// ID 返回工作流唯一标识
func (o *orderAfterSaleResellerBatch) ID() string { return "zltx.orderAfterSaleResellerBatch" }
// Schema 返回入参约束(用于校验/表单生成)
func (o *orderAfterSaleResellerBatch) Schema() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{"orderNumber": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}},
"required": []string{"orderNumber"},
}
}
// Invoke 调用原有编排工作流并规范化输出
func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
// 构建工作流
chain, err := o.buildWorkflow(ctx)
if err != nil {
return nil, err
}
var in OrderAfterSaleResellerBatchWorkflowInput
if v, ok := input["orderNumber"].([]string); ok {
in.OrderNumber = v
}
_, err = chain.Invoke(ctx, in)
if err != nil {
return nil, err
}
// 工作流 callback
// 不关心输出,全部在途中输出
return nil, nil
}
type OrderAfterSaleResellerBatchWorkflowInput struct { type OrderAfterSaleResellerBatchWorkflowInput struct {
OrderNumber []string `json:"orderNumber"` OrderNumber []string `json:"orderNumber"`
} }
var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空") var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空")
func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolConfig) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) { // buildWorkflow 构建工作流
func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) {
// 定义工作流、出入参 // 定义工作流、出入参
c := compose.NewChain[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData]() c := compose.NewChain[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData]()
@ -29,7 +76,7 @@ func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolCo
// 2.调用工具 // 2.调用工具
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in OrderAfterSaleResellerBatchWorkflowInput) (toolZoarb.OrderAfterSaleResellerBatchData, error) { c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in OrderAfterSaleResellerBatchWorkflowInput) (toolZoarb.OrderAfterSaleResellerBatchData, error) {
return toolZoarb.Call(ctx, cfg, in.OrderNumber) return toolZoarb.Call(ctx, o.cfg, in.OrderNumber)
})) }))
// 3.结果映射与整形 // 3.结果映射与整形

View File

@ -1,46 +0,0 @@
package zltx
import (
"ai_scheduler/internal/config"
"context"
"errors"
"strings"
"testing"
)
func TestOrderAfterResellerBatch_InvalidOrderNumbers(t *testing.T) {
ctx := context.Background()
cfg := config.ToolConfig{}
run, err := BuildOrderAfterResellerBatchWorkflow(ctx, cfg)
if err != nil {
t.Fatalf("build workflow error: %v", err)
}
_, err = run.Invoke(ctx, OrderAfterSaleResellerBatchWorkflowInput{})
if err == nil {
t.Fatalf("expected error, got nil")
}
if !errors.Is(err, ErrInvalidOrderNumbers) {
t.Fatalf("expected ErrInvalidOrderNumbers, got %v", err)
}
}
func TestOrderAfterResellerBatch_ContextTokenRequired(t *testing.T) {
ctx := context.Background()
cfg := config.ToolConfig{}
run, err := BuildOrderAfterResellerBatchWorkflow(ctx, cfg)
if err != nil {
t.Fatalf("build workflow error: %v", err)
}
in := OrderAfterSaleResellerBatchWorkflowInput{OrderNumber: []string{"123"}}
_, err = run.Invoke(ctx, in)
if err == nil {
t.Fatalf("expected error, got nil")
}
if !strings.Contains(err.Error(), "token 未注入") {
t.Fatalf("expected contains 'token 未注入', got %v", err)
}
}

View File

@ -2,7 +2,6 @@ package tools
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/pkg/utils_ollama"
zltxtool "ai_scheduler/internal/tools/zltx" zltxtool "ai_scheduler/internal/tools/zltx"
@ -101,23 +100,23 @@ func (m *Manager) GetTool(name string) (entitys.Tool, bool) {
} }
// GetAllTools 获取所有工具 // GetAllTools 获取所有工具
func (m *Manager) GetAllTools() []entitys.Tool { // func (m *Manager) GetAllTools() []entitys.Tool {
tools := make([]entitys.Tool, 0, len(m.tools)) // tools := make([]entitys.Tool, 0, len(m.tools))
for _, tool := range m.tools { // for _, tool := range m.tools {
tools = append(tools, tool) // tools = append(tools, tool)
} // }
return tools // return tools
} // }
// GetToolDefinitions 获取所有工具定义 // // GetToolDefinitions 获取所有工具定义
func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition { // func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition {
definitions := make([]entitys.ToolDefinition, 0, len(m.tools)) // definitions := make([]entitys.ToolDefinition, 0, len(m.tools))
for _, tool := range m.tools { // for _, tool := range m.tools {
definitions = append(definitions, tool.Definition()) // definitions = append(definitions, tool.Definition())
} // }
return definitions // return definitions
} // }
// ExecuteTool 执行工具 // ExecuteTool 执行工具
func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error { func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error {