//
// Created by luckybug on 04.10.17.
//

#include "usage.h"

#include <util/generic/vector.h>
#include <util/generic/list.h>
#include <util/string/split.h>
#include <util/stream/file.h>
#include <util/generic/hash.h>
#include <util/folder/path.h>
#include <library/cpp/json/json_writer.h>
#include <library/cpp/json/json_reader.h>
#include <util/generic/buffer.h>
#include <util/memory/blob.h>
#include <util/random/fast.h>
#include <util/random/shuffle.h>
#include <library/cpp/hnsw/index_builder/index_data.h>
#include <library/cpp/hnsw/index_builder/index_builder.h>
#include <util/stream/buffer.h>
#include <library/cpp/hnsw/index_builder/index_writer.h>
#include <library/cpp/hnsw/index/index_base.h>


#include <mail/so/spamstop/tools/lsa/data/matrix.h>
#include <mail/so/spamstop/tools/lsa/data/dictionary.h>
#include <mail/so/spamstop/tools/lsa/data/theme_tree.h>
#include <mail/so/spamstop/tools/lsa/data/hnsw.h>
#include <mail/so/spamstop/tools/lsa/data/mark.h>
#include <mail/so/libs/syslog/so_log.h>
#include <mapreduce/yt/interface/init.h>
#include <mapreduce/yt/interface/client.h>
#include <library/cpp/hnsw/index_builder/dense_vector_index_builder.h>

NLSA::TDictionary LoadDictionaryFromJSON(const TFsPath & dictJson) {
    NJson::TJsonValue val;
    {
        TBlob blob = TBlob::FromFile(dictJson);
        Y_VERIFY(NJson::ReadJsonFastTree(TStringBuf((const char*)blob.Data(), blob.Size()), &val));
    }


    const auto & array = val.GetArray();

    Cout << "source dictionary size: " << array.size() << Endl;

    NLSA::TDictionary dictionary;

    for(const NJson::TJsonValue & el : array)
    {
        TString word;
        {
            NJson::TJsonValue v;
            Y_VERIFY(el.GetValue("word", &v));
            Y_VERIFY(v.GetString(&word));
        }
        NJson::TJsonValue::TArray coordinate;
        {
            NJson::TJsonValue v;
            Y_VERIFY(el.GetValue("coordinate", &v));

            Y_VERIFY(v.GetArray(&coordinate));
        }
        unsigned long long cluster{};
        {
            NJson::TJsonValue v;
            if(el.GetValue("cluster", &v))
                v.GetUInteger(&cluster);
        }

        NLSA::TMatrix m(1, coordinate.size());

        for(size_t i = 0; i < coordinate.size(); i++)
            m(0, i) = coordinate[i].GetDouble();

        dictionary.emplace(word, NLSA::TW2VTrait{m, cluster});
    }

    return dictionary;
}

NLSA::TViewDictionary LoadDictionaryFromTSV(const TFsPath & dictTSV) {
    TFileInput inputTsv(dictTSV);
    inputTsv.ReadLine();

    NLSA::TViewDictionary dictionary;

    TString line;
    while(inputTsv.ReadLine(line)) {
        TStringBuf view(line);

        const auto splitter = StringSplitter(line);
        auto range = splitter.SplitBySet(" \t");
        auto it = range.begin();

        TStringBuf word = it->Token();
        ++it;

        TVector<float> m;
        for(;it != range.end(); ++it) {
            m.emplace_back(FromString(it->Token()));
        }

        dictionary.emplace(word, NLSA::TW2VViewTrait{NLSA::TViewMatrixData(std::move(m)), 0});

    }

    return dictionary;
}

void PrepareBinanryForLSA(const TFsPath & dictJson, const TFsPath & dictDst)
{
    const NLSA::TDictionary & dictionary = LoadDictionaryFromJSON(dictJson);
    Cout << dictionary.size() << Endl;

    {
        TFileOutput output(dictDst);
        ::Save(&output, dictionary);
    }
}

