package rsa

import (
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha256"
	"crypto/x509"
	"encoding/base64"
	"encoding/pem"
	"fmt"
)

// parseRSAPublicKeyFromPEM 解析PEM编码的RSA公钥
func parseRSAPublicKeyFromPEM(pemData []byte) (*rsa.PublicKey, error) {
	block, _ := pem.Decode(pemData)
	if block == nil || block.Type != "PUBLIC KEY" {
		return nil, fmt.Errorf("failed to parse PEM block containing the RSA public key")
	}

	pub, err := x509.ParsePKIXPublicKey(block.Bytes)
	if err != nil {
		return nil, err
	}

	switch pub := pub.(type) {
	case *rsa.PublicKey:
		return pub, nil
	default:
		return nil, fmt.Errorf("unknown public key type in PKIX wrapping")
	}
}

// encrypt 使用RSA公钥加密数据
func Encrypt(publicKeyPEM string, plaintext []byte) ([]byte, error) {
	// 将PEM编码的公钥转换为[]byte
	pemData := []byte(publicKeyPEM)

	// 解析PEM数据以获取公钥
	pubKey, err := parseRSAPublicKeyFromPEM(pemData)
	if err != nil {
		return nil, err
	}

	// 创建用于加密的随机填充
	label := []byte("") // OAEP标签,对于某些情况可能是非空的
	ciphertext, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, pubKey, plaintext, label)
	if err != nil {
		return nil, err
	}
	return ciphertext, nil
}

// parseRSAPrivateKeyFromPEM 解析PEM编码的RSA私钥
func parseRSAPrivateKeyFromPEM(pemData []byte) (*rsa.PrivateKey, error) {
	block, _ := pem.Decode(pemData)
	if block == nil || block.Type != "RSA PRIVATE KEY" {
		return nil, fmt.Errorf("failed to parse PEM block containing the RSA private key")
	}

	// 尝试使用PKCS#1 v1.5
	priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
	if err != nil {
		// 如果失败,尝试使用PKCS#8
		privInterface, err := x509.ParsePKCS8PrivateKey(block.Bytes)
		if err != nil {
			return nil, err
		}

		switch k := privInterface.(type) {
		case *rsa.PrivateKey:
			priv = k
		default:
			return nil, fmt.Errorf("unknown private key type in PKCS#8 wrapping")
		}
	}

	return priv, nil
}

// decrypt 使用RSA私钥解密数据
func Decrypt(privateKeyPEM string, encryptedDataBase64 string) ([]byte, error) {
	// 将PEM编码的私钥转换为[]byte
	pemData := []byte(privateKeyPEM)

	// 解析PEM数据以获取私钥
	privKey, err := parseRSAPrivateKeyFromPEM(pemData)
	if err != nil {
		return nil, err
	}

	// 将Base64编码的加密数据解码为字节切片
	encryptedData, err := base64.StdEncoding.DecodeString(encryptedDataBase64)
	if err != nil {
		return nil, err
	}

	// 根据你的加密方式选择合适的解密函数
	// 这里假设使用的是OAEP填充和SHA-256哈希函数
	label := []byte("") // OAEP标签,对于某些情况可能是非空的
	decrypted, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, privKey, encryptedData, label)
	if err != nil {
		// 如果失败,可以尝试使用PKCS#1 v1.5填充
		decrypted, err = rsa.DecryptPKCS1v15(rand.Reader, privKey, encryptedData)
		if err != nil {
			return nil, err
		}
	}

	return decrypted, nil
}

// 生成密钥对
func GenerateKey() (string, string, error) {
	// 生成私钥
	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		return "", "", err
	}
	// 导出私钥PKCS#1格式
	privKey := x509.MarshalPKCS1PrivateKey(privateKey)
	// 将私钥转换为PEM编码
	var privBlock = &pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: privKey,
	}
	privPem := pem.EncodeToMemory(privBlock)
	// 导出公钥
	pubKey := &privateKey.PublicKey
	derPkix, err := x509.MarshalPKIXPublicKey(pubKey)
	if err != nil {
		return "", "", err
	}
	// 将公钥转换为PEM编码
	var pubBlock = &pem.Block{
		Type:  "PUBLIC KEY",
		Bytes: derPkix,
	}
	pubPem := pem.EncodeToMemory(pubBlock)
	return string(pubPem), string(privPem), nil
}