393 lines
8.3 KiB
Go
393 lines
8.3 KiB
Go
package cmbv2
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"encoding/asn1"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"github.com/ZZMarquis/gm/sm4"
|
|
"io"
|
|
"math/big"
|
|
"strings"
|
|
"voucher/internal/pkg/cmb/sm2/sdk"
|
|
"voucher/internal/pkg/cmbv2/model"
|
|
"voucher/internal/pkg/cmbv2/utils"
|
|
)
|
|
|
|
type Cmb struct {
|
|
publicKey *model.PublicKey
|
|
privateKey *model.PrivateKey
|
|
|
|
cipherType model.CipherType
|
|
c3Len int
|
|
uid []byte
|
|
|
|
sdk sdk.SDK
|
|
cmbLifeSdk sdk.SDK
|
|
|
|
sm2P256 *utils.Sm2P256Curve
|
|
}
|
|
|
|
func NewCmb(privateKey, sopPublicKey string) (*Cmb, error) {
|
|
cmb := &Cmb{
|
|
c3Len: 32,
|
|
uid: model.DefaultUid,
|
|
cipherType: model.C1C3C2,
|
|
|
|
sdk: sdk.NewBaseSdk(),
|
|
cmbLifeSdk: sdk.NewCmbLifeSdk(),
|
|
}
|
|
|
|
sm2P256 := utils.NewP256Sm2()
|
|
cmb.sm2P256 = &sm2P256
|
|
|
|
if len(privateKey) > 0 {
|
|
if err := cmb.setHexPrivateKey(privateKey); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if err := cmb.setHexPublicKey(sopPublicKey); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return cmb, nil
|
|
}
|
|
|
|
func (s *Cmb) setHexPublicKey(hexStr string) error {
|
|
|
|
d, err := hex.DecodeString(hexStr)
|
|
if err != nil {
|
|
return fmt.Errorf("publicKey is not hex string: %s", err.Error())
|
|
}
|
|
|
|
s.publicKey, err = utils.HexToPublicKey(s.sm2P256, d)
|
|
if err != nil {
|
|
return fmt.Errorf("parse publicKey err: %+v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Cmb) setHexPrivateKey(hexStr string) error {
|
|
|
|
d, err := hex.DecodeString(hexStr)
|
|
if err != nil {
|
|
return fmt.Errorf("privateKey is not hex string: %s", err.Error())
|
|
}
|
|
|
|
s.privateKey, err = utils.HexToPrivateKey(s.sm2P256, d)
|
|
if err != nil {
|
|
return fmt.Errorf("parse privateKey err: %+v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Cmb) Encrypt(input []byte) (string, error) {
|
|
|
|
if input == nil {
|
|
return "", fmt.Errorf("加密元数据 is empty")
|
|
}
|
|
|
|
sm4Key := utils.GenerateSM4Key()
|
|
iv := utils.GetSM4IV()
|
|
|
|
encryptedBody, err := sm4.CBCEncrypt(sm4Key, iv, utils.Padding(input, 1))
|
|
|
|
keyAndIv := utils.AssemblingByteArray(sm4Key, iv)
|
|
|
|
data := []byte(base64.StdEncoding.EncodeToString(keyAndIv))
|
|
|
|
kvTmp, err := s.encrypt(data)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return fmt.Sprintf("%s|%s", base64.StdEncoding.EncodeToString([]byte(kvTmp)), base64.StdEncoding.EncodeToString(encryptedBody)), nil
|
|
}
|
|
|
|
func (s *Cmb) encrypt(data []byte) ([]byte, error) {
|
|
|
|
c2 := make([]byte, len(data))
|
|
copy(c2, data)
|
|
|
|
var c1 []byte
|
|
var kx, ky *big.Int
|
|
for {
|
|
k, err := rand.Int(rand.Reader, s.publicKey.Params().N)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("rand error: %v", err)
|
|
}
|
|
c1x, c1y := s.publicKey.Curve.ScalarBaseMult(k.Bytes())
|
|
c1 = elliptic.Marshal(s.publicKey.Curve, c1x, c1y)
|
|
kx, ky = s.publicKey.Curve.ScalarMult(s.publicKey.X, s.publicKey.Y, k.Bytes())
|
|
|
|
err = s.cmbLifeSdk.Kdf(s.publicKey, kx, ky, c2)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("kdf error: %v", err)
|
|
}
|
|
if s.encrypted(c2, data) {
|
|
break
|
|
}
|
|
}
|
|
|
|
c3 := s.cmbLifeSdk.CalculateHash(kx, data, ky)
|
|
|
|
c1Len := len(c1)
|
|
c2Len := len(c2)
|
|
c3Len := len(c3)
|
|
|
|
toData := make([]byte, c1Len+c2Len+c3Len)
|
|
|
|
if s.cipherType == model.C1C2C3 {
|
|
copy(toData[:c1Len], c1)
|
|
copy(toData[c1Len:c1Len+c2Len], c2)
|
|
copy(toData[c1Len+c2Len:], c3)
|
|
} else if s.cipherType == model.C1C3C2 {
|
|
copy(toData[:c1Len], c1)
|
|
copy(toData[c1Len:c1Len+c3Len], c3)
|
|
copy(toData[c1Len+c3Len:], c2)
|
|
} else {
|
|
return nil, fmt.Errorf("cipher type not support")
|
|
}
|
|
|
|
return toData, nil
|
|
}
|
|
|
|
func (s *Cmb) encrypted(encData []byte, in []byte) bool {
|
|
encDataLen := len(encData)
|
|
|
|
for i := 0; i != encDataLen; i++ {
|
|
if encData[i] != in[i] {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (s *Cmb) Decrypt(input string) (string, error) {
|
|
|
|
tmpDataArr := strings.Split(input, "|")
|
|
if len(tmpDataArr) != 2 {
|
|
return "", fmt.Errorf("数据格式错误")
|
|
}
|
|
|
|
keyAndIvStr := tmpDataArr[0]
|
|
encryptedBody := tmpDataArr[1]
|
|
|
|
data, err := base64.StdEncoding.DecodeString(keyAndIvStr)
|
|
if err != nil {
|
|
return "", fmt.Errorf("data is not base64 string: %s", err.Error())
|
|
}
|
|
|
|
kvBase64Tmp, err := s.decrypt(data)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
kvTmp, err := base64.StdEncoding.DecodeString(kvBase64Tmp)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if len(kvTmp) != 33 {
|
|
return "", fmt.Errorf("iv长度不等于33")
|
|
}
|
|
|
|
plainKey := kvTmp[0:16]
|
|
plainIv := kvTmp[17:33]
|
|
|
|
data2, err := base64.StdEncoding.DecodeString(encryptedBody)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
plainText, err := sm4.CBCDecrypt(plainKey, plainIv, data2)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return string(utils.Padding(plainText, 0)), nil
|
|
}
|
|
|
|
func (s *Cmb) decrypt(data []byte) (string, error) {
|
|
|
|
c1Len := 65
|
|
C1Byte := make([]byte, c1Len)
|
|
copy(C1Byte, data[:c1Len])
|
|
|
|
x, y := elliptic.Unmarshal(s.privateKey.Curve, C1Byte)
|
|
dBC1X, dBC1Y := s.privateKey.Curve.ScalarMult(x, y, s.privateKey.D.Bytes())
|
|
|
|
c2Len := len(data) - c1Len - s.c3Len
|
|
|
|
c2 := make([]byte, c2Len)
|
|
c3 := make([]byte, s.c3Len)
|
|
|
|
if s.cipherType == model.C1C2C3 {
|
|
copy(c2, data[c1Len:c1Len+c2Len])
|
|
copy(c3, data[c1Len+c2Len:])
|
|
} else if s.cipherType == model.C1C3C2 {
|
|
copy(c3, data[c1Len:c1Len+s.c3Len])
|
|
copy(c2, data[c1Len+s.c3Len:])
|
|
} else {
|
|
return "", fmt.Errorf("cipher type not support")
|
|
}
|
|
|
|
if err := s.cmbLifeSdk.Kdf(s.privateKey.Curve, dBC1X, dBC1Y, c2); err != nil {
|
|
return "", fmt.Errorf("kdf error: %v", err)
|
|
}
|
|
|
|
u := s.cmbLifeSdk.CalculateHash(dBC1X, c2, dBC1Y)
|
|
if bytes.Compare(u, c3) == 0 {
|
|
return string(c2), nil
|
|
}
|
|
|
|
return "", fmt.Errorf("decrypt error")
|
|
}
|
|
|
|
func (s *Cmb) Sign(input []byte) (string, error) {
|
|
|
|
signData, err := s.sign(input)
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return base64.StdEncoding.EncodeToString(signData), nil
|
|
}
|
|
|
|
func (s *Cmb) sign(data []byte) ([]byte, error) {
|
|
|
|
z := s.sdk.GetZ(s.privateKey.PublicKey.X, s.privateKey.PublicKey.Y, s.uid)
|
|
e := s.sdk.GetE(z, data)
|
|
|
|
c := s.privateKey.PublicKey.Curve
|
|
N := c.Params().N
|
|
|
|
if N.Sign() == 0 {
|
|
return nil, fmt.Errorf("invalid curve order")
|
|
}
|
|
|
|
var k, r, sb *big.Int
|
|
var err error
|
|
for {
|
|
for {
|
|
k, err = s.randFieldElement(c, rand.Reader)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("randFieldElement: %+v", err)
|
|
}
|
|
r, _ = s.privateKey.Curve.ScalarBaseMult(k.Bytes())
|
|
r.Add(r, e)
|
|
r.Mod(r, N)
|
|
if r.Sign() != 0 {
|
|
if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
rD := new(big.Int).Mul(s.privateKey.D, r)
|
|
sb = new(big.Int).Sub(k, rD)
|
|
d1 := new(big.Int).Add(s.privateKey.D, model.One)
|
|
d1Inv := new(big.Int).ModInverse(d1, N)
|
|
sb.Mul(sb, d1Inv)
|
|
sb.Mod(sb, N)
|
|
if sb.Sign() != 0 {
|
|
break
|
|
}
|
|
}
|
|
|
|
signature := model.Signature{R: r, S: sb}
|
|
|
|
si, err := asn1.Marshal(signature)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("asn1.Marshal: %+v", err)
|
|
}
|
|
|
|
return si, nil
|
|
}
|
|
|
|
func (s *Cmb) randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) {
|
|
params := c.Params()
|
|
b := make([]byte, params.BitSize/8+8)
|
|
_, err = io.ReadFull(random, b)
|
|
if err != nil {
|
|
return
|
|
}
|
|
k = new(big.Int).SetBytes(b)
|
|
n := new(big.Int).Sub(params.N, model.One)
|
|
k.Mod(k, n)
|
|
k.Add(k, model.One)
|
|
return
|
|
}
|
|
|
|
func (s *Cmb) Verify(input, sign string) (bool, error) {
|
|
|
|
signBytes, err := base64.StdEncoding.DecodeString(sign)
|
|
if err != nil {
|
|
return false, fmt.Errorf("signature is not base64 string: %s", err.Error())
|
|
}
|
|
|
|
signature := model.Signature{}
|
|
_, err = asn1.Unmarshal(signBytes, &signature)
|
|
if err != nil {
|
|
return false, fmt.Errorf("signature is not asn1: %s", err.Error())
|
|
}
|
|
|
|
ok := s.verify([]byte(input), signature)
|
|
|
|
if !ok {
|
|
return false, nil
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (s *Cmb) verifyBool(data []byte) bool {
|
|
if bytes.Equal(data, model.VerifyTrue) {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s *Cmb) verify(data []byte, signature model.Signature) bool {
|
|
|
|
c := s.publicKey.Curve
|
|
N := c.Params().N
|
|
|
|
if signature.R.Cmp(model.One) < 0 || signature.S.Cmp(model.One) < 0 {
|
|
return s.verifyBool(model.VerifyFalse)
|
|
}
|
|
|
|
if signature.R.Cmp(N) >= 0 || signature.S.Cmp(N) >= 0 {
|
|
return s.verifyBool(model.VerifyFalse)
|
|
}
|
|
|
|
z := s.sdk.GetZ(s.publicKey.X, s.publicKey.Y, s.uid)
|
|
e := s.sdk.GetE(z, data)
|
|
t := new(big.Int).Add(signature.R, signature.S)
|
|
t.Mod(t, N)
|
|
|
|
if t.Sign() == 0 {
|
|
return s.verifyBool(model.VerifyFalse)
|
|
}
|
|
|
|
var x *big.Int
|
|
x1, y1 := c.ScalarBaseMult(signature.S.Bytes())
|
|
x2, y2 := c.ScalarMult(s.publicKey.X, s.publicKey.Y, t.Bytes())
|
|
x, _ = c.Add(x1, y1, x2, y2)
|
|
|
|
x.Add(x, e)
|
|
x.Mod(x, N)
|
|
if x.Cmp(signature.R) == 0 {
|
|
return s.verifyBool(model.VerifyTrue)
|
|
}
|
|
|
|
return s.verifyBool(model.VerifyFalse)
|
|
}
|