#pragma once
#include <crypta/graph/mrcc_opt/proto/messages.pb.h>
#include <crypta/graph/mrcc_opt/lib/edge.h>
#include <crypta/graph/mrcc_opt/lib/time.h>
#include <crypta/graph/mrcc_opt/lib/yt.h>
#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/interface/operation.h>
#include <mapreduce/yt/library/operation_tracker/operation_tracker.h>
#include <mapreduce/yt/util/ypath_join.h>
#include <util/folder/path.h>
#include <util/generic/hash.h>
#include <util/generic/hash_set.h>
#include <util/generic/vector.h>
#include <util/string/join.h>
#include <util/string/cast.h>
#include <util/digest/murmur.h>
#include <iostream>

using namespace NYT;

namespace NConnectedComponents {
    using NYT::IMapper;
    using NYT::IReducer;
    using NYT::TTableReader;
    using NYT::TTableWriter;

    ui64 GetHash(const TString& id) {
        return MurmurHash<ui64>(id.c_str(), id.size());
    }

    template <typename T>
    T GetID(const TVector<TString>& values);

    template <typename T>
    T GetID(const NYT::TNode& row, const TVector<TString>& fields) {
        TVector<TString> values;
        values.reserve(fields.size());
        for (const auto& field: fields) {
            values.push_back(row[field].ConvertTo<TString>());
        }
        return GetID<T>(values);
    }

    template<>
    TString GetID<TString>(const TVector<TString>& values) {
        return JoinSeq("|", values);
    }

    template<>
    ui64 GetID<ui64>(const TVector<TString>& values) {
        auto id = GetID<TString>(values);
        return GetHash(id);
    }

    template <typename TIdType>
    typename TTypedGraphEdge<TIdType>::Type PrepareOutputEdge(const TIdType& source, const TIdType& destination) {
        typename TTypedGraphEdge<TIdType>::Type out;
        if (source < destination) {
            out.SetSource(source);
            out.SetDestination(destination);
        } else {
            out.SetSource(destination);
            out.SetDestination(source);
        }
        return out;
    }

    class TExtractFirstRecordReducer: public IReducer<TTableReader<NYT::TNode>, TTableWriter<NYT::TNode>> {
        public:
        virtual void Do(TTableReader<NYT::TNode>* input, TTableWriter<NYT::TNode>* output) override {
            output->AddRow(input->MoveRow());
        }
    };

    class TExtractFirstTwoRecordsReducer: public IReducer<TTableReader<NYT::TNode>, TTableWriter<NYT::TNode>> {
        public:
        virtual void Do(TTableReader<NYT::TNode>* input, TTableWriter<NYT::TNode>* output) override {
            output->AddRow(input->MoveRow());
            input->Next();
            if (input->IsValid()) {
                output->AddRow(input->MoveRow());
            }
        }
    };

    namespace NTmpFields {
        static const TString EDGE("is_edge");
        static const TString  HASH("hash");
    };

    class TGeneralMatchVerticesWithFinalEdges {
    public:
        class TMapper: public IMapper<TTableReader<NYT::TNode>, TTableWriter<NYT::TNode>> {
        public:
            virtual void Do(TTableReader<NYT::TNode>* input, TTableWriter<NYT::TNode>* output) override {
                for (; input->IsValid(); input->Next()) {
                    const auto& row = input->GetRow();
                    NYT::TNode out(row);
                    if (input->GetTableIndex() == 0) {
                    } else {
                        out(NTmpFields::EDGE, 1);
                        out(NTmpFields::HASH, row[NGraphEdgeFields::DESTINATION]);
                    }
                    output->AddRow(out);
                }
            }
        };

