#include "shards_map.h"

#include "exception.h"
#include "strings.h"
#include "utils.h"

#include <passport/infra/libs/cpp/dbpool/db_pool.h>
#include <passport/infra/libs/cpp/utils/log/global.h>
#include <passport/infra/libs/cpp/utils/string/split.h>
#include <passport/infra/libs/cpp/utils/string/string_utils.h>
#include <passport/infra/libs/cpp/xml/config.h>

#include <util/stream/file.h>

namespace NPassport::NBb {
    TShardRanges::TShardRanges() = default;

    static const TString SHARED_RANGE = "shared_range";

    void TShardRanges::Init(IInputStream& config, TShard shardCount) {
        Ranges_.clear();

        TString line;
        int linenum = 0;
        TUid lastUid = 0;

        // read config file, suggested format is
        //  uid1 shard1
        //  uid2 shard2
        // meaning range [uid1;uid2) is in shard1 etc
        // limitations: uid1 = 0, uidN < uidN+1
        try {
            bool firstRange = true;
            while (config.ReadLine(line) > 0) {
                ++linenum;

                if (line.empty() || line[0] == '#') {
                    continue;
                }

                std::vector<TString> vec = NUtils::ToVector(line, " \t");
                if (vec.size() != 2) {
                    throw yexception() << "wrong number of elements in line";
                }

                TUid u;
                TShard s;
                try {
                    u = TUtils::ToUInt(vec[0], TStrings::UID);
                    s = TUtils::ToUInt(vec[1], SHARED_RANGE);
                } catch (const TBlackboxError& e) {
                    throw yexception() << "can't parse range line '" << line << "' : " << e.StatusStr();
                }

                if (firstRange) {
                    if (u != 0) {
                        throw yexception() << "uncovered range [0," << vec[0] << "]";
                    }
                    firstRange = false;
                } else if (u <= lastUid) {
                    throw yexception() << "range not sorted";
                }

                if (s > shardCount) {
                    throw yexception() << "range mapped to unknown shard: " << vec[1];
                }

                Ranges_[u] = s - 1;
                lastUid = u;
            }

            if (firstRange) {
                throw yexception() << "range config empty";
            }

        } catch (const std::exception& e) {
            TLog::Error("Failed to read shard config: '%s' at line %d", e.what(), linenum);
            throw;
        }
    }

    TShardRanges::TShard TShardRanges::GetShard(const TString& uid) const {
        TUid u;
        try {
            u = TUtils::ToUInt(uid, TStrings::UID);
        } catch (const TBlackboxError&) {
            throw yexception() << "Internal error: can't parse uid " << uid;
        }

        std::map<TUid, TShard>::const_iterator p = Ranges_.upper_bound(u);

        if (p == Ranges_.end()) {
            return Ranges_.rbegin()->second; // greater the last bound, return last section
        }
        if (p == Ranges_.begin()) { // should always happen if initialized correctly
            throw yexception() << "Internal error: can't find shard for uid " << uid;
        }
        --p; // take previous range
        return p->second;
    }

    TShardsMap::TShardsMap() = default;

    TShardsMap::~TShardsMap() = default;

    void TShardsMap::Init(const TString& prefix,
                          const TString& dbPrefix,
                          const NXml::TConfig& config,
                          std::shared_ptr<NDbPool::TDbPoolCtx> ctx) {
        unsigned shardCount = config.SubKeys(prefix + "/" + dbPrefix + "db_conf/shard").size();

        for (unsigned i = 0; i < shardCount; ++i) {
            TString id = IntToString<10>(i + 1);
            TString key = prefix + "/" + dbPrefix + "db_conf/shard[@id=" + id + "]";

            if (!config.Contains(key)) {
                throw yexception() << "Shard config error: can't find dbpool settings for shard #" << id;
            }

            Shards_.push_back(ctx->CreateDb(config, key, prefix + "/" + dbPrefix + "shard_settings"));
            Shards_.back()->TryPing();
            Out_.push_back(std::ref(*Shards_.back()));
        }
    }

    unsigned TShardsMap::GetShardCount() const {
        return Shards_.size();
    }

    NDbPool::TDbPool& TShardsMap::GetPool(TShardRanges::TShard s) const {
        NDbPool::TDbPool* res = GetPoolUnsafe(s);
        if (!res) {
            throw TBlackboxError(TBlackboxError::EType::Unknown) << "There is no shard #" << s;
        }
        return *res;
    }

    NDbPool::TDbPool* TShardsMap::GetPoolUnsafe(TShardRanges::TShard s) const {
        if (s < Shards_.size()) {
            return Shards_[s].get();
        }
        return nullptr;
    }

    bool TShardsMap::ShardsOk(TString& statusbuf) const {
        TString tmpMsg;
        statusbuf.clear();
        bool result = true;
        for (const std::shared_ptr<NDbPool::TDbPool>& shard : Shards_) {
            if (!shard) {
                continue;
            }
            if (!shard->IsOk(&tmpMsg)) {
                result = false;
            }

            NUtils::AddMessage(statusbuf, tmpMsg);
        }

        return result;
    }

    TRangedShardsMap::TRangedShardsMap() = default;

    TRangedShardsMap::~TRangedShardsMap() = default;

    void TRangedShardsMap::Init(const TString& prefix,
                                const NXml::TConfig& config,
                                std::shared_ptr<NDbPool::TDbPoolCtx> ctx) {
        TShardsMap::Init(prefix, TStrings::EMPTY, config, ctx);

        TFileInput file(TString(config.AsString(prefix + "/ranges_path")));
        Ranges_.Init(file, Shards_.size());
    }

    NDbPool::TDbPool& TRangedShardsMap::GetPool(const TString& uid) const {
        NDbPool::TDbPool* res = GetPoolUnsafe(Ranges_.GetShard(uid));
        if (!res) {
            throw TBlackboxError(TBlackboxError::EType::Unknown) << "Can't get shard for uid: " << uid;
        }
        return *res;
    }

    std::vector<TString> TRangedShardsMap::SplitUids(const std::vector<TString>& uidsList) const {
        std::vector<TString> res(Shards_.size(), TStrings::EMPTY);

        for (const TString& uid : uidsList) {
            TShardRanges::TShard s = Ranges_.GetShard(uid);
            NUtils::AppendSeparated(res[s], ',', uid);
        }

        return res;
    }

}
