geoGo/pkg/func.go

281 lines
5.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package pkg
import (
"encoding/json"
"fmt"
"io"
"math/rand/v2"
"net/http"
"os"
"path/filepath"
"reflect"
"strings"
"time"
"github.com/go-viper/mapstructure/v2"
"github.com/google/uuid"
)
func GetModuleDir() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", err
}
for {
modPath := filepath.Join(dir, "go.mod")
if _, err := os.Stat(modPath); err == nil {
return dir, nil // 找到 go.mod
}
// 向上查找父目录
parent := filepath.Dir(dir)
if parent == dir {
break // 到达根目录,未找到
}
dir = parent
}
return "", fmt.Errorf("go.mod not found in current directory or parents")
}
// GetCacheDir 用于获取缓存目录路径
// 如果缓存目录不存在,则会自动创建
// 返回值:
// - string: 缓存目录的路径
// - error: 如果获取模块目录失败或创建缓存目录失败,则返回错误信息
func GetCacheDir() (string, error) {
// 获取模块目录
modDir, err := GetModuleDir()
if err != nil {
return "", err
}
// 拼接缓存目录路径
path := fmt.Sprintf("%s/cache", modDir)
// 创建目录包括所有必要的父目录权限设置为0755
err = os.MkdirAll(path, 0755)
if err != nil {
return "", fmt.Errorf("创建目录失败: %w", err)
}
// 返回成功创建的缓存目录路径
return path, nil
}
func GetTmplDir() (string, error) {
modDir, err := GetModuleDir()
if err != nil {
return "", err
}
path := fmt.Sprintf("%s/tmpl", modDir)
err = os.MkdirAll(path, 0755)
if err != nil {
return "", fmt.Errorf("创建目录失败: %w", err)
}
return path, nil
}
func ReverseSliceNew[T any](s []T) []T {
result := make([]T, len(s))
for i := 0; i < len(s); i++ {
result[i] = s[len(s)-1-i]
}
return result
}
func JsonStringIgonErr(data interface{}) string {
return string(JsonByteIgonErr(data))
}
func JsonByteIgonErr(data interface{}) []byte {
dataByte, _ := json.Marshal(data)
return dataByte
}
func IntersectionGeneric[T comparable](slice1, slice2 []T) []T {
m := make(map[T]bool)
result := []T{}
for _, v := range slice1 {
m[v] = true
}
for _, v := range slice2 {
if m[v] {
result = append(result, v)
delete(m, v) // 避免重复
}
}
return result
}
func CreateOrderNum(prefix string) string {
code := fmt.Sprintf("%04d", rand.IntN(10000))
fmt.Println("4位随机数字:", code) // 输出示例: "0837"
return prefix + time.Now().Format("20060102150405") + code
}
func BuildUpdateMap(obj interface{}, omitFields ...string) map[string]interface{} {
result := make(map[string]interface{})
omitMap := make(map[string]bool)
for _, f := range omitFields {
omitMap[f] = true
}
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
fieldName := t.Field(i).Name
if omitMap[fieldName] {
continue
}
// 只处理非 nil 的指针字段
if field.Kind() == reflect.Ptr && !field.IsNil() {
// 将驼峰转为下划线(可选的,根据你的数据库列名决定)
colName := CamelToSnake(fieldName)
result[colName] = field.Elem().Interface()
}
}
return result
}
func CamelToSnake(s string) string {
var result []rune
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
result = append(result, '_')
}
result = append(result, r)
}
return string(result)
}
func CopyNonNilFields(src, dst interface{}) error {
config := &mapstructure.DecoderConfig{
Result: dst,
TagName: "json",
ZeroFields: false, // 重要:不清零目标字段
Squash: false,
}
decoder, err := mapstructure.NewDecoder(config)
if err != nil {
return err
}
return decoder.Decode(src)
}
func DownloadFile(url string, saveDir string, filename string) (string, error) {
os.MkdirAll(saveDir, 0755)
if filename == "" {
filename = uuid.New().String() + ".docx"
}
filePath := filepath.Join(saveDir, filename)
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
out, err := os.Create(filePath)
if err != nil {
return "", err
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
return "", err
}
absPath, _ := filepath.Abs(filePath)
return absPath, nil
}
func DownloadImage(url string, requestID string, dir string) (string, error) {
os.MkdirAll(dir, 0755)
ext := filepath.Ext(url)
if ext == "" {
ext = ".jpg"
}
filename := requestID + "_" + uuid.New().String() + ext
filePath := filepath.Join(dir, filename)
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
out, err := os.Create(filePath)
if err != nil {
return "", err
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
return "", err
}
return filepath.Abs(filePath)
}
func DeleteFile(path string) {
if path != "" {
os.Remove(path)
}
}
func GenerateUUID() string {
return uuid.New().String()
}
func GenerateUserIndex() string {
return uuid.New().String()[:20]
}
func GetString(m map[string]interface{}, key string) string {
if v, ok := m[key]; ok {
switch v.(type) {
case []uint8:
return string(v.([]uint8))
case string:
return v.(string)
case int64:
return fmt.Sprintf("%d", v)
default:
return fmt.Sprintf("%v", v)
}
}
return ""
}
func ParseTags(tagStr string) []string {
if tagStr == "" {
return []string{}
}
tags := strings.Split(tagStr, ",")
result := make([]string, 0)
for _, t := range tags {
trimmed := strings.TrimSpace(t)
if trimmed != "" {
result = append(result, trimmed)
}
}
return result
}