
#include <library/cpp/vowpalwabbit/vowpal_wabbit_model.h>
#include <library/cpp/vowpalwabbit/vowpal_wabbit_predictor.h>
#include <library/cpp/accurate_accumulate/accurate_accumulate.h>
#include <mapreduce/yt/interface/operation.h>
#include <mapreduce/yt/interface/client.h>

#include <util/string/split.h>

#include <util/ysaveload.h>

#include <utility>
#include <util/generic/hash_set.h>
#include <util/generic/iterator_range.h>
#include <util/string/join.h>

#include <mail/so/spamstop/tools/lsa/data/clusterify.h>
#include <util/generic/ymath.h>

#include "launch_mode.h"

struct TFieldNamespace{
    TString field;
    TString ns;

    size_t Size() const {
        return field.Size() + ns.Size();
    }

    Y_SAVELOAD_DEFINE(field, ns);
};

struct TCrossFieldNamespace{
    TString crossField;
    TString field;
    TString ns;

    size_t Size() const {
        return crossField.Size() + field.Size() + ns.Size();
    }

    Y_SAVELOAD_DEFINE(crossField, field, ns);
};

class TRowProcessor{
public:
    struct TResult{
        TVector<std::pair<TString, const NYT::TNode *>> forwards;
        TVector<std::pair<TString, TVector<TString>>> nsWithHashes;
    };

    TResult Process(const NYT::TNode & node) const {
        const auto & record = node.AsMap();

        TResult result;

        for(const auto & forwardField : record)
            result.forwards.emplace_back(forwardField.first, &forwardField.second);

        for(const auto & fieldNs : fieldNamepsaces)
        {
            auto it = record.find(fieldNs.field);
            if(record.cend() == it || !it->second.IsString() || it->second.AsString().empty())
                continue;

            TVector<TStringBuf> words(Reserve(200));
            Split(it->second.AsString(), ",", words);

            TVector<TString> resultWords;

            for(const auto & word : words) {
                resultWords.emplace_back(TString{word});

                for(const auto & crossFieldNs : crossNamespaces) {
                    if(crossFieldNs.field != fieldNs.field)
                        continue;

                    it = record.find(crossFieldNs.crossField);
                    if(record.cend() == it || !it->second.IsString())
                        continue;

                    const auto & crossField = it->second.AsString();

                    resultWords.emplace_back(crossField + '_' + word);
                }
            }
            result.nsWithHashes.emplace_back(fieldNs.ns, std::move(resultWords));
        }

        return result;
    }

    size_t Consumption() const {
        size_t size = 1024;

        for(const auto & f : crossNamespaces)
            size += f.Size();
        for(const auto & f : fieldNamepsaces)
            size += f.Size();

        return size * 2;
    }

    TRowProcessor() = default;
    TRowProcessor(TRowProcessor&&) = default;
    TRowProcessor(const TString & fieldNS, const TString & rawCrossFields) {
        if(!rawCrossFields.empty())
        {
            TVector<TString> crossFields;
            Split(rawCrossFields, ",", crossFields);

            for(const auto & fieldWithNamespace : crossFields) {
                TCrossFieldNamespace & fieldNamepsace = crossNamespaces.emplace_back();
                Split(fieldWithNamespace, ':', fieldNamepsace.crossField, fieldNamepsace.field, fieldNamepsace.ns);
            }
        }
        {
            TVector<TStringBuf> rawFieldNamepsaces;
            Split(fieldNS, ",", rawFieldNamepsaces);

            for(const auto & rawFieldNamepsace : rawFieldNamepsaces)
            {
                TFieldNamespace & fieldNamepsace = fieldNamepsaces.emplace_back();
                Split(rawFieldNamepsace, ':', fieldNamepsace.field, fieldNamepsace.ns);
            }
        }
    }
    Y_SAVELOAD_DEFINE(fieldNamepsaces, crossNamespaces);
private:
    TVector<TFieldNamespace> fieldNamepsaces;
    TVector<TCrossFieldNamespace> crossNamespaces;
};



class TVWMapper : public NYT::IMapper<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<NYT::TNode>>
{
public:
    size_t Consumption() const {
        size_t size = 1024;
        size += data.size();
        size += sizeof(mode);
        size += rowProcessor.Consumption();

        return size * 2;
    }

