#include "mysql-driver.h"

#include "type.h"

#include <passport/infra/libs/cpp/dbpool/exception.h>
#include <passport/infra/libs/cpp/dbpool/query.h>
#include <passport/infra/libs/cpp/dbpool/result.h>

#include <passport/infra/libs/cpp/utils/string/string_utils.h>

#include <contrib/libs/libmysql_r/include/mysql.h>

#include <util/generic/singleton.h>
#include <util/generic/string.h>

namespace NPassport::NDbPool {
    static ESqlType MysqlTypeConvert(int mt) {
        switch (mt) {
            case FIELD_TYPE_DECIMAL:
                return ESqlType::DECIMAL;
            case FIELD_TYPE_TINY:
                return ESqlType::TINYINT;
            case FIELD_TYPE_SHORT:
                return ESqlType::SMALLINT;
            case FIELD_TYPE_LONG:
            case FIELD_TYPE_INT24:
                return ESqlType::INTEGER;
            case FIELD_TYPE_LONGLONG:
                return ESqlType::BIGINT;
            case FIELD_TYPE_FLOAT:
                return ESqlType::FLOAT;
            case FIELD_TYPE_DOUBLE:
                return ESqlType::DOUBLE;
            case FIELD_TYPE_TIMESTAMP:
            case FIELD_TYPE_DATETIME:
            case FIELD_TYPE_NEWDATE:
                return ESqlType::DATETIME;
            case FIELD_TYPE_YEAR:
            case FIELD_TYPE_DATE:
                return ESqlType::DATE;
            case FIELD_TYPE_TIME:
                return ESqlType::TIME;
            case FIELD_TYPE_ENUM:
            case FIELD_TYPE_SET:
            case FIELD_TYPE_STRING:
                return ESqlType::CHAR;
            case FIELD_TYPE_TINY_BLOB:
            case FIELD_TYPE_MEDIUM_BLOB:
            case FIELD_TYPE_LONG_BLOB:
            case FIELD_TYPE_BLOB:
                return ESqlType::BLOB;
            case FIELD_TYPE_VAR_STRING:
                return ESqlType::VARCHAR;

            default:
                return ESqlType::INVALID;
        }
    }

    TMysqlDriver::TMysqlDriver()
        : Lib_(Singleton<TMySqlLibHolder>()->Get())
    {
    }

    static const TString SSL_MODE = "ssl_mode";
    static const std::map<TString, mysql_ssl_mode> SSL_MODES = {
        {"SSL_MODE_DISABLED", SSL_MODE_DISABLED},
        {"SSL_MODE_PREFERRED", SSL_MODE_PREFERRED},
        {"SSL_MODE_REQUIRED", SSL_MODE_REQUIRED},
        {"SSL_MODE_VERIFY_CA", SSL_MODE_VERIFY_CA},
        {"SSL_MODE_VERIFY_IDENTITY", SSL_MODE_VERIFY_IDENTITY},
    };

    static mysql_ssl_mode GetSslMode(const IDriver::TExtendedArgs& ext) {
        auto it = ext.find(SSL_MODE);
        if (it == ext.end()) {
            return SSL_MODE_VERIFY_CA;
        }

        auto sslIt = SSL_MODES.find(it->second);
        Y_ENSURE(sslIt != SSL_MODES.end(),
                 "'ssl_mode' is not supported: " << it->second);

        return sslIt->second;
    }

    static const TString SSL_CA = "ssl_ca";
    static const char* GetSslCa(const IDriver::TExtendedArgs& ext) {
        auto it = ext.find(SSL_CA);
        if (it == ext.end()) {
            return "/etc/ssl/certs/ca-certificates.crt";
        }

        return it->second.c_str();
    }

