#include "db_pool_ctx.h"

#include "db_pool.h"
#include "db_pool_stats.h"
#include "destination.h"
#include "misc/counter.h"

#include <passport/infra/libs/cpp/json/config.h>
#include <passport/infra/libs/cpp/json/writer.h>
#include <passport/infra/libs/cpp/unistat/time_stat.h>
#include <passport/infra/libs/cpp/utils/file.h>
#include <passport/infra/libs/cpp/utils/string/split.h>
#include <passport/infra/libs/cpp/xml/config.h>

namespace NPassport::NDbPool {
    TDbPoolCtx::TCredentials TDbPoolCtx::TCredentials::ReadFromFile(const TString& filename) {
        const NJson::TConfig cfg = NJson::TConfig::ReadFromFile(filename);

        return TCredentials{
            .User = cfg.As<TString>("/db_user"),
            .Password = cfg.As<TString>("/db_pass", ""),
        };
    }

    std::shared_ptr<TDbPoolCtx> TDbPoolCtx::Create() {
        return std::make_shared<TDbPoolCtx>();
    }

    void TDbPoolCtx::InitLogger(const TString& fileName) {
        if (fileName.empty() || fileName == "_NOLOG_") {
            return;
        }

        Log_ = TDbPoolLog(fileName);
    }

    TDbPoolLog TDbPoolCtx::GetLogger() const {
        return Log_;
    }

    std::shared_ptr<TDbPool> TDbPoolCtx::CreateDb(const NXml::TConfig& config,
                                                  const TString& xpath,
                                                  TQueryOpts&& opts) {
        return CreateDb(config, xpath, xpath, std::move(opts));
    }

    static TString Get(const NXml::TConfig& config,
                       const TString& xpathSpecial,
                       const TString& xpathCommon,
                       const TString& key) {
        if (config.Contains(xpathSpecial + key)) {
            return config.AsString(xpathSpecial + key);
        }
        return config.AsString(xpathCommon + key);
    }

    template <typename Config>
    static std::vector<TString> GetSubKeys(const Config& config,
                                           const TString& prefixSpecial,
                                           const TString& prefixCommon,
                                           const TString& key) {
        if (config.Contains(prefixSpecial + key)) {
            return config.SubKeys(prefixSpecial + key);
        }
        return config.SubKeys(prefixCommon + key);
    }

