voucher/internal/pkg/cmb/sm2/sm2.go

407 lines
9.1 KiB
Go

package sm2
import (
"bytes"
"crypto/elliptic"
"crypto/rand"
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"math/big"
"voucher/internal/pkg/cmb/sm2/model"
"voucher/internal/pkg/cmb/sm2/sdk"
"voucher/internal/pkg/cmb/sm2/util"
)
type Sm2 struct {
publicKey *model.PublicKey
privateKey *model.PrivateKey
signature model.Signature
sdk sdk.SDK
cipherType model.CipherType
c3Len int
uid []byte
data []byte
err []error
toData []byte
}
func NewSm2() *Sm2 {
return &Sm2{
sdk: sdk.NewBaseSdk(),
c3Len: 32,
uid: model.DefaultUid,
cipherType: model.C1C2C3,
err: make([]error, 0),
}
}
func (s *Sm2) Encrypt() *Sm2 {
if err := s.encryptErrHandler(); err != nil {
return s
}
c2 := make([]byte, len(s.data))
copy(c2, s.data)
var c1 []byte
var kx, ky *big.Int
for {
k, err := rand.Int(rand.Reader, s.publicKey.Params().N)
if err != nil {
return s.SetError(fmt.Errorf("rand error: %v", err))
}
c1x, c1y := s.publicKey.Curve.ScalarBaseMult(k.Bytes())
c1 = elliptic.Marshal(s.publicKey.Curve, c1x, c1y)
kx, ky = s.publicKey.Curve.ScalarMult(s.publicKey.X, s.publicKey.Y, k.Bytes())
err = s.sdk.Kdf(s.publicKey, kx, ky, c2)
if err != nil {
return s.SetError(fmt.Errorf("kdf error: %v", err))
}
if s.encrypted(c2, s.data) {
break
}
}
c3 := s.sdk.CalculateHash(kx, s.data, ky)
c1Len := len(c1)
c2Len := len(c2)
c3Len := len(c3)
s.toData = make([]byte, c1Len+c2Len+c3Len)
if s.cipherType == model.C1C2C3 {
copy(s.toData[:c1Len], c1)
copy(s.toData[c1Len:c1Len+c2Len], c2)
copy(s.toData[c1Len+c2Len:], c3)
} else if s.cipherType == model.C1C3C2 {
copy(s.toData[:c1Len], c1)
copy(s.toData[c1Len:c1Len+c3Len], c3)
copy(s.toData[c1Len+c3Len:], c2)
} else {
return s.SetError(fmt.Errorf("cipher type not support"))
}
return s
}
func (s *Sm2) Decrypt() *Sm2 {
if err := s.decryptErrHandler(); err != nil {
return s
}
c1Len := 65
C1Byte := make([]byte, c1Len)
copy(C1Byte, s.data[:c1Len])
x, y := elliptic.Unmarshal(s.privateKey.Curve, C1Byte)
dBC1X, dBC1Y := s.privateKey.Curve.ScalarMult(x, y, s.privateKey.D.Bytes())
c2Len := len(s.data) - c1Len - s.c3Len
c2 := make([]byte, c2Len)
c3 := make([]byte, s.c3Len)
if s.cipherType == model.C1C2C3 {
copy(c2, s.data[c1Len:c1Len+c2Len])
copy(c3, s.data[c1Len+c2Len:])
} else if s.cipherType == model.C1C3C2 {
copy(c3, s.data[c1Len:c1Len+s.c3Len])
copy(c2, s.data[c1Len+s.c3Len:])
} else {
return s.SetError(fmt.Errorf("cipher type not support"))
}
err := s.sdk.Kdf(s.privateKey.Curve, dBC1X, dBC1Y, c2)
if err != nil {
return s.SetError(fmt.Errorf("kdf error: %v", err))
}
u := s.sdk.CalculateHash(dBC1X, c2, dBC1Y)
if bytes.Compare(u, c3) == 0 {
s.toData = c2
return s
} else {
return s.SetError(fmt.Errorf("decrypt error"))
}
}
func (s *Sm2) Verify() *Sm2 {
if err := s.encryptErrHandler(); err != nil {
return s
}
c := s.publicKey.Curve
N := c.Params().N
if s.signature.R.Cmp(model.One) < 0 || s.signature.S.Cmp(model.One) < 0 {
s.toData = model.VerifyFalse
return s
}
if s.signature.R.Cmp(N) >= 0 || s.signature.S.Cmp(N) >= 0 {
s.toData = model.VerifyFalse
return s
}
z := s.sdk.GetZ(s.publicKey.X, s.publicKey.Y, s.uid)
e := s.sdk.GetE(z, s.data)
t := new(big.Int).Add(s.signature.R, s.signature.S)
t.Mod(t, N)
if t.Sign() == 0 {
s.toData = model.VerifyFalse
return s
}
var x *big.Int
x1, y1 := c.ScalarBaseMult(s.signature.S.Bytes())
x2, y2 := c.ScalarMult(s.publicKey.X, s.publicKey.Y, t.Bytes())
x, _ = c.Add(x1, y1, x2, y2)
x.Add(x, e)
x.Mod(x, N)
if x.Cmp(s.signature.R) == 0 {
s.toData = model.VerifyTrue
} else {
s.toData = model.VerifyFalse
}
return s
}
func (s *Sm2) Sign() *Sm2 {
if err := s.signErrHandler(); err != nil {
return s
}
z := s.sdk.GetZ(s.privateKey.PublicKey.X, s.privateKey.PublicKey.Y, s.uid)
e := s.sdk.GetE(z, s.data)
c := s.privateKey.PublicKey.Curve
N := c.Params().N
if N.Sign() == 0 {
return s.SetError(fmt.Errorf("invalid curve order"))
}
var k, r, sb *big.Int
var err error
for {
for {
k, err = s.randFieldElement(c, rand.Reader)
if err != nil {
return s.SetError(fmt.Errorf("randFieldElement: %+v", err))
}
r, _ = s.privateKey.Curve.ScalarBaseMult(k.Bytes())
r.Add(r, e)
r.Mod(r, N)
if r.Sign() != 0 {
if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 {
break
}
}
}
rD := new(big.Int).Mul(s.privateKey.D, r)
sb = new(big.Int).Sub(k, rD)
d1 := new(big.Int).Add(s.privateKey.D, model.One)
d1Inv := new(big.Int).ModInverse(d1, N)
sb.Mul(sb, d1Inv)
sb.Mod(sb, N)
if sb.Sign() != 0 {
break
}
}
s.signature = model.Signature{R: r, S: sb}
si, err := asn1.Marshal(s.signature)
if err != nil {
return s.SetError(fmt.Errorf("asn1.Marshal: %+v", err))
}
s.toData = si
return s
}
func (s *Sm2) randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) {
params := c.Params()
b := make([]byte, params.BitSize/8+8)
_, err = io.ReadFull(random, b)
if err != nil {
return
}
k = new(big.Int).SetBytes(b)
n := new(big.Int).Sub(params.N, model.One)
k.Mod(k, n)
k.Add(k, model.One)
return
}
func (s *Sm2) encrypted(encData []byte, in []byte) bool {
encDataLen := len(encData)
for i := 0; i != encDataLen; i++ {
if encData[i] != in[i] {
return true
}
}
return false
}
func (s *Sm2) SetCipherType(cipherType model.CipherType) *Sm2 {
s.cipherType = cipherType
return s
}
func (s *Sm2) SetHexPublicKey(hexStr string) *Sm2 {
d, err := hex.DecodeString(hexStr)
if err != nil {
return s.SetError(fmt.Errorf("publicKey is not hex string: %s", err.Error()))
}
s.publicKey, err = util.HexToPublicKey(d)
if err != nil {
return s.SetError(fmt.Errorf("parse publicKey err: %+v", err))
}
return s
}
func (s *Sm2) SetHexPrivateKey(hexStr string) *Sm2 {
d, err := hex.DecodeString(hexStr)
if err != nil {
return s.SetError(fmt.Errorf("privateKey is not hex string: %s", err.Error()))
}
s.privateKey, err = util.HexToPrivateKey(d)
if err != nil {
return s.SetError(fmt.Errorf("parse privateKey err: %+v", err))
}
return s
}
func (s *Sm2) SetPublicKey(publicKey *model.PublicKey) *Sm2 {
s.publicKey = publicKey
return s
}
func (s *Sm2) SetPrivateKey(privateKey *model.PrivateKey) *Sm2 {
s.privateKey = privateKey
return s
}
func (s *Sm2) SetHexSignature(hexStr string) *Sm2 {
sign, err := hex.DecodeString(hexStr)
if err != nil {
return s.SetError(fmt.Errorf("signature is not hex string: %s", err.Error()))
}
_, err = asn1.Unmarshal(sign, &s.signature)
if err != nil {
return s.SetError(fmt.Errorf("signature is not asn1: %s", err.Error()))
}
return s
}
func (s *Sm2) SetBase64Signature(base64Str string) *Sm2 {
sign, err := base64.StdEncoding.DecodeString(base64Str)
if err != nil {
return s.SetError(fmt.Errorf("signature is not base64 string: %s", err.Error()))
}
_, err = asn1.Unmarshal(sign, &s.signature)
if err != nil {
return s.SetError(fmt.Errorf("signature is not asn1: %s", err.Error()))
}
return s
}
func (s *Sm2) SetHexSignatureData(hexStr string) *Sm2 {
sign, err := hex.DecodeString(hexStr)
if err != nil {
return s.SetError(fmt.Errorf("signature is not hex string: %s", err.Error()))
}
_, err = asn1.Unmarshal(sign, &s.signature)
if err != nil {
return s.SetError(fmt.Errorf("signature is not asn1: %s", err.Error()))
}
return s
}
func (s *Sm2) SetSignature(signature model.Signature) *Sm2 {
s.signature = signature
return s
}
func (s *Sm2) SetSdk(sdk sdk.SDK) *Sm2 {
s.sdk = sdk
return s
}
func (s *Sm2) SetC3Len(c3Len int) *Sm2 {
s.c3Len = c3Len
return s
}
func (s *Sm2) SetUid(uid []byte) *Sm2 {
s.uid = uid
return s
}
func (s *Sm2) SetBase64StringData(base64String string) *Sm2 {
data, err := base64.StdEncoding.DecodeString(base64String)
if err != nil {
return s.SetError(fmt.Errorf("data is not base64 string: %s", err.Error()))
}
s.data = data
return s
}
func (s *Sm2) SetStringData(str string) *Sm2 {
s.data = []byte(str)
return s
}
func (s *Sm2) SetHexData(hexStr string) *Sm2 {
data, err := hex.DecodeString(hexStr)
if err != nil {
return s.SetError(fmt.Errorf("data is not hex string: %s", err.Error()))
}
s.data = data
return s
}
func (s *Sm2) SetData(data []byte) *Sm2 {
s.data = data
return s
}
func (s *Sm2) ToBytes() ([]byte, error) {
if err := s.errHandle(); err != nil {
return nil, err
}
return s.toData, nil
}
func (s *Sm2) ToString() (string, error) {
if err := s.errHandle(); err != nil {
return "", err
}
return string(s.toData), nil
}
func (s *Sm2) ToBase64String() (string, error) {
if err := s.errHandle(); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(s.toData), nil
}
func (s *Sm2) ToHexString() (string, error) {
if err := s.errHandle(); err != nil {
return "", err
}
return hex.EncodeToString(s.toData), nil
}
func (s *Sm2) ToBool() (bool, error) {
if err := s.errHandle(); err != nil {
return false, err
}
if bytes.Equal(s.toData, model.VerifyTrue) {
return true, nil
}
return false, nil
}
func (s *Sm2) ToSignature() (*model.Signature, error) {
if err := s.errHandle(); err != nil {
return nil, err
}
return &s.signature, nil
}