package callback import ( "context" "fmt" "time" "ai_scheduler/internal/pkg" "github.com/redis/go-redis/v9" ) type Manager interface { Register(ctx context.Context, taskID string, sessionID string) error Wait(ctx context.Context, taskID string, timeout time.Duration) (string, error) Notify(ctx context.Context, taskID string, result string) error GetSession(ctx context.Context, taskID string) (string, error) } type RedisManager struct { rdb *redis.Client } func NewRedisManager(rdb *pkg.Rdb) *RedisManager { return &RedisManager{ rdb: rdb.Rdb, } } const ( keyPrefixSession = "callback:session:" keyPrefixSignal = "callback:signal:" defaultTTL = 24 * time.Hour ) func (m *RedisManager) Register(ctx context.Context, taskID string, sessionID string) error { key := keyPrefixSession + taskID return m.rdb.Set(ctx, key, sessionID, defaultTTL).Err() } func (m *RedisManager) Wait(ctx context.Context, taskID string, timeout time.Duration) (string, error) { key := keyPrefixSignal + taskID // BLPop 阻塞等待 result, err := m.rdb.BLPop(ctx, timeout, key).Result() if err != nil { if err == redis.Nil { return "", fmt.Errorf("timeout waiting for callback") } return "", err } // result[0] is key, result[1] is value if len(result) < 2 { return "", fmt.Errorf("invalid redis result") } return result[1], nil } func (m *RedisManager) Notify(ctx context.Context, taskID string, result string) error { key := keyPrefixSignal + taskID // Push 信号,同时设置过期时间防止堆积 pipe := m.rdb.Pipeline() pipe.RPush(ctx, key, result) pipe.Expire(ctx, key, 1*time.Hour) // 信号列表也需要过期 _, err := pipe.Exec(ctx) return err } func (m *RedisManager) GetSession(ctx context.Context, taskID string) (string, error) { key := keyPrefixSession + taskID return m.rdb.Get(ctx, key).Result() }