164 lines
3.5 KiB
Go
164 lines
3.5 KiB
Go
package util
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"errors"
|
|
zzsm2 "github.com/ZZMarquis/gm/sm2"
|
|
"github.com/tjfoc/gmsm/sm3"
|
|
"github.com/tjfoc/gmsm/x509"
|
|
"math/big"
|
|
"qteam/app/utils/postbank/internal/sm2"
|
|
)
|
|
|
|
func Sm2Decrypt(privateKey *sm2.PrivateKey, encryptData []byte) ([]byte, error) {
|
|
C1Byte := make([]byte, 65)
|
|
copy(C1Byte, encryptData[:65])
|
|
x, y := elliptic.Unmarshal(privateKey.Curve, C1Byte)
|
|
dBC1X, dBC1Y := privateKey.Curve.ScalarMult(x, y, bigIntToByte(privateKey.D))
|
|
dBC1Bytes := elliptic.Marshal(privateKey.Curve, dBC1X, dBC1Y)
|
|
|
|
kLen := len(encryptData) - 65 - 32
|
|
t, err := kdf(dBC1Bytes, kLen)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
M := make([]byte, kLen)
|
|
for i := 0; i < kLen; i++ {
|
|
M[i] = encryptData[65+i] ^ t[i]
|
|
}
|
|
|
|
C3 := make([]byte, 32)
|
|
copy(C3, encryptData[len(encryptData)-32:])
|
|
u := calculateHash(dBC1X, M, dBC1Y)
|
|
|
|
if bytes.Compare(u, C3) == 0 {
|
|
return M, nil
|
|
} else {
|
|
return nil, errors.New("解密失败")
|
|
}
|
|
}
|
|
|
|
func Sm2Encrypt(publicKey *sm2.PublicKey, m []byte) ([]byte, error) {
|
|
kLen := len(m)
|
|
var C1, t []byte
|
|
var err error
|
|
var kx, ky *big.Int
|
|
for {
|
|
k, _ := rand.Int(rand.Reader, publicKey.Params().N)
|
|
C1x, C1y := zzsm2.GetSm2P256V1().ScalarBaseMult(bigIntToByte(k))
|
|
// C1x, C1y := sm2.P256Sm2().ScalarBaseMult(bigIntToByte(k))
|
|
C1 = elliptic.Marshal(publicKey.Curve, C1x, C1y)
|
|
|
|
kx, ky = publicKey.ScalarMult(publicKey.X, publicKey.Y, bigIntToByte(k))
|
|
kpbBytes := elliptic.Marshal(publicKey, kx, ky)
|
|
t, err = kdf(kpbBytes, kLen)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !isAllZero(t) {
|
|
break
|
|
}
|
|
}
|
|
|
|
C2 := make([]byte, kLen)
|
|
for i := 0; i < kLen; i++ {
|
|
C2[i] = m[i] ^ t[i]
|
|
}
|
|
|
|
C3 := calculateHash(kx, m, ky)
|
|
|
|
r := make([]byte, 0, len(C1)+len(C2)+len(C3))
|
|
r = append(r, C1...)
|
|
r = append(r, C2...)
|
|
r = append(r, C3...)
|
|
return r, nil
|
|
}
|
|
|
|
func isAllZero(m []byte) bool {
|
|
for i := 0; i < len(m); i++ {
|
|
if m[i] != 0 {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func calculateHash(x *big.Int, M []byte, y *big.Int) []byte {
|
|
digest := sm3.New()
|
|
digest.Write(bigIntToByte(x))
|
|
digest.Write(M)
|
|
digest.Write(bigIntToByte(y))
|
|
result := digest.Sum(nil)[:32]
|
|
return result
|
|
}
|
|
|
|
func bigIntToByte(n *big.Int) []byte {
|
|
byteArray := n.Bytes()
|
|
// If the most significant byte's most significant bit is set,
|
|
// prepend a 0 byte to the slice to avoid being interpreted as a negative number.
|
|
if (byteArray[0] & 0x80) != 0 {
|
|
byteArray = append([]byte{0}, byteArray...)
|
|
}
|
|
return byteArray
|
|
}
|
|
|
|
func kdf(Z []byte, klen int) ([]byte, error) {
|
|
ct := 1
|
|
end := (klen + 31) / 32
|
|
result := make([]byte, 0)
|
|
for i := 1; i <= end; i++ {
|
|
b, err := sm3hash(Z, toByteArray(ct))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result = append(result, b...)
|
|
ct++
|
|
}
|
|
last, err := sm3hash(Z, toByteArray(ct))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if klen%32 == 0 {
|
|
result = append(result, last...)
|
|
} else {
|
|
result = append(result, last[:klen%32]...)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func sm3hash(sources ...[]byte) ([]byte, error) {
|
|
b, err := joinBytes(sources...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
md := make([]byte, 32)
|
|
h := x509.SM3.New()
|
|
h.Write(b)
|
|
h.Sum(md[:0])
|
|
return md, nil
|
|
}
|
|
|
|
func joinBytes(params ...[]byte) ([]byte, error) {
|
|
var buffer bytes.Buffer
|
|
for i := 0; i < len(params); i++ {
|
|
_, err := buffer.Write(params[i])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return buffer.Bytes(), nil
|
|
}
|
|
|
|
func toByteArray(i int) []byte {
|
|
byteArray := []byte{
|
|
byte(i >> 24),
|
|
byte((i & 16777215) >> 16),
|
|
byte((i & 65535) >> 8),
|
|
byte(i & 255),
|
|
}
|
|
return byteArray
|
|
}
|