#include "mysql_driver.h"

#include <passport/infra/libs/cpp/dbpool2/misc/utils.h>

#include <passport/infra/libs/cpp/dbpool/exception.h>
#include <passport/infra/libs/cpp/utils/string/string_utils.h>

#include <util/string/builder.h>
#include <util/string/cast.h>
#include <util/thread/singleton.h>

namespace NPassport::NDbPool2 {
    static bool IsBlob(int mt) {
        switch (mt) {
            case FIELD_TYPE_TINY_BLOB:
            case FIELD_TYPE_MEDIUM_BLOB:
            case FIELD_TYPE_LONG_BLOB:
            case FIELD_TYPE_BLOB:
                return true;

            default:
                return false;
        }
    }

    static const char* GetMysqlError(MYSQL& handle) {
        const char* err = mysql_error(&handle);
        return err ? err : "<no_error>";
    }

    class TMySqlThreadHolder {
    public:
        TMySqlThreadHolder() {
            Y_ENSURE(0 == mysql_thread_init());
        }

        ~TMySqlThreadHolder() {
            mysql_thread_end();
        }
    };

    TMysqlDriver::TMySqlHolder::TMySqlHolder()
        : MYSQL()
    {
        if (!mysql_init(this)) {
            struct TCloser {
                MYSQL* Handle;
                ~TCloser() {
                    mysql_close(Handle);
                }
            } closer_ = {this};

            throw yexception() << "mysql_init() failed: " << GetMysqlError(*this);
        }
    }

    TMysqlDriver::TMySqlHolder::~TMySqlHolder() {
        mysql_close(this);
    }

    TMysqlDriver::TMysqlDriver()
        : LibHolder_(Singleton<NDbPool::TMySqlLibHolder>()->Get())
    {
    }

    TMysqlDriver::~TMysqlDriver() = default;

    void TMysqlDriver::OnPollEvent() {
        try {
            InitMysql();

            while (Proccess() == EProcessResult::ContinueCycle) {
            }
        } catch (const std::exception& e) {
            PerQuery_.Promise.SetValue(GetErrorInfo(e.what()));
            SetState(EState::Error);
        }
    }

    SOCKET TMysqlDriver::GetSocket() const {
        Y_ENSURE(EState::NotInited != State_, "driver cannot return socket: " << State_);
        return Handle_.net.fd;
    }

    static const TString SSL_MODE = "ssl_mode";
    static const std::map<TString, mysql_ssl_mode> SSL_MODES = {
        {"SSL_MODE_DISABLED", SSL_MODE_DISABLED},
        {"SSL_MODE_PREFERRED", SSL_MODE_PREFERRED},
        {"SSL_MODE_REQUIRED", SSL_MODE_REQUIRED},
        {"SSL_MODE_VERIFY_CA", SSL_MODE_VERIFY_CA},
        {"SSL_MODE_VERIFY_IDENTITY", SSL_MODE_VERIFY_IDENTITY},
    };

    static mysql_ssl_mode GetSslMode(const NDbPool::IDriver::TExtendedArgs& ext) {
        auto it = ext.find(SSL_MODE);
        if (it == ext.end()) {
            return SSL_MODE_VERIFY_IDENTITY;
        }

        auto sslIt = SSL_MODES.find(it->second);
        Y_ENSURE(sslIt != SSL_MODES.end(),
                 "'ssl_mode' is not supported: " << it->second);

        return sslIt->second;
    }

    static const TString SSL_CA = "ssl_ca";
    static const char* GetSslCa(const NDbPool::IDriver::TExtendedArgs& ext) {
        auto it = ext.find(SSL_CA);
        if (it == ext.end()) {
            return "/etc/ssl/certs/ca-certificates.crt";
        }

        return it->second.c_str();
    }

    void TMysqlDriver::Init(const TDriverSettings& settings) {
        Y_ENSURE(EState::NotInited == State_, "driver cannot be inited: " << State_);
        InitMysql();

        Settings_ = settings;

        Settings_.SerializedDestination = NUtils::CreateStr(
            "db2;"
            "driver=mysql;"
            "host=",
            Settings_.Host, ":", Settings_.Port, ";",
            "db=", Settings_.Dsn->Db, ";");

        unsigned int conn_timeout = 1;
        mysql_options(&Handle_, MYSQL_OPT_CONNECT_TIMEOUT, &conn_timeout);
        unsigned int sslMode = GetSslMode(Settings_.Dsn->Extended);
        mysql_options(&Handle_, MYSQL_OPT_SSL_MODE, &sslMode);
        mysql_options(&Handle_, MYSQL_OPT_SSL_CA, GetSslCa(Settings_.Dsn->Extended));

        SetState(EState::Inited);
    }