    void Start(TWriter*) override {
        if(mode == TLaunchMode::Apply) {
            model = MakeHolder<NVowpalWabbit::TModel>(TBlob::NoCopy(data.data(), data.size()));
            predictor = MakeHolder<NVowpalWabbit::TPredictor<NVowpalWabbit::TModel>>(*model);
        }
    }

    void Do(NYT::TTableReader<NYT::TNode>* input, NYT::TTableWriter<NYT::TNode>* output) override
    {
        for (; input->IsValid(); input->Next()) {
            const auto & record = input->GetRow();
            const auto & recordMap = record.AsMap();
            auto prepared = rowProcessor.Process(record);

            NYT::TNode result;

            switch(mode) {
                case TLaunchMode::Apply:{

                    for(const auto & forwardField : prepared.forwards) {
                        result(forwardField.first, *forwardField.second);
                    }

                    TKahanAccumulator<double> prediction(predictor->GetConstPrediction());
                    result("const_weight", prediction.Get());


                    TVector<double> allPredictions;
                    for(const auto & fieldNs : prepared.nsWithHashes) {
                        TVector<ui32> hashes(Reserve(fieldNs.second.size() * 2));
                        predictor->CalcHashes(fieldNs.first, fieldNs.second, 2, hashes);

                        TVector<double> predictions(Reserve(hashes.size()));
                        for(const auto hash : hashes) {
                            const auto p = (*model)[hash];
                            prediction += p;
                            predictions.emplace_back(p);
                            allPredictions.emplace_back(p);
                        }

//                        result(fieldNs.first + "_weights", JoinSeq(",", predictions));

                    }

                    {
                        auto clusters = ClusterifyValues<3>(allPredictions.cbegin(), allPredictions.cend());

                        for(auto& p: clusters)
                            p = Sigmoid(p + predictor->GetConstPrediction());

                        result("clusters", JoinSeq(",", clusters));
                    }


                    const auto resolution = Sigmoid(prediction.Get());

                    result("vw_w", prediction.Get());
                    result("vw_w_log", resolution);

                    break;
                }
                case TLaunchMode::Prepare: {
                    TStringStream vwRow;

                    auto targetIt = recordMap.find(targetField);

                    if(recordMap.cend() == targetIt)
                        continue;

                    TString target;
                    switch(targetIt->second.GetType()){

                        case NYT::TNode::Map:break;
                        case NYT::TNode::Null:break;
                        case NYT::TNode::List:break;
                        case NYT::TNode::Undefined:
                            continue;
                        case NYT::TNode::String:
                            target = targetIt->second.AsString();
                            break;
                        case NYT::TNode::Int64:
                            target = ToString(targetIt->second.AsInt64());
                            break;
                        case NYT::TNode::Uint64:
                            target = ToString(targetIt->second.AsUint64());
                            break;
                        case NYT::TNode::Double:
                            target = ToString(targetIt->second.AsDouble());
                            break;
                        case NYT::TNode::Bool:
                            target = ToString(targetIt->second.AsBool());
                            break;
                    }

                    vwRow << (positiveValues.contains(target) ? "1" : "-1") << " 1.0 0.0";

                    for(const auto & fieldNs : prepared.nsWithHashes) {
                        if(fieldNs.second.empty())
                            continue;

                        vwRow << " |" << fieldNs.first;

                        for(const auto & word : fieldNs.second)
                            vwRow << ' ' << word;
                    }

                    vwRow << "\\n";

                    result("vw_row", vwRow.Str());
                    break;
                }
            }

            output->AddRow(result);
        }
    }

    Y_SAVELOAD_JOB(mode, data, rowProcessor, targetField, positiveValues);

    TVWMapper() = default;
    explicit TVWMapper(TLaunchMode mode, TVector<char> data, TRowProcessor rowProcessor, TString targetField, const TString & rawPositiveValues)
            : mode(mode), data(std::move(data)), rowProcessor(std::move(rowProcessor)), targetField(std::move(targetField)) {

        TVector<TString> tmp;
        Split(rawPositiveValues, ",", tmp);
        positiveValues.insert(tmp.cbegin(), tmp.cend());
    }
private:
    TLaunchMode mode{};
    TVector<char> data;

