package util import ( "bytes" "crypto/elliptic" "crypto/rand" "errors" zzsm2 "github.com/ZZMarquis/gm/sm2" "github.com/tjfoc/gmsm/sm3" "github.com/tjfoc/gmsm/x509" "math/big" "qteam/app/utils/postbank/internal/sm2" ) func Sm2Decrypt(privateKey *sm2.PrivateKey, encryptData []byte) ([]byte, error) { C1Byte := make([]byte, 65) copy(C1Byte, encryptData[:65]) x, y := elliptic.Unmarshal(privateKey.Curve, C1Byte) dBC1X, dBC1Y := privateKey.Curve.ScalarMult(x, y, bigIntToByte(privateKey.D)) dBC1Bytes := elliptic.Marshal(privateKey.Curve, dBC1X, dBC1Y) kLen := len(encryptData) - 65 - 32 t, err := kdf(dBC1Bytes, kLen) if err != nil { return nil, err } M := make([]byte, kLen) for i := 0; i < kLen; i++ { M[i] = encryptData[65+i] ^ t[i] } C3 := make([]byte, 32) copy(C3, encryptData[len(encryptData)-32:]) u := calculateHash(dBC1X, M, dBC1Y) if bytes.Compare(u, C3) == 0 { return M, nil } else { return nil, errors.New("解密失败") } } func Sm2Encrypt(publicKey *sm2.PublicKey, m []byte) ([]byte, error) { kLen := len(m) var C1, t []byte var err error var kx, ky *big.Int for { k, _ := rand.Int(rand.Reader, publicKey.Params().N) C1x, C1y := zzsm2.GetSm2P256V1().ScalarBaseMult(bigIntToByte(k)) // C1x, C1y := sm2.P256Sm2().ScalarBaseMult(bigIntToByte(k)) C1 = elliptic.Marshal(publicKey.Curve, C1x, C1y) kx, ky = publicKey.ScalarMult(publicKey.X, publicKey.Y, bigIntToByte(k)) kpbBytes := elliptic.Marshal(publicKey, kx, ky) t, err = kdf(kpbBytes, kLen) if err != nil { return nil, err } if !isAllZero(t) { break } } C2 := make([]byte, kLen) for i := 0; i < kLen; i++ { C2[i] = m[i] ^ t[i] } C3 := calculateHash(kx, m, ky) r := make([]byte, 0, len(C1)+len(C2)+len(C3)) r = append(r, C1...) r = append(r, C2...) r = append(r, C3...) return r, nil } func isAllZero(m []byte) bool { for i := 0; i < len(m); i++ { if m[i] != 0 { return false } } return true } func calculateHash(x *big.Int, M []byte, y *big.Int) []byte { digest := sm3.New() digest.Write(bigIntToByte(x)) digest.Write(M) digest.Write(bigIntToByte(y)) result := digest.Sum(nil)[:32] return result } func bigIntToByte(n *big.Int) []byte { byteArray := n.Bytes() // If the most significant byte's most significant bit is set, // prepend a 0 byte to the slice to avoid being interpreted as a negative number. if (byteArray[0] & 0x80) != 0 { byteArray = append([]byte{0}, byteArray...) } return byteArray } func kdf(Z []byte, klen int) ([]byte, error) { ct := 1 end := (klen + 31) / 32 result := make([]byte, 0) for i := 1; i <= end; i++ { b, err := sm3hash(Z, toByteArray(ct)) if err != nil { return nil, err } result = append(result, b...) ct++ } last, err := sm3hash(Z, toByteArray(ct)) if err != nil { return nil, err } if klen%32 == 0 { result = append(result, last...) } else { result = append(result, last[:klen%32]...) } return result, nil } func sm3hash(sources ...[]byte) ([]byte, error) { b, err := joinBytes(sources...) if err != nil { return nil, err } md := make([]byte, 32) h := x509.SM3.New() h.Write(b) h.Sum(md[:0]) return md, nil } func joinBytes(params ...[]byte) ([]byte, error) { var buffer bytes.Buffer for i := 0; i < len(params); i++ { _, err := buffer.Write(params[i]) if err != nil { return nil, err } } return buffer.Bytes(), nil } func toByteArray(i int) []byte { byteArray := []byte{ byte(i >> 24), byte((i & 16777215) >> 16), byte((i & 65535) >> 8), byte(i & 255), } return byteArray }