fix: 1.eino 工作流方案调整 2.下游批量订单售后接入工作流

This commit is contained in:
fuzhongyun 2025-12-11 18:35:48 +08:00
parent 15c3febe76
commit d310bf8104
12 changed files with 258 additions and 112 deletions

View File

@ -4,15 +4,16 @@
package main
import (
"ai_scheduler/internal/biz"
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/server"
"ai_scheduler/internal/services"
"ai_scheduler/internal/tools"
"ai_scheduler/internal/tools_bot"
"ai_scheduler/utils"
"ai_scheduler/internal/biz"
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/server"
"ai_scheduler/internal/services"
"ai_scheduler/internal/domain/workflow"
"ai_scheduler/internal/tools"
"ai_scheduler/internal/tools_bot"
"ai_scheduler/utils"
"github.com/gofiber/fiber/v2/log"
"github.com/google/wire"
@ -20,16 +21,17 @@ import (
// InitializeApp 初始化应用程序
func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) {
panic(wire.Build(
server.ProviderSetServer,
llm.ProviderSet,
tools.ProviderSetTools,
pkg.ProviderSetClient,
services.ProviderSetServices,
biz.ProviderSetBiz,
impl.ProviderImpl,
utils.ProviderUtils,
tools_bot.ProviderSetBotTools,
))
panic(wire.Build(
server.ProviderSetServer,
llm.ProviderSet,
workflow.ProviderSetWorkflow,
tools.ProviderSetTools,
pkg.ProviderSetClient,
services.ProviderSetServices,
biz.ProviderSetBiz,
impl.ProviderImpl,
utils.ProviderUtils,
tools_bot.ProviderSetBotTools,
))
}

View File

@ -7,7 +7,7 @@ import (
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/domain/workflow/zltx"
"ai_scheduler/internal/domain/workflow/runtime"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway"
"ai_scheduler/internal/pkg/l_request"
@ -24,11 +24,12 @@ import (
)
type Handle struct {
Ollama *llm_service.OllamaService
toolManager *tools.Manager
Bot *tools_bot.BotTool
conf *config.Config
sessionImpl *impl.SessionImpl
Ollama *llm_service.OllamaService
toolManager *tools.Manager
Bot *tools_bot.BotTool
conf *config.Config
sessionImpl *impl.SessionImpl
workflowManager *runtime.Registry
}
func NewHandle(
@ -37,13 +38,15 @@ func NewHandle(
conf *config.Config,
sessionImpl *impl.SessionImpl,
dTalkBot *tools_bot.BotTool,
workflowManager *runtime.Registry,
) *Handle {
return &Handle{
Ollama: Ollama,
toolManager: toolManager,
conf: conf,
sessionImpl: sessionImpl,
Bot: dTalkBot,
Ollama: Ollama,
toolManager: toolManager,
conf: conf,
sessionImpl: sessionImpl,
Bot: dTalkBot,
workflowManager: workflowManager,
}
}
@ -285,17 +288,24 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require
func (r *Handle) handleEinoWorkflow(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
// token 写入ctx
ctx = util.SetTokenToContext(ctx, requireData.Auth)
// 构建工作流 - todo 示例,后续抽象出来
zltxWorkflow, err := zltx.BuildOrderAfterResellerBatchWorkflow(ctx, r.conf.Tools.ZltxOrderAfterSaleResellerBatch)
if err != nil {
return
// 解析入参workflow_id 与 input
var params map[string]any
if len(requireData.Match.Parameters) > 0 {
_ = json.Unmarshal([]byte(requireData.Match.Parameters), &params)
}
// 工作流执行
_, err = zltxWorkflow.Invoke(ctx, zltx.OrderAfterSaleResellerBatchWorkflowInput{})
return
wfID, _ := params["workflow_id"].(string)
input, _ := params["input"].(map[string]any)
if wfID == "" {
return fmt.Errorf("workflow_id 不能为空")
}
entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在执行工作流")
res, err := r.workflowManager.Invoke(ctx, wfID, input)
if err != nil {
return err
}
b, _ := json.Marshal(res)
entitys.ResJson(requireData.Ch, requireData.Task.Index, string(b))
return nil
}
// 权限验证

View File

@ -1,10 +1,10 @@
package biz
import (
"ai_scheduler/internal/biz/do"
"ai_scheduler/internal/biz/llm_service"
"github.com/google/wire"
"ai_scheduler/internal/biz/do"
"ai_scheduler/internal/biz/llm_service"
"github.com/google/wire"
)
var ProviderSetBiz = wire.NewSet(
@ -15,7 +15,7 @@ var ProviderSetBiz = wire.NewSet(
llm_service.NewOllamaGenerate,
//handle.NewHandle,
do.NewDo,
do.NewHandle,
do.NewHandle,
NewTaskBiz,
NewDingTalkBotBiz,
)

View File

@ -1 +0,0 @@
package workflow

View File

@ -0,0 +1,22 @@
package workflow
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/domain/workflow/runtime"
"ai_scheduler/internal/pkg/utils_ollama"
"github.com/google/wire"
)
var ProviderSetWorkflow = wire.NewSet(NewRegistry)
// NewRegistry 注入共享依赖并注册默认 Registry确保自注册工作流可被发现
func NewRegistry(conf *config.Config, llm *utils_ollama.Client) *runtime.Registry {
// 步骤1设置运行时依赖配置与LLM客户端供工作流工厂在首次实例化时使用必须在任何调用 Invoke 之前完成,否则会触发 "deps not set"
runtime.SetDeps(&runtime.Deps{Conf: conf, LLM: llm})
// 步骤2创建新的工作流注册表注册表负责按工作流ID惰性实例化并缓存单例实例保障并发访问下的安全
r := runtime.NewRegistry()
// 步骤3将该注册表设置为全局默认便于通过 runtime.Default() 获取;自注册的工作流可通过默认注册表被发现并调用
runtime.SetDefault(r)
return r
}

View File

@ -0,0 +1,7 @@
package workflow
import (
// 手工维护:在此空导入工作流包以触发其 init() 自注册
// 新增工作流时,只需在这里添加一行 `_ "<import_path>"`
_ "ai_scheduler/internal/domain/workflow/zltx"
)

View File

@ -0,0 +1,12 @@
package workflow
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/pkg/utils_ollama"
)
// 仅声明依赖结构,避免在 workflow 包内实现注册中心逻辑导致循环依赖
type Deps struct {
Conf *config.Config
LLM *utils_ollama.Client
}

View File

@ -0,0 +1,95 @@
package runtime
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/pkg/utils_ollama"
"context"
"errors"
"sync"
)
type Workflow interface {
ID() string
Schema() map[string]any
Invoke(ctx context.Context, input map[string]any) (map[string]any, error)
}
type Deps struct {
Conf *config.Config
LLM *utils_ollama.Client
}
type Factory func(deps *Deps) (Workflow, error)
var (
regMu sync.RWMutex
factories = map[string]Factory{}
deps *Deps
defaultReg *Registry
)
func Register(id string, f Factory) {
regMu.Lock()
factories[id] = f
regMu.Unlock()
}
func SetDeps(d *Deps) {
regMu.Lock()
deps = d
regMu.Unlock()
}
type Registry struct {
mu sync.RWMutex
instances map[string]Workflow
}
func NewRegistry() *Registry {
return &Registry{instances: make(map[string]Workflow)}
}
func SetDefault(r *Registry) {
regMu.Lock()
defaultReg = r
regMu.Unlock()
}
func Default() *Registry {
regMu.RLock()
r := defaultReg
regMu.RUnlock()
return r
}
func (r *Registry) Invoke(ctx context.Context, id string, input map[string]any) (map[string]any, error) {
if input == nil {
input = map[string]any{}
}
regMu.RLock()
f, ok := factories[id]
regMu.RUnlock()
if !ok {
return nil, errors.New("workflow not found: " + id)
}
r.mu.RLock()
w, exists := r.instances[id]
r.mu.RUnlock()
if !exists {
if deps == nil {
return nil, errors.New("deps not set")
}
nw, err := f(deps)
if err != nil {
return nil, err
}
r.mu.Lock()
r.instances[id] = nw
w = nw
r.mu.Unlock()
}
return w.Invoke(ctx, input)
}

View File

@ -1 +0,0 @@
package zltx

View File

@ -3,19 +3,66 @@ package zltx
import (
"ai_scheduler/internal/config"
toolZoarb "ai_scheduler/internal/domain/tools/zltx/order_after_reseller_batch"
"ai_scheduler/internal/domain/workflow/runtime"
"context"
"errors"
"github.com/cloudwego/eino/compose"
)
func init() {
runtime.Register("zltx.orderAfterSaleResellerBatch", func(d *runtime.Deps) (runtime.Workflow, error) {
return &orderAfterSaleResellerBatch{cfg: d.Conf.Tools.ZltxOrderAfterSaleResellerBatch}, nil
})
}
type orderAfterSaleResellerBatch struct {
cfg config.ToolConfig
}
// ID 返回工作流唯一标识
func (o *orderAfterSaleResellerBatch) ID() string { return "zltx.orderAfterSaleResellerBatch" }
// Schema 返回入参约束(用于校验/表单生成)
func (o *orderAfterSaleResellerBatch) Schema() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{"orderNumber": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}},
"required": []string{"orderNumber"},
}
}
// Invoke 调用原有编排工作流并规范化输出
func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
// 构建工作流
chain, err := o.buildWorkflow(ctx)
if err != nil {
return nil, err
}
var in OrderAfterSaleResellerBatchWorkflowInput
if v, ok := input["orderNumber"].([]string); ok {
in.OrderNumber = v
}
_, err = chain.Invoke(ctx, in)
if err != nil {
return nil, err
}
// 工作流 callback
// 不关心输出,全部在途中输出
return nil, nil
}
type OrderAfterSaleResellerBatchWorkflowInput struct {
OrderNumber []string `json:"orderNumber"`
}
var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空")
func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolConfig) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) {
// buildWorkflow 构建工作流
func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) {
// 定义工作流、出入参
c := compose.NewChain[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData]()
@ -29,7 +76,7 @@ func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolCo
// 2.调用工具
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in OrderAfterSaleResellerBatchWorkflowInput) (toolZoarb.OrderAfterSaleResellerBatchData, error) {
return toolZoarb.Call(ctx, cfg, in.OrderNumber)
return toolZoarb.Call(ctx, o.cfg, in.OrderNumber)
}))
// 3.结果映射与整形