    THolder<NVowpalWabbit::TModel> model;
    THolder<NVowpalWabbit::TPredictor<NVowpalWabbit::TModel>> predictor;


    TRowProcessor rowProcessor;
    TString targetField;
    THashSet<TString> positiveValues;
};

REGISTER_MAPPER(TVWMapper)

int main(int argc, const char ** argv)
{
    NYT::Initialize(argc, argv);

    const auto & progName = argv[0];
    argc --;
    argv ++;

    if(argc == 0) {
        Cerr << "Usage: " << progName << " mode" << Endl;
        return -1;
    }

    TString targetField;
    TString rawPositiveValues;
    TString fieldNS;
    TString pathToBinary;
    TString srcTable;
    TString dstTable;
    TString crossFields;
    const auto mode = FromString<TLaunchMode>(argv[0]);


    switch(mode) {

        case TLaunchMode::Apply:
            if(argc < 5) {
                Cerr << "Usage: " << progName << " mode (field:namespace,)* binary_model src_table (name:type,)* dst_table (cross:name:namespace,)*" << Endl;
                return -2;
            }
            fieldNS = argv[1];
            pathToBinary = argv[2];
            srcTable = argv[3];
            dstTable = argv[4];
            crossFields = argc >= 5 ? argv[5] : "";
            break;
        case TLaunchMode::Prepare:
            if(argc < 6) {
                Cerr << "Usage: " << progName << " mode target_field (positive_value,)* (field:namespace,)* src_table (name:type,)* dst_table (cross:name:namespace,)*" << Endl;
                return -3;
            }
            targetField = argv[1];
            rawPositiveValues = argv[2];
            fieldNS = argv[3];
            srcTable = argv[4];
            dstTable = argv[5];
            crossFields = argc >= 6 ? argv[6] : "";
            break;
    }

    auto client = NYT::CreateClient("hahn");

    NYT::TRichYPath dstTableYPath(dstTable);
    if(mode == TLaunchMode::Apply) {
        dstTableYPath.OptimizeFor(NYT::OF_SCAN_ATTR);
        NYT::TTableSchema schema;
        schema.Strict(false);

        const auto& srcScheme = client->Get(srcTable + "/@schema");
        for(const auto& val : srcScheme.AsList()) {
            if(val.HasKey("sort_order"))
                schema.AddColumn(
                        val["name"].AsString(),
                        FromString<NYT::EValueType>("VT_" + to_upper(val["type"].AsString())),
                        FromString<NYT::ESortOrder>(val["sort_order"].AsString()));
            else
                schema.AddColumn(
                        val["name"].AsString(),
                        FromString<NYT::EValueType>("VT_" + to_upper(val["type"].AsString()))
                );
        }
        schema
            .AddColumn("vw_w", NYT::VT_DOUBLE)
            .AddColumn("vw_w_log", NYT::VT_DOUBLE)
            .AddColumn("clusters", NYT::VT_STRING);

        dstTableYPath.Schema(schema);
    }

    THolder<TVWMapper> mapper;

    switch(mode) {
        case TLaunchMode::Apply: {
            const auto data = TBlob::FromFile(pathToBinary);
            mapper = MakeHolder<TVWMapper>(mode, TVector<char>((char*)data.Data(), (char*)data.Data() + data.Size()), TRowProcessor{fieldNS, crossFields}, "", "");
            break;
        }
        case TLaunchMode::Prepare:
            mapper = MakeHolder<TVWMapper>(mode, TVector<char>{}, TRowProcessor{fieldNS, crossFields}, targetField, rawPositiveValues);
            break;
    }

    NYT::TUserJobSpec mapSpec;

    if(mode == TLaunchMode::Apply)
        mapSpec.MemoryLimit(std::max(size_t(128)*1024*1024, mapper->Consumption()));

    NYT::TMapOperationSpec spec;
    spec
            .AddInput<NYT::TNode>(srcTable)
            .AddOutput<NYT::TNode>(dstTableYPath)
            .MapperSpec(mapSpec);
    client->Map(spec, mapper.Release());

    return 0;
}