    std::shared_ptr<TDbPool> TDbPoolCtx::CreateDb(const NXml::TConfig& config,
                                                  const TString& xpathSpecial,
                                                  const TString& xpathCommon,
                                                  TQueryOpts&& opts) {
        auto getWithDefault = [&](const TString& key, auto defaultValue) {
            if (config.Contains(xpathSpecial + key)) {
                return config.As<decltype(defaultValue)>(xpathSpecial + key, defaultValue);
            }
            if (config.Contains(xpathCommon + key)) {
                return config.As<decltype(defaultValue)>(xpathCommon + key, defaultValue);
            }
            return defaultValue;
        };

        size_t dbpoolsize = getWithDefault("/poolsize", (size_t)10);
        double totalSizeRatio = getWithDefault("/total_size_ratio", 1.5);
        ui32 get_tm = getWithDefault("/get_timeout", (ui32)10);
        ui32 cn_tm = getWithDefault("/connect_timeout", (ui32)500);
        ui32 q_tm = getWithDefault("/query_timeout", (ui32)500);
        ui32 f_tm = getWithDefault("/fail_threshold", (ui32)500);
        ui32 pingPeriod = getWithDefault("/ping_period", (ui32)3000);
        ui32 timeToInit = getWithDefault("/time_to_init", (ui32)2000);

        TString driver = Get(config, xpathSpecial, xpathCommon, "/db_driver");
        ui16 port = getWithDefault("/db_port", (ui16)0);
        TString dbname = Get(config, xpathSpecial, xpathCommon, "/db_name");
        TString displayName = getWithDefault("/display_name", dbname);

        TCredentials creds;
        if (TString path = xpathSpecial + "/db_credentials"; config.Contains(path)) {
            creds = TCredentials::ReadFromFile(config.AsString(path));
        } else if (TString path = xpathCommon + "/db_credentials"; config.Contains(path)) {
            creds = TCredentials::ReadFromFile(config.AsString(path));
        } else {
            creds.User = config.AsString(xpathSpecial + "/db_user", "");
            creds.Password = config.AsString(xpathSpecial + "/db_pass", "");
        }

        double balancingCloseRate = getWithDefault("/balancing_close_rate", 0.1);
        double balancingOpenRate = getWithDefault("/balancing_open_rate", 0.3);

        bool fetchStatusOnPing = getWithDefault("/fetch_status_on_ping", false);

        TString locale_cmd = getWithDefault("/locale_cmd", TString());

        std::vector<TDbHost> hosts;
        for (const TString& path : config.SubKeys(xpathSpecial + "/db_host")) {
            hosts.push_back(TDbHost{
                .Host = config.AsString(path),
                .Port = port,
                .Weight = config.AsNum<size_t>(path + "/@weight", 1),
            });
        }

        if (hosts.empty()) {
            throw yexception() << "db_host is missing at " << xpathSpecial;
        }

        NDbPool::TDestination::TExtendedParams ext;
        for (const TString& path : GetSubKeys(config, xpathSpecial, xpathCommon, "/extended")) {
            ext.insert({config.AsString(path + "/@key"), config.AsString(path)});
        }

        TDbPoolSettings settings{
            .Dsn = NDbPool::TDestination::Create(
                driver,
                creds.User,
                creds.Password,
                dbname,
                locale_cmd,
                displayName,
                std::move(ext)),
            .Hosts = std::move(hosts),
            .Size = dbpoolsize,
            .TotalSizeRatio = totalSizeRatio,
            .GetTimeout = TDuration::MilliSeconds(get_tm),
            .ConnectionTimeout = TDuration::MilliSeconds(cn_tm),
            .QueryTimeout = TDuration::MilliSeconds(q_tm),
            .FailThreshold = TDuration::MilliSeconds(f_tm),
            .PingPeriod = TDuration::MilliSeconds(pingPeriod),
            .TimeToInit = TDuration::MilliSeconds(timeToInit),
            .DefaultQueryOpts = std::make_shared<TQueryOpts>(std::move(opts)),
            .BalancingCloseRate = balancingCloseRate,
            .BalancingOpenRate = balancingOpenRate,
            .FetchStatusOnPing = fetchStatusOnPing,
        };

        return CreateDb(settings);
    }

    std::shared_ptr<TDbPool> TDbPoolCtx::CreateDb(const NJson::TConfig& config,
                                                  const TString& jpoint,
                                                  TQueryOpts&& opts) {
        return CreateDb(config, jpoint, jpoint, std::move(opts));
    }

    template <typename T>
    static auto Get(const NJson::TConfig& config,
                    const TString& jpointSpecial,
                    const TString& jpointCommon,
                    const TString& key) {
        if (config.Contains(jpointSpecial + key)) {
            return config.As<T>(jpointSpecial + key);
        }
        return config.As<T>(jpointCommon + key);
    }

