400 lines
9.0 KiB
Go
400 lines
9.0 KiB
Go
package sm2
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"encoding/asn1"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"voucher/internal/pkg/cmb/sm2/model"
|
|
"voucher/internal/pkg/cmb/sm2/sdk"
|
|
"voucher/internal/pkg/cmb/sm2/util"
|
|
)
|
|
|
|
type Sm2 struct {
|
|
publicKey *model.PublicKey
|
|
privateKey *model.PrivateKey
|
|
signature model.Signature
|
|
sdk sdk.SDK
|
|
cipherType model.CipherType
|
|
c3Len int
|
|
uid []byte
|
|
data []byte
|
|
err []error
|
|
toData []byte
|
|
}
|
|
|
|
func NewSm2() *Sm2 {
|
|
return &Sm2{
|
|
sdk: sdk.NewBaseSdk(),
|
|
c3Len: 32,
|
|
uid: model.DefaultUid,
|
|
cipherType: model.C1C2C3,
|
|
err: make([]error, 0),
|
|
}
|
|
}
|
|
|
|
func (s *Sm2) Encrypt() *Sm2 {
|
|
if err := s.encryptErrHandler(); err != nil {
|
|
return s
|
|
}
|
|
c2 := make([]byte, len(s.data))
|
|
copy(c2, s.data)
|
|
var c1 []byte
|
|
var kx, ky *big.Int
|
|
for {
|
|
k, err := rand.Int(rand.Reader, s.publicKey.Params().N)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("rand error: %v", err))
|
|
}
|
|
c1x, c1y := s.publicKey.Curve.ScalarBaseMult(k.Bytes())
|
|
c1 = elliptic.Marshal(s.publicKey.Curve, c1x, c1y)
|
|
kx, ky = s.publicKey.Curve.ScalarMult(s.publicKey.X, s.publicKey.Y, k.Bytes())
|
|
|
|
err = s.sdk.Kdf(s.publicKey, kx, ky, c2)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("kdf error: %v", err))
|
|
}
|
|
if s.encrypted(c2, s.data) {
|
|
break
|
|
}
|
|
}
|
|
|
|
c3 := s.sdk.CalculateHash(kx, s.data, ky)
|
|
|
|
c1Len := len(c1)
|
|
c2Len := len(c2)
|
|
c3Len := len(c3)
|
|
s.toData = make([]byte, c1Len+c2Len+c3Len)
|
|
if s.cipherType == model.C1C2C3 {
|
|
copy(s.toData[:c1Len], c1)
|
|
copy(s.toData[c1Len:c1Len+c2Len], c2)
|
|
copy(s.toData[c1Len+c2Len:], c3)
|
|
} else if s.cipherType == model.C1C3C2 {
|
|
copy(s.toData[:c1Len], c1)
|
|
copy(s.toData[c1Len:c1Len+c3Len], c3)
|
|
copy(s.toData[c1Len+c3Len:], c2)
|
|
} else {
|
|
return s.SetError(fmt.Errorf("cipher type not support"))
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) Decrypt() *Sm2 {
|
|
if err := s.decryptErrHandler(); err != nil {
|
|
return s
|
|
}
|
|
c1Len := 65
|
|
C1Byte := make([]byte, c1Len)
|
|
copy(C1Byte, s.data[:c1Len])
|
|
x, y := elliptic.Unmarshal(s.privateKey.Curve, C1Byte)
|
|
dBC1X, dBC1Y := s.privateKey.Curve.ScalarMult(x, y, s.privateKey.D.Bytes())
|
|
|
|
c2Len := len(s.data) - c1Len - s.c3Len
|
|
c2 := make([]byte, c2Len)
|
|
c3 := make([]byte, s.c3Len)
|
|
if s.cipherType == model.C1C2C3 {
|
|
copy(c2, s.data[c1Len:c1Len+c2Len])
|
|
copy(c3, s.data[c1Len+c2Len:])
|
|
} else if s.cipherType == model.C1C3C2 {
|
|
copy(c3, s.data[c1Len:c1Len+s.c3Len])
|
|
copy(c2, s.data[c1Len+s.c3Len:])
|
|
} else {
|
|
return s.SetError(fmt.Errorf("cipher type not support"))
|
|
}
|
|
|
|
err := s.sdk.Kdf(s.privateKey.Curve, dBC1X, dBC1Y, c2)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("kdf error: %v", err))
|
|
}
|
|
|
|
u := s.sdk.CalculateHash(dBC1X, c2, dBC1Y)
|
|
if bytes.Compare(u, c3) == 0 {
|
|
s.toData = c2
|
|
return s
|
|
} else {
|
|
return s.SetError(fmt.Errorf("decrypt error"))
|
|
}
|
|
}
|
|
|
|
func (s *Sm2) Verify() *Sm2 {
|
|
if err := s.encryptErrHandler(); err != nil {
|
|
return s
|
|
}
|
|
c := s.publicKey.Curve
|
|
N := c.Params().N
|
|
if s.signature.R.Cmp(model.One) < 0 || s.signature.S.Cmp(model.One) < 0 {
|
|
s.toData = model.VerifyFalse
|
|
return s
|
|
}
|
|
if s.signature.R.Cmp(N) >= 0 || s.signature.S.Cmp(N) >= 0 {
|
|
s.toData = model.VerifyFalse
|
|
return s
|
|
}
|
|
z := s.sdk.GetZ(s.publicKey.X, s.publicKey.Y, s.uid)
|
|
e := s.sdk.GetE(z, s.data)
|
|
t := new(big.Int).Add(s.signature.R, s.signature.S)
|
|
t.Mod(t, N)
|
|
if t.Sign() == 0 {
|
|
s.toData = model.VerifyFalse
|
|
return s
|
|
}
|
|
var x *big.Int
|
|
x1, y1 := c.ScalarBaseMult(s.signature.S.Bytes())
|
|
x2, y2 := c.ScalarMult(s.publicKey.X, s.publicKey.Y, t.Bytes())
|
|
x, _ = c.Add(x1, y1, x2, y2)
|
|
|
|
x.Add(x, e)
|
|
x.Mod(x, N)
|
|
if x.Cmp(s.signature.R) == 0 {
|
|
s.toData = model.VerifyTrue
|
|
} else {
|
|
s.toData = model.VerifyFalse
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) Sign() *Sm2 {
|
|
if err := s.signErrHandler(); err != nil {
|
|
return s
|
|
}
|
|
z := s.sdk.GetZ(s.privateKey.PublicKey.X, s.privateKey.PublicKey.Y, s.uid)
|
|
e := s.sdk.GetE(z, s.data)
|
|
|
|
c := s.privateKey.PublicKey.Curve
|
|
N := c.Params().N
|
|
if N.Sign() == 0 {
|
|
return s.SetError(fmt.Errorf("invalid curve order"))
|
|
}
|
|
var k, r, sb *big.Int
|
|
var err error
|
|
for {
|
|
for {
|
|
k, err = s.randFieldElement(c, rand.Reader)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("randFieldElement: %+v", err))
|
|
}
|
|
r, _ = s.privateKey.Curve.ScalarBaseMult(k.Bytes())
|
|
r.Add(r, e)
|
|
r.Mod(r, N)
|
|
if r.Sign() != 0 {
|
|
if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
rD := new(big.Int).Mul(s.privateKey.D, r)
|
|
sb = new(big.Int).Sub(k, rD)
|
|
d1 := new(big.Int).Add(s.privateKey.D, model.One)
|
|
d1Inv := new(big.Int).ModInverse(d1, N)
|
|
sb.Mul(sb, d1Inv)
|
|
sb.Mod(sb, N)
|
|
if sb.Sign() != 0 {
|
|
break
|
|
}
|
|
}
|
|
s.signature = model.Signature{R: r, S: sb}
|
|
si, err := asn1.Marshal(s.signature)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("asn1.Marshal: %+v", err))
|
|
}
|
|
s.toData = si
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) {
|
|
params := c.Params()
|
|
b := make([]byte, params.BitSize/8+8)
|
|
_, err = io.ReadFull(random, b)
|
|
if err != nil {
|
|
return
|
|
}
|
|
k = new(big.Int).SetBytes(b)
|
|
n := new(big.Int).Sub(params.N, model.One)
|
|
k.Mod(k, n)
|
|
k.Add(k, model.One)
|
|
return
|
|
}
|
|
|
|
func (s *Sm2) encrypted(encData []byte, in []byte) bool {
|
|
encDataLen := len(encData)
|
|
for i := 0; i != encDataLen; i++ {
|
|
if encData[i] != in[i] {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s *Sm2) SetCipherType(cipherType model.CipherType) *Sm2 {
|
|
s.cipherType = cipherType
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetHexPublicKey(hexStr string) *Sm2 {
|
|
d, err := hex.DecodeString(hexStr)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("publicKey is not hex string: %s", err.Error()))
|
|
}
|
|
s.publicKey, err = util.HexToPublicKey(d)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("parse publicKey err: %+v", err))
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetHexPrivateKey(hexStr string) *Sm2 {
|
|
d, err := hex.DecodeString(hexStr)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("privateKey is not hex string: %s", err.Error()))
|
|
}
|
|
s.privateKey, err = util.HexToPrivateKey(d)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("parse privateKey err: %+v", err))
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetPublicKey(publicKey *model.PublicKey) *Sm2 {
|
|
s.publicKey = publicKey
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetPrivateKey(privateKey *model.PrivateKey) *Sm2 {
|
|
s.privateKey = privateKey
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetHexSignature(hexStr string) *Sm2 {
|
|
sign, err := hex.DecodeString(hexStr)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("signature is not hex string: %s", err.Error()))
|
|
}
|
|
_, err = asn1.Unmarshal(sign, &s.signature)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("signature is not asn1: %s", err.Error()))
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetBase64Signature(base64Str string) *Sm2 {
|
|
sign, err := base64.StdEncoding.DecodeString(base64Str)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("signature is not base64 string: %s", err.Error()))
|
|
}
|
|
_, err = asn1.Unmarshal(sign, &s.signature)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("signature is not asn1: %s", err.Error()))
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetHexSignatureData(hexStr string) *Sm2 {
|
|
sign, err := hex.DecodeString(hexStr)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("signature is not hex string: %s", err.Error()))
|
|
}
|
|
_, err = asn1.Unmarshal(sign, &s.signature)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("signature is not asn1: %s", err.Error()))
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetSignature(signature model.Signature) *Sm2 {
|
|
s.signature = signature
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetSdk(sdk sdk.SDK) *Sm2 {
|
|
s.sdk = sdk
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetC3Len(c3Len int) *Sm2 {
|
|
s.c3Len = c3Len
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetUid(uid []byte) *Sm2 {
|
|
s.uid = uid
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetBase64StringData(base64String string) *Sm2 {
|
|
data, err := base64.StdEncoding.DecodeString(base64String)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("data is not base64 string: %s", err.Error()))
|
|
}
|
|
s.data = data
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetStringData(str string) *Sm2 {
|
|
s.data = []byte(str)
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetHexData(hexStr string) *Sm2 {
|
|
data, err := hex.DecodeString(hexStr)
|
|
if err != nil {
|
|
return s.SetError(fmt.Errorf("data is not hex string: %s", err.Error()))
|
|
}
|
|
s.data = data
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) SetData(data []byte) *Sm2 {
|
|
s.data = data
|
|
return s
|
|
}
|
|
|
|
func (s *Sm2) ToBytes() ([]byte, error) {
|
|
if err := s.errHandle(); err != nil {
|
|
return nil, err
|
|
}
|
|
return s.toData, nil
|
|
}
|
|
|
|
func (s *Sm2) ToString() (string, error) {
|
|
if err := s.errHandle(); err != nil {
|
|
return "", err
|
|
}
|
|
return string(s.toData), nil
|
|
}
|
|
|
|
func (s *Sm2) ToBase64String() (string, error) {
|
|
if err := s.errHandle(); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.StdEncoding.EncodeToString(s.toData), nil
|
|
}
|
|
|
|
func (s *Sm2) ToHexString() (string, error) {
|
|
if err := s.errHandle(); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(s.toData), nil
|
|
}
|
|
|
|
func (s *Sm2) ToBool() (bool, error) {
|
|
if err := s.errHandle(); err != nil {
|
|
return false, err
|
|
}
|
|
if bytes.Equal(s.toData, model.VerifyTrue) {
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (s *Sm2) ToSignature() (*model.Signature, error) {
|
|
if err := s.errHandle(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &s.signature, nil
|
|
}
|