diff --git a/server/internal/api/exports.go b/server/internal/api/exports.go index 5932197..2fd1453 100644 --- a/server/internal/api/exports.go +++ b/server/internal/api/exports.go @@ -14,6 +14,7 @@ import ( "server/internal/exporter" "server/internal/logging" "server/internal/repo" + "server/internal/utils" "server/internal/ymtcrypto" "strconv" "strings" @@ -455,7 +456,7 @@ func (a *ExportsAPI) runJob(id uint64, db *sql.DB, q string, args []interface{}, defer func() { if r := recover(); r != nil { repo.NewExportRepo().MarkFailed(a.meta, id) - logging.JSON("ERROR", map[string]interface{}{"event": "export_panic", "job_id": id, "error": toString(r)}) + logging.JSON("ERROR", map[string]interface{}{"event": "export_panic", "job_id": id, "error": utils.ToString(r)}) } }() // load datasource once for transform decisions @@ -498,7 +499,7 @@ func (a *ExportsAPI) runJob(id uint64, db *sql.DB, q string, args []interface{}, var chunks [][2]string if v, ok := fl["create_time_between"]; ok { if arr, ok2 := v.([]interface{}); ok2 && len(arr) == 2 { - chunks = exporter.SplitByDays(toString(arr[0]), toString(arr[1]), constants.ExportThresholds.ChunkDays) + chunks = exporter.SplitByDays(utils.ToString(arr[0]), utils.ToString(arr[1]), constants.ExportThresholds.ChunkDays) } if arrs, ok3 := v.([]string); ok3 && len(arrs) == 2 { chunks = exporter.SplitByDays(arrs[0], arrs[1], constants.ExportThresholds.ChunkDays) @@ -606,53 +607,6 @@ func (a *ExportsAPI) runJob(id uint64, db *sql.DB, q string, args []interface{}, rrepo.MarkCompleted(a.meta, id, count) return } - w, err := newBaseWriter() - if err != nil { - a.meta.Exec("UPDATE export_jobs SET status=?, finished_at=? WHERE id= ?", "failed", time.Now(), id) - return - } - _ = w.WriteHeader(cols) - rows, err := db.Query(q, args...) - if err != nil { - a.meta.Exec("UPDATE export_jobs SET status=?, finished_at=? WHERE id= ?", "failed", time.Now(), id) - return - } - defer rows.Close() - out := make([]interface{}, len(cols)) - dest := make([]interface{}, len(cols)) - for i := range out { - dest[i] = &out[i] - } - var count int64 - var tick int64 - for rows.Next() { - if err := rows.Scan(dest...); err != nil { - a.meta.Exec("UPDATE export_jobs SET status=?, finished_at=? WHERE id=?", "failed", time.Now(), id) - return - } - vals := make([]string, len(cols)) - for i := range out { - if b, ok := out[i].([]byte); ok { - vals[i] = string(b) - } else if out[i] == nil { - vals[i] = "" - } else { - vals[i] = toString(out[i]) - } - } - w.WriteRow(vals) - count++ - tick++ - if tick%1000 == 0 { - rrepo.UpdateProgress(a.meta, id, count) - } - } - path, size, _ := w.Close() - log.Printf("job_id=%d sql=%s args=%v", id, "INSERT INTO export_job_files (job_id, storage_uri, row_count, size_bytes, created_at, updated_at) VALUES (?,?,?,?,?,?)", []interface{}{id, path, count, size, time.Now(), time.Now()}) - a.meta.Exec("INSERT INTO export_job_files (job_id, storage_uri, row_count, size_bytes, created_at, updated_at) VALUES (?,?,?,?,?,?)", id, path, count, size, time.Now(), time.Now()) - log.Printf("job_id=%d sql=%s args=%v", id, "UPDATE export_jobs SET status=?, finished_at=?, total_rows=?, updated_at=? WHERE id= ?", []interface{}{"completed", time.Now(), count, time.Now(), id}) - a.meta.Exec("UPDATE export_jobs SET status=?, finished_at=?, total_rows=?, updated_at=? WHERE id= ?", "completed", time.Now(), count, time.Now(), id) - return } if fmt == "xlsx" { files := []string{} @@ -674,7 +628,7 @@ func (a *ExportsAPI) runJob(id uint64, db *sql.DB, q string, args []interface{}, var chunks [][2]string if v, ok := fl["create_time_between"]; ok { if arr, ok2 := v.([]interface{}); ok2 && len(arr) == 2 { - chunks = exporter.SplitByDays(toString(arr[0]), toString(arr[1]), constants.ExportThresholds.ChunkDays) + chunks = exporter.SplitByDays(utils.ToString(arr[0]), utils.ToString(arr[1]), constants.ExportThresholds.ChunkDays) } if arrs, ok3 := v.([]string); ok3 && len(arrs) == 2 { chunks = exporter.SplitByDays(arrs[0], arrs[1], constants.ExportThresholds.ChunkDays) @@ -772,14 +726,14 @@ func (a *ExportsAPI) runJob(id uint64, db *sql.DB, q string, args []interface{}, if err != nil { logging.JSON("ERROR", map[string]interface{}{"event": "export_writer_error", "job_id": id, "stage": "xlsx_direct", "error": err.Error()}) log.Printf("export_writer_error job_id=%d stage=xlsx_direct err=%v", id, err) - a.meta.Exec("UPDATE export_jobs SET status=?, finished_at=? WHERE id= ?", "failed", time.Now(), id) + a.meta.Exec("UPDATE export_jobs SET status=?, finished_at=? WHERE id= ?", string(constants.JobStatusFailed), time.Now(), id) return } _ = x.WriteHeader(cols) rrepo.UpdateProgress(a.meta, id, 0) rows, err := db.Query(q, args...) if err != nil { - a.meta.Exec("UPDATE export_jobs SET status=?, finished_at=? WHERE id= ?", "failed", time.Now(), id) + a.meta.Exec("UPDATE export_jobs SET status=?, finished_at=? WHERE id= ?", string(constants.JobStatusFailed), time.Now(), id) return } defer rows.Close() @@ -792,7 +746,7 @@ func (a *ExportsAPI) runJob(id uint64, db *sql.DB, q string, args []interface{}, var tick int64 for rows.Next() { if err := rows.Scan(dest...); err != nil { - a.meta.Exec("UPDATE export_jobs SET status=?, finished_at=? WHERE id=?", "failed", time.Now(), id) + a.meta.Exec("UPDATE export_jobs SET status=?, finished_at=? WHERE id=?", string(constants.JobStatusFailed), time.Now(), id) return } vals := make([]string, len(cols)) @@ -802,7 +756,7 @@ func (a *ExportsAPI) runJob(id uint64, db *sql.DB, q string, args []interface{}, } else if out[i] == nil { vals[i] = "" } else { - vals[i] = toString(out[i]) + vals[i] = utils.ToString(out[i]) } } vals = transformRow(jobDS, fields, vals) @@ -820,7 +774,7 @@ func (a *ExportsAPI) runJob(id uint64, db *sql.DB, q string, args []interface{}, rrepo.MarkCompleted(a.meta, id, count) return } - a.meta.Exec("UPDATE export_jobs SET status=?, finished_at=?, updated_at=? WHERE id= ?", "failed", time.Now(), time.Now(), id) + a.meta.Exec("UPDATE export_jobs SET status=?, finished_at=?, updated_at=? WHERE id= ?", string(constants.JobStatusFailed), time.Now(), time.Now(), id) } // recompute final rows for a job and correct export_jobs.total_rows @@ -1142,57 +1096,10 @@ func decodeOrderKey(s string) string { } func (a *ExportsAPI) cancel(w http.ResponseWriter, r *http.Request, id string) { - a.meta.Exec("UPDATE export_jobs SET status=?, updated_at=? WHERE id=? AND status IN ('queued','running')", "canceled", time.Now(), id) + a.meta.Exec("UPDATE export_jobs SET status=?, updated_at=? WHERE id=? AND status IN ('queued','running')", string(constants.JobStatusCanceled), time.Now(), id) w.Write([]byte("ok")) } -func toString(v interface{}) string { - switch t := v.(type) { - case []byte: - return string(t) - case string: - return t - case int64: - return strconv.FormatInt(t, 10) - case int32: - return strconv.FormatInt(int64(t), 10) - case int: - return strconv.Itoa(t) - case uint64: - return strconv.FormatUint(t, 10) - case uint32: - return strconv.FormatUint(uint64(t), 10) - case uint: - return strconv.FormatUint(uint64(t), 10) - case float64: - // 对于整数部分,使用整数格式;对于小数部分,保留必要精度 - if t == float64(int64(t)) { - return strconv.FormatInt(int64(t), 10) - } - return strconv.FormatFloat(t, 'f', -1, 64) - case float32: - // 对于整数部分,使用整数格式;对于小数部分,保留必要精度 - if t == float32(int64(t)) { - return strconv.FormatInt(int64(t), 10) - } - return strconv.FormatFloat(float64(t), 'f', -1, 32) - case bool: - if t { - return "1" - } - return "0" - case time.Time: - return t.Format("2006-01-02 15:04:05") - case nil: - return "" - default: - // 尝试转换为字符串,如果是数字类型则格式化 - if s := fmt.Sprintf("%v", t); s != "" { - return s - } - return "" - } -} func renderSQL(q string, args []interface{}) string { formatArg := func(a interface{}) string { switch t := a.(type) { @@ -1409,7 +1316,7 @@ func normalizeIDs(v interface{}) []interface{} { switch t := v.(type) { case []interface{}: for _, x := range t { - if s := toString(x); s != "" { + if s := utils.ToString(x); s != "" { out = append(out, s) } } @@ -1438,7 +1345,7 @@ func normalizeIDs(v interface{}) []interface{} { } } default: - if s := toString(t); s != "" { + if s := utils.ToString(t); s != "" { out = append(out, s) } } @@ -1458,7 +1365,7 @@ func pickFirst(perm map[string]interface{}, filters map[string]interface{}, keys if len(arr) > 0 { return arr[0], true } - if s := toString(v); s != "" { + if s := utils.ToString(v); s != "" { return s, true } } @@ -1468,7 +1375,7 @@ func pickFirst(perm map[string]interface{}, filters map[string]interface{}, keys if len(arr) > 0 { return arr[0], true } - if s := toString(v); s != "" { + if s := utils.ToString(v); s != "" { return s, true } } diff --git a/server/internal/exporter/sqlbuilder.go b/server/internal/exporter/sqlbuilder.go index 7338ced..9720e13 100644 --- a/server/internal/exporter/sqlbuilder.go +++ b/server/internal/exporter/sqlbuilder.go @@ -4,10 +4,10 @@ import ( "encoding/json" "errors" "fmt" + "server/internal/constants" "server/internal/schema" - "strconv" + "server/internal/utils" "strings" - "time" ) type BuildRequest struct { @@ -66,19 +66,19 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ mf, _ := sch.MapField(t, f) if req.Datasource == "marketing" && t == "order" && req.MainTable == "order" { if f == "status" { - cols = append(cols, "CASE `order`.status WHEN 0 THEN '待充值' WHEN 1 THEN '充值中' WHEN 2 THEN '已完成' WHEN 3 THEN '充值失败' WHEN 4 THEN '已取消' WHEN 5 THEN '已过期' WHEN 6 THEN '待支付' END AS `order.status`") + cols = append(cols, constants.BuildCaseWhen("order", "status", constants.MarketingOrderStatus, "order.status")) continue } if f == "type" { - cols = append(cols, "CASE `order`.type WHEN 1 THEN '直充卡密' WHEN 2 THEN '立减金' WHEN 3 THEN '红包' ELSE '' END AS `order.type`") + cols = append(cols, constants.BuildCaseWhen("order", "type", constants.MarketingOrderType, "order.type")) continue } if f == "pay_type" { - cols = append(cols, "CASE `order`.pay_type WHEN 1 THEN '支付宝' WHEN 5 THEN '微信' ELSE '' END AS `order.pay_type`") + cols = append(cols, constants.BuildCaseWhen("order", "pay_type", constants.MarketingPayType, "order.pay_type")) continue } if f == "pay_status" { - cols = append(cols, "CASE `order`.pay_status WHEN 1 THEN '待支付' WHEN 2 THEN '已支付' WHEN 3 THEN '已退款' ELSE '' END AS `order.pay_status`") + cols = append(cols, constants.BuildCaseWhen("order", "pay_status", constants.MarketingPayStatus, "order.pay_status")) continue } if req.Datasource == "marketing" && f == "card_code" { @@ -88,7 +88,7 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ } if req.Datasource == "ymt" && t == "order" { if f == "type" { - cols = append(cols, "CASE `"+mt+"`.type WHEN 1 THEN '红包订单' WHEN 2 THEN '直充卡密订单' WHEN 3 THEN '立减金订单' ELSE '' END AS `order.type`") + cols = append(cols, constants.BuildCaseWhen(mt, "type", constants.YMTOrderType, "order.type")) continue } if f == "recharge_suc_time" { @@ -97,15 +97,15 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ continue } if f == "status" { - cols = append(cols, "CASE `"+mt+"`.status WHEN 1 THEN '待充值' WHEN 2 THEN '充值中' WHEN 3 THEN '充值成功' WHEN 4 THEN '充值失败' WHEN 5 THEN '已过期' WHEN 6 THEN '已作废' WHEN 7 THEN '已核销' WHEN 8 THEN '核销失败' WHEN 9 THEN '订单重置' WHEN 10 THEN '卡单' ELSE '' END AS `order.status`") + cols = append(cols, constants.BuildCaseWhen(mt, "status", constants.YMTOrderStatus, "order.status")) continue } if f == "pay_status" { - cols = append(cols, "CASE `"+mt+"`.pay_status WHEN 1 THEN '待支付' WHEN 2 THEN '支付中' WHEN 3 THEN '已支付' WHEN 4 THEN '取消支付' WHEN 5 THEN '退款中' WHEN 6 THEN '退款成功' ELSE '' END AS `order.pay_status`") + cols = append(cols, constants.BuildCaseWhen(mt, "pay_status", constants.YMTPayStatus, "order.pay_status")) continue } if f == "is_retry" { - cols = append(cols, "CASE `"+mt+"`.is_retry WHEN 0 THEN '可以失败重试' WHEN 1 THEN '可以失败重试' WHEN 2 THEN '不可以失败重试' ELSE '' END AS `order.is_retry`") + cols = append(cols, constants.BuildCaseWhen(mt, "is_retry", constants.YMTIsRetry, "order.is_retry")) continue } if f == "supplier_name" { @@ -114,40 +114,40 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ continue } if f == "is_inner" { - cols = append(cols, "CASE `"+mt+"`.is_inner WHEN 1 THEN '内部供应商' ELSE '外部供应商' END AS `order.is_inner`") + cols = append(cols, constants.BuildCaseWhen(mt, "is_inner", constants.YMTIsInner, "order.is_inner")) continue } } if req.Datasource == "ymt" && t == "activity" { if f == "settlement_type" { - cols = append(cols, "CASE `"+mt+"`.settlement_type WHEN 1 THEN '发放结算' WHEN 2 THEN '打开结算' WHEN 3 THEN '领用结算' WHEN 4 THEN '核销结算' ELSE '' END AS `activity.settlement_type`") + cols = append(cols, constants.BuildCaseWhen(mt, "settlement_type", constants.YMTSettlementType, "activity.settlement_type")) continue } } if t == "merchant" { if f == "third_party" { - cols = append(cols, "CASE `"+mt+"`.third_party WHEN 1 THEN '外部供应商' WHEN 2 THEN '内部供应商' ELSE '' END AS `merchant.third_party`") + cols = append(cols, constants.BuildCaseWhen(mt, "third_party", constants.ThirdPartyType, "merchant.third_party")) continue } } // Generic mapping for order.is_retry across datasources if t == "order" && f == "is_retry" { - cols = append(cols, "CASE `"+mt+"`.is_retry WHEN 0 THEN '可以失败重试' WHEN 1 THEN '可以失败重试' WHEN 2 THEN '不可以失败重试' ELSE '' END AS `order.is_retry`") + cols = append(cols, constants.BuildCaseWhen(mt, "is_retry", constants.YMTIsRetry, "order.is_retry")) continue } // Generic mapping for order.is_inner across datasources if t == "order" && f == "is_inner" { - cols = append(cols, "CASE `"+mt+"`.is_inner WHEN 1 THEN '内部供应商' ELSE '外部供应商' END AS `order.is_inner`") + cols = append(cols, constants.BuildCaseWhen(mt, "is_inner", constants.YMTIsInner, "order.is_inner")) continue } if req.Datasource == "ymt" && t == "order_digit" { if f == "order_type" { - cols = append(cols, "CASE `"+mt+"`.order_type WHEN 1 THEN '直充' WHEN 2 THEN '卡密' ELSE '' END AS `order_digit.order_type`") + cols = append(cols, constants.BuildCaseWhen(mt, "order_type", constants.OrderDigitOrderType, "order_digit.order_type")) continue } if f == "sms_channel" { // 短信渠道枚举:1=官方,2=专票 - cols = append(cols, "CASE `"+mt+"`.sms_channel WHEN 1 THEN '官方' WHEN 2 THEN '专票' ELSE '' END AS `order_digit.sms_channel`") + cols = append(cols, constants.BuildCaseWhen(mt, "sms_channel", constants.OrderDigitSmsChannel, "order_digit.sms_channel")) continue } } @@ -350,7 +350,7 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ } } if v, ok := req.Filters["out_trade_no_eq"]; ok { - s := toString(v) + s := utils.ToString(v) if s != "" { if tbl, col, ok := sch.FilterColumn("out_trade_no_eq"); ok { where = append(where, fmt.Sprintf("`%s`.%s = ?", sch.TableName(tbl), escape(col))) @@ -359,7 +359,7 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ } } if v, ok := req.Filters["account_eq"]; ok { - s := toString(v) + s := utils.ToString(v) if s != "" { if tbl, col, ok := sch.FilterColumn("account_eq"); ok { where = append(where, fmt.Sprintf("`%s`.%s = ?", sch.TableName(tbl), escape(col))) @@ -368,7 +368,7 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ } } if v, ok := req.Filters["plan_id_eq"]; ok { - s := toString(v) + s := utils.ToString(v) if s != "" && s != "0" { if tbl, col, ok := sch.FilterColumn("plan_id_eq"); ok { where = append(where, fmt.Sprintf("`%s`.%s = ?", sch.TableName(tbl), escape(col))) @@ -377,7 +377,7 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ } } if v, ok := req.Filters["key_batch_id_eq"]; ok { - s := toString(v) + s := utils.ToString(v) if s != "" { if tbl, col, ok := sch.FilterColumn("key_batch_id_eq"); ok { where = append(where, fmt.Sprintf("`%s`.%s = ?", sch.TableName(tbl), escape(col))) @@ -386,7 +386,7 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ } } if v, ok := req.Filters["product_id_eq"]; ok { - s := toString(v) + s := utils.ToString(v) if s != "" { if tbl, col, ok := sch.FilterColumn("product_id_eq"); ok { where = append(where, fmt.Sprintf("`%s`.%s = ?", sch.TableName(tbl), escape(col))) @@ -397,7 +397,7 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ if v, ok := req.Filters["reseller_id_eq"]; ok { // If merchant_id_in is present, it handles the merchant_id logic (via OR condition), if _, hasIn := req.Filters["merchant_id_in"]; !hasIn { - s := toString(v) + s := utils.ToString(v) if s != "" { if tbl, col, ok := sch.FilterColumn("reseller_id_eq"); ok { where = append(where, fmt.Sprintf("`%s`.%s = ?", sch.TableName(tbl), escape(col))) @@ -407,7 +407,7 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ } } if v, ok := req.Filters["code_batch_id_eq"]; ok { - s := toString(v) + s := utils.ToString(v) if s != "" { if tbl, col, ok := sch.FilterColumn("code_batch_id_eq"); ok { where = append(where, fmt.Sprintf("`%s`.%s = ?", sch.TableName(tbl), escape(col))) @@ -416,7 +416,7 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ } } if v, ok := req.Filters["order_cash_cash_activity_id_eq"]; ok { - s := toString(v) + s := utils.ToString(v) if s != "" { if tbl, col, ok := sch.FilterColumn("order_cash_cash_activity_id_eq"); ok { where = append(where, fmt.Sprintf("`%s`.%s = ?", sch.TableName(tbl), escape(col))) @@ -425,7 +425,7 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ } } if v, ok := req.Filters["order_voucher_channel_activity_id_eq"]; ok { - s := toString(v) + s := utils.ToString(v) if s != "" { if tbl, col, ok := sch.FilterColumn("order_voucher_channel_activity_id_eq"); ok { where = append(where, fmt.Sprintf("`%s`.%s = ?", sch.TableName(tbl), escape(col))) @@ -434,7 +434,7 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ } } if v, ok := req.Filters["voucher_batch_channel_activity_id_eq"]; ok { - s := toString(v) + s := utils.ToString(v) if s != "" { if tbl, col, ok := sch.FilterColumn("voucher_batch_channel_activity_id_eq"); ok { where = append(where, fmt.Sprintf("`%s`.%s = ?", sch.TableName(tbl), escape(col))) @@ -443,7 +443,7 @@ func BuildSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{ } } if v, ok := req.Filters["merchant_out_biz_no_eq"]; ok { - s := toString(v) + s := utils.ToString(v) if s != "" { if tbl, col, ok := sch.FilterColumn("merchant_out_biz_no_eq"); ok { where = append(where, fmt.Sprintf("`%s`.%s = ?", sch.TableName(tbl), escape(col))) @@ -519,45 +519,6 @@ func escape(s string) string { return s } -func toString(v interface{}) string { - switch t := v.(type) { - case []byte: - return string(t) - case string: - return t - case int64: - return strconv.FormatInt(t, 10) - case int32: - return strconv.FormatInt(int64(t), 10) - case int: - return strconv.Itoa(t) - case uint64: - return strconv.FormatUint(t, 10) - case uint32: - return strconv.FormatUint(uint64(t), 10) - case uint: - return strconv.FormatUint(uint64(t), 10) - case float64: - // 对于整数部分,使用整数格式;对于小数部分,保留必要精度 - if t == float64(int64(t)) { - return strconv.FormatInt(int64(t), 10) - } - return strconv.FormatFloat(t, 'f', -1, 64) - case float32: - // 对于整数部分,使用整数格式;对于小数部分,保留必要精度 - if t == float32(int64(t)) { - return strconv.FormatInt(int64(t), 10) - } - return strconv.FormatFloat(float64(t), 'f', -1, 32) - case time.Time: - return t.Format("2006-01-02 15:04:05") - case nil: - return "" - default: - return "" - } -} - // BuildCountSQL: minimal COUNT for filters-only joins, counting distinct main PK to avoid 1:N duplication func BuildCountSQL(req BuildRequest, whitelist map[string]bool) (string, []interface{}, error) { if req.MainTable != "order" && req.MainTable != "order_info" { @@ -654,7 +615,7 @@ func BuildCountSQL(req BuildRequest, whitelist map[string]bool) (string, []inter continue } } - if k == "plan_id_eq" && toString(v) == "0" { + if k == "plan_id_eq" && utils.ToString(v) == "0" { continue } if tbl, col, ok := sch.FilterColumn(k); ok { @@ -668,7 +629,7 @@ func BuildCountSQL(req BuildRequest, whitelist map[string]bool) (string, []inter args = append(args, arr[0], arr[1]) } default: - s := toString(v) + s := utils.ToString(v) if s != "" { where = append(where, fmt.Sprintf("`%s`.%s = ?", sch.TableName(tbl), escape(col))) args = append(args, s) diff --git a/server/internal/exporter/stream.go b/server/internal/exporter/stream.go index dc72d27..e2550d7 100644 --- a/server/internal/exporter/stream.go +++ b/server/internal/exporter/stream.go @@ -5,6 +5,7 @@ import ( "log" "server/internal/logging" "server/internal/schema" + "server/internal/utils" "strings" "time" ) @@ -172,8 +173,8 @@ func CountRowsFastChunked(db *sql.DB, ds, main string, filters map[string]interf switch t := v.(type) { case []interface{}: if len(t) == 2 { - start = toString(t[0]) - end = toString(t[1]) + start = utils.ToString(t[0]) + end = utils.ToString(t[1]) } case []string: if len(t) == 2 { @@ -286,7 +287,7 @@ func StreamWithCursor(db *sql.DB, base string, args []interface{}, cur *CursorSQ } else if out[idx] == nil { vals[i] = "" } else { - vals[i] = toString(out[idx]) + vals[i] = utils.ToString(out[idx]) } } if transform != nil { @@ -297,8 +298,8 @@ func StreamWithCursor(db *sql.DB, base string, args []interface{}, cur *CursorSQ part++ tick++ // update cursor state from injected columns - lastTs = toString(out[0]) - lastPk = toString(out[1]) + lastTs = utils.ToString(out[0]) + lastPk = utils.ToString(out[1]) if onProgress != nil && (tick == 1 || tick%200 == 0) { _ = onProgress(total) logging.JSON("INFO", map[string]interface{}{"event": "progress_tick", "total_rows": total}) @@ -380,7 +381,7 @@ func pagedOffset(db *sql.DB, base string, args []interface{}, batch int, cols [] } else if out[i] == nil { vals[i] = "" } else { - vals[i] = toString(out[i]) + vals[i] = utils.ToString(out[i]) } } if transform != nil { diff --git a/server/internal/repo/export_repo.go b/server/internal/repo/export_repo.go index 4a83642..102489c 100644 --- a/server/internal/repo/export_repo.go +++ b/server/internal/repo/export_repo.go @@ -4,6 +4,7 @@ package repo import ( "database/sql" "encoding/json" + "server/internal/constants" "server/internal/exporter" "server/internal/logging" "time" @@ -191,7 +192,7 @@ func (r *ExportQueryRepo) InsertJob( args := []interface{}{ templateID, - "queued", + string(constants.JobStatusQueued), requestedBy, ownerID, toJSON(permission), @@ -219,7 +220,7 @@ func (r *ExportQueryRepo) StartJob(metaDB *sql.DB, jobID uint64) { now := time.Now() _, err := metaDB.Exec( "UPDATE export_jobs SET status=?, started_at=?, updated_at=? WHERE id=?", - "running", now, now, jobID, + string(constants.JobStatusRunning), now, now, jobID, ) if err != nil { logging.DBError("start_job", jobID, err) @@ -248,7 +249,7 @@ func (r *ExportQueryRepo) MarkFailed(metaDB *sql.DB, jobID uint64) { now := time.Now() _, err := metaDB.Exec( "UPDATE export_jobs SET status=?, finished_at=? WHERE id=?", - "failed", now, jobID, + string(constants.JobStatusFailed), now, jobID, ) if err != nil { logging.DBError("mark_failed", jobID, err) @@ -263,7 +264,7 @@ func (r *ExportQueryRepo) MarkCompleted(metaDB *sql.DB, jobID uint64, totalRows SET status=?, finished_at=?, total_rows=?, row_estimate=GREATEST(COALESCE(row_estimate,0), ?), updated_at=? WHERE id=?`, - "completed", now, totalRows, totalRows, now, jobID, + string(constants.JobStatusCompleted), now, totalRows, totalRows, now, jobID, ) if err != nil { logging.DBError("mark_completed", jobID, err) diff --git a/server/internal/utils/convert.go b/server/internal/utils/convert.go new file mode 100644 index 0000000..84406fa --- /dev/null +++ b/server/internal/utils/convert.go @@ -0,0 +1,59 @@ +package utils +// Package utils 提供通用工具函数 +package utils + +import ( + "fmt" + "strconv" + "time" +) + +// ToString 将任意类型转换为字符串 +// 支持 []byte, string, int/int32/int64, uint/uint32/uint64, float32/float64, bool, time.Time +func ToString(v interface{}) string { + switch t := v.(type) { + case []byte: + return string(t) + case string: + return t + case int64: + return strconv.FormatInt(t, 10) + case int32: + return strconv.FormatInt(int64(t), 10) + case int: + return strconv.Itoa(t) + case uint64: + return strconv.FormatUint(t, 10) + case uint32: + return strconv.FormatUint(uint64(t), 10) + case uint: + return strconv.FormatUint(uint64(t), 10) + case float64: + // 对于整数部分,使用整数格式;对于小数部分,保留必要精度 + if t == float64(int64(t)) { + return strconv.FormatInt(int64(t), 10) + } + return strconv.FormatFloat(t, 'f', -1, 64) + case float32: + // 对于整数部分,使用整数格式;对于小数部分,保留必要精度 + if t == float32(int64(t)) { + return strconv.FormatInt(int64(t), 10) + } + return strconv.FormatFloat(float64(t), 'f', -1, 32) + case bool: + if t { + return "1" + } + return "0" + case time.Time: + return t.Format("2006-01-02 15:04:05") + case nil: + return "" + default: + // 尝试使用 fmt.Sprintf 作为兜底 + if s := fmt.Sprintf("%v", t); s != "" { + return s + } + return "" + } +}