    IDriver::TResultFuture<bool> TMysqlDriver::StartConnecting() {
        InitMysql();

        Y_ENSURE(EState::Inited == State_);
        SetState(EState::Connecting);

        PerQuery_.PromiseConnecting = NThreading::NewPromise<TErrorOr<bool>>();
        Connect();

        return PerQuery_.PromiseConnecting.GetFuture();
    }

    bool TMysqlDriver::IsReadyForQuery() const noexcept {
        return EState::Connected == State_;
    }

    TDriverDestination TMysqlDriver::GetDestination() const {
        return {Settings_.SerializedDestination, GetId()};
    }

    static const TString SELECT_1 = "SELECT 1";
    IDriver::TResultFuture<bool> TMysqlDriver::StartPinging() {
        return StartSendingQuery(SELECT_1)
            .Apply([](const TResultType& fut) -> IDriver::TResultFuture<bool> {
                const TErrorOr<NDbPool::TTable>& val = fut.GetValue();
                if (const TDriverError* err = std::get_if<TDriverError>(&val)) {
                    return NThreading::MakeFuture<TErrorOr<bool>>(*err);
                }

                return NThreading::MakeFuture<TErrorOr<bool>>(true);
            });
    }

    TString TMysqlDriver::EscapeQueryParam(const TString& str) const {
        TString res(str.length() * 2 + 1, 0);

        unsigned len = mysql_real_escape_string_quote((MYSQL*)&Handle_, (char*)res.data(), str.data(), str.length(), '\'');
        res.resize(len);

        return res;
    }

    IDriver::TResultType TMysqlDriver::StartSendingQuery(NDbPool::TQuery&& query) {
        InitMysql();

        Y_ENSURE(EState::Connected == State_);
        Y_ENSURE(query.Query(), "query cannot be empty: " << Settings_.SerializedDestination << ". id=" << GetId());

        SetState(EState::SendingQuery);
        PerQuery_ = TPerQuery(std::move(query));
        PerQuery_.Promise = NThreading::NewPromise<TErrorOr<NDbPool::TTable>>();

        OnPollEvent(); // calls sendQuery() + storeResult() + ... - if result is ready

        return PerQuery_.Promise.GetFuture();
    }

    TMysqlDriver::EProcessResult TMysqlDriver::Proccess() {
        auto toEnum = [](bool isSuccess) -> EProcessResult {
            return isSuccess ? EProcessResult::ContinueCycle : EProcessResult::BackToPoller;
        };

        switch (State_) {
            case EState::Connecting:
                return toEnum(Connect());
            case EState::SendingQuery:
                return toEnum(SendQuery());
            case EState::StoringResult:
                return toEnum(StoreResult());
            case EState::FetchingRows:
                return toEnum(FetchRows());
            case EState::FreeingResult:
                return FreeResult() ? EProcessResult::BackToPoller : EProcessResult::ContinueCycle;
            case EState::Inited:
            case EState::NotInited:
            case EState::Connected:
            case EState::Error:
                return EProcessResult::BackToPoller;
        }
    }

    bool TMysqlDriver::Connect() {
        Y_ENSURE(EState::Connecting == State_, "driver cannot connect: " << State_);

        // connection is still (blocking) synchronous - it is weird but ok
        // TODO: check it with mysql server 8.0
        const net_async_status status = mysql_real_connect_nonblocking(
            &Handle_,
            Settings_.Host.c_str(),
            Settings_.Dsn->User.c_str(),
            TZtStringBuf(Settings_.Dsn->Passwd).c_str(),
            Settings_.Dsn->Db.c_str(),
            Settings_.Port,
            nullptr,
            0);

        if (!IsReady(status)) {
            return false;
        }

        SetState(EState::Connected);
        // boolean value doesn't matter: future should have value
        PerQuery_.PromiseConnecting.SetValue(true);
        return true;
    }

    bool TMysqlDriver::SendQuery() {
        Y_ENSURE(EState::SendingQuery == State_, "driver cannot send query: " << State_);

        // first call sets buffer with query to output queue on socket.
        // mysql_real_query_nonblocking() returns status before moment when query is completely sent
        if (!IsReady(mysql_real_query_nonblocking(&Handle_,
                                                  PerQuery_.Query.Query().data(),
                                                  PerQuery_.Query.Query().size())))
        {
            return false;
        }

        SetState(EState::StoringResult);
        return true;
    }