    std::shared_ptr<TDbPool> TDbPoolCtx::CreateDb(const NJson::TConfig& config,
                                                  const TString& jpointSpecial,
                                                  const TString& jpointCommon,
                                                  TQueryOpts&& opts) {
        auto getWithDefault = [&](const TString& key, auto defaultValue) {
            if (config.Contains(jpointSpecial + key)) {
                return config.As<decltype(defaultValue)>(jpointSpecial + key, defaultValue);
            }
            if (config.Contains(jpointCommon + key)) {
                return config.As<decltype(defaultValue)>(jpointCommon + key, defaultValue);
            }
            return defaultValue;
        };

        size_t dbpoolsize = getWithDefault("/poolsize", (size_t)10);
        ui32 getTimeout = getWithDefault("/get_timeout", (ui32)10);
        ui32 conectionTimeout = getWithDefault("/connect_timeout", (ui32)500);
        ui32 queryTimeout = getWithDefault("/query_timeout", (ui32)500);
        ui32 failThreshold = getWithDefault("/fail_threshold", (ui32)500);
        ui32 pingPeriod = getWithDefault("/ping_period", (ui32)3000);
        ui32 timeToInit = getWithDefault("/time_to_init", (ui32)2000);

        TString driver = Get<TString>(config, jpointSpecial, jpointCommon, "/db_driver");
        ui16 port = getWithDefault("/db_port", (ui32)0);
        TString dbname = Get<TString>(config, jpointSpecial, jpointCommon, "/db_name");
        TString displayName = getWithDefault("/display_name", dbname);

        TCredentials creds;
        if (TString path = jpointSpecial + "/db_credentials"; config.Contains(path)) {
            creds = TCredentials::ReadFromFile(config.As<TString>(path));
        } else if (TString path = jpointCommon + "/db_credentials"; config.Contains(path)) {
            creds = TCredentials::ReadFromFile(config.As<TString>(path));
        } else {
            creds.User = getWithDefault("/db_user", TString());
            creds.Password = getWithDefault("/db_pass", TString());
        }

        double balancingCloseRate = getWithDefault("/balancing_close_rate", 0.1);
        double balancingOpenRate = getWithDefault("/balancing_open_rate", 0.3);

        bool fetchStatusOnPing = getWithDefault("/fetch_status_on_ping", false);

        TString localeCmd = getWithDefault("/locale_cmd", TString());

        std::vector<TDbHost> hosts;
        for (const TString& path : config.SubKeys(jpointSpecial + "/db_host")) {
            hosts.push_back(TDbHost{
                .Host = config.As<TString>(path + "/host"),
                .Port = port,
                .Weight = config.As<size_t>(path + "/weight", 1),
            });
        }

        NDbPool::TDestination::TExtendedParams ext;
        for (const TString& path : GetSubKeys(config, jpointSpecial, jpointCommon, "/extended")) {
            ext.insert({NJson::TConfig::GetKeyFromPath(path), config.As<TString>(path)});
        }

        TDbPoolSettings settings{
            .Dsn = NDbPool::TDestination::Create(
                driver,
                creds.User,
                creds.Password,
                dbname,
                localeCmd,
                displayName,
                std::move(ext)),
            .Hosts = std::move(hosts),
            .Size = dbpoolsize,
            .GetTimeout = TDuration::MilliSeconds(getTimeout),
            .ConnectionTimeout = TDuration::MilliSeconds(conectionTimeout),
            .QueryTimeout = TDuration::MilliSeconds(queryTimeout),
            .FailThreshold = TDuration::MilliSeconds(failThreshold),
            .PingPeriod = TDuration::MilliSeconds(pingPeriod),
            .TimeToInit = TDuration::MilliSeconds(timeToInit),
            .DefaultQueryOpts = std::make_shared<TQueryOpts>(std::move(opts)),
            .BalancingCloseRate = balancingCloseRate,
            .BalancingOpenRate = balancingOpenRate,
            .FetchStatusOnPing = fetchStatusOnPing,
        };

        return CreateDb(settings);
    }

    void TDbPoolCtx::AddUnistat(NUnistat::TBuilder& builder) const {
        for (const std::shared_ptr<TDbPool>& db : Dbpools_) {
            db->AddUnistat(builder);
        }
    }

    void TDbPoolCtx::AddUnistatExtended(NUnistat::TBuilder& builder) const {
        for (const std::shared_ptr<TDbPool>& db : Dbpools_) {
            db->AddUnistatExtended(builder);
        }
    }

    TString TDbPoolCtx::GetExtendedStats() const {
        TString res;
        NJson::TWriter wr(res);
        NJson::TObject root(wr);

        NJson::TObject dbpools(root, "dbpools");
        for (const std::shared_ptr<TDbPool>& db : Dbpools_) {
            db->GetExtendedStats(dbpools);
        }

        return res;
    }

    TString TDbPoolCtx::GetLegacyStats() const {
        TDbPoolStats stats;

        for (const std::shared_ptr<TDbPool>& db : Dbpools_) {
            stats.Add(db.get());
        }

        return stats.Result();
    }

    std::shared_ptr<TDbPool> TDbPoolCtx::CreateDb(const TDbPoolSettings& settings) {
        auto res = std::make_shared<NDbPool::TDbPool>(settings, Log_);
        Dbpools_.push_back(res);
        return res;
    }
}
