361 lines
8.8 KiB
Go
361 lines
8.8 KiB
Go
package pkg
|
||
|
||
import (
|
||
"ai_scheduler/internal/entitys"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"net/url"
|
||
"os"
|
||
"path/filepath"
|
||
"reflect"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
jsoniter "github.com/json-iterator/go"
|
||
)
|
||
|
||
func JsonStringIgonErr(data interface{}) string {
|
||
return string(JsonByteIgonErr(data))
|
||
}
|
||
|
||
func JsonByteIgonErr(data interface{}) []byte {
|
||
dataByte, _ := json.Marshal(data)
|
||
return dataByte
|
||
}
|
||
|
||
// IsChannelClosed 检查给定的 channel 是否已经关闭
|
||
// 参数 ch: 要检查的 channel,类型为 chan entitys.ResponseData
|
||
// 返回值: bool 类型,true 表示 channel 已关闭,false 表示未关闭
|
||
func IsChannelClosed(ch chan entitys.ResponseData) bool {
|
||
select {
|
||
case _, ok := <-ch: // 尝试从 channel 中读取数据
|
||
return !ok // 如果 ok=false,说明 channel 已关闭
|
||
default: // 如果 channel 暂时无数据可读(但不一定关闭)
|
||
return false // channel 未关闭(但可能有数据未读取)
|
||
}
|
||
}
|
||
|
||
// ValidateImageURL 验证图片 URL 是否有效
|
||
func ValidateImageURL(rawURL string) error {
|
||
// 1. 基础格式验证
|
||
parsed, err := url.Parse(rawURL)
|
||
if err != nil {
|
||
return fmt.Errorf("未知的图片格式: %v", err)
|
||
}
|
||
|
||
// 2. 检查协议是否为 http/https
|
||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||
return errors.New("必须是http/https结构")
|
||
}
|
||
|
||
// 3. 检查是否有空的主机名
|
||
if parsed.Host == "" {
|
||
return errors.New("未知的url地址")
|
||
}
|
||
|
||
// 4. 检查路径是否为空
|
||
if strings.TrimSpace(parsed.Path) == "" {
|
||
return errors.New("url为空")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// hexEncode 将 src 的二进制数据编码为十六进制字符串,写入 dst,返回写入长度
|
||
func HexEncode(src, dst []byte) int {
|
||
const hextable = "0123456789abcdef"
|
||
for i := 0; i < len(src); i++ {
|
||
dst[i*2] = hextable[src[i]>>4]
|
||
dst[i*2+1] = hextable[src[i]&0xf]
|
||
}
|
||
return len(src) * 2
|
||
}
|
||
|
||
// Ter 三目运算 Ter(true, 1, 2)
|
||
func Ter[T any](cond bool, a, b T) T {
|
||
if cond {
|
||
return a
|
||
}
|
||
return b
|
||
}
|
||
|
||
// StringToSlice [num,num]转slice
|
||
func StringToSlice(s string) ([]int, error) {
|
||
// 1. 去掉两端的方括号
|
||
trimmed := strings.Trim(s, "[]")
|
||
|
||
// 2. 按逗号分割
|
||
parts := strings.Split(trimmed, ",")
|
||
|
||
// 3. 转换为 []int
|
||
result := make([]int, 0, len(parts))
|
||
for _, part := range parts {
|
||
num, err := strconv.Atoi(strings.TrimSpace(part))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
result = append(result, num)
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// Difference 差集
|
||
func Difference[T comparable](a, b []T) []T {
|
||
// 创建 b 的映射(T 必须是可比较的类型)
|
||
bMap := make(map[T]struct{}, len(b))
|
||
for _, item := range b {
|
||
bMap[item] = struct{}{}
|
||
}
|
||
|
||
var diff []T // 修正为 []T 而非 []int
|
||
for _, item := range a {
|
||
if _, found := bMap[item]; !found {
|
||
diff = append(diff, item)
|
||
}
|
||
}
|
||
return diff
|
||
}
|
||
|
||
// SliceStringToInt []string=>[]int
|
||
func SliceStringToInt(strSlice []string) []int {
|
||
numSlice := make([]int, len(strSlice))
|
||
for i, str := range strSlice {
|
||
num, err := strconv.Atoi(str)
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
numSlice[i] = num
|
||
}
|
||
return numSlice
|
||
}
|
||
|
||
// SliceIntToString []int=>[]string
|
||
func SliceIntToString(slice []int) []string {
|
||
strSlice := make([]string, len(slice)) // len=cap=len(slice)
|
||
for i, num := range slice {
|
||
strSlice[i] = strconv.Itoa(num) // 直接赋值,无 append
|
||
}
|
||
return strSlice
|
||
}
|
||
|
||
// SafeReplace 替换字符串中的 %s,并自动转义特殊字符(如 ")
|
||
/**
|
||
* SafeReplace 函数用于安全地替换模板字符串中的占位符
|
||
* @param template 原始模板字符串
|
||
* @param replaceTag 要被替换的占位符(如 "%s")
|
||
* @param replacements 可变参数,用于替换占位符的字符串
|
||
* @return 返回替换后的字符串和可能的错误
|
||
*/
|
||
func SafeReplace(template string, replaceTag string, replacements ...string) (string, error) {
|
||
// 如果没有提供替换参数,直接返回原始模板
|
||
if len(replacements) == 0 {
|
||
return template, nil
|
||
}
|
||
|
||
// 检查模板中 %s 的数量是否匹配替换参数
|
||
expectedReplacements := strings.Count(template, replaceTag)
|
||
if expectedReplacements != len(replacements) {
|
||
return "", fmt.Errorf("模板需要 %d 个替换参数,但提供了 %d 个", expectedReplacements, len(replacements))
|
||
}
|
||
|
||
// 逐个替换 %s,并转义特殊字符
|
||
for _, rep := range replacements {
|
||
// 转义特殊字符(如 ", \, \n 等)
|
||
escaped := strconv.Quote(rep)
|
||
// 去掉 strconv.Quote 添加的额外引号
|
||
escaped = escaped[1 : len(escaped)-1]
|
||
template = strings.Replace(template, replaceTag, escaped, 1)
|
||
}
|
||
|
||
return template, nil
|
||
}
|
||
|
||
func StructToMapUsingJsoniter(obj interface{}) (map[string]string, error) {
|
||
var json = jsoniter.ConfigCompatibleWithStandardLibrary
|
||
|
||
// 转换为JSON
|
||
jsonBytes, err := json.Marshal(obj)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 解析为map[string]interface{}
|
||
var tempMap map[string]interface{}
|
||
err = json.Unmarshal(jsonBytes, &tempMap)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 转换为map[string]string
|
||
result := make(map[string]string)
|
||
for k, v := range tempMap {
|
||
result[k] = fmt.Sprintf("%v", v)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
func GetModuleDir() (string, error) {
|
||
dir, err := os.Getwd()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
for {
|
||
modPath := filepath.Join(dir, "go.mod")
|
||
if _, err := os.Stat(modPath); err == nil {
|
||
return dir, nil // 找到 go.mod
|
||
}
|
||
|
||
// 向上查找父目录
|
||
parent := filepath.Dir(dir)
|
||
if parent == dir {
|
||
break // 到达根目录,未找到
|
||
}
|
||
dir = parent
|
||
}
|
||
|
||
return "", fmt.Errorf("go.mod not found in current directory or parents")
|
||
}
|
||
|
||
// GetCacheDir 用于获取缓存目录路径
|
||
// 如果缓存目录不存在,则会自动创建
|
||
// 返回值:
|
||
// - string: 缓存目录的路径
|
||
// - error: 如果获取模块目录失败或创建缓存目录失败,则返回错误信息
|
||
func GetCacheDir() (string, error) {
|
||
// 获取模块目录
|
||
modDir, err := GetModuleDir()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
// 拼接缓存目录路径
|
||
path := fmt.Sprintf("%s/cache", modDir)
|
||
// 创建目录(包括所有必要的父目录),权限设置为0755
|
||
err = os.MkdirAll(path, 0755)
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建目录失败: %w", err)
|
||
}
|
||
// 返回成功创建的缓存目录路径
|
||
return path, nil
|
||
}
|
||
|
||
func GetTmplDir() (string, error) {
|
||
modDir, err := GetModuleDir()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
path := fmt.Sprintf("%s/tmpl", modDir)
|
||
err = os.MkdirAll(path, 0755)
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建目录失败: %w", err)
|
||
}
|
||
return path, nil
|
||
}
|
||
|
||
// 通用结构体转 Query 参数
|
||
func StructToQuery(obj interface{}) (url.Values, error) {
|
||
values := url.Values{}
|
||
v := reflect.ValueOf(obj)
|
||
t := reflect.TypeOf(obj)
|
||
|
||
// 如果是指针,获取指向的值
|
||
if v.Kind() == reflect.Ptr {
|
||
v = v.Elem()
|
||
t = t.Elem()
|
||
}
|
||
|
||
// 确保是结构体
|
||
if v.Kind() != reflect.Struct {
|
||
return values, fmt.Errorf("expected struct, got %v", v.Kind())
|
||
}
|
||
|
||
for i := 0; i < v.NumField(); i++ {
|
||
field := v.Field(i)
|
||
fieldType := t.Field(i)
|
||
|
||
// 跳过零值字段(omitempty)
|
||
tag := fieldType.Tag.Get("json")
|
||
if strings.Contains(tag, "omitempty") && field.IsZero() {
|
||
continue
|
||
}
|
||
|
||
// 获取字段名
|
||
fieldName := getFieldName(fieldType)
|
||
if fieldName == "" {
|
||
continue
|
||
}
|
||
|
||
// 处理不同类型的字段
|
||
addFieldToValues(values, fieldName, field)
|
||
}
|
||
|
||
return values, nil
|
||
}
|
||
|
||
func getFieldName(field reflect.StructField) string {
|
||
tag := field.Tag.Get("json")
|
||
if tag != "" {
|
||
parts := strings.Split(tag, ",")
|
||
if parts[0] != "-" && parts[0] != "" {
|
||
return parts[0]
|
||
}
|
||
if parts[0] == "-" {
|
||
return "" // 跳过该字段
|
||
}
|
||
}
|
||
return field.Name
|
||
}
|
||
|
||
func addFieldToValues(values url.Values, name string, field reflect.Value) {
|
||
if !field.IsValid() || field.IsZero() {
|
||
return
|
||
}
|
||
|
||
switch field.Kind() {
|
||
case reflect.String:
|
||
values.Add(name, field.String())
|
||
|
||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||
values.Add(name, strconv.FormatInt(field.Int(), 10))
|
||
|
||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||
values.Add(name, strconv.FormatUint(field.Uint(), 10))
|
||
|
||
case reflect.Float32, reflect.Float64:
|
||
values.Add(name, strconv.FormatFloat(field.Float(), 'f', -1, 64))
|
||
|
||
case reflect.Bool:
|
||
values.Add(name, strconv.FormatBool(field.Bool()))
|
||
|
||
case reflect.Slice:
|
||
// 处理切片,特别是 []string
|
||
if field.Type().Elem().Kind() == reflect.String {
|
||
for i := 0; i < field.Len(); i++ {
|
||
item := field.Index(i).String()
|
||
// 特殊处理 ct 字段
|
||
if name == "ct" {
|
||
formatted := strings.Replace(item, " ", "+", 1)
|
||
if i == 1 && field.Len() >= 2 {
|
||
formatted = formatted + ".999"
|
||
}
|
||
values.Add("ct[]", formatted)
|
||
} else {
|
||
values.Add(fmt.Sprintf("%s[]", name), item)
|
||
}
|
||
}
|
||
}
|
||
|
||
case reflect.Struct:
|
||
// 处理 time.Time
|
||
if t, ok := field.Interface().(time.Time); ok {
|
||
values.Add(name, t.Format("2006-01-02+15:04:05"))
|
||
}
|
||
|
||
default:
|
||
values.Add(name, fmt.Sprintf("%v", field.Interface()))
|
||
}
|
||
}
|