View File

@ -1,46 +0,0 @@
package zltx
import (
"ai_scheduler/internal/config"
"context"
"errors"
"strings"
"testing"
)
func TestOrderAfterResellerBatch_InvalidOrderNumbers(t *testing.T) {
ctx := context.Background()
cfg := config.ToolConfig{}
run, err := BuildOrderAfterResellerBatchWorkflow(ctx, cfg)
if err != nil {
t.Fatalf("build workflow error: %v", err)
}
_, err = run.Invoke(ctx, OrderAfterSaleResellerBatchWorkflowInput{})
if err == nil {
t.Fatalf("expected error, got nil")
}
if !errors.Is(err, ErrInvalidOrderNumbers) {
t.Fatalf("expected ErrInvalidOrderNumbers, got %v", err)
}
}
func TestOrderAfterResellerBatch_ContextTokenRequired(t *testing.T) {
ctx := context.Background()
cfg := config.ToolConfig{}
run, err := BuildOrderAfterResellerBatchWorkflow(ctx, cfg)
if err != nil {
t.Fatalf("build workflow error: %v", err)
}
in := OrderAfterSaleResellerBatchWorkflowInput{OrderNumber: []string{"123"}}
_, err = run.Invoke(ctx, in)
if err == nil {
t.Fatalf("expected error, got nil")
}
if !strings.Contains(err.Error(), "token 未注入") {
t.Fatalf("expected contains 'token 未注入', got %v", err)
}
}

View File

@ -2,7 +2,6 @@ package tools
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/utils_ollama"
zltxtool "ai_scheduler/internal/tools/zltx"
@ -101,23 +100,23 @@ func (m *Manager) GetTool(name string) (entitys.Tool, bool) {
}
// GetAllTools 获取所有工具
func (m *Manager) GetAllTools() []entitys.Tool {
tools := make([]entitys.Tool, 0, len(m.tools))
for _, tool := range m.tools {
tools = append(tools, tool)
}
return tools
}
// func (m *Manager) GetAllTools() []entitys.Tool {
// tools := make([]entitys.Tool, 0, len(m.tools))
// for _, tool := range m.tools {
// tools = append(tools, tool)
// }
// return tools
// }
// GetToolDefinitions 获取所有工具定义
func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition {
definitions := make([]entitys.ToolDefinition, 0, len(m.tools))
for _, tool := range m.tools {
definitions = append(definitions, tool.Definition())
}
// // GetToolDefinitions 获取所有工具定义
// func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition {
// definitions := make([]entitys.ToolDefinition, 0, len(m.tools))
// for _, tool := range m.tools {
// definitions = append(definitions, tool.Definition())
// }
return definitions
}
// return definitions
// }
// ExecuteTool 执行工具
func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error {