220 lines
4.6 KiB
Go
220 lines
4.6 KiB
Go
package sm2
|
|
|
|
// reference to ecdsa
|
|
import (
|
|
"crypto"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"encoding/asn1"
|
|
"errors"
|
|
"github.com/tjfoc/gmsm/sm3"
|
|
"io"
|
|
"math/big"
|
|
)
|
|
|
|
var (
|
|
default_uid = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
|
|
)
|
|
|
|
type PublicKey struct {
|
|
elliptic.Curve
|
|
X, Y *big.Int
|
|
}
|
|
|
|
type PrivateKey struct {
|
|
PublicKey
|
|
D *big.Int
|
|
}
|
|
|
|
type sm2Signature struct {
|
|
R, S *big.Int
|
|
}
|
|
|
|
func (priv *PrivateKey) Public() crypto.PublicKey {
|
|
return &priv.PublicKey
|
|
}
|
|
|
|
var errZeroParam = errors.New("zero parameter")
|
|
var one = new(big.Int).SetInt64(1)
|
|
var two = new(big.Int).SetInt64(2)
|
|
|
|
func (priv *PrivateKey) Sign(random io.Reader, msg []byte, signer crypto.SignerOpts) ([]byte, error) {
|
|
r, s, err := Sm2Sign(priv, msg, nil, random)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return asn1.Marshal(sm2Signature{r, s})
|
|
}
|
|
|
|
func Sm2Sign(priv *PrivateKey, msg, uid []byte, random io.Reader) (r, s *big.Int, err error) {
|
|
digest, err := priv.PublicKey.Sm3Digest(msg, uid)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
e := new(big.Int).SetBytes(digest)
|
|
c := priv.PublicKey.Curve
|
|
N := c.Params().N
|
|
if N.Sign() == 0 {
|
|
return nil, nil, errZeroParam
|
|
}
|
|
var k *big.Int
|
|
for { // 调整算法细节以实现SM2
|
|
for {
|
|
k, err = randFieldElement(c, random)
|
|
if err != nil {
|
|
r = nil
|
|
return
|
|
}
|
|
r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
|
|
r.Add(r, e)
|
|
r.Mod(r, N)
|
|
if r.Sign() != 0 {
|
|
if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 {
|
|
break
|
|
}
|
|
}
|
|
|
|
}
|
|
rD := new(big.Int).Mul(priv.D, r)
|
|
s = new(big.Int).Sub(k, rD)
|
|
d1 := new(big.Int).Add(priv.D, one)
|
|
d1Inv := new(big.Int).ModInverse(d1, N)
|
|
s.Mul(s, d1Inv)
|
|
s.Mod(s, N)
|
|
if s.Sign() != 0 {
|
|
break
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (pub *PublicKey) Sm3Digest(msg, uid []byte) ([]byte, error) {
|
|
if len(uid) == 0 {
|
|
uid = default_uid
|
|
}
|
|
|
|
za, err := getZ(pub, uid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
e, err := msgHash(za, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return e.Bytes(), nil
|
|
}
|
|
|
|
func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
|
|
c := pub.Curve
|
|
N := c.Params().N
|
|
one := new(big.Int).SetInt64(1)
|
|
if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
|
|
return false
|
|
}
|
|
if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
|
|
return false
|
|
}
|
|
if len(uid) == 0 {
|
|
uid = default_uid
|
|
}
|
|
za, err := getZ(pub, uid)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
e, err := msgHash(za, msg)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
t := new(big.Int).Add(r, s)
|
|
t.Mod(t, N)
|
|
if t.Sign() == 0 {
|
|
return false
|
|
}
|
|
var x *big.Int
|
|
x1, y1 := c.ScalarBaseMult(s.Bytes())
|
|
x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
|
|
x, _ = c.Add(x1, y1, x2, y2)
|
|
|
|
x.Add(x, e)
|
|
x.Mod(x, N)
|
|
return x.Cmp(r) == 0
|
|
}
|
|
|
|
func msgHash(za, msg []byte) (*big.Int, error) {
|
|
e := sm3.New()
|
|
e.Write(za)
|
|
e.Write(msg)
|
|
return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
|
|
}
|
|
|
|
func bigIntToByte(n *big.Int) []byte {
|
|
byteArray := n.Bytes()
|
|
// If the most significant byte's most significant bit is set,
|
|
// prepend a 0 byte to the slice to avoid being interpreted as a negative number.
|
|
if (byteArray[0] & 0x80) != 0 {
|
|
byteArray = append([]byte{0}, byteArray...)
|
|
}
|
|
return byteArray
|
|
}
|
|
|
|
func getZ(pub *PublicKey, uid []byte) ([]byte, error) {
|
|
z := sm3.New()
|
|
uidLen := len(uid) * 8
|
|
entla := []byte{byte(uidLen >> 8), byte(uidLen & 255)}
|
|
z.Write(entla)
|
|
z.Write(uid)
|
|
|
|
// a 先写死,原来的没有暴露
|
|
z.Write(bigIntToByte(sm2P256ToBig(&sm2P256.a)))
|
|
z.Write(bigIntToByte(sm2P256.B))
|
|
z.Write(bigIntToByte(sm2P256.Gx))
|
|
z.Write(bigIntToByte(sm2P256.Gy))
|
|
|
|
z.Write(bigIntToByte(pub.X))
|
|
z.Write(bigIntToByte(pub.Y))
|
|
return z.Sum(nil), nil
|
|
}
|
|
|
|
func randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) {
|
|
if random == nil {
|
|
random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
|
|
}
|
|
params := c.Params()
|
|
b := make([]byte, params.BitSize/8+8)
|
|
_, err = io.ReadFull(random, b)
|
|
if err != nil {
|
|
return
|
|
}
|
|
k = new(big.Int).SetBytes(b)
|
|
n := new(big.Int).Sub(params.N, one)
|
|
k.Mod(k, n)
|
|
k.Add(k, one)
|
|
return
|
|
}
|
|
|
|
func GenerateKey(random io.Reader) (*PrivateKey, error) {
|
|
c := P256Sm2()
|
|
if random == nil {
|
|
random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
|
|
}
|
|
params := c.Params()
|
|
b := make([]byte, params.BitSize/8+8)
|
|
_, err := io.ReadFull(random, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
k := new(big.Int).SetBytes(b)
|
|
n := new(big.Int).Sub(params.N, two)
|
|
k.Mod(k, n)
|
|
k.Add(k, one)
|
|
priv := new(PrivateKey)
|
|
priv.PublicKey.Curve = c
|
|
priv.D = k
|
|
priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
|
|
|
|
return priv, nil
|
|
}
|