MarketingSystemDataExportTool/server/internal/exporter/stream.go

520 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package exporter
import (
"database/sql"
"log"
"server/internal/logging"
"server/internal/schema"
"server/internal/utils"
"strings"
"time"
)
type CursorSQL struct{ ds, main, mt, tsCol, pkCol string }
func NewCursorSQL(ds, main string) *CursorSQL {
sch := schema.Get(ds, main)
mt := sch.TableName(main)
ts, _ := sch.MapField(main, "create_time")
pk, _ := sch.MapField(main, "order_number")
return &CursorSQL{ds: ds, main: main, mt: mt, tsCol: ts, pkCol: pk}
}
func (c *CursorSQL) InjectSelect(base string) string {
idx := strings.Index(base, " FROM ")
if idx <= 0 {
return base
}
u := strings.ToUpper(base)
prefix := "SELECT "
if strings.HasPrefix(u, "SELECT DISTINCT ") {
prefix = "SELECT DISTINCT "
} else if strings.HasPrefix(u, "SELECT SQL_NO_CACHE ") {
prefix = "SELECT SQL_NO_CACHE "
}
return prefix + "`" + c.mt + "`." + c.tsCol + " AS __ts, `" + c.mt + "`." + c.pkCol + " AS __pk, " + base[len(prefix):]
}
func (c *CursorSQL) AddOrder(base string) string {
return base + " ORDER BY `" + c.mt + "`." + c.tsCol + ", `" + c.mt + "`." + c.pkCol
}
func (c *CursorSQL) AddCursor(base string) string {
u := strings.ToUpper(base)
cond := " AND ((`" + c.mt + "`." + c.tsCol + ") > ? OR ((`" + c.mt + "`." + c.tsCol + ") = ? AND (`" + c.mt + "`." + c.pkCol + ") > ?))"
if strings.Contains(u, " WHERE ") {
return base + cond
}
return base + strings.TrimPrefix(cond, " AND ")
}
func CountRows(db *sql.DB, base string, args []interface{}) int64 {
u := strings.ToUpper(base)
idx := strings.Index(u, " FROM ")
cut := len(base)
if idx > 0 {
for _, tok := range []string{" ORDER BY ", " LIMIT ", " OFFSET "} {
if p := strings.Index(u[idx:], tok); p >= 0 {
cp := idx + p
if cp < cut {
cut = cp
}
}
}
}
minimal := base
if idx > 0 {
seg := base[idx:cut]
minimal = "SELECT 1" + seg
}
q := "SELECT COUNT(1) FROM (" + minimal + ") AS sub"
row := db.QueryRow(q, args...)
var c int64
if err := row.Scan(&c); err != nil {
logging.JSON("ERROR", map[string]interface{}{"event": "count_error", "error": err.Error(), "sql": q, "args": args})
log.Printf("count_error sql=%s args=%v err=%v", q, args, err)
return 0
}
return c
}
func CountRowsFast(db *sql.DB, ds, main string, filters map[string]interface{}) int64 {
sch := schema.Get(ds, main)
mt := sch.TableName(main)
q := "SELECT COUNT(1) FROM `" + mt + "` WHERE 1=1"
args := []interface{}{}
addIn := func(col string, v interface{}) {
switch t := v.(type) {
case []interface{}:
if len(t) == 0 {
return
}
ph := make([]string, len(t))
for i := range t {
ph[i] = "?"
args = append(args, t[i])
}
q += " AND `" + col + "` IN (" + strings.Join(ph, ",") + ")"
case []string:
if len(t) == 0 {
return
}
ph := make([]string, len(t))
for i := range t {
ph[i] = "?"
args = append(args, t[i])
}
q += " AND `" + col + "` IN (" + strings.Join(ph, ",") + ")"
case []int:
if len(t) == 0 {
return
}
ph := make([]string, len(t))
for i := range t {
ph[i] = "?"
args = append(args, t[i])
}
q += " AND `" + col + "` IN (" + strings.Join(ph, ",") + ")"
case []int64:
if len(t) == 0 {
return
}
ph := make([]string, len(t))
for i := range t {
ph[i] = "?"
args = append(args, t[i])
}
q += " AND `" + col + "` IN (" + strings.Join(ph, ",") + ")"
}
}
for k, v := range filters {
tbl, col, ok := sch.FilterColumn(k)
if !ok {
continue
}
if tbl != "order" {
continue
}
switch k {
case "creator_in":
addIn(col, v)
case "create_time_between":
switch t := v.(type) {
case []interface{}:
if len(t) == 2 {
q += " AND `" + col + "` BETWEEN ? AND ?"
args = append(args, t[0], t[1])
}
case []string:
if len(t) == 2 {
q += " AND `" + col + "` BETWEEN ? AND ?"
args = append(args, t[0], t[1])
}
}
default:
q += " AND `" + col + "` = ?"
args = append(args, v)
}
}
// 记录行数统计 SQL方便排查估算问题
logging.JSON("INFO", map[string]interface{}{
"event": "count_fast_query",
"datasource": ds,
"main": main,
"sql": q,
"args": args,
"filters": filters,
})
row := db.QueryRow(q, args...)
var c int64
if err := row.Scan(&c); err != nil {
logging.JSON("ERROR", map[string]interface{}{"event": "count_fast_error", "error": err.Error(), "sql": q, "args": args})
log.Printf("count_fast_error sql=%s args=%v err=%v", q, args, err)
return 0
}
return c
}
func CountRowsFastChunked(db *sql.DB, ds, main string, filters map[string]interface{}) int64 {
start := ""
end := ""
if v, ok := filters["create_time_between"]; ok {
switch t := v.(type) {
case []interface{}:
if len(t) == 2 {
start = utils.ToString(t[0])
end = utils.ToString(t[1])
}
case []string:
if len(t) == 2 {
start = t[0]
end = t[1]
}
}
}
if start == "" || end == "" {
return CountRowsFast(db, ds, main, filters)
}
// 计算时间跨度(天数)
layout := "2006-01-02 15:04:05"
st, err1 := time.Parse(layout, start)
en, err2 := time.Parse(layout, end)
if err1 != nil || err2 != nil {
return CountRowsFast(db, ds, main, filters)
}
daysDiff := int(en.Sub(st).Hours() / 24)
// 优化:根据时间跨度自适应分块策略
// ≤15天: 直接查询≤30天: 按天分片,>30天: 按月分片30天
var ranges [][2]string
if daysDiff <= 15 {
// 15天内直接查询不分片
return CountRowsFast(db, ds, main, filters)
} else if daysDiff <= 30 {
// 15-30天按天分片步长15
ranges = SplitByDays(start, end, 15)
} else if daysDiff <= 90 {
// 30-90天按周分片7天
ranges = SplitByWeeks(start, end)
} else {
// >90天按月分片30天
ranges = SplitByMonths(start, end)
logging.JSON("INFO", map[string]interface{}{
"event": "count_chunked_by_months",
"datasource": ds,
"main": main,
"days_diff": daysDiff,
"chunks": len(ranges),
})
}
var total int64
for _, rg := range ranges {
fl := map[string]interface{}{}
for k, v := range filters {
fl[k] = v
}
fl["create_time_between"] = []string{rg[0], rg[1]}
total += CountRowsFast(db, ds, main, fl)
}
return total
}
// SplitByDays 按天数分割时间范围,返回多个时间区间
func SplitByDays(startStr, endStr string, stepDays int) [][2]string {
layout := "2006-01-02 15:04:05"
s := strings.TrimSpace(startStr)
e := strings.TrimSpace(endStr)
st, err1 := time.Parse(layout, s)
en, err2 := time.Parse(layout, e)
if err1 != nil || err2 != nil || !en.After(st) || stepDays <= 0 {
return [][2]string{{s, e}}
}
var out [][2]string
cur := st
step := time.Duration(stepDays) * 24 * time.Hour
for cur.Before(en) {
nxt := cur.Add(step)
if nxt.After(en) {
nxt = en
}
out = append(out, [2]string{cur.Format(layout), nxt.Format(layout)})
cur = nxt
}
return out
}
// SplitByWeeks 按周7天分割时间范围返回多个时间区间
func SplitByWeeks(startStr, endStr string) [][2]string {
layout := "2006-01-02 15:04:05"
s := strings.TrimSpace(startStr)
e := strings.TrimSpace(endStr)
st, err1 := time.Parse(layout, s)
en, err2 := time.Parse(layout, e)
if err1 != nil || err2 != nil || !en.After(st) {
return [][2]string{{s, e}}
}
var out [][2]string
cur := st
weekDuration := 7 * 24 * time.Hour
for cur.Before(en) {
nxt := cur.Add(weekDuration)
if nxt.After(en) {
nxt = en
}
out = append(out, [2]string{cur.Format(layout), nxt.Format(layout)})
cur = nxt
}
return out
}
// SplitByMonths 按月30天分割时间范围用于超長时间跨度的分块統計
func SplitByMonths(startStr, endStr string) [][2]string {
layout := "2006-01-02 15:04:05"
s := strings.TrimSpace(startStr)
e := strings.TrimSpace(endStr)
st, err1 := time.Parse(layout, s)
en, err2 := time.Parse(layout, e)
if err1 != nil || err2 != nil || !en.After(st) {
return [][2]string{{s, e}}
}
var out [][2]string
cur := st
// 按月分片
for cur.Before(en) {
// 下个月的第一天
nxt := time.Date(cur.Year(), cur.Month()+1, 1, 0, 0, 0, 0, cur.Location())
if nxt.After(en) {
nxt = en
}
out = append(out, [2]string{cur.Format(layout), nxt.Format(layout)})
cur = nxt
}
return out
}
type RowTransform func([]string) []string
type RollCallback func(path string, size int64, partRows int64) error
type ProgressCallback func(totalRows int64) error
func StreamWithCursor(db *sql.DB, base string, args []interface{}, cur *CursorSQL, batch int, cols []string, newWriter func() (RowWriter, error), transform RowTransform, maxRowsPerFile int64, onRoll RollCallback, onProgress ProgressCallback) (int64, []string, error) {
w, err := newWriter()
if err != nil {
return 0, nil, err
}
_ = w.WriteHeader(cols)
if onProgress != nil {
_ = onProgress(0)
}
out := make([]interface{}, len(cols)+2)
dest := make([]interface{}, len(cols)+2)
for i := range out {
dest[i] = &out[i]
}
var total int64
var part int64
var tick int64
files := []string{}
lastTs := ""
lastPk := ""
for {
q2 := cur.InjectSelect(base)
if lastTs != "" || lastPk != "" {
q2 = cur.AddCursor(q2)
}
q2 = cur.AddOrder(q2) + " LIMIT ?"
args2 := append([]interface{}{}, args...)
if lastTs != "" || lastPk != "" {
args2 = append(args2, lastTs, lastTs, lastPk)
}
args2 = append(args2, batch)
rows, e := db.Query(q2, args2...)
if e != nil {
logging.JSON("ERROR", map[string]interface{}{"event": "cursor_query_error", "sql": q2, "args": args2, "error": e.Error()})
log.Printf("cursor_query_error sql=%s args=%v err=%v", q2, args2, e)
// fallback to LIMIT/OFFSET pagination when cursor query fails
_ = rows
_, _, _ = w.Close()
return pagedOffset(db, base, args, batch, cols, newWriter, transform, maxRowsPerFile, onRoll, onProgress)
}
fetched := false
for rows.Next() {
fetched = true
if e := rows.Scan(dest...); e != nil {
rows.Close()
// fallback to LIMIT/OFFSET when scan fails (likely column mismatch)
logging.JSON("ERROR", map[string]interface{}{"event": "cursor_scan_error", "error": e.Error()})
log.Printf("cursor_scan_error err=%v", e)
_, _, _ = w.Close()
return pagedOffset(db, base, args, batch, cols, newWriter, transform, maxRowsPerFile, onRoll, onProgress)
}
vals := make([]string, len(cols))
for i := 0; i < len(cols); i++ {
// skip the injected cursor columns (__ts, __pk) at positions 0 and 1
idx := i + 2
if b, ok := out[idx].([]byte); ok {
vals[i] = string(b)
} else if out[idx] == nil {
vals[i] = ""
} else {
vals[i] = utils.ToString(out[idx])
}
}
if transform != nil {
vals = transform(vals)
}
_ = w.WriteRow(vals)
total++
part++
tick++
// update cursor state from injected columns
lastTs = utils.ToString(out[0])
lastPk = utils.ToString(out[1])
if onProgress != nil && (tick == 1 || tick%200 == 0) {
_ = onProgress(total)
logging.JSON("INFO", map[string]interface{}{"event": "progress_tick", "total_rows": total})
}
if part >= maxRowsPerFile {
p, sz, _ := w.Close()
files = append(files, p)
if onRoll != nil {
_ = onRoll(p, sz, part)
}
w, e = newWriter()
if e != nil {
rows.Close()
return total, files, e
}
_ = w.WriteHeader(cols)
part = 0
}
}
rows.Close()
if !fetched {
break
}
}
p, sz, _ := w.Close()
if part > 0 || len(files) == 0 {
files = append(files, p)
if onRoll != nil {
_ = onRoll(p, sz, part)
}
}
if onProgress != nil {
_ = onProgress(total)
}
return total, files, nil
}
// pagedOffset provides a robust fallback using LIMIT/OFFSET without cursor columns
func pagedOffset(db *sql.DB, base string, args []interface{}, batch int, cols []string, newWriter func() (RowWriter, error), transform RowTransform, maxRowsPerFile int64, onRoll RollCallback, onProgress ProgressCallback) (int64, []string, error) {
w, err := newWriter()
if err != nil {
return 0, nil, err
}
_ = w.WriteHeader(cols)
if onProgress != nil {
_ = onProgress(0)
}
files := []string{}
var total int64
var part int64
var tick int64
for off := 0; ; off += batch {
q := "SELECT * FROM (" + base + ") AS sub LIMIT ? OFFSET ?"
args2 := append(append([]interface{}{}, args...), batch, off)
rows, e := db.Query(q, args2...)
if e != nil {
logging.JSON("ERROR", map[string]interface{}{"event": "offset_query_error", "sql": q, "args": args2, "error": e.Error()})
log.Printf("offset_query_error sql=%s args=%v err=%v", q, args2, e)
return total, files, e
}
fetched := false
out := make([]interface{}, len(cols))
dest := make([]interface{}, len(cols))
for i := range out {
dest[i] = &out[i]
}
for rows.Next() {
fetched = true
if e := rows.Scan(dest...); e != nil {
rows.Close()
logging.JSON("ERROR", map[string]interface{}{"event": "offset_scan_error", "error": e.Error()})
log.Printf("offset_scan_error err=%v", e)
return total, files, e
}
vals := make([]string, len(cols))
for i := 0; i < len(cols); i++ {
if b, ok := out[i].([]byte); ok {
vals[i] = string(b)
} else if out[i] == nil {
vals[i] = ""
} else {
vals[i] = utils.ToString(out[i])
}
}
if transform != nil {
vals = transform(vals)
}
_ = w.WriteRow(vals)
total++
part++
tick++
if onProgress != nil && (tick == 1 || tick%200 == 0) {
_ = onProgress(total)
logging.JSON("INFO", map[string]interface{}{"event": "progress_tick", "total_rows": total})
}
if part >= maxRowsPerFile {
p, sz, _ := w.Close()
files = append(files, p)
if onRoll != nil {
_ = onRoll(p, sz, part)
}
w, e = newWriter()
if e != nil {
rows.Close()
return total, files, e
}
_ = w.WriteHeader(cols)
part = 0
}
}
rows.Close()
if !fetched {
break
}
}
p, sz, _ := w.Close()
if part > 0 || len(files) == 0 {
files = append(files, p)
if onRoll != nil {
_ = onRoll(p, sz, part)
}
}
if onProgress != nil {
_ = onProgress(total)
}
return total, files, nil
}