diff --git a/server/internal/api/templates.go b/server/internal/api/templates.go index a79ee2b..7a148bc 100644 --- a/server/internal/api/templates.go +++ b/server/internal/api/templates.go @@ -1,3 +1,4 @@ +// Package api 提供HTTP API处理器 package api import ( @@ -13,47 +14,69 @@ import ( "time" ) +// ==================== 模板API处理器 ==================== + +// TemplatesAPI 模板管理API type TemplatesAPI struct { - meta *sql.DB - marketing *sql.DB + metaDB *sql.DB // 元数据库(存储模板和任务) + marketingDB *sql.DB // 营销系统数据库 } -func TemplatesHandler(meta, marketing *sql.DB) http.Handler { - api := &TemplatesAPI{meta: meta, marketing: marketing} +// TemplatesHandler 创建模板API处理器 +func TemplatesHandler(metaDB, marketingDB *sql.DB) http.Handler { + api := &TemplatesAPI{metaDB: metaDB, marketingDB: marketingDB} return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - p := strings.TrimPrefix(r.URL.Path, "/api/templates") - if r.Method == http.MethodPost && p == "" { + path := strings.TrimPrefix(r.URL.Path, "/api/templates") + + // POST /api/templates - 创建模板 + if r.Method == http.MethodPost && path == "" { api.createTemplate(w, r) return } - if r.Method == http.MethodGet && p == "" { + + // GET /api/templates - 获取模板列表 + if r.Method == http.MethodGet && path == "" { api.listTemplates(w, r) return } - if strings.HasPrefix(p, "/") { - id := strings.TrimPrefix(p, "/") + + // 带ID的路径处理 + if strings.HasPrefix(path, "/") { + templateID := strings.TrimPrefix(path, "/") + + // GET /api/templates/:id - 获取单个模板 if r.Method == http.MethodGet { - api.getTemplate(w, r, id) + api.getTemplate(w, r, templateID) return } + + // PATCH /api/templates/:id - 更新模板 if r.Method == http.MethodPatch { - api.patchTemplate(w, r, id) + api.patchTemplate(w, r, templateID) return } + + // DELETE /api/templates/:id - 删除模板 if r.Method == http.MethodDelete { - api.deleteTemplate(w, r, id) + api.deleteTemplate(w, r, templateID) return } - if r.Method == http.MethodPost && strings.HasSuffix(p, "/validate") { - id = strings.TrimSuffix(id, "/validate") - api.validateTemplate(w, r, id) + + // POST /api/templates/:id/validate - 验证模板 + if r.Method == http.MethodPost && strings.HasSuffix(path, "/validate") { + templateID = strings.TrimSuffix(templateID, "/validate") + api.validateTemplate(w, r, templateID) return } } + fail(w, r, http.StatusNotFound, "not found") }) } +// ==================== 请求/响应结构 ==================== + +// TemplatePayload 模板创建/更新请求体 type TemplatePayload struct { Name string `json:"name"` Datasource string `json:"datasource"` @@ -65,291 +88,457 @@ type TemplatePayload struct { Visibility string `json:"visibility"` } -func (a *TemplatesAPI) createTemplate(w http.ResponseWriter, r *http.Request) { - b, _ := io.ReadAll(r.Body) - var p TemplatePayload - json.Unmarshal(b, &p) - r = WithPayload(r, p) - uidStr := r.URL.Query().Get("userId") - if uidStr != "" { - var uid uint64 - _, _ = fmt.Sscan(uidStr, &uid) - if uid > 0 { - p.OwnerID = uid - } - } - now := time.Now() - tplSQL := "INSERT INTO export_templates (name, datasource, main_table, fields_json, filters_json, file_format, visibility, owner_id, enabled, stats_enabled, last_validated_at, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)" - tplArgs := []interface{}{p.Name, p.Datasource, p.MainTable, toJSON(p.Fields), toJSON(p.Filters), p.FileFormat, p.Visibility, p.OwnerID, 1, 0, now, now, now} - log.Printf("trace_id=%s sql=%s args=%v", TraceIDFrom(r), tplSQL, tplArgs) - _, err := a.meta.Exec(tplSQL, tplArgs...) - if err != nil { - fail(w, r, http.StatusInternalServerError, err.Error()) - return - } - writeJSON(w, r, http.StatusCreated, 0, "ok", nil) -} +// ==================== API方法 ==================== -func (a *TemplatesAPI) listTemplates(w http.ResponseWriter, r *http.Request) { - uidStr := r.URL.Query().Get("userId") - sqlText := "SELECT id,name,datasource,main_table,file_format,visibility,owner_id,enabled,last_validated_at,created_at,updated_at,fields_json, (SELECT COUNT(1) FROM export_jobs ej WHERE ej.template_id = export_templates.id) AS exec_count FROM export_templates" - args := []interface{}{} - conds := []string{} - if uidStr != "" { - conds = append(conds, "owner_id IN (0, ?)") - args = append(args, uidStr) - } - conds = append(conds, "enabled = 1") - if len(conds) > 0 { - sqlText += " WHERE " + strings.Join(conds, " AND ") - } - sqlText += " ORDER BY datasource ASC, id DESC LIMIT 200" - rows, err := a.meta.Query(sqlText, args...) - if err != nil { - fail(w, r, http.StatusInternalServerError, err.Error()) - return - } - defer rows.Close() - wl := Whitelist() - countFields := func(ds, main string, fs []string) int64 { - seen := map[string]struct{}{} - for _, tf := range fs { - if ds == "ymt" && strings.HasPrefix(tf, "order_info.") { - tf = strings.Replace(tf, "order_info.", "order.", 1) - } - if !wl[tf] { - continue - } - // special dedupe: when both order.merchant_name and merchant.name exist, only count merchant.name - if ds == "ymt" && tf == "order.merchant_name" { - if _, ok := seen["merchant.name"]; ok { - continue - } - } - if _, ok := seen[tf]; ok { - continue - } - seen[tf] = struct{}{} - } - return int64(len(seen)) - } - out := []map[string]interface{}{} - for rows.Next() { - var id uint64 - var name, datasource, mainTable, fileFormat, visibility string - var ownerID uint64 - var enabled int - var lastValidatedAt sql.NullTime - var createdAt, updatedAt time.Time - var execCount int64 - var fieldsRaw []byte - err := rows.Scan(&id, &name, &datasource, &mainTable, &fileFormat, &visibility, &ownerID, &enabled, &lastValidatedAt, &createdAt, &updatedAt, &fieldsRaw, &execCount) - if err != nil { - fail(w, r, http.StatusInternalServerError, err.Error()) - return - } - var fs []string - _ = json.Unmarshal(fieldsRaw, &fs) - fieldCount := countFields(datasource, mainTable, fs) - m := map[string]interface{}{"id": id, "name": name, "datasource": datasource, "main_table": mainTable, "file_format": fileFormat, "visibility": visibility, "owner_id": ownerID, "enabled": enabled == 1, "last_validated_at": lastValidatedAt.Time, "created_at": createdAt, "updated_at": updatedAt, "field_count": fieldCount, "exec_count": execCount} - out = append(out, m) - } - ok(w, r, out) -} - -func (a *TemplatesAPI) getTemplate(w http.ResponseWriter, r *http.Request, id string) { - row := a.meta.QueryRow("SELECT id,name,datasource,main_table,fields_json,filters_json,file_format,visibility,owner_id,enabled,explain_score,last_validated_at,created_at,updated_at FROM export_templates WHERE id=?", id) - var m = map[string]interface{}{} - var tid uint64 - var name, datasource, mainTable, fileFormat, visibility string - var ownerID uint64 - var enabled int - var explainScore sql.NullInt64 - var lastValidatedAt sql.NullTime - var createdAt, updatedAt time.Time - var fields, filters []byte - err := row.Scan(&tid, &name, &datasource, &mainTable, &fields, &filters, &fileFormat, &visibility, &ownerID, &enabled, &explainScore, &lastValidatedAt, &createdAt, &updatedAt) - if err != nil { - fail(w, r, http.StatusNotFound, "not found") - return - } - m["id"] = tid - m["name"] = name - m["datasource"] = datasource - m["main_table"] = mainTable - m["file_format"] = fileFormat - m["visibility"] = visibility - m["owner_id"] = ownerID - m["enabled"] = enabled == 1 - m["explain_score"] = explainScore.Int64 - m["last_validated_at"] = lastValidatedAt.Time - m["created_at"] = createdAt - m["updated_at"] = updatedAt - m["fields"] = fromJSON(fields) - m["filters"] = fromJSON(filters) - ok(w, r, m) -} - -func (a *TemplatesAPI) patchTemplate(w http.ResponseWriter, r *http.Request, id string) { - b, err := io.ReadAll(r.Body) +// createTemplate 创建新模板 +func (api *TemplatesAPI) createTemplate(w http.ResponseWriter, r *http.Request) { + // 读取并解析请求体 + body, err := io.ReadAll(r.Body) if err != nil { log.Printf("trace_id=%s error reading request body: %v", TraceIDFrom(r), err) fail(w, r, http.StatusBadRequest, "invalid request body") return } - log.Printf("trace_id=%s patchTemplate request body: %s", TraceIDFrom(r), string(b)) - - var p map[string]interface{} - err = json.Unmarshal(b, &p) - if err != nil { - log.Printf("trace_id=%s error unmarshaling request body: %v", TraceIDFrom(r), err) + var payload TemplatePayload + if err := json.Unmarshal(body, &payload); err != nil { + log.Printf("trace_id=%s error parsing JSON: %v", TraceIDFrom(r), err) fail(w, r, http.StatusBadRequest, "invalid JSON format") return } - log.Printf("trace_id=%s patchTemplate parsed payload: %v", TraceIDFrom(r), p) - log.Printf("trace_id=%s patchTemplate template ID: %s", TraceIDFrom(r), id) + r = WithPayload(r, payload) - set := []string{} - args := []interface{}{} - for k, v := range p { - log.Printf("trace_id=%s patchTemplate processing field: %s, value: %v, type: %T", TraceIDFrom(r), k, v, v) - switch k { - case "name", "visibility", "file_format", "main_table": - if strVal, ok := v.(string); ok { - set = append(set, k+"=?") - args = append(args, strVal) - log.Printf("trace_id=%s patchTemplate added string field: %s, value: %s", TraceIDFrom(r), k, strVal) - } else { - log.Printf("trace_id=%s patchTemplate invalid string field: %s, value: %v, type: %T", TraceIDFrom(r), k, v, v) + // 介URL参数获取用户ID + if userIDStr := r.URL.Query().Get("userId"); userIDStr != "" { + var userID uint64 + if _, scanErr := fmt.Sscan(userIDStr, &userID); scanErr == nil && userID > 0 { + payload.OwnerID = userID + } + } + + // 插入数据库 + now := time.Now() + insertSQL := `INSERT INTO export_templates + (name, datasource, main_table, fields_json, filters_json, file_format, + visibility, owner_id, enabled, stats_enabled, last_validated_at, created_at, updated_at) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)` + + args := []interface{}{ + payload.Name, + payload.Datasource, + payload.MainTable, + toJSON(payload.Fields), + toJSON(payload.Filters), + payload.FileFormat, + payload.Visibility, + payload.OwnerID, + 1, // enabled + 0, // stats_enabled + now, // last_validated_at + now, // created_at + now, // updated_at + } + + log.Printf("trace_id=%s sql=%s args=%v", TraceIDFrom(r), insertSQL, args) + + if _, err := api.metaDB.Exec(insertSQL, args...); err != nil { + fail(w, r, http.StatusInternalServerError, err.Error()) + return + } + + writeJSON(w, r, http.StatusCreated, 0, "ok", nil) +} + +// listTemplates 获取模板列表 +func (api *TemplatesAPI) listTemplates(w http.ResponseWriter, r *http.Request) { + userIDStr := r.URL.Query().Get("userId") + + // 构建查询SQL + querySQL := `SELECT id, name, datasource, main_table, file_format, visibility, + owner_id, enabled, last_validated_at, created_at, updated_at, fields_json, + (SELECT COUNT(1) FROM export_jobs ej WHERE ej.template_id = export_templates.id) AS exec_count + FROM export_templates` + + var args []interface{} + var conditions []string + + if userIDStr != "" { + conditions = append(conditions, "owner_id IN (0, ?)") + args = append(args, userIDStr) + } + conditions = append(conditions, "enabled = 1") + + if len(conditions) > 0 { + querySQL += " WHERE " + strings.Join(conditions, " AND ") + } + querySQL += " ORDER BY datasource ASC, id DESC LIMIT 200" + + rows, err := api.metaDB.Query(querySQL, args...) + if err != nil { + fail(w, r, http.StatusInternalServerError, err.Error()) + return + } + defer rows.Close() + + whitelist := Whitelist() + templates := []map[string]interface{}{} + + for rows.Next() { + var ( + id uint64 + name string + datasource string + mainTable string + fileFormat string + visibility string + ownerID uint64 + enabled int + lastValidatedAt sql.NullTime + createdAt time.Time + updatedAt time.Time + fieldsRaw []byte + execCount int64 + ) + + if err := rows.Scan(&id, &name, &datasource, &mainTable, &fileFormat, &visibility, + &ownerID, &enabled, &lastValidatedAt, &createdAt, &updatedAt, &fieldsRaw, &execCount); err != nil { + fail(w, r, http.StatusInternalServerError, err.Error()) + return + } + + // 解析字段并计算有效字段数 + var fields []string + _ = json.Unmarshal(fieldsRaw, &fields) + fieldCount := countValidFields(datasource, mainTable, fields, whitelist) + + templates = append(templates, map[string]interface{}{ + "id": id, + "name": name, + "datasource": datasource, + "main_table": mainTable, + "file_format": fileFormat, + "visibility": visibility, + "owner_id": ownerID, + "enabled": enabled == 1, + "last_validated_at": lastValidatedAt.Time, + "created_at": createdAt, + "updated_at": updatedAt, + "field_count": fieldCount, + "exec_count": execCount, + }) + } + + ok(w, r, templates) +} + +// countValidFields 计算有效字段数(去重) +func countValidFields(datasource, mainTable string, fields []string, whitelist map[string]bool) int64 { + seen := map[string]struct{}{} + + for _, field := range fields { + // YMT系统的order_info映射为order + if datasource == "ymt" && strings.HasPrefix(field, "order_info.") { + field = strings.Replace(field, "order_info.", "order.", 1) + } + + // 检查白名单 + if !whitelist[field] { + continue + } + + // YMT系统客户名称去重 + if datasource == "ymt" && field == "order.merchant_name" { + if _, exists := seen["merchant.name"]; exists { + continue } + } + + if _, exists := seen[field]; exists { + continue + } + seen[field] = struct{}{} + } + + return int64(len(seen)) +} + +// getTemplate 获取单个模板详情 +func (api *TemplatesAPI) getTemplate(w http.ResponseWriter, r *http.Request, templateID string) { + querySQL := `SELECT id, name, datasource, main_table, fields_json, filters_json, + file_format, visibility, owner_id, enabled, explain_score, + last_validated_at, created_at, updated_at + FROM export_templates WHERE id=?` + + row := api.metaDB.QueryRow(querySQL, templateID) + + var ( + id uint64 + name string + datasource string + mainTable string + fileFormat string + visibility string + ownerID uint64 + enabled int + explainScore sql.NullInt64 + lastValidatedAt sql.NullTime + createdAt time.Time + updatedAt time.Time + fieldsJSON []byte + filtersJSON []byte + ) + + err := row.Scan(&id, &name, &datasource, &mainTable, &fieldsJSON, &filtersJSON, + &fileFormat, &visibility, &ownerID, &enabled, &explainScore, + &lastValidatedAt, &createdAt, &updatedAt) + if err != nil { + fail(w, r, http.StatusNotFound, "not found") + return + } + + result := map[string]interface{}{ + "id": id, + "name": name, + "datasource": datasource, + "main_table": mainTable, + "file_format": fileFormat, + "visibility": visibility, + "owner_id": ownerID, + "enabled": enabled == 1, + "explain_score": explainScore.Int64, + "last_validated_at": lastValidatedAt.Time, + "created_at": createdAt, + "updated_at": updatedAt, + "fields": fromJSON(fieldsJSON), + "filters": fromJSON(filtersJSON), + } + + ok(w, r, result) +} + +// patchTemplate 更新模板 +func (api *TemplatesAPI) patchTemplate(w http.ResponseWriter, r *http.Request, templateID string) { + traceID := TraceIDFrom(r) + + // 读取请求体 + body, err := io.ReadAll(r.Body) + if err != nil { + log.Printf("trace_id=%s error reading request body: %v", traceID, err) + fail(w, r, http.StatusBadRequest, "invalid request body") + return + } + + log.Printf("trace_id=%s patchTemplate request body: %s", traceID, string(body)) + + // 解析JSON + var payload map[string]interface{} + if err := json.Unmarshal(body, &payload); err != nil { + log.Printf("trace_id=%s error unmarshaling request body: %v", traceID, err) + fail(w, r, http.StatusBadRequest, "invalid JSON format") + return + } + + log.Printf("trace_id=%s patchTemplate parsed payload: %v", traceID, payload) + log.Printf("trace_id=%s patchTemplate template ID: %s", traceID, templateID) + + // 构建UPDATE语句 + var setClauses []string + var args []interface{} + + for key, value := range payload { + log.Printf("trace_id=%s patchTemplate processing field: %s, value: %v, type: %T", traceID, key, value, value) + + switch key { + case "name", "visibility", "file_format", "main_table": + if strVal, isStr := value.(string); isStr { + setClauses = append(setClauses, key+"=?") + args = append(args, strVal) + log.Printf("trace_id=%s patchTemplate added string field: %s, value: %s", traceID, key, strVal) + } else { + log.Printf("trace_id=%s patchTemplate invalid string field: %s, value: %v, type: %T", traceID, key, value, value) + } + case "fields": - set = append(set, "fields_json=?") - jsonBytes := toJSON(v) + setClauses = append(setClauses, "fields_json=?") + jsonBytes := toJSON(value) args = append(args, jsonBytes) - log.Printf("trace_id=%s patchTemplate added fields_json: %s", TraceIDFrom(r), string(jsonBytes)) + log.Printf("trace_id=%s patchTemplate added fields_json: %s", traceID, string(jsonBytes)) + case "filters": - set = append(set, "filters_json=?") - jsonBytes := toJSON(v) + setClauses = append(setClauses, "filters_json=?") + jsonBytes := toJSON(value) args = append(args, jsonBytes) - log.Printf("trace_id=%s patchTemplate added filters_json: %s", TraceIDFrom(r), string(jsonBytes)) + log.Printf("trace_id=%s patchTemplate added filters_json: %s", traceID, string(jsonBytes)) + case "enabled": - set = append(set, "enabled=?") - if boolVal, ok := v.(bool); ok { + setClauses = append(setClauses, "enabled=?") + if boolVal, isBool := value.(bool); isBool { if boolVal { args = append(args, 1) } else { args = append(args, 0) } - log.Printf("trace_id=%s patchTemplate added enabled: %t", TraceIDFrom(r), boolVal) + log.Printf("trace_id=%s patchTemplate added enabled: %t", traceID, boolVal) } else { - log.Printf("trace_id=%s patchTemplate invalid bool field: %s, value: %v, type: %T", TraceIDFrom(r), k, v, v) + log.Printf("trace_id=%s patchTemplate invalid bool field: %s, value: %v, type: %T", traceID, key, value, value) } } } - if len(set) == 0 { - log.Printf("trace_id=%s patchTemplate no fields to update", TraceIDFrom(r)) + if len(setClauses) == 0 { + log.Printf("trace_id=%s patchTemplate no fields to update", traceID) fail(w, r, http.StatusBadRequest, "no patch") return } - // ensure updated_at - set = append(set, "updated_at=?") + // 添加updated_at + setClauses = append(setClauses, "updated_at=?") now := time.Now() - args = append(args, now, id) + args = append(args, now, templateID) - sql := "UPDATE export_templates SET " + strings.Join(set, ",") + " WHERE id= ?" - log.Printf("trace_id=%s patchTemplate executing SQL: %s", TraceIDFrom(r), sql) - log.Printf("trace_id=%s patchTemplate SQL args: %v", TraceIDFrom(r), args) + updateSQL := "UPDATE export_templates SET " + strings.Join(setClauses, ",") + " WHERE id= ?" + log.Printf("trace_id=%s patchTemplate executing SQL: %s", traceID, updateSQL) + log.Printf("trace_id=%s patchTemplate SQL args: %v", traceID, args) - _, err = a.meta.Exec(sql, args...) - if err != nil { - log.Printf("trace_id=%s patchTemplate SQL error: %v", TraceIDFrom(r), err) + if _, err := api.metaDB.Exec(updateSQL, args...); err != nil { + log.Printf("trace_id=%s patchTemplate SQL error: %v", traceID, err) fail(w, r, http.StatusInternalServerError, err.Error()) return } - log.Printf("trace_id=%s patchTemplate update successful", TraceIDFrom(r)) + log.Printf("trace_id=%s patchTemplate update successful", traceID) ok(w, r, nil) } -func (a *TemplatesAPI) deleteTemplate(w http.ResponseWriter, r *http.Request, id string) { - var cnt int64 - row := a.meta.QueryRow("SELECT COUNT(1) FROM export_jobs WHERE template_id=?", id) - _ = row.Scan(&cnt) - if cnt > 0 { - soft := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("soft"))) - if soft == "1" || soft == "true" || soft == "yes" { - a.meta.Exec("UPDATE export_templates SET enabled=?, updated_at=? WHERE id= ?", 0, time.Now(), id) +// deleteTemplate 删除模板 +func (api *TemplatesAPI) deleteTemplate(w http.ResponseWriter, r *http.Request, templateID string) { + // 检查是否有关联的导出任务 + var jobCount int64 + row := api.metaDB.QueryRow("SELECT COUNT(1) FROM export_jobs WHERE template_id=?", templateID) + _ = row.Scan(&jobCount) + + if jobCount > 0 { + // 有关联任务,检查是否要求软删除 + softDelete := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("soft"))) + if softDelete == "1" || softDelete == "true" || softDelete == "yes" { + // 软删除:禁用模板 + _, _ = api.metaDB.Exec("UPDATE export_templates SET enabled=?, updated_at=? WHERE id=?", 0, time.Now(), templateID) ok(w, r, nil) return } fail(w, r, http.StatusBadRequest, "template in use") return } - _, err := a.meta.Exec("DELETE FROM export_templates WHERE id= ?", id) - if err != nil { + + // 无关联任务,硬删除 + if _, err := api.metaDB.Exec("DELETE FROM export_templates WHERE id=?", templateID); err != nil { fail(w, r, http.StatusInternalServerError, err.Error()) return } + ok(w, r, nil) } -func (a *TemplatesAPI) validateTemplate(w http.ResponseWriter, r *http.Request, id string) { - row := a.meta.QueryRow("SELECT datasource, main_table, fields_json, filters_json FROM export_templates WHERE id= ?", id) - var ds string - var main string - var fields, filters []byte - err := row.Scan(&ds, &main, &fields, &filters) - if err != nil { +// validateTemplate 验证模板 +func (api *TemplatesAPI) validateTemplate(w http.ResponseWriter, r *http.Request, templateID string) { + // 获取模板信息 + row := api.metaDB.QueryRow( + "SELECT datasource, main_table, fields_json, filters_json FROM export_templates WHERE id=?", + templateID, + ) + + var ( + datasource string + mainTable string + fieldsJSON []byte + filtersJSON []byte + ) + + if err := row.Scan(&datasource, &mainTable, &fieldsJSON, &filtersJSON); err != nil { fail(w, r, http.StatusNotFound, "not found") return } - var fs []string - var fl map[string]interface{} - json.Unmarshal(fields, &fs) - json.Unmarshal(filters, &fl) - wl := Whitelist() - req := exporter.BuildRequest{MainTable: main, Datasource: ds, Fields: fs, Filters: fl} - q, args, err := exporter.BuildSQL(req, wl) + + // 解析字段和过滤条件 + var fields []string + var filters map[string]interface{} + _ = json.Unmarshal(fieldsJSON, &fields) + _ = json.Unmarshal(filtersJSON, &filters) + + // 构建SQL + whitelist := Whitelist() + request := exporter.BuildRequest{ + MainTable: mainTable, + Datasource: datasource, + Fields: fields, + Filters: filters, + } + + query, args, err := exporter.BuildSQL(request, whitelist) if err != nil { failCat(w, r, http.StatusBadRequest, err.Error(), "sql_build_error") return } - dataDB := a.selectDataDB(ds) - score, sugg, err := exporter.EvaluateExplain(dataDB, q, args) + + // 执行EXPLAIN分析 + dataDB := api.selectDataDB(datasource) + score, suggestions, err := exporter.EvaluateExplain(dataDB, query, args) if err != nil { failCat(w, r, http.StatusBadRequest, err.Error(), "explain_error") return } - idxSugg := exporter.IndexSuggestions(req) - sugg = append(sugg, idxSugg...) - _, _ = a.meta.Exec("UPDATE export_templates SET explain_json=?, explain_score=?, last_validated_at=?, updated_at=? WHERE id=?", toJSON(map[string]interface{}{"sql": q, "suggestions": sugg}), score, time.Now(), time.Now(), id) - ok(w, r, map[string]interface{}{"score": score, "suggestions": sugg}) -} -func (a *TemplatesAPI) selectDataDB(ds string) *sql.DB { - if ds == "ymt" { - return a.meta + // 添加索引建议 + indexSuggestions := exporter.IndexSuggestions(request) + suggestions = append(suggestions, indexSuggestions...) + + // 更新模板的验证结果 + explainResult := map[string]interface{}{ + "sql": query, + "suggestions": suggestions, } - return a.marketing + now := time.Now() + _, _ = api.metaDB.Exec( + "UPDATE export_templates SET explain_json=?, explain_score=?, last_validated_at=?, updated_at=? WHERE id=?", + toJSON(explainResult), score, now, now, templateID, + ) + + ok(w, r, map[string]interface{}{ + "score": score, + "suggestions": suggestions, + }) } +// selectDataDB 根据数据源选择对应的数据库连接 +func (api *TemplatesAPI) selectDataDB(datasource string) *sql.DB { + if datasource == "ymt" { + return api.metaDB // YMT数据在meta库 + } + return api.marketingDB +} + +// ==================== 辅助函数 ==================== + +// toJSON 将对象转换为JSON字节 func toJSON(v interface{}) []byte { b, _ := json.Marshal(v) return b } +// fromJSON 将JSON字节解析为对象 func fromJSON(b []byte) interface{} { var v interface{} - json.Unmarshal(b, &v) + _ = json.Unmarshal(b, &v) return v } -func Whitelist() map[string]bool { return schema.AllWhitelist() } +// Whitelist 获取字段白名单 +func Whitelist() map[string]bool { + return schema.AllWhitelist() +} +// FieldLabels 获取字段标签映射 func FieldLabels() map[string]string { return schema.AllLabels() } diff --git a/server/internal/constants/constants.go b/server/internal/constants/constants.go new file mode 100644 index 0000000..39d14f9 --- /dev/null +++ b/server/internal/constants/constants.go @@ -0,0 +1,156 @@ +// Package constants 定义系统常量、阈值和配置 +package constants + +import "time" + +// ==================== 导出任务状态 ==================== + +// JobStatus 导出任务状态枚举 +type JobStatus string + +const ( + JobStatusQueued JobStatus = "queued" // 排队中 + JobStatusRunning JobStatus = "running" // 执行中 + JobStatusCompleted JobStatus = "completed" // 已完成 + JobStatusFailed JobStatus = "failed" // 失败 + JobStatusCanceled JobStatus = "canceled" // 已取消 +) + +// ==================== 数据源类型 ==================== + +// Datasource 数据源类型 +type Datasource string + +const ( + DatasourceMarketing Datasource = "marketing" // 营销系统 + DatasourceYMT Datasource = "ymt" // 易码通 +) + +// ==================== 主表类型 ==================== + +// MainTable 主表类型 +type MainTable string + +const ( + MainTableOrder MainTable = "order" // 订单表 + MainTableOrderInfo MainTable = "order_info" // 订单信息表(YMT) +) + +// ==================== 导出格式 ==================== + +// FileFormat 导出文件格式 +type FileFormat string + +const ( + FileFormatCSV FileFormat = "csv" + FileFormatXLSX FileFormat = "xlsx" +) + +// ==================== 导出阈值配置 ==================== + +// ExportThresholds 导出相关阈值 +var ExportThresholds = struct { + // MaxRowsPerFile 单文件最大行数 + MaxRowsPerFile int64 + // PassScoreThreshold EXPLAIN评分通过阈值 + PassScoreThreshold int + // ChunkDays 分块导出的天数步长 + ChunkDays int + // ChunkThreshold 启用分块导出的行数阈值 + ChunkThreshold int64 + // ProgressUpdateInterval 进度更新间隔(行数) + ProgressUpdateInterval int64 +}{ + MaxRowsPerFile: 300000, + PassScoreThreshold: 60, + ChunkDays: 10, + ChunkThreshold: 50000, + ProgressUpdateInterval: 1000, +} + +// BatchSizes 批量处理大小配置 +var BatchSizes = struct { + // CSVDefault CSV默认批次大小 + CSVDefault int + // XLSXDefault XLSX默认批次大小 + XLSXDefault int + // SmallDataset 小数据集批次 + SmallDataset int + // MediumDataset 中等数据集批次 + MediumDataset int + // LargeDataset 大数据集批次 + LargeDataset int + // HugeDataset 超大数据集批次 + HugeDataset int +}{ + CSVDefault: 10000, + XLSXDefault: 5000, + SmallDataset: 10000, + MediumDataset: 20000, + LargeDataset: 50000, + HugeDataset: 100000, +} + +// ChooseBatchSize 根据估算行数和格式选择合适的批次大小 +func ChooseBatchSize(estimate int64, format FileFormat) int { + if format == FileFormatXLSX { + return BatchSizes.XLSXDefault + } + if estimate <= 0 { + return BatchSizes.CSVDefault + } + if estimate < 50000 { + return BatchSizes.SmallDataset + } + if estimate < 200000 { + return BatchSizes.MediumDataset + } + if estimate < 500000 { + return BatchSizes.LargeDataset + } + if estimate >= 2000000 { + return BatchSizes.HugeDataset + } + return BatchSizes.LargeDataset +} + +// ==================== HTTP服务器配置 ==================== + +// ServerConfig HTTP服务器配置 +var ServerConfig = struct { + DefaultPort string + ReadTimeout time.Duration + WriteTimeout time.Duration +}{ + DefaultPort: "8077", + ReadTimeout: 15 * time.Second, + WriteTimeout: 60 * time.Second, +} + +// ==================== 分页配置 ==================== + +// PaginationConfig 分页配置 +var PaginationConfig = struct { + DefaultPageSize int + MaxPageSize int +}{ + DefaultPageSize: 15, + MaxPageSize: 100, +} + +// ==================== 存储配置 ==================== + +// StorageConfig 存储配置 +var StorageConfig = struct { + Directory string + FilePrefix string + ZipSuffix string + CSVSuffix string + XLSXSuffix string +}{ + Directory: "storage", + FilePrefix: "export_job_", + ZipSuffix: ".zip", + CSVSuffix: ".csv", + XLSXSuffix: ".xlsx", +} diff --git a/server/internal/constants/enums.go b/server/internal/constants/enums.go new file mode 100644 index 0000000..bb4dd90 --- /dev/null +++ b/server/internal/constants/enums.go @@ -0,0 +1,317 @@ +// Package constants 枚举值映射配置 +package constants + +// ==================== Marketing系统订单状态 ==================== + +// MarketingOrderStatus 营销系统订单状态映射 +var MarketingOrderStatus = map[int]string{ + 0: "待充值", + 1: "充值中", + 2: "已完成", + 3: "充值失败", + 4: "已取消", + 5: "已过期", + 6: "待支付", +} + +// MarketingOrderType 营销系统订单类型映射 +var MarketingOrderType = map[int]string{ + 1: "直充卡密", + 2: "立减金", + 3: "红包", +} + +// MarketingPayType 营销系统支付方式映射 +var MarketingPayType = map[int]string{ + 1: "支付宝", + 5: "微信", +} + +// MarketingPayStatus 营销系统支付状态映射 +var MarketingPayStatus = map[int]string{ + 1: "待支付", + 2: "已支付", + 3: "已退款", +} + +// ==================== YMT系统订单状态 ==================== + +// YMTOrderStatus 易码通订单状态映射 +var YMTOrderStatus = map[int]string{ + 1: "待充值", + 2: "充值中", + 3: "充值成功", + 4: "充值失败", + 5: "已过期", + 6: "已作废", + 7: "已核销", + 8: "核销失败", + 9: "订单重置", + 10: "卡单", +} + +// YMTOrderType 易码通订单类型映射 +var YMTOrderType = map[int]string{ + 1: "红包订单", + 2: "直充卡密订单", + 3: "立减金订单", +} + +// YMTPayStatus 易码通支付状态映射 +var YMTPayStatus = map[int]string{ + 1: "待支付", + 2: "支付中", + 3: "已支付", + 4: "取消支付", + 5: "退款中", + 6: "退款成功", +} + +// YMTIsRetry 易码通重试状态映射 +var YMTIsRetry = map[int]string{ + 0: "可以失败重试", + 1: "可以失败重试", + 2: "不可以失败重试", +} + +// YMTIsInner 易码通供应商类型映射 +var YMTIsInner = map[int]string{ + 0: "外部供应商", + 1: "内部供应商", +} + +// YMTSettlementType 易码通结算类型映射 +var YMTSettlementType = map[int]string{ + 1: "发放结算", + 2: "打开结算", + 3: "领用结算", + 4: "核销结算", +} + +// ==================== 通用枚举 ==================== + +// ThirdPartyType 第三方类型映射 +var ThirdPartyType = map[int]string{ + 1: "外部供应商", + 2: "内部供应商", +} + +// OrderCashReceiveStatus 红包领取状态映射 +var OrderCashReceiveStatus = map[int]string{ + 0: "待领取", + 1: "领取中", + 2: "领取成功", + 3: "领取失败", +} + +// OrderCashChannel 红包渠道映射 +var OrderCashChannel = map[int]string{ + 1: "支付宝", + 2: "微信", + 3: "云闪付", +} + +// OrderVoucherChannel 立减金渠道映射 +var OrderVoucherChannel = map[int]string{ + 1: "支付宝", + 2: "微信", + 3: "云闪付", +} + +// YMTOrderVoucherStatus 易码通立减金状态映射 +var YMTOrderVoucherStatus = map[int]string{ + 1: "待发放", + 2: "发放中", + 3: "发放失败", + 4: "待核销", + 5: "已核销", + 6: "已过期", + 7: "已退款", +} + +// MarketingOrderVoucherStatus 营销系统立减金状态映射 +var MarketingOrderVoucherStatus = map[int]string{ + 1: "可用", + 2: "已实扣", + 3: "已过期", + 4: "已退款", + 5: "领取失败", + 6: "发放中", + 7: "部分退款", + 8: "已退回", + 9: "发放失败", +} + +// OrderVoucherReceiveMode 立减金领取模式映射 +var OrderVoucherReceiveMode = map[int]string{ + 1: "渠道授权用户id", + 2: "手机号或邮箱", +} + +// OrderDigitOrderType 数字订单类型映射 +var OrderDigitOrderType = map[int]string{ + 1: "直充", + 2: "卡密", +} + +// OrderDigitSmsChannel 短信渠道映射 +var OrderDigitSmsChannel = map[int]string{ + 1: "官方", + 2: "专票", +} + +// ==================== 表名中文标签 ==================== + +// TableLabels 表名中文标签 +var TableLabels = map[string]string{ + "order": "订单", + "order_info": "订单", + "order_detail": "订单详情", + "order_cash": "红包", + "order_voucher": "立减金", + "plan": "计划", + "activity": "活动", + "merchant": "客户", + "supplier": "供应商", + "key_batch": "卡密批次", + "code_batch": "验证码批次", + "voucher": "券", + "voucher_batch": "券批次", + "order_digit": "数字订单", + "merchant_key_send": "商户发放", +} + +// GetTableLabel 获取表的中文标签 +func GetTableLabel(table string) string { + if label, ok := TableLabels[table]; ok { + return label + } + return table +} + +// ==================== SQL CASE WHEN 生成器 ==================== + +// BuildCaseWhen 生成 CASE WHEN SQL 片段 +// 参数: tableName - 表名, columnName - 列名, enumMap - 枚举映射, alias - 别名 +func BuildCaseWhen(tableName, columnName string, enumMap map[int]string, alias string) string { + if len(enumMap) == 0 { + return "" + } + sql := "CASE `" + tableName + "`." + columnName + for k, v := range enumMap { + sql += " WHEN " + itoa(k) + " THEN '" + v + "'" + } + sql += " ELSE '' END AS `" + alias + "`" + return sql +} + +// itoa 简单的整数转字符串 +func itoa(i int) string { + if i == 0 { + return "0" + } + neg := false + if i < 0 { + neg = true + i = -i + } + var b [20]byte + p := len(b) - 1 + for i > 0 { + b[p] = byte('0' + i%10) + i /= 10 + p-- + } + if neg { + b[p] = '-' + p-- + } + return string(b[p+1:]) +} + +// ==================== 订单类型标签解析 ==================== + +// ParseOrderTypeLabel 从标签解析订单类型值 +func ParseOrderTypeLabel(datasource, label string) int { + if datasource == "ymt" { + for k, v := range YMTOrderType { + if v == label { + return k + } + } + } else { + for k, v := range MarketingOrderType { + if v == label { + return k + } + } + } + return 0 +} + +// ==================== 已支付状态判断 ==================== + +// IsPaidStatus 判断是否为已支付状态 +// datasource: marketing 或 ymt +// status: 状态码或状态文本 +func IsPaidStatus(datasource, status string) bool { + if status == "" { + return false + } + + // 尝试解析为数字 + numeric := -1 + if isDigits(status) { + numeric = parseDigits(status) + } + + switch datasource { + case "marketing": + if numeric >= 0 { + // 2:已支付 3:已退款 + return numeric == 2 || numeric == 3 + } + return status == "已支付" || status == "已退款" + case "ymt": + if numeric >= 0 { + // 3:已支付 5:退款中 6:退款成功 + return numeric == 3 || numeric == 5 || numeric == 6 + } + return status == "已支付" || status == "退款成功" || status == "退款中" + default: + return contains(status, "支付") && !contains(status, "待") + } +} + +func isDigits(s string) bool { + for _, c := range s { + if c < '0' || c > '9' { + return false + } + } + return len(s) > 0 +} + +func parseDigits(s string) int { + n := 0 + for _, c := range s { + if c >= '0' && c <= '9' { + n = n*10 + int(c-'0') + } + } + return n +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && indexString(s, substr) >= 0 +} + +func indexString(s, substr string) int { + n := len(substr) + for i := 0; i+n <= len(s); i++ { + if s[i:i+n] == substr { + return i + } + } + return -1 +} diff --git a/server/internal/exporter/job_runner.go b/server/internal/exporter/job_runner.go new file mode 100644 index 0000000..fe322dd --- /dev/null +++ b/server/internal/exporter/job_runner.go @@ -0,0 +1,251 @@ +// Package exporter 导出任务执行器 +package exporter + +import ( + "database/sql" + "server/internal/constants" + "server/internal/logging" + "strconv" +) + +// ==================== 任务执行配置 ==================== + +// JobConfig 导出任务配置 +type JobConfig struct { + JobID uint64 // 任务ID + Datasource string // 数据源 + MainTable string // 主表 + Format constants.FileFormat // 导出格式 + Query string // SQL查询 + Args []interface{} // SQL参数 + Fields []string // 字段列表 + Headers []string // 表头列表 + Filters map[string]interface{} // 过滤条件 +} + +// JobCallbacks 任务回调函数 +type JobCallbacks struct { + // OnStart 任务开始时调用 + OnStart func(jobID uint64) error + // OnProgress 进度更新时调用 + OnProgress func(jobID uint64, totalRows int64) error + // OnFileCreated 文件创建时调用 + OnFileCreated func(jobID uint64, path string, size, rowCount int64) error + // OnComplete 任务完成时调用 + OnComplete func(jobID uint64, totalRows int64, files []string) error + // OnFailed 任务失败时调用 + OnFailed func(jobID uint64, err error) error + // Transform 行数据转换函数 + Transform func(datasource string, fields, values []string) []string +} + +// ==================== 任务执行器 ==================== + +// JobRunner 导出任务执行器 +type JobRunner struct { + db *sql.DB + config JobConfig + callbacks JobCallbacks +} + +// NewJobRunner 创建任务执行器 +func NewJobRunner(db *sql.DB, config JobConfig, callbacks JobCallbacks) *JobRunner { + return &JobRunner{ + db: db, + config: config, + callbacks: callbacks, + } +} + +// Run 执行导出任务 +func (r *JobRunner) Run() error { + defer r.recoverPanic() + + // 通知任务开始 + if r.callbacks.OnStart != nil { + if err := r.callbacks.OnStart(r.config.JobID); err != nil { + return err + } + } + + // 根据格式执行导出 + return r.runExport() +} + +// recoverPanic 恢复panic并标记任务失败 +func (r *JobRunner) recoverPanic() { + if rec := recover(); rec != nil { + logging.JSON("ERROR", map[string]interface{}{ + "event": "export_panic", + "job_id": r.config.JobID, + "error": anyToString(rec), + }) + if r.callbacks.OnFailed != nil { + r.callbacks.OnFailed(r.config.JobID, nil) + } + } +} + +// runExport 执行导出(统一CSV和XLSX逻辑) +func (r *JobRunner) runExport() error { + cfg := r.config + maxRowsPerFile := constants.ExportThresholds.MaxRowsPerFile + + // 创建写入器工厂 + dir := constants.StorageConfig.Directory + name := constants.StorageConfig.FilePrefix + strconv.FormatUint(cfg.JobID, 10) + + newWriter := func() (RowWriter, error) { + w, err := NewWriter(cfg.Format, dir, name) + if err != nil { + return nil, err + } + // 写入表头 + if err := w.WriteHeader(cfg.Headers); err != nil { + return nil, err + } + return w, nil + } + + // 创建游标 + cursor := NewCursorSQL(cfg.Datasource, cfg.MainTable) + batchSize := constants.ChooseBatchSize(0, cfg.Format) + + // 转换函数 + transform := func(vals []string) []string { + if r.callbacks.Transform != nil { + return r.callbacks.Transform(cfg.Datasource, cfg.Fields, vals) + } + return vals + } + + // 收集生成的文件 + var files []string + var totalRows int64 + + // 进度回调 + onProgress := func(rows int64) error { + if r.callbacks.OnProgress != nil { + return r.callbacks.OnProgress(cfg.JobID, rows) + } + return nil + } + + // 文件滚动回调 + onRoll := func(path string, size, partRows int64) error { + files = append(files, path) + if r.callbacks.OnFileCreated != nil { + return r.callbacks.OnFileCreated(cfg.JobID, path, size, partRows) + } + return nil + } + + // 执行流式导出 + count, resultFiles, err := StreamWithCursor( + r.db, + cfg.Query, + cfg.Args, + cursor, + batchSize, + cfg.Headers, + newWriter, + transform, + maxRowsPerFile, + onRoll, + onProgress, + ) + + if err != nil { + logging.JSON("ERROR", map[string]interface{}{ + "event": "export_stream_error", + "job_id": cfg.JobID, + "error": err.Error(), + }) + if r.callbacks.OnFailed != nil { + r.callbacks.OnFailed(cfg.JobID, err) + } + return err + } + + totalRows = count + if len(resultFiles) > 0 { + files = resultFiles + } + + // 通知完成 + if r.callbacks.OnComplete != nil { + return r.callbacks.OnComplete(cfg.JobID, totalRows, files) + } + + return nil +} + +// ==================== 辅助函数 ==================== + +// anyToString 将任意类型转换为字符串 +func anyToString(v interface{}) string { + switch t := v.(type) { + case []byte: + return string(t) + case string: + return t + case int64: + return strconv.FormatInt(t, 10) + case int: + return strconv.Itoa(t) + case uint64: + return strconv.FormatUint(t, 10) + case float64: + if t == float64(int64(t)) { + return strconv.FormatInt(int64(t), 10) + } + return strconv.FormatFloat(t, 'f', -1, 64) + case nil: + return "" + case error: + return t.Error() + default: + return "" + } +} + +// ==================== 分块导出支持 ==================== + +// ChunkRange 时间分块范围 +type ChunkRange struct { + Start string + End string +} + +// SplitTimeRange 按天数分割时间范围 +func SplitTimeRange(startStr, endStr string, stepDays int) []ChunkRange { + chunks := SplitByDays(startStr, endStr, stepDays) + result := make([]ChunkRange, len(chunks)) + for i, c := range chunks { + result[i] = ChunkRange{Start: c[0], End: c[1]} + } + return result +} + +// ShouldUseChunkedExport 判断是否应该使用分块导出 +func ShouldUseChunkedExport(datasource, mainTable string, estimate int64, fields []string, filters map[string]interface{}) bool { + // 行数太少不需要分块 + if estimate <= constants.ExportThresholds.ChunkThreshold { + return false + } + + // Marketing系统的某些场景不适合分块 + if datasource == "marketing" && mainTable == "order" { + // 检查是否有order_voucher相关字段或过滤 + for _, f := range fields { + if len(f) > 14 && f[:14] == "order_voucher." { + return false + } + } + if _, ok := filters["order_voucher_channel_activity_id_eq"]; ok { + return false + } + } + + return true +} diff --git a/server/internal/exporter/writer.go b/server/internal/exporter/writer.go index a37fb1d..c4f62f9 100644 --- a/server/internal/exporter/writer.go +++ b/server/internal/exporter/writer.go @@ -1,19 +1,53 @@ +// Package exporter 提供数据导出功能 package exporter import ( "bufio" "encoding/csv" + "errors" "os" "path/filepath" + "server/internal/constants" "time" "github.com/xuri/excelize/v2" ) +// ==================== 接口定义 ==================== + +// RowWriter 行写入器接口 +// 所有导出格式必须实现此接口 type RowWriter interface { + // WriteHeader 写入表头 WriteHeader(cols []string) error + // WriteRow 写入数据行 WriteRow(vals []string) error - Close() (string, int64, error) + // Close 关闭并返回文件路径、大小 + Close() (path string, size int64, err error) +} + +// WriterFactory 写入器工厂函数类型 +type WriterFactory func() (RowWriter, error) + +// ==================== 工厂方法 ==================== + +// NewWriter 根据格式创建对应的写入器 +func NewWriter(format constants.FileFormat, dir, name string) (RowWriter, error) { + switch format { + case constants.FileFormatCSV: + return NewCSVWriter(dir, name) + case constants.FileFormatXLSX: + return NewXLSXWriter(dir, name, "Sheet1") + default: + return nil, errors.New("unsupported format: " + string(format)) + } +} + +// NewWriterFactory 创建写入器工厂函数 +func NewWriterFactory(format constants.FileFormat, dir, name string) WriterFactory { + return func() (RowWriter, error) { + return NewWriter(format, dir, name) + } } type CSVWriter struct { diff --git a/server/internal/logging/logging.go b/server/internal/logging/logging.go index 6f16f52..677b902 100644 --- a/server/internal/logging/logging.go +++ b/server/internal/logging/logging.go @@ -1,39 +1,142 @@ +// Package logging 提供统一的日志功能 package logging import ( - "encoding/json" - "fmt" - "io" - "log" - "os" - "path/filepath" - "time" + "encoding/json" + "fmt" + "io" + "log" + "os" + "path/filepath" + "runtime" + "time" ) +// ==================== 日志级别 ==================== + +// Level 日志级别 +type Level string + +const ( + LevelDebug Level = "DEBUG" + LevelInfo Level = "INFO" + LevelWarn Level = "WARN" + LevelError Level = "ERROR" +) + +// ==================== 初始化 ==================== + +// Init 初始化日志系统 func Init(dir string) error { - if dir == "" { - dir = "log" - } - if err := os.MkdirAll(dir, 0755); err != nil { - return err - } - name := fmt.Sprintf("server-%s.log", time.Now().Format("20060102")) - p := filepath.Join(dir, name) - f, err := os.OpenFile(p, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - if err != nil { - return err - } - mw := io.MultiWriter(os.Stdout, f) - log.SetOutput(mw) - log.SetFlags(0) - return nil + if dir == "" { + dir = "log" + } + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + name := fmt.Sprintf("server-%s.log", time.Now().Format("20060102")) + p := filepath.Join(dir, name) + f, err := os.OpenFile(p, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return err + } + mw := io.MultiWriter(os.Stdout, f) + log.SetOutput(mw) + log.SetFlags(0) + return nil } +// ==================== 结构化日志 ==================== + +// JSON 输出JSON格式日志 func JSON(level string, fields map[string]interface{}) { - m := map[string]interface{}{"level": level, "ts": time.Now().Format(time.RFC3339)} - for k, v := range fields { - m[k] = v - } - b, _ := json.Marshal(m) - log.Println(string(b)) + m := map[string]interface{}{"level": level, "ts": time.Now().Format(time.RFC3339)} + for k, v := range fields { + m[k] = v + } + b, _ := json.Marshal(m) + log.Println(string(b)) +} + +// ==================== 便捷方法 ==================== + +// Debug 输出Debug级别日志 +func Debug(event string, fields map[string]interface{}) { + if fields == nil { + fields = make(map[string]interface{}) + } + fields["event"] = event + JSON(string(LevelDebug), fields) +} + +// Info 输出Info级别日志 +func Info(event string, fields map[string]interface{}) { + if fields == nil { + fields = make(map[string]interface{}) + } + fields["event"] = event + JSON(string(LevelInfo), fields) +} + +// Warn 输出Warn级别日志 +func Warn(event string, fields map[string]interface{}) { + if fields == nil { + fields = make(map[string]interface{}) + } + fields["event"] = event + JSON(string(LevelWarn), fields) +} + +// Error 输出Error级别日志 +func Error(event string, err error, fields map[string]interface{}) { + if fields == nil { + fields = make(map[string]interface{}) + } + fields["event"] = event + if err != nil { + fields["error"] = err.Error() + } + // 添加调用位置信息 + _, file, line, ok := runtime.Caller(1) + if ok { + fields["caller"] = fmt.Sprintf("%s:%d", filepath.Base(file), line) + } + JSON(string(LevelError), fields) +} + +// ==================== 专用日志 ==================== + +// DBError 数据库错误日志 +func DBError(action string, jobID uint64, err error) { + Error("db_error", err, map[string]interface{}{ + "action": action, + "job_id": jobID, + }) +} + +// ExportProgress 导出进度日志 +func ExportProgress(jobID uint64, totalRows int64) { + Info("progress_update", map[string]interface{}{ + "job_id": jobID, + "total_rows": totalRows, + }) +} + +// ExportSQL 导出SQL日志 +func ExportSQL(jobID uint64, datasource, mainTable, sql string, args []interface{}) { + Info("export_sql", map[string]interface{}{ + "job_id": jobID, + "datasource": datasource, + "main_table": mainTable, + "sql": sql, + "args": args, + }) +} + +// FieldsRemoved 字段移除日志 +func FieldsRemoved(event string, removed []string, reason string) { + Info(event, map[string]interface{}{ + "removed": removed, + "reason": reason, + }) } diff --git a/server/internal/repo/export_repo.go b/server/internal/repo/export_repo.go index 9939c42..4a83642 100644 --- a/server/internal/repo/export_repo.go +++ b/server/internal/repo/export_repo.go @@ -1,3 +1,4 @@ +// Package repo 提供数据访问层 package repo import ( @@ -8,140 +9,308 @@ import ( "time" ) +// ==================== 导出仓库 ==================== + +// ExportQueryRepo 导出查询仓库 type ExportQueryRepo struct{} -func NewExportRepo() *ExportQueryRepo { return &ExportQueryRepo{} } - -func (r *ExportQueryRepo) Build(req exporter.BuildRequest, wl map[string]bool) (string, []interface{}, error) { - return exporter.BuildSQL(req, wl) +// NewExportRepo 创建导出仓库实例 +func NewExportRepo() *ExportQueryRepo { + return &ExportQueryRepo{} } -func (r *ExportQueryRepo) Explain(db *sql.DB, q string, args []interface{}) (int, []string, error) { - return exporter.EvaluateExplain(db, q, args) +// ==================== SQL构建 ==================== + +// Build 构建SQL查询 +func (r *ExportQueryRepo) Build(req exporter.BuildRequest, whitelist map[string]bool) (string, []interface{}, error) { + return exporter.BuildSQL(req, whitelist) } -func (r *ExportQueryRepo) Count(db *sql.DB, base string, args []interface{}) int64 { - return exporter.CountRows(db, base, args) +// Explain 执行EXPLAIN分析 +func (r *ExportQueryRepo) Explain(db *sql.DB, query string, args []interface{}) (int, []string, error) { + return exporter.EvaluateExplain(db, query, args) } -// Count by BuildRequest using filters-only joins and COUNT(DISTINCT main pk) -func (r *ExportQueryRepo) CountByReq(db *sql.DB, req exporter.BuildRequest, wl map[string]bool) int64 { - q, args, err := exporter.BuildCountSQL(req, wl) +// ==================== 行数统计 ==================== + +// Count 使用基础查询统计行数 +func (r *ExportQueryRepo) Count(db *sql.DB, baseQuery string, args []interface{}) int64 { + return exporter.CountRows(db, baseQuery, args) +} + +// CountByReq 使用BuildRequest统计行数(COUNT(DISTINCT)) +func (r *ExportQueryRepo) CountByReq(db *sql.DB, req exporter.BuildRequest, whitelist map[string]bool) int64 { + query, args, err := exporter.BuildCountSQL(req, whitelist) if err != nil { - logging.JSON("ERROR", map[string]interface{}{"event": "build_count_sql_error", "error": err.Error()}) + logging.Error("build_count_sql_error", err, nil) return 0 } - var c int64 - row := db.QueryRow(q, args...) - if err := row.Scan(&c); err != nil { - logging.JSON("ERROR", map[string]interface{}{"event": "count_by_req_error", "error": err.Error(), "sql": q, "args": args}) + + var count int64 + if err := db.QueryRow(query, args...).Scan(&count); err != nil { + logging.Error("count_by_req_error", err, map[string]interface{}{ + "sql": query, + "args": args, + }) return 0 } - return c + return count } -func (r *ExportQueryRepo) EstimateFast(db *sql.DB, ds, main string, filters map[string]interface{}) int64 { - return exporter.CountRowsFast(db, ds, main, filters) +// EstimateFast 快速估算行数 +func (r *ExportQueryRepo) EstimateFast(db *sql.DB, datasource, mainTable string, filters map[string]interface{}) int64 { + return exporter.CountRowsFast(db, datasource, mainTable, filters) } -func (r *ExportQueryRepo) EstimateFastChunked(db *sql.DB, ds, main string, filters map[string]interface{}) int64 { - return exporter.CountRowsFastChunked(db, ds, main, filters) +// EstimateFastChunked 快速估算行数(分块) +func (r *ExportQueryRepo) EstimateFastChunked(db *sql.DB, datasource, mainTable string, filters map[string]interface{}) int64 { + return exporter.CountRowsFastChunked(db, datasource, mainTable, filters) } -func (r *ExportQueryRepo) NewCursor(datasource, main string) *exporter.CursorSQL { - return exporter.NewCursorSQL(datasource, main) +// ==================== 游标操作 ==================== + +// NewCursor 创建游标 +func (r *ExportQueryRepo) NewCursor(datasource, mainTable string) *exporter.CursorSQL { + return exporter.NewCursorSQL(datasource, mainTable) } +// ==================== 回调函数类型 ==================== + +// RowWriterFactory 写入器工厂函数 type RowWriterFactory func() (exporter.RowWriter, error) + +// RowTransform 行转换函数 type RowTransform func([]string) []string + +// RollCallback 文件滚动回调 type RollCallback func(path string, size int64, partRows int64) error + +// ProgressCallback 进度回调 type ProgressCallback func(totalRows int64) error -func (r *ExportQueryRepo) StreamCursor(db *sql.DB, base string, args []interface{}, cur *exporter.CursorSQL, batch int, cols []string, newWriter RowWriterFactory, transform RowTransform, maxRowsPerFile int64, onRoll RollCallback, onProgress ProgressCallback) (int64, []string, error) { - return exporter.StreamWithCursor(db, base, args, cur, batch, cols, func() (exporter.RowWriter, error) { return newWriter() }, func(vals []string) []string { return transform(vals) }, maxRowsPerFile, func(p string, sz int64, rows int64) error { return onRoll(p, sz, rows) }, func(total int64) error { return onProgress(total) }) +// ==================== 流式导出 ==================== + +// StreamCursor 流式导出数据 +func (r *ExportQueryRepo) StreamCursor( + db *sql.DB, + baseQuery string, + args []interface{}, + cursor *exporter.CursorSQL, + batchSize int, + columns []string, + newWriter RowWriterFactory, + transform RowTransform, + maxRowsPerFile int64, + onRoll RollCallback, + onProgress ProgressCallback, +) (int64, []string, error) { + return exporter.StreamWithCursor( + db, + baseQuery, + args, + cursor, + batchSize, + columns, + func() (exporter.RowWriter, error) { return newWriter() }, + func(vals []string) []string { return transform(vals) }, + maxRowsPerFile, + func(path string, size int64, rows int64) error { return onRoll(path, size, rows) }, + func(total int64) error { return onProgress(total) }, + ) } -func (r *ExportQueryRepo) ZipAndRecord(meta *sql.DB, jobID uint64, files []string, total int64) { +// ZipAndRecord 压缩文件并记录 +func (r *ExportQueryRepo) ZipAndRecord(metaDB *sql.DB, jobID uint64, files []string, totalRows int64) { if len(files) == 0 { return } + zipPath, zipSize := exporter.ZipFiles(jobID, files) - meta.Exec("INSERT INTO export_job_files (job_id, storage_uri, row_count, size_bytes, created_at, updated_at) VALUES (?,?,?,?,?,?)", jobID, zipPath, total, zipSize, time.Now(), time.Now()) + now := time.Now() + + _, err := metaDB.Exec( + "INSERT INTO export_job_files (job_id, storage_uri, row_count, size_bytes, created_at, updated_at) VALUES (?,?,?,?,?,?)", + jobID, zipPath, totalRows, zipSize, now, now, + ) + if err != nil { + logging.DBError("zip_and_record", jobID, err) + } } -// Metadata and job helpers -func (r *ExportQueryRepo) GetTemplateMeta(meta *sql.DB, tplID uint64) (string, string, []string, error) { - var ds string - var main string +// ==================== 模板元数据 ==================== + +// GetTemplateMeta 获取模板元数据 +func (r *ExportQueryRepo) GetTemplateMeta(metaDB *sql.DB, templateID uint64) (datasource, mainTable string, fields []string, err error) { var fieldsJSON []byte - row := meta.QueryRow("SELECT datasource, main_table, fields_json FROM export_templates WHERE id= ?", tplID) - if err := row.Scan(&ds, &main, &fieldsJSON); err != nil { + + row := metaDB.QueryRow( + "SELECT datasource, main_table, fields_json FROM export_templates WHERE id=?", + templateID, + ) + + if err = row.Scan(&datasource, &mainTable, &fieldsJSON); err != nil { return "", "", nil, err } - var fs []string - _ = json.Unmarshal(fieldsJSON, &fs) - return ds, main, fs, nil + + _ = json.Unmarshal(fieldsJSON, &fields) + return datasource, mainTable, fields, nil } -func (r *ExportQueryRepo) GetJobFilters(meta *sql.DB, jobID uint64) (uint64, []byte, error) { - var tplID uint64 - var filtersJSON []byte - row := meta.QueryRow("SELECT template_id, filters_json FROM export_jobs WHERE id= ?", jobID) - if err := row.Scan(&tplID, &filtersJSON); err != nil { +// GetJobFilters 获取任务的过滤条件 +func (r *ExportQueryRepo) GetJobFilters(metaDB *sql.DB, jobID uint64) (templateID uint64, filtersJSON []byte, err error) { + row := metaDB.QueryRow( + "SELECT template_id, filters_json FROM export_jobs WHERE id=?", + jobID, + ) + + if err = row.Scan(&templateID, &filtersJSON); err != nil { return 0, nil, err } - return tplID, filtersJSON, nil + return templateID, filtersJSON, nil } -func (r *ExportQueryRepo) InsertJob(meta *sql.DB, tplID, requestedBy, owner uint64, permission, filters, options map[string]interface{}, explain map[string]interface{}, explainScore int, rowEstimate int64, fileFormat string) (uint64, error) { - ejSQL := "INSERT INTO export_jobs (template_id, status, requested_by, owner_id, permission_scope_json, filters_json, options_json, explain_json, explain_score, row_estimate, file_format, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)" - ejArgs := []interface{}{tplID, "queued", requestedBy, owner, toJSON(permission), toJSON(filters), toJSON(options), toJSON(explain), explainScore, rowEstimate, fileFormat, time.Now(), time.Now()} - res, err := meta.Exec(ejSQL, ejArgs...) +// ==================== 任务管理 ==================== + +// InsertJob 插入新的导出任务 +func (r *ExportQueryRepo) InsertJob( + metaDB *sql.DB, + templateID, requestedBy, ownerID uint64, + permission, filters, options map[string]interface{}, + explainResult map[string]interface{}, + explainScore int, + rowEstimate int64, + fileFormat string, +) (uint64, error) { + now := time.Now() + + insertSQL := `INSERT INTO export_jobs + (template_id, status, requested_by, owner_id, permission_scope_json, + filters_json, options_json, explain_json, explain_score, row_estimate, + file_format, created_at, updated_at) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)` + + args := []interface{}{ + templateID, + "queued", + requestedBy, + ownerID, + toJSON(permission), + toJSON(filters), + toJSON(options), + toJSON(explainResult), + explainScore, + rowEstimate, + fileFormat, + now, + now, + } + + result, err := metaDB.Exec(insertSQL, args...) if err != nil { return 0, err } - id, _ := res.LastInsertId() + + id, _ := result.LastInsertId() return uint64(id), nil } -func (r *ExportQueryRepo) StartJob(meta *sql.DB, id uint64) { - if _, err := meta.Exec("UPDATE export_jobs SET status=?, started_at=?, updated_at=? WHERE id= ?", "running", time.Now(), time.Now(), id); err != nil { - logging.JSON("ERROR", map[string]interface{}{"event": "db_update_error", "action": "start_job", "job_id": id, "error": err.Error()}) - } -} -func (r *ExportQueryRepo) UpdateProgress(meta *sql.DB, id uint64, total int64) { - if _, err := meta.Exec("UPDATE export_jobs SET total_rows=GREATEST(COALESCE(total_rows,0), ?), updated_at=?, status=CASE WHEN status='queued' THEN 'running' ELSE status END WHERE id= ?", total, time.Now(), id); err != nil { - logging.JSON("ERROR", map[string]interface{}{"event": "db_update_error", "action": "update_progress", "job_id": id, "error": err.Error()}) - } - logging.JSON("INFO", map[string]interface{}{"event": "progress_update", "job_id": id, "total_rows": total}) -} -func (r *ExportQueryRepo) MarkFailed(meta *sql.DB, id uint64) { - if _, err := meta.Exec("UPDATE export_jobs SET status=?, finished_at=? WHERE id= ?", "failed", time.Now(), id); err != nil { - logging.JSON("ERROR", map[string]interface{}{"event": "db_update_error", "action": "mark_failed", "job_id": id, "error": err.Error()}) - } -} -func (r *ExportQueryRepo) MarkCompleted(meta *sql.DB, id uint64, total int64) { - if _, err := meta.Exec("UPDATE export_jobs SET status=?, finished_at=?, total_rows=?, row_estimate=GREATEST(COALESCE(row_estimate,0), ?), updated_at=? WHERE id= ?", "completed", time.Now(), total, total, time.Now(), id); err != nil { - logging.JSON("ERROR", map[string]interface{}{"event": "db_update_error", "action": "mark_completed", "job_id": id, "error": err.Error()}) - } -} -func (r *ExportQueryRepo) InsertJobFile(meta *sql.DB, id uint64, uri string, sheetName string, rowCount, size int64) { - if _, err := meta.Exec("INSERT INTO export_job_files (job_id, storage_uri, sheet_name, row_count, size_bytes, created_at, updated_at) VALUES (?,?,?,?,?,?,?)", id, uri, sheetName, rowCount, size, time.Now(), time.Now()); err != nil { - logging.JSON("ERROR", map[string]interface{}{"event": "db_insert_error", "action": "insert_job_file", "job_id": id, "error": err.Error(), "path": uri}) +// StartJob 标记任务开始执行 +func (r *ExportQueryRepo) StartJob(metaDB *sql.DB, jobID uint64) { + now := time.Now() + _, err := metaDB.Exec( + "UPDATE export_jobs SET status=?, started_at=?, updated_at=? WHERE id=?", + "running", now, now, jobID, + ) + if err != nil { + logging.DBError("start_job", jobID, err) } } -func (r *ExportQueryRepo) UpdateRowEstimate(meta *sql.DB, id uint64, est int64) { - if _, err := meta.Exec("UPDATE export_jobs SET row_estimate=?, updated_at=? WHERE id= ?", est, time.Now(), id); err != nil { - logging.JSON("ERROR", map[string]interface{}{"event": "db_update_error", "action": "update_row_estimate", "job_id": id, "error": err.Error(), "row_estimate": est}) +// UpdateProgress 更新任务进度 +func (r *ExportQueryRepo) UpdateProgress(metaDB *sql.DB, jobID uint64, totalRows int64) { + now := time.Now() + _, err := metaDB.Exec( + `UPDATE export_jobs + SET total_rows=GREATEST(COALESCE(total_rows,0), ?), + updated_at=?, + status=CASE WHEN status='queued' THEN 'running' ELSE status END + WHERE id=?`, + totalRows, now, jobID, + ) + if err != nil { + logging.DBError("update_progress", jobID, err) + } + logging.ExportProgress(jobID, totalRows) +} + +// MarkFailed 标记任务失败 +func (r *ExportQueryRepo) MarkFailed(metaDB *sql.DB, jobID uint64) { + now := time.Now() + _, err := metaDB.Exec( + "UPDATE export_jobs SET status=?, finished_at=? WHERE id=?", + "failed", now, jobID, + ) + if err != nil { + logging.DBError("mark_failed", jobID, err) } } +// MarkCompleted 标记任务完成 +func (r *ExportQueryRepo) MarkCompleted(metaDB *sql.DB, jobID uint64, totalRows int64) { + now := time.Now() + _, err := metaDB.Exec( + `UPDATE export_jobs + SET status=?, finished_at=?, total_rows=?, + row_estimate=GREATEST(COALESCE(row_estimate,0), ?), updated_at=? + WHERE id=?`, + "completed", now, totalRows, totalRows, now, jobID, + ) + if err != nil { + logging.DBError("mark_completed", jobID, err) + } +} + +// InsertJobFile 插入任务文件记录 +func (r *ExportQueryRepo) InsertJobFile(metaDB *sql.DB, jobID uint64, uri, sheetName string, rowCount, sizeBytes int64) { + now := time.Now() + _, err := metaDB.Exec( + `INSERT INTO export_job_files + (job_id, storage_uri, sheet_name, row_count, size_bytes, created_at, updated_at) + VALUES (?,?,?,?,?,?,?)`, + jobID, uri, sheetName, rowCount, sizeBytes, now, now, + ) + if err != nil { + logging.Error("insert_job_file", err, map[string]interface{}{ + "job_id": jobID, + "path": uri, + }) + } +} + +// UpdateRowEstimate 更新行数估算 +func (r *ExportQueryRepo) UpdateRowEstimate(metaDB *sql.DB, jobID uint64, estimate int64) { + now := time.Now() + _, err := metaDB.Exec( + "UPDATE export_jobs SET row_estimate=?, updated_at=? WHERE id=?", + estimate, now, jobID, + ) + if err != nil { + logging.Error("update_row_estimate", err, map[string]interface{}{ + "job_id": jobID, + "row_estimate": estimate, + }) + } +} + +// toJSON 将对象序列化为JSON func toJSON(v interface{}) []byte { b, _ := json.Marshal(v) return b } +// ==================== 数据类型 ==================== + +// JobDetail 任务详情 type JobDetail struct { ID uint64 TemplateID uint64 @@ -157,6 +326,7 @@ type JobDetail struct { ExplainJSON sql.NullString } +// JobFile 任务文件 type JobFile struct { URI sql.NullString Sheet sql.NullString @@ -164,6 +334,7 @@ type JobFile struct { SizeBytes sql.NullInt64 } +// JobListItem 任务列表项 type JobListItem struct { ID uint64 TemplateID uint64 @@ -178,79 +349,142 @@ type JobListItem struct { ExplainJSON sql.NullString } -func (r *ExportQueryRepo) GetJob(meta *sql.DB, id string) (JobDetail, error) { - row := meta.QueryRow("SELECT id, template_id, status, requested_by, total_rows, file_format, started_at, finished_at, created_at, updated_at, explain_score, explain_json FROM export_jobs WHERE id= ?", id) - var d JobDetail - err := row.Scan(&d.ID, &d.TemplateID, &d.Status, &d.RequestedBy, &d.TotalRows, &d.FileFormat, &d.StartedAt, &d.FinishedAt, &d.CreatedAt, &d.UpdatedAt, &d.ExplainScore, &d.ExplainJSON) - return d, err +// ==================== 任务查询 ==================== + +// GetJob 获取任务详情 +func (r *ExportQueryRepo) GetJob(metaDB *sql.DB, jobID string) (JobDetail, error) { + querySQL := `SELECT id, template_id, status, requested_by, total_rows, + file_format, started_at, finished_at, created_at, updated_at, + explain_score, explain_json + FROM export_jobs WHERE id=?` + + var detail JobDetail + err := metaDB.QueryRow(querySQL, jobID).Scan( + &detail.ID, &detail.TemplateID, &detail.Status, &detail.RequestedBy, + &detail.TotalRows, &detail.FileFormat, &detail.StartedAt, &detail.FinishedAt, + &detail.CreatedAt, &detail.UpdatedAt, &detail.ExplainScore, &detail.ExplainJSON, + ) + return detail, err } -func (r *ExportQueryRepo) ListJobFiles(meta *sql.DB, jobID string) ([]JobFile, error) { - rows, err := meta.Query("SELECT storage_uri, sheet_name, row_count, size_bytes FROM export_job_files WHERE job_id= ?", jobID) +// ListJobFiles 获取任务文件列表 +func (r *ExportQueryRepo) ListJobFiles(metaDB *sql.DB, jobID string) ([]JobFile, error) { + rows, err := metaDB.Query( + "SELECT storage_uri, sheet_name, row_count, size_bytes FROM export_job_files WHERE job_id=?", + jobID, + ) if err != nil { return nil, err } defer rows.Close() - out := []JobFile{} + + var files []JobFile for rows.Next() { - var f JobFile - rows.Scan(&f.URI, &f.Sheet, &f.RowCount, &f.SizeBytes) - out = append(out, f) + var file JobFile + if err := rows.Scan(&file.URI, &file.Sheet, &file.RowCount, &file.SizeBytes); err != nil { + continue + } + files = append(files, file) } - return out, nil + return files, nil } -func (r *ExportQueryRepo) GetLatestFileURI(meta *sql.DB, jobID string) (string, error) { - row := meta.QueryRow("SELECT storage_uri FROM export_job_files WHERE job_id=? ORDER BY id DESC LIMIT 1", jobID) +// GetLatestFileURI 获取最新文件URI +func (r *ExportQueryRepo) GetLatestFileURI(metaDB *sql.DB, jobID string) (string, error) { var uri string - err := row.Scan(&uri) + err := metaDB.QueryRow( + "SELECT storage_uri FROM export_job_files WHERE job_id=? ORDER BY id DESC LIMIT 1", + jobID, + ).Scan(&uri) return uri, err } -func (r *ExportQueryRepo) CountJobs(meta *sql.DB, tplID uint64, owner string) int64 { - var c int64 - if tplID > 0 { - if owner != "" { - _ = meta.QueryRow("SELECT COUNT(1) FROM export_jobs WHERE template_id = ? AND owner_id = ?", tplID, owner).Scan(&c) +// CountJobs 统计任务数量 +func (r *ExportQueryRepo) CountJobs(metaDB *sql.DB, templateID uint64, ownerID string) int64 { + var count int64 + var err error + + if templateID > 0 { + if ownerID != "" { + err = metaDB.QueryRow( + "SELECT COUNT(1) FROM export_jobs WHERE template_id=? AND owner_id=?", + templateID, ownerID, + ).Scan(&count) } else { - _ = meta.QueryRow("SELECT COUNT(1) FROM export_jobs WHERE template_id = ?", tplID).Scan(&c) + err = metaDB.QueryRow( + "SELECT COUNT(1) FROM export_jobs WHERE template_id=?", + templateID, + ).Scan(&count) } } else { - if owner != "" { - _ = meta.QueryRow("SELECT COUNT(1) FROM export_jobs WHERE owner_id = ?", owner).Scan(&c) + if ownerID != "" { + err = metaDB.QueryRow( + "SELECT COUNT(1) FROM export_jobs WHERE owner_id=?", + ownerID, + ).Scan(&count) } else { - _ = meta.QueryRow("SELECT COUNT(1) FROM export_jobs").Scan(&c) + err = metaDB.QueryRow("SELECT COUNT(1) FROM export_jobs").Scan(&count) } } - return c + + if err != nil { + logging.Error("count_jobs", err, nil) + } + return count } -func (r *ExportQueryRepo) ListJobs(meta *sql.DB, tplID uint64, owner string, size, offset int) ([]JobListItem, error) { - var rows *sql.Rows - var err error - if tplID > 0 { - if owner != "" { - rows, err = meta.Query("SELECT id, template_id, status, requested_by, row_estimate, total_rows, file_format, created_at, updated_at, explain_score, explain_json FROM export_jobs WHERE template_id = ? AND owner_id = ? ORDER BY id DESC LIMIT ? OFFSET ?", tplID, owner, size, offset) - } else { - rows, err = meta.Query("SELECT id, template_id, status, requested_by, row_estimate, total_rows, file_format, created_at, updated_at, explain_score, explain_json FROM export_jobs WHERE template_id = ? ORDER BY id DESC LIMIT ? OFFSET ?", tplID, size, offset) - } - } else { - if owner != "" { - rows, err = meta.Query("SELECT id, template_id, status, requested_by, row_estimate, total_rows, file_format, created_at, updated_at, explain_score, explain_json FROM export_jobs WHERE owner_id = ? ORDER BY id DESC LIMIT ? OFFSET ?", owner, size, offset) - } else { - rows, err = meta.Query("SELECT id, template_id, status, requested_by, row_estimate, total_rows, file_format, created_at, updated_at, explain_score, explain_json FROM export_jobs ORDER BY id DESC LIMIT ? OFFSET ?", size, offset) - } +// ListJobs 获取任务列表 +func (r *ExportQueryRepo) ListJobs(metaDB *sql.DB, templateID uint64, ownerID string, pageSize, offset int) ([]JobListItem, error) { + querySQL := `SELECT id, template_id, status, requested_by, row_estimate, + total_rows, file_format, created_at, updated_at, explain_score, explain_json + FROM export_jobs` + + var args []interface{} + var conditions []string + + if templateID > 0 { + conditions = append(conditions, "template_id=?") + args = append(args, templateID) } + if ownerID != "" { + conditions = append(conditions, "owner_id=?") + args = append(args, ownerID) + } + + if len(conditions) > 0 { + querySQL += " WHERE " + joinStrings(conditions, " AND ") + } + querySQL += " ORDER BY id DESC LIMIT ? OFFSET ?" + args = append(args, pageSize, offset) + + rows, err := metaDB.Query(querySQL, args...) if err != nil { return nil, err } defer rows.Close() - items := []JobListItem{} + + var items []JobListItem for rows.Next() { - var it JobListItem - if err := rows.Scan(&it.ID, &it.TemplateID, &it.Status, &it.RequestedBy, &it.RowEstimate, &it.TotalRows, &it.FileFormat, &it.CreatedAt, &it.UpdatedAt, &it.ExplainScore, &it.ExplainJSON); err == nil { - items = append(items, it) + var item JobListItem + if err := rows.Scan( + &item.ID, &item.TemplateID, &item.Status, &item.RequestedBy, + &item.RowEstimate, &item.TotalRows, &item.FileFormat, + &item.CreatedAt, &item.UpdatedAt, &item.ExplainScore, &item.ExplainJSON, + ); err == nil { + items = append(items, item) } } return items, nil } + +// joinStrings 连接字符串切片 +func joinStrings(strs []string, sep string) string { + if len(strs) == 0 { + return "" + } + result := strs[0] + for i := 1; i < len(strs); i++ { + result += sep + strs[i] + } + return result +}