#include "Conversion.hpp"
#include "HybridAes.hpp"
#include <NitroModules/ArrayBuffer.hpp>
#include <cstring>
#include <openssl/evp.h>
#include <openssl/rand.h>

namespace margelo::nitro::nitrokryptom {

std::shared_ptr<Promise<std::shared_ptr<ArrayBuffer>>> HybridAes::encrypt(
    const std::shared_ptr<ArrayBuffer>& data,
    const std::shared_ptr<ArrayBuffer>& key,
    const std::optional<std::shared_ptr<ArrayBuffer>>& iv
) {
    if (key->size() != 32 && key->size() != 64) throw std::runtime_error("unexpected key size");
    std::unique_ptr<EVP_CIPHER_CTX, decltype(&EVP_CIPHER_CTX_free)> ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free);
    if (!ctx) throw std::runtime_error("Failed to create EVP_CIPHER_CTX");
    auto expectedCipherSize = (data->size() / 16 + 1) * 16;
    uint8_t* resBuffer = new uint8_t[16 + expectedCipherSize];
    if (iv.has_value()) {
        memcpy(resBuffer, iv.value()->data(), 16);
    } else {
        int res = RAND_bytes(resBuffer, 16);
        if (res != 1) {
            delete[] resBuffer;
            throw std::runtime_error("RAND_bytes for IV failed");
        }
    }
    if (EVP_EncryptInit_ex(ctx.get(), key->size() == 128 ? EVP_aes_128_cbc() : EVP_aes_256_cbc(), nullptr, key->data(), resBuffer) != 1) {
        delete[] resBuffer;
        throw std::runtime_error("EVP_EncryptInit failed");
    }
    int writtenBytes = 0;

    if (EVP_EncryptUpdate(ctx.get(), &resBuffer[16], &writtenBytes, data->data(), icure::nitrokryptom::conversions::safe_size_to_int(data->size())) != 1) {
        delete[] resBuffer;
        throw std::runtime_error("EVP_EncryptUpdate failed");
    }
    
    int totalWritten = writtenBytes;
    
    if (EVP_EncryptFinal_ex(ctx.get(), &resBuffer[16 + totalWritten], &writtenBytes) != 1) {
        delete[] resBuffer;
        throw std::runtime_error("EVP_EncryptFinal_ex failed");
    }
    
    totalWritten += writtenBytes;
    
    if (totalWritten != expectedCipherSize) {
        delete[] resBuffer;
        throw std::runtime_error("Unexpected number of bytes written");
    }

    return Promise<std::shared_ptr<ArrayBuffer>>::resolved(std::make_shared<margelo::nitro::NativeArrayBuffer>(resBuffer, 16 + totalWritten, [=]() { delete[] resBuffer; }));
}

std::shared_ptr<Promise<std::shared_ptr<ArrayBuffer>>> HybridAes::decrypt(const std::shared_ptr<ArrayBuffer>& ivAndEncryptedData, const std::shared_ptr<ArrayBuffer>& key) {
    if (key->size() != 32 && key->size() != 64) throw std::runtime_error("unexpected key size");
    std::unique_ptr<EVP_CIPHER_CTX, decltype(&EVP_CIPHER_CTX_free)> ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free);
    if (!ctx) throw std::runtime_error("Failed to create EVP_CIPHER_CTX");
    uint8_t* resBuffer = new uint8_t[ivAndEncryptedData->size()];
    if (EVP_DecryptInit_ex(ctx.get(), key->size() == 128 ? EVP_aes_128_cbc() : EVP_aes_256_cbc(), nullptr, key->data(), ivAndEncryptedData->data()) != 1) {
        delete[] resBuffer;
        throw std::runtime_error("EVP_DecryptInit_ex failed");
    }
    int writtenBytes = 0;

    if (
        EVP_DecryptUpdate(
          ctx.get(),
          resBuffer,
          &writtenBytes,
          &(ivAndEncryptedData->data())[16],
          icure::nitrokryptom::conversions::safe_size_to_int(ivAndEncryptedData->size()) - 16
      ) != 1
    ) {
        delete[] resBuffer;
        throw std::runtime_error("EVP_DecryptUpdate failed");
    }
    
    int totalWritten = writtenBytes;
    
    if (EVP_DecryptFinal_ex(ctx.get(), &resBuffer[totalWritten], &writtenBytes) != 1) {
        delete[] resBuffer;
        throw std::runtime_error("EVP_DecryptFinal_ex failed");
    }
    
    totalWritten += writtenBytes;

    return Promise<std::shared_ptr<ArrayBuffer>>::resolved(std::make_shared<margelo::nitro::NativeArrayBuffer>(resBuffer, totalWritten, [=]() { delete[] resBuffer; }));
}

}
