#include "db.h"

#include <crypta/lib/native/proto_serializer/proto_serializer.h>
#include <crypta/lib/native/yt/utils/helpers.h>
#include <crypta/lib/proto/user_data/token_dict_item.pb.h>
#include <crypta/siberia/bin/custom_audience/common/utils/paths.h>

#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/library/parallel_io/parallel_reader.h>
#include <mapreduce/yt/common/config.h>

#include <util/stream/file.h>

using namespace NCrypta::NSiberia::NCustomAudience;
using namespace NCrypta::NLog;
using namespace NLab;

namespace {
    TVector<NLab::TUserData> ReadUserData(NYT::IClientPtr client, const NYT::TRichYPath& srcTable, ui64 rowsCount, TLogPtr log) {
        log->info("[User data] Start to download table '{}'", srcTable.Path_);

        const ui64 realRowsCount = NCrypta::GetAttribute(client, srcTable.Path_, "row_count").AsInt64();
        const ui64 size = (rowsCount == 0) ? realRowsCount : Min(realRowsCount, rowsCount);

        const auto srcTableRichPath = NYT::TRichYPath(srcTable).AddRange(NYT::TReadRange::FromRowIndices(0, size));
        auto reader = NYT::CreateParallelTableReader<NLab::TUserData>(
            client, srcTableRichPath, NYT::TParallelTableReaderOptions().Ordered(false)
        );

        log->info("[User data] Records total count: {}", size);

        TVector<NLab::TUserData> userDatas;
        userDatas.reserve(size);

        NLab::TUserData userData;

        for (size_t i = 0; reader->IsValid(); reader->Next(), ++i) {
            userDatas.push_back(reader->MoveRow());

            if (i != 0 && i % 100000 == 0) {
                log->info("[User data] Records loaded: {}", i);
            }
        }

        log->info("[User data] Successfully read", srcTable.Path_);

        return userDatas;
    };

    TTokenDicts ReadDict(NYT::IClientPtr client, const NYT::TRichYPath& srcTable, TLogPtr log) {
        log->info("[Dict] Start reading file: '{}'", srcTable.Path_);

        const ui64 count = NCrypta::GetAttribute(client, srcTable.Path_, "row_count").AsInt64();
        log->info("[Dict] Records total count: {}", count);

        auto reader = NYT::CreateParallelTableReader<TTokenDictItem>(
            client, srcTable, NYT::TParallelTableReaderOptions().Ordered(false)
        );

        TTokenDicts dicts;
        dicts.Reserve(count);

        for (size_t i = 0; reader->IsValid(); reader->Next(), ++i) {
            const auto& dictItem = reader->GetRow();

            const auto& token = dictItem.GetToken();
            const auto& id = dictItem.GetId();
            dicts.Dict.insert({id, {.Token = token, .Weight = static_cast<float>(dictItem.GetWeight())}});
            dicts.ReversedDict.insert({token, id});

            if (i != 0 && i % 100000 == 0) {
                log->info("[Dict] Records loaded: {}", i);
            }
        }

        log->info("[Dict] Successfully read", srcTable.Path_);

        return dicts;
    };

    TVector<TIterRange<TUserDataIter>> GetUserDataPacks(const TVector<NLab::TUserData>& userDatas, size_t packSize) {
        const auto& packs = GetPacks(userDatas, packSize);

        TVector<TIterRange<TUserDataIter>> res;
        for (const auto& pack : packs) {
            res.push_back(pack);
        }

        return res;
    }
}

void TDb::Read(const TDbConfig& config, TLogPtr log) {
    if (IsReady()) {
        ythrow yexception() << "Rereading DB is not supported";
    }

    for (const auto& ytConfig: config.GetYt()) {
        try {
            auto client = NYT::CreateClient(ytConfig.GetProxy());

            UserDatas = ReadUserData(
                client,
                NYT::TRichYPath(config.GetUserDataTables().GetUserData()).Columns(
                    {"yuid", "Attributes", "Segments", "CryptaID", "AffinitiesEncoded"}),
                config.GetMaxRowsToDownload(),
                log
            );
            log->info("User data db size: {}", UserDatas.size());

            Dicts.HostDicts = ReadDict(client, config.GetUserDataTables().GetHostDict(), log);
            log->info("Host dict size: {}", Dicts.HostDicts.Dict.size());

            Dicts.WordDicts = ReadDict(client, config.GetUserDataTables().GetWordDict(), log);
            log->info("Word dict size: {}", Dicts.WordDicts.Dict.size());

            Dicts.AppDicts = ReadDict(client, config.GetUserDataTables().GetAppDict(), log);
            log->info("App dict size: {}", Dicts.AppDicts.Dict.size());

            Packs = GetUserDataPacks(UserDatas, config.GetPackSize());

            AtomicSet(Ready, 1);
            return;
        } catch (const yexception& e) {
            log->error("Failed to load base from {}: {}", ytConfig.GetProxy(), e.what());
        }
    }

    ythrow yexception() << "Failed to load tables from all clusters";
}

const TPacksVector& TDb::GetPacks() const {
    return Packs;
};

const TDicts& TDb::GetDicts() const {
    return Dicts;
};

bool TDb::IsReady() const {
    return AtomicGet(Ready);
}
