#include <util/digest/fnv.h>
#include <util/generic/size_literals.h>
#include <util/stream/file.h>
#include <util/string/join.h>
#include <util/thread/pool.h>

#include <dict/word2vec/model/model.h>
#include <kernel/mirrors/mirrors_trie.h>
#include <library/cpp/getopt/last_getopt.h>
#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/util/temp_table.h>
#include <quality/logs/parse_lib/parse_lib.h>
#include <quality/logs/parse_lib/parsing_rules.h>
#include <quality/traffic/iterator/iterator.h>
#include <robot/library/yt/static/command.h>
#include <wmconsole/version3/library/jupiter/jupiter.h>
#include <wmconsole/version3/wmcutil/args.h>
#include <wmconsole/version3/wmcutil/compress.h>
#include <wmconsole/version3/wmcutil/log.h>
#include <wmconsole/version3/wmcutil/periodic.h>
#include <wmconsole/version3/wmcutil/url.h>
#include <wmconsole/version3/wmcutil/yt/yt_runner.h>
#include <wmconsole/version3/wmcutil/yt/yt_utils.h>
#include <wmconsole/version3/searchqueries-mr/tools/host2vec-model/protos/embedding.pb.h>
#include <util/draft/date.h>

#include "config.h"

using namespace NJupiter;

namespace NWebmaster {

namespace {
const char *F_COMPRESSED_CHUNK_NO   = "CompressedChunkNo";
const char *F_COMPRESSED_DATA       = "CompressedData";
const char *F_HOSTS                 = "Hosts";
const char *F_KEY                   = "key";
const char *F_PARTITION_ID          = "PartitionId";
const char *F_SESSION_ID            = "SessionId";
const char *F_SESSION_PART          = "SessionPart";
const char *F_SUBKEY                = "subkey";
//const char *F_VALUE                 = "value";

const char *FORMAT = "%Y-%m-%d";
const int COMPRESS_PARTITIONS = 512;
}

typedef void(*TTaskHandler)(const TConfig &config);

struct TMirrors {
    TMirrors(const TString &mirrorsTrieFile = "mirrors.trie")
        : MirrorsTrie(new TMirrorsMappedTrie(mirrorsTrieFile.data(), PCHM_Force_Lock))
        , MirrorCharBuffer(MirrorsTrie->MakeCharBuffer())
    {
    }

    TString GetMainMirror(TString host) {
        host.to_lower();
        TString mainMirror = host;
        const char *mainMirrorPtr = MirrorsTrie->GetCheck(host.data(), MirrorCharBuffer.Get());
        if (mainMirrorPtr) {
            mainMirror = mainMirrorPtr;
        }
        return mainMirror;
    }

public:
    TSimpleSharedPtr<TMirrorsMappedTrie> MirrorsTrie;
    TMirrorsMappedTrie::TCharBuffer MirrorCharBuffer;
};

//ReduceBy F_KEY
//SortBy F_KEY, F_SUBKEY
struct TParseBrowserLogsReduce : public NYT::IReducer<NYT::TTableReader<NYT::TYaMRRow>, NYT::TTableWriter<NYT::TNode>> {
    Y_SAVELOAD_JOB(MirrorsTrieFile)

public:
    TParseBrowserLogsReduce() = default;
    TParseBrowserLogsReduce(const TString &mirrorsTrieFile)
        : MirrorsTrieFile(mirrorsTrieFile)
    {
    }

    void Start(TWriter */*writer*/) override {
        PRules.Reset(new TStraightForwardParsingRules);
        Mirrors.Reset(new TMirrors(MirrorsTrieFile));
    }

    static inline const TTrafficItem* SafeParseMRData(TParsingRules& prules, NYT::TTableReader<NYT::TYaMRRow>* iter) {
        const TTrafficItem* item = nullptr;
        try {
            item = dynamic_cast<const TTrafficItem*>(prules.ParseMRData(iter->GetRow().Key, iter->GetRow().SubKey, iter->GetRow().Value));
        } catch (...) {
        }
        return item;
    }