        class TReducer: public IReducer<TTableReader<NYT::TNode>, TTableWriter<NYT::TNode>> {
        public:
            TReducer() {}
            TReducer(const TVector<TString>& vertexIDFields, const TString& componentField): VertexIDFields(vertexIDFields), ComponentField(componentField) {
                }
            Y_SAVELOAD_JOB(VertexIDFields, ComponentField);
            virtual void Do(TTableReader<NYT::TNode>* input, TTableWriter<NYT::TNode>* output) override {
                //  reduce by HASH
                //  sort by HASH, EDGE
                const auto& row = input->GetRow();
                if (!row[NTmpFields::EDGE].IsNull()) {
                    return;
                }
                NYT::TNode out;
                for (const auto& field: VertexIDFields) {
                    out(field, row[field]);
                }
                out(ComponentField, row[NTmpFields::HASH]);
                for (ui64 i = 0; input->IsValid(); input->Next(), ++i) {
                    const auto& row = input->GetRow();
                    if (!row[NTmpFields::EDGE].IsNull()) {
                        out(ComponentField, row[NGraphEdgeFields::SOURCE]);
                        break;
                    }
                    if (i > 1) {
                        throw yexception() << "Too much vertices for one _hash";
                    }
                }
                output->AddRow(out);
            }
        private:
            TVector<TString> VertexIDFields{};
            TString ComponentField{};
        };

    };


    template <class TIdType>
    class TMRPrepareOperations {
        public:
        using TGraphEdge = typename NConnectedComponents::TTypedGraphEdge<TIdType>::Type;
        using TEdges = typename NConnectedComponents::TTypedGraphEdge<TIdType>::Edges;


        class TConvertToMRCCEdgesMapper: public IMapper<TTableReader<NYT::TNode>, TTableWriter<TGraphEdge>> {
        public:
            TConvertToMRCCEdgesMapper() {}
            TConvertToMRCCEdgesMapper(const TVector<TString>& firstIDFields, const TVector<TString>& secondIDFields):
                FirstIDFields(firstIDFields), SecondIDFields(secondIDFields) {
            }
            Y_SAVELOAD_JOB(FirstIDFields, SecondIDFields);
            virtual void Do( TTableReader<NYT::TNode>* input, TTableWriter<TGraphEdge>* output) override {
                for (; input->IsValid(); input->Next()) {
                    const auto& row = input->GetRow();
                    auto source = GetID<TIdType>(row, FirstIDFields);
                    auto destination = GetID<TIdType>(row, SecondIDFields);
                    if (source == destination) {
                        continue;
                    }
                    TGraphEdge out = PrepareOutputEdge(source, destination);
                    output->AddRow(out);
                }
            }
        private:
            TVector<TString> FirstIDFields{};
            TVector<TString> SecondIDFields{};
        };

        class TConvertVerticesToEdges: public IMapper<TTableReader<NYT::TNode>, TTableWriter<TGraphEdge>> {
        public:
            TConvertVerticesToEdges() {}
            TConvertVerticesToEdges(const TVector<TString>& vertexIDFields, const TString& componentField):
                VertexIDFields(vertexIDFields), ComponentField(componentField) {
            }
            Y_SAVELOAD_JOB(VertexIDFields, ComponentField);
            virtual void Do( TTableReader<NYT::TNode>* input, TTableWriter<TGraphEdge>* output) override {
                for (; input->IsValid(); input->Next()) {
                    const auto& row = input->GetRow();

                    auto source = row[ComponentField].template ConvertTo<TIdType>();
                    auto destination = GetID<TIdType>(row, VertexIDFields);
                    if (source == destination) {
                        continue;
                    }
                    auto out = PrepareOutputEdge(source, destination);

                    output->AddRow(out);
                }
            }
        private:
            TVector<TString> VertexIDFields{};
            TString ComponentField{};
        };

        class TExtractVerticesMapper: public IMapper<TTableReader<NYT::TNode>, TTableWriter<NYT::TNode>> {
        public:
            TExtractVerticesMapper() {}
            TExtractVerticesMapper(const TVector<TString>& firstIDFields, const TVector<TString>& secondIDFields, const TVector<TString>& vertexIDFields):
                FirstIDFields(firstIDFields), SecondIDFields(secondIDFields), VertexIDFields(vertexIDFields) {
            }
            Y_SAVELOAD_JOB(FirstIDFields, SecondIDFields, VertexIDFields);
            virtual void Do( TTableReader<NYT::TNode>* input, TTableWriter<NYT::TNode>* output) override {
                for (; input->IsValid(); input->Next()) {
                    const auto& row = input->GetRow();
                    if (input->GetTableIndex() == 0) {
                        NYT::TNode outSource;
                        NYT::TNode outTarget;
                        for (ui64 i = 0; i < VertexIDFields.size(); ++i) {
                            outSource(VertexIDFields[i], row[FirstIDFields[i]]);
                            outTarget(VertexIDFields[i], row[SecondIDFields[i]]);
                        }
                        outSource(NTmpFields::HASH, GetID<TIdType>(row, FirstIDFields));
                        outTarget(NTmpFields::HASH, GetID<TIdType>(row, SecondIDFields));

                        output->AddRow(outSource);
                        output->AddRow(outTarget);
                    } else {
                        NYT::TNode out;
                        for (const auto& field: VertexIDFields) {
                            out(field, row[field]);
                        }
                        out(NTmpFields::HASH, GetID<TIdType>(row, VertexIDFields));
                        output->AddRow(out);
                    }
                }
            }
        private:
            TVector<TString> FirstIDFields{};
            TVector<TString> SecondIDFields{};
            TVector<TString> VertexIDFields{};
        };