    bool TMysqlDriver::StoreResult() {
        Y_ENSURE(EState::StoringResult == State_, "driver cannot store result: " << State_);

        if (!IsReady(mysql_store_result_nonblocking(&Handle_, &PerQuery_.MysqlResult))) {
            return false;
        }

        if (PerQuery_.MysqlResult) {
            PerQuery_.Result.reserve(mysql_num_rows(PerQuery_.MysqlResult));
            SetState(EState::FetchingRows);
        } else {
            Y_ENSURE(mysql_errno(&Handle_) == 0); // should always be ok - some paranoia
            // Got nothing from server
            SetState(EState::FreeingResult);
        }

        return true;
    }

    bool TMysqlDriver::FetchRows() {
        Y_ENSURE(EState::FetchingRows == State_, "driver cannot fetch row: " << State_);

        if (!IsReady(mysql_fetch_row_nonblocking(PerQuery_.MysqlResult, &PerQuery_.MysqlRow))) {
            return false;
        }

        if (PerQuery_.MysqlRow) {
            PerQuery_.Result.emplace_back(BuildRow());
        } else {
            SetState(EState::FreeingResult);
        }

        return true;
    }

    bool TMysqlDriver::FreeResult() {
        Y_ENSURE(EState::FreeingResult == State_, "driver cannot free result: " << State_);

        if (!IsReady(mysql_free_result_nonblocking(PerQuery_.MysqlResult))) {
            return false;
        }

        PerQuery_.MysqlResult = nullptr;

        PerQuery_.Promise.SetValue(std::move(PerQuery_.Result));
        SetState(EState::Connected);
        return true;
    }

    static const TString NULL_STR = "<NULL>";

    NDbPool::TRow TMysqlDriver::BuildRow() const {
        Y_ENSURE(PerQuery_.MysqlResult && PerQuery_.MysqlRow);

        MYSQL_FIELD* flds = mysql_fetch_fields(PerQuery_.MysqlResult);
        Y_ENSURE(flds);

        unsigned long* l = mysql_fetch_lengths(PerQuery_.MysqlResult);
        size_t numFields = mysql_num_fields(PerQuery_.MysqlResult);

        NDbPool::TRow res;
        res.reserve(numFields);

        for (size_t idx = 0; idx < numFields; ++idx) {
            res.emplace_back(PerQuery_.MysqlRow[idx] ? TString(PerQuery_.MysqlRow[idx], l[idx]) : TString(NULL_STR),
                             !PerQuery_.MysqlRow[idx],
                             IsBlob(flds[idx].type));
        }

        return res;
    }

    void TMysqlDriver::InitMysql() {
        FastTlsSingleton<TMySqlThreadHolder>();
    }

    bool TMysqlDriver::IsReady(net_async_status status) const {
        switch (status) {
            case NET_ASYNC_COMPLETE:
                return true;
            case NET_ASYNC_NOT_READY:
                return false;
            case NET_ASYNC_ERROR:
                ythrow NDbPool::TException(Settings_.SerializedDestination) << "got NET_ASYNC_ERROR";
            case NET_ASYNC_COMPLETE_NO_MORE_RESULTS:
                ythrow NDbPool::TException(Settings_.SerializedDestination) << "Impossible result: NET_ASYNC_COMPLETE_NO_MORE_RESULTS";
        }
    }

    void TMysqlDriver::SetState(TMysqlDriver::EState state) {
        PreviousState_ = State_;
        State_ = state;
    }

    TDriverError TMysqlDriver::GetErrorInfo(TStringBuf what) {
        InitMysql();

        return TDriverError{
            TDriverDestination{
                Settings_.SerializedDestination,
            },
            TStringBuilder()
                << "state=" << State_
                << ". prev_state=" << PreviousState_
                << ". last_error=" << what
                << ". mysql_error:" << GetMysqlError(Handle_),
            (int)mysql_errno(&Handle_),
        };
    }

    TMysqlDriver::TPerQuery::TPerQuery(NDbPool::TQuery&& query)
        : Query(std::move(query))
    {
    }

    TMysqlDriver::TPerQuery::~TPerQuery() {
        if (MysqlResult) {
            mysql_free_result(MysqlResult);
        }
    }
}