    inline bool FixHost(const TString &host, TString &fixedHost) {
        if (host.find(".") == TString::npos) {
            return false;
        }

        THttpURL parsedUrl;
        if (!NUtils::ParseUrl(parsedUrl, host)) {
            return false;
        }

        const TString parsedHost = parsedUrl.PrintS(THttpURL::FlagScheme | THttpURL::FlagHost);
        const TString mainMirror = Mirrors->GetMainMirror(parsedHost);
        fixedHost = TString{NUtils::FixDomainPrefix(NUtils::RemoveScheme(mainMirror))};
        return true;
    }

    void Do(TReader *input, TWriter *output) override {
        TString id = TString{input->GetRow().Key};
        NTrafficLib::TTrafficIterator iter;
        TMap<size_t, TVector<TString>> hostsBySessionParts;
        time_t prevTimestamp = 0;
        size_t sessionPart = 0;
        for (; input->IsValid(); input->Next()) {
            if (iter.Next(SafeParseMRData(*PRules, input))) {
                if (prevTimestamp == 0) {
                    prevTimestamp = iter.GetTimestamp();
                }

                TString fixedHost;
                if (!FixHost(iter.GetUrlHost(), fixedHost)) {
                    continue;
                }

                if ((iter.GetTimestamp() - prevTimestamp) > 60 * 15) { //https://wiki.yandex-team.ru/JandeksPoisk/Jekosistema/PonimaniePolzovatelejj/threshold/#rezultaty
                    sessionPart++;
                }

                TVector<TString> &hosts = hostsBySessionParts[sessionPart];
                if (hosts.empty() || fixedHost != hosts.back()) {
                    hosts.push_back(fixedHost);
                }

                //if (+hosts > 1000) {
                    //Cerr << "Too long sessions " << input->GetRow().Key << "\n";
                    //return;
                //}

                prevTimestamp = iter.GetTimestamp();
            }
        }

        const size_t partitionId = FnvHash<ui32>(id.data(), id.size()) % COMPRESS_PARTITIONS;
        for (const auto &sessionObj : hostsBySessionParts) {
            size_t sessionPart = sessionObj.first;
            const TVector<TString> &hosts = sessionObj.second;
            output->AddRow(NYT::TNode()
                (F_PARTITION_ID, partitionId)
                (F_SESSION_ID, id)
                (F_SESSION_PART, sessionPart)
                (F_HOSTS, JoinStrings(hosts, " "))
            );
        }
    }

public:
    THolder<TParsingRules> PRules;
    TString MirrorsTrieFile;
    THolder <TMirrors> Mirrors;
};

REGISTER_REDUCER(TParseBrowserLogsReduce);

//ReduceBy F_PARTITION_ID
//SortBy F_PARTITION_ID, F_SESSION_ID, F_SESSION_PART
struct TCompressPoolReduce : public NYT::IReducer<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<NYT::TNode>> {
public:
    void Do(TReader* input, TWriter* output) override {
        NUtils::TChunk chunk;
        const size_t partitionId = input->GetRow()[F_PARTITION_ID].AsUint64();
        for (; input->IsValid(); input->Next()) {
            const NYT::TNode &row = input->GetRow();
            const TString hosts = row[F_HOSTS].AsString() + "\n";

            chunk.Write(hosts.data(), hosts.size());

            TStringBuf compressedData(chunk.Data(), chunk.Size());
            if (chunk.Overflow()) {
                output->AddRow(NYT::TNode()
                    (F_PARTITION_ID, partitionId)
                    (F_COMPRESSED_CHUNK_NO, chunk.No++)
                    (F_COMPRESSED_DATA, compressedData)
                );
                chunk.Clear();
            }
        }

        chunk.Finish();

        if (chunk.Size() > 0) {
            TStringBuf compressedData(chunk.Data(), chunk.Size());
            output->AddRow(NYT::TNode()
                (F_PARTITION_ID, partitionId)
                (F_COMPRESSED_CHUNK_NO, chunk.No++)
                (F_COMPRESSED_DATA, compressedData)
            );
        }
    }
};

REGISTER_REDUCER(TCompressPoolReduce);

TString GetCompressed(const TString &table) {
    return table + ".gz";
}

void PrepareDataset(const TConfig &config, NYT::IClientBasePtr client, const TString& sessionsRoot) {
    size_t days = config.TABLE_HOST2VEC_TRAIN_DAYS;
    TDate endDate(Now().TimeT());
    endDate = endDate - 1; //previous day
    days--; //days will include endDate
    TDate startDate = endDate - days;

    NYT::ITransactionPtr tx = client->StartTransaction();

    NYT::TTableSchema datasetTableSchema;
    datasetTableSchema.Strict(true);
    datasetTableSchema.AddColumn(NYT::TColumnSchema().Name(F_SESSION_ID).Type(NYT::VT_STRING).SortOrder(NYT::SO_ASCENDING));
    datasetTableSchema.AddColumn(NYT::TColumnSchema().Name(F_SESSION_PART).Type(NYT::VT_UINT64).SortOrder(NYT::SO_ASCENDING));
    datasetTableSchema.AddColumn(NYT::TColumnSchema().Name(F_PARTITION_ID).Type(NYT::VT_UINT64));
    datasetTableSchema.AddColumn(NYT::TColumnSchema().Name(F_HOSTS).Type(NYT::VT_STRING));

    NYT::TTableSchema compressedDatasetTableSchema;
    compressedDatasetTableSchema.Strict(true);
    compressedDatasetTableSchema.AddColumn(NYT::TColumnSchema().Name(F_PARTITION_ID).Type(NYT::VT_UINT64).SortOrder(NYT::SO_ASCENDING));
    compressedDatasetTableSchema.AddColumn(NYT::TColumnSchema().Name(F_COMPRESSED_CHUNK_NO).Type(NYT::VT_UINT64).SortOrder(NYT::SO_ASCENDING));
    compressedDatasetTableSchema.AddColumn(NYT::TColumnSchema().Name(F_COMPRESSED_DATA).Type(NYT::VT_STRING));

    TOpRunner runner(tx);

    for (TDate curDate = startDate; curDate <= endDate; ++curDate) {
        const TString inputTable = NYTUtils::JoinPath(sessionsRoot, curDate.ToStroka(FORMAT), "clean");
        if (tx->Exists(inputTable)) {
            runner.InputYaMR(inputTable);
        }
    }

    runner
        .OutputNode(NYT::TRichYPath(config.TABLE_HOST2VEC_TRAIN_DATASET).Schema(datasetTableSchema))
        .File(config.FILE_MIRRORS_TRIE_YT_PATH)
        .MemoryLimit(12_GBs)
        .ReduceBy(F_KEY)
        .SortBy(F_KEY, F_SUBKEY)
        .Reduce(new TParseBrowserLogsReduce(config.FILE_MIRRORS_TRIE_DISK_PATH))

        .SortBy(F_PARTITION_ID, F_SESSION_ID, F_SESSION_PART)
        .Sort(config.TABLE_HOST2VEC_TRAIN_DATASET)

        .InputNode(config.TABLE_HOST2VEC_TRAIN_DATASET)
        .OutputNode(NYT::TRichYPath(GetCompressed(config.TABLE_HOST2VEC_TRAIN_DATASET)).Schema(compressedDatasetTableSchema))
        .ReduceBy(F_PARTITION_ID)
        .SortBy(F_PARTITION_ID, F_SESSION_ID, F_SESSION_PART)
        .MaxRowWeight(128_MBs)
        .MemoryLimit(1_GBs)
        .Reduce(new TCompressPoolReduce)
    ;

    tx->Commit();
}

void DownloadDataset(NYT::IClientBasePtr client, const TString &table, const ui32 metaPartitions) {
    LOG_INFO("downloading dataset");

    THolder<IThreadPool> queue(CreateThreadPool(4));

    const int step = COMPRESS_PARTITIONS / metaPartitions;
    for (int partitionId = 0, metaPartitionId = 0; partitionId < COMPRESS_PARTITIONS; partitionId += step, metaPartitionId++) {
        const ui32 startPartitionId = partitionId;
        const ui32 endPartitionId = partitionId + step;

        queue->SafeAddFunc(
            [=, &client]() {
                LOG_INFO("downloading range %u [%u..%u]", metaPartitionId, startPartitionId, endPartitionId);
                NYT::TRichYPath path(table);

                path.AddRange(NYT::TReadRange()
                    .LowerLimit(NYT::TReadLimit().Key(startPartitionId))
                    .UpperLimit(NYT::TReadLimit().Key(endPartitionId))
                );

                TUnbufferedFileOutput fo(Sprintf("%04d.gz", metaPartitionId));
                auto reader = client->CreateTableReader<NYT::TNode>(path);
                for (; reader->IsValid(); reader->Next()) {
                    const NYT::TNode &row = reader->GetRow();
                    const TString data = row[F_COMPRESSED_DATA].AsString();
                    fo.Write(data.data(), data.size());
                }
                fo.Finish();

                LOG_INFO("downloading range %u [%u..%u] - done", metaPartitionId, startPartitionId, endPartitionId);
            }
        );
    }

    queue->Stop();

    LOG_INFO("downloading dataset - done");
}

void UpdateMirrorsTrie(const TConfig &config, NYT::IClientBasePtr srcClient, NYT::IClientBasePtr dstClient) {
    const char *ATTR_MIRRORS_SOURCE = "MirrorsSource";
    const TString jupiterMirrorsTriePath = GetJupiterMirrorsTrieInProdFile(srcClient);

    try {
        if (dstClient->Exists(config.FILE_MIRRORS_TRIE_YT_PATH) && NYTUtils::GetAttr(dstClient, config.FILE_MIRRORS_TRIE_YT_PATH, ATTR_MIRRORS_SOURCE).AsString() == jupiterMirrorsTriePath) {
            LOG_INFO("mirrors will not be updated: %s to %s", jupiterMirrorsTriePath.data(), config.FILE_MIRRORS_TRIE_YT_PATH.data());
            return;
        }
    } catch (yexception &e) {
        LOG_WARN("updating mirrors: %s", e.what());
    }

    LOG_INFO("updating mirrors: %s to %s", jupiterMirrorsTriePath.data(), config.FILE_MIRRORS_TRIE_YT_PATH.data());
    NYTUtils::DownloadFile(srcClient, jupiterMirrorsTriePath, config.FILE_MIRRORS_TRIE_DISK_PATH);
    NYT::ITransactionPtr tx = dstClient->StartTransaction();
    NYTUtils::UploadFile(dstClient, config.FILE_MIRRORS_TRIE_DISK_PATH, config.FILE_MIRRORS_TRIE_YT_PATH);
    NYTUtils::SetAttr(dstClient, config.FILE_MIRRORS_TRIE_YT_PATH, ATTR_MIRRORS_SOURCE, jupiterMirrorsTriePath);
    tx->Commit();
    LOG_INFO("updating mirrors: %s to %s -done", jupiterMirrorsTriePath.data(), config.FILE_MIRRORS_TRIE_YT_PATH.data());
}

void TaskBuildHost2vecDatasetSpyLog(const TConfig &config) {
    const int DATASET_SHARDS = 8;
    NYT::IClientPtr jupiterClient = NYT::CreateClient(config.MR_SERVER_HOST_JUPITER);
    NYT::IClientPtr logsClient = NYT::CreateClient(config.MR_SERVER_HOST_LOGS);
    UpdateMirrorsTrie(config, jupiterClient, logsClient);
    PrepareDataset(config, logsClient, "//user_sessions/pub/spy_log/daily");
    DownloadDataset(logsClient, GetCompressed(config.TABLE_HOST2VEC_TRAIN_DATASET), DATASET_SHARDS);
}

void TaskBuildHost2vecDatasetSimilarGroup(const TConfig &config) {
    const int DATASET_SHARDS = 8;
    NYT::IClientPtr jupiterClient = NYT::CreateClient(config.MR_SERVER_HOST_JUPITER);
    NYT::IClientPtr logsClient = NYT::CreateClient(config.MR_SERVER_HOST_LOGS);
    UpdateMirrorsTrie(config, jupiterClient, logsClient);
    PrepareDataset(config, logsClient, "//user_sessions/pub/similargroup/daily");
    DownloadDataset(logsClient, GetCompressed(config.TABLE_HOST2VEC_TRAIN_DATASET), DATASET_SHARDS);
}

//pigz -p 4 -c -d *.gz > dataset
//./train -debug 1 -train dataset -binary 2 -window 10 -size 256 -threads 16 -output-vectors vectors -output words

void TaskUploadHost2vecModel(const TConfig &config, const TString &modelRoot) {
    static_assert(std::is_same<NWord2Vec::TCoordinate, float>::value, "this code is based on w2v float coordinates");

    NYT::IClientPtr logsClient = NYT::CreateClient(config.MR_SERVER_HOST_LOGS);
    NYT::ITransactionPtr tx = logsClient->StartTransaction();
    NYTUtils::CreatePath(tx, modelRoot);
    const TString ytWordsPath = NYTUtils::JoinPath(modelRoot, config.FILE_HOST2VEC_MODEL_WORDS);
    const TString ytVectorsPath = NYTUtils::JoinPath(modelRoot, config.FILE_HOST2VEC_MODEL_VECTORS);
    const TString ytTablePath = NYTUtils::JoinPath(modelRoot, config.FILE_HOST2VEC_MODEL_TABLE);
    LOG_INFO("uploading %s to %s", config.FILE_HOST2VEC_MODEL_WORDS.data(), ytWordsPath.data());
    NYTUtils::UploadFile(tx, config.FILE_HOST2VEC_MODEL_WORDS, ytWordsPath);
    LOG_INFO("uploading %s to %s", config.FILE_HOST2VEC_MODEL_VECTORS.data(), ytVectorsPath.data());
    NYTUtils::UploadFile(tx, config.FILE_HOST2VEC_MODEL_VECTORS, ytVectorsPath);

    LOG_INFO("uploading model to %s", ytTablePath.data());
    NHost2Vec::NProto::TEmbedding msg;
    auto writer = TTable<NHost2Vec::NProto::TEmbedding>(tx, ytTablePath).GetWriter();
    TBlob vectorsBlob = TBlob::FromFileContent(config.FILE_HOST2VEC_MODEL_VECTORS);
    TFileInput wordsStream(config.FILE_HOST2VEC_MODEL_WORDS);
    NWord2Vec::TModel model;
    model.LoadFromYandex(&wordsStream, vectorsBlob);
    for (const auto &obj : model) {
        const auto &embedding = obj.second;
        const char* begin = reinterpret_cast<const char*>(embedding.begin());
        const char* end = reinterpret_cast<const char*>(embedding.end());
        const TString embeddingStr(begin, end);
        msg.Sethost(WideToUTF8(obj.first));
        msg.Setvector(embeddingStr);
        writer->AddRow(msg);
    }
    writer->Finish();

    LOG_INFO("uploading done");
    tx->Commit();
}

void TaskUploadHost2vecModelSpyLog(const TConfig &config) {
    TaskUploadHost2vecModel(config, config.FILE_HOST2VEC_MODEL_SPY_LOG_ROOT);
}

void TaskUploadHost2vecModelSimilarGroup(const TConfig &config) {
    TaskUploadHost2vecModel(config, config.FILE_HOST2VEC_MODEL_SIMILARGROUP_ROOT);
}

void TaskSyncModels(const TConfig &config) {
    //TaskUploadHost2vecModel(config, config.FILE_HOST2VEC_MODEL_SIMILARGROUP_ROOT);

    //config.MR_SERVER_HOST_LOGS
    //config.MR_SERVER_HOST_MAIN

    //config.FILE_HOST2VEC_MODEL_SIMILARGROUP_ROOT;
    //config.FILE_HOST2VEC_MODEL_SPY_LOG_ROOT;

    NYT::IClientPtr client = NYT::CreateClient(config.MR_SERVER_HOST_LOGS);

    for (auto &node : client->List(config.FILE_HOST2VEC_MODEL_SPY_LOG_ROOT)) {
        Cout << node.AsString() << Endl;
    }

    //TDeque<TSourceTable> sourceTables;
    //LoadSourceTables(client, TCommonYTConfigSQ::CInstance().TABLE_PARSED_USER_SESSIONS_DAILY_ROOT, sourceTables, 100, TSourceTable::E_FMT_USER_SESSIONS);
/*
    THolder<IThreadPool> processQueue(CreateThreadPool(4));
    for (const TSourceTable &table : sourceTables) {
        processQueue->SafeAddFunc([=]() {
            try {
                LOG_INFO("sort table %s", table.Name.data());
                TSortCmd<NProto::TQuery>(client)
                    .OperationWeight(OPERATION_WEIGHT)
                    .Input(TTable<NProto::TQuery>(client, table.Name))
                    .Output(TTable<NProto::TQuery>(client, table.Name)
                        .SetCompressionCodec(ECompressionCodec::BROTLI_6)
                        .SetErasureCodec(EErasureCodec::LRC_12_2_2)
                    )
                    .By({"Host", "CorrectedQuery", "Path", "RegionId", "IsMobile", "IsPad", "Position", "RequestSource", "ResultSource"})
                    .Do()
                ;
                LOG_INFO("sort table %s - done", table.Name.data());
            } catch (yexception &e) {
                LOG_ERROR("sort table %s error: %s", table.Name.data(), e.what());
            }
        });
    }
    processQueue->Stop();
*/

}

void TaskDev(const TConfig &config) {
    Y_UNUSED(config);
}

static void LogInfo(const TString &msg) {
    LOG_INFO("%s", msg.data());
}

int RunTask(const TTaskHandler &taskHandler, const TConfig &config) {
    if (!config.IsGlobalOk()) {
        throw yexception() << "during loading configuration files errors occurred";
    }

    TOpRunner::LogInfo = LogInfo;

    try {
        taskHandler(config);
    } catch (const yexception &e) {
        LOG_ERROR("something went wrong: %s", e.what());
        return 1;
    }

    return 0;
}

TString GetAvailableModes(const TMap<TString, TTaskHandler> &taskHandlers) {
    TVector<TString> availableModes;

    for (const auto &obj : taskHandlers) {
        availableModes.push_back(obj.first);
    }

    return JoinSeq(" | ", availableModes);
}

} //namespace NWebmaster

