fix: 1.eino 工作流方案调整 2.下游批量订单售后接入工作流
This commit is contained in:
parent
15c3febe76
commit
d310bf8104
|
|
@ -4,15 +4,16 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz"
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"ai_scheduler/internal/server"
|
||||
"ai_scheduler/internal/services"
|
||||
"ai_scheduler/internal/tools"
|
||||
"ai_scheduler/internal/tools_bot"
|
||||
"ai_scheduler/utils"
|
||||
"ai_scheduler/internal/biz"
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"ai_scheduler/internal/server"
|
||||
"ai_scheduler/internal/services"
|
||||
"ai_scheduler/internal/domain/workflow"
|
||||
"ai_scheduler/internal/tools"
|
||||
"ai_scheduler/internal/tools_bot"
|
||||
"ai_scheduler/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/google/wire"
|
||||
|
|
@ -20,16 +21,17 @@ import (
|
|||
|
||||
// InitializeApp 初始化应用程序
|
||||
func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) {
|
||||
panic(wire.Build(
|
||||
server.ProviderSetServer,
|
||||
llm.ProviderSet,
|
||||
tools.ProviderSetTools,
|
||||
pkg.ProviderSetClient,
|
||||
services.ProviderSetServices,
|
||||
biz.ProviderSetBiz,
|
||||
impl.ProviderImpl,
|
||||
utils.ProviderUtils,
|
||||
tools_bot.ProviderSetBotTools,
|
||||
))
|
||||
panic(wire.Build(
|
||||
server.ProviderSetServer,
|
||||
llm.ProviderSet,
|
||||
workflow.ProviderSetWorkflow,
|
||||
tools.ProviderSetTools,
|
||||
pkg.ProviderSetClient,
|
||||
services.ProviderSetServices,
|
||||
biz.ProviderSetBiz,
|
||||
impl.ProviderImpl,
|
||||
utils.ProviderUtils,
|
||||
tools_bot.ProviderSetBotTools,
|
||||
))
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import (
|
|||
errors "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/domain/workflow/zltx"
|
||||
"ai_scheduler/internal/domain/workflow/runtime"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/gateway"
|
||||
"ai_scheduler/internal/pkg/l_request"
|
||||
|
|
@ -24,11 +24,12 @@ import (
|
|||
)
|
||||
|
||||
type Handle struct {
|
||||
Ollama *llm_service.OllamaService
|
||||
toolManager *tools.Manager
|
||||
Bot *tools_bot.BotTool
|
||||
conf *config.Config
|
||||
sessionImpl *impl.SessionImpl
|
||||
Ollama *llm_service.OllamaService
|
||||
toolManager *tools.Manager
|
||||
Bot *tools_bot.BotTool
|
||||
conf *config.Config
|
||||
sessionImpl *impl.SessionImpl
|
||||
workflowManager *runtime.Registry
|
||||
}
|
||||
|
||||
func NewHandle(
|
||||
|
|
@ -37,13 +38,15 @@ func NewHandle(
|
|||
conf *config.Config,
|
||||
sessionImpl *impl.SessionImpl,
|
||||
dTalkBot *tools_bot.BotTool,
|
||||
workflowManager *runtime.Registry,
|
||||
) *Handle {
|
||||
return &Handle{
|
||||
Ollama: Ollama,
|
||||
toolManager: toolManager,
|
||||
conf: conf,
|
||||
sessionImpl: sessionImpl,
|
||||
Bot: dTalkBot,
|
||||
Ollama: Ollama,
|
||||
toolManager: toolManager,
|
||||
conf: conf,
|
||||
sessionImpl: sessionImpl,
|
||||
Bot: dTalkBot,
|
||||
workflowManager: workflowManager,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -285,17 +288,24 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require
|
|||
func (r *Handle) handleEinoWorkflow(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||||
// token 写入ctx
|
||||
ctx = util.SetTokenToContext(ctx, requireData.Auth)
|
||||
|
||||
// 构建工作流 - todo 示例,后续抽象出来
|
||||
zltxWorkflow, err := zltx.BuildOrderAfterResellerBatchWorkflow(ctx, r.conf.Tools.ZltxOrderAfterSaleResellerBatch)
|
||||
if err != nil {
|
||||
return
|
||||
// 解析入参:workflow_id 与 input
|
||||
var params map[string]any
|
||||
if len(requireData.Match.Parameters) > 0 {
|
||||
_ = json.Unmarshal([]byte(requireData.Match.Parameters), ¶ms)
|
||||
}
|
||||
|
||||
// 工作流执行
|
||||
_, err = zltxWorkflow.Invoke(ctx, zltx.OrderAfterSaleResellerBatchWorkflowInput{})
|
||||
|
||||
return
|
||||
wfID, _ := params["workflow_id"].(string)
|
||||
input, _ := params["input"].(map[string]any)
|
||||
if wfID == "" {
|
||||
return fmt.Errorf("workflow_id 不能为空")
|
||||
}
|
||||
entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在执行工作流")
|
||||
res, err := r.workflowManager.Invoke(ctx, wfID, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b, _ := json.Marshal(res)
|
||||
entitys.ResJson(requireData.Ch, requireData.Task.Index, string(b))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 权限验证
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
package biz
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz/do"
|
||||
"ai_scheduler/internal/biz/llm_service"
|
||||
|
||||
"github.com/google/wire"
|
||||
"ai_scheduler/internal/biz/do"
|
||||
"ai_scheduler/internal/biz/llm_service"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSetBiz = wire.NewSet(
|
||||
|
|
@ -15,7 +15,7 @@ var ProviderSetBiz = wire.NewSet(
|
|||
llm_service.NewOllamaGenerate,
|
||||
//handle.NewHandle,
|
||||
do.NewDo,
|
||||
do.NewHandle,
|
||||
do.NewHandle,
|
||||
NewTaskBiz,
|
||||
NewDingTalkBotBiz,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
package workflow
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
package workflow
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/domain/workflow/runtime"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSetWorkflow = wire.NewSet(NewRegistry)
|
||||
|
||||
// NewRegistry 注入共享依赖并注册默认 Registry,确保自注册工作流可被发现
|
||||
func NewRegistry(conf *config.Config, llm *utils_ollama.Client) *runtime.Registry {
|
||||
// 步骤1:设置运行时依赖(配置与LLM客户端),供工作流工厂在首次实例化时使用;必须在任何调用 Invoke 之前完成,否则会触发 "deps not set"
|
||||
runtime.SetDeps(&runtime.Deps{Conf: conf, LLM: llm})
|
||||
// 步骤2:创建新的工作流注册表;注册表负责按工作流ID惰性实例化并缓存单例实例,保障并发访问下的安全
|
||||
r := runtime.NewRegistry()
|
||||
// 步骤3:将该注册表设置为全局默认,便于通过 runtime.Default() 获取;自注册的工作流可通过默认注册表被发现并调用
|
||||
runtime.SetDefault(r)
|
||||
return r
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
package workflow
|
||||
|
||||
import (
|
||||
// 手工维护:在此空导入工作流包以触发其 init() 自注册
|
||||
// 新增工作流时,只需在这里添加一行 `_ "<import_path>"`
|
||||
_ "ai_scheduler/internal/domain/workflow/zltx"
|
||||
)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
package workflow
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
)
|
||||
|
||||
// 仅声明依赖结构,避免在 workflow 包内实现注册中心逻辑导致循环依赖
|
||||
type Deps struct {
|
||||
Conf *config.Config
|
||||
LLM *utils_ollama.Client
|
||||
}
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
package runtime
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Workflow interface {
|
||||
ID() string
|
||||
Schema() map[string]any
|
||||
Invoke(ctx context.Context, input map[string]any) (map[string]any, error)
|
||||
}
|
||||
|
||||
type Deps struct {
|
||||
Conf *config.Config
|
||||
LLM *utils_ollama.Client
|
||||
}
|
||||
|
||||
type Factory func(deps *Deps) (Workflow, error)
|
||||
|
||||
var (
|
||||
regMu sync.RWMutex
|
||||
factories = map[string]Factory{}
|
||||
deps *Deps
|
||||
defaultReg *Registry
|
||||
)
|
||||
|
||||
func Register(id string, f Factory) {
|
||||
regMu.Lock()
|
||||
factories[id] = f
|
||||
regMu.Unlock()
|
||||
}
|
||||
|
||||
func SetDeps(d *Deps) {
|
||||
regMu.Lock()
|
||||
deps = d
|
||||
regMu.Unlock()
|
||||
}
|
||||
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
instances map[string]Workflow
|
||||
}
|
||||
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{instances: make(map[string]Workflow)}
|
||||
}
|
||||
|
||||
func SetDefault(r *Registry) {
|
||||
regMu.Lock()
|
||||
defaultReg = r
|
||||
regMu.Unlock()
|
||||
}
|
||||
|
||||
func Default() *Registry {
|
||||
regMu.RLock()
|
||||
r := defaultReg
|
||||
regMu.RUnlock()
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Registry) Invoke(ctx context.Context, id string, input map[string]any) (map[string]any, error) {
|
||||
if input == nil {
|
||||
input = map[string]any{}
|
||||
}
|
||||
regMu.RLock()
|
||||
f, ok := factories[id]
|
||||
regMu.RUnlock()
|
||||
if !ok {
|
||||
return nil, errors.New("workflow not found: " + id)
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
w, exists := r.instances[id]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
if deps == nil {
|
||||
return nil, errors.New("deps not set")
|
||||
}
|
||||
nw, err := f(deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.mu.Lock()
|
||||
r.instances[id] = nw
|
||||
w = nw
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
return w.Invoke(ctx, input)
|
||||
}
|
||||
|
|
@ -1 +0,0 @@
|
|||
package zltx
|
||||
|
|
@ -3,19 +3,66 @@ package zltx
|
|||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
toolZoarb "ai_scheduler/internal/domain/tools/zltx/order_after_reseller_batch"
|
||||
"ai_scheduler/internal/domain/workflow/runtime"
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
func init() {
|
||||
runtime.Register("zltx.orderAfterSaleResellerBatch", func(d *runtime.Deps) (runtime.Workflow, error) {
|
||||
return &orderAfterSaleResellerBatch{cfg: d.Conf.Tools.ZltxOrderAfterSaleResellerBatch}, nil
|
||||
})
|
||||
}
|
||||
|
||||
type orderAfterSaleResellerBatch struct {
|
||||
cfg config.ToolConfig
|
||||
}
|
||||
|
||||
// ID 返回工作流唯一标识
|
||||
func (o *orderAfterSaleResellerBatch) ID() string { return "zltx.orderAfterSaleResellerBatch" }
|
||||
|
||||
// Schema 返回入参约束(用于校验/表单生成)
|
||||
func (o *orderAfterSaleResellerBatch) Schema() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{"orderNumber": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}},
|
||||
"required": []string{"orderNumber"},
|
||||
}
|
||||
}
|
||||
|
||||
// Invoke 调用原有编排工作流并规范化输出
|
||||
func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
// 构建工作流
|
||||
chain, err := o.buildWorkflow(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var in OrderAfterSaleResellerBatchWorkflowInput
|
||||
if v, ok := input["orderNumber"].([]string); ok {
|
||||
in.OrderNumber = v
|
||||
}
|
||||
_, err = chain.Invoke(ctx, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 工作流 callback
|
||||
|
||||
// 不关心输出,全部在途中输出
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type OrderAfterSaleResellerBatchWorkflowInput struct {
|
||||
OrderNumber []string `json:"orderNumber"`
|
||||
}
|
||||
|
||||
var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空")
|
||||
|
||||
func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolConfig) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) {
|
||||
// buildWorkflow 构建工作流
|
||||
func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) {
|
||||
// 定义工作流、出入参
|
||||
c := compose.NewChain[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData]()
|
||||
|
||||
|
|
@ -29,7 +76,7 @@ func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolCo
|
|||
|
||||
// 2.调用工具
|
||||
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in OrderAfterSaleResellerBatchWorkflowInput) (toolZoarb.OrderAfterSaleResellerBatchData, error) {
|
||||
return toolZoarb.Call(ctx, cfg, in.OrderNumber)
|
||||
return toolZoarb.Call(ctx, o.cfg, in.OrderNumber)
|
||||
}))
|
||||
|
||||
// 3.结果映射与整形
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
package zltx
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOrderAfterResellerBatch_InvalidOrderNumbers(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cfg := config.ToolConfig{}
|
||||
|
||||
run, err := BuildOrderAfterResellerBatchWorkflow(ctx, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("build workflow error: %v", err)
|
||||
}
|
||||
|
||||
_, err = run.Invoke(ctx, OrderAfterSaleResellerBatchWorkflowInput{})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidOrderNumbers) {
|
||||
t.Fatalf("expected ErrInvalidOrderNumbers, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderAfterResellerBatch_ContextTokenRequired(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cfg := config.ToolConfig{}
|
||||
|
||||
run, err := BuildOrderAfterResellerBatchWorkflow(ctx, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("build workflow error: %v", err)
|
||||
}
|
||||
|
||||
in := OrderAfterSaleResellerBatchWorkflowInput{OrderNumber: []string{"123"}}
|
||||
_, err = run.Invoke(ctx, in)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "token 未注入") {
|
||||
t.Fatalf("expected contains 'token 未注入', got %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -2,7 +2,6 @@ package tools
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
zltxtool "ai_scheduler/internal/tools/zltx"
|
||||
|
|
@ -101,23 +100,23 @@ func (m *Manager) GetTool(name string) (entitys.Tool, bool) {
|
|||
}
|
||||
|
||||
// GetAllTools 获取所有工具
|
||||
func (m *Manager) GetAllTools() []entitys.Tool {
|
||||
tools := make([]entitys.Tool, 0, len(m.tools))
|
||||
for _, tool := range m.tools {
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
return tools
|
||||
}
|
||||
// func (m *Manager) GetAllTools() []entitys.Tool {
|
||||
// tools := make([]entitys.Tool, 0, len(m.tools))
|
||||
// for _, tool := range m.tools {
|
||||
// tools = append(tools, tool)
|
||||
// }
|
||||
// return tools
|
||||
// }
|
||||
|
||||
// GetToolDefinitions 获取所有工具定义
|
||||
func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition {
|
||||
definitions := make([]entitys.ToolDefinition, 0, len(m.tools))
|
||||
for _, tool := range m.tools {
|
||||
definitions = append(definitions, tool.Definition())
|
||||
}
|
||||
// // GetToolDefinitions 获取所有工具定义
|
||||
// func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition {
|
||||
// definitions := make([]entitys.ToolDefinition, 0, len(m.tools))
|
||||
// for _, tool := range m.tools {
|
||||
// definitions = append(definitions, tool.Definition())
|
||||
// }
|
||||
|
||||
return definitions
|
||||
}
|
||||
// return definitions
|
||||
// }
|
||||
|
||||
// ExecuteTool 执行工具
|
||||
func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error {
|
||||
|
|
|
|||
Loading…
Reference in New Issue