package sm2 import ( "bytes" "crypto/elliptic" "crypto/rand" "encoding/asn1" "encoding/base64" "encoding/hex" "fmt" "io" "math/big" "voucher/internal/pkg/cmb/sm2/model" "voucher/internal/pkg/cmb/sm2/sdk" "voucher/internal/pkg/cmb/sm2/util" ) type Sm2 struct { publicKey *model.PublicKey privateKey *model.PrivateKey signature model.Signature sdk sdk.SDK cipherType model.CipherType c3Len int uid []byte data []byte err []error toData []byte } func NewSm2() *Sm2 { return &Sm2{ sdk: sdk.NewBaseSdk(), c3Len: 32, uid: model.DefaultUid, cipherType: model.C1C2C3, err: make([]error, 0), } } func (s *Sm2) Encrypt() *Sm2 { if err := s.encryptErrHandler(); err != nil { return s } c2 := make([]byte, len(s.data)) copy(c2, s.data) var c1 []byte var kx, ky *big.Int for { k, err := rand.Int(rand.Reader, s.publicKey.Params().N) if err != nil { return s.SetError(fmt.Errorf("rand error: %v", err)) } c1x, c1y := s.publicKey.Curve.ScalarBaseMult(k.Bytes()) c1 = elliptic.Marshal(s.publicKey.Curve, c1x, c1y) kx, ky = s.publicKey.Curve.ScalarMult(s.publicKey.X, s.publicKey.Y, k.Bytes()) err = s.sdk.Kdf(s.publicKey, kx, ky, c2) if err != nil { return s.SetError(fmt.Errorf("kdf error: %v", err)) } if s.encrypted(c2, s.data) { break } } c3 := s.sdk.CalculateHash(kx, s.data, ky) c1Len := len(c1) c2Len := len(c2) c3Len := len(c3) s.toData = make([]byte, c1Len+c2Len+c3Len) if s.cipherType == model.C1C2C3 { copy(s.toData[:c1Len], c1) copy(s.toData[c1Len:c1Len+c2Len], c2) copy(s.toData[c1Len+c2Len:], c3) } else if s.cipherType == model.C1C3C2 { copy(s.toData[:c1Len], c1) copy(s.toData[c1Len:c1Len+c3Len], c3) copy(s.toData[c1Len+c3Len:], c2) } else { return s.SetError(fmt.Errorf("cipher type not support")) } return s } func (s *Sm2) Decrypt() *Sm2 { if err := s.decryptErrHandler(); err != nil { return s } c1Len := 65 C1Byte := make([]byte, c1Len) copy(C1Byte, s.data[:c1Len]) x, y := elliptic.Unmarshal(s.privateKey.Curve, C1Byte) dBC1X, dBC1Y := s.privateKey.Curve.ScalarMult(x, y, s.privateKey.D.Bytes()) c2Len := len(s.data) - c1Len - s.c3Len c2 := make([]byte, c2Len) c3 := make([]byte, s.c3Len) if s.cipherType == model.C1C2C3 { copy(c2, s.data[c1Len:c1Len+c2Len]) copy(c3, s.data[c1Len+c2Len:]) } else if s.cipherType == model.C1C3C2 { copy(c3, s.data[c1Len:c1Len+s.c3Len]) copy(c2, s.data[c1Len+s.c3Len:]) } else { return s.SetError(fmt.Errorf("cipher type not support")) } err := s.sdk.Kdf(s.privateKey.Curve, dBC1X, dBC1Y, c2) if err != nil { return s.SetError(fmt.Errorf("kdf error: %v", err)) } u := s.sdk.CalculateHash(dBC1X, c2, dBC1Y) if bytes.Compare(u, c3) == 0 { s.toData = c2 return s } else { return s.SetError(fmt.Errorf("decrypt error")) } } func (s *Sm2) Verify() *Sm2 { if err := s.encryptErrHandler(); err != nil { return s } c := s.publicKey.Curve N := c.Params().N if s.signature.R.Cmp(model.One) < 0 || s.signature.S.Cmp(model.One) < 0 { s.toData = model.VerifyFalse return s } if s.signature.R.Cmp(N) >= 0 || s.signature.S.Cmp(N) >= 0 { s.toData = model.VerifyFalse return s } z := s.sdk.GetZ(s.publicKey.X, s.publicKey.Y, s.uid) e := s.sdk.GetE(z, s.data) t := new(big.Int).Add(s.signature.R, s.signature.S) t.Mod(t, N) if t.Sign() == 0 { s.toData = model.VerifyFalse return s } var x *big.Int x1, y1 := c.ScalarBaseMult(s.signature.S.Bytes()) x2, y2 := c.ScalarMult(s.publicKey.X, s.publicKey.Y, t.Bytes()) x, _ = c.Add(x1, y1, x2, y2) x.Add(x, e) x.Mod(x, N) if x.Cmp(s.signature.R) == 0 { s.toData = model.VerifyTrue } else { s.toData = model.VerifyFalse } return s } func (s *Sm2) Sign() *Sm2 { if err := s.signErrHandler(); err != nil { return s } z := s.sdk.GetZ(s.privateKey.PublicKey.X, s.privateKey.PublicKey.Y, s.uid) e := s.sdk.GetE(z, s.data) c := s.privateKey.PublicKey.Curve N := c.Params().N if N.Sign() == 0 { return s.SetError(fmt.Errorf("invalid curve order")) } var k, r, sb *big.Int var err error for { for { k, err = s.randFieldElement(c, rand.Reader) if err != nil { return s.SetError(fmt.Errorf("randFieldElement: %+v", err)) } r, _ = s.privateKey.Curve.ScalarBaseMult(k.Bytes()) r.Add(r, e) r.Mod(r, N) if r.Sign() != 0 { if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 { break } } } rD := new(big.Int).Mul(s.privateKey.D, r) sb = new(big.Int).Sub(k, rD) d1 := new(big.Int).Add(s.privateKey.D, model.One) d1Inv := new(big.Int).ModInverse(d1, N) sb.Mul(sb, d1Inv) sb.Mod(sb, N) if sb.Sign() != 0 { break } } s.signature = model.Signature{R: r, S: sb} si, err := asn1.Marshal(s.signature) if err != nil { return s.SetError(fmt.Errorf("asn1.Marshal: %+v", err)) } s.toData = si return s } func (s *Sm2) randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) { params := c.Params() b := make([]byte, params.BitSize/8+8) _, err = io.ReadFull(random, b) if err != nil { return } k = new(big.Int).SetBytes(b) n := new(big.Int).Sub(params.N, model.One) k.Mod(k, n) k.Add(k, model.One) return } func (s *Sm2) encrypted(encData []byte, in []byte) bool { encDataLen := len(encData) for i := 0; i != encDataLen; i++ { if encData[i] != in[i] { return true } } return false } func (s *Sm2) SetCipherType(cipherType model.CipherType) *Sm2 { s.cipherType = cipherType return s } func (s *Sm2) SetHexPublicKey(hexStr string) *Sm2 { d, err := hex.DecodeString(hexStr) if err != nil { return s.SetError(fmt.Errorf("publicKey is not hex string: %s", err.Error())) } s.publicKey, err = util.HexToPublicKey(d) if err != nil { return s.SetError(fmt.Errorf("parse publicKey err: %+v", err)) } return s } func (s *Sm2) SetHexPrivateKey(hexStr string) *Sm2 { d, err := hex.DecodeString(hexStr) if err != nil { return s.SetError(fmt.Errorf("privateKey is not hex string: %s", err.Error())) } s.privateKey, err = util.HexToPrivateKey(d) if err != nil { return s.SetError(fmt.Errorf("parse privateKey err: %+v", err)) } return s } func (s *Sm2) SetPublicKey(publicKey *model.PublicKey) *Sm2 { s.publicKey = publicKey return s } func (s *Sm2) SetPrivateKey(privateKey *model.PrivateKey) *Sm2 { s.privateKey = privateKey return s } func (s *Sm2) SetHexSignature(hexStr string) *Sm2 { sign, err := hex.DecodeString(hexStr) if err != nil { return s.SetError(fmt.Errorf("signature is not hex string: %s", err.Error())) } _, err = asn1.Unmarshal(sign, &s.signature) if err != nil { return s.SetError(fmt.Errorf("signature is not asn1: %s", err.Error())) } return s } func (s *Sm2) SetBase64Signature(base64Str string) *Sm2 { sign, err := base64.StdEncoding.DecodeString(base64Str) if err != nil { return s.SetError(fmt.Errorf("signature is not base64 string: %s", err.Error())) } _, err = asn1.Unmarshal(sign, &s.signature) if err != nil { return s.SetError(fmt.Errorf("signature is not asn1: %s", err.Error())) } return s } func (s *Sm2) SetHexSignatureData(hexStr string) *Sm2 { sign, err := hex.DecodeString(hexStr) if err != nil { return s.SetError(fmt.Errorf("signature is not hex string: %s", err.Error())) } _, err = asn1.Unmarshal(sign, &s.signature) if err != nil { return s.SetError(fmt.Errorf("signature is not asn1: %s", err.Error())) } return s } func (s *Sm2) SetSignature(signature model.Signature) *Sm2 { s.signature = signature return s } func (s *Sm2) SetSdk(sdk sdk.SDK) *Sm2 { s.sdk = sdk return s } func (s *Sm2) SetC3Len(c3Len int) *Sm2 { s.c3Len = c3Len return s } func (s *Sm2) SetUid(uid []byte) *Sm2 { s.uid = uid return s } func (s *Sm2) SetBase64StringData(base64String string) *Sm2 { data, err := base64.StdEncoding.DecodeString(base64String) if err != nil { return s.SetError(fmt.Errorf("data is not base64 string: %s", err.Error())) } s.data = data return s } func (s *Sm2) SetStringData(str string) *Sm2 { s.data = []byte(str) return s } func (s *Sm2) SetHexData(hexStr string) *Sm2 { data, err := hex.DecodeString(hexStr) if err != nil { return s.SetError(fmt.Errorf("data is not hex string: %s", err.Error())) } s.data = data return s } func (s *Sm2) SetData(data []byte) *Sm2 { s.data = data return s } func (s *Sm2) ToBytes() ([]byte, error) { if err := s.errHandle(); err != nil { return nil, err } return s.toData, nil } func (s *Sm2) ToString() (string, error) { if err := s.errHandle(); err != nil { return "", err } return string(s.toData), nil } func (s *Sm2) ToBase64String() (string, error) { if err := s.errHandle(); err != nil { return "", err } return base64.StdEncoding.EncodeToString(s.toData), nil } func (s *Sm2) ToHexString() (string, error) { if err := s.errHandle(); err != nil { return "", err } return hex.EncodeToString(s.toData), nil } func (s *Sm2) ToBool() (bool, error) { if err := s.errHandle(); err != nil { return false, err } if bytes.Equal(s.toData, model.VerifyTrue) { return true, nil } return false, nil } func (s *Sm2) ToSignature() (*model.Signature, error) { if err := s.errHandle(); err != nil { return nil, err } return &s.signature, nil }