YouChuKoffee/app/utils/postbank/internal/util/sm2x.go

164 lines
3.5 KiB
Go
Raw Normal View History

2024-06-19 18:32:34 +08:00
package util
import (
"bytes"
"crypto/elliptic"
"crypto/rand"
"errors"
zzsm2 "github.com/ZZMarquis/gm/sm2"
"github.com/tjfoc/gmsm/sm3"
"github.com/tjfoc/gmsm/x509"
"math/big"
"qteam/app/utils/postbank/internal/sm2"
)
func Sm2Decrypt(privateKey *sm2.PrivateKey, encryptData []byte) ([]byte, error) {
C1Byte := make([]byte, 65)
copy(C1Byte, encryptData[:65])
x, y := elliptic.Unmarshal(privateKey.Curve, C1Byte)
dBC1X, dBC1Y := privateKey.Curve.ScalarMult(x, y, bigIntToByte(privateKey.D))
dBC1Bytes := elliptic.Marshal(privateKey.Curve, dBC1X, dBC1Y)
kLen := len(encryptData) - 65 - 32
t, err := kdf(dBC1Bytes, kLen)
if err != nil {
return nil, err
}
M := make([]byte, kLen)
for i := 0; i < kLen; i++ {
M[i] = encryptData[65+i] ^ t[i]
}
C3 := make([]byte, 32)
copy(C3, encryptData[len(encryptData)-32:])
u := calculateHash(dBC1X, M, dBC1Y)
if bytes.Compare(u, C3) == 0 {
return M, nil
} else {
return nil, errors.New("解密失败")
}
}
func Sm2Encrypt(publicKey *sm2.PublicKey, m []byte) ([]byte, error) {
kLen := len(m)
var C1, t []byte
var err error
var kx, ky *big.Int
for {
k, _ := rand.Int(rand.Reader, publicKey.Params().N)
C1x, C1y := zzsm2.GetSm2P256V1().ScalarBaseMult(bigIntToByte(k))
// C1x, C1y := sm2.P256Sm2().ScalarBaseMult(bigIntToByte(k))
C1 = elliptic.Marshal(publicKey.Curve, C1x, C1y)
kx, ky = publicKey.ScalarMult(publicKey.X, publicKey.Y, bigIntToByte(k))
kpbBytes := elliptic.Marshal(publicKey, kx, ky)
t, err = kdf(kpbBytes, kLen)
if err != nil {
return nil, err
}
if !isAllZero(t) {
break
}
}
C2 := make([]byte, kLen)
for i := 0; i < kLen; i++ {
C2[i] = m[i] ^ t[i]
}
C3 := calculateHash(kx, m, ky)
r := make([]byte, 0, len(C1)+len(C2)+len(C3))
r = append(r, C1...)
r = append(r, C2...)
r = append(r, C3...)
return r, nil
}
func isAllZero(m []byte) bool {
for i := 0; i < len(m); i++ {
if m[i] != 0 {
return false
}
}
return true
}
func calculateHash(x *big.Int, M []byte, y *big.Int) []byte {
digest := sm3.New()
digest.Write(bigIntToByte(x))
digest.Write(M)
digest.Write(bigIntToByte(y))
result := digest.Sum(nil)[:32]
return result
}
func bigIntToByte(n *big.Int) []byte {
byteArray := n.Bytes()
// If the most significant byte's most significant bit is set,
// prepend a 0 byte to the slice to avoid being interpreted as a negative number.
if (byteArray[0] & 0x80) != 0 {
byteArray = append([]byte{0}, byteArray...)
}
return byteArray
}
func kdf(Z []byte, klen int) ([]byte, error) {
ct := 1
end := (klen + 31) / 32
result := make([]byte, 0)
for i := 1; i <= end; i++ {
b, err := sm3hash(Z, toByteArray(ct))
if err != nil {
return nil, err
}
result = append(result, b...)
ct++
}
last, err := sm3hash(Z, toByteArray(ct))
if err != nil {
return nil, err
}
if klen%32 == 0 {
result = append(result, last...)
} else {
result = append(result, last[:klen%32]...)
}
return result, nil
}
func sm3hash(sources ...[]byte) ([]byte, error) {
b, err := joinBytes(sources...)
if err != nil {
return nil, err
}
md := make([]byte, 32)
h := x509.SM3.New()
h.Write(b)
h.Sum(md[:0])
return md, nil
}
func joinBytes(params ...[]byte) ([]byte, error) {
var buffer bytes.Buffer
for i := 0; i < len(params); i++ {
_, err := buffer.Write(params[i])
if err != nil {
return nil, err
}
}
return buffer.Bytes(), nil
}
func toByteArray(i int) []byte {
byteArray := []byte{
byte(i >> 24),
byte((i & 16777215) >> 16),
byte((i & 65535) >> 8),
byte(i & 255),
}
return byteArray
}