ai_scheduler/internal/pkg/func.go

301 lines
7.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package pkg
import (
"ai_scheduler/internal/entitys"
"encoding/json"
"errors"
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
"time"
jsoniter "github.com/json-iterator/go"
)
func JsonStringIgonErr(data interface{}) string {
return string(JsonByteIgonErr(data))
}
func JsonByteIgonErr(data interface{}) []byte {
dataByte, _ := json.Marshal(data)
return dataByte
}
// IsChannelClosed 检查给定的 channel 是否已经关闭
// 参数 ch: 要检查的 channel类型为 chan entitys.ResponseData
// 返回值: bool 类型true 表示 channel 已关闭false 表示未关闭
func IsChannelClosed(ch chan entitys.ResponseData) bool {
select {
case _, ok := <-ch: // 尝试从 channel 中读取数据
return !ok // 如果 ok=false说明 channel 已关闭
default: // 如果 channel 暂时无数据可读(但不一定关闭)
return false // channel 未关闭(但可能有数据未读取)
}
}
// ValidateImageURL 验证图片 URL 是否有效
func ValidateImageURL(rawURL string) error {
// 1. 基础格式验证
parsed, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("未知的图片格式: %v", err)
}
// 2. 检查协议是否为 http/https
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return errors.New("必须是http/https结构")
}
// 3. 检查是否有空的主机名
if parsed.Host == "" {
return errors.New("未知的url地址")
}
// 4. 检查路径是否为空
if strings.TrimSpace(parsed.Path) == "" {
return errors.New("url为空")
}
return nil
}
// hexEncode 将 src 的二进制数据编码为十六进制字符串,写入 dst返回写入长度
func HexEncode(src, dst []byte) int {
const hextable = "0123456789abcdef"
for i := 0; i < len(src); i++ {
dst[i*2] = hextable[src[i]>>4]
dst[i*2+1] = hextable[src[i]&0xf]
}
return len(src) * 2
}
// Ter 三目运算 Ter(true, 1, 2)
func Ter[T any](cond bool, a, b T) T {
if cond {
return a
}
return b
}
// StringToSlice [num,num]转slice
func StringToSlice(s string) ([]int, error) {
// 1. 去掉两端的方括号
trimmed := strings.Trim(s, "[]")
// 2. 按逗号分割
parts := strings.Split(trimmed, ",")
// 3. 转换为 []int
result := make([]int, 0, len(parts))
for _, part := range parts {
num, err := strconv.Atoi(strings.TrimSpace(part))
if err != nil {
return nil, err
}
result = append(result, num)
}
return result, nil
}
// Difference 差集
func Difference[T comparable](a, b []T) []T {
// 创建 b 的映射T 必须是可比较的类型)
bMap := make(map[T]struct{}, len(b))
for _, item := range b {
bMap[item] = struct{}{}
}
var diff []T // 修正为 []T 而非 []int
for _, item := range a {
if _, found := bMap[item]; !found {
diff = append(diff, item)
}
}
return diff
}
// SliceStringToInt []string=>[]int
func SliceStringToInt(strSlice []string) []int {
numSlice := make([]int, len(strSlice))
for i, str := range strSlice {
num, err := strconv.Atoi(str)
if err != nil {
return nil
}
numSlice[i] = num
}
return numSlice
}
// SliceIntToString []int=>[]string
func SliceIntToString(slice []int) []string {
strSlice := make([]string, len(slice)) // len=cap=len(slice)
for i, num := range slice {
strSlice[i] = strconv.Itoa(num) // 直接赋值,无 append
}
return strSlice
}
// SafeReplace 替换字符串中的 %s并自动转义特殊字符如 "
/**
* SafeReplace 函数用于安全地替换模板字符串中的占位符
* @param template 原始模板字符串
* @param replaceTag 要被替换的占位符(如 "%s"
* @param replacements 可变参数,用于替换占位符的字符串
* @return 返回替换后的字符串和可能的错误
*/
func SafeReplace(template string, replaceTag string, replacements ...string) (string, error) {
// 如果没有提供替换参数,直接返回原始模板
if len(replacements) == 0 {
return template, nil
}
// 检查模板中 %s 的数量是否匹配替换参数
expectedReplacements := strings.Count(template, replaceTag)
if expectedReplacements != len(replacements) {
return "", fmt.Errorf("模板需要 %d 个替换参数,但提供了 %d 个", expectedReplacements, len(replacements))
}
// 逐个替换 %s并转义特殊字符
for _, rep := range replacements {
// 转义特殊字符(如 ", \, \n 等)
escaped := strconv.Quote(rep)
// 去掉 strconv.Quote 添加的额外引号
escaped = escaped[1 : len(escaped)-1]
template = strings.Replace(template, replaceTag, escaped, 1)
}
return template, nil
}
func StructToMapUsingJsoniter(obj interface{}) (map[string]string, error) {
var json = jsoniter.ConfigCompatibleWithStandardLibrary
// 转换为JSON
jsonBytes, err := json.Marshal(obj)
if err != nil {
return nil, err
}
// 解析为map[string]interface{}
var tempMap map[string]interface{}
err = json.Unmarshal(jsonBytes, &tempMap)
if err != nil {
return nil, err
}
// 转换为map[string]string
result := make(map[string]string)
for k, v := range tempMap {
result[k] = fmt.Sprintf("%v", v)
}
return result, nil
}
// 通用结构体转 Query 参数
func StructToQuery(obj interface{}) (url.Values, error) {
values := url.Values{}
v := reflect.ValueOf(obj)
t := reflect.TypeOf(obj)
// 如果是指针,获取指向的值
if v.Kind() == reflect.Ptr {
v = v.Elem()
t = t.Elem()
}
// 确保是结构体
if v.Kind() != reflect.Struct {
return values, fmt.Errorf("expected struct, got %v", v.Kind())
}
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
fieldType := t.Field(i)
// 跳过零值字段omitempty
tag := fieldType.Tag.Get("json")
if strings.Contains(tag, "omitempty") && field.IsZero() {
continue
}
// 获取字段名
fieldName := getFieldName(fieldType)
if fieldName == "" {
continue
}
// 处理不同类型的字段
addFieldToValues(values, fieldName, field)
}
return values, nil
}
func getFieldName(field reflect.StructField) string {
tag := field.Tag.Get("json")
if tag != "" {
parts := strings.Split(tag, ",")
if parts[0] != "-" && parts[0] != "" {
return parts[0]
}
if parts[0] == "-" {
return "" // 跳过该字段
}
}
return field.Name
}
func addFieldToValues(values url.Values, name string, field reflect.Value) {
if !field.IsValid() || field.IsZero() {
return
}
switch field.Kind() {
case reflect.String:
values.Add(name, field.String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
values.Add(name, strconv.FormatInt(field.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
values.Add(name, strconv.FormatUint(field.Uint(), 10))
case reflect.Float32, reflect.Float64:
values.Add(name, strconv.FormatFloat(field.Float(), 'f', -1, 64))
case reflect.Bool:
values.Add(name, strconv.FormatBool(field.Bool()))
case reflect.Slice:
// 处理切片,特别是 []string
if field.Type().Elem().Kind() == reflect.String {
for i := 0; i < field.Len(); i++ {
item := field.Index(i).String()
// 特殊处理 ct 字段
if name == "ct" {
formatted := strings.Replace(item, " ", "+", 1)
if i == 1 && field.Len() >= 2 {
formatted = formatted + ".999"
}
values.Add("ct[]", formatted)
} else {
values.Add(fmt.Sprintf("%s[]", name), item)
}
}
}
case reflect.Struct:
// 处理 time.Time
if t, ok := field.Interface().(time.Time); ok {
values.Add(name, t.Format("2006-01-02+15:04:05"))
}
default:
values.Add(name, fmt.Sprintf("%v", field.Interface()))
}
}