diff --git a/internal/pkg/helper/ctx.go b/internal/pkg/helper/ctx.go index 83495ba..81b1273 100644 --- a/internal/pkg/helper/ctx.go +++ b/internal/pkg/helper/ctx.go @@ -16,3 +16,21 @@ func (valueOnlyContext) Err() error { return nil } func CopyValueCtx(ctx context.Context) context.Context { return valueOnlyContext{ctx} } + +// RunTaskWithTimeout 带有超时控制的任务执行函数 +func RunTaskWithTimeout(ctx context.Context, task func(ctx context.Context) error, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- task(ctx) + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/internal/pkg/helper/ctx_test.go b/internal/pkg/helper/ctx_test.go new file mode 100644 index 0000000..98d6aa1 --- /dev/null +++ b/internal/pkg/helper/ctx_test.go @@ -0,0 +1,60 @@ +package helper + +import ( + "context" + "fmt" + "sync" + "testing" + "time" +) + +// 模拟服务端任务 +func mockServerTask(ctx context.Context) error { + // 模拟从 context 中获取数据 + value := ctx.Value("testKey") + if value == nil { + return fmt.Errorf("testKey not found in context") + } + + // 模拟耗时操作 + for i := 0; i < 5; i++ { + select { + case <-ctx.Done(): + return ctx.Err() + default: + fmt.Printf("Task is running... Step %d\n", i+1) + time.Sleep(1 * time.Second) + } + } + + fmt.Println("Task completed successfully") + return nil +} + +func TestCopyValueCtx(t *testing.T) { + // 创建带有值的原始 context + originalCtx := context.WithValue(context.Background(), "testKey", "testValue") + // 模拟客户端取消 context + clientCtx, cancel := context.WithCancel(originalCtx) + defer cancel() + + // 复制只有 value 的 context + serverCtx := CopyValueCtx(clientCtx) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + // 执行带有超时控制的任务 + err := RunTaskWithTimeout(serverCtx, mockServerTask, 10*time.Second) + if err != nil { + fmt.Printf("Task failed: %v\n", err) + } + }() + + // 模拟客户端断开连接 + time.Sleep(2 * time.Second) + + wg.Wait() +}