MarketingSystemDataExportTool/server/internal/exporter/stream.go

429 lines
11 KiB
Go

package exporter
import (
"database/sql"
"log"
"server/internal/logging"
"server/internal/schema"
"strings"
"time"
)
type CursorSQL struct{ ds, main, mt, tsCol, pkCol string }
func NewCursorSQL(ds, main string) *CursorSQL {
sch := schema.Get(ds, main)
mt := sch.TableName(main)
ts, _ := sch.MapField(main, "create_time")
pk, _ := sch.MapField(main, "order_number")
return &CursorSQL{ds: ds, main: main, mt: mt, tsCol: ts, pkCol: pk}
}
func (c *CursorSQL) InjectSelect(base string) string {
idx := strings.Index(base, " FROM ")
if idx <= 0 {
return base
}
u := strings.ToUpper(base)
prefix := "SELECT "
if strings.HasPrefix(u, "SELECT DISTINCT ") {
prefix = "SELECT DISTINCT "
} else if strings.HasPrefix(u, "SELECT SQL_NO_CACHE ") {
prefix = "SELECT SQL_NO_CACHE "
}
return prefix + "`" + c.mt + "`." + c.tsCol + " AS __ts, `" + c.mt + "`." + c.pkCol + " AS __pk, " + base[len(prefix):]
}
func (c *CursorSQL) AddOrder(base string) string {
return base + " ORDER BY `" + c.mt + "`." + c.tsCol + ", `" + c.mt + "`." + c.pkCol
}
func (c *CursorSQL) AddCursor(base string) string {
u := strings.ToUpper(base)
cond := " AND ((`" + c.mt + "`." + c.tsCol + ") > ? OR ((`" + c.mt + "`." + c.tsCol + ") = ? AND (`" + c.mt + "`." + c.pkCol + ") > ?))"
if strings.Contains(u, " WHERE ") {
return base + cond
}
return base + strings.TrimPrefix(cond, " AND ")
}
func CountRows(db *sql.DB, base string, args []interface{}) int64 {
u := strings.ToUpper(base)
idx := strings.Index(u, " FROM ")
cut := len(base)
if idx > 0 {
for _, tok := range []string{" ORDER BY ", " LIMIT ", " OFFSET "} {
if p := strings.Index(u[idx:], tok); p >= 0 {
cp := idx + p
if cp < cut {
cut = cp
}
}
}
}
minimal := base
if idx > 0 {
seg := base[idx:cut]
minimal = "SELECT 1" + seg
}
q := "SELECT COUNT(1) FROM (" + minimal + ") AS sub"
row := db.QueryRow(q, args...)
var c int64
if err := row.Scan(&c); err != nil {
logging.JSON("ERROR", map[string]interface{}{"event": "count_error", "error": err.Error(), "sql": q, "args": args})
log.Printf("count_error sql=%s args=%v err=%v", q, args, err)
return 0
}
return c
}
func CountRowsFast(db *sql.DB, ds, main string, filters map[string]interface{}) int64 {
sch := schema.Get(ds, main)
mt := sch.TableName(main)
q := "SELECT COUNT(1) FROM `" + mt + "` WHERE 1=1"
args := []interface{}{}
addIn := func(col string, v interface{}) {
switch t := v.(type) {
case []interface{}:
if len(t) == 0 {
return
}
ph := make([]string, len(t))
for i := range t {
ph[i] = "?"
args = append(args, t[i])
}
q += " AND `" + col + "` IN (" + strings.Join(ph, ",") + ")"
case []string:
if len(t) == 0 {
return
}
ph := make([]string, len(t))
for i := range t {
ph[i] = "?"
args = append(args, t[i])
}
q += " AND `" + col + "` IN (" + strings.Join(ph, ",") + ")"
case []int:
if len(t) == 0 {
return
}
ph := make([]string, len(t))
for i := range t {
ph[i] = "?"
args = append(args, t[i])
}
q += " AND `" + col + "` IN (" + strings.Join(ph, ",") + ")"
case []int64:
if len(t) == 0 {
return
}
ph := make([]string, len(t))
for i := range t {
ph[i] = "?"
args = append(args, t[i])
}
q += " AND `" + col + "` IN (" + strings.Join(ph, ",") + ")"
}
}
for k, v := range filters {
tbl, col, ok := sch.FilterColumn(k)
if !ok {
continue
}
if tbl != "order" {
continue
}
switch k {
case "creator_in":
addIn(col, v)
case "create_time_between":
switch t := v.(type) {
case []interface{}:
if len(t) == 2 {
q += " AND `" + col + "` BETWEEN ? AND ?"
args = append(args, t[0], t[1])
}
case []string:
if len(t) == 2 {
q += " AND `" + col + "` BETWEEN ? AND ?"
args = append(args, t[0], t[1])
}
}
default:
q += " AND `" + col + "` = ?"
args = append(args, v)
}
}
row := db.QueryRow(q, args...)
var c int64
if err := row.Scan(&c); err != nil {
logging.JSON("ERROR", map[string]interface{}{"event": "count_fast_error", "error": err.Error(), "sql": q, "args": args})
log.Printf("count_fast_error sql=%s args=%v err=%v", q, args, err)
return 0
}
return c
}
func CountRowsFastChunked(db *sql.DB, ds, main string, filters map[string]interface{}) int64 {
start := ""
end := ""
if v, ok := filters["create_time_between"]; ok {
switch t := v.(type) {
case []interface{}:
if len(t) == 2 {
start = toString(t[0])
end = toString(t[1])
}
case []string:
if len(t) == 2 {
start = t[0]
end = t[1]
}
}
}
if start == "" || end == "" {
return CountRowsFast(db, ds, main, filters)
}
ranges := SplitByDays(start, end, 15)
var total int64
for _, rg := range ranges {
fl := map[string]interface{}{}
for k, v := range filters {
fl[k] = v
}
fl["create_time_between"] = []string{rg[0], rg[1]}
total += CountRowsFast(db, ds, main, fl)
}
return total
}
// SplitByDays 按天数分割时间范围,返回多个时间区间
func SplitByDays(startStr, endStr string, stepDays int) [][2]string {
layout := "2006-01-02 15:04:05"
s := strings.TrimSpace(startStr)
e := strings.TrimSpace(endStr)
st, err1 := time.Parse(layout, s)
en, err2 := time.Parse(layout, e)
if err1 != nil || err2 != nil || !en.After(st) || stepDays <= 0 {
return [][2]string{{s, e}}
}
var out [][2]string
cur := st
step := time.Duration(stepDays) * 24 * time.Hour
for cur.Before(en) {
nxt := cur.Add(step)
if nxt.After(en) {
nxt = en
}
out = append(out, [2]string{cur.Format(layout), nxt.Format(layout)})
cur = nxt
}
return out
}
type RowTransform func([]string) []string
type RollCallback func(path string, size int64, partRows int64) error
type ProgressCallback func(totalRows int64) error
func StreamWithCursor(db *sql.DB, base string, args []interface{}, cur *CursorSQL, batch int, cols []string, newWriter func() (RowWriter, error), transform RowTransform, maxRowsPerFile int64, onRoll RollCallback, onProgress ProgressCallback) (int64, []string, error) {
w, err := newWriter()
if err != nil {
return 0, nil, err
}
_ = w.WriteHeader(cols)
if onProgress != nil {
_ = onProgress(0)
}
out := make([]interface{}, len(cols)+2)
dest := make([]interface{}, len(cols)+2)
for i := range out {
dest[i] = &out[i]
}
var total int64
var part int64
var tick int64
files := []string{}
lastTs := ""
lastPk := ""
for {
q2 := cur.InjectSelect(base)
if lastTs != "" || lastPk != "" {
q2 = cur.AddCursor(q2)
}
q2 = cur.AddOrder(q2) + " LIMIT ?"
args2 := append([]interface{}{}, args...)
if lastTs != "" || lastPk != "" {
args2 = append(args2, lastTs, lastTs, lastPk)
}
args2 = append(args2, batch)
rows, e := db.Query(q2, args2...)
if e != nil {
logging.JSON("ERROR", map[string]interface{}{"event": "cursor_query_error", "sql": q2, "args": args2, "error": e.Error()})
log.Printf("cursor_query_error sql=%s args=%v err=%v", q2, args2, e)
// fallback to LIMIT/OFFSET pagination when cursor query fails
_ = rows
_, _, _ = w.Close()
return pagedOffset(db, base, args, batch, cols, newWriter, transform, maxRowsPerFile, onRoll, onProgress)
}
fetched := false
for rows.Next() {
fetched = true
if e := rows.Scan(dest...); e != nil {
rows.Close()
// fallback to LIMIT/OFFSET when scan fails (likely column mismatch)
logging.JSON("ERROR", map[string]interface{}{"event": "cursor_scan_error", "error": e.Error()})
log.Printf("cursor_scan_error err=%v", e)
_, _, _ = w.Close()
return pagedOffset(db, base, args, batch, cols, newWriter, transform, maxRowsPerFile, onRoll, onProgress)
}
vals := make([]string, len(cols))
for i := 0; i < len(cols); i++ {
// skip the injected cursor columns (__ts, __pk) at positions 0 and 1
idx := i + 2
if b, ok := out[idx].([]byte); ok {
vals[i] = string(b)
} else if out[idx] == nil {
vals[i] = ""
} else {
vals[i] = toString(out[idx])
}
}
if transform != nil {
vals = transform(vals)
}
_ = w.WriteRow(vals)
total++
part++
tick++
// update cursor state from injected columns
lastTs = toString(out[0])
lastPk = toString(out[1])
if onProgress != nil && (tick == 1 || tick%200 == 0) {
_ = onProgress(total)
logging.JSON("INFO", map[string]interface{}{"event": "progress_tick", "total_rows": total})
}
if part >= maxRowsPerFile {
p, sz, _ := w.Close()
files = append(files, p)
if onRoll != nil {
_ = onRoll(p, sz, part)
}
w, e = newWriter()
if e != nil {
rows.Close()
return total, files, e
}
_ = w.WriteHeader(cols)
part = 0
}
}
rows.Close()
if !fetched {
break
}
}
p, sz, _ := w.Close()
if part > 0 || len(files) == 0 {
files = append(files, p)
if onRoll != nil {
_ = onRoll(p, sz, part)
}
}
if onProgress != nil {
_ = onProgress(total)
}
return total, files, nil
}
// pagedOffset provides a robust fallback using LIMIT/OFFSET without cursor columns
func pagedOffset(db *sql.DB, base string, args []interface{}, batch int, cols []string, newWriter func() (RowWriter, error), transform RowTransform, maxRowsPerFile int64, onRoll RollCallback, onProgress ProgressCallback) (int64, []string, error) {
w, err := newWriter()
if err != nil {
return 0, nil, err
}
_ = w.WriteHeader(cols)
if onProgress != nil {
_ = onProgress(0)
}
files := []string{}
var total int64
var part int64
var tick int64
for off := 0; ; off += batch {
q := "SELECT * FROM (" + base + ") AS sub LIMIT ? OFFSET ?"
args2 := append(append([]interface{}{}, args...), batch, off)
rows, e := db.Query(q, args2...)
if e != nil {
logging.JSON("ERROR", map[string]interface{}{"event": "offset_query_error", "sql": q, "args": args2, "error": e.Error()})
log.Printf("offset_query_error sql=%s args=%v err=%v", q, args2, e)
return total, files, e
}
fetched := false
out := make([]interface{}, len(cols))
dest := make([]interface{}, len(cols))
for i := range out {
dest[i] = &out[i]
}
for rows.Next() {
fetched = true
if e := rows.Scan(dest...); e != nil {
rows.Close()
logging.JSON("ERROR", map[string]interface{}{"event": "offset_scan_error", "error": e.Error()})
log.Printf("offset_scan_error err=%v", e)
return total, files, e
}
vals := make([]string, len(cols))
for i := 0; i < len(cols); i++ {
if b, ok := out[i].([]byte); ok {
vals[i] = string(b)
} else if out[i] == nil {
vals[i] = ""
} else {
vals[i] = toString(out[i])
}
}
if transform != nil {
vals = transform(vals)
}
_ = w.WriteRow(vals)
total++
part++
tick++
if onProgress != nil && (tick == 1 || tick%200 == 0) {
_ = onProgress(total)
logging.JSON("INFO", map[string]interface{}{"event": "progress_tick", "total_rows": total})
}
if part >= maxRowsPerFile {
p, sz, _ := w.Close()
files = append(files, p)
if onRoll != nil {
_ = onRoll(p, sz, part)
}
w, e = newWriter()
if e != nil {
rows.Close()
return total, files, e
}
_ = w.WriteHeader(cols)
part = 0
}
}
rows.Close()
if !fetched {
break
}
}
p, sz, _ := w.Close()
if part > 0 || len(files) == 0 {
files = append(files, p)
if onRoll != nil {
_ = onRoll(p, sz, part)
}
}
if onProgress != nil {
_ = onProgress(total)
}
return total, files, nil
}