void W2VTsvToBin(const TFsPath & tsv, const TFsPath & binPath, const NYT::TRichYPath& tableTarget)
{
    const NLSA::TViewDictionary dictionary = LoadDictionaryFromTSV(tsv);
    Cout << dictionary.size() << Endl;

    {
        TFileOutput output(binPath);
        ::Save(&output, dictionary);
    }
    {
        auto c = NYT::CreateClient("hahn");

        auto writer = c->CreateTableWriter<NYT::TNode>(tableTarget);

        for(const auto& [word, traits]: dictionary){
            auto coordsNode = NYT::TNode::CreateList();
            auto& coords = coordsNode.AsList();
            for(const auto v : traits.coordinate.data)
                coords.emplace_back(v);
            writer->AddRow(NYT::TNode()("word", word)("coordinate", coordsNode));
        }
    }
}

NLSA::TThemeTree LoadThemes(const TFsPath & themesJson)
{
    NLSA::TThemeTree themeTree;
    {
        NJson::TJsonValue val;

        {
            TBlob blob = TBlob::FromFile(themesJson);
            Y_VERIFY(NJson::ReadJsonFastTree(TStringBuf((const char *) blob.Data(), blob.Size()), &val));
        }

        const auto &jsonThemes = val.GetArraySafe();


        for (const auto &t : jsonThemes) {
//        children: [
//        1,
//                2,
//                3,
//                34,
//                44
//        ],
//        description: "Root",
//            id: 0,
//                parent: 0
            const auto &rawTheme = t.GetMapSafe();

            NLSA::TTheme theme;

            theme.themeId = static_cast<int>(rawTheme.find("id")->second.GetIntegerSafe());
            theme.description = rawTheme.find("description")->second.GetString();
            theme.parentId = static_cast<int>(rawTheme.find("parent")->second.GetIntegerSafe());
            const auto &children = rawTheme.find("children")->second.GetArraySafe();
            std::transform(children.cbegin(), children.cend(), std::back_inserter(theme.children),
                           [](const NJson::TJsonValue &v) { return v.GetIntegerSafe(); });

            themeTree.AddTheme(theme);

            Y_VERIFY(*themeTree.GetThemeById(theme.themeId) == theme);

            Cout << theme.themeId << Endl;
            Cout << theme.description << Endl;
            Cout << theme.parentId << Endl;
            for(auto ch : theme.children)
            {
                Cout << ch << ' ';
            }
            Cout << Endl;
            Cout << Endl;

        }
    }

    return themeTree;
}

NLSA::TMlByIds LoadML(const TFsPath & mlJson)
{
    NJson::TJsonValue val;

    {
        TMappedFileInput in(mlJson);
        Y_VERIFY(NJson::ReadJsonTree(&in, &val));
    }

    NLSA::TMlByIds websById;

    const auto & mlByIds = val.GetMapSafe();

    for(const auto & mlWithId : mlByIds)
    {
        const auto id = FromString<int>(mlWithId.first);

        const auto & ml = mlWithId.second.GetMap();

        NLSA::TWeb web;

        for(const auto & rawArray : ml.find("intercepts")->second.GetArraySafe())
        {
            const auto & array = rawArray.GetArraySafe();

            NLSA::TMatrix m(1, array.size());

            for(size_t i = 0; i < array.size(); i++)
            {
                m(0, i) = array[i].GetDoubleSafe();
            }

            web.offsets.emplace_back(m);
        }

        for(const auto & rawMatrix : ml.find("coefs")->second.GetArraySafe()) {

            const auto & matrix = rawMatrix.GetArraySafe();

            const size_t rows = matrix.size();
            const size_t cols = matrix[0].GetArraySafe().size();

            NLSA::TMatrix m(rows, cols);
            for(size_t y = 0; y < rows; y ++) {

                const auto &array = matrix[y].GetArraySafe();

                for (size_t x = 0; x < cols; x++) {
                    m(y, x) = array[x].GetDoubleSafe();
                }
            }
            web.layers.emplace_back(m);
        }

        websById.emplace(id, web);
    }

    return websById;
};

void PrepareThemesForLSA(const TFsPath & themesJson, const TFsPath & dstBin)
{
    const NLSA::TThemeTree & tree = LoadThemes(themesJson);

    {
        TUnbufferedFileOutput output(dstBin);

        ::Save(&output, tree);
    }
}

void PrepareML(const TFsPath & mlJson, const TFsPath & dstBin)
{
    const NLSA::TMlByIds & ml = LoadML(mlJson);

    {
        TUnbufferedFileOutput output(dstBin);

        ::Save(&output, ml);
    }
}

