diff --git a/csv.go b/csv.go index 4ac6631..4f7a2a4 100644 --- a/csv.go +++ b/csv.go @@ -10,11 +10,22 @@ import ( "strings" ) -func getCsvInfo(dir string) (dataMap []map[string]string, title []string) { +func getCsvInfo(dir string) (dataMap []map[string]string, title []string, files []string) { fp := fmt.Sprintf("%s/%s/", dir, "csvfile") + if _, err := os.Stat(fp); os.IsNotExist(err) { + // 文件夹不存在,创建它 + err := os.Mkdir(fp, 0755) // 0755 是常见的目录权限 + if err != nil { + fatal("无法创建目录 %s: %v", fp, err) + } + } entries, err := os.ReadDir(fp) if err != nil { - panic("当前目录下未找到csvfile文件夹,请自行创建") + fatal("获取文件目录失败 %s: %v", fp, err) + } + if len(entries) == 0 { + warning("请将需导入文件放入%s目录下", fp) + exit() } var csvFiles []string for _, entry := range entries { @@ -25,7 +36,7 @@ func getCsvInfo(dir string) (dataMap []map[string]string, title []string) { for { prompt := promptui.Select{ Label: "请选择需要导入的csv文件: ", - Items: csvFiles, + Items: append([]string{"全选", "导入非`已导入_`开头的文件"}, csvFiles...), Size: len(csvFiles), } _, f, err := prompt.Run() @@ -33,10 +44,35 @@ func getCsvInfo(dir string) (dataMap []map[string]string, title []string) { warning("Prompt failed %v\n", err) continue } + fp = fmt.Sprintf("%s/%s", fp, f) + switch fp { + case "全选": + files = csvFiles + case "导入非`已导入_`开头的文件": + for _, s := range csvFiles { + if !strings.HasPrefix(s, "已导入_") { + files = append(files, s) + } + } + default: + files = append(files, fp) + } break } - return ReadCsvData(fp) + if len(files) == 0 { + warning("未找到符合条件的文件") + finish(dir) + } + + for _, v := range files { + data, fileTitle := ReadCsvData(v) + if title == nil { + title = fileTitle + } + dataMap = append(dataMap, data...) + } + return dataMap, title, files } func ReadCsvData(fp string) (dataMap []map[string]string, title []string) { diff --git a/func.go b/func.go index 90fba1d..6ec9ed2 100644 --- a/func.go +++ b/func.go @@ -90,7 +90,7 @@ func saveData(dataChan chan []map[string]string, set *Set, title []string, dir s for _, item := range v { var dataRaw strings.Builder is_start := true - dataRaw.WriteString(fmt.Sprintf("SELECT %s From %s WHERE ", strings.Join(title, "`,`"), c.Table)) + dataRaw.WriteString(fmt.Sprintf("SELECT %s From %s WHERE ", strings.Join(set.OverWriteJudColumns, "`,`"), c.Table)) for _, t := range set.OverWriteJudColumns { if strings.Contains(item[t], "'") { item[t] = strings.ReplaceAll(item[t], "'", "`") @@ -99,10 +99,17 @@ func saveData(dataChan chan []map[string]string, set *Set, title []string, dir s item[t] = strings.ReplaceAll(item[t], `"`, "`") } if !is_start { - dataRaw.WriteString(fmt.Sprintf("'%s',", item[t])) + dataRaw.WriteString(fmt.Sprintf("'%s'=%s", t, item[t])) } - dataRaw.WriteString(fmt.Sprintf("'%s',", item[t])) + dataRaw.WriteString(fmt.Sprintf("AND '%s'=%s", item[t])) } + raw := dataRaw.String() + result := db.Exec(raw) + if result.Error != nil { + warning(fmt.Sprintf(" 查询失败: %v", result.Error)) + return + } + } default: query.WriteString(";") diff --git a/main.go b/main.go index 5a84566..fbe34bb 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "os" "strings" + "time" ) var ( @@ -54,10 +55,11 @@ func main() { } func do(dir string) { - csvData, title := getCsvInfo(dir) + csvData, title, files := getCsvInfo(dir) set := setOp(title) - cutChannel := cutData(csvData, 100) + cutChannel := cutData(csvData, 1) saveData(cutChannel, set, title, dir) + overFiles(files) finish(dir) select {} } @@ -141,9 +143,29 @@ func cConf(fp string) *Conf { } func exit() { - os.Exit(1) + waitTime := 3 + warning("程序即将再%d秒后退出", waitTime) + for { + if waitTime == 0 { + os.Exit(1) + } + time.Sleep(1 * time.Second) + warning("%d", waitTime) + waitTime-- + } + } func finish(dir string) { do(dir) } + +func overFiles(files []string) { + for _, v := range files { + err := os.Rename(v, fmt.Sprintf("已导入_%s", v)) + if err != nil { + warning("重命名文件时出错: %v\n", err) + exit() + } + } +}