package cmbv2 import ( "bytes" "crypto/elliptic" "crypto/rand" "encoding/asn1" "encoding/base64" "encoding/hex" "fmt" "github.com/ZZMarquis/gm/sm4" "io" "math/big" "strings" "voucher/internal/pkg/cmb/sm2/sdk" "voucher/internal/pkg/cmbv2/model" "voucher/internal/pkg/cmbv2/utils" ) type Cmb struct { publicKey *model.PublicKey privateKey *model.PrivateKey cipherType model.CipherType c3Len int uid []byte sdk sdk.SDK cmbLifeSdk sdk.SDK sm2P256 *utils.Sm2P256Curve } func NewCmb(privateKey, sopPublicKey string) (*Cmb, error) { cmb := &Cmb{ c3Len: 32, uid: model.DefaultUid, cipherType: model.C1C3C2, sdk: sdk.NewBaseSdk(), cmbLifeSdk: sdk.NewCmbLifeSdk(), } sm2P256 := utils.NewP256Sm2() cmb.sm2P256 = &sm2P256 if len(privateKey) > 0 { if err := cmb.setHexPrivateKey(privateKey); err != nil { return nil, err } } if err := cmb.setHexPublicKey(sopPublicKey); err != nil { return nil, err } return cmb, nil } func (s *Cmb) setHexPublicKey(hexStr string) error { d, err := hex.DecodeString(hexStr) if err != nil { return fmt.Errorf("publicKey is not hex string: %s", err.Error()) } s.publicKey, err = utils.HexToPublicKey(s.sm2P256, d) if err != nil { return fmt.Errorf("parse publicKey err: %+v", err) } return nil } func (s *Cmb) setHexPrivateKey(hexStr string) error { d, err := hex.DecodeString(hexStr) if err != nil { return fmt.Errorf("privateKey is not hex string: %s", err.Error()) } s.privateKey, err = utils.HexToPrivateKey(s.sm2P256, d) if err != nil { return fmt.Errorf("parse privateKey err: %+v", err) } return nil } func (s *Cmb) Encrypt(input []byte) (string, error) { if input == nil { return "", fmt.Errorf("加密元数据 is empty") } sm4Key := utils.GenerateSM4Key() iv := utils.GetSM4IV() encryptedBody, err := sm4.CBCEncrypt(sm4Key, iv, utils.Padding(input, 1)) keyAndIv := utils.AssemblingByteArray(sm4Key, iv) data := []byte(base64.StdEncoding.EncodeToString(keyAndIv)) kvTmp, err := s.encrypt(data) if err != nil { return "", err } return fmt.Sprintf("%s|%s", base64.StdEncoding.EncodeToString([]byte(kvTmp)), base64.StdEncoding.EncodeToString(encryptedBody)), nil } func (s *Cmb) encrypt(data []byte) ([]byte, error) { c2 := make([]byte, len(data)) copy(c2, data) var c1 []byte var kx, ky *big.Int for { k, err := rand.Int(rand.Reader, s.publicKey.Params().N) if err != nil { return nil, fmt.Errorf("rand error: %v", err) } c1x, c1y := s.publicKey.Curve.ScalarBaseMult(k.Bytes()) c1 = elliptic.Marshal(s.publicKey.Curve, c1x, c1y) kx, ky = s.publicKey.Curve.ScalarMult(s.publicKey.X, s.publicKey.Y, k.Bytes()) err = s.cmbLifeSdk.Kdf(s.publicKey, kx, ky, c2) if err != nil { return nil, fmt.Errorf("kdf error: %v", err) } if s.encrypted(c2, data) { break } } c3 := s.cmbLifeSdk.CalculateHash(kx, data, ky) c1Len := len(c1) c2Len := len(c2) c3Len := len(c3) toData := make([]byte, c1Len+c2Len+c3Len) if s.cipherType == model.C1C2C3 { copy(toData[:c1Len], c1) copy(toData[c1Len:c1Len+c2Len], c2) copy(toData[c1Len+c2Len:], c3) } else if s.cipherType == model.C1C3C2 { copy(toData[:c1Len], c1) copy(toData[c1Len:c1Len+c3Len], c3) copy(toData[c1Len+c3Len:], c2) } else { return nil, fmt.Errorf("cipher type not support") } return toData, nil } func (s *Cmb) encrypted(encData []byte, in []byte) bool { encDataLen := len(encData) for i := 0; i != encDataLen; i++ { if encData[i] != in[i] { return true } } return false } func (s *Cmb) Decrypt(input string) (string, error) { tmpDataArr := strings.Split(input, "|") if len(tmpDataArr) != 2 { return "", fmt.Errorf("数据格式错误") } keyAndIvStr := tmpDataArr[0] encryptedBody := tmpDataArr[1] data, err := base64.StdEncoding.DecodeString(keyAndIvStr) if err != nil { return "", fmt.Errorf("data is not base64 string: %s", err.Error()) } kvBase64Tmp, err := s.decrypt(data) if err != nil { return "", err } kvTmp, err := base64.StdEncoding.DecodeString(kvBase64Tmp) if err != nil { return "", err } if len(kvTmp) != 33 { return "", fmt.Errorf("iv长度不等于33") } plainKey := kvTmp[0:16] plainIv := kvTmp[17:33] data2, err := base64.StdEncoding.DecodeString(encryptedBody) if err != nil { return "", err } plainText, err := sm4.CBCDecrypt(plainKey, plainIv, data2) if err != nil { return "", err } return string(utils.Padding(plainText, 0)), nil } func (s *Cmb) decrypt(data []byte) (string, error) { c1Len := 65 C1Byte := make([]byte, c1Len) copy(C1Byte, data[:c1Len]) x, y := elliptic.Unmarshal(s.privateKey.Curve, C1Byte) dBC1X, dBC1Y := s.privateKey.Curve.ScalarMult(x, y, s.privateKey.D.Bytes()) c2Len := len(data) - c1Len - s.c3Len c2 := make([]byte, c2Len) c3 := make([]byte, s.c3Len) if s.cipherType == model.C1C2C3 { copy(c2, data[c1Len:c1Len+c2Len]) copy(c3, data[c1Len+c2Len:]) } else if s.cipherType == model.C1C3C2 { copy(c3, data[c1Len:c1Len+s.c3Len]) copy(c2, data[c1Len+s.c3Len:]) } else { return "", fmt.Errorf("cipher type not support") } if err := s.cmbLifeSdk.Kdf(s.privateKey.Curve, dBC1X, dBC1Y, c2); err != nil { return "", fmt.Errorf("kdf error: %v", err) } u := s.cmbLifeSdk.CalculateHash(dBC1X, c2, dBC1Y) if bytes.Compare(u, c3) == 0 { return string(c2), nil } return "", fmt.Errorf("decrypt error") } func (s *Cmb) Sign(input []byte) (string, error) { signData, err := s.sign(input) if err != nil { return "", err } return base64.StdEncoding.EncodeToString(signData), nil } func (s *Cmb) sign(data []byte) ([]byte, error) { z := s.sdk.GetZ(s.privateKey.PublicKey.X, s.privateKey.PublicKey.Y, s.uid) e := s.sdk.GetE(z, data) c := s.privateKey.PublicKey.Curve N := c.Params().N if N.Sign() == 0 { return nil, fmt.Errorf("invalid curve order") } var k, r, sb *big.Int var err error for { for { k, err = s.randFieldElement(c, rand.Reader) if err != nil { return nil, fmt.Errorf("randFieldElement: %+v", err) } r, _ = s.privateKey.Curve.ScalarBaseMult(k.Bytes()) r.Add(r, e) r.Mod(r, N) if r.Sign() != 0 { if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 { break } } } rD := new(big.Int).Mul(s.privateKey.D, r) sb = new(big.Int).Sub(k, rD) d1 := new(big.Int).Add(s.privateKey.D, model.One) d1Inv := new(big.Int).ModInverse(d1, N) sb.Mul(sb, d1Inv) sb.Mod(sb, N) if sb.Sign() != 0 { break } } signature := model.Signature{R: r, S: sb} si, err := asn1.Marshal(signature) if err != nil { return nil, fmt.Errorf("asn1.Marshal: %+v", err) } return si, nil } func (s *Cmb) randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) { params := c.Params() b := make([]byte, params.BitSize/8+8) _, err = io.ReadFull(random, b) if err != nil { return } k = new(big.Int).SetBytes(b) n := new(big.Int).Sub(params.N, model.One) k.Mod(k, n) k.Add(k, model.One) return } func (s *Cmb) Verify(input, sign string) (bool, error) { signBytes, err := base64.StdEncoding.DecodeString(sign) if err != nil { return false, fmt.Errorf("signature is not base64 string: %s", err.Error()) } signature := model.Signature{} _, err = asn1.Unmarshal(signBytes, &signature) if err != nil { return false, fmt.Errorf("signature is not asn1: %s", err.Error()) } ok := s.verify([]byte(input), signature) if !ok { return false, nil } return true, nil } func (s *Cmb) verifyBool(data []byte) bool { if bytes.Equal(data, model.VerifyTrue) { return true } return false } func (s *Cmb) verify(data []byte, signature model.Signature) bool { c := s.publicKey.Curve N := c.Params().N if signature.R.Cmp(model.One) < 0 || signature.S.Cmp(model.One) < 0 { return s.verifyBool(model.VerifyFalse) } if signature.R.Cmp(N) >= 0 || signature.S.Cmp(N) >= 0 { return s.verifyBool(model.VerifyFalse) } z := s.sdk.GetZ(s.publicKey.X, s.publicKey.Y, s.uid) e := s.sdk.GetE(z, data) t := new(big.Int).Add(signature.R, signature.S) t.Mod(t, N) if t.Sign() == 0 { return s.verifyBool(model.VerifyFalse) } var x *big.Int x1, y1 := c.ScalarBaseMult(signature.S.Bytes()) x2, y2 := c.ScalarMult(s.publicKey.X, s.publicKey.Y, t.Bytes()) x, _ = c.Add(x1, y1, x2, y2) x.Add(x, e) x.Mod(x, N) if x.Cmp(signature.R) == 0 { return s.verifyBool(model.VerifyTrue) } return s.verifyBool(model.VerifyFalse) }