int main(int argc, const char** argv) {
    NYT::Initialize(argc, argv);
    using namespace NWebmaster;

    int res = 1;

    TMap<TString, TTaskHandler> taskHandlers;
    taskHandlers["host2vec_build_train_spy_log"]        = TaskBuildHost2vecDatasetSpyLog;
    taskHandlers["host2vec_build_train_similargroup"]   = TaskBuildHost2vecDatasetSimilarGroup;
    taskHandlers["host2vec_upload_model_spy_log"]       = TaskUploadHost2vecModelSpyLog;
    taskHandlers["host2vec_upload_model_similargroup"]  = TaskUploadHost2vecModelSimilarGroup;
    taskHandlers["sync"]                                = TaskSyncModels;
    taskHandlers["dev"] = TaskDev;

    TArgs::Init(argc, argv);
    TArgs::Opts().SetFreeArgsMax(1);
    TArgs::Opts().SetFreeArgTitle(0, GetAvailableModes(taskHandlers), " ");
    auto opts = TArgs::ParseOpts();
    TVector<TString> freeArgs = opts->GetFreeArgs();

    if (freeArgs.empty()) {
        Cerr << "no mode requested" << Endl;
        return res;
    }

    const TString RequestedMode = freeArgs.front();

    LOG_INFO("Started");

    try {
        Cerr << "------ " << Now() << " ------" << Endl;

        TConfig config;
        config.Load();

        TPeriodicLog::SetDefaultHandler(config.PERIODIC_LOG_HANDLER);

        const auto modeIt = taskHandlers.find(RequestedMode);
        if (modeIt == taskHandlers.end()) {
            LOG_WARN("unknown mode requested: %s", RequestedMode.data());
        } else {
            LOG_INFO("mode requested: %s", RequestedMode.data());
            res = RunTask(modeIt->second, config);
        }
    } catch (const std::exception &e) {
        LOG_CRIT("%s", e.what());
    }

    LOG_INFO("Finished");

    return res;
}
