ai_scheduler/internal/domain/component/callback/manager.go

72 lines
1.8 KiB
Go

package callback
import (
"context"
"fmt"
"time"
"ai_scheduler/internal/pkg"
"github.com/redis/go-redis/v9"
)
type Manager interface {
Register(ctx context.Context, taskID string, sessionID string) error
Wait(ctx context.Context, taskID string, timeout time.Duration) (string, error)
Notify(ctx context.Context, taskID string, result string) error
GetSession(ctx context.Context, taskID string) (string, error)
}
type RedisManager struct {
rdb *redis.Client
}
func NewRedisManager(rdb *pkg.Rdb) *RedisManager {
return &RedisManager{
rdb: rdb.Rdb,
}
}
const (
keyPrefixSession = "callback:session:"
keyPrefixSignal = "callback:signal:"
defaultTTL = 24 * time.Hour
)
func (m *RedisManager) Register(ctx context.Context, taskID string, sessionID string) error {
key := keyPrefixSession + taskID
return m.rdb.Set(ctx, key, sessionID, defaultTTL).Err()
}
func (m *RedisManager) Wait(ctx context.Context, taskID string, timeout time.Duration) (string, error) {
key := keyPrefixSignal + taskID
// BLPop 阻塞等待
result, err := m.rdb.BLPop(ctx, timeout, key).Result()
if err != nil {
if err == redis.Nil {
return "", fmt.Errorf("timeout waiting for callback")
}
return "", err
}
// result[0] is key, result[1] is value
if len(result) < 2 {
return "", fmt.Errorf("invalid redis result")
}
return result[1], nil
}
func (m *RedisManager) Notify(ctx context.Context, taskID string, result string) error {
key := keyPrefixSignal + taskID
// Push 信号,同时设置过期时间防止堆积
pipe := m.rdb.Pipeline()
pipe.RPush(ctx, key, result)
pipe.Expire(ctx, key, 1*time.Hour) // 信号列表也需要过期
_, err := pipe.Exec(ctx)
return err
}
func (m *RedisManager) GetSession(ctx context.Context, taskID string) (string, error) {
key := keyPrefixSession + taskID
return m.rdb.Get(ctx, key).Result()
}