TVector<NLSA::TMark> LoadMarks(const TFsPath & jsonPath)
{
    NJson::TJsonValue val;

    {
        TBlob blob = TBlob::FromFile(jsonPath);
        Y_VERIFY(NJson::ReadJsonFastTree(TStringBuf((const char *) blob.Data(), blob.Size()), &val));
    }

    const auto & jsonArray = val.GetArraySafe();

    TVector<NLSA::TMark> marks;

    marks.reserve(jsonArray.size());

    for(const auto & jsonMark : jsonArray)
    {
        NLSA::TMark mark;

        if(!jsonMark.Has("rcpt") || !jsonMark.Has("mid") || !jsonMark.Has("target"))
            continue;

        mark.rcpt = jsonMark["rcpt"].GetUIntegerSafe();
        mark.mid = jsonMark["mid"].GetStringSafe();

        for(const auto & v : jsonMark["target"].GetArraySafe())
            mark.targets.emplace(v.GetIntegerSafe());

        if(jsonMark.Has("body"))
            for(const auto & v : jsonMark["body"].GetArraySafe())
                mark.body.emplace_back(v.GetStringSafe());

        if(jsonMark.Has("subject"))
            for(const auto & v : jsonMark["subject"].GetArraySafe())
                mark.subject.emplace_back(v.GetStringSafe());

        if(jsonMark.Has("fromaddr"))
            for(const auto & v : jsonMark["fromaddr"].GetArraySafe())
                mark.fromaddr.emplace_back(v.GetStringSafe());

        if(jsonMark.Has("fromname"))
            for(const auto & v : jsonMark["fromname"].GetArraySafe())
                mark.fromname.emplace_back(v.GetStringSafe());

        marks.emplace_back(mark);
    }

    return marks;
}

void PrepareHNSW(
        const TFsPath & w2vbin,
        const NYT::TRichYPath& srcTable,
        const TFsPath & dstBin)
{

    NLSA::TViewDictionary dict;
    {
        TMappedFileInput f(w2vbin);
        ::Load(&f, dict);
        NLSA::NormalizeDict(dict);
    }

    Cout << "loaded dict: " << dict.size() << Endl;

    TVector<NLSA::TItemStorage::TItem> coordinates;
    {
        auto c = NYT::CreateClient("hahn");

        for(auto it = c->CreateTableReader<NYT::TNode>(srcTable); it->IsValid(); it->Next()){
            const auto& word = it->GetRow()["word"].AsString();

            auto w2vIt = dict.find(word);
            if(dict.cend() == w2vIt)
                continue;
            coordinates.emplace_back(w2vIt->second.coordinate);
        }

    }

    NLSA::TItemStorage itemStorage(std::move(coordinates));

    auto opts = NLSA::BuildOptions(itemStorage.GetNumItems());
    const auto & indexData = NHnsw::BuildIndex<NLSA::TDistance>(opts, itemStorage);

    TBufferOutput bufferOutput;

    NHnsw::WriteIndex(indexData, bufferOutput);

    NLSA::THnswContext hnswContext(bufferOutput.Buffer(), itemStorage);

    TFileOutput output(dstBin);
    ::Save(&output, hnswContext);
}

int main(int argc, const char * argv[]) try{
    NYT::Initialize(argc, argv);
    argc --;
    argv ++;

    if(argc < 1)
        ythrow yexception() << "usage";

    const auto usage = FromString<EUsage>(argv[0]);

    argc --;
    argv ++;

    switch(usage){
        case EUsage::w2vbin:
            W2VTsvToBin(argv[0], argv[1], argv[2]);
            break;
        case EUsage::w2v:
            PrepareBinanryForLSA(argv[0], argv[1]);
            break;
        case EUsage::themes:
            PrepareThemesForLSA(argv[0], argv[1]);
            break;
        case EUsage::ml:
            PrepareML(argv[0], argv[1]);
            break;
        case EUsage::hnsw:
            PrepareHNSW(argv[0], argv[1], argv[2]);
            break;
    }
} catch(...) {
    Cerr << CurrentExceptionMessageWithBt() << Endl;
    return 1;
}
