#include "Conversion.hpp"
#include "HybridRsa.hpp"
#include <NitroModules/ArrayBuffer.hpp>

#include <openssl/evp.h>
#include <openssl/rsa.h>
#include <openssl/err.h>
#include <openssl/pem.h>

namespace icure::nitrokryptom::rsa {

std::shared_ptr<margelo::nitro::ArrayBuffer> exportKey(EVP_PKEY* key, int (*writeToBio)(EVP_PKEY*, BIO*)) {
    std::unique_ptr<BIO, decltype(&BIO_free_all)> bio(BIO_new(BIO_s_secmem()), BIO_free_all);
    if (writeToBio(key, bio.get()) != 1) throw std::runtime_error("PEM_write_bio PKCS8PrivateKey/PEM_write_bio_PUBKEY failed");
    BUF_MEM* bptr = nullptr;
    if (BIO_get_mem_ptr(bio.get(), &bptr) != 1 || bptr == nullptr || bptr->length <= 0) throw std::runtime_error("Couldn't retrieve bio mem info");
    uint8_t *resBuffer = new uint8_t[bptr->length];
    memcpy(resBuffer, bptr->data, bptr->length);
    return std::make_shared<margelo::nitro::NativeArrayBuffer>(resBuffer, bptr->length, [=]() { delete[] resBuffer; });
}

std::shared_ptr<EVP_PKEY> loadKey(std::shared_ptr<margelo::nitro::ArrayBuffer> key, EVP_PKEY *(*loadKey)(BIO* out, EVP_PKEY **x, pem_password_cb *cb, void *u)) {
    BIO* bio = BIO_new_mem_buf(key->data(), conversions::safe_size_to_int(key->size()));
    if (bio == nullptr) throw std::runtime_error("Failed to load arraybuffer as bio");
    EVP_PKEY* imported = loadKey(bio, nullptr, nullptr, nullptr);
    BIO_free(bio);
    if (imported != nullptr) {
        return std::shared_ptr<EVP_PKEY>(imported, EVP_PKEY_free);

    } else {
        throw std::runtime_error("Failed to load key");
    }
}

std::shared_ptr<EVP_PKEY> loadPrivateKey(std::shared_ptr<margelo::nitro::ArrayBuffer> key) {
    return loadKey(key, PEM_read_bio_PrivateKey);
}

std::shared_ptr<EVP_PKEY> loadPublicKey(std::shared_ptr<margelo::nitro::ArrayBuffer> key) {
    return loadKey(key, PEM_read_bio_PUBKEY);
}

const EVP_MD * getRsaOaepMd(margelo::nitro::nitrokryptom::RsaEncryptionAlgorithmSpec algorithm) {
    switch (algorithm) {
        case margelo::nitro::nitrokryptom::RsaEncryptionAlgorithmSpec::OAEPWITHSHA1:
            return EVP_sha1();
        case margelo::nitro::nitrokryptom::RsaEncryptionAlgorithmSpec::OAEPWITHSHA256:
            return EVP_sha256();
    }
}

}

namespace margelo::nitro::nitrokryptom {

std::shared_ptr<Promise<HybridKeypair>> HybridRsa::generateKeypair(double keySize) {
    return Promise<HybridKeypair>::async([keySize]() -> HybridKeypair {
        std::unique_ptr<EVP_PKEY_CTX, decltype(&EVP_PKEY_CTX_free)> ctx(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr), EVP_PKEY_CTX_free);
        std::unique_ptr<BIGNUM, decltype(&BN_free)> e(BN_new(), BN_free);
        if (EVP_PKEY_keygen_init(ctx.get()) != 1) throw std::runtime_error("EVP_PKEY_keygen_init failed");
        if (EVP_PKEY_CTX_set_rsa_keygen_bits(ctx.get(), keySize) != 1) throw std::runtime_error("EVP_PKEY_CTX_set_rsa_keygen_bits failed");
        if (BN_set_word(e.get(), RSA_F4) != 1) throw std::runtime_error("BN_set_word failed");
        if (EVP_PKEY_CTX_set1_rsa_keygen_pubexp(ctx.get(), e.get()) != 1) throw std::runtime_error("EVP_PKEY_CTX_set1_rsa_keygen_pubexp failed");
        EVP_PKEY* key = nullptr;
        if (EVP_PKEY_generate(ctx.get(), &key) != 1) throw std::runtime_error("EVP_PKEY_generate failed");
        std::shared_ptr<ArrayBuffer> privPem = icure::nitrokryptom::rsa::exportKey(key, [](EVP_PKEY* k, BIO *b) -> int { return PEM_write_bio_PKCS8PrivateKey(b, k, nullptr, nullptr, 0, nullptr, nullptr); });
        std::shared_ptr<ArrayBuffer> pubPem = icure::nitrokryptom::rsa::exportKey(key, [](EVP_PKEY* k, BIO *b) -> int { return PEM_write_bio_PUBKEY(b, k); });
        EVP_PKEY_free(key);
        return HybridKeypair(privPem, pubPem);
    });
}

