#include "ota_update_script.h"

#include <yandex_io/libs/device/device.h>

#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include <openssl/sha.h>

#include <fstream>
#include <iostream>
#include <string>

YIO_DEFINE_LOG_MODULE("updates");

using namespace quasar;

OtaUpdateScript::OtaUpdateScript(std::shared_ptr<YandexIO::IDevice> device,
                                 const std::string& script,
                                 const std::string& signature, size_t scriptTimestamp,
                                 size_t quasarTimestamp, Type type)
    : script_(script)
    , signature_(signature)
    , scriptTimestamp_(scriptTimestamp)
    , quasarTimestamp_(quasarTimestamp)
    , scriptType_(type)
{
    const auto& updatesdConfig = device->configuration()->getServiceConfig("updatesd");
    std::string publicKeyPath;

    if (updatesdConfig.isMember("otaScriptPublicKey")) {
        publicKeyPath = updatesdConfig["otaScriptPublicKey"].asString();
    } else {
        throw std::runtime_error("otaScriptPublicKey is missing from"
                                 " updatesd config");
    }

    publicKey_ = getPublicKey(publicKeyPath);
    if (!publicKey_) {
        throw std::runtime_error("Error reading OTA script public key");
    }

    scriptWithTimestamp_ = script_ + std::to_string(scriptTimestamp_);
}

OtaUpdateScript::~OtaUpdateScript() {
    if (publicKey_) {
        RSA_free(publicKey_);
    }
}

bool OtaUpdateScript::verify() const {
    if (!checkSignature()) {
        YIO_LOG_ERROR_EVENT("OtaUpdateScript.BadSignature", "Signature verification failed");
        return false;
    }

    if (!checkTimestamp()) {
        YIO_LOG_ERROR_EVENT("OtaUpdateScript.BadTimestamp", "Script timestamp verification failed");
        return false;
    }

    return true;
}

bool OtaUpdateScript::execute() const {
    if (scriptType_ == Type::T_STORE_EXEC) {
        return storeExec();
    } else if (scriptType_ == Type::T_EXEC) {
        return exec();
    }

    return false;
}

bool OtaUpdateScript::checkSignature() const {
    return verify(scriptWithTimestamp_, signature_, publicKey_);
}

bool OtaUpdateScript::checkTimestamp() const {
    return scriptTimestamp_ >= quasarTimestamp_;
}

bool OtaUpdateScript::storeExec() const {
    std::ofstream updateScript(SCRIPT_STORE_PATH);
    updateScript << script_;
    updateScript.close();

    ::chmod(SCRIPT_STORE_PATH, S_IRWXU);

    int ret = ::system(SCRIPT_STORE_PATH);
    if (ret < 0) {
        YIO_LOG_ERROR_EVENT("OtaUpdateScript.FailedStoreExec", "Error executing update script. Retval = " << ret);
    }

    ::unlink(SCRIPT_STORE_PATH);

    return (ret >= 0);
}

bool OtaUpdateScript::exec() const {
    int ret = ::system(script_.c_str());
    if (ret < 0) {
        YIO_LOG_ERROR_EVENT("OtaUpdateScript.FailedExec", "Error executing update script. Retval = " << ret);
        return false;
    }

    return true;
}

RSA* OtaUpdateScript::getPublicKey(const std::string& path) {
    std::string encryptionKey(getFileContent(path));

    BIO* bufio = BIO_new_mem_buf((void*)encryptionKey.c_str(),
                                 encryptionKey.length());
    if (!bufio) {
        throw std::runtime_error("Error allocating memory for BIO");
    }

    rsa_st* result = PEM_read_bio_RSAPublicKey(bufio, nullptr,
                                               nullptr, nullptr);
    if (!result) {
        BIO_free(bufio);
        ERR_load_crypto_strings();
        char buf[256];
        std::string error;
        error = ERR_error_string(ERR_get_error(), buf);
        throw std::runtime_error("Cannot read encryption key: " + error);
    }

    BIO_free(bufio);

    return result;
}

bool OtaUpdateScript::verify(const std::string& data, const std::string& sign,
                             rsa_st* pubKey) {
    const std::string decodedSignature = base64Decode(sign);
    const int64_t publicKeyLen = RSA_size(pubKey);
    std::vector<uint8_t> signDecryptBuf(publicKeyLen);

    int ret = RSA_public_decrypt(decodedSignature.size(),
                                 (const unsigned char*)decodedSignature.data(),
                                 signDecryptBuf.data(), pubKey, RSA_NO_PADDING);
    if (ret != publicKeyLen) {
        throw std::runtime_error("Error decrypting signature: " + getErrorMessage());
    }

    unsigned char scriptHashed[64];
    {
        SHA256_CTX ctx;
        SHA256_Init(&ctx);
        SHA256_Update(&ctx, data.c_str(), data.length());
        SHA256_Final(scriptHashed, &ctx);
    }

    ret = RSA_verify_PKCS1_PSS_mgf1(pubKey,
                                    (const unsigned char*)scriptHashed,
                                    EVP_sha256(), EVP_sha256(), signDecryptBuf.data(), -2);

    return (ret == 1);
}

std::string OtaUpdateScript::getErrorMessage() {
    ERR_load_crypto_strings();
    char buf[256];
    return ERR_error_string(ERR_get_error(), buf);
}