        class TExtractFirstRecordReducer: public IReducer<TTableReader<TGraphEdge>, TTableWriter<TGraphEdge>> {
        public:
            virtual void Do(TTableReader<TGraphEdge>* input, TTableWriter<TGraphEdge>* output) override {
                output->AddRow(input->MoveRow());
            }
        };
    };


REGISTER_REDUCER(NConnectedComponents::TExtractFirstRecordReducer);
REGISTER_REDUCER(NConnectedComponents::TExtractFirstTwoRecordsReducer);
REGISTER_MAPPER(NConnectedComponents::TGeneralMatchVerticesWithFinalEdges::TMapper);
REGISTER_REDUCER(NConnectedComponents::TGeneralMatchVerticesWithFinalEdges::TReducer);

REGISTER_MAPPER(NConnectedComponents::TMRPrepareOperations<TString>::TExtractVerticesMapper);
REGISTER_MAPPER(NConnectedComponents::TMRPrepareOperations<TString>::TConvertToMRCCEdgesMapper);
REGISTER_MAPPER(NConnectedComponents::TMRPrepareOperations<TString>::TConvertVerticesToEdges);
REGISTER_REDUCER(NConnectedComponents::TMRPrepareOperations<TString>::TExtractFirstRecordReducer);


REGISTER_MAPPER(NConnectedComponents::TMRPrepareOperations<ui64>::TExtractVerticesMapper);
REGISTER_MAPPER(NConnectedComponents::TMRPrepareOperations<ui64>::TConvertToMRCCEdgesMapper);
REGISTER_MAPPER(NConnectedComponents::TMRPrepareOperations<ui64>::TConvertVerticesToEdges);
REGISTER_REDUCER(NConnectedComponents::TMRPrepareOperations<ui64>::TExtractFirstRecordReducer);


template <class TDataView, class TIdType>
struct TConverter {
    void ConvertDataToMRCCEdges(const TYT& yt, const TString& source, const TString& destination, const TDataView& dataView, IOutputStream& logger = Cout);
    void ConvertComponentsToMRCCEdges(const TYT& yt, const TString& source, const TString& destination, const TDataView& dataView, IOutputStream& logger = Cout, bool append = false);
    IOperationPtr ExtractVerticesFromData(const TYT& yt, const TString& source, const TString& destination, const TString& previousLabels, const TDataView& dataView, IOutputStream& logger = Cout, bool wait = true);
    void ConvertMRCCEdgesToComponents(const TYT& yt, const TString& vertices, const TVector<TString>& sources, const TString& destination, const TDataView& dataView, IOutputStream& logger = Cout);
    TConverter() {}
};


struct TGeneralDataView {
    struct {
        TVector<TString> FirstIDFields{};
        TVector<TString> SecondIDFields{};
    } Edge;
    struct {
        TVector<TString> IDFields{};
        TString ComponentField{};
    } Vertex;
    TGeneralDataView() {}
    TGeneralDataView(const TVector<TString>& firstIDFieldsIn, const TVector<TString>& secondIDFieldsIn,
        const TVector<TString>& fieldsOut, const TString& componentFieldOut) :
        Edge({.FirstIDFields=firstIDFieldsIn, .SecondIDFields=secondIDFieldsIn}),
        Vertex({.IDFields=fieldsOut, .ComponentField=componentFieldOut}) {}
};


template <class TIdType>
struct TConverter<TGeneralDataView, TIdType> {
    void ConvertDataToMRCCEdges(const TYT& yt, const TString& source, const TString& destination, const TGeneralDataView& dataView, IOutputStream& logger = Cout) {
        TMeasure measure(logger, __func__);
        NYT::TMapOperationSpec spec;
        spec.AddInput<NYT::TNode>(source);
        spec.template AddOutput<typename TMRPrepareOperations<TIdType>::TGraphEdge>(destination);

        auto op = yt.Client->Map(
            spec,
            new typename NConnectedComponents::TMRPrepareOperations<TIdType>::TConvertToMRCCEdgesMapper(dataView.Edge.FirstIDFields, dataView.Edge.SecondIDFields),
            yt.CommonOperationOptions
        );
    }

