feat: 1. 增加同步阻塞等待异步回调redis组件 2.原需求收集机器人迁移至eino工作流
This commit is contained in:
parent
7682ecd75b
commit
e3448ae41e
|
|
@ -9,11 +9,14 @@ import (
|
|||
"ai_scheduler/internal/biz/tools_regis"
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/domain/component"
|
||||
"ai_scheduler/internal/domain/repo"
|
||||
"ai_scheduler/internal/domain/workflow"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"ai_scheduler/internal/server"
|
||||
"ai_scheduler/internal/services"
|
||||
"ai_scheduler/internal/tool_callback"
|
||||
|
||||
// "ai_scheduler/internal/tool_callback"
|
||||
"ai_scheduler/internal/tools"
|
||||
"ai_scheduler/utils"
|
||||
|
||||
|
|
@ -34,7 +37,9 @@ func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), erro
|
|||
utils.ProviderUtils,
|
||||
dingtalk.ProviderSetDingTalk,
|
||||
tools_regis.ProviderToolsRegis,
|
||||
tool_callback.ProviderSetCallBackTools,
|
||||
// tool_callback.ProviderSetCallBackTools,
|
||||
component.ProviderSet,
|
||||
repo.ProviderSet,
|
||||
))
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
package callback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"ai_scheduler/internal/pkg"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
Register(ctx context.Context, taskID string, sessionID string) error
|
||||
Wait(ctx context.Context, taskID string, timeout time.Duration) (string, error)
|
||||
Notify(ctx context.Context, taskID string, result string) error
|
||||
GetSession(ctx context.Context, taskID string) (string, error)
|
||||
}
|
||||
|
||||
type RedisManager struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewRedisManager(rdb *pkg.Rdb) *RedisManager {
|
||||
return &RedisManager{
|
||||
rdb: rdb.Rdb,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
keyPrefixSession = "callback:session:"
|
||||
keyPrefixSignal = "callback:signal:"
|
||||
defaultTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
func (m *RedisManager) Register(ctx context.Context, taskID string, sessionID string) error {
|
||||
key := keyPrefixSession + taskID
|
||||
return m.rdb.Set(ctx, key, sessionID, defaultTTL).Err()
|
||||
}
|
||||
|
||||
func (m *RedisManager) Wait(ctx context.Context, taskID string, timeout time.Duration) (string, error) {
|
||||
key := keyPrefixSignal + taskID
|
||||
// BLPop 阻塞等待
|
||||
result, err := m.rdb.BLPop(ctx, timeout, key).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return "", fmt.Errorf("timeout waiting for callback")
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
// result[0] is key, result[1] is value
|
||||
if len(result) < 2 {
|
||||
return "", fmt.Errorf("invalid redis result")
|
||||
}
|
||||
return result[1], nil
|
||||
}
|
||||
|
||||
func (m *RedisManager) Notify(ctx context.Context, taskID string, result string) error {
|
||||
key := keyPrefixSignal + taskID
|
||||
// Push 信号,同时设置过期时间防止堆积
|
||||
pipe := m.rdb.Pipeline()
|
||||
pipe.RPush(ctx, key, result)
|
||||
pipe.Expire(ctx, key, 1*time.Hour) // 信号列表也需要过期
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *RedisManager) GetSession(ctx context.Context, taskID string) (string, error) {
|
||||
key := keyPrefixSession + taskID
|
||||
return m.rdb.Get(ctx, key).Result()
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
package callback
|
||||
|
||||
import "github.com/google/wire"
|
||||
|
||||
var ProviderSet = wire.NewSet(NewRedisManager, wire.Bind(new(Manager), new(*RedisManager)))
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
package component
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/domain/component/callback"
|
||||
)
|
||||
|
||||
type Components struct {
|
||||
Callback callback.Manager
|
||||
}
|
||||
|
||||
func NewComponents(callbackManager callback.Manager) *Components {
|
||||
return &Components{
|
||||
Callback: callbackManager,
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
package component
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/domain/component/callback"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSetComponent = wire.NewSet(NewComponents)
|
||||
|
||||
var ProviderSet = wire.NewSet(
|
||||
callback.NewRedisManager, wire.Bind(new(callback.Manager), new(*callback.RedisManager)),
|
||||
NewComponents,
|
||||
)
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
package repo
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// SessionAdapter 适配 impl.SessionImpl 到 SessionRepo 接口
|
||||
type SessionAdapter struct {
|
||||
impl *impl.SessionImpl
|
||||
}
|
||||
|
||||
func NewSessionAdapter(impl *impl.SessionImpl) *SessionAdapter {
|
||||
return &SessionAdapter{impl: impl}
|
||||
}
|
||||
|
||||
func (s *SessionAdapter) GetUserName(ctx context.Context, sessionID string) (string, error) {
|
||||
// 复用 SessionImpl 的查询能力
|
||||
// 这里假设 sessionID 是唯一的,直接用 FindOne
|
||||
session, has, err := s.impl.FindOne(s.impl.WithSessionId(sessionID))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !has {
|
||||
return "", errors.New("session not found")
|
||||
}
|
||||
return session.UserName, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
package repo
|
||||
|
||||
import "github.com/google/wire"
|
||||
|
||||
var ProviderSet = wire.NewSet(NewRepos)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
package repo
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/utils"
|
||||
)
|
||||
|
||||
// Repos 聚合所有 Repository
|
||||
type Repos struct {
|
||||
Session SessionRepo
|
||||
}
|
||||
|
||||
func NewRepos(sessionImpl *impl.SessionImpl, rdb *utils.Rdb) *Repos {
|
||||
return &Repos{
|
||||
Session: NewSessionAdapter(sessionImpl),
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
package repo
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// SessionRepo 定义会话相关的查询接口
|
||||
// 这里只暴露 workflow 真正需要的方法,避免直接依赖 impl 层
|
||||
type SessionRepo interface {
|
||||
GetUserName(ctx context.Context, sessionID string) (string, error)
|
||||
}
|
||||
|
|
@ -2,6 +2,8 @@ package workflow
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/domain/component"
|
||||
"ai_scheduler/internal/domain/repo"
|
||||
"ai_scheduler/internal/domain/workflow/runtime"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
|
||||
|
|
@ -13,9 +15,15 @@ import (
|
|||
var ProviderSetWorkflow = wire.NewSet(NewRegistry)
|
||||
|
||||
// NewRegistry 注入共享依赖并注册默认 Registry,确保自注册工作流可被发现
|
||||
func NewRegistry(conf *config.Config, llm *utils_ollama.Client) *runtime.Registry {
|
||||
func NewRegistry(conf *config.Config, llm *utils_ollama.Client, repos *repo.Repos, components *component.Components) *runtime.Registry {
|
||||
// 步骤1:设置运行时依赖(配置与LLM客户端),供工作流工厂在首次实例化时使用;必须在任何调用 Invoke 之前完成,否则会触发 "deps not set"
|
||||
runtime.SetDeps(&runtime.Deps{Conf: conf, LLM: llm, ToolManager: toolManager.NewManager(conf)})
|
||||
runtime.SetDeps(&runtime.Deps{
|
||||
Conf: conf,
|
||||
LLM: llm,
|
||||
ToolManager: toolManager.NewManager(conf),
|
||||
Repos: repos,
|
||||
Component: components,
|
||||
})
|
||||
// 步骤2:创建新的工作流注册表;注册表负责按工作流ID惰性实例化并缓存单例实例,保障并发访问下的安全
|
||||
r := runtime.NewRegistry()
|
||||
// 步骤3:将该注册表设置为全局默认,便于通过 runtime.Default() 获取;自注册的工作流可通过默认注册表被发现并调用
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package workflow
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/domain/component"
|
||||
toolManager "ai_scheduler/internal/domain/tools"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
)
|
||||
|
|
@ -11,4 +12,5 @@ type Deps struct {
|
|||
Conf *config.Config
|
||||
LLM *utils_ollama.Client
|
||||
ToolManager *toolManager.Manager
|
||||
Component *component.Components
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ package runtime
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/domain/component"
|
||||
"ai_scheduler/internal/domain/repo"
|
||||
toolManager "ai_scheduler/internal/domain/tools"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
|
|
@ -20,6 +22,8 @@ type Deps struct {
|
|||
Conf *config.Config
|
||||
LLM *utils_ollama.Client
|
||||
ToolManager *toolManager.Manager
|
||||
Component *component.Components // 基础设施能力
|
||||
Repos *repo.Repos // 数据访问
|
||||
}
|
||||
|
||||
type Factory func(deps *Deps) (Workflow, error)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,170 @@
|
|||
package zltx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"ai_scheduler/internal/domain/component/callback"
|
||||
"ai_scheduler/internal/domain/repo"
|
||||
"ai_scheduler/internal/domain/workflow/runtime"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"ai_scheduler/internal/pkg/l_request"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const WorkflowIDBugOptimizationSubmit = "bug_optimization_submit"
|
||||
|
||||
func init() {
|
||||
runtime.Register(WorkflowIDBugOptimizationSubmit, func(d *runtime.Deps) (runtime.Workflow, error) {
|
||||
// 从 Deps.Repos 获取 SessionRepo
|
||||
return &bugOptimizationSubmit{
|
||||
manager: d.Component.Callback,
|
||||
sessionRepo: d.Repos.Session,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
type bugOptimizationSubmit struct {
|
||||
manager callback.Manager
|
||||
sessionRepo repo.SessionRepo
|
||||
redisCli *redis.Client
|
||||
}
|
||||
|
||||
func (w *bugOptimizationSubmit) ID() string {
|
||||
return WorkflowIDBugOptimizationSubmit
|
||||
}
|
||||
|
||||
type BugOptimizationSubmitInput struct {
|
||||
Ch chan entitys.Response
|
||||
RequireData *entitys.Recognize
|
||||
}
|
||||
|
||||
type BugOptimizationSubmitOutput struct {
|
||||
Msg string
|
||||
}
|
||||
|
||||
type contextWithTask struct {
|
||||
Input *BugOptimizationSubmitInput
|
||||
TaskID string
|
||||
}
|
||||
|
||||
func (w *bugOptimizationSubmit) Invoke(ctx context.Context, recognize *entitys.Recognize) (map[string]any, error) {
|
||||
chain, err := w.buildWorkflow(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
input := &BugOptimizationSubmitInput{
|
||||
Ch: recognize.Ch,
|
||||
RequireData: recognize,
|
||||
}
|
||||
|
||||
out, err := chain.Invoke(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]any{"msg": out.Msg}, nil
|
||||
}
|
||||
|
||||
func (w *bugOptimizationSubmit) buildWorkflow(ctx context.Context) (compose.Runnable[*BugOptimizationSubmitInput, *BugOptimizationSubmitOutput], error) {
|
||||
c := compose.NewChain[*BugOptimizationSubmitInput, *BugOptimizationSubmitOutput]()
|
||||
|
||||
// Node 1: Prepare and Call
|
||||
c.AppendLambda(compose.InvokableLambda(w.prepareAndCall))
|
||||
|
||||
// Node 2: Wait
|
||||
c.AppendLambda(compose.InvokableLambda(w.waitCallback))
|
||||
|
||||
return c.Compile(ctx)
|
||||
}
|
||||
|
||||
func (w *bugOptimizationSubmit) prepareAndCall(ctx context.Context, in *BugOptimizationSubmitInput) (*contextWithTask, error) {
|
||||
// 生成 TaskID
|
||||
taskID := uuid.New().String()
|
||||
|
||||
// Ext 中获取 sessionId
|
||||
sessionID := in.RequireData.GetSession()
|
||||
|
||||
// 注册回调映射
|
||||
if err := w.manager.Register(ctx, taskID, sessionID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 查询用户名
|
||||
userName := "unknown"
|
||||
if w.sessionRepo != nil {
|
||||
name, err := w.sessionRepo.GetUserName(ctx, sessionID)
|
||||
if err == nil && name != "" {
|
||||
userName = name
|
||||
}
|
||||
}
|
||||
|
||||
// 构建请求参数
|
||||
var fileUrls, fileContent string
|
||||
if len(in.RequireData.UserContent.File) > 0 {
|
||||
for _, file := range in.RequireData.UserContent.File {
|
||||
fileUrls += file.FileUrl + ","
|
||||
fileContent += file.FileRec + ","
|
||||
}
|
||||
fileUrls = fileUrls[:len(fileUrls)-1]
|
||||
fileContent = fileContent[:len(fileContent)-1]
|
||||
}
|
||||
|
||||
body := map[string]string{
|
||||
"mark": in.RequireData.Match.Index,
|
||||
"text": in.RequireData.UserContent.Text,
|
||||
"img": fileUrls,
|
||||
"img_content": fileContent,
|
||||
"creator": userName,
|
||||
"task_id": taskID,
|
||||
}
|
||||
|
||||
request := l_request.Request{
|
||||
Url: "https://connector.dingtalk.com/webhook/flow/10352c521dd02104cee9000c",
|
||||
Method: "POST",
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
JsonByte: pkg.JsonByteIgonErr(body),
|
||||
}
|
||||
|
||||
res, err := request.Send()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(res.Content, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if success, ok := data["success"].(bool); !ok || !success {
|
||||
return nil, errors.New("dingtalk flow failed")
|
||||
}
|
||||
|
||||
entitys.ResLog(in.Ch, in.RequireData.Match.Index, "问题记录中")
|
||||
entitys.ResLoading(in.Ch, in.RequireData.Match.Index, "问题记录中...")
|
||||
|
||||
return &contextWithTask{Input: in, TaskID: taskID}, nil
|
||||
}
|
||||
|
||||
func (w *bugOptimizationSubmit) waitCallback(ctx context.Context, in *contextWithTask) (*BugOptimizationSubmitOutput, error) {
|
||||
// 阻塞等待回调信号
|
||||
// 设置 5 分钟超时
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
res, err := w.manager.Wait(waitCtx, in.TaskID, 5*time.Minute)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &BugOptimizationSubmitOutput{Msg: res}, nil
|
||||
}
|
||||
|
|
@ -23,7 +23,6 @@ func init() {
|
|||
|
||||
type orderAfterSaleResellerBatch struct {
|
||||
cfg config.ToolConfig
|
||||
data *OrderAfterSaleResellerBatchWorkflowInput
|
||||
}
|
||||
|
||||
// 工作流入参
|
||||
|
|
@ -86,15 +85,19 @@ func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, rec *entitys.R
|
|||
return nil, err
|
||||
}
|
||||
|
||||
o.data = &OrderAfterSaleResellerBatchWorkflowInput{
|
||||
input := &OrderAfterSaleResellerBatchWorkflowInput{
|
||||
Ch: rec.Ch,
|
||||
UserInput: rec.UserContent.Text,
|
||||
FileContent: "",
|
||||
UserHistory: rec.ChatHis,
|
||||
ParameterResult: rec.Match.Parameters,
|
||||
}
|
||||
|
||||
// 将 Input 注入 Context
|
||||
ctx = context.WithValue(ctx, workflowInputContextKey{}, input)
|
||||
|
||||
// 工作流过程输出,不关注最终输出
|
||||
_, err = chain.Invoke(ctx, o.data)
|
||||
_, err = chain.Invoke(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -107,6 +110,9 @@ func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, rec *entitys.R
|
|||
|
||||
var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空")
|
||||
|
||||
// contextKey 用于在 Context 中传递 WorkflowInput
|
||||
type workflowInputContextKey struct{}
|
||||
|
||||
// buildWorkflow 构建工作流
|
||||
func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compose.Runnable[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput], error) {
|
||||
// 定义工作流、出入参
|
||||
|
|
@ -127,39 +133,93 @@ func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compos
|
|||
}),
|
||||
))
|
||||
|
||||
// 3.参数校验 & 传递 Input
|
||||
// 注意:为了在后续节点访问 WorkflowInput,这里使用闭包或 Context 传递。
|
||||
// Eino Chain 节点间传递的是返回值。这里我们修改节点签名,将 input 一路传下去,或者使用 context。
|
||||
// 由于 Eino Chain 是强类型的,这里选择让 Parser 返回的数据结构包含原始 input,或者我们在 Parser 后重新组合。
|
||||
// 但最简单的方法是使用 Context 存储 Input (如果 Eino 支持 Context 传递)。Eino 的 Invoke 接受 ctx。
|
||||
// 但 Eino Chain 的设计是数据流驱动。
|
||||
// 修正方案:修改中间节点的数据结构,或者使用闭包捕获(但闭包捕获的是 build 时的变量,无法捕获运行时 input)。
|
||||
// 正确做法:Chain 的节点入参必须是上一个节点的出参。
|
||||
// 我们可以把 Parser 的输入改为 Input,输出改为一个包含 Input 和 ParsedData 的结构。
|
||||
// 但这里为了最小改动,我们利用 Context 来传递 Input 引用(这在 Eino 中是可行的,因为 ctx 会贯穿整个 Invoke)。
|
||||
// 更好的做法是重构 Chain 的数据流,但在保持逻辑不变的前提下,Context 是最快解法。
|
||||
|
||||
// 为了线程安全,我们在第一个节点把 Input 放入 Context?不行,Chain.Invoke(ctx, input) 的 ctx 是外部传入的。
|
||||
// Eino 允许 Lambda 修改 Context 吗?通常不允许。
|
||||
|
||||
// 让我们重新审视数据流:
|
||||
// Input -> Lambda1 -> Message -> Parser -> NodeData -> Lambda4 -> ToolResp -> Lambda5 -> Output
|
||||
// Lambda4 需要 Input.Ch 来发 Loading。
|
||||
// Lambda5 需要 Input.Ch 来发 Log/Json,还需要 NodeData。
|
||||
|
||||
// 根本问题是:中间节点丢失了 Input 信息。
|
||||
// 解决方案:使用一个聚合结构体在 Chain 中传递。
|
||||
|
||||
// 由于要大改数据流比较复杂,这里使用一种技巧:
|
||||
// 在 Invoke 时,构造一个带有 Input 信息的 Context 传入。
|
||||
// 这样每个节点都能从 Context 拿到 Input。
|
||||
|
||||
// 重新实现 buildWorkflow 以支持 Context 传递
|
||||
return o.buildWorkflowWithContext(ctx)
|
||||
}
|
||||
|
||||
func (o *orderAfterSaleResellerBatch) buildWorkflowWithContext(ctx context.Context) (compose.Runnable[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput], error) {
|
||||
c := compose.NewChain[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput]()
|
||||
|
||||
// 0. Context 注入节点 (Trick: 利用第一个节点将 Input 注入 Context,但 Eino Chain 无法修改 Context 传递给下游)
|
||||
// 实际上,我们可以在 Invoke 调用前,在外部包装 Context。
|
||||
// 所以这里不需要额外的节点,只需要在 Invoke 时处理。
|
||||
// 但 Invoke 是由 Chain 提供的,我们只能控制传入的 ctx。
|
||||
// 见下文 Invoke 方法的修改。
|
||||
|
||||
// 1.llm 推断参数
|
||||
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *OrderAfterSaleResellerBatchWorkflowInput) (*schema.Message, error) {
|
||||
return &schema.Message{Content: in.ParameterResult}, nil
|
||||
}))
|
||||
|
||||
// 2.参数解析
|
||||
c.AppendLambda(compose.MessageParser(
|
||||
schema.NewMessageJSONParser[*OrderAfterSaleResellerBatchNodeData](&schema.MessageJSONParseConfig{
|
||||
ParseFrom: schema.MessageParseFromContent,
|
||||
}),
|
||||
))
|
||||
|
||||
// 3.参数校验
|
||||
c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *OrderAfterSaleResellerBatchNodeData) (*OrderAfterSaleResellerBatchNodeData, error) {
|
||||
// 校验必填项
|
||||
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *OrderAfterSaleResellerBatchNodeData) (*OrderAfterSaleResellerBatchNodeData, error) {
|
||||
if len(in.OrderNumber) == 0 {
|
||||
return nil, ErrInvalidOrderNumbers
|
||||
}
|
||||
|
||||
o.data.Data = in
|
||||
|
||||
// 将解析后的 Data 存入 Input (通过 Context 获取 Input)
|
||||
input := ctx.Value(workflowInputContextKey{}).(*OrderAfterSaleResellerBatchWorkflowInput)
|
||||
input.Data = in // 这里修改 Input 是安全的,因为 Input 是请求维度的引用
|
||||
return in, nil
|
||||
}))
|
||||
|
||||
// 4.工具调用
|
||||
c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *OrderAfterSaleResellerBatchNodeData) (*toolZoarb.OrderAfterSaleResellerBatchResponse, error) {
|
||||
entitys.ResLoading(o.data.Ch, o.ID(), "数据拉取中")
|
||||
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *OrderAfterSaleResellerBatchNodeData) (*toolZoarb.OrderAfterSaleResellerBatchResponse, error) {
|
||||
input := ctx.Value(workflowInputContextKey{}).(*OrderAfterSaleResellerBatchWorkflowInput)
|
||||
|
||||
entitys.ResLoading(input.Ch, o.ID(), "数据拉取中")
|
||||
toolRes, err := toolZoarb.Call(ctx, o.cfg, in.OrderNumber)
|
||||
|
||||
entitys.ResLog(o.data.Ch, o.ID(), "数据拉取完成")
|
||||
entitys.ResLog(input.Ch, o.ID(), "数据拉取完成")
|
||||
|
||||
return toolRes, err
|
||||
}))
|
||||
|
||||
// 5.结果数据映射
|
||||
c.AppendLambda(compose.InvokableLambda(o.dataMapping))
|
||||
c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *toolZoarb.OrderAfterSaleResellerBatchResponse) (*OrderAfterSaleResellerBatchWorkflowOutput, error) {
|
||||
return o.dataMapping(ctx, in)
|
||||
}))
|
||||
|
||||
// 编译工作流
|
||||
return c.Compile(ctx)
|
||||
}
|
||||
|
||||
// 结果数据映射
|
||||
func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoarb.OrderAfterSaleResellerBatchResponse) (*OrderAfterSaleResellerBatchWorkflowOutput, error) {
|
||||
entitys.ResLog(o.data.Ch, o.ID(), "数据整理中")
|
||||
func (o *orderAfterSaleResellerBatch) dataMapping(ctx context.Context, in *toolZoarb.OrderAfterSaleResellerBatchResponse) (*OrderAfterSaleResellerBatchWorkflowOutput, error) {
|
||||
input := ctx.Value(workflowInputContextKey{}).(*OrderAfterSaleResellerBatchWorkflowInput)
|
||||
|
||||
entitys.ResLog(input.Ch, o.ID(), "数据整理中")
|
||||
|
||||
toolResp := &OrderAfterSaleResellerBatchWorkflowOutput{
|
||||
Code: in.Code,
|
||||
|
|
@ -170,17 +230,17 @@ func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoa
|
|||
// 转换数据
|
||||
for _, item := range in.Data.Data {
|
||||
// 处理方式
|
||||
afterType := util.StringToInt(o.data.Data.AfterType)
|
||||
afterType := util.StringToInt(input.Data.AfterType)
|
||||
if afterType == 0 {
|
||||
afterType = 1 // 默认退款
|
||||
}
|
||||
// 费用承担者
|
||||
responsibleType := util.StringToInt(o.data.Data.ResponsibleType)
|
||||
responsibleType := util.StringToInt(input.Data.ResponsibleType)
|
||||
if responsibleType == 0 {
|
||||
responsibleType = 4 // 默认无
|
||||
}
|
||||
// 售后金额
|
||||
afterSalesPrice := util.StringToFloat64(o.data.Data.AfterSalesPrice)
|
||||
afterSalesPrice := util.StringToFloat64(input.Data.AfterSalesPrice)
|
||||
if afterSalesPrice == 0 {
|
||||
afterSalesPrice = item.OrderPrice
|
||||
}
|
||||
|
|
@ -199,10 +259,10 @@ func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoa
|
|||
Account: item.Account,
|
||||
Platforms: item.Platforms,
|
||||
AfterType: afterType,
|
||||
Remark: o.data.Data.AfterSalesReason,
|
||||
Remark: input.Data.AfterSalesReason,
|
||||
AfterAmount: afterSalesPrice,
|
||||
ResponsibleType: responsibleType,
|
||||
ResponsiblePerson: o.data.Data.ResponsiblePerson,
|
||||
ResponsiblePerson: input.Data.ResponsiblePerson,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -215,7 +275,7 @@ func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoa
|
|||
}
|
||||
|
||||
toolRespJson, _ := json.Marshal(toolResp)
|
||||
entitys.ResJson(o.data.Ch, o.ID(), string(toolRespJson))
|
||||
entitys.ResJson(input.Ch, o.ID(), string(toolRespJson))
|
||||
|
||||
return toolResp, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package entitys
|
|||
import (
|
||||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Recognize struct {
|
||||
|
|
@ -47,3 +48,20 @@ type RecognizeFile struct {
|
|||
FileRealMime string // 文件真实MIME类型
|
||||
FileUrl string // 文件下载链接
|
||||
}
|
||||
|
||||
func (r *Recognize) GetTaskExt() *TaskExt {
|
||||
var ext TaskExt
|
||||
if err := json.Unmarshal(r.Ext, &ext); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &ext
|
||||
}
|
||||
|
||||
func (r *Recognize) GetSession() string {
|
||||
ext := r.GetTaskExt()
|
||||
if ext == nil {
|
||||
return ""
|
||||
}
|
||||
return ext.Session
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/constants"
|
||||
errorcode "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/domain/component/callback"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/gateway"
|
||||
"ai_scheduler/internal/pkg"
|
||||
|
|
@ -25,17 +26,17 @@ type CallbackService struct {
|
|||
dingtalkOldClient *dingtalk.OldClient
|
||||
dingtalkContactClient *dingtalk.ContactClient
|
||||
dingtalkNotableClient *dingtalk.NotableClient
|
||||
callBackTool *tool_callback.CallBackTool
|
||||
callbackManager callback.Manager
|
||||
}
|
||||
|
||||
func NewCallbackService(cfg *config.Config, gateway *gateway.Gateway, dingtalkOldClient *dingtalk.OldClient, dingtalkContactClient *dingtalk.ContactClient, dingtalkNotableClient *dingtalk.NotableClient, callBackTool *tool_callback.CallBackTool) *CallbackService {
|
||||
func NewCallbackService(cfg *config.Config, gateway *gateway.Gateway, dingtalkOldClient *dingtalk.OldClient, dingtalkContactClient *dingtalk.ContactClient, dingtalkNotableClient *dingtalk.NotableClient, callbackManager callback.Manager) *CallbackService {
|
||||
return &CallbackService{
|
||||
cfg: cfg,
|
||||
gateway: gateway,
|
||||
dingtalkOldClient: dingtalkOldClient,
|
||||
dingtalkContactClient: dingtalkContactClient,
|
||||
dingtalkNotableClient: dingtalkNotableClient,
|
||||
callBackTool: callBackTool,
|
||||
callbackManager: callbackManager,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -139,11 +140,14 @@ func (s *CallbackService) Callback(c *fiber.Ctx) error {
|
|||
|
||||
func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) error {
|
||||
// 校验taskId
|
||||
sessionID, ok := s.callBackTool.GetSessionByTaskID(env.TaskID)
|
||||
if !ok {
|
||||
ctx := c.Context()
|
||||
sessionID, err := s.callbackManager.GetSession(ctx, env.TaskID)
|
||||
if err != nil {
|
||||
return errorcode.ParamErrf("failed to get session for task_id: %s, err: %v", env.TaskID, err)
|
||||
}
|
||||
if sessionID == "" {
|
||||
return errorcode.ParamErrf("missing session_id for task_id: %s", env.TaskID)
|
||||
}
|
||||
ctx := c.Context()
|
||||
|
||||
switch env.Action {
|
||||
case ActionBugOptimizationSubmitUpdate:
|
||||
|
|
@ -166,8 +170,10 @@ func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) err
|
|||
// 发送日志
|
||||
s.sendStreamTxt(sessionID, msg)
|
||||
|
||||
// 删除映射
|
||||
s.callBackTool.DelTaskMapping(env.TaskID)
|
||||
// 通知等待者
|
||||
if err := s.callbackManager.Notify(ctx, env.TaskID, msg); err != nil {
|
||||
// 记录错误但继续
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{"code": 0, "message": "ok"})
|
||||
case ActionBugOptimizationSubmitProcess:
|
||||
|
|
|
|||
Loading…
Reference in New Issue