373 lines
7.6 KiB
Go
373 lines
7.6 KiB
Go
package pkg
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"math/rand/v2"
|
||
"net"
|
||
"net/http"
|
||
"os"
|
||
"path/filepath"
|
||
"reflect"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/go-viper/mapstructure/v2"
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
func GetModuleDir() (string, error) {
|
||
dir, err := os.Getwd()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
for {
|
||
modPath := filepath.Join(dir, "go.mod")
|
||
if _, err := os.Stat(modPath); err == nil {
|
||
return dir, nil // 找到 go.mod
|
||
}
|
||
|
||
// 向上查找父目录
|
||
parent := filepath.Dir(dir)
|
||
if parent == dir {
|
||
break // 到达根目录,未找到
|
||
}
|
||
dir = parent
|
||
}
|
||
|
||
return "", fmt.Errorf("go.mod not found in current directory or parents")
|
||
}
|
||
|
||
// GetCacheDir 用于获取缓存目录路径
|
||
// 如果缓存目录不存在,则会自动创建
|
||
// 返回值:
|
||
// - string: 缓存目录的路径
|
||
// - error: 如果获取模块目录失败或创建缓存目录失败,则返回错误信息
|
||
func GetCacheDir() (string, error) {
|
||
// 获取模块目录
|
||
modDir, err := GetModuleDir()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
// 拼接缓存目录路径
|
||
path := fmt.Sprintf("%s/cache", modDir)
|
||
// 创建目录(包括所有必要的父目录),权限设置为0755
|
||
err = os.MkdirAll(path, 0755)
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建目录失败: %w", err)
|
||
}
|
||
// 返回成功创建的缓存目录路径
|
||
return path, nil
|
||
}
|
||
|
||
func GetTmplDir() (string, error) {
|
||
modDir, err := GetModuleDir()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
path := fmt.Sprintf("%s/tmpl", modDir)
|
||
err = os.MkdirAll(path, 0755)
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建目录失败: %w", err)
|
||
}
|
||
return path, nil
|
||
}
|
||
|
||
func ReverseSliceNew[T any](s []T) []T {
|
||
result := make([]T, len(s))
|
||
for i := 0; i < len(s); i++ {
|
||
result[i] = s[len(s)-1-i]
|
||
}
|
||
return result
|
||
}
|
||
|
||
func JsonStringIgonErr(data interface{}) string {
|
||
return string(JsonByteIgonErr(data))
|
||
}
|
||
|
||
func JsonByteIgonErr(data interface{}) []byte {
|
||
dataByte, _ := json.Marshal(data)
|
||
return dataByte
|
||
}
|
||
|
||
func IntersectionGeneric[T comparable](slice1, slice2 []T) []T {
|
||
m := make(map[T]bool)
|
||
result := []T{}
|
||
|
||
for _, v := range slice1 {
|
||
m[v] = true
|
||
}
|
||
|
||
for _, v := range slice2 {
|
||
if m[v] {
|
||
result = append(result, v)
|
||
delete(m, v) // 避免重复
|
||
}
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
func CreateOrderNum(prefix string) string {
|
||
code := fmt.Sprintf("%04d", rand.IntN(10000))
|
||
fmt.Println("4位随机数字:", code) // 输出示例: "0837"
|
||
return prefix + time.Now().Format("20060102150405") + code
|
||
}
|
||
|
||
func BuildUpdateMap(obj interface{}, omitFields ...string) map[string]interface{} {
|
||
result := make(map[string]interface{})
|
||
omitMap := make(map[string]bool)
|
||
for _, f := range omitFields {
|
||
omitMap[f] = true
|
||
}
|
||
|
||
v := reflect.ValueOf(obj)
|
||
if v.Kind() == reflect.Ptr {
|
||
v = v.Elem()
|
||
}
|
||
t := v.Type()
|
||
|
||
for i := 0; i < v.NumField(); i++ {
|
||
field := v.Field(i)
|
||
fieldName := t.Field(i).Name
|
||
|
||
if omitMap[fieldName] {
|
||
continue
|
||
}
|
||
|
||
// 只处理非 nil 的指针字段
|
||
if field.Kind() == reflect.Ptr && !field.IsNil() {
|
||
// 将驼峰转为下划线(可选的,根据你的数据库列名决定)
|
||
colName := CamelToSnake(fieldName)
|
||
result[colName] = field.Elem().Interface()
|
||
}
|
||
}
|
||
return result
|
||
}
|
||
|
||
func CamelToSnake(s string) string {
|
||
var result []rune
|
||
for i, r := range s {
|
||
if i > 0 && r >= 'A' && r <= 'Z' {
|
||
result = append(result, '_')
|
||
}
|
||
result = append(result, r)
|
||
}
|
||
return string(result)
|
||
}
|
||
|
||
func CopyNonNilFields(src, dst interface{}) error {
|
||
config := &mapstructure.DecoderConfig{
|
||
Result: dst,
|
||
TagName: "json",
|
||
ZeroFields: false, // 重要:不清零目标字段
|
||
Squash: false,
|
||
}
|
||
|
||
decoder, err := mapstructure.NewDecoder(config)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return decoder.Decode(src)
|
||
}
|
||
|
||
func DownloadFile(url string, saveDir string, filename string) (string, error) {
|
||
os.MkdirAll(saveDir, 0755)
|
||
|
||
resp, err := http.Get(url)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 获取文件扩展名
|
||
ext := ""
|
||
|
||
// 1. 从URL获取扩展名
|
||
if idx := strings.LastIndex(url, "."); idx != -1 {
|
||
ext = url[idx:]
|
||
// 去除查询参数
|
||
if idx2 := strings.Index(ext, "?"); idx2 != -1 {
|
||
ext = ext[:idx2]
|
||
}
|
||
}
|
||
|
||
// 2. 如果URL没有扩展名,从Content-Type获取
|
||
if ext == "" {
|
||
contentType := resp.Header.Get("Content-Type")
|
||
switch contentType {
|
||
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||
ext = ".docx"
|
||
case "application/msword":
|
||
ext = ".doc"
|
||
case "application/pdf":
|
||
ext = ".pdf"
|
||
case "video/mp4":
|
||
ext = ".mp4"
|
||
case "video/x-msvideo":
|
||
ext = ".avi"
|
||
default:
|
||
// 默认空,后续会报错
|
||
}
|
||
}
|
||
|
||
if ext == "" {
|
||
return "", fmt.Errorf("无法确定文件类型: %s", url)
|
||
}
|
||
|
||
if filename == "" {
|
||
filename = uuid.New().String() + ext
|
||
} else if !strings.HasSuffix(filename, ext) {
|
||
filename = filename + ext
|
||
}
|
||
|
||
filePath := filepath.Join(saveDir, filename)
|
||
|
||
out, err := os.Create(filePath)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer out.Close()
|
||
|
||
_, err = io.Copy(out, resp.Body)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
absPath, _ := filepath.Abs(filePath)
|
||
return absPath, nil
|
||
}
|
||
|
||
func DownloadImage(url string, requestID string, dir string) (string, error) {
|
||
os.MkdirAll(dir, 0755)
|
||
|
||
ext := filepath.Ext(url)
|
||
if ext == "" {
|
||
ext = ".jpg"
|
||
}
|
||
filename := requestID + "_" + uuid.New().String() + ext
|
||
filePath := filepath.Join(dir, filename)
|
||
|
||
resp, err := http.Get(url)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
out, err := os.Create(filePath)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer out.Close()
|
||
|
||
_, err = io.Copy(out, resp.Body)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
return filepath.Abs(filePath)
|
||
}
|
||
|
||
func DeleteFile(path string) {
|
||
if path != "" {
|
||
os.Remove(path)
|
||
}
|
||
}
|
||
|
||
func GenerateUUID() string {
|
||
return uuid.New().String()
|
||
}
|
||
|
||
func GenerateUserIndex() string {
|
||
return uuid.New().String()[:20]
|
||
}
|
||
|
||
func GetString(m map[string]interface{}, key string) string {
|
||
if v, ok := m[key]; ok {
|
||
switch v.(type) {
|
||
case []uint8:
|
||
return string(v.([]uint8))
|
||
case string:
|
||
return v.(string)
|
||
case int64:
|
||
return fmt.Sprintf("%d", v)
|
||
default:
|
||
return fmt.Sprintf("%v", v)
|
||
}
|
||
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func ParseTags(tagStr string) []string {
|
||
if tagStr == "" {
|
||
return []string{}
|
||
}
|
||
tags := strings.Split(tagStr, ",")
|
||
result := make([]string, 0)
|
||
for _, t := range tags {
|
||
trimmed := strings.TrimSpace(t)
|
||
if trimmed != "" {
|
||
result = append(result, trimmed)
|
||
}
|
||
}
|
||
return result
|
||
}
|
||
|
||
func GetFreePort() (int, error) {
|
||
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
l, err := net.ListenTCP("tcp", addr)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
defer l.Close()
|
||
|
||
return l.Addr().(*net.TCPAddr).Port, nil
|
||
}
|
||
|
||
// StructToMap 将结构体转换为 map[string]any
|
||
func StructToMap(v any) (map[string]any, error) {
|
||
b, err := json.Marshal(v)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
var m map[string]any
|
||
err = json.Unmarshal(b, &m)
|
||
return m, err
|
||
}
|
||
|
||
// ReadDocxToBytes 读取 docx 文件并返回 []byte
|
||
func ReadDocxToBytes(filePath string) ([]byte, error) {
|
||
// 读取文件
|
||
data, err := os.ReadFile(filePath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取文件失败: %w", err)
|
||
}
|
||
|
||
return data, nil
|
||
}
|
||
|
||
// getFileExtension 获取文件扩展名
|
||
func GetFileExtension(filename string) string {
|
||
parts := strings.Split(filename, ".")
|
||
if len(parts) < 2 {
|
||
return ".jpg" // 默认扩展名
|
||
}
|
||
return "." + parts[len(parts)-1]
|
||
}
|
||
|
||
// generateImageFileName 根据图片类型生成文件名
|
||
func GenerateImageFileName(ext string) string {
|
||
// 生成唯一标识
|
||
timestamp := time.Now().UnixNano() / 1e6 // 毫秒级时间戳
|
||
randomNum := GenerateRandomLowerString(3) // 随机数
|
||
return fmt.Sprintf("img_%d%s%s", timestamp, randomNum, ext)
|
||
}
|