    IOperationPtr ExtractVerticesFromData(const TYT& yt, const TString& source, const TString& destination, const TString& previousLabels, const TGeneralDataView& dataView, IOutputStream& logger = Cout, bool wait = true) {
        TMeasure measure(logger, __func__);
        NYT::TMapReduceOperationSpec spec;
        spec
            .AddInput<NYT::TNode>(source)
            .AddOutput<NYT::TNode>(destination)
            .ReduceBy({NTmpFields::HASH})
            .SortBy({NTmpFields::HASH});
        if (previousLabels) {
            spec.AddInput<NYT::TNode>(previousLabels);
        }
        NYT::TOperationOptions options(yt.CommonOperationOptions);
        options.Wait(wait);

        auto op = yt.Client->MapReduce(
            spec,
            new typename NConnectedComponents::TMRPrepareOperations<TIdType>::TExtractVerticesMapper(dataView.Edge.FirstIDFields, dataView.Edge.SecondIDFields, dataView.Vertex.IDFields),
            new NConnectedComponents::TExtractFirstRecordReducer,
            new NConnectedComponents::TExtractFirstRecordReducer,
            options
        );
        return op;
    }

    void ConvertComponentsToMRCCEdges(const TYT& yt, const TString& source, const TString& destination, const TGeneralDataView& dataView, IOutputStream& logger, bool append) {
        TMeasure measure(logger, __func__);
        NYT::TMapOperationSpec spec;
        spec.AddInput<NYT::TNode>(source);
        spec.template AddOutput<typename TMRPrepareOperations<TIdType>::TGraphEdge>(NYT::TRichYPath(destination).Append(append));

        auto op = yt.Client->Map(
            spec,
            new typename NConnectedComponents::TMRPrepareOperations<TIdType>::TConvertVerticesToEdges(dataView.Vertex.IDFields, dataView.Vertex.ComponentField)
        );

    }

    void ConvertMRCCEdgesToComponents(const TYT& yt, const TString& vertices, const TVector<TString>& sources, const TString& destination, const TGeneralDataView& dataView, IOutputStream& logger) {
        TMeasure measure(logger, __func__);
        NYT::TMapReduceOperationSpec spec;
        spec.AddInput<NYT::TNode>(vertices);
        for (const auto& source: sources) {
            spec.AddInput<NYT::TNode>(source);
        }
        spec.AddOutput<NYT::TNode>(destination);
        spec.ReduceBy({NConnectedComponents::NTmpFields::HASH})
            .SortBy({NConnectedComponents::NTmpFields::HASH, NConnectedComponents::NTmpFields::EDGE});

        auto op = yt.Client->MapReduce(
            spec,
            new NConnectedComponents::TGeneralMatchVerticesWithFinalEdges::TMapper,
            new NConnectedComponents::TExtractFirstTwoRecordsReducer,
            new NConnectedComponents::TGeneralMatchVerticesWithFinalEdges::TReducer(dataView.Vertex.IDFields, dataView.Vertex.ComponentField),
            yt.CommonOperationOptions
        );
    }
};


template <class TDataView>
struct TDataPaths {
    TDataView DataView;
    TString SourceData{};
    TString SourceDataTmp{};
    TString DestinationComponents{};
    TString DestinationComponentsTmp{};
    TString Workdir{};
    TString PreviousLabels{};

