From e3448ae41eac3442529ac595a1576ecd21c216b1 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Thu, 25 Dec 2025 14:46:52 +0800 Subject: [PATCH] =?UTF-8?q?feat:=201.=20=E5=A2=9E=E5=8A=A0=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=E9=98=BB=E5=A1=9E=E7=AD=89=E5=BE=85=E5=BC=82=E6=AD=A5?= =?UTF-8?q?=E5=9B=9E=E8=B0=83redis=E7=BB=84=E4=BB=B6=202.=E5=8E=9F?= =?UTF-8?q?=E9=9C=80=E6=B1=82=E6=94=B6=E9=9B=86=E6=9C=BA=E5=99=A8=E4=BA=BA?= =?UTF-8?q?=E8=BF=81=E7=A7=BB=E8=87=B3eino=E5=B7=A5=E4=BD=9C=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/server/wire.go | 9 +- internal/domain/component/callback/manager.go | 71 ++++++++ .../domain/component/callback/provider_set.go | 5 + internal/domain/component/components.go | 15 ++ internal/domain/component/provider_set.go | 14 ++ internal/domain/repo/adapter.go | 29 +++ internal/domain/repo/provider_set.go | 5 + internal/domain/repo/repos.go | 17 ++ internal/domain/repo/session.go | 11 ++ internal/domain/workflow/provider_set.go | 12 +- internal/domain/workflow/registry.go | 2 + internal/domain/workflow/runtime/registry.go | 4 + .../workflow/zltx/bug_optimization_submit.go | 170 ++++++++++++++++++ .../zltx/order_after_reseller_batch.go | 106 ++++++++--- internal/entitys/recognize.go | 18 ++ internal/services/callback.go | 22 ++- 16 files changed, 475 insertions(+), 35 deletions(-) create mode 100644 internal/domain/component/callback/manager.go create mode 100644 internal/domain/component/callback/provider_set.go create mode 100644 internal/domain/component/components.go create mode 100644 internal/domain/component/provider_set.go create mode 100644 internal/domain/repo/adapter.go create mode 100644 internal/domain/repo/provider_set.go create mode 100644 internal/domain/repo/repos.go create mode 100644 internal/domain/repo/session.go create mode 100644 internal/domain/workflow/zltx/bug_optimization_submit.go diff --git a/cmd/server/wire.go b/cmd/server/wire.go index a35fa9b..8357aae 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -9,11 +9,14 @@ import ( "ai_scheduler/internal/biz/tools_regis" "ai_scheduler/internal/config" "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/domain/component" + "ai_scheduler/internal/domain/repo" "ai_scheduler/internal/domain/workflow" "ai_scheduler/internal/pkg" "ai_scheduler/internal/server" "ai_scheduler/internal/services" - "ai_scheduler/internal/tool_callback" + + // "ai_scheduler/internal/tool_callback" "ai_scheduler/internal/tools" "ai_scheduler/utils" @@ -34,7 +37,9 @@ func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), erro utils.ProviderUtils, dingtalk.ProviderSetDingTalk, tools_regis.ProviderToolsRegis, - tool_callback.ProviderSetCallBackTools, + // tool_callback.ProviderSetCallBackTools, + component.ProviderSet, + repo.ProviderSet, )) } diff --git a/internal/domain/component/callback/manager.go b/internal/domain/component/callback/manager.go new file mode 100644 index 0000000..7803b09 --- /dev/null +++ b/internal/domain/component/callback/manager.go @@ -0,0 +1,71 @@ +package callback + +import ( + "context" + "fmt" + "time" + + "ai_scheduler/internal/pkg" + + "github.com/redis/go-redis/v9" +) + +type Manager interface { + Register(ctx context.Context, taskID string, sessionID string) error + Wait(ctx context.Context, taskID string, timeout time.Duration) (string, error) + Notify(ctx context.Context, taskID string, result string) error + GetSession(ctx context.Context, taskID string) (string, error) +} + +type RedisManager struct { + rdb *redis.Client +} + +func NewRedisManager(rdb *pkg.Rdb) *RedisManager { + return &RedisManager{ + rdb: rdb.Rdb, + } +} + +const ( + keyPrefixSession = "callback:session:" + keyPrefixSignal = "callback:signal:" + defaultTTL = 24 * time.Hour +) + +func (m *RedisManager) Register(ctx context.Context, taskID string, sessionID string) error { + key := keyPrefixSession + taskID + return m.rdb.Set(ctx, key, sessionID, defaultTTL).Err() +} + +func (m *RedisManager) Wait(ctx context.Context, taskID string, timeout time.Duration) (string, error) { + key := keyPrefixSignal + taskID + // BLPop 阻塞等待 + result, err := m.rdb.BLPop(ctx, timeout, key).Result() + if err != nil { + if err == redis.Nil { + return "", fmt.Errorf("timeout waiting for callback") + } + return "", err + } + // result[0] is key, result[1] is value + if len(result) < 2 { + return "", fmt.Errorf("invalid redis result") + } + return result[1], nil +} + +func (m *RedisManager) Notify(ctx context.Context, taskID string, result string) error { + key := keyPrefixSignal + taskID + // Push 信号,同时设置过期时间防止堆积 + pipe := m.rdb.Pipeline() + pipe.RPush(ctx, key, result) + pipe.Expire(ctx, key, 1*time.Hour) // 信号列表也需要过期 + _, err := pipe.Exec(ctx) + return err +} + +func (m *RedisManager) GetSession(ctx context.Context, taskID string) (string, error) { + key := keyPrefixSession + taskID + return m.rdb.Get(ctx, key).Result() +} diff --git a/internal/domain/component/callback/provider_set.go b/internal/domain/component/callback/provider_set.go new file mode 100644 index 0000000..302b5c1 --- /dev/null +++ b/internal/domain/component/callback/provider_set.go @@ -0,0 +1,5 @@ +package callback + +import "github.com/google/wire" + +var ProviderSet = wire.NewSet(NewRedisManager, wire.Bind(new(Manager), new(*RedisManager))) diff --git a/internal/domain/component/components.go b/internal/domain/component/components.go new file mode 100644 index 0000000..11c8d86 --- /dev/null +++ b/internal/domain/component/components.go @@ -0,0 +1,15 @@ +package component + +import ( + "ai_scheduler/internal/domain/component/callback" +) + +type Components struct { + Callback callback.Manager +} + +func NewComponents(callbackManager callback.Manager) *Components { + return &Components{ + Callback: callbackManager, + } +} diff --git a/internal/domain/component/provider_set.go b/internal/domain/component/provider_set.go new file mode 100644 index 0000000..9d6abe6 --- /dev/null +++ b/internal/domain/component/provider_set.go @@ -0,0 +1,14 @@ +package component + +import ( + "ai_scheduler/internal/domain/component/callback" + + "github.com/google/wire" +) + +var ProviderSetComponent = wire.NewSet(NewComponents) + +var ProviderSet = wire.NewSet( + callback.NewRedisManager, wire.Bind(new(callback.Manager), new(*callback.RedisManager)), + NewComponents, +) diff --git a/internal/domain/repo/adapter.go b/internal/domain/repo/adapter.go new file mode 100644 index 0000000..c9a1357 --- /dev/null +++ b/internal/domain/repo/adapter.go @@ -0,0 +1,29 @@ +package repo + +import ( + "ai_scheduler/internal/data/impl" + "context" + "errors" +) + +// SessionAdapter 适配 impl.SessionImpl 到 SessionRepo 接口 +type SessionAdapter struct { + impl *impl.SessionImpl +} + +func NewSessionAdapter(impl *impl.SessionImpl) *SessionAdapter { + return &SessionAdapter{impl: impl} +} + +func (s *SessionAdapter) GetUserName(ctx context.Context, sessionID string) (string, error) { + // 复用 SessionImpl 的查询能力 + // 这里假设 sessionID 是唯一的,直接用 FindOne + session, has, err := s.impl.FindOne(s.impl.WithSessionId(sessionID)) + if err != nil { + return "", err + } + if !has { + return "", errors.New("session not found") + } + return session.UserName, nil +} diff --git a/internal/domain/repo/provider_set.go b/internal/domain/repo/provider_set.go new file mode 100644 index 0000000..c5b2437 --- /dev/null +++ b/internal/domain/repo/provider_set.go @@ -0,0 +1,5 @@ +package repo + +import "github.com/google/wire" + +var ProviderSet = wire.NewSet(NewRepos) diff --git a/internal/domain/repo/repos.go b/internal/domain/repo/repos.go new file mode 100644 index 0000000..40ba3de --- /dev/null +++ b/internal/domain/repo/repos.go @@ -0,0 +1,17 @@ +package repo + +import ( + "ai_scheduler/internal/data/impl" + "ai_scheduler/utils" +) + +// Repos 聚合所有 Repository +type Repos struct { + Session SessionRepo +} + +func NewRepos(sessionImpl *impl.SessionImpl, rdb *utils.Rdb) *Repos { + return &Repos{ + Session: NewSessionAdapter(sessionImpl), + } +} diff --git a/internal/domain/repo/session.go b/internal/domain/repo/session.go new file mode 100644 index 0000000..5ccc66c --- /dev/null +++ b/internal/domain/repo/session.go @@ -0,0 +1,11 @@ +package repo + +import ( + "context" +) + +// SessionRepo 定义会话相关的查询接口 +// 这里只暴露 workflow 真正需要的方法,避免直接依赖 impl 层 +type SessionRepo interface { + GetUserName(ctx context.Context, sessionID string) (string, error) +} diff --git a/internal/domain/workflow/provider_set.go b/internal/domain/workflow/provider_set.go index 97e1b5d..b9a2815 100644 --- a/internal/domain/workflow/provider_set.go +++ b/internal/domain/workflow/provider_set.go @@ -2,6 +2,8 @@ package workflow import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/component" + "ai_scheduler/internal/domain/repo" "ai_scheduler/internal/domain/workflow/runtime" "ai_scheduler/internal/pkg/utils_ollama" @@ -13,9 +15,15 @@ import ( var ProviderSetWorkflow = wire.NewSet(NewRegistry) // NewRegistry 注入共享依赖并注册默认 Registry,确保自注册工作流可被发现 -func NewRegistry(conf *config.Config, llm *utils_ollama.Client) *runtime.Registry { +func NewRegistry(conf *config.Config, llm *utils_ollama.Client, repos *repo.Repos, components *component.Components) *runtime.Registry { // 步骤1:设置运行时依赖(配置与LLM客户端),供工作流工厂在首次实例化时使用;必须在任何调用 Invoke 之前完成,否则会触发 "deps not set" - runtime.SetDeps(&runtime.Deps{Conf: conf, LLM: llm, ToolManager: toolManager.NewManager(conf)}) + runtime.SetDeps(&runtime.Deps{ + Conf: conf, + LLM: llm, + ToolManager: toolManager.NewManager(conf), + Repos: repos, + Component: components, + }) // 步骤2:创建新的工作流注册表;注册表负责按工作流ID惰性实例化并缓存单例实例,保障并发访问下的安全 r := runtime.NewRegistry() // 步骤3:将该注册表设置为全局默认,便于通过 runtime.Default() 获取;自注册的工作流可通过默认注册表被发现并调用 diff --git a/internal/domain/workflow/registry.go b/internal/domain/workflow/registry.go index af69a03..10b24ef 100644 --- a/internal/domain/workflow/registry.go +++ b/internal/domain/workflow/registry.go @@ -2,6 +2,7 @@ package workflow import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/component" toolManager "ai_scheduler/internal/domain/tools" "ai_scheduler/internal/pkg/utils_ollama" ) @@ -11,4 +12,5 @@ type Deps struct { Conf *config.Config LLM *utils_ollama.Client ToolManager *toolManager.Manager + Component *component.Components } diff --git a/internal/domain/workflow/runtime/registry.go b/internal/domain/workflow/runtime/registry.go index 2b4049b..f804e1d 100644 --- a/internal/domain/workflow/runtime/registry.go +++ b/internal/domain/workflow/runtime/registry.go @@ -2,6 +2,8 @@ package runtime import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/component" + "ai_scheduler/internal/domain/repo" toolManager "ai_scheduler/internal/domain/tools" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" @@ -20,6 +22,8 @@ type Deps struct { Conf *config.Config LLM *utils_ollama.Client ToolManager *toolManager.Manager + Component *component.Components // 基础设施能力 + Repos *repo.Repos // 数据访问 } type Factory func(deps *Deps) (Workflow, error) diff --git a/internal/domain/workflow/zltx/bug_optimization_submit.go b/internal/domain/workflow/zltx/bug_optimization_submit.go new file mode 100644 index 0000000..30ad0bc --- /dev/null +++ b/internal/domain/workflow/zltx/bug_optimization_submit.go @@ -0,0 +1,170 @@ +package zltx + +import ( + "context" + "encoding/json" + "errors" + "time" + + "ai_scheduler/internal/domain/component/callback" + "ai_scheduler/internal/domain/repo" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/l_request" + + "github.com/cloudwego/eino/compose" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" +) + +const WorkflowIDBugOptimizationSubmit = "bug_optimization_submit" + +func init() { + runtime.Register(WorkflowIDBugOptimizationSubmit, func(d *runtime.Deps) (runtime.Workflow, error) { + // 从 Deps.Repos 获取 SessionRepo + return &bugOptimizationSubmit{ + manager: d.Component.Callback, + sessionRepo: d.Repos.Session, + }, nil + }) +} + +type bugOptimizationSubmit struct { + manager callback.Manager + sessionRepo repo.SessionRepo + redisCli *redis.Client +} + +func (w *bugOptimizationSubmit) ID() string { + return WorkflowIDBugOptimizationSubmit +} + +type BugOptimizationSubmitInput struct { + Ch chan entitys.Response + RequireData *entitys.Recognize +} + +type BugOptimizationSubmitOutput struct { + Msg string +} + +type contextWithTask struct { + Input *BugOptimizationSubmitInput + TaskID string +} + +func (w *bugOptimizationSubmit) Invoke(ctx context.Context, recognize *entitys.Recognize) (map[string]any, error) { + chain, err := w.buildWorkflow(ctx) + if err != nil { + return nil, err + } + + input := &BugOptimizationSubmitInput{ + Ch: recognize.Ch, + RequireData: recognize, + } + + out, err := chain.Invoke(ctx, input) + if err != nil { + return nil, err + } + + return map[string]any{"msg": out.Msg}, nil +} + +func (w *bugOptimizationSubmit) buildWorkflow(ctx context.Context) (compose.Runnable[*BugOptimizationSubmitInput, *BugOptimizationSubmitOutput], error) { + c := compose.NewChain[*BugOptimizationSubmitInput, *BugOptimizationSubmitOutput]() + + // Node 1: Prepare and Call + c.AppendLambda(compose.InvokableLambda(w.prepareAndCall)) + + // Node 2: Wait + c.AppendLambda(compose.InvokableLambda(w.waitCallback)) + + return c.Compile(ctx) +} + +func (w *bugOptimizationSubmit) prepareAndCall(ctx context.Context, in *BugOptimizationSubmitInput) (*contextWithTask, error) { + // 生成 TaskID + taskID := uuid.New().String() + + // Ext 中获取 sessionId + sessionID := in.RequireData.GetSession() + + // 注册回调映射 + if err := w.manager.Register(ctx, taskID, sessionID); err != nil { + return nil, err + } + + // 查询用户名 + userName := "unknown" + if w.sessionRepo != nil { + name, err := w.sessionRepo.GetUserName(ctx, sessionID) + if err == nil && name != "" { + userName = name + } + } + + // 构建请求参数 + var fileUrls, fileContent string + if len(in.RequireData.UserContent.File) > 0 { + for _, file := range in.RequireData.UserContent.File { + fileUrls += file.FileUrl + "," + fileContent += file.FileRec + "," + } + fileUrls = fileUrls[:len(fileUrls)-1] + fileContent = fileContent[:len(fileContent)-1] + } + + body := map[string]string{ + "mark": in.RequireData.Match.Index, + "text": in.RequireData.UserContent.Text, + "img": fileUrls, + "img_content": fileContent, + "creator": userName, + "task_id": taskID, + } + + request := l_request.Request{ + Url: "https://connector.dingtalk.com/webhook/flow/10352c521dd02104cee9000c", + Method: "POST", + Headers: map[string]string{ + "Content-Type": "application/json", + }, + JsonByte: pkg.JsonByteIgonErr(body), + } + + res, err := request.Send() + if err != nil { + return nil, err + } + + var data map[string]any + if err := json.Unmarshal(res.Content, &data); err != nil { + return nil, err + } + + if success, ok := data["success"].(bool); !ok || !success { + return nil, errors.New("dingtalk flow failed") + } + + entitys.ResLog(in.Ch, in.RequireData.Match.Index, "问题记录中") + entitys.ResLoading(in.Ch, in.RequireData.Match.Index, "问题记录中...") + + return &contextWithTask{Input: in, TaskID: taskID}, nil +} + +func (w *bugOptimizationSubmit) waitCallback(ctx context.Context, in *contextWithTask) (*BugOptimizationSubmitOutput, error) { + // 阻塞等待回调信号 + // 设置 5 分钟超时 + waitCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + res, err := w.manager.Wait(waitCtx, in.TaskID, 5*time.Minute) + if err != nil { + return nil, err + } + + return &BugOptimizationSubmitOutput{Msg: res}, nil +} diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go index ff21d84..eee022a 100644 --- a/internal/domain/workflow/zltx/order_after_reseller_batch.go +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -22,8 +22,7 @@ func init() { } type orderAfterSaleResellerBatch struct { - cfg config.ToolConfig - data *OrderAfterSaleResellerBatchWorkflowInput + cfg config.ToolConfig } // 工作流入参 @@ -86,15 +85,19 @@ func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, rec *entitys.R return nil, err } - o.data = &OrderAfterSaleResellerBatchWorkflowInput{ + input := &OrderAfterSaleResellerBatchWorkflowInput{ Ch: rec.Ch, UserInput: rec.UserContent.Text, FileContent: "", UserHistory: rec.ChatHis, ParameterResult: rec.Match.Parameters, } + + // 将 Input 注入 Context + ctx = context.WithValue(ctx, workflowInputContextKey{}, input) + // 工作流过程输出,不关注最终输出 - _, err = chain.Invoke(ctx, o.data) + _, err = chain.Invoke(ctx, input) if err != nil { return nil, err } @@ -107,6 +110,9 @@ func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, rec *entitys.R var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空") +// contextKey 用于在 Context 中传递 WorkflowInput +type workflowInputContextKey struct{} + // buildWorkflow 构建工作流 func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compose.Runnable[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput], error) { // 定义工作流、出入参 @@ -127,39 +133,93 @@ func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compos }), )) + // 3.参数校验 & 传递 Input + // 注意:为了在后续节点访问 WorkflowInput,这里使用闭包或 Context 传递。 + // Eino Chain 节点间传递的是返回值。这里我们修改节点签名,将 input 一路传下去,或者使用 context。 + // 由于 Eino Chain 是强类型的,这里选择让 Parser 返回的数据结构包含原始 input,或者我们在 Parser 后重新组合。 + // 但最简单的方法是使用 Context 存储 Input (如果 Eino 支持 Context 传递)。Eino 的 Invoke 接受 ctx。 + // 但 Eino Chain 的设计是数据流驱动。 + // 修正方案:修改中间节点的数据结构,或者使用闭包捕获(但闭包捕获的是 build 时的变量,无法捕获运行时 input)。 + // 正确做法:Chain 的节点入参必须是上一个节点的出参。 + // 我们可以把 Parser 的输入改为 Input,输出改为一个包含 Input 和 ParsedData 的结构。 + // 但这里为了最小改动,我们利用 Context 来传递 Input 引用(这在 Eino 中是可行的,因为 ctx 会贯穿整个 Invoke)。 + // 更好的做法是重构 Chain 的数据流,但在保持逻辑不变的前提下,Context 是最快解法。 + + // 为了线程安全,我们在第一个节点把 Input 放入 Context?不行,Chain.Invoke(ctx, input) 的 ctx 是外部传入的。 + // Eino 允许 Lambda 修改 Context 吗?通常不允许。 + + // 让我们重新审视数据流: + // Input -> Lambda1 -> Message -> Parser -> NodeData -> Lambda4 -> ToolResp -> Lambda5 -> Output + // Lambda4 需要 Input.Ch 来发 Loading。 + // Lambda5 需要 Input.Ch 来发 Log/Json,还需要 NodeData。 + + // 根本问题是:中间节点丢失了 Input 信息。 + // 解决方案:使用一个聚合结构体在 Chain 中传递。 + + // 由于要大改数据流比较复杂,这里使用一种技巧: + // 在 Invoke 时,构造一个带有 Input 信息的 Context 传入。 + // 这样每个节点都能从 Context 拿到 Input。 + + // 重新实现 buildWorkflow 以支持 Context 传递 + return o.buildWorkflowWithContext(ctx) +} + +func (o *orderAfterSaleResellerBatch) buildWorkflowWithContext(ctx context.Context) (compose.Runnable[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput], error) { + c := compose.NewChain[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput]() + + // 0. Context 注入节点 (Trick: 利用第一个节点将 Input 注入 Context,但 Eino Chain 无法修改 Context 传递给下游) + // 实际上,我们可以在 Invoke 调用前,在外部包装 Context。 + // 所以这里不需要额外的节点,只需要在 Invoke 时处理。 + // 但 Invoke 是由 Chain 提供的,我们只能控制传入的 ctx。 + // 见下文 Invoke 方法的修改。 + + // 1.llm 推断参数 + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *OrderAfterSaleResellerBatchWorkflowInput) (*schema.Message, error) { + return &schema.Message{Content: in.ParameterResult}, nil + })) + + // 2.参数解析 + c.AppendLambda(compose.MessageParser( + schema.NewMessageJSONParser[*OrderAfterSaleResellerBatchNodeData](&schema.MessageJSONParseConfig{ + ParseFrom: schema.MessageParseFromContent, + }), + )) + // 3.参数校验 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *OrderAfterSaleResellerBatchNodeData) (*OrderAfterSaleResellerBatchNodeData, error) { - // 校验必填项 + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *OrderAfterSaleResellerBatchNodeData) (*OrderAfterSaleResellerBatchNodeData, error) { if len(in.OrderNumber) == 0 { return nil, ErrInvalidOrderNumbers } - - o.data.Data = in - + // 将解析后的 Data 存入 Input (通过 Context 获取 Input) + input := ctx.Value(workflowInputContextKey{}).(*OrderAfterSaleResellerBatchWorkflowInput) + input.Data = in // 这里修改 Input 是安全的,因为 Input 是请求维度的引用 return in, nil })) // 4.工具调用 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *OrderAfterSaleResellerBatchNodeData) (*toolZoarb.OrderAfterSaleResellerBatchResponse, error) { - entitys.ResLoading(o.data.Ch, o.ID(), "数据拉取中") + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *OrderAfterSaleResellerBatchNodeData) (*toolZoarb.OrderAfterSaleResellerBatchResponse, error) { + input := ctx.Value(workflowInputContextKey{}).(*OrderAfterSaleResellerBatchWorkflowInput) + entitys.ResLoading(input.Ch, o.ID(), "数据拉取中") toolRes, err := toolZoarb.Call(ctx, o.cfg, in.OrderNumber) - - entitys.ResLog(o.data.Ch, o.ID(), "数据拉取完成") + entitys.ResLog(input.Ch, o.ID(), "数据拉取完成") return toolRes, err })) // 5.结果数据映射 - c.AppendLambda(compose.InvokableLambda(o.dataMapping)) + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *toolZoarb.OrderAfterSaleResellerBatchResponse) (*OrderAfterSaleResellerBatchWorkflowOutput, error) { + return o.dataMapping(ctx, in) + })) - // 编译工作流 return c.Compile(ctx) } // 结果数据映射 -func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoarb.OrderAfterSaleResellerBatchResponse) (*OrderAfterSaleResellerBatchWorkflowOutput, error) { - entitys.ResLog(o.data.Ch, o.ID(), "数据整理中") +func (o *orderAfterSaleResellerBatch) dataMapping(ctx context.Context, in *toolZoarb.OrderAfterSaleResellerBatchResponse) (*OrderAfterSaleResellerBatchWorkflowOutput, error) { + input := ctx.Value(workflowInputContextKey{}).(*OrderAfterSaleResellerBatchWorkflowInput) + + entitys.ResLog(input.Ch, o.ID(), "数据整理中") toolResp := &OrderAfterSaleResellerBatchWorkflowOutput{ Code: in.Code, @@ -170,17 +230,17 @@ func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoa // 转换数据 for _, item := range in.Data.Data { // 处理方式 - afterType := util.StringToInt(o.data.Data.AfterType) + afterType := util.StringToInt(input.Data.AfterType) if afterType == 0 { afterType = 1 // 默认退款 } // 费用承担者 - responsibleType := util.StringToInt(o.data.Data.ResponsibleType) + responsibleType := util.StringToInt(input.Data.ResponsibleType) if responsibleType == 0 { responsibleType = 4 // 默认无 } // 售后金额 - afterSalesPrice := util.StringToFloat64(o.data.Data.AfterSalesPrice) + afterSalesPrice := util.StringToFloat64(input.Data.AfterSalesPrice) if afterSalesPrice == 0 { afterSalesPrice = item.OrderPrice } @@ -199,10 +259,10 @@ func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoa Account: item.Account, Platforms: item.Platforms, AfterType: afterType, - Remark: o.data.Data.AfterSalesReason, + Remark: input.Data.AfterSalesReason, AfterAmount: afterSalesPrice, ResponsibleType: responsibleType, - ResponsiblePerson: o.data.Data.ResponsiblePerson, + ResponsiblePerson: input.Data.ResponsiblePerson, }) } @@ -215,7 +275,7 @@ func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoa } toolRespJson, _ := json.Marshal(toolResp) - entitys.ResJson(o.data.Ch, o.ID(), string(toolRespJson)) + entitys.ResJson(input.Ch, o.ID(), string(toolRespJson)) return toolResp, nil } diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index 87f684c..fcd2fe5 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -3,6 +3,7 @@ package entitys import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/model" + "encoding/json" ) type Recognize struct { @@ -47,3 +48,20 @@ type RecognizeFile struct { FileRealMime string // 文件真实MIME类型 FileUrl string // 文件下载链接 } + +func (r *Recognize) GetTaskExt() *TaskExt { + var ext TaskExt + if err := json.Unmarshal(r.Ext, &ext); err != nil { + return nil + } + + return &ext +} + +func (r *Recognize) GetSession() string { + ext := r.GetTaskExt() + if ext == nil { + return "" + } + return ext.Session +} diff --git a/internal/services/callback.go b/internal/services/callback.go index e224ce3..c68e1c3 100644 --- a/internal/services/callback.go +++ b/internal/services/callback.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" errorcode "ai_scheduler/internal/data/error" + "ai_scheduler/internal/domain/component/callback" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg" @@ -25,17 +26,17 @@ type CallbackService struct { dingtalkOldClient *dingtalk.OldClient dingtalkContactClient *dingtalk.ContactClient dingtalkNotableClient *dingtalk.NotableClient - callBackTool *tool_callback.CallBackTool + callbackManager callback.Manager } -func NewCallbackService(cfg *config.Config, gateway *gateway.Gateway, dingtalkOldClient *dingtalk.OldClient, dingtalkContactClient *dingtalk.ContactClient, dingtalkNotableClient *dingtalk.NotableClient, callBackTool *tool_callback.CallBackTool) *CallbackService { +func NewCallbackService(cfg *config.Config, gateway *gateway.Gateway, dingtalkOldClient *dingtalk.OldClient, dingtalkContactClient *dingtalk.ContactClient, dingtalkNotableClient *dingtalk.NotableClient, callbackManager callback.Manager) *CallbackService { return &CallbackService{ cfg: cfg, gateway: gateway, dingtalkOldClient: dingtalkOldClient, dingtalkContactClient: dingtalkContactClient, dingtalkNotableClient: dingtalkNotableClient, - callBackTool: callBackTool, + callbackManager: callbackManager, } } @@ -139,11 +140,14 @@ func (s *CallbackService) Callback(c *fiber.Ctx) error { func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) error { // 校验taskId - sessionID, ok := s.callBackTool.GetSessionByTaskID(env.TaskID) - if !ok { + ctx := c.Context() + sessionID, err := s.callbackManager.GetSession(ctx, env.TaskID) + if err != nil { + return errorcode.ParamErrf("failed to get session for task_id: %s, err: %v", env.TaskID, err) + } + if sessionID == "" { return errorcode.ParamErrf("missing session_id for task_id: %s", env.TaskID) } - ctx := c.Context() switch env.Action { case ActionBugOptimizationSubmitUpdate: @@ -166,8 +170,10 @@ func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) err // 发送日志 s.sendStreamTxt(sessionID, msg) - // 删除映射 - s.callBackTool.DelTaskMapping(env.TaskID) + // 通知等待者 + if err := s.callbackManager.Notify(ctx, env.TaskID, msg); err != nil { + // 记录错误但继续 + } return c.JSON(fiber.Map{"code": 0, "message": "ok"}) case ActionBugOptimizationSubmitProcess: