voucher/internal/pkg/cmbv2/sm2.go

388 lines
8.1 KiB
Go

package cmbv2
import (
"bytes"
"crypto/elliptic"
"crypto/rand"
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"fmt"
"github.com/ZZMarquis/gm/sm4"
"io"
"math/big"
"strings"
"voucher/internal/pkg/cmb/sm2/sdk"
"voucher/internal/pkg/cmbv2/model"
"voucher/internal/pkg/cmbv2/utils"
)
type Cmb struct {
publicKey *model.PublicKey
privateKey *model.PrivateKey
cipherType model.CipherType
c3Len int
uid []byte
sdk sdk.SDK
sm2P256 *utils.Sm2P256Curve
}
func NewCmb(privateKey, sopPublicKey string) (*Cmb, error) {
cmb := &Cmb{
c3Len: 32,
uid: model.DefaultUid,
cipherType: model.C1C3C2,
sdk: sdk.NewBaseSdk(),
}
sm2P256 := utils.NewP256Sm2()
cmb.sm2P256 = &sm2P256
if err := cmb.setHexPrivateKey(privateKey); err != nil {
return nil, err
}
if err := cmb.setHexPublicKey(sopPublicKey); err != nil {
return nil, err
}
return cmb, nil
}
func (s *Cmb) setHexPublicKey(hexStr string) error {
d, err := hex.DecodeString(hexStr)
if err != nil {
return fmt.Errorf("publicKey is not hex string: %s", err.Error())
}
s.publicKey, err = utils.HexToPublicKey(s.sm2P256, d)
if err != nil {
return fmt.Errorf("parse publicKey err: %+v", err)
}
return nil
}
func (s *Cmb) setHexPrivateKey(hexStr string) error {
d, err := hex.DecodeString(hexStr)
if err != nil {
return fmt.Errorf("privateKey is not hex string: %s", err.Error())
}
s.privateKey, err = utils.HexToPrivateKey(s.sm2P256, d)
if err != nil {
return fmt.Errorf("parse privateKey err: %+v", err)
}
return nil
}
func (s *Cmb) Encrypt(input []byte) (string, error) {
if input == nil {
return "", fmt.Errorf("加密元数据 is empty")
}
sm4Key := utils.GenerateSM4Key()
iv := utils.GetSM4IV()
encryptedBody, err := sm4.CBCEncrypt(sm4Key, iv, utils.Padding(input, 1))
keyAndIv := utils.AssemblingByteArray(sm4Key, iv)
data := []byte(base64.StdEncoding.EncodeToString(keyAndIv))
kvTmp, err := s.encrypt(data)
if err != nil {
return "", err
}
return fmt.Sprintf("%s|%s", base64.StdEncoding.EncodeToString([]byte(kvTmp)), base64.StdEncoding.EncodeToString(encryptedBody)), nil
}
func (s *Cmb) encrypt(data []byte) ([]byte, error) {
c2 := make([]byte, len(data))
copy(c2, data)
var c1 []byte
var kx, ky *big.Int
for {
k, err := rand.Int(rand.Reader, s.publicKey.Params().N)
if err != nil {
return nil, fmt.Errorf("rand error: %v", err)
}
c1x, c1y := s.publicKey.Curve.ScalarBaseMult(k.Bytes())
c1 = elliptic.Marshal(s.publicKey.Curve, c1x, c1y)
kx, ky = s.publicKey.Curve.ScalarMult(s.publicKey.X, s.publicKey.Y, k.Bytes())
err = s.sdk.Kdf(s.publicKey, kx, ky, c2)
if err != nil {
return nil, fmt.Errorf("kdf error: %v", err)
}
if s.encrypted(c2, data) {
break
}
}
c3 := s.sdk.CalculateHash(kx, data, ky)
c1Len := len(c1)
c2Len := len(c2)
c3Len := len(c3)
toData := make([]byte, c1Len+c2Len+c3Len)
if s.cipherType == model.C1C2C3 {
copy(toData[:c1Len], c1)
copy(toData[c1Len:c1Len+c2Len], c2)
copy(toData[c1Len+c2Len:], c3)
} else if s.cipherType == model.C1C3C2 {
copy(toData[:c1Len], c1)
copy(toData[c1Len:c1Len+c3Len], c3)
copy(toData[c1Len+c3Len:], c2)
} else {
return nil, fmt.Errorf("cipher type not support")
}
return toData, nil
}
func (s *Cmb) encrypted(encData []byte, in []byte) bool {
encDataLen := len(encData)
for i := 0; i != encDataLen; i++ {
if encData[i] != in[i] {
return true
}
}
return false
}
func (s *Cmb) Decrypt(input string) (string, error) {
tmpDataArr := strings.Split(input, "|")
if len(tmpDataArr) != 2 {
return "", fmt.Errorf("数据格式错误")
}
keyAndIvStr := tmpDataArr[0]
encryptedBody := tmpDataArr[1]
data, err := base64.StdEncoding.DecodeString(keyAndIvStr)
if err != nil {
return "", fmt.Errorf("data is not base64 string: %s", err.Error())
}
kvBase64Tmp, err := s.decrypt(data)
if err != nil {
return "", err
}
kvTmp, err := base64.StdEncoding.DecodeString(kvBase64Tmp)
if err != nil {
return "", err
}
if len(kvTmp) != 33 {
return "", fmt.Errorf("iv长度不等于33")
}
plainKey := kvTmp[0:16]
plainIv := kvTmp[17:33]
data2, err := base64.StdEncoding.DecodeString(encryptedBody)
if err != nil {
return "", err
}
plainText, err := sm4.CBCDecrypt(plainKey, plainIv, data2)
if err != nil {
return "", err
}
return string(utils.Padding(plainText, 0)), nil
}
func (s *Cmb) decrypt(data []byte) (string, error) {
c1Len := 65
C1Byte := make([]byte, c1Len)
copy(C1Byte, data[:c1Len])
x, y := elliptic.Unmarshal(s.privateKey.Curve, C1Byte)
dBC1X, dBC1Y := s.privateKey.Curve.ScalarMult(x, y, s.privateKey.D.Bytes())
c2Len := len(data) - c1Len - s.c3Len
c2 := make([]byte, c2Len)
c3 := make([]byte, s.c3Len)
if s.cipherType == model.C1C2C3 {
copy(c2, data[c1Len:c1Len+c2Len])
copy(c3, data[c1Len+c2Len:])
} else if s.cipherType == model.C1C3C2 {
copy(c3, data[c1Len:c1Len+s.c3Len])
copy(c2, data[c1Len+s.c3Len:])
} else {
return "", fmt.Errorf("cipher type not support")
}
if err := s.sdk.Kdf(s.privateKey.Curve, dBC1X, dBC1Y, c2); err != nil {
return "", fmt.Errorf("kdf error: %v", err)
}
u := s.sdk.CalculateHash(dBC1X, c2, dBC1Y)
if bytes.Compare(u, c3) == 0 {
return string(c2), nil
}
return "", fmt.Errorf("decrypt error")
}
func (s *Cmb) Sign(input []byte) (string, error) {
signData, err := s.sign(input)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(signData), nil
}
func (s *Cmb) sign(data []byte) ([]byte, error) {
z := s.sdk.GetZ(s.privateKey.PublicKey.X, s.privateKey.PublicKey.Y, s.uid)
e := s.sdk.GetE(z, data)
c := s.privateKey.PublicKey.Curve
N := c.Params().N
if N.Sign() == 0 {
return nil, fmt.Errorf("invalid curve order")
}
var k, r, sb *big.Int
var err error
for {
for {
k, err = s.randFieldElement(c, rand.Reader)
if err != nil {
return nil, fmt.Errorf("randFieldElement: %+v", err)
}
r, _ = s.privateKey.Curve.ScalarBaseMult(k.Bytes())
r.Add(r, e)
r.Mod(r, N)
if r.Sign() != 0 {
if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 {
break
}
}
}
rD := new(big.Int).Mul(s.privateKey.D, r)
sb = new(big.Int).Sub(k, rD)
d1 := new(big.Int).Add(s.privateKey.D, model.One)
d1Inv := new(big.Int).ModInverse(d1, N)
sb.Mul(sb, d1Inv)
sb.Mod(sb, N)
if sb.Sign() != 0 {
break
}
}
signature := model.Signature{R: r, S: sb}
si, err := asn1.Marshal(signature)
if err != nil {
return nil, fmt.Errorf("asn1.Marshal: %+v", err)
}
return si, nil
}
func (s *Cmb) randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) {
params := c.Params()
b := make([]byte, params.BitSize/8+8)
_, err = io.ReadFull(random, b)
if err != nil {
return
}
k = new(big.Int).SetBytes(b)
n := new(big.Int).Sub(params.N, model.One)
k.Mod(k, n)
k.Add(k, model.One)
return
}
func (s *Cmb) Verify(input, sign string) (bool, error) {
signBytes, err := base64.StdEncoding.DecodeString(sign)
if err != nil {
return false, fmt.Errorf("signature is not base64 string: %s", err.Error())
}
signature := model.Signature{}
_, err = asn1.Unmarshal(signBytes, &signature)
if err != nil {
return false, fmt.Errorf("signature is not asn1: %s", err.Error())
}
ok := s.verify([]byte(input), signature)
if !ok {
return false, nil
}
return true, nil
}
func (s *Cmb) verifyBool(data []byte) bool {
if bytes.Equal(data, model.VerifyTrue) {
return true
}
return false
}
func (s *Cmb) verify(data []byte, signature model.Signature) bool {
c := s.publicKey.Curve
N := c.Params().N
if signature.R.Cmp(model.One) < 0 || signature.S.Cmp(model.One) < 0 {
return s.verifyBool(model.VerifyFalse)
}
if signature.R.Cmp(N) >= 0 || signature.S.Cmp(N) >= 0 {
return s.verifyBool(model.VerifyFalse)
}
z := s.sdk.GetZ(s.publicKey.X, s.publicKey.Y, s.uid)
e := s.sdk.GetE(z, data)
t := new(big.Int).Add(signature.R, signature.S)
t.Mod(t, N)
if t.Sign() == 0 {
return s.verifyBool(model.VerifyFalse)
}
var x *big.Int
x1, y1 := c.ScalarBaseMult(signature.S.Bytes())
x2, y2 := c.ScalarMult(s.publicKey.X, s.publicKey.Y, t.Bytes())
x, _ = c.Add(x1, y1, x2, y2)
x.Add(x, e)
x.Mod(x, N)
if x.Cmp(signature.R) == 0 {
return s.verifyBool(model.VerifyTrue)
}
return s.verifyBool(model.VerifyFalse)
}