package rsa

import (
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha256"
	"crypto/x509"
	"encoding/base64"
	"encoding/pem"
	"fmt"
	"strings"
)

// parseRSAPublicKeyFromPEM 解析PEM编码的RSA公钥
func parseRSAPublicKeyFromPEM(pemData []byte) (*rsa.PublicKey, error) {
	block, _ := pem.Decode(pemData)
	if block == nil || block.Type != "PUBLIC KEY" {
		return nil, fmt.Errorf("failed to parse PEM block containing the RSA public key")
	}

	pub, err := x509.ParsePKIXPublicKey(block.Bytes)
	if err != nil {
		return nil, err
	}

	switch pub := pub.(type) {
	case *rsa.PublicKey:
		return pub, nil
	default:
		return nil, fmt.Errorf("unknown public key type in PKIX wrapping")
	}
}

// encrypt 使用RSA公钥加密数据
func Encrypt(publicKeyPEM string, plaintext string) ([]byte, error) {
	var encryptedData []byte
	// 将PEM编码的公钥转换为[]byte
	pemData := []byte(publicKeyPEM)

	// 解析PEM数据以获取公钥
	pubKey, err := parseRSAPublicKeyFromPEM(pemData)
	if err != nil {
		return nil, err
	}
	hash := sha256.New()
	maxBlockSize := pubKey.Size() - 2*hash.Size() - 2
	// 创建用于加密的随机填充
	label := []byte("") // OAEP标签,对于某些情况可能是非空的
	for len(plaintext) > 0 {
		blockSize := maxBlockSize
		if len(plaintext) < maxBlockSize {
			blockSize = len(plaintext)
		}
		block := plaintext[:blockSize]
		encryptedBlock, err := rsa.EncryptOAEP(hash, rand.Reader, pubKey, []byte(block), label)
		if err != nil {
			return nil, err
		}
		encryptedData = append(encryptedData, encryptedBlock...)
		plaintext = plaintext[blockSize:]
	}
	return encryptedData, nil
}

// parseRSAPrivateKeyFromPEM 解析PEM编码的RSA私钥
func parseRSAPrivateKeyFromPEM(pemData []byte) (*rsa.PrivateKey, error) {
	block, _ := pem.Decode(pemData)
	if block == nil || block.Type != "RSA PRIVATE KEY" {
		return nil, fmt.Errorf("failed to parse PEM block containing the RSA private key")
	}

	// 尝试使用PKCS#1 v1.5
	priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
	if err != nil {
		// 如果失败,尝试使用PKCS#8
		privInterface, err := x509.ParsePKCS8PrivateKey(block.Bytes)
		if err != nil {
			return nil, err
		}

		switch k := privInterface.(type) {
		case *rsa.PrivateKey:
			priv = k
		default:
			return nil, fmt.Errorf("unknown private key type in PKCS#8 wrapping")
		}
	}

	return priv, nil
}

// decrypt 使用RSA私钥解密数据
func Decrypt(privateKeyPEM string, encryptedDataBase64 string) ([]byte, error) {
	var decryptedData []byte
	// 将PEM编码的私钥转换为[]byte
	pemData := []byte(privateKeyPEM)
	// 解析PEM数据以获取私钥
	privKey, err := parseRSAPrivateKeyFromPEM(pemData)
	if err != nil {
		return nil, err
	}

	keySize := privKey.PublicKey.Size()
	label := []byte("") // OAEP标签,对于某些情况可能是非空的
	hash := sha256.New()
	// 将Base64编码的加密数据解码为字节切片
	encryptedData, err := base64.StdEncoding.DecodeString(encryptedDataBase64)
	if err != nil {
		return nil, err
	}
	for len(encryptedData) > 0 {
		block := encryptedData[:keySize]
		// 这里假设使用的是OAEP填充和SHA-256哈希函数
		decryptedBlock, err := rsa.DecryptOAEP(hash, rand.Reader, privKey, block, label)
		if err != nil {
			//// 如果失败,可以尝试使用PKCS#1 v1.5填充
			decryptedBlock, err = rsa.DecryptPKCS1v15(rand.Reader, privKey, encryptedData)
			if err != nil {
				return nil, err
			}
		}
		decryptedData = append(decryptedData, decryptedBlock...)
		encryptedData = encryptedData[keySize:]
	}
	return decryptedData, nil
}

// 生成密钥对
func GenerateKey() (string, string, error) {
	// 生成私钥
	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		return "", "", err
	}
	// 导出私钥PKCS#1格式
	privKey := x509.MarshalPKCS1PrivateKey(privateKey)
	// 将私钥转换为PEM编码
	var privBlock = &pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: privKey,
	}
	privPem := pem.EncodeToMemory(privBlock)
	// 导出公钥
	pubKey := &privateKey.PublicKey
	derPkix, err := x509.MarshalPKIXPublicKey(pubKey)
	if err != nil {
		return "", "", err
	}
	// 将公钥转换为PEM编码
	var pubBlock = &pem.Block{
		Type:  "PUBLIC KEY",
		Bytes: derPkix,
	}
	pubPem := pem.EncodeToMemory(pubBlock)
	pri := strings.Replace(string(privPem), "-----BEGIN RSA PRIVATE KEY-----\n", "", -1)
	pri = strings.Replace(pri, "\n-----END RSA PRIVATE KEY-----\n", "", -1)
	pri = strings.Replace(pri, "\n", "", -1)
	pub := strings.Replace(string(pubPem), "-----BEGIN PUBLIC KEY-----\n", "", -1)
	pub = strings.Replace(pub, "\n-----END PUBLIC KEY-----\n", "", -1)
	pub = strings.Replace(pub, "\n", "", -1)
	return pub, pri, nil
}