package sm4

import (
	"PaymentCenter/app/utils/encrypt/sm4/internal/sm2"
	"PaymentCenter/app/utils/encrypt/sm4/internal/util"
	"crypto/rand"
	"encoding/base64"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"github.com/ZZMarquis/gm/sm4"
	"math/big"
	"strings"
)

func checkInData(reqData map[string]string, key string) (string, error) {
	data, ok := reqData[key]
	if !ok {
		return "", errors.New("请求数据中不存在" + key)
	}
	return data, nil
}

func Sm4Decrypt(merchantId, privateKey, sopPublicKey, respJson string, isRequest bool) (string, error) {
	var reqData map[string]string
	err := json.Unmarshal([]byte(respJson), &reqData)
	if err != nil {
		return "", err
	}

	keys := [4]string{}
	if isRequest {
		keys = [4]string{"request", "signature", "encryptKey", "accessToken"}
	} else {
		keys = [4]string{"response", "signature", "encryptKey", "accessToken"}
	}
	var inEncryptKey, inAccessToken, inData, inSignature string

	for i := 0; i < 4; i++ {
		data, err := checkInData(reqData, keys[i])
		if err != nil {
			return "", err
		}
		switch keys[i] {
		case "request", "response":
			inData = data
		case "signature":
			inSignature = data
		case "encryptKey":
			inEncryptKey = data
		case "accessToken":
			inAccessToken = data
		}
	}

	checked := verify(fmt.Sprintf("%s%s%s", inData, inEncryptKey, inAccessToken), inSignature, sopPublicKey, merchantId)
	if !checked {
		return "", errors.New("签名验证失败")
	}

	priKey, err := sm2.ReadPrivateKeyFromHex(privateKey)
	if err != nil {
		return "", errors.New("读取私钥失败")
	}
	hexEncryptKey, err := hex.DecodeString(inEncryptKey)
	if err != nil {
		return "", errors.New("解密sm4key失败")
	}
	sm4Key, err := util.Sm2Decrypt(priKey, hexEncryptKey)

	request, _ := base64.StdEncoding.DecodeString(inData)

	encryptedSm4Key, err := sm4.CBCDecrypt(sm4Key, util.GetSM4IV(), request)

	return string(util.Padding(encryptedSm4Key, 0)), nil
}

func Sm4Encrypt(merchantId, privateKey, sopPublicKey, inputJson, token string, isRequest bool) (string, error) {
	sm4Key := util.GenerateSM4Key()
	iv := util.GetSM4IV()
	tmp, err := sm4.CBCEncrypt(sm4Key, iv, util.Padding([]byte(inputJson), 1))
	if err != nil {
		return "", err
	}
	responseMsg := base64.StdEncoding.EncodeToString(tmp)
	responseMsg = addNewline(responseMsg)

	pubKey, err := sm2.ReadPublicKeyFromHex(sopPublicKey)
	if err != nil {
		return "", errors.New("读取私钥失败")
	}
	encryptKeyBytes, err := util.Sm2Encrypt(pubKey, sm4Key)
	encryptKey := strings.ToUpper(hex.EncodeToString(encryptKeyBytes))

	accessToken := util.GenAccessToken(token)
	signContent := fmt.Sprintf("%s%s%s", responseMsg, encryptKey, accessToken)
	signature, err := sign(merchantId, privateKey, signContent)

	var reqData map[string]string

	if isRequest {
		reqData = map[string]string{
			"request":     responseMsg,
			"signature":   signature,
			"encryptKey":  encryptKey,
			"accessToken": accessToken,
		}
	} else {
		reqData = map[string]string{
			"response":    responseMsg,
			"signature":   signature,
			"encryptKey":  encryptKey,
			"accessToken": accessToken,
		}
	}

	jsonStr, err := json.Marshal(reqData)
	if err != nil {
		return "", err
	}
	return string(jsonStr), err
}

// GenerateKey 生成密钥对
func GenerateKey() (string, string) {
	pri, _ := sm2.GenerateKey(rand.Reader)
	hexPri := pri.D.Text(16)
	// 获取公钥
	publicKeyHex := publicKeyToString(&pri.PublicKey)
	return strings.ToUpper(hexPri), publicKeyHex
}

// publicKeyToString 公钥sm2.PublicKey转字符串(与java中org.bouncycastle.crypto生成的公私钥完全互通使用)
func publicKeyToString(publicKey *sm2.PublicKey) string {
	xBytes := publicKey.X.Bytes()
	yBytes := publicKey.Y.Bytes()

	// 确保坐标字节切片长度相同
	byteLen := len(xBytes)
	if len(yBytes) > byteLen {
		byteLen = len(yBytes)
	}

	// 为坐标补齐前导零
	xBytes = append(make([]byte, byteLen-len(xBytes)), xBytes...)
	yBytes = append(make([]byte, byteLen-len(yBytes)), yBytes...)

	// 添加 "04" 前缀
	publicKeyBytes := append([]byte{0x04}, append(xBytes, yBytes...)...)

	return strings.ToUpper(hex.EncodeToString(publicKeyBytes))
}

func addNewline(str string) string {
	lineLength := 76
	var result strings.Builder
	for i := 0; i < len(str); i++ {
		if i > 0 && i%lineLength == 0 {
			result.WriteString("\r\n")
		}
		result.WriteByte(str[i])
	}
	return result.String()
}

func sign(merchantId string, privateKeyHex string, signContent string) (string, error) {
	privateKey, err := sm2.ReadPrivateKeyFromHex(privateKeyHex)
	if err != nil {
		return "", err
	}

	r, s, err := sm2.Sm2Sign(privateKey, []byte(signContent), []byte(merchantId), rand.Reader)
	if err != nil {
		return "", err
	}
	return rSToSign(r, s), nil
}

func verify(content string, signature string, publicKeyStr string, merchantId string) bool {
	pubKey, err := sm2.ReadPublicKeyFromHex(publicKeyStr)
	if err != nil {
		panic(fmt.Sprintf("pubKeyBytes sm2 ReadPublicKeyFromHex err: %v", err))
	}
	r, s := signToRS(signature)
	return sm2.Sm2Verify(pubKey, []byte(content), []byte(merchantId), r, s)
}

func signToRS(signStr string) (*big.Int, *big.Int) {
	signSub := strings.Split(signStr, "#")
	if len(signSub) != 2 {
		panic(fmt.Sprintf("err rs: %x", signSub))
	}
	r, _ := new(big.Int).SetString(signSub[0], 16)
	s, _ := new(big.Int).SetString(signSub[1], 16)
	return r, s
}

func rSToSign(r *big.Int, s *big.Int) string {
	rStr := r.Text(16)
	sStr := s.Text(16)
	return fmt.Sprintf("%s#%s", rStr, sStr)
}