#include "sqlite-driver.h"

#include "type.h"

#include <passport/infra/libs/cpp/dbpool/query.h>
#include <passport/infra/libs/cpp/dbpool/result.h>

#include <passport/infra/libs/cpp/utils/string/string_utils.h>

#include <util/datetime/base.h>
#include <util/generic/string.h>
#include <util/stream/file.h>

#include <time.h>

namespace NPassport::NDbPool {
    static void UnixTimeFunc(sqlite3_context* context, int argc, sqlite3_value** argv) {
        if (argc > 1) {
            sqlite3_result_error(context, "bad arguments", -1);
            return;
        }

        if (0 == argc) {
            sqlite3_result_int(context, (int)time(nullptr));
            return;
        }

        const char* data = (const char*)sqlite3_value_text(argv[0]);

        TInstant ins;
        if (!TInstant::TryParseIso8601(data, ins)) {
            sqlite3_result_error(context, "bad format", -1);
            return;
        }

        sqlite3_result_int(context, (int)ins.TimeT());
    }

    static ESqlType SqliteTypeConvert(const char* type) {
        if (!type) {
            return ESqlType::INVALID;
        }

        if (0 == strcasecmp(type, "DECIMAL") ||
            0 == strcasecmp(type, "FIXED") ||
            0 == strcasecmp(type, "DEC")) {
            return ESqlType::DECIMAL;
        }
        if (0 == strcasecmp(type, "TINYINT")) {
            return ESqlType::TINYINT;
        }
        if (0 == strcasecmp(type, "SMALLINT")) {
            return ESqlType::SMALLINT;
        }
        if (0 == strcasecmp(type, "INTEGER") || 0 == strcasecmp(type, "INT")) {
            return ESqlType::INTEGER;
        }
        if (0 == strcasecmp(type, "BIGINT")) {
            return ESqlType::BIGINT;
        }
        if (0 == strcasecmp(type, "FLOAT")) {
            return ESqlType::FLOAT;
        }
        if (0 == strcasecmp(type, "DOUBLE")) {
            return ESqlType::DOUBLE;
        }
        if (0 == strcasecmp(type, "TIMESTAMP") || 0 == strcasecmp(type, "DATETIME")) {
            return ESqlType::DATETIME;
        }
        if (0 == strcasecmp(type, "YEAR") || 0 == strcasecmp(type, "DATE")) {
            return ESqlType::DATE;
        }
        if (0 == strcasecmp(type, "TIME")) {
            return ESqlType::TIME;
        }
        if (0 == strcasecmp(type, "ENUM") ||
            0 == strcasecmp(type, "SET") ||
            0 == strcasecmp(type, "CHAR")) {
            return ESqlType::CHAR;
        }
        if (0 == strcasecmp(type, "TINYBLOB") || 0 == strcasecmp(type, "MEDIUMBLOB") ||
            0 == strcasecmp(type, "LONGBLOB") || 0 == strcasecmp(type, "BLOB")) {
            return ESqlType::BLOB;
        }
        if (0 == strcasecmp(type, "VARCHAR")) {
            return ESqlType::VARCHAR;
        }

        return ESqlType::INVALID;
    }

    TSqliteDriver::TSqliteDriver()
        : Handle_(nullptr, sqlite3_close)
    {
    }

    bool TSqliteDriver::Connect(const TString&,
                                int,
                                const TString&,
                                const TZtStringBuf,
                                const TString& db,
                                TDuration,
                                TDuration,
                                const TExtendedArgs&,
                                bool) {
        const bool isSqlFile = db.EndsWith(".sql");

        int status = -1;
        sqlite3* handle = nullptr;
        if (isSqlFile) {
            status = sqlite3_open(":memory:", &handle);
        } else {
            const TString path = NUtils::CreateStr("file:", db, "?mode=ro&immutable=1");

            status = sqlite3_open_v2(
                path.c_str(),
                &handle,
                SQLITE_OPEN_READONLY | SQLITE_OPEN_URI,
                nullptr);
        }

        Handle_.reset(handle);

        if (SQLITE_OK != status) {
            return false;
        }

        sqlite3_create_function(
            handle,
            "UNIX_TIMESTAMP",
            -1,
            SQLITE_UTF8,
            nullptr,
            &UnixTimeFunc,
            nullptr,
            nullptr);

        return isSqlFile ? InitInMemory(db) : true;
    }

    TString TSqliteDriver::EscapeQueryParam(const TStringBuf s) const {
        TString str(s);
        char* to = sqlite3_mprintf("%q", str.c_str());
        if (!to) {
            return TString();
        }

        TString r(to);
        sqlite3_free(to);
        return r;
    }

    bool TSqliteDriver::InitInMemory(const TString& filepath) {
        TFileInput sql(filepath);
        TString queries = sql.ReadAll();

        TStringBuf buf(queries);
        while (buf) {
            TStringBuf current = buf.NextTok(';');
            if (!Query(NUtils::CreateStr(current, ";"), {})) {
                return false;
            }
        }

        return true;
    }

    std::unique_ptr<TResult> TSqliteDriver::Query(const TQuery& q, TDuration) {
        const TString& sqlQuery = q.Query();

        sqlite3_stmt* stmt;
        if (SQLITE_OK != sqlite3_prepare(Handle_.get(), sqlQuery.c_str(), sqlQuery.size(), &stmt, nullptr)) {
            return {};
        }

        int rc = sqlite3_step(stmt);
        int ncols = sqlite3_column_count(stmt);

        std::vector<bool> columnBlobs;
        columnBlobs.reserve(ncols);
        if (SQLITE_ROW == rc) {
            for (int i = 0; i < ncols; ++i) {
                columnBlobs.push_back(
                    SqliteTypeConvert(sqlite3_column_decltype(stmt, i)) == ESqlType::BLOB);
            }
        }

        TTable table;
        while (SQLITE_ROW == rc) {
            TRow row;
            row.reserve(ncols);

            for (int idx = 0; idx < ncols; ++idx) {
                const char* res = (const char*)sqlite3_column_text(stmt, idx);
                row.emplace_back(TString(res ? res : ""),
                                 !res,
                                 columnBlobs.at(idx));
            }
            table.push_back(std::move(row));
            rc = sqlite3_step(stmt);
        }

        // this does NOT work for select requests, which seem to return 1 all the time
        // it is expected that the user will handle the fact that affectedRows() for select is meaningless
        size_t affectedRows = sqlite3_changes(Handle_.get());
        if (SQLITE_OK != sqlite3_finalize(stmt)) {
            return {};
        }

        return std::make_unique<TResult>(std::move(table), affectedRows);
    }

    TString TSqliteDriver::Error() {
        const char* res = sqlite3_errmsg(Handle_.get());
        return res ? TString(res) : TString();
    }

    int TSqliteDriver::ErrNum() const {
        return sqlite3_errcode(Handle_.get());
    }

    std::unique_ptr<TResult> TSqliteDriver::Ping(TDuration) {
        TTable res;
        TRow& row = res.emplace_back();
        row.emplace_back("1");

        return std::make_unique<TResult>(std::move(res), 0);
    }
}
