" + content + "
", nil + case "markdown": + converter := md.NewConverter("", true, nil) + return converter.ConvertString("" + content + "
") + default: + return strip.StripTags(content), nil + } +} + +func ParseTags(tagStr string) []string { + if tagStr == "" { + return []string{} + } + tags := strings.Split(tagStr, ",") + result := make([]string, 0) + for _, t := range tags { + trimmed := strings.TrimSpace(t) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} + +func IsTimeExceeded(targetTime string) bool { + // 实现时间比较 + return false +} diff --git a/pkg/func.go b/pkg/func.go new file mode 100644 index 0000000..03f68c7 --- /dev/null +++ b/pkg/func.go @@ -0,0 +1,247 @@ +package pkg + +import ( + "encoding/json" + "fmt" + "io" + "math/rand/v2" + "net/http" + "os" + "path/filepath" + "reflect" + "time" + + "github.com/go-viper/mapstructure/v2" + "github.com/google/uuid" +) + +func GetModuleDir() (string, error) { + dir, err := os.Getwd() + if err != nil { + return "", err + } + + for { + modPath := filepath.Join(dir, "go.mod") + if _, err := os.Stat(modPath); err == nil { + return dir, nil // 找到 go.mod + } + + // 向上查找父目录 + parent := filepath.Dir(dir) + if parent == dir { + break // 到达根目录,未找到 + } + dir = parent + } + + return "", fmt.Errorf("go.mod not found in current directory or parents") +} + +// GetCacheDir 用于获取缓存目录路径 +// 如果缓存目录不存在,则会自动创建 +// 返回值: +// - string: 缓存目录的路径 +// - error: 如果获取模块目录失败或创建缓存目录失败,则返回错误信息 +func GetCacheDir() (string, error) { + // 获取模块目录 + modDir, err := GetModuleDir() + if err != nil { + return "", err + } + // 拼接缓存目录路径 + path := fmt.Sprintf("%s/cache", modDir) + // 创建目录(包括所有必要的父目录),权限设置为0755 + err = os.MkdirAll(path, 0755) + if err != nil { + return "", fmt.Errorf("创建目录失败: %w", err) + } + // 返回成功创建的缓存目录路径 + return path, nil +} + +func GetTmplDir() (string, error) { + modDir, err := GetModuleDir() + if err != nil { + return "", err + } + path := fmt.Sprintf("%s/tmpl", modDir) + err = os.MkdirAll(path, 0755) + if err != nil { + return "", fmt.Errorf("创建目录失败: %w", err) + } + return path, nil +} + +func ReverseSliceNew[T any](s []T) []T { + result := make([]T, len(s)) + for i := 0; i < len(s); i++ { + result[i] = s[len(s)-1-i] + } + return result +} + +func JsonStringIgonErr(data interface{}) string { + return string(JsonByteIgonErr(data)) +} + +func JsonByteIgonErr(data interface{}) []byte { + dataByte, _ := json.Marshal(data) + return dataByte +} + +func IntersectionGeneric[T comparable](slice1, slice2 []T) []T { + m := make(map[T]bool) + result := []T{} + + for _, v := range slice1 { + m[v] = true + } + + for _, v := range slice2 { + if m[v] { + result = append(result, v) + delete(m, v) // 避免重复 + } + } + + return result +} + +func CreateOrderNum(prefix string) string { + code := fmt.Sprintf("%04d", rand.IntN(10000)) + fmt.Println("4位随机数字:", code) // 输出示例: "0837" + return prefix + time.Now().Format("20060102150405") + code +} + +func BuildUpdateMap(obj interface{}, omitFields ...string) map[string]interface{} { + result := make(map[string]interface{}) + omitMap := make(map[string]bool) + for _, f := range omitFields { + omitMap[f] = true + } + + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldName := t.Field(i).Name + + if omitMap[fieldName] { + continue + } + + // 只处理非 nil 的指针字段 + if field.Kind() == reflect.Ptr && !field.IsNil() { + // 将驼峰转为下划线(可选的,根据你的数据库列名决定) + colName := CamelToSnake(fieldName) + result[colName] = field.Elem().Interface() + } + } + return result +} + +func CamelToSnake(s string) string { + var result []rune + for i, r := range s { + if i > 0 && r >= 'A' && r <= 'Z' { + result = append(result, '_') + } + result = append(result, r) + } + return string(result) +} + +func CopyNonNilFields(src, dst interface{}) error { + config := &mapstructure.DecoderConfig{ + Result: dst, + TagName: "json", + ZeroFields: false, // 重要:不清零目标字段 + Squash: false, + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(src) +} + +func DownloadFile(url string, saveDir string, filename string) (string, error) { + os.MkdirAll(saveDir, 0755) + + if filename == "" { + filename = uuid.New().String() + ".docx" + } + + filePath := filepath.Join(saveDir, filename) + + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + out, err := os.Create(filePath) + if err != nil { + return "", err + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + if err != nil { + return "", err + } + + absPath, _ := filepath.Abs(filePath) + return absPath, nil +} + +func DownloadImage(url string, requestID string, dir string) (string, error) { + os.MkdirAll(dir, 0755) + + ext := filepath.Ext(url) + if ext == "" { + ext = ".jpg" + } + filename := requestID + "_" + uuid.New().String() + ext + filePath := filepath.Join(dir, filename) + + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + out, err := os.Create(filePath) + if err != nil { + return "", err + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + if err != nil { + return "", err + } + + return filepath.Abs(filePath) +} + +func DeleteFile(path string) { + if path != "" { + os.Remove(path) + } +} + +func GenerateUUID() string { + return uuid.New().String() +} + +func GenerateUserIndex() string { + return uuid.New().String()[:20] +} diff --git a/pkg/mapstructure/decode_hooks.go b/pkg/mapstructure/decode_hooks.go new file mode 100644 index 0000000..3a754ca --- /dev/null +++ b/pkg/mapstructure/decode_hooks.go @@ -0,0 +1,279 @@ +package mapstructure + +import ( + "encoding" + "errors" + "fmt" + "net" + "reflect" + "strconv" + "strings" + "time" +) + +// typedDecodeHook takes a raw DecodeHookFunc (an interface{}) and turns +// it into the proper DecodeHookFunc type, such as DecodeHookFuncType. +func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc { + // Create variables here so we can reference them with the reflect pkg + var f1 DecodeHookFuncType + var f2 DecodeHookFuncKind + var f3 DecodeHookFuncValue + + // Fill in the variables into this interface and the rest is done + // automatically using the reflect package. + potential := []interface{}{f1, f2, f3} + + v := reflect.ValueOf(h) + vt := v.Type() + for _, raw := range potential { + pt := reflect.ValueOf(raw).Type() + if vt.ConvertibleTo(pt) { + return v.Convert(pt).Interface() + } + } + + return nil +} + +// DecodeHookExec executes the given decode hook. This should be used +// since it'll naturally degrade to the older backwards compatible DecodeHookFunc +// that took reflect.Kind instead of reflect.Type. +func DecodeHookExec( + raw DecodeHookFunc, + from reflect.Value, to reflect.Value) (interface{}, error) { + + switch f := typedDecodeHook(raw).(type) { + case DecodeHookFuncType: + return f(from.Type(), to.Type(), from.Interface()) + case DecodeHookFuncKind: + return f(from.Kind(), to.Kind(), from.Interface()) + case DecodeHookFuncValue: + return f(from, to) + default: + return nil, errors.New("invalid decode hook signature") + } +} + +// ComposeDecodeHookFunc creates a single DecodeHookFunc that +// automatically composes multiple DecodeHookFuncs. +// +// The composed funcs are called in order, with the result of the +// previous transformation. +func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc { + return func(f reflect.Value, t reflect.Value) (interface{}, error) { + var err error + data := f.Interface() + + newFrom := f + for _, f1 := range fs { + data, err = DecodeHookExec(f1, newFrom, t) + if err != nil { + return nil, err + } + newFrom = reflect.ValueOf(data) + } + + return data, nil + } +} + +// OrComposeDecodeHookFunc executes all input hook functions until one of them returns no error. In that case its value is returned. +// If all hooks return an error, OrComposeDecodeHookFunc returns an error concatenating all error messages. +func OrComposeDecodeHookFunc(ff ...DecodeHookFunc) DecodeHookFunc { + return func(a, b reflect.Value) (interface{}, error) { + var allErrs string + var out interface{} + var err error + + for _, f := range ff { + out, err = DecodeHookExec(f, a, b) + if err != nil { + allErrs += err.Error() + "\n" + continue + } + + return out, nil + } + + return nil, errors.New(allErrs) + } +} + +// StringToSliceHookFunc returns a DecodeHookFunc that converts +// string to []string by splitting on the given sep. +func StringToSliceHookFunc(sep string) DecodeHookFunc { + return func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + if f != reflect.String || t != reflect.Slice { + return data, nil + } + + raw := data.(string) + if raw == "" { + return []string{}, nil + } + + return strings.Split(raw, sep), nil + } +} + +// StringToTimeDurationHookFunc returns a DecodeHookFunc that converts +// strings to time.Duration. +func StringToTimeDurationHookFunc() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(time.Duration(5)) { + return data, nil + } + + // Convert it by parsing + return time.ParseDuration(data.(string)) + } +} + +// StringToIPHookFunc returns a DecodeHookFunc that converts +// strings to net.IP +func StringToIPHookFunc() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(net.IP{}) { + return data, nil + } + + // Convert it by parsing + ip := net.ParseIP(data.(string)) + if ip == nil { + return net.IP{}, fmt.Errorf("failed parsing ip %v", data) + } + + return ip, nil + } +} + +// StringToIPNetHookFunc returns a DecodeHookFunc that converts +// strings to net.IPNet +func StringToIPNetHookFunc() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(net.IPNet{}) { + return data, nil + } + + // Convert it by parsing + _, net, err := net.ParseCIDR(data.(string)) + return net, err + } +} + +// StringToTimeHookFunc returns a DecodeHookFunc that converts +// strings to time.Time. +func StringToTimeHookFunc(layout string) DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(time.Time{}) { + return data, nil + } + + // Convert it by parsing + return time.Parse(layout, data.(string)) + } +} + +// WeaklyTypedHook is a DecodeHookFunc which adds support for weak typing to +// the decoder. +// +// Note that this is significantly different from the WeaklyTypedInput option +// of the DecoderConfig. +func WeaklyTypedHook( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + dataVal := reflect.ValueOf(data) + switch t { + case reflect.String: + switch f { + case reflect.Bool: + if dataVal.Bool() { + return "1", nil + } + return "0", nil + case reflect.Float32: + return strconv.FormatFloat(dataVal.Float(), 'f', -1, 64), nil + case reflect.Int: + return strconv.FormatInt(dataVal.Int(), 10), nil + case reflect.Slice: + dataType := dataVal.Type() + elemKind := dataType.Elem().Kind() + if elemKind == reflect.Uint8 { + return string(dataVal.Interface().([]uint8)), nil + } + case reflect.Uint: + return strconv.FormatUint(dataVal.Uint(), 10), nil + } + } + + return data, nil +} + +func RecursiveStructToMapHookFunc() DecodeHookFunc { + return func(f reflect.Value, t reflect.Value) (interface{}, error) { + if f.Kind() != reflect.Struct { + return f.Interface(), nil + } + + var i interface{} = struct{}{} + if t.Type() != reflect.TypeOf(&i).Elem() { + return f.Interface(), nil + } + + m := make(map[string]interface{}) + t.Set(reflect.ValueOf(m)) + + return f.Interface(), nil + } +} + +// TextUnmarshallerHookFunc returns a DecodeHookFunc that applies +// strings to the UnmarshalText function, when the target type +// implements the encoding.TextUnmarshaler interface +func TextUnmarshallerHookFunc() DecodeHookFuncType { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + result := reflect.New(t).Interface() + unmarshaller, ok := result.(encoding.TextUnmarshaler) + if !ok { + return data, nil + } + if err := unmarshaller.UnmarshalText([]byte(data.(string))); err != nil { + return nil, err + } + return result, nil + } +} diff --git a/pkg/mapstructure/decode_hooks_test.go b/pkg/mapstructure/decode_hooks_test.go new file mode 100644 index 0000000..07fbedf --- /dev/null +++ b/pkg/mapstructure/decode_hooks_test.go @@ -0,0 +1,567 @@ +package mapstructure + +import ( + "errors" + "math/big" + "net" + "reflect" + "testing" + "time" +) + +func TestComposeDecodeHookFunc(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "foo", nil + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + + f := ComposeDecodeHookFunc(f1, f2) + + result, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "foobar" { + t.Fatalf("bad: %#v", result) + } +} + +func TestComposeDecodeHookFunc_err(t *testing.T) { + f1 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) { + return nil, errors.New("foo") + } + + f2 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) { + panic("NOPE") + } + + f := ComposeDecodeHookFunc(f1, f2) + + _, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err.Error() != "foo" { + t.Fatalf("bad: %s", err) + } +} + +func TestComposeDecodeHookFunc_kinds(t *testing.T) { + var f2From reflect.Kind + + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return int(42), nil + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + f2From = f + return data, nil + } + + f := ComposeDecodeHookFunc(f1, f2) + + _, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if f2From != reflect.Int { + t.Fatalf("bad: %#v", f2From) + } +} + +func TestOrComposeDecodeHookFunc(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "foo", nil + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + + f := OrComposeDecodeHookFunc(f1, f2) + + result, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "foo" { + t.Fatalf("bad: %#v", result) + } +} + +func TestOrComposeDecodeHookFunc_correctValueIsLast(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f1 error") + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f2 error") + } + + f3 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + + f := OrComposeDecodeHookFunc(f1, f2, f3) + + result, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "bar" { + t.Fatalf("bad: %#v", result) + } +} + +func TestOrComposeDecodeHookFunc_err(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f1 error") + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f2 error") + } + + f := OrComposeDecodeHookFunc(f1, f2) + + _, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err == nil { + t.Fatalf("bad: should return an error") + } + if err.Error() != "f1 error\nf2 error\n" { + t.Fatalf("bad: %s", err) + } +} + +func TestComposeDecodeHookFunc_safe_nofuncs(t *testing.T) { + f := ComposeDecodeHookFunc() + type myStruct2 struct { + MyInt int + } + + type myStruct1 struct { + Blah map[string]myStruct2 + } + + src := &myStruct1{Blah: map[string]myStruct2{ + "test": { + MyInt: 1, + }, + }} + + dst := &myStruct1{} + dConf := &DecoderConfig{ + Result: dst, + ErrorUnused: true, + DecodeHook: f, + } + d, err := NewDecoder(dConf) + if err != nil { + t.Fatal(err) + } + err = d.Decode(src) + if err != nil { + t.Fatal(err) + } +} + +func TestStringToSliceHookFunc(t *testing.T) { + f := StringToSliceHookFunc(",") + + strValue := reflect.ValueOf("42") + sliceValue := reflect.ValueOf([]byte("42")) + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {sliceValue, sliceValue, []byte("42"), false}, + {strValue, strValue, "42", false}, + { + reflect.ValueOf("foo,bar,baz"), + sliceValue, + []string{"foo", "bar", "baz"}, + false, + }, + { + reflect.ValueOf(""), + sliceValue, + []string{}, + false, + }, + } + + for i, tc := range cases { + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStringToTimeDurationHookFunc(t *testing.T) { + f := StringToTimeDurationHookFunc() + + timeValue := reflect.ValueOf(time.Duration(5)) + strValue := reflect.ValueOf("") + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("5s"), timeValue, 5 * time.Second, false}, + {reflect.ValueOf("5"), timeValue, time.Duration(0), true}, + {reflect.ValueOf("5"), strValue, "5", false}, + } + + for i, tc := range cases { + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStringToTimeHookFunc(t *testing.T) { + strValue := reflect.ValueOf("5") + timeValue := reflect.ValueOf(time.Time{}) + cases := []struct { + f, t reflect.Value + layout string + result interface{} + err bool + }{ + {reflect.ValueOf("2006-01-02T15:04:05Z"), timeValue, time.RFC3339, + time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC), false}, + {strValue, timeValue, time.RFC3339, time.Time{}, true}, + {strValue, strValue, time.RFC3339, "5", false}, + } + + for i, tc := range cases { + f := StringToTimeHookFunc(tc.layout) + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStringToIPHookFunc(t *testing.T) { + strValue := reflect.ValueOf("5") + ipValue := reflect.ValueOf(net.IP{}) + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("1.2.3.4"), ipValue, + net.IPv4(0x01, 0x02, 0x03, 0x04), false}, + {strValue, ipValue, net.IP{}, true}, + {strValue, strValue, "5", false}, + } + + for i, tc := range cases { + f := StringToIPHookFunc() + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStringToIPNetHookFunc(t *testing.T) { + strValue := reflect.ValueOf("5") + ipNetValue := reflect.ValueOf(net.IPNet{}) + var nilNet *net.IPNet = nil + + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("1.2.3.4/24"), ipNetValue, + &net.IPNet{ + IP: net.IP{0x01, 0x02, 0x03, 0x00}, + Mask: net.IPv4Mask(0xff, 0xff, 0xff, 0x00), + }, false}, + {strValue, ipNetValue, nilNet, true}, + {strValue, strValue, "5", false}, + } + + for i, tc := range cases { + f := StringToIPNetHookFunc() + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestWeaklyTypedHook(t *testing.T) { + var f DecodeHookFunc = WeaklyTypedHook + + strValue := reflect.ValueOf("") + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + // TO STRING + { + reflect.ValueOf(false), + strValue, + "0", + false, + }, + + { + reflect.ValueOf(true), + strValue, + "1", + false, + }, + + { + reflect.ValueOf(float32(7)), + strValue, + "7", + false, + }, + + { + reflect.ValueOf(int(7)), + strValue, + "7", + false, + }, + + { + reflect.ValueOf([]uint8("foo")), + strValue, + "foo", + false, + }, + + { + reflect.ValueOf(uint(7)), + strValue, + "7", + false, + }, + } + + for i, tc := range cases { + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStructToMapHookFuncTabled(t *testing.T) { + var f DecodeHookFunc = RecursiveStructToMapHookFunc() + + type b struct { + TestKey string + } + + type a struct { + Sub b + } + + testStruct := a{ + Sub: b{ + TestKey: "testval", + }, + } + + testMap := map[string]interface{}{ + "Sub": map[string]interface{}{ + "TestKey": "testval", + }, + } + + cases := []struct { + name string + receiver interface{} + input interface{} + expected interface{} + err bool + }{ + { + "map receiver", + func() interface{} { + var res map[string]interface{} + return &res + }(), + testStruct, + &testMap, + false, + }, + { + "interface receiver", + func() interface{} { + var res interface{} + return &res + }(), + testStruct, + func() interface{} { + var exp interface{} = testMap + return &exp + }(), + false, + }, + { + "slice receiver errors", + func() interface{} { + var res []string + return &res + }(), + testStruct, + new([]string), + true, + }, + { + "slice to slice - no change", + func() interface{} { + var res []string + return &res + }(), + []string{"a", "b"}, + &[]string{"a", "b"}, + false, + }, + { + "string to string - no change", + func() interface{} { + var res string + return &res + }(), + "test", + func() *string { + s := "test" + return &s + }(), + false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfg := &DecoderConfig{ + DecodeHook: f, + Result: tc.receiver, + } + + d, err := NewDecoder(cfg) + if err != nil { + t.Fatalf("unexpected jderr %#v", err) + } + + err = d.Decode(tc.input) + if tc.err != (err != nil) { + t.Fatalf("expected jderr %#v", err) + } + + if !reflect.DeepEqual(tc.expected, tc.receiver) { + t.Fatalf("expected %#v, got %#v", + tc.expected, tc.receiver) + } + }) + + } +} + +func TestTextUnmarshallerHookFunc(t *testing.T) { + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("42"), reflect.ValueOf(big.Int{}), big.NewInt(42), false}, + {reflect.ValueOf("invalid"), reflect.ValueOf(big.Int{}), nil, true}, + {reflect.ValueOf("5"), reflect.ValueOf("5"), "5", false}, + } + + for i, tc := range cases { + f := TextUnmarshallerHookFunc() + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} diff --git a/pkg/mapstructure/error.go b/pkg/mapstructure/error.go new file mode 100644 index 0000000..47a99e5 --- /dev/null +++ b/pkg/mapstructure/error.go @@ -0,0 +1,50 @@ +package mapstructure + +import ( + "errors" + "fmt" + "sort" + "strings" +) + +// Error implements the error interface and can represents multiple +// errors that occur in the course of a single decode. +type Error struct { + Errors []string +} + +func (e *Error) Error() string { + points := make([]string, len(e.Errors)) + for i, err := range e.Errors { + points[i] = fmt.Sprintf("* %s", err) + } + + sort.Strings(points) + return fmt.Sprintf( + "%d error(s) decoding:\n\n%s", + len(e.Errors), strings.Join(points, "\n")) +} + +// WrappedErrors implements the errwrap.Wrapper interface to make this +// return value more useful with the errwrap and go-multierror libraries. +func (e *Error) WrappedErrors() []error { + if e == nil { + return nil + } + + result := make([]error, len(e.Errors)) + for i, e := range e.Errors { + result[i] = errors.New(e) + } + + return result +} + +func appendErrors(errors []string, err error) []string { + switch e := err.(type) { + case *Error: + return append(errors, e.Errors...) + default: + return append(errors, e.Error()) + } +} diff --git a/pkg/mapstructure/mapstructure.go b/pkg/mapstructure/mapstructure.go new file mode 100644 index 0000000..0d26c75 --- /dev/null +++ b/pkg/mapstructure/mapstructure.go @@ -0,0 +1,1386 @@ +package mapstructure + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "sort" + "strconv" + "strings" +) + +// DecodeHookFunc is the callback function that can be used for +// data transformations. See "DecodeHook" in the DecoderConfig +// struct. +// +// The type must be one of DecodeHookFuncType, DecodeHookFuncKind, or +// DecodeHookFuncValue. +// Values are a superset of Types (Values can return types), and Types are a +// superset of Kinds (Types can return Kinds) and are generally a richer thing +// to use, but Kinds are simpler if you only need those. +// +// The reason DecodeHookFunc is multi-typed is for backwards compatibility: +// we started with Kinds and then realized Types were the better solution, +// but have a promise to not break backwards compat so we now support +// both. +type DecodeHookFunc interface{} + +// DecodeHookFuncType is a DecodeHookFunc which has complete information about +// the source and target types. +type DecodeHookFuncType func(reflect.Type, reflect.Type, interface{}) (interface{}, error) + +// DecodeHookFuncKind is a DecodeHookFunc which knows only the Kinds of the +// source and target types. +type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) + +// DecodeHookFuncValue is a DecodeHookFunc which has complete access to both the source and target +// values. +type DecodeHookFuncValue func(from reflect.Value, to reflect.Value) (interface{}, error) + +// DecoderConfig is the configuration that is used to create a new decoder +// and allows customization of various aspects of decoding. +type DecoderConfig struct { + // DecodeHook, if set, will be called before any decoding and any + // type conversion (if WeaklyTypedInput is on). This lets you modify + // the values before they're set down onto the resulting struct. The + // DecodeHook is called for every map and value in the input. This means + // that if a struct has embedded fields with squash tags the decode hook + // is called only once with all of the input data, not once for each + // embedded struct. + // + // If an error is returned, the entire decode will fail with that error. + DecodeHook DecodeHookFunc + + // If ErrorUnused is true, then it is an error for there to exist + // keys in the original map that were unused in the decoding process + // (extra keys). + ErrorUnused bool + + // If ErrorUnset is true, then it is an error for there to exist + // fields in the result that were not set in the decoding process + // (extra fields). This only applies to decoding to a struct. This + // will affect all nested structs as well. + ErrorUnset bool + + // ZeroFields, if set to true, will zero fields before writing them. + // For example, a map will be emptied before decoded values are put in + // it. If this is false, a map will be merged. + ZeroFields bool + + // If WeaklyTypedInput is true, the decoder will make the following + // "weak" conversions: + // + // - bools to string (true = "1", false = "0") + // - numbers to string (base 10) + // - bools to int/uint (true = 1, false = 0) + // - strings to int/uint (base implied by prefix) + // - int to bool (true if value != 0) + // - string to bool (accepts: 1, t, T, TRUE, true, True, 0, f, F, + // FALSE, false, False. Anything else is an error) + // - empty array = empty map and vice versa + // - negative numbers to overflowed uint values (base 10) + // - slice of maps to a merged map + // - single values are converted to slices if required. Each + // element is weakly decoded. For example: "4" can become []int{4} + // if the target type is an int slice. + // + WeaklyTypedInput bool + + // Squash will squash embedded structs. A squash tag may also be + // added to an individual struct field using a tag. For example: + // + // type Parent struct { + // Child `mapstructure:",squash"` + // } + Squash bool + + // Metadata is the struct that will contain extra metadata about + // the decoding. If this is nil, then no metadata will be tracked. + Metadata *Metadata + + // Result is a pointer to the struct that will contain the decoded + // value. + Result interface{} + + // The tag name that mapstructure reads for field names. This + // defaults to "mapstructure" + TagName string + + // IgnoreUntaggedFields ignores all struct fields without explicit + // TagName, comparable to `mapstructure:"-"` as default behaviour. + IgnoreUntaggedFields bool + + // MatchName is the function used to match the map key to the struct + // field name or tag. Defaults to `strings.EqualFold`. This can be used + // to implement case-sensitive tag values, support snake casing, etc. + MatchName func(mapKey, fieldName string) bool +} + +// A Decoder takes a raw interface value and turns it into structured +// data, keeping track of rich error information along the way in case +// anything goes wrong. Unlike the basic top-level Decode method, you can +// more finely control how the Decoder behaves using the DecoderConfig +// structure. The top-level Decode method is just a convenience that sets +// up the most basic Decoder. +type Decoder struct { + config *DecoderConfig +} + +// Metadata contains information about decoding a structure that +// is tedious or difficult to get otherwise. +type Metadata struct { + // Keys are the keys of the structure which were successfully decoded + Keys []string + + // Unused is a slice of keys that were found in the raw value but + // weren't decoded since there was no matching field in the result interface + Unused []string + + // Unset is a slice of field names that were found in the result interface + // but weren't set in the decoding process since there was no matching value + // in the input + Unset []string +} + +// Decode takes an input structure and uses reflection to translate it to +// the output structure. output must be a pointer to a map or struct. +func Decode(input interface{}, output interface{}) error { + config := &DecoderConfig{ + Metadata: nil, + Result: output, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +// WeakDecode is the same as Decode but is shorthand to enable +// WeaklyTypedInput. See DecoderConfig for more info. +func WeakDecode(input, output interface{}) error { + config := &DecoderConfig{ + Metadata: nil, + Result: output, + WeaklyTypedInput: true, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +// DecodeMetadata is the same as Decode, but is shorthand to +// enable metadata collection. See DecoderConfig for more info. +func DecodeMetadata(input interface{}, output interface{}, metadata *Metadata) error { + config := &DecoderConfig{ + Metadata: metadata, + Result: output, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +// WeakDecodeMetadata is the same as Decode, but is shorthand to +// enable both WeaklyTypedInput and metadata collection. See +// DecoderConfig for more info. +func WeakDecodeMetadata(input interface{}, output interface{}, metadata *Metadata) error { + config := &DecoderConfig{ + Metadata: metadata, + Result: output, + WeaklyTypedInput: true, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +// NewDecoder returns a new decoder for the given configuration. Once +// a decoder has been returned, the same configuration must not be used +// again. +func NewDecoder(config *DecoderConfig) (*Decoder, error) { + val := reflect.ValueOf(config.Result) + if val.Kind() != reflect.Ptr { + return nil, errors.New("result must be a pointer") + } + + val = val.Elem() + if !val.CanAddr() { + return nil, errors.New("result must be addressable (a pointer)") + } + + if config.Metadata != nil { + if config.Metadata.Keys == nil { + config.Metadata.Keys = make([]string, 0) + } + + if config.Metadata.Unused == nil { + config.Metadata.Unused = make([]string, 0) + } + + if config.Metadata.Unset == nil { + config.Metadata.Unset = make([]string, 0) + } + } + + if config.TagName == "" { + config.TagName = "mapstructure" + } + + if config.MatchName == nil { + config.MatchName = strings.EqualFold + } + + result := &Decoder{ + config: config, + } + + return result, nil +} + +// Decode decodes the given raw interface to the target pointer specified +// by the configuration. +func (d *Decoder) Decode(input interface{}) error { + return d.decode("", input, reflect.ValueOf(d.config.Result).Elem()) +} + +// Decodes an unknown data type into a specific reflection value. +func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) error { + var inputVal reflect.Value + if input != nil { + inputVal = reflect.ValueOf(input) + + // We need to check here if input is a typed nil. Typed nils won't + // match the "input == nil" below so we check that here. + if inputVal.Kind() == reflect.Ptr && inputVal.IsNil() { + input = nil + } + } + + if input == nil { + // If the data is nil, then we don't set anything, unless ZeroFields is set + // to true. + if d.config.ZeroFields { + outVal.Set(reflect.Zero(outVal.Type())) + + if d.config.Metadata != nil && name != "" { + d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) + } + } + return nil + } + + if !inputVal.IsValid() { + // If the input value is invalid, then we just set the value + // to be the zero value. + outVal.Set(reflect.Zero(outVal.Type())) + if d.config.Metadata != nil && name != "" { + d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) + } + return nil + } + + if d.config.DecodeHook != nil { + // We have a DecodeHook, so let's pre-process the input. + var err error + input, err = DecodeHookExec(d.config.DecodeHook, inputVal, outVal) + if err != nil { + return fmt.Errorf("error decoding '%s': %s", name, err) + } + } + + var err error + outputKind := getKind(outVal) + addMetaKey := true + switch outputKind { + case reflect.Bool: + err = d.decodeBool(name, input, outVal) + case reflect.Interface: + err = d.decodeBasic(name, input, outVal) + case reflect.String: + err = d.decodeString(name, input, outVal) + case reflect.Int: + err = d.decodeInt(name, input, outVal) + case reflect.Uint: + err = d.decodeUint(name, input, outVal) + case reflect.Float32: + err = d.decodeFloat(name, input, outVal) + case reflect.Struct: + err = d.decodeStruct(name, input, outVal) + case reflect.Map: + err = d.decodeMap(name, input, outVal) + case reflect.Ptr: + addMetaKey, err = d.decodePtr(name, input, outVal) + case reflect.Slice: + err = d.decodeSlice(name, input, outVal) + case reflect.Array: + err = d.decodeArray(name, input, outVal) + case reflect.Func: + err = d.decodeFunc(name, input, outVal) + default: + // If we reached this point then we weren't able to decode it + return fmt.Errorf("%s: unsupported type: %s", name, outputKind) + } + + // If we reached here, then we successfully decoded SOMETHING, so + // mark the key as used if we're tracking metainput. + if addMetaKey && d.config.Metadata != nil && name != "" { + d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) + } + + return err +} + +// This decodes a basic type (bool, int, string, etc.) and sets the +// value to "data" of that type. +func (d *Decoder) decodeBasic(name string, data interface{}, val reflect.Value) error { + if val.IsValid() && val.Elem().IsValid() { + elem := val.Elem() + + // If we can't address this element, then its not writable. Instead, + // we make a copy of the value (which is a pointer and therefore + // writable), decode into that, and replace the whole value. + copied := false + if !elem.CanAddr() { + copied = true + + // Make *T + copy := reflect.New(elem.Type()) + + // *T = elem + copy.Elem().Set(elem) + + // Set elem so we decode into it + elem = copy + } + + // Decode. If we have an error then return. We also return right + // away if we're not a copy because that means we decoded directly. + if err := d.decode(name, data, elem); err != nil || !copied { + return err + } + + // If we're a copy, we need to set te final result + val.Set(elem.Elem()) + return nil + } + + dataVal := reflect.ValueOf(data) + + // If the input data is a pointer, and the assigned type is the dereference + // of that exact pointer, then indirect it so that we can assign it. + // Example: *string to string + if dataVal.Kind() == reflect.Ptr && dataVal.Type().Elem() == val.Type() { + dataVal = reflect.Indirect(dataVal) + } + + if !dataVal.IsValid() { + dataVal = reflect.Zero(val.Type()) + } + + dataValType := dataVal.Type() + if !dataValType.AssignableTo(val.Type()) { + return fmt.Errorf( + "'%s' expected type '%s', got '%s'", + name, val.Type(), dataValType) + } + + val.Set(dataVal) + return nil +} + +func (d *Decoder) decodeString(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + + converted := true + switch { + case dataKind == reflect.String: + val.SetString(dataVal.String()) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetString("1") + } else { + val.SetString("0") + } + case dataKind == reflect.Int && d.config.WeaklyTypedInput: + val.SetString(strconv.FormatInt(dataVal.Int(), 10)) + case dataKind == reflect.Uint && d.config.WeaklyTypedInput: + val.SetString(strconv.FormatUint(dataVal.Uint(), 10)) + case dataKind == reflect.Float32 && d.config.WeaklyTypedInput: + val.SetString(strconv.FormatFloat(dataVal.Float(), 'f', -1, 64)) + case dataKind == reflect.Slice && d.config.WeaklyTypedInput, + dataKind == reflect.Array && d.config.WeaklyTypedInput: + dataType := dataVal.Type() + elemKind := dataType.Elem().Kind() + switch elemKind { + case reflect.Uint8: + var uints []uint8 + if dataKind == reflect.Array { + uints = make([]uint8, dataVal.Len(), dataVal.Len()) + for i := range uints { + uints[i] = dataVal.Index(i).Interface().(uint8) + } + } else { + uints = dataVal.Interface().([]uint8) + } + val.SetString(string(uints)) + default: + converted = false + } + default: + converted = false + } + + if !converted { + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeInt(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + dataType := dataVal.Type() + + switch { + case dataKind == reflect.Int: + val.SetInt(dataVal.Int()) + case dataKind == reflect.Uint: + val.SetInt(int64(dataVal.Uint())) + case dataKind == reflect.Float32: + val.SetInt(int64(dataVal.Float())) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetInt(1) + } else { + val.SetInt(0) + } + case dataKind == reflect.String && d.config.WeaklyTypedInput: + str := dataVal.String() + if str == "" { + str = "0" + } + + i, err := strconv.ParseInt(str, 0, val.Type().Bits()) + if err == nil { + val.SetInt(i) + } else { + return fmt.Errorf("cannot parse '%s' as int: %s", name, err) + } + case dataType.PkgPath() == "encoding/json" && dataType.Name() == "Number": + jn := data.(json.Number) + i, err := jn.Int64() + if err != nil { + return fmt.Errorf( + "error decoding json.Number into %s: %s", name, err) + } + val.SetInt(i) + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeUint(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + dataType := dataVal.Type() + + switch { + case dataKind == reflect.Int: + i := dataVal.Int() + if i < 0 && !d.config.WeaklyTypedInput { + return fmt.Errorf("cannot parse '%s', %d overflows uint", + name, i) + } + val.SetUint(uint64(i)) + case dataKind == reflect.Uint: + val.SetUint(dataVal.Uint()) + case dataKind == reflect.Float32: + f := dataVal.Float() + if f < 0 && !d.config.WeaklyTypedInput { + return fmt.Errorf("cannot parse '%s', %f overflows uint", + name, f) + } + val.SetUint(uint64(f)) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetUint(1) + } else { + val.SetUint(0) + } + case dataKind == reflect.String && d.config.WeaklyTypedInput: + str := dataVal.String() + if str == "" { + str = "0" + } + + i, err := strconv.ParseUint(str, 0, val.Type().Bits()) + if err == nil { + val.SetUint(i) + } else { + return fmt.Errorf("cannot parse '%s' as uint: %s", name, err) + } + case dataType.PkgPath() == "encoding/json" && dataType.Name() == "Number": + jn := data.(json.Number) + i, err := strconv.ParseUint(string(jn), 0, 64) + if err != nil { + return fmt.Errorf( + "error decoding json.Number into %s: %s", name, err) + } + val.SetUint(i) + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeBool(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + + switch { + case dataKind == reflect.Bool: + val.SetBool(dataVal.Bool()) + case dataKind == reflect.Int && d.config.WeaklyTypedInput: + val.SetBool(dataVal.Int() != 0) + case dataKind == reflect.Uint && d.config.WeaklyTypedInput: + val.SetBool(dataVal.Uint() != 0) + case dataKind == reflect.Float32 && d.config.WeaklyTypedInput: + val.SetBool(dataVal.Float() != 0) + case dataKind == reflect.String && d.config.WeaklyTypedInput: + b, err := strconv.ParseBool(dataVal.String()) + if err == nil { + val.SetBool(b) + } else if dataVal.String() == "" { + val.SetBool(false) + } else { + return fmt.Errorf("cannot parse '%s' as bool: %s", name, err) + } + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeFloat(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + dataType := dataVal.Type() + + switch { + case dataKind == reflect.Int: + val.SetFloat(float64(dataVal.Int())) + case dataKind == reflect.Uint: + val.SetFloat(float64(dataVal.Uint())) + case dataKind == reflect.Float32: + val.SetFloat(dataVal.Float()) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetFloat(1) + } else { + val.SetFloat(0) + } + case dataKind == reflect.String && d.config.WeaklyTypedInput: + str := dataVal.String() + if str == "" { + str = "0" + } + + f, err := strconv.ParseFloat(str, val.Type().Bits()) + if err == nil { + val.SetFloat(f) + } else { + return fmt.Errorf("cannot parse '%s' as float: %s", name, err) + } + case dataType.PkgPath() == "encoding/json" && dataType.Name() == "Number": + jn := data.(json.Number) + i, err := jn.Float64() + if err != nil { + return fmt.Errorf( + "error decoding json.Number into %s: %s", name, err) + } + val.SetFloat(i) + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeMap(name string, data interface{}, val reflect.Value) error { + valType := val.Type() + valKeyType := valType.Key() + valElemType := valType.Elem() + + // By default we overwrite keys in the current map + valMap := val + + // If the map is nil or we're purposely zeroing fields, make a new map + if valMap.IsNil() || d.config.ZeroFields { + // Make a new map to hold our result + mapType := reflect.MapOf(valKeyType, valElemType) + valMap = reflect.MakeMap(mapType) + } + + // Check input type and based on the input type jump to the proper func + dataVal := reflect.Indirect(reflect.ValueOf(data)) + switch dataVal.Kind() { + case reflect.Map: + return d.decodeMapFromMap(name, dataVal, val, valMap) + + case reflect.Struct: + return d.decodeMapFromStruct(name, dataVal, val, valMap, data) + + case reflect.Array, reflect.Slice: + if d.config.WeaklyTypedInput { + return d.decodeMapFromSlice(name, dataVal, val, valMap) + } + + fallthrough + + default: + return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind()) + } +} + +func (d *Decoder) decodeMapFromSlice(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error { + // Special case for BC reasons (covered by tests) + if dataVal.Len() == 0 { + val.Set(valMap) + return nil + } + + for i := 0; i < dataVal.Len(); i++ { + err := d.decode( + name+"["+strconv.Itoa(i)+"]", + dataVal.Index(i).Interface(), val) + if err != nil { + return err + } + } + + return nil +} + +func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error { + valType := val.Type() + valKeyType := valType.Key() + valElemType := valType.Elem() + + // Accumulate errors + errors := make([]string, 0) + + // If the input data is empty, then we just match what the input data is. + if dataVal.Len() == 0 { + if dataVal.IsNil() { + if !val.IsNil() { + val.Set(dataVal) + } + } else { + // Set to empty allocated value + val.Set(valMap) + } + + return nil + } + + for _, k := range dataVal.MapKeys() { + fieldName := name + "[" + k.String() + "]" + + // First decode the key into the proper type + currentKey := reflect.Indirect(reflect.New(valKeyType)) + if err := d.decode(fieldName, k.Interface(), currentKey); err != nil { + errors = appendErrors(errors, err) + continue + } + + // Next decode the data into the proper type + v := dataVal.MapIndex(k).Interface() + currentVal := reflect.Indirect(reflect.New(valElemType)) + if err := d.decode(fieldName, v, currentVal); err != nil { + errors = appendErrors(errors, err) + continue + } + + valMap.SetMapIndex(currentKey, currentVal) + } + + // Set the built up map to the value + val.Set(valMap) + + // If we had errors, return those + if len(errors) > 0 { + return &Error{errors} + } + + return nil +} + +func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value, inData interface{}) error { + typ := dataVal.Type() + + for i := 0; i < typ.NumField(); i++ { + // Get the StructField first since this is a cheap operation. If the + // field is unexported, then ignore it. + f := typ.Field(i) + if f.PkgPath != "" { + continue + } + + // Next get the actual value of this field and verify it is assignable + // to the map value. + v := dataVal.Field(i) + if !v.Type().AssignableTo(valMap.Type().Elem()) { + return fmt.Errorf("cannot assign type '%s' to map value field of type '%s'", v.Type(), valMap.Type().Elem()) + } + + tagValue := f.Tag.Get(d.config.TagName) + keyName := f.Name + + if tagValue == "" && d.config.IgnoreUntaggedFields { + continue + } + + // If Squash is set in the config, we squash the field down. + squash := d.config.Squash && v.Kind() == reflect.Struct && f.Anonymous + + v = dereferencePtrToStructIfNeeded(v, d.config.TagName) + + // Determine the name of the key in the map + if index := strings.Index(tagValue, ","); index != -1 { + if tagValue[:index] == "-" { + continue + } + // If "omitempty" is specified in the tag, it ignores empty values. + if strings.Index(tagValue[index+1:], "omitempty") != -1 && isEmptyValue(v) { + continue + } + + // If "squash" is specified in the tag, we squash the field down. + squash = squash || strings.Index(tagValue[index+1:], "squash") != -1 + if squash { + // When squashing, the embedded type can be a pointer to a struct. + if v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct { + v = v.Elem() + } + + // The final type must be a struct + if v.Kind() != reflect.Struct { + return fmt.Errorf("cannot squash non-struct type '%s'", v.Type()) + } + } + if keyNameTagValue := tagValue[:index]; keyNameTagValue != "" { + keyName = keyNameTagValue + } + } else if len(tagValue) > 0 { + if tagValue == "-" { + continue + } + keyName = tagValue + } + + switch v.Kind() { + // this is an embedded struct, so handle it differently + case reflect.Struct: + x := reflect.New(v.Type()) + x.Elem().Set(v) + + vType := valMap.Type() + vKeyType := vType.Key() + vElemType := vType.Elem() + mType := reflect.MapOf(vKeyType, vElemType) + vMap := reflect.MakeMap(mType) + + // Creating a pointer to a map so that other methods can completely + // overwrite the map if need be (looking at you decodeMapFromMap). The + // indirection allows the underlying map to be settable (CanSet() == true) + // where as reflect.MakeMap returns an unsettable map. + addrVal := reflect.New(vMap.Type()) + reflect.Indirect(addrVal).Set(vMap) + + err := d.decode(keyName, x.Interface(), reflect.Indirect(addrVal)) + if err != nil { + return err + } + + // the underlying map may have been completely overwritten so pull + // it indirectly out of the enclosing value. + vMap = reflect.Indirect(addrVal) + + if squash { + for _, k := range vMap.MapKeys() { + valMap.SetMapIndex(k, vMap.MapIndex(k)) + } + } else { + valMap.SetMapIndex(reflect.ValueOf(keyName), vMap) + } + + default: + valMap.SetMapIndex(reflect.ValueOf(keyName), v) + } + + } + + if val.CanAddr() { + val.Set(valMap) + } + + return nil +} + +func (d *Decoder) decodePtr(name string, data interface{}, val reflect.Value) (bool, error) { + // If the input data is nil, then we want to just set the output + // pointer to be nil as well. + isNil := data == nil + if !isNil { + switch v := reflect.Indirect(reflect.ValueOf(data)); v.Kind() { + case reflect.Chan, + reflect.Func, + reflect.Interface, + reflect.Map, + reflect.Ptr, + reflect.Slice: + isNil = v.IsNil() + } + } + if isNil { + if !val.IsNil() && val.CanSet() { + nilValue := reflect.New(val.Type()).Elem() + val.Set(nilValue) + } + + return true, nil + } + + // Create an element of the concrete (non pointer) type and decode + // into that. Then set the value of the pointer to this type. + valType := val.Type() + valElemType := valType.Elem() + if val.CanSet() { + realVal := val + if realVal.IsNil() || d.config.ZeroFields { + realVal = reflect.New(valElemType) + } + + if err := d.decode(name, data, reflect.Indirect(realVal)); err != nil { + // 报错情况下依旧设置指针 + val.Set(realVal) + return false, err + } + + val.Set(realVal) + } else { + if err := d.decode(name, data, reflect.Indirect(val)); err != nil { + return false, err + } + } + return false, nil +} + +func (d *Decoder) decodeFunc(name string, data interface{}, val reflect.Value) error { + // Create an element of the concrete (non pointer) type and decode + // into that. Then set the value of the pointer to this type. + dataVal := reflect.Indirect(reflect.ValueOf(data)) + if val.Type() != dataVal.Type() { + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + val.Set(dataVal) + return nil +} + +func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataValKind := dataVal.Kind() + valType := val.Type() + valElemType := valType.Elem() + sliceType := reflect.SliceOf(valElemType) + + // If we have a non array/slice type then we first attempt to convert. + if dataValKind != reflect.Array && dataValKind != reflect.Slice { + if d.config.WeaklyTypedInput { + switch { + // Slice and array we use the normal logic + case dataValKind == reflect.Slice, dataValKind == reflect.Array: + break + + // Empty maps turn into empty slices + case dataValKind == reflect.Map: + if dataVal.Len() == 0 { + val.Set(reflect.MakeSlice(sliceType, 0, 0)) + return nil + } + // Create slice of maps of other sizes + return d.decodeSlice(name, []interface{}{data}, val) + + case dataValKind == reflect.String && valElemType.Kind() == reflect.Uint8: + return d.decodeSlice(name, []byte(dataVal.String()), val) + + // All other types we try to convert to the slice type + // and "lift" it into it. i.e. a string becomes a string slice. + default: + // Just re-try this function with data as a slice. + return d.decodeSlice(name, []interface{}{data}, val) + } + } + + return fmt.Errorf( + "'%s': source data must be an array or slice, got %s", name, dataValKind) + } + + // If the input value is nil, then don't allocate since empty != nil + if dataValKind != reflect.Array && dataVal.IsNil() { + return nil + } + + valSlice := val + if valSlice.IsNil() || d.config.ZeroFields { + // Make a new slice to hold our result, same size as the original data. + valSlice = reflect.MakeSlice(sliceType, dataVal.Len(), dataVal.Len()) + } + + // Accumulate any errors + errors := make([]string, 0) + + for i := 0; i < dataVal.Len(); i++ { + currentData := dataVal.Index(i).Interface() + for valSlice.Len() <= i { + valSlice = reflect.Append(valSlice, reflect.Zero(valElemType)) + } + currentField := valSlice.Index(i) + + fieldName := name + "[" + strconv.Itoa(i) + "]" + if err := d.decode(fieldName, currentData, currentField); err != nil { + errors = appendErrors(errors, err) + } + } + + // Finally, set the value to the slice we built up + val.Set(valSlice) + + // If there were errors, we return those + if len(errors) > 0 { + return &Error{errors} + } + + return nil +} + +func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataValKind := dataVal.Kind() + valType := val.Type() + valElemType := valType.Elem() + arrayType := reflect.ArrayOf(valType.Len(), valElemType) + + valArray := val + + if valArray.Interface() == reflect.Zero(valArray.Type()).Interface() || d.config.ZeroFields { + // Check input type + if dataValKind != reflect.Array && dataValKind != reflect.Slice { + if d.config.WeaklyTypedInput { + switch { + // Empty maps turn into empty arrays + case dataValKind == reflect.Map: + if dataVal.Len() == 0 { + val.Set(reflect.Zero(arrayType)) + return nil + } + + // All other types we try to convert to the array type + // and "lift" it into it. i.e. a string becomes a string array. + default: + // Just re-try this function with data as a slice. + return d.decodeArray(name, []interface{}{data}, val) + } + } + + return fmt.Errorf( + "'%s': source data must be an array or slice, got %s", name, dataValKind) + + } + if dataVal.Len() > arrayType.Len() { + return fmt.Errorf( + "'%s': expected source data to have length less or equal to %d, got %d", name, arrayType.Len(), dataVal.Len()) + + } + + // Make a new array to hold our result, same size as the original data. + valArray = reflect.New(arrayType).Elem() + } + + // Accumulate any errors + errors := make([]string, 0) + + for i := 0; i < dataVal.Len(); i++ { + currentData := dataVal.Index(i).Interface() + currentField := valArray.Index(i) + + fieldName := name + "[" + strconv.Itoa(i) + "]" + if err := d.decode(fieldName, currentData, currentField); err != nil { + errors = appendErrors(errors, err) + } + } + + // Finally, set the value to the array we built up + val.Set(valArray) + + // If there were errors, we return those + if len(errors) > 0 { + return &Error{errors} + } + + return nil +} + +func (d *Decoder) decodeStruct(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + + // If the type of the value to write to and the data match directly, + // then we just set it directly instead of recursing into the structure. + if dataVal.Type() == val.Type() { + val.Set(dataVal) + return nil + } + + dataValKind := dataVal.Kind() + switch dataValKind { + case reflect.Map: + return d.decodeStructFromMap(name, dataVal, val) + + case reflect.Struct: + // Not the most efficient way to do this but we can optimize later if + // we want to. To convert from struct to struct we go to map first + // as an intermediary. + + // Make a new map to hold our result + mapType := reflect.TypeOf((map[string]interface{})(nil)) + mval := reflect.MakeMap(mapType) + + // Creating a pointer to a map so that other methods can completely + // overwrite the map if need be (looking at you decodeMapFromMap). The + // indirection allows the underlying map to be settable (CanSet() == true) + // where as reflect.MakeMap returns an unsettable map. + addrVal := reflect.New(mval.Type()) + + reflect.Indirect(addrVal).Set(mval) + if err := d.decodeMapFromStruct(name, dataVal, reflect.Indirect(addrVal), mval, data); err != nil { + return err + } + + result := d.decodeStructFromMap(name, reflect.Indirect(addrVal), val) + return result + + default: + return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind()) + } +} + +func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) error { + dataValType := dataVal.Type() + if kind := dataValType.Key().Kind(); kind != reflect.String && kind != reflect.Interface { + return fmt.Errorf( + "'%s' needs a map with string keys, has '%s' keys", + name, dataValType.Key().Kind()) + } + + dataValKeys := make(map[reflect.Value]struct{}) + dataValKeysUnused := make(map[interface{}]struct{}) + for _, dataValKey := range dataVal.MapKeys() { + dataValKeys[dataValKey] = struct{}{} + dataValKeysUnused[dataValKey.Interface()] = struct{}{} + } + + targetValKeysUnused := make(map[interface{}]struct{}) + errors := make([]string, 0) + + // This slice will keep track of all the structs we'll be decoding. + // There can be more than one struct if there are embedded structs + // that are squashed. + structs := make([]reflect.Value, 1, 5) + structs[0] = val + + // Compile the list of all the fields that we're going to be decoding + // from all the structs. + type field struct { + field reflect.StructField + val reflect.Value + } + + // remainField is set to a valid field set with the "remain" tag if + // we are keeping track of remaining values. + var remainField *field + + fields := []field{} + for len(structs) > 0 { + structVal := structs[0] + structs = structs[1:] + + structType := structVal.Type() + + for i := 0; i < structType.NumField(); i++ { + fieldType := structType.Field(i) + fieldVal := structVal.Field(i) + if fieldVal.Kind() == reflect.Ptr && fieldVal.Elem().Kind() == reflect.Struct { + // Handle embedded struct pointers as embedded structs. + fieldVal = fieldVal.Elem() + } + + // If "squash" is specified in the tag, we squash the field down. + squash := d.config.Squash && fieldVal.Kind() == reflect.Struct && fieldType.Anonymous + remain := false + + // We always parse the tags cause we're looking for other tags too + tagParts := strings.Split(fieldType.Tag.Get(d.config.TagName), ",") + for _, tag := range tagParts[1:] { + if tag == "squash" { + squash = true + break + } + + if tag == "remain" { + remain = true + break + } + } + + if squash { + if fieldVal.Kind() != reflect.Struct { + errors = appendErrors(errors, + fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind())) + } else { + structs = append(structs, fieldVal) + } + continue + } + + // Build our field + if remain { + remainField = &field{fieldType, fieldVal} + } else { + // Normal struct field, store it away + fields = append(fields, field{fieldType, fieldVal}) + } + } + } + + // for fieldType, field := range fields { + for _, f := range fields { + field, fieldValue := f.field, f.val + fieldName := field.Name + + tagValue := field.Tag.Get(d.config.TagName) + tagValue = strings.SplitN(tagValue, ",", 2)[0] + if tagValue != "" { + fieldName = tagValue + } + + rawMapKey := reflect.ValueOf(fieldName) + rawMapVal := dataVal.MapIndex(rawMapKey) + if !rawMapVal.IsValid() { + // Do a slower search by iterating over each key and + // doing case-insensitive search. + for dataValKey := range dataValKeys { + mK, ok := dataValKey.Interface().(string) + if !ok { + // Not a string key + continue + } + + if d.config.MatchName(mK, fieldName) { + rawMapKey = dataValKey + rawMapVal = dataVal.MapIndex(dataValKey) + break + } + } + + if !rawMapVal.IsValid() { + // There was no matching key in the map for the value in + // the struct. Remember it for potential errors and metadata. + targetValKeysUnused[fieldName] = struct{}{} + continue + } + } + + if !fieldValue.IsValid() { + // This should never happen + panic("field is not valid") + } + + // If we can't set the field, then it is unexported or something, + // and we just continue onwards. + if !fieldValue.CanSet() { + continue + } + + // Delete the key we're using from the unused map so we stop tracking + delete(dataValKeysUnused, rawMapKey.Interface()) + + // If the name is empty string, then we're at the root, and we + // don't dot-join the fields. + if name != "" { + fieldName = name + "." + fieldName + } + + if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil { + errors = appendErrors(errors, err) + } + } + + // If we have a "remain"-tagged field and we have unused keys then + // we put the unused keys directly into the remain field. + if remainField != nil && len(dataValKeysUnused) > 0 { + // Build a map of only the unused values + remain := map[interface{}]interface{}{} + for key := range dataValKeysUnused { + remain[key] = dataVal.MapIndex(reflect.ValueOf(key)).Interface() + } + + // Decode it as-if we were just decoding this map onto our map. + if err := d.decodeMap(name, remain, remainField.val); err != nil { + errors = appendErrors(errors, err) + } + + // Set the map to nil so we have none so that the next check will + // not error (ErrorUnused) + dataValKeysUnused = nil + } + + if d.config.ErrorUnused && len(dataValKeysUnused) > 0 { + keys := make([]string, 0, len(dataValKeysUnused)) + for rawKey := range dataValKeysUnused { + keys = append(keys, rawKey.(string)) + } + sort.Strings(keys) + + err := fmt.Errorf("'%s' has invalid keys: %s", name, strings.Join(keys, ", ")) + errors = appendErrors(errors, err) + } + + if d.config.ErrorUnset && len(targetValKeysUnused) > 0 { + keys := make([]string, 0, len(targetValKeysUnused)) + for rawKey := range targetValKeysUnused { + keys = append(keys, rawKey.(string)) + } + sort.Strings(keys) + + err := fmt.Errorf("'%s' has unset fields: %s", name, strings.Join(keys, ", ")) + errors = appendErrors(errors, err) + } + + if len(errors) > 0 { + return &Error{errors} + } + + // Add the unused keys to the list of unused keys if we're tracking metadata + if d.config.Metadata != nil { + for rawKey := range dataValKeysUnused { + key := rawKey.(string) + if name != "" { + key = name + "." + key + } + + d.config.Metadata.Unused = append(d.config.Metadata.Unused, key) + } + for rawKey := range targetValKeysUnused { + key := rawKey.(string) + if name != "" { + key = name + "." + key + } + + d.config.Metadata.Unset = append(d.config.Metadata.Unset, key) + } + } + + return nil +} + +func isEmptyValue(v reflect.Value) bool { + switch getKind(v) { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + return false +} + +func getKind(val reflect.Value) reflect.Kind { + kind := val.Kind() + + switch { + case kind >= reflect.Int && kind <= reflect.Int64: + return reflect.Int + case kind >= reflect.Uint && kind <= reflect.Uint64: + return reflect.Uint + case kind >= reflect.Float32 && kind <= reflect.Float64: + return reflect.Float32 + default: + return kind + } +} + +func isStructTypeConvertibleToMap(typ reflect.Type, checkMapstructureTags bool, tagName string) bool { + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + if f.PkgPath == "" && !checkMapstructureTags { // check for unexported fields + return true + } + if checkMapstructureTags && f.Tag.Get(tagName) != "" { // check for mapstructure tags inside + return true + } + } + return false +} + +func dereferencePtrToStructIfNeeded(v reflect.Value, tagName string) reflect.Value { + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + return v + } + deref := v.Elem() + derefT := deref.Type() + if isStructTypeConvertibleToMap(derefT, true, tagName) { + return deref + } + return v +} diff --git a/pkg/mapstructure/mapstructure_benchmark_test.go b/pkg/mapstructure/mapstructure_benchmark_test.go new file mode 100644 index 0000000..b9bde7e --- /dev/null +++ b/pkg/mapstructure/mapstructure_benchmark_test.go @@ -0,0 +1,285 @@ +package mapstructure + +import ( + "encoding/json" + "testing" +) + +type Person struct { + Name string + Age int + Emails []string + Extra map[string]string +} + +func Benchmark_Decode(b *testing.B) { + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "emails": []string{"one", "two", "three"}, + "extra": map[string]string{ + "twitter": "mitchellh", + }, + } + + var result Person + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +// decodeViaJSON takes the map data and passes it through encoding/json to convert it into the +// given Go native structure pointed to by v. v must be a pointer to a struct. +func decodeViaJSON(data interface{}, v interface{}) error { + // Perform the task by simply marshalling the input into JSON, + // then unmarshalling it into target native Go struct. + b, err := json.Marshal(data) + if err != nil { + return err + } + return json.Unmarshal(b, v) +} + +func Benchmark_DecodeViaJSON(b *testing.B) { + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "emails": []string{"one", "two", "three"}, + "extra": map[string]string{ + "twitter": "mitchellh", + }, + } + + var result Person + for i := 0; i < b.N; i++ { + decodeViaJSON(input, &result) + } +} + +func Benchmark_JSONUnmarshal(b *testing.B) { + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "emails": []string{"one", "two", "three"}, + "extra": map[string]string{ + "twitter": "mitchellh", + }, + } + + inputB, err := json.Marshal(input) + if err != nil { + b.Fatal("Failed to marshal test input:", err) + } + + var result Person + for i := 0; i < b.N; i++ { + json.Unmarshal(inputB, &result) + } +} + +func Benchmark_DecodeBasic(b *testing.B) { + input := map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "Vuint": 42, + "vbool": true, + "Vfloat": 42.42, + "vsilent": true, + "vdata": 42, + "vjsonInt": json.Number("1234"), + "vjsonFloat": json.Number("1234.5"), + "vjsonNumber": json.Number("1234.5"), + } + + for i := 0; i < b.N; i++ { + var result Basic + Decode(input, &result) + } +} + +func Benchmark_DecodeEmbedded(b *testing.B) { + input := map[string]interface{}{ + "vstring": "foo", + "Basic": map[string]interface{}{ + "vstring": "innerfoo", + }, + "vunique": "bar", + } + + var result Embedded + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeTypeConversion(b *testing.B) { + input := map[string]interface{}{ + "IntToFloat": 42, + "IntToUint": 42, + "IntToBool": 1, + "IntToString": 42, + "UintToInt": 42, + "UintToFloat": 42, + "UintToBool": 42, + "UintToString": 42, + "BoolToInt": true, + "BoolToUint": true, + "BoolToFloat": true, + "BoolToString": true, + "FloatToInt": 42.42, + "FloatToUint": 42.42, + "FloatToBool": 42.42, + "FloatToString": 42.42, + "StringToInt": "42", + "StringToUint": "42", + "StringToBool": "1", + "StringToFloat": "42.42", + "SliceToMap": []interface{}{}, + "MapToSlice": map[string]interface{}{}, + } + + var resultStrict TypeConversionResult + for i := 0; i < b.N; i++ { + Decode(input, &resultStrict) + } +} + +func Benchmark_DecodeMap(b *testing.B) { + input := map[string]interface{}{ + "vfoo": "foo", + "vother": map[interface{}]interface{}{ + "foo": "foo", + "bar": "bar", + }, + } + + var result Map + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeMapOfStruct(b *testing.B) { + input := map[string]interface{}{ + "value": map[string]interface{}{ + "foo": map[string]string{"vstring": "one"}, + "bar": map[string]string{"vstring": "two"}, + }, + } + + var result MapOfStruct + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeSlice(b *testing.B) { + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": []string{"foo", "bar", "baz"}, + } + + var result Slice + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeSliceOfStruct(b *testing.B) { + input := map[string]interface{}{ + "value": []map[string]interface{}{ + {"vstring": "one"}, + {"vstring": "two"}, + }, + } + + var result SliceOfStruct + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeWeaklyTypedInput(b *testing.B) { + // This input can come from anywhere, but typically comes from + // something like decoding JSON, generated by a weakly typed language + // such as PHP. + input := map[string]interface{}{ + "name": 123, // number => string + "age": "42", // string => number + "emails": map[string]interface{}{}, // empty map => empty array + } + + var result Person + config := &DecoderConfig{ + WeaklyTypedInput: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + for i := 0; i < b.N; i++ { + decoder.Decode(input) + } +} + +func Benchmark_DecodeMetadata(b *testing.B) { + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "email": "foo@bar.com", + } + + var md Metadata + var result Person + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + for i := 0; i < b.N; i++ { + decoder.Decode(input) + } +} + +func Benchmark_DecodeMetadataEmbedded(b *testing.B) { + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var md Metadata + var result EmbeddedSquash + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + b.Fatalf("jderr: %s", err) + } + + for i := 0; i < b.N; i++ { + decoder.Decode(input) + } +} + +func Benchmark_DecodeTagged(b *testing.B) { + input := map[string]interface{}{ + "foo": "bar", + "bar": "value", + } + + var result Tagged + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} diff --git a/pkg/mapstructure/mapstructure_bugs_test.go b/pkg/mapstructure/mapstructure_bugs_test.go new file mode 100644 index 0000000..31fa5cd --- /dev/null +++ b/pkg/mapstructure/mapstructure_bugs_test.go @@ -0,0 +1,627 @@ +package mapstructure + +import ( + "reflect" + "testing" + "time" +) + +// GH-1, GH-10, GH-96 +func TestDecode_NilValue(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in interface{} + target interface{} + out interface{} + metaKeys []string + metaUnused []string + }{ + { + "all nil", + &map[string]interface{}{ + "vfoo": nil, + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{}, + }, + { + "partial nil", + &map[string]interface{}{ + "vfoo": "baz", + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "baz", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{}, + }, + { + "partial decode", + &map[string]interface{}{ + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "foo", Vother: nil}, + []string{"Vother"}, + []string{}, + }, + { + "unused values", + &map[string]interface{}{ + "vbar": "bar", + "vfoo": nil, + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{"vbar"}, + }, + { + "map interface all nil", + &map[interface{}]interface{}{ + "vfoo": nil, + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{}, + }, + { + "map interface partial nil", + &map[interface{}]interface{}{ + "vfoo": "baz", + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "baz", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{}, + }, + { + "map interface partial decode", + &map[interface{}]interface{}{ + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "foo", Vother: nil}, + []string{"Vother"}, + []string{}, + }, + { + "map interface unused values", + &map[interface{}]interface{}{ + "vbar": "bar", + "vfoo": nil, + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{"vbar"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + config := &DecoderConfig{ + Metadata: new(Metadata), + Result: tc.target, + ZeroFields: true, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("should not error: %s", err) + } + + err = decoder.Decode(tc.in) + if err != nil { + t.Fatalf("should not error: %s", err) + } + + if !reflect.DeepEqual(tc.out, tc.target) { + t.Fatalf("%q: TestDecode_NilValue() expected: %#v, got: %#v", tc.name, tc.out, tc.target) + } + + if !reflect.DeepEqual(tc.metaKeys, config.Metadata.Keys) { + t.Fatalf("%q: Metadata.Keys mismatch expected: %#v, got: %#v", tc.name, tc.metaKeys, config.Metadata.Keys) + } + + if !reflect.DeepEqual(tc.metaUnused, config.Metadata.Unused) { + t.Fatalf("%q: Metadata.Unused mismatch expected: %#v, got: %#v", tc.name, tc.metaUnused, config.Metadata.Unused) + } + }) + } +} + +// #48 +func TestNestedTypePointerWithDefaults(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + result := NestedPointer{ + Vbar: &Basic{ + Vuint: 42, + }, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } + + // this is the error + if result.Vbar.Vuint != 42 { + t.Errorf("vuint value should be 42: %#v", result.Vbar.Vuint) + } + +} + +type NestedSlice struct { + Vfoo string + Vbars []Basic + Vempty []Basic +} + +// #48 +func TestNestedTypeSliceWithDefaults(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbars": []map[string]interface{}{ + {"vstring": "foo", "vint": 42, "vbool": true}, + {"vint": 42, "vbool": true}, + }, + "vempty": []map[string]interface{}{ + {"vstring": "foo", "vint": 42, "vbool": true}, + {"vint": 42, "vbool": true}, + }, + } + + result := NestedSlice{ + Vbars: []Basic{ + {Vuint: 42}, + {Vstring: "foo"}, + }, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbars[0].Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbars[0].Vstring) + } + // this is the error + if result.Vbars[0].Vuint != 42 { + t.Errorf("vuint value should be 42: %#v", result.Vbars[0].Vuint) + } +} + +// #48 workaround +func TestNestedTypeWithDefaults(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + result := Nested{ + Vbar: Basic{ + Vuint: 42, + }, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } + + // this is the error + if result.Vbar.Vuint != 42 { + t.Errorf("vuint value should be 42: %#v", result.Vbar.Vuint) + } + +} + +// #67 panic() on extending slices (decodeSlice with disabled ZeroValues) +func TestDecodeSliceToEmptySliceWOZeroing(t *testing.T) { + t.Parallel() + + type TestStruct struct { + Vfoo []string + } + + decode := func(m interface{}, rawVal interface{}) error { + config := &DecoderConfig{ + Metadata: nil, + Result: rawVal, + ZeroFields: false, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(m) + } + + { + input := map[string]interface{}{ + "vfoo": []string{"1"}, + } + + result := &TestStruct{} + + err := decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + } + + { + input := map[string]interface{}{ + "vfoo": []string{"1"}, + } + + result := &TestStruct{ + Vfoo: []string{}, + } + + err := decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + } + + { + input := map[string]interface{}{ + "vfoo": []string{"2", "3"}, + } + + result := &TestStruct{ + Vfoo: []string{"1"}, + } + + err := decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + } +} + +// #70 +func TestNextSquashMapstructure(t *testing.T) { + data := &struct { + Level1 struct { + Level2 struct { + Foo string + } `mapstructure:",squash"` + } `mapstructure:",squash"` + }{} + err := Decode(map[interface{}]interface{}{"foo": "baz"}, &data) + if err != nil { + t.Fatalf("should not error: %s", err) + } + if data.Level1.Level2.Foo != "baz" { + t.Fatal("value should be baz") + } +} + +type ImplementsInterfacePointerReceiver struct { + Name string +} + +func (i *ImplementsInterfacePointerReceiver) DoStuff() {} + +type ImplementsInterfaceValueReceiver string + +func (i ImplementsInterfaceValueReceiver) DoStuff() {} + +// GH-140 Type error when using DecodeHook to decode into interface +func TestDecode_DecodeHookInterface(t *testing.T) { + t.Parallel() + + type Interface interface { + DoStuff() + } + type DecodeIntoInterface struct { + Test Interface + } + + testData := map[string]string{"test": "test"} + + stringToPointerInterfaceDecodeHook := func(from, to reflect.Type, data interface{}) (interface{}, error) { + if from.Kind() != reflect.String { + return data, nil + } + + if to != reflect.TypeOf((*Interface)(nil)).Elem() { + return data, nil + } + // Ensure interface is satisfied + var impl Interface = &ImplementsInterfacePointerReceiver{data.(string)} + return impl, nil + } + + stringToValueInterfaceDecodeHook := func(from, to reflect.Type, data interface{}) (interface{}, error) { + if from.Kind() != reflect.String { + return data, nil + } + + if to != reflect.TypeOf((*Interface)(nil)).Elem() { + return data, nil + } + // Ensure interface is satisfied + var impl Interface = ImplementsInterfaceValueReceiver(data.(string)) + return impl, nil + } + + { + decodeInto := new(DecodeIntoInterface) + + decoder, _ := NewDecoder(&DecoderConfig{ + DecodeHook: stringToPointerInterfaceDecodeHook, + Result: decodeInto, + }) + + err := decoder.Decode(testData) + if err != nil { + t.Fatalf("Decode returned error: %s", err) + } + + expected := &ImplementsInterfacePointerReceiver{"test"} + if !reflect.DeepEqual(decodeInto.Test, expected) { + t.Fatalf("expected: %#v (%T), got: %#v (%T)", decodeInto.Test, decodeInto.Test, expected, expected) + } + } + + { + decodeInto := new(DecodeIntoInterface) + + decoder, _ := NewDecoder(&DecoderConfig{ + DecodeHook: stringToValueInterfaceDecodeHook, + Result: decodeInto, + }) + + err := decoder.Decode(testData) + if err != nil { + t.Fatalf("Decode returned error: %s", err) + } + + expected := ImplementsInterfaceValueReceiver("test") + if !reflect.DeepEqual(decodeInto.Test, expected) { + t.Fatalf("expected: %#v (%T), got: %#v (%T)", decodeInto.Test, decodeInto.Test, expected, expected) + } + } +} + +// #103 Check for data type before trying to access its composants prevent a panic error +// in decodeSlice +func TestDecodeBadDataTypeInSlice(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "Toto": "titi", + } + result := []struct { + Toto string + }{} + + if err := Decode(input, &result); err == nil { + t.Error("An error was expected, got nil") + } +} + +// #202 Ensure that intermediate maps in the struct -> struct decode process are settable +// and not just the elements within them. +func TestDecodeIntermediateMapsSettable(t *testing.T) { + type Timestamp struct { + Seconds int64 + Nanos int32 + } + + type TsWrapper struct { + Timestamp *Timestamp + } + + type TimeWrapper struct { + Timestamp time.Time + } + + input := TimeWrapper{ + Timestamp: time.Unix(123456789, 987654), + } + + expected := TsWrapper{ + Timestamp: &Timestamp{ + Seconds: 123456789, + Nanos: 987654, + }, + } + + timePtrType := reflect.TypeOf((*time.Time)(nil)) + mapStrInfType := reflect.TypeOf((map[string]interface{})(nil)) + + var actual TsWrapper + decoder, err := NewDecoder(&DecoderConfig{ + Result: &actual, + DecodeHook: func(from, to reflect.Type, data interface{}) (interface{}, error) { + if from == timePtrType && to == mapStrInfType { + ts := data.(*time.Time) + nanos := ts.UnixNano() + + seconds := nanos / 1000000000 + nanos = nanos % 1000000000 + + return &map[string]interface{}{ + "Seconds": seconds, + "Nanos": int32(nanos), + }, nil + } + return data, nil + }, + }) + + if err != nil { + t.Fatalf("failed to create decoder: %v", err) + } + + if err := decoder.Decode(&input); err != nil { + t.Fatalf("failed to decode input: %v", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("expected: %#[1]v (%[1]T), got: %#[2]v (%[2]T)", expected, actual) + } +} + +// GH-206: decodeInt throws an error for an empty string +func TestDecode_weakEmptyStringToInt(t *testing.T) { + input := map[string]interface{}{ + "StringToInt": "", + "StringToUint": "", + "StringToBool": "", + "StringToFloat": "", + } + + expectedResultWeak := TypeConversionResult{ + StringToInt: 0, + StringToUint: 0, + StringToBool: false, + StringToFloat: 0, + } + + // Test weak type conversion + var resultWeak TypeConversionResult + err := WeakDecode(input, &resultWeak) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if !reflect.DeepEqual(resultWeak, expectedResultWeak) { + t.Errorf("expected \n%#v, got: \n%#v", expectedResultWeak, resultWeak) + } +} + +// GH-228: Squash cause *time.Time set to zero +func TestMapSquash(t *testing.T) { + type AA struct { + T *time.Time + } + type A struct { + AA + } + + v := time.Now() + in := &AA{ + T: &v, + } + out := &A{} + d, err := NewDecoder(&DecoderConfig{ + Squash: true, + Result: out, + }) + if err != nil { + t.Fatalf("jderr: %s", err) + } + if err := d.Decode(in); err != nil { + t.Fatalf("jderr: %s", err) + } + + // these failed + if !v.Equal(*out.T) { + t.Fatal("expected equal") + } + if out.T.IsZero() { + t.Fatal("expected false") + } +} + +// GH-238: Empty key name when decoding map from struct with only omitempty flag +func TestMapOmitEmptyWithEmptyFieldnameInTag(t *testing.T) { + type Struct struct { + Username string `mapstructure:",omitempty"` + Age int `mapstructure:",omitempty"` + } + + s := Struct{ + Username: "Joe", + } + var m map[string]interface{} + + if err := Decode(s, &m); err != nil { + t.Fatal(err) + } + + if len(m) != 1 { + t.Fatalf("fail: %#v", m) + } + if m["Username"] != "Joe" { + t.Fatalf("fail: %#v", m) + } +} diff --git a/pkg/mapstructure/mapstructure_examples_test.go b/pkg/mapstructure/mapstructure_examples_test.go new file mode 100644 index 0000000..2413b69 --- /dev/null +++ b/pkg/mapstructure/mapstructure_examples_test.go @@ -0,0 +1,256 @@ +package mapstructure + +import ( + "fmt" +) + +func ExampleDecode() { + type Person struct { + Name string + Age int + Emails []string + Extra map[string]string + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON where we're not quite sure of the + // struct initially. + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "emails": []string{"one", "two", "three"}, + "extra": map[string]string{ + "twitter": "mitchellh", + }, + } + + var result Person + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: + // mapstructure.Person{Name:"Mitchell", Age:91, Emails:[]string{"one", "two", "three"}, Extra:map[string]string{"twitter":"mitchellh"}} +} + +func ExampleDecode_errors() { + type Person struct { + Name string + Age int + Emails []string + Extra map[string]string + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON where we're not quite sure of the + // struct initially. + input := map[string]interface{}{ + "name": 123, + "age": "bad value", + "emails": []int{1, 2, 3}, + } + + var result Person + err := Decode(input, &result) + if err == nil { + panic("should have an error") + } + + fmt.Println(err.Error()) + // Output: + // 5 error(s) decoding: + // + // * 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value' + // * 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1' + // * 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2' + // * 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3' + // * 'Name' expected type 'string', got unconvertible type 'int', value: '123' +} + +func ExampleDecode_metadata() { + type Person struct { + Name string + Age int + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON where we're not quite sure of the + // struct initially. + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "email": "foo@bar.com", + } + + // For metadata, we make a more advanced DecoderConfig so we can + // more finely configure the decoder that is used. In this case, we + // just tell the decoder we want to track metadata. + var md Metadata + var result Person + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + if err := decoder.Decode(input); err != nil { + panic(err) + } + + fmt.Printf("Unused keys: %#v", md.Unused) + // Output: + // Unused keys: []string{"email"} +} + +func ExampleDecode_weaklyTypedInput() { + type Person struct { + Name string + Age int + Emails []string + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON, generated by a weakly typed language + // such as PHP. + input := map[string]interface{}{ + "name": 123, // number => string + "age": "42", // string => number + "emails": map[string]interface{}{}, // empty map => empty array + } + + var result Person + config := &DecoderConfig{ + WeaklyTypedInput: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + err = decoder.Decode(input) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: mapstructure.Person{Name:"123", Age:42, Emails:[]string{}} +} + +func ExampleDecode_tags() { + // Note that the mapstructure tags defined in the struct type + // can indicate which fields the values are mapped to. + type Person struct { + Name string `mapstructure:"person_name"` + Age int `mapstructure:"person_age"` + } + + input := map[string]interface{}{ + "person_name": "Mitchell", + "person_age": 91, + } + + var result Person + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: + // mapstructure.Person{Name:"Mitchell", Age:91} +} + +func ExampleDecode_embeddedStruct() { + // Squashing multiple embedded structs is allowed using the squash tag. + // This is demonstrated by creating a composite struct of multiple types + // and decoding into it. In this case, a person can carry with it both + // a Family and a Location, as well as their own FirstName. + type Family struct { + LastName string + } + type Location struct { + City string + } + type Person struct { + Family `mapstructure:",squash"` + Location `mapstructure:",squash"` + FirstName string + } + + input := map[string]interface{}{ + "FirstName": "Mitchell", + "LastName": "Hashimoto", + "City": "San Francisco", + } + + var result Person + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%s %s, %s", result.FirstName, result.LastName, result.City) + // Output: + // Mitchell Hashimoto, San Francisco +} + +func ExampleDecode_remainingData() { + // Note that the mapstructure tags defined in the struct type + // can indicate which fields the values are mapped to. + type Person struct { + Name string + Age int + Other map[string]interface{} `mapstructure:",remain"` + } + + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "email": "mitchell@example.com", + } + + var result Person + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: + // mapstructure.Person{Name:"Mitchell", Age:91, Other:map[string]interface {}{"email":"mitchell@example.com"}} +} + +func ExampleDecode_omitempty() { + // Add omitempty annotation to avoid map keys for empty values + type Family struct { + LastName string + } + type Location struct { + City string + } + type Person struct { + *Family `mapstructure:",omitempty"` + *Location `mapstructure:",omitempty"` + Age int + FirstName string + } + + result := &map[string]interface{}{} + input := Person{FirstName: "Somebody"} + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%+v", result) + // Output: + // &map[Age:0 FirstName:Somebody] +} diff --git a/pkg/mapstructure/mapstructure_ext_test.go b/pkg/mapstructure/mapstructure_ext_test.go new file mode 100644 index 0000000..a646a51 --- /dev/null +++ b/pkg/mapstructure/mapstructure_ext_test.go @@ -0,0 +1,58 @@ +package mapstructure + +import ( + "reflect" + "testing" +) + +func TestDecode_Ptr(t *testing.T) { + t.Parallel() + + type G struct { + Id int + Name string + } + + type X struct { + Id int + Name int + } + + type AG struct { + List []*G + } + + type AX struct { + List []*X + } + + g2 := &AG{ + List: []*G{ + { + Id: 11, + Name: "gg", + }, + }, + } + x2 := AX{} + + // 报错但还是会转换成功,转换后值为目标类型的 0 值 + err := Decode(g2, &x2) + + res := AX{ + List: []*X{ + { + Id: 11, + Name: 0, // 这个类型的 0 值 + }, + }, + } + + if err == nil { + t.Errorf("Decode_Ptr jderr should not be 'nil': %#v", err) + } + + if !reflect.DeepEqual(res, x2) { + t.Errorf("result should be %#v: got %#v", res, x2) + } +} diff --git a/pkg/mapstructure/mapstructure_test.go b/pkg/mapstructure/mapstructure_test.go new file mode 100644 index 0000000..17e609a --- /dev/null +++ b/pkg/mapstructure/mapstructure_test.go @@ -0,0 +1,2763 @@ +package mapstructure + +import ( + "encoding/json" + "io" + "reflect" + "sort" + "strings" + "testing" + "time" +) + +type Basic struct { + Vstring string + Vint int + Vint8 int8 + Vint16 int16 + Vint32 int32 + Vint64 int64 + Vuint uint + Vbool bool + Vfloat float64 + Vextra string + vsilent bool + Vdata interface{} + VjsonInt int + VjsonUint uint + VjsonUint64 uint64 + VjsonFloat float64 + VjsonNumber json.Number +} + +type BasicPointer struct { + Vstring *string + Vint *int + Vuint *uint + Vbool *bool + Vfloat *float64 + Vextra *string + vsilent *bool + Vdata *interface{} + VjsonInt *int + VjsonFloat *float64 + VjsonNumber *json.Number +} + +type BasicSquash struct { + Test Basic `mapstructure:",squash"` +} + +type Embedded struct { + Basic + Vunique string +} + +type EmbeddedPointer struct { + *Basic + Vunique string +} + +type EmbeddedSquash struct { + Basic `mapstructure:",squash"` + Vunique string +} + +type EmbeddedPointerSquash struct { + *Basic `mapstructure:",squash"` + Vunique string +} + +type BasicMapStructure struct { + Vunique string `mapstructure:"vunique"` + Vtime *time.Time `mapstructure:"time"` +} + +type NestedPointerWithMapstructure struct { + Vbar *BasicMapStructure `mapstructure:"vbar"` +} + +type EmbeddedPointerSquashWithNestedMapstructure struct { + *NestedPointerWithMapstructure `mapstructure:",squash"` + Vunique string +} + +type EmbeddedAndNamed struct { + Basic + Named Basic + Vunique string +} + +type SliceAlias []string + +type EmbeddedSlice struct { + SliceAlias `mapstructure:"slice_alias"` + Vunique string +} + +type ArrayAlias [2]string + +type EmbeddedArray struct { + ArrayAlias `mapstructure:"array_alias"` + Vunique string +} + +type SquashOnNonStructType struct { + InvalidSquashType int `mapstructure:",squash"` +} + +type Map struct { + Vfoo string + Vother map[string]string +} + +type MapOfStruct struct { + Value map[string]Basic +} + +type Nested struct { + Vfoo string + Vbar Basic +} + +type NestedPointer struct { + Vfoo string + Vbar *Basic +} + +type NilInterface struct { + W io.Writer +} + +type NilPointer struct { + Value *string +} + +type Slice struct { + Vfoo string + Vbar []string +} + +type SliceOfAlias struct { + Vfoo string + Vbar SliceAlias +} + +type SliceOfStruct struct { + Value []Basic +} + +type SlicePointer struct { + Vbar *[]string +} + +type Array struct { + Vfoo string + Vbar [2]string +} + +type ArrayOfStruct struct { + Value [2]Basic +} + +type Func struct { + Foo func() string +} + +type Tagged struct { + Extra string `mapstructure:"bar,what,what"` + Value string `mapstructure:"foo"` +} + +type Remainder struct { + A string + Extra map[string]interface{} `mapstructure:",remain"` +} + +type StructWithOmitEmpty struct { + VisibleStringField string `mapstructure:"visible-string"` + OmitStringField string `mapstructure:"omittable-string,omitempty"` + VisibleIntField int `mapstructure:"visible-int"` + OmitIntField int `mapstructure:"omittable-int,omitempty"` + VisibleFloatField float64 `mapstructure:"visible-float"` + OmitFloatField float64 `mapstructure:"omittable-float,omitempty"` + VisibleSliceField []interface{} `mapstructure:"visible-slice"` + OmitSliceField []interface{} `mapstructure:"omittable-slice,omitempty"` + VisibleMapField map[string]interface{} `mapstructure:"visible-map"` + OmitMapField map[string]interface{} `mapstructure:"omittable-map,omitempty"` + NestedField *Nested `mapstructure:"visible-nested"` + OmitNestedField *Nested `mapstructure:"omittable-nested,omitempty"` +} + +type TypeConversionResult struct { + IntToFloat float32 + IntToUint uint + IntToBool bool + IntToString string + UintToInt int + UintToFloat float32 + UintToBool bool + UintToString string + BoolToInt int + BoolToUint uint + BoolToFloat float32 + BoolToString string + FloatToInt int + FloatToUint uint + FloatToBool bool + FloatToString string + SliceUint8ToString string + StringToSliceUint8 []byte + ArrayUint8ToString string + StringToInt int + StringToUint uint + StringToBool bool + StringToFloat float32 + StringToStrSlice []string + StringToIntSlice []int + StringToStrArray [1]string + StringToIntArray [1]int + SliceToMap map[string]interface{} + MapToSlice []interface{} + ArrayToMap map[string]interface{} + MapToArray [1]interface{} +} + +func TestBasicTypes(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vint8": 42, + "vint16": 42, + "vint32": 42, + "vint64": 42, + "Vuint": 42, + "vbool": true, + "Vfloat": 42.42, + "vsilent": true, + "vdata": 42, + "vjsonInt": json.Number("1234"), + "vjsonUint": json.Number("1234"), + "vjsonUint64": json.Number("9223372036854775809"), // 2^63 + 1 + "vjsonFloat": json.Number("1234.5"), + "vjsonNumber": json.Number("1234.5"), + } + + var result Basic + err := Decode(input, &result) + if err != nil { + t.Errorf("got an jderr: %s", err.Error()) + t.FailNow() + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vint) + } + if result.Vint8 != 42 { + t.Errorf("vint8 value should be 42: %#v", result.Vint) + } + if result.Vint16 != 42 { + t.Errorf("vint16 value should be 42: %#v", result.Vint) + } + if result.Vint32 != 42 { + t.Errorf("vint32 value should be 42: %#v", result.Vint) + } + if result.Vint64 != 42 { + t.Errorf("vint64 value should be 42: %#v", result.Vint) + } + + if result.Vuint != 42 { + t.Errorf("vuint value should be 42: %#v", result.Vuint) + } + + if result.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbool) + } + + if result.Vfloat != 42.42 { + t.Errorf("vfloat value should be 42.42: %#v", result.Vfloat) + } + + if result.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vextra) + } + + if result.vsilent != false { + t.Error("vsilent should not be set, it is unexported") + } + + if result.Vdata != 42 { + t.Error("vdata should be valid") + } + + if result.VjsonInt != 1234 { + t.Errorf("vjsonint value should be 1234: %#v", result.VjsonInt) + } + + if result.VjsonUint != 1234 { + t.Errorf("vjsonuint value should be 1234: %#v", result.VjsonUint) + } + + if result.VjsonUint64 != 9223372036854775809 { + t.Errorf("vjsonuint64 value should be 9223372036854775809: %#v", result.VjsonUint64) + } + + if result.VjsonFloat != 1234.5 { + t.Errorf("vjsonfloat value should be 1234.5: %#v", result.VjsonFloat) + } + + if !reflect.DeepEqual(result.VjsonNumber, json.Number("1234.5")) { + t.Errorf("vjsonnumber value should be '1234.5': %T, %#v", result.VjsonNumber, result.VjsonNumber) + } +} + +func TestBasic_IntWithFloat(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vint": float64(42), + } + + var result Basic + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } +} + +func TestBasic_Merge(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vint": 42, + } + + var result Basic + result.Vuint = 100 + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + expected := Basic{ + Vint: 42, + Vuint: 100, + } + if !reflect.DeepEqual(result, expected) { + t.Fatalf("bad: %#v", result) + } +} + +// Test for issue #46. +func TestBasic_Struct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vdata": map[string]interface{}{ + "vstring": "foo", + }, + } + + var result, inner Basic + result.Vdata = &inner + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + expected := Basic{ + Vdata: &Basic{ + Vstring: "foo", + }, + } + if !reflect.DeepEqual(result, expected) { + t.Fatalf("bad: %#v", result) + } +} + +func TestBasic_interfaceStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + } + + var iface interface{} = &Basic{} + err := Decode(input, &iface) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + expected := &Basic{ + Vstring: "foo", + } + if !reflect.DeepEqual(iface, expected) { + t.Fatalf("bad: %#v", iface) + } +} + +// Issue 187 +func TestBasic_interfaceStructNonPtr(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + } + + var iface interface{} = Basic{} + err := Decode(input, &iface) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + expected := Basic{ + Vstring: "foo", + } + if !reflect.DeepEqual(iface, expected) { + t.Fatalf("bad: %#v", iface) + } +} + +func TestDecode_BasicSquash(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + } + + var result BasicSquash + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Test.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Test.Vstring) + } +} + +func TestDecodeFrom_BasicSquash(t *testing.T) { + t.Parallel() + + var v interface{} + var ok bool + + input := BasicSquash{ + Test: Basic{ + Vstring: "foo", + }, + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if _, ok = result["Test"]; ok { + t.Error("test should not be present in map") + } + + v, ok = result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } +} + +func TestDecode_Embedded(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "Basic": map[string]interface{}{ + "vstring": "innerfoo", + }, + "vunique": "bar", + } + + var result Embedded + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vstring != "innerfoo" { + t.Errorf("vstring value should be 'innerfoo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedPointer(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "Basic": map[string]interface{}{ + "vstring": "innerfoo", + }, + "vunique": "bar", + } + + var result EmbeddedPointer + err := Decode(input, &result) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + expected := EmbeddedPointer{ + Basic: &Basic{ + Vstring: "innerfoo", + }, + Vunique: "bar", + } + if !reflect.DeepEqual(result, expected) { + t.Fatalf("bad: %#v", result) + } +} + +func TestDecode_EmbeddedSlice(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "slice_alias": []string{"foo", "bar"}, + "vunique": "bar", + } + + var result EmbeddedSlice + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if !reflect.DeepEqual(result.SliceAlias, SliceAlias([]string{"foo", "bar"})) { + t.Errorf("slice value: %#v", result.SliceAlias) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedArray(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "array_alias": [2]string{"foo", "bar"}, + "vunique": "bar", + } + + var result EmbeddedArray + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if !reflect.DeepEqual(result.ArrayAlias, ArrayAlias([2]string{"foo", "bar"})) { + t.Errorf("array value: %#v", result.ArrayAlias) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_decodeSliceWithArray(t *testing.T) { + t.Parallel() + + var result []int + input := [1]int{1} + expected := []int{1} + if err := Decode(input, &result); err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if !reflect.DeepEqual(expected, result) { + t.Errorf("wanted %+v, got %+v", expected, result) + } +} + +func TestDecode_EmbeddedNoSquash(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var result Embedded + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vstring != "" { + t.Errorf("vstring value should be empty: %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedPointerNoSquash(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + result := EmbeddedPointer{ + Basic: &Basic{}, + } + + err := Decode(input, &result) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + if result.Vstring != "" { + t.Errorf("vstring value should be empty: %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedSquash(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var result EmbeddedSquash + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecodeFrom_EmbeddedSquash(t *testing.T) { + t.Parallel() + + var v interface{} + var ok bool + + input := EmbeddedSquash{ + Basic: Basic{ + Vstring: "foo", + }, + Vunique: "bar", + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if _, ok = result["Basic"]; ok { + t.Error("basic should not be present in map") + } + + v, ok = result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } + + v, ok = result["Vunique"] + if !ok { + t.Error("vunique should be present in map") + } else if !reflect.DeepEqual(v, "bar") { + t.Errorf("vunique value should be 'bar': %#v", v) + } +} + +func TestDecode_EmbeddedPointerSquash_FromStructToMap(t *testing.T) { + t.Parallel() + + input := EmbeddedPointerSquash{ + Basic: &Basic{ + Vstring: "foo", + }, + Vunique: "bar", + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result["Vstring"] != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result["Vstring"]) + } + + if result["Vunique"] != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result["Vunique"]) + } +} + +func TestDecode_EmbeddedPointerSquash_FromMapToStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "Vstring": "foo", + "Vunique": "bar", + } + + result := EmbeddedPointerSquash{ + Basic: &Basic{}, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedPointerSquashWithNestedMapstructure_FromStructToMap(t *testing.T) { + t.Parallel() + + vTime := time.Now() + + input := EmbeddedPointerSquashWithNestedMapstructure{ + NestedPointerWithMapstructure: &NestedPointerWithMapstructure{ + Vbar: &BasicMapStructure{ + Vunique: "bar", + Vtime: &vTime, + }, + }, + Vunique: "foo", + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + expected := map[string]interface{}{ + "vbar": map[string]interface{}{ + "vunique": "bar", + "time": &vTime, + }, + "Vunique": "foo", + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("result should be %#v: got %#v", expected, result) + } +} + +func TestDecode_EmbeddedPointerSquashWithNestedMapstructure_FromMapToStruct(t *testing.T) { + t.Parallel() + + vTime := time.Now() + + input := map[string]interface{}{ + "vbar": map[string]interface{}{ + "vunique": "bar", + "time": &vTime, + }, + "Vunique": "foo", + } + + result := EmbeddedPointerSquashWithNestedMapstructure{ + NestedPointerWithMapstructure: &NestedPointerWithMapstructure{}, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + expected := EmbeddedPointerSquashWithNestedMapstructure{ + NestedPointerWithMapstructure: &NestedPointerWithMapstructure{ + Vbar: &BasicMapStructure{ + Vunique: "bar", + Vtime: &vTime, + }, + }, + Vunique: "foo", + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("result should be %#v: got %#v", expected, result) + } +} + +func TestDecode_EmbeddedSquashConfig(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + "Named": map[string]interface{}{ + "vstring": "baz", + }, + } + + var result EmbeddedAndNamed + config := &DecoderConfig{ + Squash: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } + + if result.Named.Vstring != "baz" { + t.Errorf("Named.vstring value should be 'baz': %#v", result.Named.Vstring) + } +} + +func TestDecodeFrom_EmbeddedSquashConfig(t *testing.T) { + t.Parallel() + + input := EmbeddedAndNamed{ + Basic: Basic{Vstring: "foo"}, + Named: Basic{Vstring: "baz"}, + Vunique: "bar", + } + + result := map[string]interface{}{} + config := &DecoderConfig{ + Squash: true, + Result: &result, + } + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if _, ok := result["Basic"]; ok { + t.Error("basic should not be present in map") + } + + v, ok := result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } + + v, ok = result["Vunique"] + if !ok { + t.Error("vunique should be present in map") + } else if !reflect.DeepEqual(v, "bar") { + t.Errorf("vunique value should be 'bar': %#v", v) + } + + v, ok = result["Named"] + if !ok { + t.Error("Named should be present in map") + } else { + named := v.(map[string]interface{}) + v, ok := named["Vstring"] + if !ok { + t.Error("Named: vstring should be present in map") + } else if !reflect.DeepEqual(v, "baz") { + t.Errorf("Named: vstring should be 'baz': %#v", v) + } + } +} + +func TestDecodeFrom_EmbeddedSquashConfig_WithTags(t *testing.T) { + t.Parallel() + + var v interface{} + var ok bool + + input := EmbeddedSquash{ + Basic: Basic{ + Vstring: "foo", + }, + Vunique: "bar", + } + + result := map[string]interface{}{} + config := &DecoderConfig{ + Squash: true, + Result: &result, + } + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if _, ok = result["Basic"]; ok { + t.Error("basic should not be present in map") + } + + v, ok = result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } + + v, ok = result["Vunique"] + if !ok { + t.Error("vunique should be present in map") + } else if !reflect.DeepEqual(v, "bar") { + t.Errorf("vunique value should be 'bar': %#v", v) + } +} + +func TestDecode_SquashOnNonStructType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "InvalidSquashType": 42, + } + + var result SquashOnNonStructType + err := Decode(input, &result) + if err == nil { + t.Fatal("unexpected success decoding invalid squash field type") + } else if !strings.Contains(err.Error(), "unsupported type for squash") { + t.Fatalf("unexpected error message for invalid squash field type: %s", err) + } +} + +func TestDecode_DecodeHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vint": "WHAT", + } + + decodeHook := func(from reflect.Kind, to reflect.Kind, v interface{}) (interface{}, error) { + if from == reflect.String && to != reflect.String { + return 5, nil + } + + return v, nil + } + + var result Basic + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Vint != 5 { + t.Errorf("vint should be 5: %#v", result.Vint) + } +} + +func TestDecode_DecodeHookType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vint": "WHAT", + } + + decodeHook := func(from reflect.Type, to reflect.Type, v interface{}) (interface{}, error) { + if from.Kind() == reflect.String && + to.Kind() != reflect.String { + return 5, nil + } + + return v, nil + } + + var result Basic + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Vint != 5 { + t.Errorf("vint should be 5: %#v", result.Vint) + } +} + +func TestDecode_Nil(t *testing.T) { + t.Parallel() + + var input interface{} + result := Basic{ + Vstring: "foo", + } + + err := Decode(input, &result) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + if result.Vstring != "foo" { + t.Fatalf("bad: %#v", result.Vstring) + } +} + +func TestDecode_NilInterfaceHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "w": "", + } + + decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) { + if t.String() == "io.Writer" { + return nil, nil + } + + return v, nil + } + + var result NilInterface + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.W != nil { + t.Errorf("W should be nil: %#v", result.W) + } +} + +func TestDecode_NilPointerHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": "", + } + + decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) { + if typed, ok := v.(string); ok { + if typed == "" { + return nil, nil + } + } + return v, nil + } + + var result NilPointer + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Value != nil { + t.Errorf("W should be nil: %#v", result.Value) + } +} + +func TestDecode_FuncHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "baz", + } + + decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) { + if t.Kind() != reflect.Func { + return v, nil + } + val := v.(string) + return func() string { return val }, nil + } + + var result Func + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Foo() != "baz" { + t.Errorf("Foo call result should be 'baz': %s", result.Foo()) + } +} + +func TestDecode_NonStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "bar", + "bar": "baz", + } + + var result map[string]string + err := Decode(input, &result) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + if result["foo"] != "bar" { + t.Fatal("foo is not bar") + } +} + +func TestDecode_StructMatch(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vbar": Basic{ + Vstring: "foo", + }, + } + + var result Nested + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("bad: %#v", result) + } +} + +func TestDecode_TypeConversion(t *testing.T) { + input := map[string]interface{}{ + "IntToFloat": 42, + "IntToUint": 42, + "IntToBool": 1, + "IntToString": 42, + "UintToInt": 42, + "UintToFloat": 42, + "UintToBool": 42, + "UintToString": 42, + "BoolToInt": true, + "BoolToUint": true, + "BoolToFloat": true, + "BoolToString": true, + "FloatToInt": 42.42, + "FloatToUint": 42.42, + "FloatToBool": 42.42, + "FloatToString": 42.42, + "SliceUint8ToString": []uint8("foo"), + "StringToSliceUint8": "foo", + "ArrayUint8ToString": [3]uint8{'f', 'o', 'o'}, + "StringToInt": "42", + "StringToUint": "42", + "StringToBool": "1", + "StringToFloat": "42.42", + "StringToStrSlice": "A", + "StringToIntSlice": "42", + "StringToStrArray": "A", + "StringToIntArray": "42", + "SliceToMap": []interface{}{}, + "MapToSlice": map[string]interface{}{}, + "ArrayToMap": []interface{}{}, + "MapToArray": map[string]interface{}{}, + } + + expectedResultStrict := TypeConversionResult{ + IntToFloat: 42.0, + IntToUint: 42, + UintToInt: 42, + UintToFloat: 42, + BoolToInt: 0, + BoolToUint: 0, + BoolToFloat: 0, + FloatToInt: 42, + FloatToUint: 42, + } + + expectedResultWeak := TypeConversionResult{ + IntToFloat: 42.0, + IntToUint: 42, + IntToBool: true, + IntToString: "42", + UintToInt: 42, + UintToFloat: 42, + UintToBool: true, + UintToString: "42", + BoolToInt: 1, + BoolToUint: 1, + BoolToFloat: 1, + BoolToString: "1", + FloatToInt: 42, + FloatToUint: 42, + FloatToBool: true, + FloatToString: "42.42", + SliceUint8ToString: "foo", + StringToSliceUint8: []byte("foo"), + ArrayUint8ToString: "foo", + StringToInt: 42, + StringToUint: 42, + StringToBool: true, + StringToFloat: 42.42, + StringToStrSlice: []string{"A"}, + StringToIntSlice: []int{42}, + StringToStrArray: [1]string{"A"}, + StringToIntArray: [1]int{42}, + SliceToMap: map[string]interface{}{}, + MapToSlice: []interface{}{}, + ArrayToMap: map[string]interface{}{}, + MapToArray: [1]interface{}{}, + } + + // Test strict type conversion + var resultStrict TypeConversionResult + err := Decode(input, &resultStrict) + if err == nil { + t.Errorf("should return an error") + } + if !reflect.DeepEqual(resultStrict, expectedResultStrict) { + t.Errorf("expected %v, got: %v", expectedResultStrict, resultStrict) + } + + // Test weak type conversion + var decoder *Decoder + var resultWeak TypeConversionResult + + config := &DecoderConfig{ + WeaklyTypedInput: true, + Result: &resultWeak, + } + + decoder, err = NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if !reflect.DeepEqual(resultWeak, expectedResultWeak) { + t.Errorf("expected \n%#v, got: \n%#v", expectedResultWeak, resultWeak) + } +} + +func TestDecoder_ErrorUnused(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "hello", + "foo": "bar", + } + + var result Basic + config := &DecoderConfig{ + ErrorUnused: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err == nil { + t.Fatal("expected error") + } +} + +func TestDecoder_ErrorUnused_NotSetable(t *testing.T) { + t.Parallel() + + // lowercase vsilent is unexported and cannot be set + input := map[string]interface{}{ + "vsilent": "false", + } + + var result Basic + config := &DecoderConfig{ + ErrorUnused: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err == nil { + t.Fatal("expected error") + } +} +func TestDecoder_ErrorUnset(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "hello", + "foo": "bar", + } + + var result Basic + config := &DecoderConfig{ + ErrorUnset: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err == nil { + t.Fatal("expected error") + } +} + +func TestMap(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vother": map[interface{}]interface{}{ + "foo": "foo", + "bar": "bar", + }, + } + + var result Map + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an error: %s", err) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vother == nil { + t.Fatal("vother should not be nil") + } + + if len(result.Vother) != 2 { + t.Error("vother should have two items") + } + + if result.Vother["foo"] != "foo" { + t.Errorf("'foo' key should be foo, got: %#v", result.Vother["foo"]) + } + + if result.Vother["bar"] != "bar" { + t.Errorf("'bar' key should be bar, got: %#v", result.Vother["bar"]) + } +} + +func TestMapMerge(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vother": map[interface{}]interface{}{ + "foo": "foo", + "bar": "bar", + }, + } + + var result Map + result.Vother = map[string]string{"hello": "world"} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an error: %s", err) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + expected := map[string]string{ + "foo": "foo", + "bar": "bar", + "hello": "world", + } + if !reflect.DeepEqual(result.Vother, expected) { + t.Errorf("bad: %#v", result.Vother) + } +} + +func TestMapOfStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": map[string]interface{}{ + "foo": map[string]string{"vstring": "one"}, + "bar": map[string]string{"vstring": "two"}, + }, + } + + var result MapOfStruct + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Value == nil { + t.Fatal("value should not be nil") + } + + if len(result.Value) != 2 { + t.Error("value should have two items") + } + + if result.Value["foo"].Vstring != "one" { + t.Errorf("foo value should be 'one', got: %s", result.Value["foo"].Vstring) + } + + if result.Value["bar"].Vstring != "two" { + t.Errorf("bar value should be 'two', got: %s", result.Value["bar"].Vstring) + } +} + +func TestNestedType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + var result Nested + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } +} + +func TestNestedTypePointer(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": &map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + var result NestedPointer + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } +} + +// Test for issue #46. +func TestNestedTypeInterface(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": &map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + + "vdata": map[string]interface{}{ + "vstring": "bar", + }, + }, + } + + var result NestedPointer + result.Vbar = new(Basic) + result.Vbar.Vdata = new(Basic) + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } + + if result.Vbar.Vdata.(*Basic).Vstring != "bar" { + t.Errorf("vstring value should be 'bar': %#v", result.Vbar.Vdata.(*Basic).Vstring) + } +} + +func TestSlice(t *testing.T) { + t.Parallel() + + inputStringSlice := map[string]interface{}{ + "vfoo": "foo", + "vbar": []string{"foo", "bar", "baz"}, + } + + inputStringSlicePointer := map[string]interface{}{ + "vfoo": "foo", + "vbar": &[]string{"foo", "bar", "baz"}, + } + + outputStringSlice := &Slice{ + "foo", + []string{"foo", "bar", "baz"}, + } + + testSliceInput(t, inputStringSlice, outputStringSlice) + testSliceInput(t, inputStringSlicePointer, outputStringSlice) +} + +func TestInvalidSlice(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": 42, + } + + result := Slice{} + err := Decode(input, &result) + if err == nil { + t.Errorf("expected failure") + } +} + +func TestSliceOfStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": []map[string]interface{}{ + {"vstring": "one"}, + {"vstring": "two"}, + }, + } + + var result SliceOfStruct + err := Decode(input, &result) + if err != nil { + t.Fatalf("got unexpected error: %s", err) + } + + if len(result.Value) != 2 { + t.Fatalf("expected two values, got %d", len(result.Value)) + } + + if result.Value[0].Vstring != "one" { + t.Errorf("first value should be 'one', got: %s", result.Value[0].Vstring) + } + + if result.Value[1].Vstring != "two" { + t.Errorf("second value should be 'two', got: %s", result.Value[1].Vstring) + } +} + +func TestSliceCornerCases(t *testing.T) { + t.Parallel() + + // Input with a map with zero values + input := map[string]interface{}{} + var resultWeak []Basic + + err := WeakDecode(input, &resultWeak) + if err != nil { + t.Fatalf("got unexpected error: %s", err) + } + + if len(resultWeak) != 0 { + t.Errorf("length should be 0") + } + // Input with more values + input = map[string]interface{}{ + "Vstring": "foo", + } + + resultWeak = nil + err = WeakDecode(input, &resultWeak) + if err != nil { + t.Fatalf("got unexpected error: %s", err) + } + + if resultWeak[0].Vstring != "foo" { + t.Errorf("value does not match") + } +} + +func TestSliceToMap(t *testing.T) { + t.Parallel() + + input := []map[string]interface{}{ + { + "foo": "bar", + }, + { + "bar": "baz", + }, + } + + var result map[string]interface{} + err := WeakDecode(input, &result) + if err != nil { + t.Fatalf("got an error: %s", err) + } + + expected := map[string]interface{}{ + "foo": "bar", + "bar": "baz", + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("bad: %#v", result) + } +} + +func TestArray(t *testing.T) { + t.Parallel() + + inputStringArray := map[string]interface{}{ + "vfoo": "foo", + "vbar": [2]string{"foo", "bar"}, + } + + inputStringArrayPointer := map[string]interface{}{ + "vfoo": "foo", + "vbar": &[2]string{"foo", "bar"}, + } + + outputStringArray := &Array{ + "foo", + [2]string{"foo", "bar"}, + } + + testArrayInput(t, inputStringArray, outputStringArray) + testArrayInput(t, inputStringArrayPointer, outputStringArray) +} + +func TestInvalidArray(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": 42, + } + + result := Array{} + err := Decode(input, &result) + if err == nil { + t.Errorf("expected failure") + } +} + +func TestArrayOfStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": []map[string]interface{}{ + {"vstring": "one"}, + {"vstring": "two"}, + }, + } + + var result ArrayOfStruct + err := Decode(input, &result) + if err != nil { + t.Fatalf("got unexpected error: %s", err) + } + + if len(result.Value) != 2 { + t.Fatalf("expected two values, got %d", len(result.Value)) + } + + if result.Value[0].Vstring != "one" { + t.Errorf("first value should be 'one', got: %s", result.Value[0].Vstring) + } + + if result.Value[1].Vstring != "two" { + t.Errorf("second value should be 'two', got: %s", result.Value[1].Vstring) + } +} + +func TestArrayToMap(t *testing.T) { + t.Parallel() + + input := []map[string]interface{}{ + { + "foo": "bar", + }, + { + "bar": "baz", + }, + } + + var result map[string]interface{} + err := WeakDecode(input, &result) + if err != nil { + t.Fatalf("got an error: %s", err) + } + + expected := map[string]interface{}{ + "foo": "bar", + "bar": "baz", + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("bad: %#v", result) + } +} + +func TestDecodeTable(t *testing.T) { + t.Parallel() + + // We need to make new types so that we don't get the short-circuit + // copy functionality. We want to test the deep copying functionality. + type BasicCopy Basic + type NestedPointerCopy NestedPointer + type MapCopy Map + + tests := []struct { + name string + in interface{} + target interface{} + out interface{} + wantErr bool + }{ + { + "basic struct input", + &Basic{ + Vstring: "vstring", + Vint: 2, + Vint8: 2, + Vint16: 2, + Vint32: 2, + Vint64: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + vsilent: true, + Vdata: []byte("data"), + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Vstring": "vstring", + "Vint": 2, + "Vint8": int8(2), + "Vint16": int16(2), + "Vint32": int32(2), + "Vint64": int64(2), + "Vuint": uint(3), + "Vbool": true, + "Vfloat": 4.56, + "Vextra": "vextra", + "Vdata": []byte("data"), + "VjsonInt": 0, + "VjsonUint": uint(0), + "VjsonUint64": uint64(0), + "VjsonFloat": 0.0, + "VjsonNumber": json.Number(""), + }, + false, + }, + { + "embedded struct input", + &Embedded{ + Vunique: "vunique", + Basic: Basic{ + Vstring: "vstring", + Vint: 2, + Vint8: 2, + Vint16: 2, + Vint32: 2, + Vint64: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + vsilent: true, + Vdata: []byte("data"), + }, + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Vunique": "vunique", + "Basic": map[string]interface{}{ + "Vstring": "vstring", + "Vint": 2, + "Vint8": int8(2), + "Vint16": int16(2), + "Vint32": int32(2), + "Vint64": int64(2), + "Vuint": uint(3), + "Vbool": true, + "Vfloat": 4.56, + "Vextra": "vextra", + "Vdata": []byte("data"), + "VjsonInt": 0, + "VjsonUint": uint(0), + "VjsonUint64": uint64(0), + "VjsonFloat": 0.0, + "VjsonNumber": json.Number(""), + }, + }, + false, + }, + { + "struct => struct", + &Basic{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + Vdata: []byte("data"), + vsilent: true, + }, + &BasicCopy{}, + &BasicCopy{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + Vdata: []byte("data"), + }, + false, + }, + { + "struct => struct with pointers", + &NestedPointer{ + Vfoo: "hello", + Vbar: nil, + }, + &NestedPointerCopy{}, + &NestedPointerCopy{ + Vfoo: "hello", + }, + false, + }, + { + "basic pointer to non-pointer", + &BasicPointer{ + Vstring: stringPtr("vstring"), + Vint: intPtr(2), + Vuint: uintPtr(3), + Vbool: boolPtr(true), + Vfloat: floatPtr(4.56), + Vdata: interfacePtr([]byte("data")), + }, + &Basic{}, + &Basic{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vdata: []byte("data"), + }, + false, + }, + { + "slice non-pointer to pointer", + &Slice{}, + &SlicePointer{}, + &SlicePointer{}, + false, + }, + { + "slice non-pointer to pointer, zero field", + &Slice{}, + &SlicePointer{ + Vbar: &[]string{"yo"}, + }, + &SlicePointer{}, + false, + }, + { + "slice to slice alias", + &Slice{}, + &SliceOfAlias{}, + &SliceOfAlias{}, + false, + }, + { + "nil map to map", + &Map{}, + &MapCopy{}, + &MapCopy{}, + false, + }, + { + "nil map to non-empty map", + &Map{}, + &MapCopy{Vother: map[string]string{"foo": "bar"}}, + &MapCopy{}, + false, + }, + + { + "slice input - should error", + []string{"foo", "bar"}, + &map[string]interface{}{}, + &map[string]interface{}{}, + true, + }, + { + "struct with slice property", + &Slice{ + Vfoo: "vfoo", + Vbar: []string{"foo", "bar"}, + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Vfoo": "vfoo", + "Vbar": []string{"foo", "bar"}, + }, + false, + }, + { + "struct with empty slice", + &map[string]interface{}{ + "Vbar": []string{}, + }, + &Slice{}, + &Slice{ + Vbar: []string{}, + }, + false, + }, + { + "struct with slice of struct property", + &SliceOfStruct{ + Value: []Basic{ + Basic{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + vsilent: true, + Vdata: []byte("data"), + }, + }, + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Value": []Basic{ + Basic{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + vsilent: true, + Vdata: []byte("data"), + }, + }, + }, + false, + }, + { + "struct with map property", + &Map{ + Vfoo: "vfoo", + Vother: map[string]string{"vother": "vother"}, + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Vfoo": "vfoo", + "Vother": map[string]string{ + "vother": "vother", + }}, + false, + }, + { + "tagged struct", + &Tagged{ + Extra: "extra", + Value: "value", + }, + &map[string]string{}, + &map[string]string{ + "bar": "extra", + "foo": "value", + }, + false, + }, + { + "omit tag struct", + &struct { + Value string `mapstructure:"value"` + Omit string `mapstructure:"-"` + }{ + Value: "value", + Omit: "omit", + }, + &map[string]string{}, + &map[string]string{ + "value": "value", + }, + false, + }, + { + "decode to wrong map type", + &struct { + Value string + }{ + Value: "string", + }, + &map[string]int{}, + &map[string]int{}, + true, + }, + { + "remainder", + map[string]interface{}{ + "A": "hello", + "B": "goodbye", + "C": "yo", + }, + &Remainder{}, + &Remainder{ + A: "hello", + Extra: map[string]interface{}{ + "B": "goodbye", + "C": "yo", + }, + }, + false, + }, + { + "remainder with no extra", + map[string]interface{}{ + "A": "hello", + }, + &Remainder{}, + &Remainder{ + A: "hello", + Extra: nil, + }, + false, + }, + { + "struct with omitempty tag return non-empty values", + &struct { + VisibleField interface{} `mapstructure:"visible"` + OmitField interface{} `mapstructure:"omittable,omitempty"` + }{ + VisibleField: nil, + OmitField: "string", + }, + &map[string]interface{}{}, + &map[string]interface{}{"visible": nil, "omittable": "string"}, + false, + }, + { + "struct with omitempty tag ignore empty values", + &struct { + VisibleField interface{} `mapstructure:"visible"` + OmitField interface{} `mapstructure:"omittable,omitempty"` + }{ + VisibleField: nil, + OmitField: nil, + }, + &map[string]interface{}{}, + &map[string]interface{}{"visible": nil}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := Decode(tt.in, tt.target); (err != nil) != tt.wantErr { + t.Fatalf("%q: TestMapOutputForStructuredInputs() unexpected error: %s", tt.name, err) + } + + if !reflect.DeepEqual(tt.out, tt.target) { + t.Fatalf("%q: TestMapOutputForStructuredInputs() expected: %#v, got: %#v", tt.name, tt.out, tt.target) + } + }) + } +} + +func TestInvalidType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": 42, + } + + var result Basic + err := Decode(input, &result) + if err == nil { + t.Fatal("error should exist") + } + + derr, ok := err.(*Error) + if !ok { + t.Fatalf("error should be kind of Error, instead: %#v", err) + } + + if derr.Errors[0] != + "'Vstring' expected type 'string', got unconvertible type 'int', value: '42'" { + t.Errorf("got unexpected error: %s", err) + } + + inputNegIntUint := map[string]interface{}{ + "vuint": -42, + } + + err = Decode(inputNegIntUint, &result) + if err == nil { + t.Fatal("error should exist") + } + + derr, ok = err.(*Error) + if !ok { + t.Fatalf("error should be kind of Error, instead: %#v", err) + } + + if derr.Errors[0] != "cannot parse 'Vuint', -42 overflows uint" { + t.Errorf("got unexpected error: %s", err) + } + + inputNegFloatUint := map[string]interface{}{ + "vuint": -42.0, + } + + err = Decode(inputNegFloatUint, &result) + if err == nil { + t.Fatal("error should exist") + } + + derr, ok = err.(*Error) + if !ok { + t.Fatalf("error should be kind of Error, instead: %#v", err) + } + + if derr.Errors[0] != "cannot parse 'Vuint', -42.000000 overflows uint" { + t.Errorf("got unexpected error: %s", err) + } +} + +func TestDecodeMetadata(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "Vuint": 42, + "vsilent": "false", + "foo": "bar", + }, + "bar": "nil", + } + + var md Metadata + var result Nested + + err := DecodeMetadata(input, &result, &md) + if err != nil { + t.Fatalf("jderr: %s", err.Error()) + } + + expectedKeys := []string{"Vbar", "Vbar.Vstring", "Vbar.Vuint", "Vfoo"} + sort.Strings(md.Keys) + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{"Vbar.foo", "Vbar.vsilent", "bar"} + sort.Strings(md.Unused) + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } +} + +func TestMetadata(t *testing.T) { + t.Parallel() + + type testResult struct { + Vfoo string + Vbar BasicPointer + } + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "Vuint": 42, + "vsilent": "false", + "foo": "bar", + }, + "bar": "nil", + } + + var md Metadata + var result testResult + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("jderr: %s", err.Error()) + } + + expectedKeys := []string{"Vbar", "Vbar.Vstring", "Vbar.Vuint", "Vfoo"} + sort.Strings(md.Keys) + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{"Vbar.foo", "Vbar.vsilent", "bar"} + sort.Strings(md.Unused) + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } + + expectedUnset := []string{ + "Vbar.Vbool", "Vbar.Vdata", "Vbar.Vextra", "Vbar.Vfloat", "Vbar.Vint", + "Vbar.VjsonFloat", "Vbar.VjsonInt", "Vbar.VjsonNumber"} + sort.Strings(md.Unset) + if !reflect.DeepEqual(md.Unset, expectedUnset) { + t.Fatalf("bad unset: %#v", md.Unset) + } +} + +func TestMetadata_Embedded(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var md Metadata + var result EmbeddedSquash + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("jderr: %s", err.Error()) + } + + expectedKeys := []string{"Vstring", "Vunique"} + + sort.Strings(md.Keys) + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{} + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } +} + +func TestNonPtrValue(t *testing.T) { + t.Parallel() + + err := Decode(map[string]interface{}{}, Basic{}) + if err == nil { + t.Fatal("error should exist") + } + + if err.Error() != "result must be a pointer" { + t.Errorf("got unexpected error: %s", err) + } +} + +func TestTagged(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "bar", + "bar": "value", + } + + var result Tagged + err := Decode(input, &result) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if result.Value != "bar" { + t.Errorf("value should be 'bar', got: %#v", result.Value) + } + + if result.Extra != "value" { + t.Errorf("extra should be 'value', got: %#v", result.Extra) + } +} + +func TestWeakDecode(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "4", + "bar": "value", + } + + var result struct { + Foo int + Bar string + } + + if err := WeakDecode(input, &result); err != nil { + t.Fatalf("jderr: %s", err) + } + if result.Foo != 4 { + t.Fatalf("bad: %#v", result) + } + if result.Bar != "value" { + t.Fatalf("bad: %#v", result) + } +} + +func TestWeakDecodeMetadata(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "4", + "bar": "value", + "unused": "value", + "unexported": "value", + } + + var md Metadata + var result struct { + Foo int + Bar string + unexported string + } + + if err := WeakDecodeMetadata(input, &result, &md); err != nil { + t.Fatalf("jderr: %s", err) + } + if result.Foo != 4 { + t.Fatalf("bad: %#v", result) + } + if result.Bar != "value" { + t.Fatalf("bad: %#v", result) + } + + expectedKeys := []string{"Bar", "Foo"} + sort.Strings(md.Keys) + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{"unexported", "unused"} + sort.Strings(md.Unused) + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } +} + +func TestDecode_StructTaggedWithOmitempty_OmitEmptyValues(t *testing.T) { + t.Parallel() + + input := &StructWithOmitEmpty{} + + var emptySlice []interface{} + var emptyMap map[string]interface{} + var emptyNested *Nested + expected := &map[string]interface{}{ + "visible-string": "", + "visible-int": 0, + "visible-float": 0.0, + "visible-slice": emptySlice, + "visible-map": emptyMap, + "visible-nested": emptyNested, + } + + actual := &map[string]interface{}{} + Decode(input, actual) + + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("Decode() expected: %#v, got: %#v", expected, actual) + } +} + +func TestDecode_StructTaggedWithOmitempty_KeepNonEmptyValues(t *testing.T) { + t.Parallel() + + input := &StructWithOmitEmpty{ + VisibleStringField: "", + OmitStringField: "string", + VisibleIntField: 0, + OmitIntField: 1, + VisibleFloatField: 0.0, + OmitFloatField: 1.0, + VisibleSliceField: nil, + OmitSliceField: []interface{}{1}, + VisibleMapField: nil, + OmitMapField: map[string]interface{}{"k": "v"}, + NestedField: nil, + OmitNestedField: &Nested{}, + } + + var emptySlice []interface{} + var emptyMap map[string]interface{} + var emptyNested *Nested + expected := &map[string]interface{}{ + "visible-string": "", + "omittable-string": "string", + "visible-int": 0, + "omittable-int": 1, + "visible-float": 0.0, + "omittable-float": 1.0, + "visible-slice": emptySlice, + "omittable-slice": []interface{}{1}, + "visible-map": emptyMap, + "omittable-map": map[string]interface{}{"k": "v"}, + "visible-nested": emptyNested, + "omittable-nested": &Nested{}, + } + + actual := &map[string]interface{}{} + Decode(input, actual) + + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("Decode() expected: %#v, got: %#v", expected, actual) + } +} + +func TestDecode_mapToStruct(t *testing.T) { + type Target struct { + String string + StringPtr *string + } + + expected := Target{ + String: "hello", + } + + var target Target + err := Decode(map[string]interface{}{ + "string": "hello", + "StringPtr": "goodbye", + }, &target) + if err != nil { + t.Fatalf("got error: %s", err) + } + + // Pointers fail reflect test so do those manually + if target.StringPtr == nil || *target.StringPtr != "goodbye" { + t.Fatalf("bad: %#v", target) + } + target.StringPtr = nil + + if !reflect.DeepEqual(target, expected) { + t.Fatalf("bad: %#v", target) + } +} + +func TestDecoder_MatchName(t *testing.T) { + t.Parallel() + + type Target struct { + FirstMatch string `mapstructure:"first_match"` + SecondMatch string + NoMatch string `mapstructure:"no_match"` + } + + input := map[string]interface{}{ + "first_match": "foo", + "SecondMatch": "bar", + "NO_MATCH": "baz", + } + + expected := Target{ + FirstMatch: "foo", + SecondMatch: "bar", + } + + var actual Target + config := &DecoderConfig{ + Result: &actual, + MatchName: func(mapKey, fieldName string) bool { + return mapKey == fieldName + }, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Decode() expected: %#v, got: %#v", expected, actual) + } +} + +func TestDecoder_IgnoreUntaggedFields(t *testing.T) { + type Input struct { + UntaggedNumber int + TaggedNumber int `mapstructure:"tagged_number"` + UntaggedString string + TaggedString string `mapstructure:"tagged_string"` + } + input := &Input{ + UntaggedNumber: 31, + TaggedNumber: 42, + UntaggedString: "hidden", + TaggedString: "visible", + } + + actual := make(map[string]interface{}) + config := &DecoderConfig{ + Result: &actual, + IgnoreUntaggedFields: true, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + expected := map[string]interface{}{ + "tagged_number": 42, + "tagged_string": "visible", + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Decode() expected: %#v\ngot: %#v", expected, actual) + } +} + +func testSliceInput(t *testing.T, input map[string]interface{}, expected *Slice) { + var result Slice + err := Decode(input, &result) + if err != nil { + t.Fatalf("got error: %s", err) + } + + if result.Vfoo != expected.Vfoo { + t.Errorf("Vfoo expected '%s', got '%s'", expected.Vfoo, result.Vfoo) + } + + if result.Vbar == nil { + t.Fatalf("Vbar a slice, got '%#v'", result.Vbar) + } + + if len(result.Vbar) != len(expected.Vbar) { + t.Errorf("Vbar length should be %d, got %d", len(expected.Vbar), len(result.Vbar)) + } + + for i, v := range result.Vbar { + if v != expected.Vbar[i] { + t.Errorf( + "Vbar[%d] should be '%#v', got '%#v'", + i, expected.Vbar[i], v) + } + } +} + +func testArrayInput(t *testing.T, input map[string]interface{}, expected *Array) { + var result Array + err := Decode(input, &result) + if err != nil { + t.Fatalf("got error: %s", err) + } + + if result.Vfoo != expected.Vfoo { + t.Errorf("Vfoo expected '%s', got '%s'", expected.Vfoo, result.Vfoo) + } + + if result.Vbar == [2]string{} { + t.Fatalf("Vbar a slice, got '%#v'", result.Vbar) + } + + if len(result.Vbar) != len(expected.Vbar) { + t.Errorf("Vbar length should be %d, got %d", len(expected.Vbar), len(result.Vbar)) + } + + for i, v := range result.Vbar { + if v != expected.Vbar[i] { + t.Errorf( + "Vbar[%d] should be '%#v', got '%#v'", + i, expected.Vbar[i], v) + } + } +} + +func stringPtr(v string) *string { return &v } +func intPtr(v int) *int { return &v } +func uintPtr(v uint) *uint { return &v } +func boolPtr(v bool) *bool { return &v } +func floatPtr(v float64) *float64 { return &v } +func interfacePtr(v interface{}) *interface{} { return &v } diff --git a/pkg/mapstructure/my_decode.go b/pkg/mapstructure/my_decode.go new file mode 100644 index 0000000..6cafe11 --- /dev/null +++ b/pkg/mapstructure/my_decode.go @@ -0,0 +1,26 @@ +package mapstructure + +import "time" + +// DecodeWithTime 支持时间转字符串 +// 支持 +// 1. *Time.time 转 string/*string +// 2. *Time.time 转 uint/uint32/uint64/int/int32/int64,支持带指针 +// 不能用 Time.time 转,它会在上层认为是一个结构体数据而直接转成map,再到hook方法 +func DecodeWithTime(input, output interface{}, layout string) error { + if layout == "" { + layout = time.DateTime + } + config := &DecoderConfig{ + Metadata: nil, + Result: output, + DecodeHook: ComposeDecodeHookFunc(TimeToStringHook(layout), TimeToUnixIntHook()), + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} diff --git a/pkg/mapstructure/my_decode_hook.go b/pkg/mapstructure/my_decode_hook.go new file mode 100644 index 0000000..4873731 --- /dev/null +++ b/pkg/mapstructure/my_decode_hook.go @@ -0,0 +1,101 @@ +package mapstructure + +import ( + "reflect" + "time" +) + +// TimeToStringHook 时间转字符串 +// 支持 *Time.time 转 string/*string +// 不能用 Time.time 转,它会在上层认为是一个结构体数据而直接转成map,再到hook方法 +func TimeToStringHook(layout string) DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + // 判断目标类型是否为字符串 + var strType string + var isStrPointer *bool // 要转换的目标类型是否为指针字符串 + if t == reflect.TypeOf(strType) { + isStrPointer = new(bool) + } else if t == reflect.TypeOf(&strType) { + isStrPointer = new(bool) + *isStrPointer = true + } + if isStrPointer == nil { + return data, nil + } + + // 判断类型是否为时间 + timeType := time.Time{} + if f != reflect.TypeOf(timeType) && f != reflect.TypeOf(&timeType) { + return data, nil + } + + // 将时间转换为字符串 + var output string + switch v := data.(type) { + case *time.Time: + output = v.Format(layout) + case time.Time: + output = v.Format(layout) + default: + return data, nil + } + + if *isStrPointer { + return &output, nil + } + return output, nil + } +} + +// TimeToUnixIntHook 时间转时间戳 +// 支持 *Time.time 转 uint/uint32/uint64/int/int32/int64,支持带指针 +// 不能用 Time.time 转,它会在上层认为是一个结构体数据而直接转成map,再到hook方法 +func TimeToUnixIntHook() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + + tkd := t.Kind() + if tkd != reflect.Int && tkd != reflect.Int32 && tkd != reflect.Int64 && + tkd != reflect.Uint && tkd != reflect.Uint32 && tkd != reflect.Uint64 { + return data, nil + } + + // 判断类型是否为时间 + timeType := time.Time{} + if f != reflect.TypeOf(timeType) && f != reflect.TypeOf(&timeType) { + return data, nil + } + + // 将时间转换为字符串 + var output int64 + switch v := data.(type) { + case *time.Time: + output = v.Unix() + case time.Time: + output = v.Unix() + default: + return data, nil + } + switch tkd { + case reflect.Int: + return int(output), nil + case reflect.Int32: + return int32(output), nil + case reflect.Int64: + return output, nil + case reflect.Uint: + return uint(output), nil + case reflect.Uint32: + return uint32(output), nil + case reflect.Uint64: + return uint64(output), nil + default: + return data, nil + } + } +} diff --git a/pkg/mapstructure/my_decode_hook_test.go b/pkg/mapstructure/my_decode_hook_test.go new file mode 100644 index 0000000..2f5687d --- /dev/null +++ b/pkg/mapstructure/my_decode_hook_test.go @@ -0,0 +1,274 @@ +package mapstructure + +import ( + "testing" + "time" +) + +func Test_TimeToStringHook(t *testing.T) { + type Input struct { + Time time.Time + Id int + } + + type InputTPointer struct { + Time *time.Time + Id int + } + + type Output struct { + Time string + Id int + } + + type OutputTPointer struct { + Time *string + Id int + } + now := time.Now() + target := now.Format("2006-01-02 15:04:05") + idValue := 1 + tests := []struct { + input any + output any + name string + layout string + }{ + { + name: "测试Time.time转string", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output{}, + }, + { + name: "测试*Time.time转*string", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: OutputTPointer{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoder, err := NewDecoder(&DecoderConfig{ + DecodeHook: TimeToStringHook(tt.layout), + Result: &tt.output, + }) + if err != nil { + t.Errorf("NewDecoder() jderr = %v,want nil", err) + } + + if i, isOk := tt.input.(Input); isOk { + err = decoder.Decode(i) + } + if i, isOk := tt.input.(InputTPointer); isOk { + err = decoder.Decode(&i) + } + if err != nil { + t.Errorf("Decode jderr = %v,want nil", err) + } + //验证测试值 + if output, isOk := tt.output.(OutputTPointer); isOk { + if *output.Time != target { + t.Errorf("Decode output time = %v,want %v", *output.Time, target) + } + if output.Id != idValue { + t.Errorf("Decode output id = %v,want %v", output.Id, idValue) + } + } + if output, isOk := tt.output.(Output); isOk { + if output.Time != target { + t.Errorf("Decode output time = %v,want %v", output.Time, target) + } + if output.Id != idValue { + t.Errorf("Decode output id = %v,want %v", output.Id, idValue) + } + } + }) + } +} + +func Test_TimeToUnixIntHook(t *testing.T) { + type InputTPointer struct { + Time *time.Time + Id int + } + + type Output[T int | *int | int32 | *int32 | int64 | *int64 | uint | *uint] struct { + Time T + Id int + } + + type test struct { + input any + output any + name string + layout string + } + + now := time.Now() + target := now.Unix() + idValue := 1 + tests := []test{ + { + name: "测试Time.time转int", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[int]{}, + }, + { + name: "测试Time.time转*int", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[*int]{}, + }, + { + name: "测试Time.time转int32", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[int32]{}, + }, + { + name: "测试Time.time转*int32", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[*int32]{}, + }, + { + name: "测试Time.time转int64", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[int64]{}, + }, + { + name: "测试Time.time转*int64", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[*int64]{}, + }, + { + name: "测试Time.time转uint", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[uint]{}, + }, + { + name: "测试Time.time转*uint", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[*uint]{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoder, err := NewDecoder(&DecoderConfig{ + DecodeHook: TimeToUnixIntHook(), + Result: &tt.output, + }) + if err != nil { + t.Errorf("NewDecoder() jderr = %v,want nil", err) + } + + if i, isOk := tt.input.(InputTPointer); isOk { + err = decoder.Decode(i) + } + if i, isOk := tt.input.(InputTPointer); isOk { + err = decoder.Decode(&i) + } + if err != nil { + t.Errorf("Decode jderr = %v,want nil", err) + } + + //验证测试值 + switch v := tt.output.(type) { + case Output[int]: + if int64(v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[*int]: + if int64(*v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[int32]: + if int64(v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[*int32]: + if int64(*v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[int64]: + if int64(v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[*int64]: + if int64(*v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[uint]: + if int64(v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[*uint]: + if int64(*v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + } + + }) + } +} diff --git a/pkg/response.go b/pkg/response.go new file mode 100644 index 0000000..7df9813 --- /dev/null +++ b/pkg/response.go @@ -0,0 +1,38 @@ +package pkg + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/log" +) + +func HandleResponse(c *fiber.Ctx, data interface{}) (err error) { + if os.Getenv("env") == "unit_test" { + log.Debug(data) + } + switch data.(type) { + case error: + err = data.(error) + case int, int32, int64, float32, float64, string, bool: + c.Response().SetBody([]byte(fmt.Sprintf("%s", data))) + case []byte: + c.Response().SetBody(data.([]byte)) + default: + dataByte, _ := json.Marshal(data) + c.Response().SetBody(dataByte) + } + return +} + +func SuccessWithPageMsg(c *fiber.Ctx, list interface{}, total int64, page, pageSize int) error { + response := fiber.Map{ + "list": list, + "total": total, + "page": page, + "pageSize": pageSize, + } + return HandleResponse(c, response) +} diff --git a/pkg/validata.go b/pkg/validata.go new file mode 100644 index 0000000..5c282e9 --- /dev/null +++ b/pkg/validata.go @@ -0,0 +1,743 @@ +package pkg + +import ( + "fmt" + "geo/tmpl/errcode" + "reflect" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/go-playground/validator/v10" + "github.com/gofiber/fiber/v2" +) + +const ( + // 验证标签常量 + TagRequired = "required" + TagEmail = "email" + TagMin = "min" + TagMax = "max" + TagLen = "len" + TagNumeric = "numeric" + TagOneof = "oneof" + TagGte = "gte" + TagGt = "gt" + TagLte = "lte" + TagLt = "lt" + TagURL = "url" + TagUUID = "uuid" + TagDatetime = "datetime" + + // 中文标签常量 + //LabelComment = "comment" + //LabelLabel = "label" + LabelZh = "zh" + + // 上下文Key + CtxRequestBody = "request_body" + + // 性能优化常量 + InitialBuilderSize = 64 + MaxCacheSize = 5000 + CacheCleanupInterval = 10 * time.Minute +) + +// ValidatorConfig 验证器配置 +type ValidatorConfig struct { + StatusCode int // 验证失败时的HTTP状态码,默认422 + ErrorHandler func(c *fiber.Ctx, err error) error // 自定义错误处理 + BeforeParse func(c *fiber.Ctx) error // 解析前执行 + AfterParse func(c *fiber.Ctx, req interface{}) error // 解析后执行 + DisableCache bool // 是否禁用缓存 + MaxCacheSize int // 最大缓存数量 + UseChinese bool // 是否使用中文提示 + EnableMetrics bool // 是否启用指标收集 +} + +// CacheStats 缓存统计 +type CacheStats struct { + hitCount int64 + missCount int64 + evictCount int64 + errorCount int64 + accessCount int64 +} + +// Metrics 指标收集 +type Metrics struct { + totalRequests int64 + validationTime int64 + cacheHitRate float64 + mu sync.RWMutex +} + +// fieldInfo 字段信息 +type fieldInfo struct { + label string + index int + typeKind reflect.Kind + accessCount int64 + lastAccess time.Time +} + +// typeInfo 类型信息 +type typeInfo struct { + fields map[string]*fieldInfo + fieldNames []string + mu sync.RWMutex + accessCount int64 + createdAt time.Time + lastAccess time.Time + typeKey string +} + +// ValidatorHelper 验证器助手 +type ValidatorHelper struct { + validate *validator.Validate + config *ValidatorConfig + typeCache sync.Map + cacheStats *CacheStats + errorFnCache map[string]func(string, string) string + errorFnCacheMu sync.RWMutex + pruneLock sync.Mutex + stopCleanup chan struct{} + cleanupOnce sync.Once + metrics *Metrics +} + +var ( + ChineseErrorTemplates = map[string]string{ + TagRequired: "%s不能为空", + TagEmail: "%s格式不正确", + TagMin: "%s不能小于%s", + TagMax: "%s不能大于%s", + TagLen: "%s长度必须为%s位", + TagNumeric: "%s必须是数字", + TagOneof: "%s必须是以下值之一: %s", + TagGte: "%s不能小于%s", + TagGt: "%s必须大于%s", + TagLte: "%s不能大于%s", + TagLt: "%s必须小于%s", + TagURL: "%s必须是有效的URL地址", + TagUUID: "%s必须是有效的UUID", + TagDatetime: "%s日期格式不正确", + } + + defaultErrorTemplates = map[string]string{ + TagRequired: "%s is required", + TagEmail: "invalid email format", + TagMin: "%s must be at least %s", + TagMax: "%s must be at most %s", + TagLen: "%s must be exactly %s characters", + TagNumeric: "%s must be numeric", + TagOneof: "%s must be one of: %s", + TagGte: "%s must be greater than or equal to %s", + TagGt: "%s must be greater than %s", + TagLte: "%s must be less than or equal to %s", + TagLt: "%s must be less than %s", + TagURL: "%s must be a valid URL", + TagUUID: "%s must be a valid UUID", + TagDatetime: "%s invalid datetime format", + } +) + +var ( + // 对象池 + builderPool = sync.Pool{ + New: func() interface{} { + b := &strings.Builder{} + b.Grow(InitialBuilderSize) + return b + }, + } + + // 错误信息切片池 + errorSlicePool = sync.Pool{ + New: func() interface{} { + slice := make([]string, 0, 8) + return &slice + }, + } + + // 字段信息对象池 + fieldInfoPool = sync.Pool{ + New: func() interface{} { + return &fieldInfo{ + accessCount: 0, + lastAccess: time.Now(), + } + }, + } + + // 类型信息对象池 + typeInfoPool = sync.Pool{ + New: func() interface{} { + return &typeInfo{ + fields: make(map[string]*fieldInfo), + fieldNames: make([]string, 0, 8), + } + }, + } +) + +var ( + Vh *ValidatorHelper + once sync.Once +) + +// NewValidatorHelper 初始化验证器助手 +func NewValidatorHelper(config ...*ValidatorConfig) { + once.Do(func() { + v := validator.New() + + // 优化JSON标签获取 + v.RegisterTagNameFunc(func(fld reflect.StructField) string { + return fld.Tag.Get("json") + }) + + // 默认配置 + cfg := &ValidatorConfig{ + StatusCode: fiber.StatusUnprocessableEntity, + ErrorHandler: defaultErrorHandler, + MaxCacheSize: MaxCacheSize, + UseChinese: true, + EnableMetrics: false, + } + + if len(config) > 0 && config[0] != nil { + if config[0].StatusCode != 0 { + cfg.StatusCode = config[0].StatusCode + } + if config[0].ErrorHandler != nil { + cfg.ErrorHandler = config[0].ErrorHandler + } + if config[0].MaxCacheSize > 0 { + cfg.MaxCacheSize = config[0].MaxCacheSize + } + cfg.DisableCache = config[0].DisableCache + cfg.BeforeParse = config[0].BeforeParse + cfg.AfterParse = config[0].AfterParse + cfg.UseChinese = config[0].UseChinese + cfg.EnableMetrics = config[0].EnableMetrics + } + + // 预编译错误函数 + errorFnCache := make(map[string]func(string, string) string) + templates := ChineseErrorTemplates + if !cfg.UseChinese { + templates = defaultErrorTemplates + } + + for tag, tmpl := range templates { + t := tmpl // 捕获变量 + errorFnCache[tag] = func(field, param string) string { + if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 { + return fmt.Sprintf(t, field, param) + } + return fmt.Sprintf(t, field) + } + } + + Vh = &ValidatorHelper{ + validate: v, + config: cfg, + cacheStats: &CacheStats{}, + errorFnCache: errorFnCache, + errorFnCacheMu: sync.RWMutex{}, + stopCleanup: make(chan struct{}), + metrics: &Metrics{}, + } + + // 启动定期清理 + if !cfg.DisableCache { + go Vh.periodicCleanup() + } + }) +} + +// ParseAndValidate 解析并验证请求 +func ParseAndValidate(c *fiber.Ctx, req interface{}) error { + if Vh == nil { + NewValidatorHelper() + } + + if Vh.config.EnableMetrics { + atomic.AddInt64(&Vh.metrics.totalRequests, 1) + defer Vh.recordValidationTime(time.Now()) + } + + // 执行解析前钩子 + if Vh.config.BeforeParse != nil { + if err := Vh.config.BeforeParse(c); err != nil { + return err + } + } + + // 解析请求体 + err := c.BodyParser(req) + if err != nil { + return errcode.ParamErr("请求格式错误:" + err.Error()) + } + + // 执行解析后钩子 + if Vh.config.AfterParse != nil { + if err = Vh.config.AfterParse(c, req); err != nil { + return errcode.ParamErr(err.Error()) + } + } + + // 验证数据 + err = Vh.validate.Struct(req) + if err != nil { + c.Locals(CtxRequestBody, req) + + if !Vh.config.DisableCache { + t := reflect.TypeOf(req) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + Vh.safeGetOrCreateTypeInfo(t) + } + } + + return Vh.config.ErrorHandler(c, err) + } + + return nil +} + +// Validate 直接验证结构体 +func Validate(req interface{}) error { + if Vh == nil { + NewValidatorHelper() + } + + if err := Vh.validate.Struct(req); err != nil { + return Vh.wrapValidationError(err, req) + } + return nil +} + +// 默认错误处理 +func defaultErrorHandler(c *fiber.Ctx, err error) error { + validationErrors, ok := err.(validator.ValidationErrors) + if !ok { + return errcode.SystemError + } + + if len(validationErrors) == 0 { + return nil + } + + // 快速路径:单个错误 + if len(validationErrors) == 1 { + e := validationErrors[0] + msg := Vh.safeGetErrorMessage(c, e) + return errcode.ParamErr(msg) + } + + // 从对象池获取builder + builder := builderPool.Get().(*strings.Builder) + builder.Reset() + defer builderPool.Put(builder) + + req := c.Locals(CtxRequestBody) + + for i, e := range validationErrors { + if i > 0 { + builder.WriteByte('\n') + } + builder.WriteString(Vh.safeGetErrorMessageWithReq(req, e)) + } + + return errcode.ParamErr(builder.String()) +} + +// 包装验证错误 +func (vh *ValidatorHelper) wrapValidationError(err error, req interface{}) error { + validationErrors, ok := err.(validator.ValidationErrors) + if !ok { + return err + } + + if len(validationErrors) == 0 { + return nil + } + + // 构建错误消息 + builder := builderPool.Get().(*strings.Builder) + builder.Reset() + defer builderPool.Put(builder) + + for i, e := range validationErrors { + if i > 0 { + builder.WriteByte('\n') + } + builder.WriteString(vh.safeGetErrorMessageWithReq(req, e)) + } + + return errcode.ParamErr(builder.String()) +} + +// 安全获取错误消息 +func (vh *ValidatorHelper) safeGetErrorMessage(c *fiber.Ctx, e validator.FieldError) string { + req := c.Locals(CtxRequestBody) + return vh.safeGetErrorMessageWithReq(req, e) +} + +// 安全获取错误消息(带请求体) +func (vh *ValidatorHelper) safeGetErrorMessageWithReq(req interface{}, e validator.FieldError) string { + if req == nil { + return vh.safeFormatFieldError(e, nil) + } + + t := reflect.TypeOf(req) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return vh.safeFormatFieldError(e, nil) + } + + // 安全获取类型信息 - 这里需要先声明变量 + var typeInfoObj *typeInfo // 这里声明变量 + if !vh.config.DisableCache { + if cached, ok := vh.typeCache.Load(t); ok { + typeInfoObj = cached.(*typeInfo) + atomic.AddInt64(&typeInfoObj.accessCount, 1) + } + } + + return vh.safeFormatFieldError(e, typeInfoObj) +} + +// 安全格式化字段错误 +func (vh *ValidatorHelper) safeFormatFieldError(e validator.FieldError, typeInfo *typeInfo) string { + structField := e.StructField() + fieldName := e.Field() + + // 获取字段标签 + var label string + if typeInfo != nil { + typeInfo.mu.RLock() + if info, ok := typeInfo.fields[structField]; ok { + label = info.label + atomic.AddInt64(&info.accessCount, 1) + } + typeInfo.mu.RUnlock() + } + + // 如果没有标签,返回默认消息 + if label == "" { + return vh.safeGetDefaultErrorMessage(fieldName, e) + } + + // 使用预编译的错误函数生成消息 + vh.errorFnCacheMu.RLock() + fn, ok := vh.errorFnCache[e.Tag()] + vh.errorFnCacheMu.RUnlock() + + if ok { + return fn(label, e.Param()) + } + + return label + "格式不正确" +} + +// 安全获取默认错误消息 +func (vh *ValidatorHelper) safeGetDefaultErrorMessage(field string, e validator.FieldError) string { + vh.errorFnCacheMu.RLock() + defer vh.errorFnCacheMu.RUnlock() + + if fn, ok := vh.errorFnCache[e.Tag()]; ok { + return fn(field, e.Param()) + } + + return field + "验证失败" +} + +// safeGetOrCreateTypeInfo 安全地获取或创建类型信息 +func (vh *ValidatorHelper) safeGetOrCreateTypeInfo(t reflect.Type) *typeInfo { + if vh.config.DisableCache || t == nil || t.Kind() != reflect.Struct { + return nil + } + + // 首先尝试从缓存读取 + if cached, ok := vh.typeCache.Load(t); ok { + info := cached.(*typeInfo) + atomic.AddInt64(&info.accessCount, 1) + atomic.AddInt64(&vh.cacheStats.hitCount, 1) + info.lastAccess = time.Now() + return info + } + + atomic.AddInt64(&vh.cacheStats.missCount, 1) + + // 从对象池获取typeInfo + info := typeInfoPool.Get().(*typeInfo) + + // 重置并初始化 + info.mu.Lock() + + // 清空现有map + for k := range info.fields { + delete(info.fields, k) + } + info.fieldNames = info.fieldNames[:0] + + info.accessCount = 0 + info.createdAt = time.Now() + info.lastAccess = time.Now() + info.typeKey = t.String() + + // 预计算所有字段信息 + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + // 获取标签 + //label := field.Tag.Get(LabelComment) + //if label == "" { + // label = field.Tag.Get(LabelLabel) + //} + //if label == "" { + // label = field.Tag.Get(LabelZh) + //} + label := field.Tag.Get(LabelZh) + // 从对象池获取或创建字段信息 + fieldInfo := fieldInfoPool.Get().(*fieldInfo) + fieldInfo.label = label + fieldInfo.index = i + fieldInfo.typeKind = field.Type.Kind() + fieldInfo.accessCount = 0 + fieldInfo.lastAccess = time.Now() + + info.fields[field.Name] = fieldInfo + info.fieldNames = append(info.fieldNames, field.Name) + } + + info.mu.Unlock() + + // 使用原子操作确保线程安全的存储 + if existing, loaded := vh.typeCache.LoadOrStore(t, info); loaded { + // 如果已经有其他goroutine存储了,使用已有的并回收新创建的 + info.mu.Lock() + for _, fieldInfo := range info.fields { + fieldInfoPool.Put(fieldInfo) + } + info.mu.Unlock() + typeInfoPool.Put(info) + + existingInfo := existing.(*typeInfo) + atomic.AddInt64(&existingInfo.accessCount, 1) + return existingInfo + } + + return info +} + +// 定期清理缓存 +func (vh *ValidatorHelper) periodicCleanup() { + ticker := time.NewTicker(CacheCleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + vh.safeCleanupCache() + case <-vh.stopCleanup: + return + } + } +} + +// safeCleanupCache 安全清理缓存 +func (vh *ValidatorHelper) safeCleanupCache() { + vh.pruneLock.Lock() + defer vh.pruneLock.Unlock() + + var keysToDelete []interface{} + now := time.Now() + + vh.typeCache.Range(func(key, value interface{}) bool { + info, ok := value.(*typeInfo) + if !ok { + return true + } + + // 检查是否需要清理 + accessCount := atomic.LoadInt64(&info.accessCount) + age := now.Sub(info.createdAt) + idleTime := now.Sub(info.lastAccess) + + // 清理条件: + // 1. 很少访问的缓存(访问次数 < 10) + // 2. 空闲时间超过30分钟 + // 3. 缓存年龄超过1小时且访问次数较少 + if (accessCount < 10 && idleTime > 30*time.Minute) || + (age > 1*time.Hour && accessCount < 100) { + keysToDelete = append(keysToDelete, key) + atomic.AddInt64(&vh.cacheStats.evictCount, 1) + } + + return true + }) + + // 删除选中的缓存 + for _, key := range keysToDelete { + if val, ok := vh.typeCache.Load(key); ok { + if info, ok := val.(*typeInfo); ok { + // 安全回收字段信息 + info.mu.Lock() + for _, fieldInfo := range info.fields { + fieldInfoPool.Put(fieldInfo) + } + info.mu.Unlock() + typeInfoPool.Put(info) + } + vh.typeCache.Delete(key) + } + } +} + +// ==================== 指标收集 ==================== + +func (vh *ValidatorHelper) recordValidationTime(start time.Time) { + if !vh.config.EnableMetrics { + return + } + duration := time.Since(start).Nanoseconds() + atomic.AddInt64(&vh.metrics.validationTime, duration) +} + +// GetMetrics 获取性能指标 +func (vh *ValidatorHelper) GetMetrics() map[string]interface{} { + if !vh.config.EnableMetrics { + return nil + } + + vh.metrics.mu.RLock() + defer vh.metrics.mu.RUnlock() + + hitCount := atomic.LoadInt64(&vh.cacheStats.hitCount) + missCount := atomic.LoadInt64(&vh.cacheStats.missCount) + totalRequests := hitCount + missCount + + var hitRate float64 + if totalRequests > 0 { + hitRate = float64(hitCount) / float64(totalRequests) * 100 + } + + var cacheSize int64 + vh.typeCache.Range(func(_, _ interface{}) bool { + cacheSize++ + return true + }) + + return map[string]interface{}{ + "cache_hit_count": hitCount, + "cache_miss_count": missCount, + "cache_evict_count": atomic.LoadInt64(&vh.cacheStats.evictCount), + "cache_hit_rate": fmt.Sprintf("%.2f%%", hitRate), + "cache_size": cacheSize, + "error_count": atomic.LoadInt64(&vh.cacheStats.errorCount), + "total_requests": atomic.LoadInt64(&vh.metrics.totalRequests), + "avg_validation_time_ns": atomic.LoadInt64(&vh.metrics.validationTime) / + max(atomic.LoadInt64(&vh.metrics.totalRequests), 1), + } +} + +// RegisterValidation 注册自定义验证规则 +func (vh *ValidatorHelper) RegisterValidation(tag string, fn validator.Func, callValidationEvenIfNull ...bool) error { + return vh.validate.RegisterValidation(tag, fn, callValidationEvenIfNull...) +} + +// RegisterTranslation 注册自定义翻译 +func (vh *ValidatorHelper) RegisterTranslation(tag string, template string) { + vh.errorFnCacheMu.Lock() + defer vh.errorFnCacheMu.Unlock() + + t := template + vh.errorFnCache[tag] = func(field, param string) string { + if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 { + return fmt.Sprintf(t, field, param) + } + return fmt.Sprintf(t, field) + } +} + +// Stop 停止后台清理任务 +func (vh *ValidatorHelper) Stop() { + close(vh.stopCleanup) +} + +// Reset 重置验证器状态 +func (vh *ValidatorHelper) Reset() { + vh.ClearCache() + atomic.StoreInt64(&vh.cacheStats.hitCount, 0) + atomic.StoreInt64(&vh.cacheStats.missCount, 0) + atomic.StoreInt64(&vh.cacheStats.evictCount, 0) + atomic.StoreInt64(&vh.cacheStats.errorCount, 0) + atomic.StoreInt64(&vh.metrics.totalRequests, 0) + atomic.StoreInt64(&vh.metrics.validationTime, 0) +} + +// ClearCache 清理所有缓存 +func (vh *ValidatorHelper) ClearCache() { + vh.pruneLock.Lock() + defer vh.pruneLock.Unlock() + + // 安全回收所有缓存 + vh.typeCache.Range(func(key, value interface{}) bool { + if info, ok := value.(*typeInfo); ok { + info.mu.Lock() + for _, fieldInfo := range info.fields { + fieldInfoPool.Put(fieldInfo) + } + info.mu.Unlock() + typeInfoPool.Put(info) + } + vh.typeCache.Delete(key) + return true + }) +} + +// SetLanguage 设置语言 +func (vh *ValidatorHelper) SetLanguage(useChinese bool) { + vh.config.UseChinese = useChinese + + templates := defaultErrorTemplates + if useChinese { + templates = ChineseErrorTemplates + } + + vh.errorFnCacheMu.Lock() + defer vh.errorFnCacheMu.Unlock() + + for tag, tmpl := range templates { + t := tmpl + vh.errorFnCache[tag] = func(field, param string) string { + if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 { + return fmt.Sprintf(t, field, param) + } + return fmt.Sprintf(t, field) + } + } +} + +func max(a, b int64) int64 { + if a > b { + return a + } + return b +} + +func GetErr(tag, field, param string) string { + t := ChineseErrorTemplates[tag] + if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 { + return fmt.Sprintf(t, field, param) + } + return fmt.Sprintf(t, field) +} diff --git a/pkg/wx.go b/pkg/wx.go new file mode 100644 index 0000000..4a7ffd8 --- /dev/null +++ b/pkg/wx.go @@ -0,0 +1,195 @@ +package pkg + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "os" + "sync" + "time" + + "net/http" + + "errors" + + "github.com/redis/go-redis/v9" +) + +type WeChatLoginResponse struct { + OpenID string `json:"openid"` // 用户唯一标识 + SessionKey string `json:"session_key"` // 会话密钥 + UnionID string `json:"unionid"` // 用户在开放平台的唯一标识(如果绑定了开放平台才有) + Errcode int `json:"errcode"` // 错误码,0为成功 + Errmsg string `json:"errmsg"` // 错误信息 +} + +func GetOpenID(appID, appSecret, code string) (openid string, err error) { + if os.Getenv("env") == "unit_test" { + return "test_123456", nil + } + // 1. 构建请求微信接口的 URL + url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", + appID, appSecret, code) + + // 2. 创建 HTTP 客户端(设置超时,避免阻塞) + client := &http.Client{ + Timeout: 5 * time.Second, + // 在某些网络受限的环境(如本地测试跳过证书验证,生产环境建议去掉) + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: false}, // 生产环境建议设为 false + }, + } + + // 3. 发起 GET 请求 + resp, err := client.Get(url) + if err != nil { + return "", fmt.Errorf("请求微信服务器失败: %w", err) + } + defer resp.Body.Close() + + // 4. 读取返回的 Body + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("读取微信响应失败: %w", err) + } + + // 5. 解析 JSON 数据 + var wechatResp WeChatLoginResponse + err = json.Unmarshal(body, &wechatResp) + if err != nil { + return "", fmt.Errorf("解析微信响应 JSON 失败: %s, 原始数据: %s", err.Error(), string(body)) + } + + // 6. 检查微信接口返回的错误码 + if wechatResp.Errcode != 0 { + // 这里可以根据不同的错误码做特殊处理,例如 code 无效、过期等 + return "", fmt.Errorf("微信接口返回错误: code=%d, msg=%s", wechatResp.Errcode, wechatResp.Errmsg) + } + + // 7. 检查 OpenID 是否为空(理论上不会,但防御性编程) + if wechatResp.OpenID == "" { + return "", errors.New("微信返回的 OpenID 为空") + } + + // 8. 返回 OpenID + return wechatResp.OpenID, nil +} + +type AccessTokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Errcode int `json:"errcode"` + Errmsg string `json:"errmsg"` +} + +var ( + tokenMutex sync.Mutex + cacheKey = "wx:access_token" +) + +// GetAccessToken 获取 access_token,带本地缓存 +func GetAccessToken(ctx context.Context, appID, appSecret string, rdb *redis.Client) (string, error) { + if rdb == nil { + return "", errors.New("缓存工具未提供") + } + + cacheToken := rdb.Get(ctx, cacheKey).Val() + if cacheToken != "" { + return cacheToken, nil + } + tokenMutex.Lock() + defer tokenMutex.Unlock() + + // 请求微信接口获取新的 access_token + url := fmt.Sprintf("https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid=%s&secret=%s", appID, appSecret) + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Get(url) + if err != nil { + return "", fmt.Errorf("请求 access_token 失败: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var tokenRes AccessTokenResponse + if err := json.Unmarshal(body, &tokenRes); err != nil { + return "", fmt.Errorf("解析 access_token 响应失败: %w", err) + } + + if tokenRes.Errcode != 0 { + return "", fmt.Errorf("获取 access_token 失败: code=%d, msg=%s", tokenRes.Errcode, tokenRes.Errmsg) + } + + // 缓存 token,提前5分钟过期,避免边界情况 + + rdb.Set(ctx, cacheKey, tokenRes.AccessToken, time.Duration(tokenRes.ExpiresIn-300)*time.Second) + return tokenRes.AccessToken, nil +} + +// PhoneInfo 定义手机号信息的结构体,与微信官方文档对齐 [citation:3][citation:8] +type PhoneInfo struct { + PhoneNumber string `json:"phoneNumber"` // 用户绑定的手机号(国外手机号会有区号) + PurePhoneNumber string `json:"purePhoneNumber"` // 没有区号的手机号 + CountryCode string `json:"countryCode"` // 区号 + Watermark struct { + Timestamp int64 `json:"timestamp"` + Appid string `json:"appid"` + } `json:"watermark"` +} + +// PhoneInfoResponse 定义微信接口返回的完整结构 +type PhoneInfoResponse struct { + Errcode int `json:"errcode"` + Errmsg string `json:"errmsg"` + PhoneInfo PhoneInfo `json:"phone_info"` +} + +// GetPhoneNumber 通过手机号 code 获取用户手机号 +// 参数: +// - appID: 小程序的 AppID +// - appSecret: 小程序的 AppSecret +// - phoneCode: 前端通过 getPhoneNumber 获取的 code +// +// 返回: +// - *PhoneInfo: 手机号信息 +// - error: 错误信息 +func GetPhoneNumber(ctx context.Context, appID, appSecret, phoneCode string, rdb *redis.Client) (*PhoneInfo, error) { + // 1. 获取 access_token + accessToken, err := GetAccessToken(ctx, appID, appSecret, rdb) + if err != nil { + return nil, fmt.Errorf("获取 access_token 失败: %w", err) + } + + // 2. 调用微信接口换取手机号 [citation:8] + url := fmt.Sprintf("https://api.weixin.qq.com/wxa/business/getuserphonenumber?access_token=%s", accessToken) + + // 构建请求体 + requestBody := map[string]string{ + "code": phoneCode, + } + jsonBody, _ := json.Marshal(requestBody) + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Post(url, "application/json", bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("请求手机号接口失败: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + + var phoneResp PhoneInfoResponse + if err := json.Unmarshal(body, &phoneResp); err != nil { + return nil, fmt.Errorf("解析手机号响应失败: %s", string(body)) + } + + // 3. 检查微信接口返回的错误码 [citation:8] + if phoneResp.Errcode != 0 { + return nil, fmt.Errorf("微信接口返回错误: code=%d, msg=%s", phoneResp.Errcode, phoneResp.Errmsg) + } + + return &phoneResp.PhoneInfo, nil +} diff --git a/tmpl/dataTemp/queryTempl.go b/tmpl/dataTemp/queryTempl.go new file mode 100644 index 0000000..c871f3b --- /dev/null +++ b/tmpl/dataTemp/queryTempl.go @@ -0,0 +1,289 @@ +package dataTemp + +import ( + "context" + "database/sql" + "fmt" + "geo/tmpl/errcode" + "geo/utils" + "reflect" + + "github.com/go-kratos/kratos/v2/log" + "gorm.io/gorm" + + "xorm.io/builder" +) + +type PrimaryKey struct { + Id int `json:"id"` +} + +type GormDb struct { + Client *gorm.DB +} +type contextTxKey struct{} + +func (d *Db) DB(ctx context.Context) *gorm.DB { + tx, ok := ctx.Value(contextTxKey{}).(*gorm.DB) + if ok { + return tx + } + return d.Db.Client +} + +func (t *Db) ExecTx(ctx context.Context, f func(ctx context.Context) error) error { + return t.Db.Client.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + ctx = context.WithValue(ctx, contextTxKey{}, tx) + return f(ctx) + }) +} + +type Db struct { + Db *GormDb + Log *log.Helper +} + +type DataTemp struct { + Db *gorm.DB + ModelType reflect.Type // 改为存储类型而不是实例 + modelName string // 可选的表名缓存 +} + +func NewDataTemp(db *utils.Db, model interface{}) *DataTemp { + // 获取模型的类型 + t := reflect.TypeOf(model) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + return &DataTemp{ + Db: db.Client, + ModelType: t, + } +} + +func (k DataTemp) modelInstance() interface{} { + return reflect.New(k.ModelType).Interface() +} + +func (k DataTemp) GetById(id int32) (data map[string]interface{}, err error) { + err = k.Db.Model(k.modelInstance()).Where("id = ?", id).Find(&data).Error + if data == nil { + err = sql.ErrNoRows + } + return +} + +func (k DataTemp) GetByStruct(ctx context.Context, search interface{}, data interface{}, orderBy string) (err error) { + + err = k.Db.Model(k.modelInstance()).WithContext(ctx).Where(search).Find(&data).Error + + if data == nil { + err = sql.ErrNoRows + } + return +} + +func (k DataTemp) SaveByStruct(search interface{}, data interface{}) (err error) { + err = k.Db.Model(k.modelInstance()).Where(search).Save(&data).Error + if data == nil { + err = sql.ErrNoRows + } + return +} + +func (k DataTemp) Add(ctx context.Context, data interface{}) (err error) { + m := k.modelInstance() + if err = k.Db.Model(m).WithContext(ctx).Create(data).Error; err != nil { + return errcode.SqlErr(err) + } + return +} + +func (k DataTemp) AddWithData(data interface{}) (interface{}, error) { + result := k.Db.Model(k.modelInstance()).Create(data) + if result.Error != nil { + return data, result.Error + } + return data, nil +} + +func (k DataTemp) GetList(cond *builder.Cond, pageBoIn *ReqPageBo) (list []map[string]interface{}, pageBoOut *RespPageBo, err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.modelInstance()).Where(query) + total int64 + ) + model.Count(&total) + pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn) + model.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order("updated_at desc").Find(&list) + return +} + +func (k DataTemp) GetRange(ctx context.Context, cond *builder.Cond) (list []map[string]interface{}, err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.modelInstance()).Where(query) + ) + err = model.WithContext(ctx).Find(&list).Error + return list, err +} + +func (k DataTemp) GetRangeToMapStruct(ctx context.Context, cond *builder.Cond, data interface{}) (err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.modelInstance()).Where(query) + ) + err = model.WithContext(ctx).Find(data).Error + return err +} + +func (k DataTemp) GetOneBySearch(cond *builder.Cond) (data map[string]interface{}, err error) { + query, _ := builder.ToBoundSQL(*cond) + if err = k.Db.Model(k.modelInstance()).Where(query).Limit(1).Find(&data).Error; err != nil { + return nil, errcode.SqlErr(err) + } + + return +} + +func (k DataTemp) Exist(ctx context.Context, cond *builder.Cond) (bool, error) { + var data map[string]interface{} + query, _ := builder.ToBoundSQL(*cond) + err := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(query).Limit(1).Find(&data).Error + if err != nil || data != nil { + return true, err + } + return false, nil +} + +func (k DataTemp) GetOneBySearchStruct(ctx context.Context, cond *builder.Cond, data interface{}) (err error) { + query, _ := builder.ToBoundSQL(*cond) + if err = k.Db.Model(k.modelInstance()).WithContext(ctx).Where(query).Limit(1).Find(&data).Error; err != nil { + return errcode.SqlErr(err) + } + + return +} + +func (k DataTemp) GetListToStruct(ctx context.Context, cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}, orderBy string) (pageBoOut *RespPageBo, err error) { + // 参数验证 + if result == nil { + return nil, fmt.Errorf("result cannot be nil") + } + + val := reflect.ValueOf(result) + if val.Kind() != reflect.Ptr { + return nil, fmt.Errorf("result must be a pointer") + } + + elem := val.Elem() + if elem.Kind() != reflect.Slice { + return nil, fmt.Errorf("result must be a pointer to slice") + } + + // 构建基础查询 + query, _ := builder.ToBoundSQL(*cond) + + // 预编译 SQL 以提高性能 + // 使用 Table 指定表名,避免 GORM 的反射开销 + + db := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(query) + + // 获取总数(使用单独的计数查询,避免缓存影响) + var total int64 + countDb := db + if pageBoIn != nil { + if err = countDb.Count(&total).Error; err != nil { + return nil, err + } + } + + // 初始化分页响应 + pageBoOut = &RespPageBo{} + pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn) + + // 如果没有数据,直接返回空切片 + if total == 0 && pageBoIn != nil { + elem.Set(reflect.MakeSlice(elem.Type(), 0, 0)) + return pageBoOut, nil + } + + // 设置排序(使用索引字段提高性能) + if orderBy == "" { + orderBy = "updated_at desc" + } + + // 应用分页和排序,执行查询 + // 使用 Select 指定字段,避免查询所有字段(如果需要优化) + baseQuery := db + if pageBoIn != nil { + baseQuery = db.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order(orderBy) + } + if err = baseQuery. + Order(orderBy). + Find(result).Error; err != nil { + return nil, err + } + + return pageBoOut, nil +} + +func (k DataTemp) UpdateByKey(ctx context.Context, key string, id interface{}, data interface{}) (err error) { + if err = k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), id).Updates(data).Error; err != nil { + return errcode.SqlErr(err) + } + return +} + +func (k DataTemp) UpdateByCond(ctx context.Context, cond *builder.Cond, data interface{}) (err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.modelInstance()).Where(query) + ) + err = model.WithContext(ctx).Updates(data).Error + return err +} + +func (k DataTemp) UpdateColumnByCond(ctx context.Context, cond *builder.Cond, column string, data interface{}) (err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.modelInstance()).Where(query) + ) + err = model.WithContext(ctx).Update(column, data).Error + return err +} + +func (k DataTemp) GetByKey(ctx context.Context, key string, value interface{}, data interface{}) (err error) { + if err = k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), value).Find(data).Error; err != nil { + return errcode.SqlErr(err) + } + return +} + +func (k DataTemp) DeleteByKey(ctx context.Context, key string, value int64) error { + result := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), value). + Update("deleted_at", gorm.Expr("CURRENT_TIMESTAMP")) + + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return errcode.NotFound("不存在或已被删除") + } + return nil +} + +func (k DataTemp) CountByCond(ctx context.Context, cond *builder.Cond) (int64, error) { + var ( + count int64 + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.modelInstance()).Where(query) + ) + err := model.WithContext(ctx).Count(&count).Error + if err != nil { + return 0, err + } + + return count, err +} diff --git a/tmpl/dataTemp/req_page.go b/tmpl/dataTemp/req_page.go new file mode 100644 index 0000000..87f17a2 --- /dev/null +++ b/tmpl/dataTemp/req_page.go @@ -0,0 +1,33 @@ +package dataTemp + +// ReqPageBo 分页请求实体 +type ReqPageBo struct { + Page int //页码,从第1页开始 + Limit int //分页大小 +} + +// GetOffset 获取便宜量 +// 确保 dataTemp/page.go 中有这些方法 +func (p *ReqPageBo) GetSize() int { + if p == nil { + return 10 // 默认每页10条 + } + if p.Limit <= 0 { + return 10 + } + return p.Limit +} + +func (p *ReqPageBo) GetOffset() int { + if p == nil { + return 0 + } + return (p.GetPage() - 1) * p.GetSize() +} + +func (p *ReqPageBo) GetPage() int { + if p == nil || p.Page <= 0 { + return 1 + } + return p.Page +} diff --git a/tmpl/dataTemp/resp_page.go b/tmpl/dataTemp/resp_page.go new file mode 100644 index 0000000..e707c8c --- /dev/null +++ b/tmpl/dataTemp/resp_page.go @@ -0,0 +1,22 @@ +package dataTemp + +// RespPageBo 分页响应实体 +type RespPageBo struct { + Page int //页码 + Limit int //每页大小 + Total int64 //总数 +} + +// SetDataByReq 通过req 设置响应参数 +func (r *RespPageBo) SetDataByReq(total int64, reqPage *ReqPageBo) *RespPageBo { + resp := r + if r == nil { + resp = &RespPageBo{} + } + resp.Total = total + if reqPage != nil { + resp.Page = reqPage.Page + resp.Limit = reqPage.Limit + } + return resp +} diff --git a/tmpl/errcode/common.go b/tmpl/errcode/common.go new file mode 100644 index 0000000..458b7db --- /dev/null +++ b/tmpl/errcode/common.go @@ -0,0 +1,113 @@ +package errcode + +import "fmt" + +var ( + AuthNotFound = &BusinessErr{code: AuthErr, message: "账号不存在"} + AuthStatusFreeze = &BusinessErr{code: AuthErr, message: "账号冻结"} + AuthStatusDel = &BusinessErr{code: AuthErr, message: "身份验证失败"} + AuthStatusPwdFail = &BusinessErr{code: AuthErr, message: "密码错误"} + AuthTokenCreateFail = &BusinessErr{code: AuthErr, message: "token生成失败"} + AuthTokenDelFail = &BusinessErr{code: AuthErr, message: "删除token失败"} + AuthWxLoginFail = &BusinessErr{code: AuthErr, message: "微信登录失败,请稍后重试"} + AuthInfoFail = &BusinessErr{code: AuthErr, message: "登录异常,请重新登录"} + + TokenNotFound = &BusinessErr{code: TokenErr, message: "缺少 Authorization Header"} + TokenFormatErr = &BusinessErr{code: TokenErr, message: "无效的token格式"} + TokenInfoNotFound = &BusinessErr{code: TokenErr, message: "未找到用户信息"} + TokenInvalid = &BusinessErr{code: TokenErr, message: "token过期"} + + PlatsNotFound = &BusinessErr{code: NotFoundErr, message: "信息未找到"} + BadRequest = &BusinessErr{code: BadReqErr, message: "操作失败"} + + ForbiddenError = &BusinessErr{code: ForbiddenErr, message: "权限不足"} + Success = &BusinessErr{code: 200, message: "成功"} + ParamError = &BusinessErr{code: ParamsErr, message: "参数错误"} + + SystemError = &BusinessErr{code: 405, message: "系统错误"} + + ClientNotFound = &BusinessErr{code: 406, message: "未找到client_id"} + SessionNotFound = &BusinessErr{code: 407, message: "未找到会话信息"} + UserNotFound = &BusinessErr{code: NotFoundErr, message: "不存在的用户"} + + KeyNotFound = &BusinessErr{code: 409, message: "身份验证失败"} + SysNotFound = &BusinessErr{code: 410, message: "未找到系统信息"} + SysCodeNotFound = &BusinessErr{code: 411, message: "未找到系统编码"} + InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"} + WorkflowError = &BusinessErr{code: 501, message: "工作流过程错误"} + ClientInfoNotFound = &BusinessErr{code: NotFoundErr, message: "用户信息未找到"} +) + +const ( + InvalidParamCode = 408 + AuthErr = 403 + TokenErr = 401 + ParamsErr = 422 + BadReqErr = 400 + NotFoundErr = 404 + ForbiddenErr = 403 + BalanceNotEnoughCode = 402 +) + +type BusinessErr struct { + code int + message string +} + +func (e *BusinessErr) Error() string { + return e.message +} +func (e *BusinessErr) Code() int { + return e.code +} + +func NotFound(message string) *BusinessErr { + return &BusinessErr{code: NotFoundErr, message: PlatsNotFound.message + ":" + message} +} + +func (e *BusinessErr) Is(target error) bool { + _, ok := target.(*BusinessErr) + return ok +} + +// CustomErr 自定义错误 +func NewBusinessErr(code int, message string) *BusinessErr { + return &BusinessErr{code: code, message: message} +} + +func SysErrf(message string, arg ...any) *BusinessErr { + return &BusinessErr{code: SystemError.code, message: fmt.Sprintf(message, arg)} +} + +func SysErr(message string) *BusinessErr { + return &BusinessErr{code: SystemError.code, message: message} +} + +func ParamErrf(message string, arg ...any) *BusinessErr { + return &BusinessErr{code: ParamError.code, message: fmt.Sprintf(message, arg)} +} + +func ParamErr(message string) *BusinessErr { + return &BusinessErr{code: ParamError.code, message: ParamError.message + ":" + message} +} + +func SqlErr(err error) *BusinessErr { + + return &BusinessErr{code: ParamError.code, message: "数据操作失败,请联系管理员处理:" + err.Error()} +} + +func BadReq(message string) *BusinessErr { + return &BusinessErr{code: BadReqErr, message: BadRequest.message + ":" + message} +} + +func Forbidden(message string) *BusinessErr { + return &BusinessErr{code: ForbiddenErr, message: ForbiddenError.message + ":" + message} +} + +func (e *BusinessErr) Wrap(err error) *BusinessErr { + return NewBusinessErr(e.code, err.Error()) +} + +func BalanceNotEnoughErr(message string) *BusinessErr { + return NewBusinessErr(BalanceNotEnoughCode, message) +} diff --git a/utils/gorm.go b/utils/gorm.go new file mode 100644 index 0000000..9923987 --- /dev/null +++ b/utils/gorm.go @@ -0,0 +1,133 @@ +package utils + +import ( + "geo/internal/config" + "geo/utils/utils_gorm" + + "gorm.io/gorm" +) + +type Db struct { + Client *gorm.DB +} + +func NewGormDb(c *config.Config) (*Db, func()) { + transDBClient, mf := utils_gorm.DBConn(&c.DB) + //directDBClient, df := directDB(c, hLog) + cleanup := func() { + mf() + //df() + } + return &Db{ + Client: transDBClient, + //DirectDBClient: directDBClient, + }, cleanup +} + +// GetOne 查询单条记录,返回 map +func (d *Db) GetOne(sql string, args ...interface{}) (map[string]interface{}, error) { + var result map[string]interface{} + + // 使用 Raw 执行原生 SQL,Scan 到 map 需要先获取 rows + rows, err := d.Client.Raw(sql, args...).Rows() + if err != nil { + return nil, err + } + defer rows.Close() + + if rows.Next() { + // 获取列名 + columns, err := rows.Columns() + if err != nil { + return nil, err + } + + // 创建扫描用的切片 + values := make([]interface{}, len(columns)) + valuePtrs := make([]interface{}, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + + if err := rows.Scan(valuePtrs...); err != nil { + return nil, err + } + + result = make(map[string]interface{}) + for i, col := range columns { + result[col] = values[i] + } + return result, nil + } + return nil, nil +} + +// GetAll 查询多条记录,返回 map 切片 +func (d *Db) GetAll(sql string, args ...interface{}) ([]map[string]interface{}, error) { + rows, err := d.Client.Raw(sql, args...).Rows() + if err != nil { + return nil, err + } + defer rows.Close() + + results := make([]map[string]interface{}, 0) + + for rows.Next() { + columns, err := rows.Columns() + if err != nil { + return nil, err + } + + values := make([]interface{}, len(columns)) + valuePtrs := make([]interface{}, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + + if err := rows.Scan(valuePtrs...); err != nil { + return nil, err + } + + row := make(map[string]interface{}) + for i, col := range columns { + row[col] = values[i] + } + results = append(results, row) + } + return results, nil +} + +// Execute 执行单条 SQL(INSERT/UPDATE/DELETE),返回影响行数 +func (d *Db) Execute(sql string, args ...interface{}) (int64, error) { + result := d.Client.Exec(sql, args...) + if result.Error != nil { + return 0, result.Error + } + return result.RowsAffected, nil +} + +// ExecuteMany 批量执行 SQL,使用事务 +func (d *Db) ExecuteMany(sql string, argsList [][]interface{}) (int64, error) { + var total int64 + + // 开始事务 + tx := d.Client.Begin() + if tx.Error != nil { + return 0, tx.Error + } + + for _, args := range argsList { + result := tx.Exec(sql, args...) + if result.Error != nil { + tx.Rollback() + return 0, result.Error + } + total += result.RowsAffected + } + + // 提交事务 + if err := tx.Commit().Error; err != nil { + return 0, err + } + return total, nil +} diff --git a/utils/provider_set.go b/utils/provider_set.go new file mode 100644 index 0000000..09fee00 --- /dev/null +++ b/utils/provider_set.go @@ -0,0 +1,9 @@ +package utils + +import ( + "github.com/google/wire" +) + +var ProviderUtils = wire.NewSet( + NewGormDb, +) diff --git a/utils/utils_gorm/gorm.go b/utils/utils_gorm/gorm.go new file mode 100644 index 0000000..d64b671 --- /dev/null +++ b/utils/utils_gorm/gorm.go @@ -0,0 +1,41 @@ +package utils_gorm + +import ( + "database/sql" + "fmt" + "geo/internal/config" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "time" +) + +func DBConn(c *config.DB) (*gorm.DB, func()) { + mysqlConn, err := sql.Open(c.Driver, c.Source) + gormDB, err := gorm.Open( + mysql.New(mysql.Config{Conn: mysqlConn}), + ) + + gormDB.Logger = NewCustomLogger(gormDB) + if err != nil { + panic("failed to connect database") + } + sqlDB, err := gormDB.DB() + + // SetMaxIdleConns sets the maximum number of connections in the idle connection pool. + sqlDB.SetMaxIdleConns(int(c.MaxIdle)) + + // SetMaxOpenConns sets the maximum number of open connections to the database. + sqlDB.SetMaxOpenConns(int(c.MaxLifetime)) + + // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. + sqlDB.SetConnMaxLifetime(time.Hour) + + return gormDB, func() { + if mysqlConn != nil { + fmt.Println("关闭 physicalGoodsDB") + if err := mysqlConn.Close(); err != nil { + fmt.Println("关闭 physicalGoodsDB 失败:", err) + } + } + } +} diff --git a/utils/utils_gorm/sql_log.go b/utils/utils_gorm/sql_log.go new file mode 100644 index 0000000..b34a29d --- /dev/null +++ b/utils/utils_gorm/sql_log.go @@ -0,0 +1,96 @@ +package utils_gorm + +import ( + "context" + "fmt" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "regexp" + "strings" + "time" +) + +type CustomLogger struct { + gormLogger logger.Interface + db *gorm.DB +} + +func NewCustomLogger(db *gorm.DB) *CustomLogger { + return &CustomLogger{ + gormLogger: logger.Default.LogMode(logger.Info), + db: db, + } +} + +func (l *CustomLogger) LogMode(level logger.LogLevel) logger.Interface { + newlogger := *l + newlogger.gormLogger = l.gormLogger.LogMode(level) + return &newlogger +} + +func (l *CustomLogger) Info(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Info(ctx, msg, data...) +} + +func (l *CustomLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Warn(ctx, msg, data...) +} + +func (l *CustomLogger) Error(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Error(ctx, msg, data...) +} + +func (l *CustomLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + elapsed := time.Since(begin) + sql, _ := fc() + l.gormLogger.Trace(ctx, begin, fc, err) + operation := extractOperation(sql) + tableName := extractTableName(sql) + fmt.Println(tableName) + //// 将SQL语句保存到数据库 + if operation == 0 || tableName == "sql_log" { + return + } + //go l.db.Model(&SqlLog{}).Create(&SqlLog{ + // OperatorID: 1, + // OperatorName: "test", + // SqlInfo: sql, + // TableNames: tableName, + // Type: operation, + //}) + + // 如果有需要,也可以根据执行时间(elapsed)等条件过滤或处理日志记录 + if elapsed > time.Second { + //l.gormLogger.Warn(ctx, "Slow SQL (> 1s): %s", sql) + } +} + +// extractTableName extracts the table name from a SQL query, supporting quoted table names. +func extractTableName(sql string) string { + // 使用非捕获组匹配多种SQL操作关键词 + re := regexp.MustCompile(`(?i)\b(?:from|update|into|delete\s+from)\b\s+[\` + "`" + `"]?(\w+)[\` + "`" + `"]?`) + match := re.FindStringSubmatch(sql) + + // 检查是否匹配成功 + if len(match) > 1 { + return match[1] + } + + return "" +} + +// extractOperation extracts the operation type from a SQL query. +func extractOperation(sql string) int32 { + sql = strings.TrimSpace(strings.ToLower(sql)) + var operation int32 + if strings.HasPrefix(sql, "select") { + operation = 0 + } else if strings.HasPrefix(sql, "insert") { + operation = 1 + } else if strings.HasPrefix(sql, "update") { + operation = 3 + } else if strings.HasPrefix(sql, "delete") { + operation = 2 + } + return operation +}