    static const TString LOCALHOST_ = "localhost:";
    bool TMysqlDriver::Connect(const TString& host,
                               int port,
                               const TString& user,
                               const TZtStringBuf pwd,
                               const TString& db,
                               TDuration connectTimeout,
                               TDuration queryTimeout,
                               const TExtendedArgs& ext,
                               bool fetchStatusOnPing) {
        const char* h = nullptr;
        const char* s = nullptr;
        if (host.StartsWith(LOCALHOST_)) {
            s = host.c_str() + LOCALHOST_.size();
        } else if (!host.empty()) {
            h = host.c_str();
        }

        const char* u = user.empty() ? nullptr : user.c_str();
        const char* p = pwd.empty() ? nullptr : pwd.c_str();
        const char* d = db.empty() ? nullptr : db.c_str();

        // https://dev.mysql.com/doc/refman/8.0/en/mysql-options.html
        // 'Call mysql_options() after mysql_init() and before mysql_connect() or mysql_real_connect().'

        unsigned int conn_timeout = connectTimeout.Seconds() + 1;
        mysql_options(&Handle_, MYSQL_OPT_CONNECT_TIMEOUT, &conn_timeout);
        unsigned int read_timeout = queryTimeout.Seconds() + 1;
        mysql_options(&Handle_, MYSQL_OPT_READ_TIMEOUT, &read_timeout);
        unsigned int sslMode = GetSslMode(ext);
        mysql_options(&Handle_, MYSQL_OPT_SSL_MODE, &sslMode);
        mysql_options(&Handle_, MYSQL_OPT_SSL_CA, GetSslCa(ext));

        PingQuery_ = fetchStatusOnPing ? BuildPingQuery(ext) : "SELECT 1";

        return mysql_real_connect(&Handle_, h, u, p, d, port, s, 0) != nullptr;
    }

    TString TMysqlDriver::EscapeQueryParam(const TStringBuf s) const {
        TString r(s.length() * 2 + 1, 0);

        unsigned len = mysql_real_escape_string_quote(
            (MYSQL*)&Handle_,
            (char*)r.data(),
            s.data(),
            s.length(),
            '\'');
        r.resize(len);

        return r;
    }

    static const TStringBuf DEFAULT_STATUS_TABLE = "db_status";
    static const TStringBuf DEFAULT_STATUS_COLUMN = "value";

    TString TMysqlDriver::BuildPingQuery(const TExtendedArgs& ext) {
        auto it = ext.find("status_table");
        TStringBuf table = it == ext.end() ? DEFAULT_STATUS_TABLE
                                           : it->second;

        it = ext.find("status_column");
        TStringBuf column = it == ext.end() ? DEFAULT_STATUS_COLUMN
                                            : it->second;

        return NUtils::CreateStr("SELECT ", column, " FROM ", table, " LIMIT 1");
    }

    TRow TMysqlDriver::FetchRow(MYSQL_RES* result) {
        MYSQL_ROW row = mysql_fetch_row(result);
        if (ErrNum()) {
            throw yexception() << Error();
        }

        MYSQL_FIELD* flds = mysql_fetch_fields(result);
        if (!row || !flds) {
            return {};
        }

        TRow res;

        unsigned long* l = mysql_fetch_lengths(result);
        size_t numFields = mysql_num_fields(result);
        res.reserve(numFields);
        for (size_t i = 0; i < numFields; ++i) {
            res.emplace_back(TString(row[i] ? row[i] : "", l[i]),
                             !row[i],
                             MysqlTypeConvert(flds[i].type) == ESqlType::BLOB);
        }

        return res;
    }

    std::unique_ptr<TResult> TMysqlDriver::Query(const TQuery& q, TDuration) {
        // query timeout cannot be changed after connect()
        if (const TString& sql = q.Query(); mysql_real_query(&Handle_, sql.data(), sql.length()) != 0) {
            return {};
        }

        // It should be stored to avoid mysql error:
        // 'Commands out of sync; you can't run this command now'
        // https://dev.mysql.com/doc/refman/8.0/en/commands-out-of-sync.html
        return StoreResult();
    }

    TString TMysqlDriver::Error() {
        return TString(mysql_error(&Handle_));
    }

    int TMysqlDriver::ErrNum() const {
        return mysql_errno((MYSQL*)&Handle_);
    }

    std::unique_ptr<TResult> TMysqlDriver::StoreResult() {
        if (ErrNum() != 0) {
            return {};
        }

        std::unique_ptr<MYSQL_RES, decltype(&mysql_free_result)> result(
            mysql_store_result(&Handle_),
            mysql_free_result);
        size_t affectedRows = mysql_affected_rows(&Handle_);

        if (!result) {
            // if no error occurred but the result is empty, the request was not select type
            return std::make_unique<TResult>(TTable(), affectedRows);
        }

        const size_t size = mysql_num_rows(result.get());

        TTable table;
        table.reserve(size);

        for (size_t idx = 0; idx < size; ++idx) {
            TRow row = FetchRow(result.get());
            if (row.empty()) {
                break;
            }

            table.push_back(std::move(row));
        }

        return std::make_unique<TResult>(std::move(table), affectedRows);
    }

    std::unique_ptr<TResult> TMysqlDriver::Ping(TDuration queryTimeout) {
        // TODO: fetch weight from backend
        return Query(PingQuery_, queryTimeout);
    }
}