std::shared_ptr<Promise<std::shared_ptr<ArrayBuffer>>> HybridRsa::sign(const std::shared_ptr<ArrayBuffer>& data, const std::shared_ptr<ArrayBuffer>& key) {
    const std::shared_ptr<ArrayBuffer> owningData = data->isOwner() ? data : ArrayBuffer::copy(data);
    const std::shared_ptr<EVP_PKEY> loadedKey = icure::nitrokryptom::rsa::loadPrivateKey(key);
    return Promise<std::shared_ptr<ArrayBuffer>>::async([loadedKey = std::move(loadedKey), owningData]() -> std::shared_ptr<ArrayBuffer> {
        std::unique_ptr<EVP_MD_CTX, decltype(&EVP_MD_CTX_free)> ctx(EVP_MD_CTX_new(), EVP_MD_CTX_free);
        if (ctx == nullptr) throw std::runtime_error("Failed to initialize context");
        if (EVP_DigestSignInit(ctx.get(), nullptr, EVP_sha256(), nullptr, loadedKey.get()) != 1) throw std::runtime_error("EVP_DigestSignInit failed");
        auto pkeyCtx = EVP_MD_CTX_get_pkey_ctx(ctx.get()); // no need to free (as per documentation)
        if (pkeyCtx == nullptr) throw std::runtime_error("Failed to get pkey context");
        if (EVP_PKEY_CTX_set_rsa_padding(pkeyCtx, RSA_PKCS1_PSS_PADDING) != 1) throw std::runtime_error("EVP_PKEY_CTX_set_rsa_padding failed");
        if (EVP_DigestSignUpdate(ctx.get(), owningData->data(), owningData->size()) != 1) throw std::runtime_error("EVP_DigestSignUpdate failed");
        size_t outlen = 0;
        if (EVP_DigestSignFinal(ctx.get(), nullptr, &outlen) != 1) throw std::runtime_error("EVP_DigestSignFinal (get size) failed");
        uint8_t *resBuffer = new uint8_t[outlen];
        if (EVP_DigestSignFinal(ctx.get(), resBuffer, &outlen) != 1) {
            delete[] resBuffer;
            throw std::runtime_error("EVP_DigestSignFinal failed");
        }
        return std::make_shared<margelo::nitro::NativeArrayBuffer>(resBuffer, outlen, [=]() { delete[] resBuffer; });
    });
}

std::shared_ptr<Promise<bool>> HybridRsa::verify(const std::shared_ptr<ArrayBuffer>& signature, const std::shared_ptr<ArrayBuffer>& data, const std::shared_ptr<ArrayBuffer>& key) {
    const std::shared_ptr<ArrayBuffer> owningData = data->isOwner() ? data : ArrayBuffer::copy(data);
    const std::shared_ptr<ArrayBuffer> owningSignature = signature->isOwner() ? signature : ArrayBuffer::copy(signature);
    const std::shared_ptr<EVP_PKEY> loadedKey = icure::nitrokryptom::rsa::loadPublicKey(key);
    return Promise<bool>::async([loadedKey, owningData, owningSignature]() -> bool {
        std::unique_ptr<EVP_MD_CTX, decltype(&EVP_MD_CTX_free)> ctx(EVP_MD_CTX_new(), EVP_MD_CTX_free);
        if (ctx == nullptr) throw std::runtime_error("Failed to initialize context");
        if (EVP_DigestVerifyInit(ctx.get(), nullptr, EVP_sha256(), nullptr, loadedKey.get()) != 1) throw std::runtime_error("EVP_DigestVerifyInit failed");
        auto pkeyCtx = EVP_MD_CTX_get_pkey_ctx(ctx.get()); // no need to free (as per documentation)
        if (pkeyCtx == nullptr) throw std::runtime_error("Failed to get pkey context");
        if (EVP_PKEY_CTX_set_rsa_padding(pkeyCtx, RSA_PKCS1_PSS_PADDING) != 1) throw std::runtime_error("EVP_PKEY_CTX_set_rsa_padding failed");
        if (EVP_DigestVerifyUpdate(ctx.get(), owningData->data(), owningData->size()) != 1) throw std::runtime_error("EVP_DigestVerifyUpdate failed");
        switch (EVP_DigestVerifyFinal(ctx.get(), owningSignature->data(), owningSignature->size())) {
            case 0:
                ERR_clear_error();
                return false;
            case 1:
                return true;
            default:
                throw std::runtime_error("EVP_DigestVerifyFinal failed");
        }
    });
}

