252 lines
6.0 KiB
Go
252 lines
6.0 KiB
Go
// Package exporter 导出任务执行器
|
||
package exporter
|
||
|
||
import (
|
||
"database/sql"
|
||
"server/internal/constants"
|
||
"server/internal/logging"
|
||
"strconv"
|
||
)
|
||
|
||
// ==================== 任务执行配置 ====================
|
||
|
||
// JobConfig 导出任务配置
|
||
type JobConfig struct {
|
||
JobID uint64 // 任务ID
|
||
Datasource string // 数据源
|
||
MainTable string // 主表
|
||
Format constants.FileFormat // 导出格式
|
||
Query string // SQL查询
|
||
Args []interface{} // SQL参数
|
||
Fields []string // 字段列表
|
||
Headers []string // 表头列表
|
||
Filters map[string]interface{} // 过滤条件
|
||
}
|
||
|
||
// JobCallbacks 任务回调函数
|
||
type JobCallbacks struct {
|
||
// OnStart 任务开始时调用
|
||
OnStart func(jobID uint64) error
|
||
// OnProgress 进度更新时调用
|
||
OnProgress func(jobID uint64, totalRows int64) error
|
||
// OnFileCreated 文件创建时调用
|
||
OnFileCreated func(jobID uint64, path string, size, rowCount int64) error
|
||
// OnComplete 任务完成时调用
|
||
OnComplete func(jobID uint64, totalRows int64, files []string) error
|
||
// OnFailed 任务失败时调用
|
||
OnFailed func(jobID uint64, err error) error
|
||
// Transform 行数据转换函数
|
||
Transform func(datasource string, fields, values []string) []string
|
||
}
|
||
|
||
// ==================== 任务执行器 ====================
|
||
|
||
// JobRunner 导出任务执行器
|
||
type JobRunner struct {
|
||
db *sql.DB
|
||
config JobConfig
|
||
callbacks JobCallbacks
|
||
}
|
||
|
||
// NewJobRunner 创建任务执行器
|
||
func NewJobRunner(db *sql.DB, config JobConfig, callbacks JobCallbacks) *JobRunner {
|
||
return &JobRunner{
|
||
db: db,
|
||
config: config,
|
||
callbacks: callbacks,
|
||
}
|
||
}
|
||
|
||
// Run 执行导出任务
|
||
func (r *JobRunner) Run() error {
|
||
defer r.recoverPanic()
|
||
|
||
// 通知任务开始
|
||
if r.callbacks.OnStart != nil {
|
||
if err := r.callbacks.OnStart(r.config.JobID); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 根据格式执行导出
|
||
return r.runExport()
|
||
}
|
||
|
||
// recoverPanic 恢复panic并标记任务失败
|
||
func (r *JobRunner) recoverPanic() {
|
||
if rec := recover(); rec != nil {
|
||
logging.JSON("ERROR", map[string]interface{}{
|
||
"event": "export_panic",
|
||
"job_id": r.config.JobID,
|
||
"error": anyToString(rec),
|
||
})
|
||
if r.callbacks.OnFailed != nil {
|
||
r.callbacks.OnFailed(r.config.JobID, nil)
|
||
}
|
||
}
|
||
}
|
||
|
||
// runExport 执行导出(统一CSV和XLSX逻辑)
|
||
func (r *JobRunner) runExport() error {
|
||
cfg := r.config
|
||
maxRowsPerFile := constants.ExportThresholds.MaxRowsPerFile
|
||
|
||
// 创建写入器工厂
|
||
dir := constants.StorageConfig.Directory
|
||
name := constants.StorageConfig.FilePrefix + strconv.FormatUint(cfg.JobID, 10)
|
||
|
||
newWriter := func() (RowWriter, error) {
|
||
w, err := NewWriter(cfg.Format, dir, name)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// 写入表头
|
||
if err := w.WriteHeader(cfg.Headers); err != nil {
|
||
return nil, err
|
||
}
|
||
return w, nil
|
||
}
|
||
|
||
// 创建游标
|
||
cursor := NewCursorSQL(cfg.Datasource, cfg.MainTable)
|
||
batchSize := constants.ChooseBatchSize(0, cfg.Format)
|
||
|
||
// 转换函数
|
||
transform := func(vals []string) []string {
|
||
if r.callbacks.Transform != nil {
|
||
return r.callbacks.Transform(cfg.Datasource, cfg.Fields, vals)
|
||
}
|
||
return vals
|
||
}
|
||
|
||
// 收集生成的文件
|
||
var files []string
|
||
var totalRows int64
|
||
|
||
// 进度回调
|
||
onProgress := func(rows int64) error {
|
||
if r.callbacks.OnProgress != nil {
|
||
return r.callbacks.OnProgress(cfg.JobID, rows)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 文件滚动回调
|
||
onRoll := func(path string, size, partRows int64) error {
|
||
files = append(files, path)
|
||
if r.callbacks.OnFileCreated != nil {
|
||
return r.callbacks.OnFileCreated(cfg.JobID, path, size, partRows)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 执行流式导出
|
||
count, resultFiles, err := StreamWithCursor(
|
||
r.db,
|
||
cfg.Query,
|
||
cfg.Args,
|
||
cursor,
|
||
batchSize,
|
||
cfg.Headers,
|
||
newWriter,
|
||
transform,
|
||
maxRowsPerFile,
|
||
onRoll,
|
||
onProgress,
|
||
)
|
||
|
||
if err != nil {
|
||
logging.JSON("ERROR", map[string]interface{}{
|
||
"event": "export_stream_error",
|
||
"job_id": cfg.JobID,
|
||
"error": err.Error(),
|
||
})
|
||
if r.callbacks.OnFailed != nil {
|
||
r.callbacks.OnFailed(cfg.JobID, err)
|
||
}
|
||
return err
|
||
}
|
||
|
||
totalRows = count
|
||
if len(resultFiles) > 0 {
|
||
files = resultFiles
|
||
}
|
||
|
||
// 通知完成
|
||
if r.callbacks.OnComplete != nil {
|
||
return r.callbacks.OnComplete(cfg.JobID, totalRows, files)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ==================== 辅助函数 ====================
|
||
|
||
// anyToString 将任意类型转换为字符串
|
||
func anyToString(v interface{}) string {
|
||
switch t := v.(type) {
|
||
case []byte:
|
||
return string(t)
|
||
case string:
|
||
return t
|
||
case int64:
|
||
return strconv.FormatInt(t, 10)
|
||
case int:
|
||
return strconv.Itoa(t)
|
||
case uint64:
|
||
return strconv.FormatUint(t, 10)
|
||
case float64:
|
||
if t == float64(int64(t)) {
|
||
return strconv.FormatInt(int64(t), 10)
|
||
}
|
||
return strconv.FormatFloat(t, 'f', -1, 64)
|
||
case nil:
|
||
return ""
|
||
case error:
|
||
return t.Error()
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
// ==================== 分块导出支持 ====================
|
||
|
||
// ChunkRange 时间分块范围
|
||
type ChunkRange struct {
|
||
Start string
|
||
End string
|
||
}
|
||
|
||
// SplitTimeRange 按天数分割时间范围
|
||
func SplitTimeRange(startStr, endStr string, stepDays int) []ChunkRange {
|
||
chunks := SplitByDays(startStr, endStr, stepDays)
|
||
result := make([]ChunkRange, len(chunks))
|
||
for i, c := range chunks {
|
||
result[i] = ChunkRange{Start: c[0], End: c[1]}
|
||
}
|
||
return result
|
||
}
|
||
|
||
// ShouldUseChunkedExport 判断是否应该使用分块导出
|
||
func ShouldUseChunkedExport(datasource, mainTable string, estimate int64, fields []string, filters map[string]interface{}) bool {
|
||
// 行数太少不需要分块
|
||
if estimate <= constants.ExportThresholds.ChunkThreshold {
|
||
return false
|
||
}
|
||
|
||
// Marketing系统的某些场景不适合分块
|
||
if datasource == "marketing" && mainTable == "order" {
|
||
// 检查是否有order_voucher相关字段或过滤
|
||
for _, f := range fields {
|
||
if len(f) > 14 && f[:14] == "order_voucher." {
|
||
return false
|
||
}
|
||
}
|
||
if _, ok := filters["order_voucher_channel_activity_id_eq"]; ok {
|
||
return false
|
||
}
|
||
}
|
||
|
||
return true
|
||
}
|