package util

import (
	"bytes"
	"crypto/elliptic"
	"crypto/rand"
	"errors"
	"gitea.cdlsxd.cn/self-tools/l_crypt/encrypt_way/sm4/internal/sm2"
	zzsm2 "github.com/ZZMarquis/gm/sm2"
	"github.com/tjfoc/gmsm/sm3"
	"github.com/tjfoc/gmsm/x509"
	"math/big"
)

func Sm2Decrypt(privateKey *sm2.PrivateKey, encryptData []byte) ([]byte, error) {
	C1Byte := make([]byte, 65)
	copy(C1Byte, encryptData[:65])
	x, y := elliptic.Unmarshal(privateKey.Curve, C1Byte)
	dBC1X, dBC1Y := privateKey.Curve.ScalarMult(x, y, bigIntToByte(privateKey.D))
	dBC1Bytes := elliptic.Marshal(privateKey.Curve, dBC1X, dBC1Y)

	kLen := len(encryptData) - 65 - 32
	t, err := kdf(dBC1Bytes, kLen)
	if err != nil {
		return nil, err
	}

	M := make([]byte, kLen)
	for i := 0; i < kLen; i++ {
		M[i] = encryptData[65+i] ^ t[i]
	}

	C3 := make([]byte, 32)
	copy(C3, encryptData[len(encryptData)-32:])
	u := calculateHash(dBC1X, M, dBC1Y)

	if bytes.Compare(u, C3) == 0 {
		return M, nil
	} else {
		return nil, errors.New("解密失败")
	}
}

func Sm2Encrypt(publicKey *sm2.PublicKey, m []byte) ([]byte, error) {
	kLen := len(m)
	var C1, t []byte
	var err error
	var kx, ky *big.Int
	for {
		k, _ := rand.Int(rand.Reader, publicKey.Params().N)
		C1x, C1y := zzsm2.GetSm2P256V1().ScalarBaseMult(bigIntToByte(k))
		// C1x, C1y := sm2.P256Sm2().ScalarBaseMult(bigIntToByte(k))
		C1 = elliptic.Marshal(publicKey.Curve, C1x, C1y)

		kx, ky = publicKey.ScalarMult(publicKey.X, publicKey.Y, bigIntToByte(k))
		kpbBytes := elliptic.Marshal(publicKey, kx, ky)
		t, err = kdf(kpbBytes, kLen)
		if err != nil {
			return nil, err
		}
		if !isAllZero(t) {
			break
		}
	}

	C2 := make([]byte, kLen)
	for i := 0; i < kLen; i++ {
		C2[i] = m[i] ^ t[i]
	}

	C3 := calculateHash(kx, m, ky)

	r := make([]byte, 0, len(C1)+len(C2)+len(C3))
	r = append(r, C1...)
	r = append(r, C2...)
	r = append(r, C3...)
	return r, nil
}

func isAllZero(m []byte) bool {
	for i := 0; i < len(m); i++ {
		if m[i] != 0 {
			return false
		}
	}
	return true
}

func calculateHash(x *big.Int, M []byte, y *big.Int) []byte {
	digest := sm3.New()
	digest.Write(bigIntToByte(x))
	digest.Write(M)
	digest.Write(bigIntToByte(y))
	result := digest.Sum(nil)[:32]
	return result
}

func bigIntToByte(n *big.Int) []byte {
	byteArray := n.Bytes()
	// If the most significant byte's most significant bit is set,
	// prepend a 0 byte to the slice to avoid being interpreted as a negative number.
	if (byteArray[0] & 0x80) != 0 {
		byteArray = append([]byte{0}, byteArray...)
	}
	return byteArray
}

func kdf(Z []byte, klen int) ([]byte, error) {
	ct := 1
	end := (klen + 31) / 32
	result := make([]byte, 0)
	for i := 1; i <= end; i++ {
		b, err := sm3hash(Z, toByteArray(ct))
		if err != nil {
			return nil, err
		}
		result = append(result, b...)
		ct++
	}
	last, err := sm3hash(Z, toByteArray(ct))
	if err != nil {
		return nil, err
	}
	if klen%32 == 0 {
		result = append(result, last...)
	} else {
		result = append(result, last[:klen%32]...)
	}
	return result, nil
}

func sm3hash(sources ...[]byte) ([]byte, error) {
	b, err := joinBytes(sources...)
	if err != nil {
		return nil, err
	}
	md := make([]byte, 32)
	h := x509.SM3.New()
	h.Write(b)
	h.Sum(md[:0])
	return md, nil
}

func joinBytes(params ...[]byte) ([]byte, error) {
	var buffer bytes.Buffer
	for i := 0; i < len(params); i++ {
		_, err := buffer.Write(params[i])
		if err != nil {
			return nil, err
		}
	}
	return buffer.Bytes(), nil
}

func toByteArray(i int) []byte {
	byteArray := []byte{
		byte(i >> 24),
		byte((i & 16777215) >> 16),
		byte((i & 65535) >> 8),
		byte(i & 255),
	}
	return byteArray
}