ai_scheduler/internal/pkg/func.go

444 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package pkg
import (
"ai_scheduler/internal/entitys"
"encoding/json"
"errors"
"fmt"
"math/rand"
"net/url"
"reflect"
"strconv"
"strings"
"time"
)
func JsonStringIgonErr(data interface{}) string {
return string(JsonByteIgonErr(data))
}
func JsonByteIgonErr(data interface{}) []byte {
dataByte, _ := json.Marshal(data)
return dataByte
}
// IsChannelClosed 检查给定的 channel 是否已经关闭
// 参数 ch: 要检查的 channel类型为 chan entitys.ResponseData
// 返回值: bool 类型true 表示 channel 已关闭false 表示未关闭
func IsChannelClosed(ch chan entitys.ResponseData) bool {
select {
case _, ok := <-ch: // 尝试从 channel 中读取数据
return !ok // 如果 ok=false说明 channel 已关闭
default: // 如果 channel 暂时无数据可读(但不一定关闭)
return false // channel 未关闭(但可能有数据未读取)
}
}
// ValidateImageURL 验证图片 URL 是否有效
func ValidateImageURL(rawURL string) error {
// 1. 基础格式验证
parsed, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("未知的图片格式: %v", err)
}
// 2. 检查协议是否为 http/https
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return errors.New("必须是http/https结构")
}
// 3. 检查是否有空的主机名
if parsed.Host == "" {
return errors.New("未知的url地址")
}
// 4. 检查路径是否为空
if strings.TrimSpace(parsed.Path) == "" {
return errors.New("url为空")
}
return nil
}
// hexEncode 将 src 的二进制数据编码为十六进制字符串,写入 dst返回写入长度
func HexEncode(src, dst []byte) int {
const hextable = "0123456789abcdef"
for i := 0; i < len(src); i++ {
dst[i*2] = hextable[src[i]>>4]
dst[i*2+1] = hextable[src[i]&0xf]
}
return len(src) * 2
}
// Ter 三目运算 Ter(true, 1, 2)
func Ter[T any](cond bool, a, b T) T {
if cond {
return a
}
return b
}
// StringToSlice [num,num]转slice
func StringToSlice(s string) ([]int, error) {
// 1. 去掉两端的方括号
trimmed := strings.Trim(s, "[]")
// 2. 按逗号分割
parts := strings.Split(trimmed, ",")
// 3. 转换为 []int
result := make([]int, 0, len(parts))
for _, part := range parts {
num, err := strconv.Atoi(strings.TrimSpace(part))
if err != nil {
return nil, err
}
result = append(result, num)
}
return result, nil
}
// Difference 差集
func Difference[T comparable](a, b []T) []T {
// 创建 b 的映射T 必须是可比较的类型)
bMap := make(map[T]struct{}, len(b))
for _, item := range b {
bMap[item] = struct{}{}
}
var diff []T // 修正为 []T 而非 []int
for _, item := range a {
if _, found := bMap[item]; !found {
diff = append(diff, item)
}
}
return diff
}
// SliceStringToInt []string=>[]int
func SliceStringToInt(strSlice []string) []int {
numSlice := make([]int, len(strSlice))
for i, str := range strSlice {
num, err := strconv.Atoi(str)
if err != nil {
return nil
}
numSlice[i] = num
}
return numSlice
}
// SliceIntToString []int=>[]string
func SliceIntToString(slice []int) []string {
strSlice := make([]string, len(slice)) // len=cap=len(slice)
for i, num := range slice {
strSlice[i] = strconv.Itoa(num) // 直接赋值,无 append
}
return strSlice
}
// SafeReplace 替换字符串中的 %s并自动转义特殊字符如 "
/**
* SafeReplace 函数用于安全地替换模板字符串中的占位符
* @param template 原始模板字符串
* @param replaceTag 要被替换的占位符(如 "%s"
* @param replacements 可变参数,用于替换占位符的字符串
* @return 返回替换后的字符串和可能的错误
*/
func SafeReplace(template string, replaceTag string, replacements ...string) (string, error) {
// 如果没有提供替换参数,直接返回原始模板
if len(replacements) == 0 {
return template, nil
}
// 检查模板中 %s 的数量是否匹配替换参数
expectedReplacements := strings.Count(template, replaceTag)
if expectedReplacements != len(replacements) {
return "", fmt.Errorf("模板需要 %d 个替换参数,但提供了 %d 个", expectedReplacements, len(replacements))
}
// 逐个替换 %s并转义特殊字符
for _, rep := range replacements {
// 转义特殊字符(如 ", \, \n 等)
escaped := strconv.Quote(rep)
// 去掉 strconv.Quote 添加的额外引号
escaped = escaped[1 : len(escaped)-1]
template = strings.Replace(template, replaceTag, escaped, 1)
}
return template, nil
}
// 配置选项
type URLValuesOptions struct {
ArrayFormat string // 数组格式:"brackets" -> name[], "indices" -> name[0], "repeat" -> name=value1&name=value2
TimeFormat string // 时间格式
}
var defaultOptions = URLValuesOptions{
ArrayFormat: "brackets", // 默认使用括号格式
TimeFormat: time.DateTime,
}
// StructToURLValues 将结构体转换为 url.Values
func StructToURLValues(input interface{}, options ...URLValuesOptions) (url.Values, error) {
opts := defaultOptions
if len(options) > 0 {
opts = options[0]
}
values := url.Values{}
if input == nil {
return values, nil
}
v := reflect.ValueOf(input)
t := reflect.TypeOf(input)
// 如果是指针,获取其指向的值
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return values, nil
}
v = v.Elem()
t = t.Elem()
}
// 确保是结构体类型
if v.Kind() != reflect.Struct {
return nil, fmt.Errorf("input must be a struct or pointer to struct")
}
// 遍历结构体字段
for i := 0; i < v.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
// 跳过非导出字段
if !field.IsExported() {
continue
}
// 解析 JSON 标签(也可以支持 form 标签)
tag := field.Tag.Get("json")
fieldName, omitempty := parseJSONTag(tag)
if fieldName == "-" {
continue // 忽略该字段
}
if fieldName == "" {
fieldName = field.Name
}
// 处理指针类型
if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
if omitempty {
continue
}
// 可以为 nil 指针添加空值
values.Set(fieldName, "")
continue
}
fieldValue = fieldValue.Elem()
}
// 处理切片/数组
if fieldValue.Kind() == reflect.Slice || fieldValue.Kind() == reflect.Array {
if fieldValue.Len() == 0 && omitempty {
continue
}
// 将切片转换为 URL 参数
err := addSliceToValues(values, fieldName, fieldValue, opts)
if err != nil {
return nil, err
}
continue
}
// 检查是否需要忽略空值
if omitempty && isEmptyValue(fieldValue) {
continue
}
// 转换单个值
str, err := valueToString(fieldValue, opts)
if err != nil {
return nil, err
}
values.Set(fieldName, str)
}
return values, nil
}
// 解析 JSON 标签
func parseJSONTag(tag string) (fieldName string, omitempty bool) {
if tag == "" {
return "", false
}
parts := strings.Split(tag, ",")
fieldName = parts[0]
if len(parts) > 1 {
for _, part := range parts[1:] {
if part == "omitempty" {
omitempty = true
}
}
}
return fieldName, omitempty
}
// 添加切片到 values
func addSliceToValues(values url.Values, fieldName string, slice reflect.Value, opts URLValuesOptions) error {
length := slice.Len()
if length == 0 {
return nil
}
switch opts.ArrayFormat {
case "brackets":
// 格式field[]=value1&field[]=value2
for i := 0; i < length; i++ {
item := slice.Index(i)
str, err := valueToString(item, opts)
if err != nil {
return err
}
values.Add(fieldName, str)
}
case "indices":
// 格式field[0]=value1&field[1]=value2
for i := 0; i < length; i++ {
item := slice.Index(i)
str, err := valueToString(item, opts)
if err != nil {
return err
}
values.Set(fmt.Sprintf("%s[%d]", fieldName, i), str)
}
case "repeat":
// 格式field=value1&field=value2
for i := 0; i < length; i++ {
item := slice.Index(i)
str, err := valueToString(item, opts)
if err != nil {
return err
}
values.Add(fieldName, str)
}
default:
// 默认使用 brackets 格式
for i := 0; i < length; i++ {
item := slice.Index(i)
str, err := valueToString(item, opts)
if err != nil {
return err
}
values.Add(fieldName+"[]", str)
}
}
return nil
}
// 将值转换为字符串
func valueToString(v reflect.Value, opts URLValuesOptions) (string, error) {
if !v.IsValid() {
return "", nil
}
// 处理不同类型
switch v.Kind() {
case reflect.String:
return v.String(), nil
case reflect.Bool:
return strconv.FormatBool(v.Bool()), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(v.Int(), 10), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.FormatUint(v.Uint(), 10), nil
case reflect.Float32, reflect.Float64:
return strconv.FormatFloat(v.Float(), 'f', -1, 64), nil
case reflect.Struct:
// 特殊处理 time.Time
if t, ok := v.Interface().(time.Time); ok {
return t.Format(opts.TimeFormat), nil
}
// 其他结构体递归处理
// 这里可以扩展为递归处理嵌套结构体
default:
// 默认使用 fmt 的字符串表示
return fmt.Sprintf("%v", v.Interface()), nil
}
return fmt.Sprintf("%v", v.Interface()), nil
}
// 检查值是否为空
func isEmptyValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.String:
return v.String() == ""
case reflect.Bool:
return false
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Slice, reflect.Array, reflect.Map:
return v.Len() == 0
case reflect.Ptr, reflect.Interface:
return v.IsNil()
case reflect.Struct:
if t, ok := v.Interface().(time.Time); ok {
return t.IsZero()
}
return false
default:
return false
}
}
// 方便函数:直接生成查询字符串
func StructToQueryString(input interface{}, options ...URLValuesOptions) (string, error) {
values, err := StructToURLValues(input, options...)
if err != nil {
return "", err
}
return values.Encode(), nil
}
const (
letterBytes = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" // 62个字符
)
// RandomString 生成随机字符串,包含 0-9, a-z, A-Z
// length: 要生成的字符串长度
func RandomString(length int) string {
// 使用 crypto/rand 替代 math/rand更安全适用于密码学场景
// 但如果不需要高安全性math/rand 更快
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
result := make([]byte, length)
for i := range result {
result[i] = letterBytes[rng.Intn(len(letterBytes))]
}
return string(result)
}