MarketingSystemDataExportTool/server/internal/exporter/stream.go

439 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package exporter
import (
"database/sql"
"log"
"server/internal/logging"
"server/internal/schema"
"server/internal/utils"
"strings"
"time"
)
type CursorSQL struct{ ds, main, mt, tsCol, pkCol string }
func NewCursorSQL(ds, main string) *CursorSQL {
sch := schema.Get(ds, main)
mt := sch.TableName(main)
ts, _ := sch.MapField(main, "create_time")
pk, _ := sch.MapField(main, "order_number")
return &CursorSQL{ds: ds, main: main, mt: mt, tsCol: ts, pkCol: pk}
}
func (c *CursorSQL) InjectSelect(base string) string {
idx := strings.Index(base, " FROM ")
if idx <= 0 {
return base
}
u := strings.ToUpper(base)
prefix := "SELECT "
if strings.HasPrefix(u, "SELECT DISTINCT ") {
prefix = "SELECT DISTINCT "
} else if strings.HasPrefix(u, "SELECT SQL_NO_CACHE ") {
prefix = "SELECT SQL_NO_CACHE "
}
return prefix + "`" + c.mt + "`." + c.tsCol + " AS __ts, `" + c.mt + "`." + c.pkCol + " AS __pk, " + base[len(prefix):]
}
func (c *CursorSQL) AddOrder(base string) string {
return base + " ORDER BY `" + c.mt + "`." + c.tsCol + ", `" + c.mt + "`." + c.pkCol
}
func (c *CursorSQL) AddCursor(base string) string {
u := strings.ToUpper(base)
cond := " AND ((`" + c.mt + "`." + c.tsCol + ") > ? OR ((`" + c.mt + "`." + c.tsCol + ") = ? AND (`" + c.mt + "`." + c.pkCol + ") > ?))"
if strings.Contains(u, " WHERE ") {
return base + cond
}
return base + strings.TrimPrefix(cond, " AND ")
}
func CountRows(db *sql.DB, base string, args []interface{}) int64 {
u := strings.ToUpper(base)
idx := strings.Index(u, " FROM ")
cut := len(base)
if idx > 0 {
for _, tok := range []string{" ORDER BY ", " LIMIT ", " OFFSET "} {
if p := strings.Index(u[idx:], tok); p >= 0 {
cp := idx + p
if cp < cut {
cut = cp
}
}
}
}
minimal := base
if idx > 0 {
seg := base[idx:cut]
minimal = "SELECT 1" + seg
}
q := "SELECT COUNT(1) FROM (" + minimal + ") AS sub"
row := db.QueryRow(q, args...)
var c int64
if err := row.Scan(&c); err != nil {
logging.JSON("ERROR", map[string]interface{}{"event": "count_error", "error": err.Error(), "sql": q, "args": args})
log.Printf("count_error sql=%s args=%v err=%v", q, args, err)
return 0
}
return c
}
func CountRowsFast(db *sql.DB, ds, main string, filters map[string]interface{}) int64 {
sch := schema.Get(ds, main)
mt := sch.TableName(main)
q := "SELECT COUNT(1) FROM `" + mt + "` WHERE 1=1"
args := []interface{}{}
addIn := func(col string, v interface{}) {
switch t := v.(type) {
case []interface{}:
if len(t) == 0 {
return
}
ph := make([]string, len(t))
for i := range t {
ph[i] = "?"
args = append(args, t[i])
}
q += " AND `" + col + "` IN (" + strings.Join(ph, ",") + ")"
case []string:
if len(t) == 0 {
return
}
ph := make([]string, len(t))
for i := range t {
ph[i] = "?"
args = append(args, t[i])
}
q += " AND `" + col + "` IN (" + strings.Join(ph, ",") + ")"
case []int:
if len(t) == 0 {
return
}
ph := make([]string, len(t))
for i := range t {
ph[i] = "?"
args = append(args, t[i])
}
q += " AND `" + col + "` IN (" + strings.Join(ph, ",") + ")"
case []int64:
if len(t) == 0 {
return
}
ph := make([]string, len(t))
for i := range t {
ph[i] = "?"
args = append(args, t[i])
}
q += " AND `" + col + "` IN (" + strings.Join(ph, ",") + ")"
}
}
for k, v := range filters {
tbl, col, ok := sch.FilterColumn(k)
if !ok {
continue
}
if tbl != "order" {
continue
}
switch k {
case "creator_in":
addIn(col, v)
case "create_time_between":
switch t := v.(type) {
case []interface{}:
if len(t) == 2 {
q += " AND `" + col + "` BETWEEN ? AND ?"
args = append(args, t[0], t[1])
}
case []string:
if len(t) == 2 {
q += " AND `" + col + "` BETWEEN ? AND ?"
args = append(args, t[0], t[1])
}
}
default:
q += " AND `" + col + "` = ?"
args = append(args, v)
}
}
// 记录行数统计 SQL方便排查估算问题
logging.JSON("INFO", map[string]interface{}{
"event": "count_fast_query",
"datasource": ds,
"main": main,
"sql": q,
"args": args,
"filters": filters,
})
row := db.QueryRow(q, args...)
var c int64
if err := row.Scan(&c); err != nil {
logging.JSON("ERROR", map[string]interface{}{"event": "count_fast_error", "error": err.Error(), "sql": q, "args": args})
log.Printf("count_fast_error sql=%s args=%v err=%v", q, args, err)
return 0
}
return c
}
func CountRowsFastChunked(db *sql.DB, ds, main string, filters map[string]interface{}) int64 {
start := ""
end := ""
if v, ok := filters["create_time_between"]; ok {
switch t := v.(type) {
case []interface{}:
if len(t) == 2 {
start = utils.ToString(t[0])
end = utils.ToString(t[1])
}
case []string:
if len(t) == 2 {
start = t[0]
end = t[1]
}
}
}
if start == "" || end == "" {
return CountRowsFast(db, ds, main, filters)
}
ranges := SplitByDays(start, end, 15)
var total int64
for _, rg := range ranges {
fl := map[string]interface{}{}
for k, v := range filters {
fl[k] = v
}
fl["create_time_between"] = []string{rg[0], rg[1]}
total += CountRowsFast(db, ds, main, fl)
}
return total
}
// SplitByDays 按天数分割时间范围,返回多个时间区间
func SplitByDays(startStr, endStr string, stepDays int) [][2]string {
layout := "2006-01-02 15:04:05"
s := strings.TrimSpace(startStr)
e := strings.TrimSpace(endStr)
st, err1 := time.Parse(layout, s)
en, err2 := time.Parse(layout, e)
if err1 != nil || err2 != nil || !en.After(st) || stepDays <= 0 {
return [][2]string{{s, e}}
}
var out [][2]string
cur := st
step := time.Duration(stepDays) * 24 * time.Hour
for cur.Before(en) {
nxt := cur.Add(step)
if nxt.After(en) {
nxt = en
}
out = append(out, [2]string{cur.Format(layout), nxt.Format(layout)})
cur = nxt
}
return out
}
type RowTransform func([]string) []string
type RollCallback func(path string, size int64, partRows int64) error
type ProgressCallback func(totalRows int64) error
func StreamWithCursor(db *sql.DB, base string, args []interface{}, cur *CursorSQL, batch int, cols []string, newWriter func() (RowWriter, error), transform RowTransform, maxRowsPerFile int64, onRoll RollCallback, onProgress ProgressCallback) (int64, []string, error) {
w, err := newWriter()
if err != nil {
return 0, nil, err
}
_ = w.WriteHeader(cols)
if onProgress != nil {
_ = onProgress(0)
}
out := make([]interface{}, len(cols)+2)
dest := make([]interface{}, len(cols)+2)
for i := range out {
dest[i] = &out[i]
}
var total int64
var part int64
var tick int64
files := []string{}
lastTs := ""
lastPk := ""
for {
q2 := cur.InjectSelect(base)
if lastTs != "" || lastPk != "" {
q2 = cur.AddCursor(q2)
}
q2 = cur.AddOrder(q2) + " LIMIT ?"
args2 := append([]interface{}{}, args...)
if lastTs != "" || lastPk != "" {
args2 = append(args2, lastTs, lastTs, lastPk)
}
args2 = append(args2, batch)
rows, e := db.Query(q2, args2...)
if e != nil {
logging.JSON("ERROR", map[string]interface{}{"event": "cursor_query_error", "sql": q2, "args": args2, "error": e.Error()})
log.Printf("cursor_query_error sql=%s args=%v err=%v", q2, args2, e)
// fallback to LIMIT/OFFSET pagination when cursor query fails
_ = rows
_, _, _ = w.Close()
return pagedOffset(db, base, args, batch, cols, newWriter, transform, maxRowsPerFile, onRoll, onProgress)
}
fetched := false
for rows.Next() {
fetched = true
if e := rows.Scan(dest...); e != nil {
rows.Close()
// fallback to LIMIT/OFFSET when scan fails (likely column mismatch)
logging.JSON("ERROR", map[string]interface{}{"event": "cursor_scan_error", "error": e.Error()})
log.Printf("cursor_scan_error err=%v", e)
_, _, _ = w.Close()
return pagedOffset(db, base, args, batch, cols, newWriter, transform, maxRowsPerFile, onRoll, onProgress)
}
vals := make([]string, len(cols))
for i := 0; i < len(cols); i++ {
// skip the injected cursor columns (__ts, __pk) at positions 0 and 1
idx := i + 2
if b, ok := out[idx].([]byte); ok {
vals[i] = string(b)
} else if out[idx] == nil {
vals[i] = ""
} else {
vals[i] = utils.ToString(out[idx])
}
}
if transform != nil {
vals = transform(vals)
}
_ = w.WriteRow(vals)
total++
part++
tick++
// update cursor state from injected columns
lastTs = utils.ToString(out[0])
lastPk = utils.ToString(out[1])
if onProgress != nil && (tick == 1 || tick%200 == 0) {
_ = onProgress(total)
logging.JSON("INFO", map[string]interface{}{"event": "progress_tick", "total_rows": total})
}
if part >= maxRowsPerFile {
p, sz, _ := w.Close()
files = append(files, p)
if onRoll != nil {
_ = onRoll(p, sz, part)
}
w, e = newWriter()
if e != nil {
rows.Close()
return total, files, e
}
_ = w.WriteHeader(cols)
part = 0
}
}
rows.Close()
if !fetched {
break
}
}
p, sz, _ := w.Close()
if part > 0 || len(files) == 0 {
files = append(files, p)
if onRoll != nil {
_ = onRoll(p, sz, part)
}
}
if onProgress != nil {
_ = onProgress(total)
}
return total, files, nil
}
// pagedOffset provides a robust fallback using LIMIT/OFFSET without cursor columns
func pagedOffset(db *sql.DB, base string, args []interface{}, batch int, cols []string, newWriter func() (RowWriter, error), transform RowTransform, maxRowsPerFile int64, onRoll RollCallback, onProgress ProgressCallback) (int64, []string, error) {
w, err := newWriter()
if err != nil {
return 0, nil, err
}
_ = w.WriteHeader(cols)
if onProgress != nil {
_ = onProgress(0)
}
files := []string{}
var total int64
var part int64
var tick int64
for off := 0; ; off += batch {
q := "SELECT * FROM (" + base + ") AS sub LIMIT ? OFFSET ?"
args2 := append(append([]interface{}{}, args...), batch, off)
rows, e := db.Query(q, args2...)
if e != nil {
logging.JSON("ERROR", map[string]interface{}{"event": "offset_query_error", "sql": q, "args": args2, "error": e.Error()})
log.Printf("offset_query_error sql=%s args=%v err=%v", q, args2, e)
return total, files, e
}
fetched := false
out := make([]interface{}, len(cols))
dest := make([]interface{}, len(cols))
for i := range out {
dest[i] = &out[i]
}
for rows.Next() {
fetched = true
if e := rows.Scan(dest...); e != nil {
rows.Close()
logging.JSON("ERROR", map[string]interface{}{"event": "offset_scan_error", "error": e.Error()})
log.Printf("offset_scan_error err=%v", e)
return total, files, e
}
vals := make([]string, len(cols))
for i := 0; i < len(cols); i++ {
if b, ok := out[i].([]byte); ok {
vals[i] = string(b)
} else if out[i] == nil {
vals[i] = ""
} else {
vals[i] = utils.ToString(out[i])
}
}
if transform != nil {
vals = transform(vals)
}
_ = w.WriteRow(vals)
total++
part++
tick++
if onProgress != nil && (tick == 1 || tick%200 == 0) {
_ = onProgress(total)
logging.JSON("INFO", map[string]interface{}{"event": "progress_tick", "total_rows": total})
}
if part >= maxRowsPerFile {
p, sz, _ := w.Close()
files = append(files, p)
if onRoll != nil {
_ = onRoll(p, sz, part)
}
w, e = newWriter()
if e != nil {
rows.Close()
return total, files, e
}
_ = w.WriteHeader(cols)
part = 0
}
}
rows.Close()
if !fetched {
break
}
}
p, sz, _ := w.Close()
if part > 0 || len(files) == 0 {
files = append(files, p)
if onRoll != nil {
_ = onRoll(p, sz, part)
}
}
if onProgress != nil {
_ = onProgress(total)
}
return total, files, nil
}