package pkg import ( "ai_scheduler/internal/entitys" "encoding/json" "errors" "fmt" "net/url" "os" "path/filepath" "reflect" "strconv" "strings" "time" jsoniter "github.com/json-iterator/go" ) func JsonStringIgonErr(data interface{}) string { return string(JsonByteIgonErr(data)) } func JsonByteIgonErr(data interface{}) []byte { dataByte, _ := json.Marshal(data) return dataByte } // IsChannelClosed 检查给定的 channel 是否已经关闭 // 参数 ch: 要检查的 channel,类型为 chan entitys.ResponseData // 返回值: bool 类型,true 表示 channel 已关闭,false 表示未关闭 func IsChannelClosed(ch chan entitys.ResponseData) bool { select { case _, ok := <-ch: // 尝试从 channel 中读取数据 return !ok // 如果 ok=false,说明 channel 已关闭 default: // 如果 channel 暂时无数据可读(但不一定关闭) return false // channel 未关闭(但可能有数据未读取) } } // ValidateImageURL 验证图片 URL 是否有效 func ValidateImageURL(rawURL string) error { // 1. 基础格式验证 parsed, err := url.Parse(rawURL) if err != nil { return fmt.Errorf("未知的图片格式: %v", err) } // 2. 检查协议是否为 http/https if parsed.Scheme != "http" && parsed.Scheme != "https" { return errors.New("必须是http/https结构") } // 3. 检查是否有空的主机名 if parsed.Host == "" { return errors.New("未知的url地址") } // 4. 检查路径是否为空 if strings.TrimSpace(parsed.Path) == "" { return errors.New("url为空") } return nil } // hexEncode 将 src 的二进制数据编码为十六进制字符串,写入 dst,返回写入长度 func HexEncode(src, dst []byte) int { const hextable = "0123456789abcdef" for i := 0; i < len(src); i++ { dst[i*2] = hextable[src[i]>>4] dst[i*2+1] = hextable[src[i]&0xf] } return len(src) * 2 } // Ter 三目运算 Ter(true, 1, 2) func Ter[T any](cond bool, a, b T) T { if cond { return a } return b } // StringToSlice [num,num]转slice func StringToSlice(s string) ([]int, error) { // 1. 去掉两端的方括号 trimmed := strings.Trim(s, "[]") // 2. 按逗号分割 parts := strings.Split(trimmed, ",") // 3. 转换为 []int result := make([]int, 0, len(parts)) for _, part := range parts { num, err := strconv.Atoi(strings.TrimSpace(part)) if err != nil { return nil, err } result = append(result, num) } return result, nil } // Difference 差集 func Difference[T comparable](a, b []T) []T { // 创建 b 的映射(T 必须是可比较的类型) bMap := make(map[T]struct{}, len(b)) for _, item := range b { bMap[item] = struct{}{} } var diff []T // 修正为 []T 而非 []int for _, item := range a { if _, found := bMap[item]; !found { diff = append(diff, item) } } return diff } // SliceStringToInt []string=>[]int func SliceStringToInt(strSlice []string) []int { numSlice := make([]int, len(strSlice)) for i, str := range strSlice { num, err := strconv.Atoi(str) if err != nil { return nil } numSlice[i] = num } return numSlice } // SliceIntToString []int=>[]string func SliceIntToString(slice []int) []string { strSlice := make([]string, len(slice)) // len=cap=len(slice) for i, num := range slice { strSlice[i] = strconv.Itoa(num) // 直接赋值,无 append } return strSlice } // SafeReplace 替换字符串中的 %s,并自动转义特殊字符(如 ") /** * SafeReplace 函数用于安全地替换模板字符串中的占位符 * @param template 原始模板字符串 * @param replaceTag 要被替换的占位符(如 "%s") * @param replacements 可变参数,用于替换占位符的字符串 * @return 返回替换后的字符串和可能的错误 */ func SafeReplace(template string, replaceTag string, replacements ...string) (string, error) { // 如果没有提供替换参数,直接返回原始模板 if len(replacements) == 0 { return template, nil } // 检查模板中 %s 的数量是否匹配替换参数 expectedReplacements := strings.Count(template, replaceTag) if expectedReplacements != len(replacements) { return "", fmt.Errorf("模板需要 %d 个替换参数,但提供了 %d 个", expectedReplacements, len(replacements)) } // 逐个替换 %s,并转义特殊字符 for _, rep := range replacements { // 转义特殊字符(如 ", \, \n 等) escaped := strconv.Quote(rep) // 去掉 strconv.Quote 添加的额外引号 escaped = escaped[1 : len(escaped)-1] template = strings.Replace(template, replaceTag, escaped, 1) } return template, nil } func StructToMapUsingJsoniter(obj interface{}) (map[string]string, error) { var json = jsoniter.ConfigCompatibleWithStandardLibrary // 转换为JSON jsonBytes, err := json.Marshal(obj) if err != nil { return nil, err } // 解析为map[string]interface{} var tempMap map[string]interface{} err = json.Unmarshal(jsonBytes, &tempMap) if err != nil { return nil, err } // 转换为map[string]string result := make(map[string]string) for k, v := range tempMap { result[k] = fmt.Sprintf("%v", v) } return result, nil } func GetModuleDir() (string, error) { dir, err := os.Getwd() if err != nil { return "", err } for { modPath := filepath.Join(dir, "go.mod") if _, err := os.Stat(modPath); err == nil { return dir, nil // 找到 go.mod } // 向上查找父目录 parent := filepath.Dir(dir) if parent == dir { break // 到达根目录,未找到 } dir = parent } return "", fmt.Errorf("go.mod not found in current directory or parents") } // GetCacheDir 用于获取缓存目录路径 // 如果缓存目录不存在,则会自动创建 // 返回值: // - string: 缓存目录的路径 // - error: 如果获取模块目录失败或创建缓存目录失败,则返回错误信息 func GetCacheDir() (string, error) { // 获取模块目录 modDir, err := GetModuleDir() if err != nil { return "", err } // 拼接缓存目录路径 path := fmt.Sprintf("%s/cache", modDir) // 创建目录(包括所有必要的父目录),权限设置为0755 err = os.MkdirAll(path, 0755) if err != nil { return "", fmt.Errorf("创建目录失败: %w", err) } // 返回成功创建的缓存目录路径 return path, nil } func GetTmplDir() (string, error) { modDir, err := GetModuleDir() if err != nil { return "", err } path := fmt.Sprintf("%s/tmpl", modDir) err = os.MkdirAll(path, 0755) if err != nil { return "", fmt.Errorf("创建目录失败: %w", err) } return path, nil } // 通用结构体转 Query 参数 func StructToQuery(obj interface{}) (url.Values, error) { values := url.Values{} v := reflect.ValueOf(obj) t := reflect.TypeOf(obj) // 如果是指针,获取指向的值 if v.Kind() == reflect.Ptr { v = v.Elem() t = t.Elem() } // 确保是结构体 if v.Kind() != reflect.Struct { return values, fmt.Errorf("expected struct, got %v", v.Kind()) } for i := 0; i < v.NumField(); i++ { field := v.Field(i) fieldType := t.Field(i) // 跳过零值字段(omitempty) tag := fieldType.Tag.Get("json") if strings.Contains(tag, "omitempty") && field.IsZero() { continue } // 获取字段名 fieldName := getFieldName(fieldType) if fieldName == "" { continue } // 处理不同类型的字段 addFieldToValues(values, fieldName, field) } return values, nil } func getFieldName(field reflect.StructField) string { tag := field.Tag.Get("json") if tag != "" { parts := strings.Split(tag, ",") if parts[0] != "-" && parts[0] != "" { return parts[0] } if parts[0] == "-" { return "" // 跳过该字段 } } return field.Name } func addFieldToValues(values url.Values, name string, field reflect.Value) { if !field.IsValid() || field.IsZero() { return } switch field.Kind() { case reflect.String: values.Add(name, field.String()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: values.Add(name, strconv.FormatInt(field.Int(), 10)) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: values.Add(name, strconv.FormatUint(field.Uint(), 10)) case reflect.Float32, reflect.Float64: values.Add(name, strconv.FormatFloat(field.Float(), 'f', -1, 64)) case reflect.Bool: values.Add(name, strconv.FormatBool(field.Bool())) case reflect.Slice: // 处理切片,特别是 []string if field.Type().Elem().Kind() == reflect.String { for i := 0; i < field.Len(); i++ { item := field.Index(i).String() // 特殊处理 ct 字段 if name == "ct" { formatted := strings.Replace(item, " ", "+", 1) if i == 1 && field.Len() >= 2 { formatted = formatted + ".999" } values.Add("ct[]", formatted) } else { values.Add(fmt.Sprintf("%s[]", name), item) } } } case reflect.Struct: // 处理 time.Time if t, ok := field.Interface().(time.Time); ok { values.Add(name, t.Format("2006-01-02+15:04:05")) } default: values.Add(name, fmt.Sprintf("%v", field.Interface())) } }