134 lines
2.9 KiB
Go
134 lines
2.9 KiB
Go
package timeslice
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/hashicorp/go-multierror"
|
|
"golang.org/x/sync/errgroup"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
DefaultGoNum = 2
|
|
DefaultTimeSliceHours = 2
|
|
maxGlobalGoroutines = 1000
|
|
)
|
|
|
|
type Callback func(ctx context.Context, req *Task) error
|
|
|
|
type ManagerSrv struct {
|
|
callback Callback
|
|
}
|
|
|
|
func NewManager(callback Callback) *ManagerSrv {
|
|
return &ManagerSrv{callback: callback}
|
|
}
|
|
|
|
func (m *ManagerSrv) Run(ctx context.Context, req *Manager) (int, error) {
|
|
|
|
if req.StartTime.After(req.EndTime) {
|
|
return 0, fmt.Errorf("start_time不能大于end_time")
|
|
}
|
|
|
|
if req.GoNum == 0 {
|
|
return 0, fmt.Errorf("协程数量不能为0")
|
|
}
|
|
if req.GoNum > maxGlobalGoroutines {
|
|
return 0, fmt.Errorf("协程数量不能大于%d", maxGlobalGoroutines)
|
|
}
|
|
|
|
timeSliceHours := float64(req.TimeSliceHours)
|
|
|
|
totalHours := req.EndTime.Sub(req.StartTime).Hours()
|
|
taskCount := int(totalHours / timeSliceHours)
|
|
|
|
// 如果剩余时间不足 TimeSliceHours 小时,增加任务数
|
|
if totalHours-float64(taskCount)*timeSliceHours > 0 {
|
|
taskCount++
|
|
}
|
|
|
|
processReq := &Process{
|
|
Manager: req,
|
|
TaskCount: taskCount,
|
|
}
|
|
|
|
return taskCount, m.process(ctx, processReq)
|
|
}
|
|
|
|
func (m *ManagerSrv) process(ctx context.Context, req *Process) error {
|
|
|
|
if req.TaskCount == 0 {
|
|
return fmt.Errorf("该时间范围无可执行任务次数,请检查时间范围")
|
|
}
|
|
|
|
// 设置最大并发任务数为 5
|
|
eg := new(errgroup.Group)
|
|
eg.SetLimit(req.Manager.GoNum)
|
|
var mu sync.Mutex
|
|
|
|
errs := make([]error, 0, req.TaskCount)
|
|
|
|
// 为每个任务按指定的时间片 TimeSliceHours 分配开始和结束时间
|
|
for i := 0; i < req.TaskCount; i++ {
|
|
|
|
currentStart := req.Manager.StartTime.Add(time.Duration(i) * time.Duration(req.Manager.TimeSliceHours) * time.Hour)
|
|
currentEnd := currentStart.Add(time.Duration(req.Manager.TimeSliceHours) * time.Hour)
|
|
|
|
if currentEnd.After(req.Manager.EndTime) {
|
|
currentEnd = req.Manager.EndTime
|
|
}
|
|
taskID := i + 1
|
|
|
|
eg.Go(func() error {
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
mu.Lock()
|
|
errs = append(errs, fmt.Errorf("任务 %d 被上下文取消", taskID))
|
|
mu.Unlock()
|
|
return ctx.Err()
|
|
default:
|
|
// 继续执行
|
|
}
|
|
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
mu.Lock()
|
|
errs = append(errs, fmt.Errorf("任务 %d panic: %v", taskID, err))
|
|
mu.Unlock()
|
|
}
|
|
}()
|
|
|
|
taskReq := &Task{
|
|
CurrentStartTime: currentStart,
|
|
CurrentEndTime: currentEnd,
|
|
TaskID: taskID,
|
|
Process: req,
|
|
}
|
|
|
|
if err := m.callback(ctx, taskReq); err != nil {
|
|
mu.Lock()
|
|
errs = append(errs, fmt.Errorf("任务 %d 执行失败: %v", taskID, err))
|
|
mu.Unlock()
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// 等待所有任务完成
|
|
if err := eg.Wait(); err != nil {
|
|
return fmt.Errorf("任务执行失败: %v", err)
|
|
}
|
|
|
|
var result error
|
|
|
|
// 收集错误
|
|
for _, err2 := range errs {
|
|
result = multierror.Append(result, err2)
|
|
}
|
|
|
|
return result
|
|
}
|