package sm2 // reference to ecdsa import ( "crypto" "crypto/elliptic" "crypto/rand" "encoding/asn1" "errors" "github.com/tjfoc/gmsm/sm3" "io" "math/big" ) var ( default_uid = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38} ) type PublicKey struct { elliptic.Curve X, Y *big.Int } type PrivateKey struct { PublicKey D *big.Int } type sm2Signature struct { R, S *big.Int } func (priv *PrivateKey) Public() crypto.PublicKey { return &priv.PublicKey } var errZeroParam = errors.New("zero parameter") var one = new(big.Int).SetInt64(1) var two = new(big.Int).SetInt64(2) func (priv *PrivateKey) Sign(random io.Reader, msg []byte, signer crypto.SignerOpts) ([]byte, error) { r, s, err := Sm2Sign(priv, msg, nil, random) if err != nil { return nil, err } return asn1.Marshal(sm2Signature{r, s}) } func Sm2Sign(priv *PrivateKey, msg, uid []byte, random io.Reader) (r, s *big.Int, err error) { digest, err := priv.PublicKey.Sm3Digest(msg, uid) if err != nil { return nil, nil, err } e := new(big.Int).SetBytes(digest) c := priv.PublicKey.Curve N := c.Params().N if N.Sign() == 0 { return nil, nil, errZeroParam } var k *big.Int for { // 调整算法细节以实现SM2 for { k, err = randFieldElement(c, random) if err != nil { r = nil return } r, _ = priv.Curve.ScalarBaseMult(k.Bytes()) r.Add(r, e) r.Mod(r, N) if r.Sign() != 0 { if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 { break } } } rD := new(big.Int).Mul(priv.D, r) s = new(big.Int).Sub(k, rD) d1 := new(big.Int).Add(priv.D, one) d1Inv := new(big.Int).ModInverse(d1, N) s.Mul(s, d1Inv) s.Mod(s, N) if s.Sign() != 0 { break } } return } func (pub *PublicKey) Sm3Digest(msg, uid []byte) ([]byte, error) { if len(uid) == 0 { uid = default_uid } za, err := getZ(pub, uid) if err != nil { return nil, err } e, err := msgHash(za, msg) if err != nil { return nil, err } return e.Bytes(), nil } func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool { c := pub.Curve N := c.Params().N one := new(big.Int).SetInt64(1) if r.Cmp(one) < 0 || s.Cmp(one) < 0 { return false } if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 { return false } if len(uid) == 0 { uid = default_uid } za, err := getZ(pub, uid) if err != nil { return false } e, err := msgHash(za, msg) if err != nil { return false } t := new(big.Int).Add(r, s) t.Mod(t, N) if t.Sign() == 0 { return false } var x *big.Int x1, y1 := c.ScalarBaseMult(s.Bytes()) x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes()) x, _ = c.Add(x1, y1, x2, y2) x.Add(x, e) x.Mod(x, N) return x.Cmp(r) == 0 } func msgHash(za, msg []byte) (*big.Int, error) { e := sm3.New() e.Write(za) e.Write(msg) return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil } func bigIntToByte(n *big.Int) []byte { byteArray := n.Bytes() // If the most significant byte's most significant bit is set, // prepend a 0 byte to the slice to avoid being interpreted as a negative number. if (byteArray[0] & 0x80) != 0 { byteArray = append([]byte{0}, byteArray...) } return byteArray } func getZ(pub *PublicKey, uid []byte) ([]byte, error) { z := sm3.New() uidLen := len(uid) * 8 entla := []byte{byte(uidLen >> 8), byte(uidLen & 255)} z.Write(entla) z.Write(uid) // a 先写死,原来的没有暴露 z.Write(bigIntToByte(sm2P256ToBig(&sm2P256.a))) z.Write(bigIntToByte(sm2P256.B)) z.Write(bigIntToByte(sm2P256.Gx)) z.Write(bigIntToByte(sm2P256.Gy)) z.Write(bigIntToByte(pub.X)) z.Write(bigIntToByte(pub.Y)) return z.Sum(nil), nil } func randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) { if random == nil { random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it. } params := c.Params() b := make([]byte, params.BitSize/8+8) _, err = io.ReadFull(random, b) if err != nil { return } k = new(big.Int).SetBytes(b) n := new(big.Int).Sub(params.N, one) k.Mod(k, n) k.Add(k, one) return } func GenerateKey(random io.Reader) (*PrivateKey, error) { c := P256Sm2() if random == nil { random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it. } params := c.Params() b := make([]byte, params.BitSize/8+8) _, err := io.ReadFull(random, b) if err != nil { return nil, err } k := new(big.Int).SetBytes(b) n := new(big.Int).Sub(params.N, two) k.Mod(k, n) k.Add(k, one) priv := new(PrivateKey) priv.PublicKey.Curve = c priv.D = k priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes()) return priv, nil }