    TDataPaths() {PreviousLabels = "";}
    TDataPaths(const TString& source, const TString& destination, const TString& workdir, const TString& previousLabels = "") :
            SourceData(source), DestinationComponents(destination), Workdir(workdir), PreviousLabels(previousLabels) {
            SourceDataTmp = NYT::JoinYPaths(Workdir, "SourceDataTmp");
            DestinationComponentsTmp = NYT::JoinYPaths(Workdir, "DestinationTmp");
    }
    TDataPaths(const TDataView& dataView, const TString& source, const TString& destination, const TString& workdir, const TString& previousLabels = "") :
        DataView(dataView), SourceData(source), DestinationComponents(destination), Workdir(workdir), PreviousLabels(previousLabels) {
            SourceDataTmp = NYT::JoinYPaths(Workdir, "SourceDataTmp");
            DestinationComponentsTmp = NYT::JoinYPaths(Workdir, "DestinationTmp");
    }

};

template <class TDataView, class TIdType>
class TDataTransformer {
public:
    TDataTransformer<TDataView, TIdType>(const TYT& yt, const TDataView& dataView,
        IOutputStream& logger = Cout) :
    Yt(yt),
    DataView(dataView),
    Converter(),
    Logger(logger) {}

    void ExtractEdgesAndVertices(const TDataPaths<TDataView>& dataPaths, const TString& destinationEdgesPaths) {
        Yt.Client->Copy(dataPaths.SourceData, dataPaths.SourceDataTmp, NYT::TCopyOptions().Recursive(true).Force(true));
        auto op = ExtractVerticesFromData(dataPaths.SourceDataTmp, dataPaths.DestinationComponentsTmp, dataPaths.PreviousLabels, false);
        ConvertDataToMRCCEdges(dataPaths.SourceDataTmp, destinationEdgesPaths, dataPaths.PreviousLabels);
        NYT::TOperationTracker verticesTracker;
        verticesTracker.AddOperation(op);
        verticesTracker.WaitAllCompleted();
        Yt.Client->Remove(dataPaths.SourceDataTmp);
    }

    void ConvertDataToMRCCEdges(const TString& sourcePath, const TString& destinationPath, const TString& previousLabels = "")  {
        Converter.ConvertDataToMRCCEdges(Yt, sourcePath, destinationPath, DataView, Logger);
        if (previousLabels) {
            Converter.ConvertComponentsToMRCCEdges(
                Yt, previousLabels, destinationPath, DataView, Logger, true
            );
        }
    }

    NYT::IOperationPtr ExtractVerticesFromData(const TString& sourceEdgesPath, const TString& destinationPath, const TString& previousLabels, bool wait = true) {
        return Converter.ExtractVerticesFromData(Yt, sourceEdgesPath, destinationPath, previousLabels, DataView, Logger, wait);
    }

    void JoinSourceVerticesWithComponents(const TString& sourceEdgesPath, const TVector<TString>& mrccEdgesPaths, const TString& destinationPath, const TString& previousLabels) {
        ExtractVerticesFromData(sourceEdgesPath, destinationPath, previousLabels);
        return ConvertMRCCEdgesToComponents(destinationPath, mrccEdgesPaths, destinationPath);
    }

    void ConvertMRCCEdgesToComponents(const TString& verticesPath, const TVector<TString>& mrccEdgesPaths, const TString& destinationPath) {
        return Converter.ConvertMRCCEdgesToComponents(Yt, verticesPath, mrccEdgesPaths, destinationPath, DataView, Logger);
    }

    void ConvertMRCCEdgesToComponents(const TDataPaths<TDataView>& dataPaths, const TVector<TString>& edges) {
        ConvertMRCCEdgesToComponents(dataPaths.DestinationComponentsTmp, edges, dataPaths.DestinationComponentsTmp);
        auto tx = Yt.Client->StartTransaction();
        if (tx->Exists(dataPaths.DestinationComponents)) {
            tx->Remove(dataPaths.DestinationComponents);
        }
        tx->Move(dataPaths.DestinationComponentsTmp, dataPaths.DestinationComponents);
        tx->Commit();
    }

private:
    TYT Yt;
    TDataView DataView;
    TConverter<TDataView, TIdType> Converter;
    IOutputStream& Logger;
};

} // namespace NConnectedComponents