std::shared_ptr<Promise<std::shared_ptr<ArrayBuffer>>> HybridRsa::encrypt(RsaEncryptionAlgorithmSpec algorithm, const std::shared_ptr<ArrayBuffer>& data, const std::shared_ptr<ArrayBuffer>& publicKey) {
    const std::shared_ptr<ArrayBuffer> owningData = data->isOwner() ? data : ArrayBuffer::copy(data);
    const std::shared_ptr<EVP_PKEY> loadedKey = icure::nitrokryptom::rsa::loadPublicKey(publicKey);
    return Promise<std::shared_ptr<ArrayBuffer>>::async([loadedKey, owningData, algorithm]() -> std::shared_ptr<ArrayBuffer> {
        std::unique_ptr<EVP_PKEY_CTX, decltype(&EVP_PKEY_CTX_free)> ctx(EVP_PKEY_CTX_new(loadedKey.get(), nullptr), EVP_PKEY_CTX_free);
        if (ctx == nullptr) throw std::runtime_error("Failed to initialize context");
        if (EVP_PKEY_encrypt_init(ctx.get()) != 1) throw std::runtime_error("EVP_PKEY_encrypt_init failed");
        if (EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_OAEP_PADDING) != 1) throw std::runtime_error("EVP_PKEY_CTX_set_rsa_padding failed");
        if (EVP_PKEY_CTX_set_rsa_oaep_md(ctx.get(), icure::nitrokryptom::rsa::getRsaOaepMd(algorithm)) != 1) throw std::runtime_error("EVP_PKEY_CTX_set_rsa_oaep_md failed");
        size_t outlen = 0;
        if (EVP_PKEY_encrypt(ctx.get(), nullptr, &outlen, owningData->data(), owningData->size()) != 1) throw std::runtime_error("EVP_PKEY_encrypt (get size) failed");
        uint8_t *resBuffer = new uint8_t[outlen];
        if (EVP_PKEY_encrypt(ctx.get(), resBuffer, &outlen, owningData->data(), owningData->size()) != 1) {
            delete[] resBuffer;
            throw std::runtime_error("EVP_PKEY_encrypt failed");
        }
        return std::make_shared<margelo::nitro::NativeArrayBuffer>(resBuffer, outlen, [=]() { delete[] resBuffer; });
    });
}

std::shared_ptr<Promise<std::shared_ptr<ArrayBuffer>>> HybridRsa::decrypt(RsaEncryptionAlgorithmSpec algorithm, const std::shared_ptr<ArrayBuffer>& data, const std::shared_ptr<ArrayBuffer>& privateKey) {
    const std::shared_ptr<ArrayBuffer> owningData = data->isOwner() ? data : ArrayBuffer::copy(data);
    const std::shared_ptr<EVP_PKEY> loadedKey = icure::nitrokryptom::rsa::loadPrivateKey(privateKey);
    return Promise<std::shared_ptr<ArrayBuffer>>::async([loadedKey, owningData, algorithm]() -> std::shared_ptr<ArrayBuffer> {
        std::unique_ptr<EVP_PKEY_CTX, decltype(&EVP_PKEY_CTX_free)> ctx(EVP_PKEY_CTX_new(loadedKey.get(), nullptr), EVP_PKEY_CTX_free);
        if (ctx == nullptr) throw std::runtime_error("Failed to initialize context");
        if (EVP_PKEY_decrypt_init(ctx.get()) != 1) throw std::runtime_error("EVP_PKEY_decrypt_init failed");
        if (EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_OAEP_PADDING) != 1) throw std::runtime_error("EVP_PKEY_CTX_set_rsa_padding failed");
        if (EVP_PKEY_CTX_set_rsa_oaep_md(ctx.get(), icure::nitrokryptom::rsa::getRsaOaepMd(algorithm)) != 1) throw std::runtime_error("EVP_PKEY_CTX_set_rsa_oaep_md failed");
        size_t outlen = 0;
        if (EVP_PKEY_decrypt(ctx.get(), nullptr, &outlen, owningData->data(), owningData->size()) != 1) throw std::runtime_error("EVP_PKEY_decrypt (get size) failed");
        uint8_t *resBuffer = new uint8_t[outlen];
        if (EVP_PKEY_decrypt(ctx.get(), resBuffer, &outlen, owningData->data(), owningData->size()) != 1) {
            delete[] resBuffer;
            throw std::runtime_error("EVP_PKEY_decrypt failed");
        }
        return std::make_shared<margelo::nitro::NativeArrayBuffer>(resBuffer, outlen, [=]() { delete[] resBuffer; });
    });
}

void HybridRsa::checkValidPublic(const std::shared_ptr<ArrayBuffer>& publicKey) {
    icure::nitrokryptom::rsa::loadPublicKey(publicKey);
}

void HybridRsa::checkValidPrivate(const std::shared_ptr<ArrayBuffer>& privateKey) {
    icure::nitrokryptom::rsa::loadPrivateKey(privateKey);
}

}
