#include "graph_handler.h"

namespace {
    using TVertexId = size_t;
    using TDistance = size_t;
    using TAdjList = TVector<THashSet<TVertexId>>;
    using TDistanceList = TVector<TDistance>;
    constexpr size_t kInf{std::numeric_limits<TDistance>::max()};

    template <class TEdges>
    TAdjList BuildAdjList(const size_t verticesSize, const TEdges& edges,
                          const std::function<bool(const typename TEdges::value_type&)>& predicate
                          = []([[maybe_unused]] const typename TEdges::value_type& edge) { return true; }) {
        TAdjList adjList(verticesSize);
        for (const auto& edge : edges) {
            if (predicate(edge)) {
                adjList[edge.GetVertex1()].insert(edge.GetVertex2());
                adjList[edge.GetVertex2()].insert(edge.GetVertex1());
            }
        }
        return adjList;
    };

    TDistanceList Bfs(const size_t start, const TAdjList& adjList) {
        TQueue<std::pair<TVertexId, TDistance>> candidates;
        TDistanceList distances(adjList.size(), kInf);
        candidates.push({start, 0});

        while (!candidates.empty()) {
            const auto [vertex, distance] = candidates.front();
            if (distances[vertex] == kInf) {
                distances[vertex] = distance;
                for (const auto adj : adjList.at(vertex)) {
                    candidates.push({adj, distance + 1});
                }
            }
            candidates.pop();
        }
        return distances;
    };
}

namespace NMichurin {
    constexpr auto kLtMichurin{yabs::proto::Profile::TSourceUniq::LT_CRYPTA_MICHURIN};

    TGraphHandler::TGraphHandler(NCrypta::NGraphEngine::TGraph* graph)
        : graph(graph)
    {
        const auto& graphVertices = graph->GetVertices();
        for (ui32 index = 0; const auto& vv : graphVertices) {
            vertexToIndexCountMap[TGenericID(vv)] = {index, 0};
            ++index;
        }
        for (ui32 index = 0; const auto& ee : graph->GetEdges()) {
            const auto& edgeData = edgeDataFromEdge(ee);
            edgeToIndexMap[edgeData] = index;
            ++index;

            const auto& v1 = TGenericID(graphVertices[ee.GetVertex1()]);
            ++(vertexToIndexCountMap[v1].refCount);
            const auto& v2 = TGenericID(graphVertices[ee.GetVertex2()]);
            ++(vertexToIndexCountMap[v2].refCount);
        }
    };

    bool TGraphHandler::ProcessEventMessage(const TEventMessage& eventMessage) {
        const auto body{NCrypta::NEvent::UnpackAny(eventMessage)};
        const auto& content{static_cast<TSoupEvent&>(*body)};
        return ProcessSoupEvent(content);
    }

    bool TGraphHandler::ProcessSoupEvent(const TSoupEvent& soupEvent) {
        const auto& edgeBetween = soupEvent.GetEdge();
        const auto [_, added] = AddEdge(edgeBetween, soupEvent.GetUnixtime());
        return added;
    }

    TVertexInfo& TGraphHandler::AddVertex(const TGenericID& vertex) {
        const ui32 idx = vertexToIndexCountMap.size();
        auto [it, added] = vertexToIndexCountMap.insert({vertex, {idx, 0}});
        if (added) {
            graph->MutableVertices()->Add()->CopyFrom(vertex.ToProto());
        }
        return it->second;
    }

    std::pair<TEdge*, bool>
    TGraphHandler::AddEdge(const TEdgeBetween& incomingEdge, ui32 timestamp) {
        auto& vInfo1 = AddVertex(TGenericID{incomingEdge.GetVertex1()});
        auto& vInfo2 = AddVertex(TGenericID{incomingEdge.GetVertex2()});

        TEdgeDataTuple edgeData{vInfo1.index, vInfo2.index, incomingEdge.GetSourceType(), incomingEdge.GetLogSource()};
        if (const auto it = edgeToIndexMap.find(edgeData); it != edgeToIndexMap.end()) {
            auto edge = graph->MutableEdges()->Mutable(it->second);
            edge->SetTimeStamp(std::max(edge->GetTimeStamp(), timestamp));
            if (incomingEdge.GetSeenCount() == 0) {
                edge->SetSeenCount(edge->GetSeenCount() + 1);
            } else {
                edge->SetSeenCount(std::max(edge->GetSeenCount(), incomingEdge.GetSeenCount()));
                //max is more safety, but maybe we should take sum
            }
            return std::make_pair(edge, false);
        } else {
            edgeToIndexMap[edgeData] = edgeToIndexMap.size();
            auto newEdge = graph->MutableEdges()->Add();
            newEdge->SetSourceType(incomingEdge.GetSourceType());
            newEdge->SetLogSource(incomingEdge.GetLogSource());
            newEdge->SetVertex1(vInfo1.index);
            newEdge->SetVertex2(vInfo2.index);
            newEdge->SetTimeStamp(timestamp);
            newEdge->SetCreated(timestamp);
            newEdge->SetSeenCount(std::max(incomingEdge.GetSeenCount(), {1}));
            ++(vInfo1.refCount);
            ++(vInfo2.refCount);
            return std::make_pair(newEdge, true);
        }
    }

    void TGraphHandler::Merge(const TGraphHandler& other) {
        const auto edgeList = other.graph->GetEdges();
        for (const auto& edge : edgeList) {
            TEdgeBetween newEdgeBetween;
            newEdgeBetween.SetSourceType(edge.GetSourceType());
            newEdgeBetween.SetLogSource(edge.GetLogSource());
            newEdgeBetween.MutableVertex1()->CopyFrom(other.graph->GetVertices().Get(edge.GetVertex1()));
            newEdgeBetween.MutableVertex2()->CopyFrom(other.graph->GetVertices().Get(edge.GetVertex2()));

            auto [newEdge, _] = AddEdge(newEdgeBetween, edge.GetTimeStamp());
            newEdge->SetDatesWeight(edge.GetDatesWeight());
            newEdge->SetSurvivalWeight(edge.GetSurvivalWeight());
            newEdge->SetIndevice(edge.GetIndevice());
            newEdge->SetIsStrong(edge.GetIsStrong());
        }
    }

    void TGraphHandler::RebuildState(const TVector<TSoupEvent>& soupToKeep) {
        clear();
        for (const auto& soupEvent : soupToKeep) {
            ProcessSoupEvent(soupEvent);
        }
    }

    THashMap<ui64, TVector<TSoupEvent>> TGraphHandler::Split(bool forceEdgesStrong) {

        const auto& customGraph = NCrypta::NGraphEngine::TCommonGraph(*graph, forceEdgesStrong);
        const auto& innerGraph = customGraph.GetInnerGraph();
        const auto& splitGraph = NCrypta::NGraphEngine::Split(innerGraph);
        const auto& oldCid = graph->GetId();

        THashMap<TString, ui32> vertexValueToComponentIndex;
        THashMap<ui32, ui64> componentIndexToCid;
        THashMap<ui64, TVector<TSoupEvent>> toRewind;

        auto componentsNumber = *MaxElement(splitGraph.begin(), splitGraph.end()) + 1;

        if (componentsNumber == 1) {
            // the graph will not split so we can not rebuild it
            return toRewind;
        }

        TVector<ui32> componentsSize(componentsNumber, 0);
        for (const auto& vertex: innerGraph.Vertices) {
            vertexValueToComponentIndex[vertex.second.CustomVertex.Value] = splitGraph[vertex.first];
            ++componentsSize[splitGraph[vertex.first]];
        }

        ui32 biggestComponentIndex = MaxElement(componentsSize.begin(), componentsSize.end()) - componentsSize.begin();
        componentIndexToCid[biggestComponentIndex] = oldCid;

        for (const auto& protoVertex : graph->GetVertices()) {
            const TGenericID genericId(protoVertex);
            auto componentIndex = vertexValueToComponentIndex[genericId.GetValue()];

            if (componentsSize[componentIndex] == 1) {
                componentIndexToCid[componentIndex] = oldCid;
                // TODO shuldn't be here, maybe add some solomon allerts
                continue;
            }

            if (!componentIndexToCid.contains(componentIndex)) {
                if (const auto &newCid = NCrypta::GenerateCryptaId(genericId); newCid != oldCid) {
                    componentIndexToCid[componentIndex] = newCid;
                }
            }
        }

        TVector<TSoupEvent> edgesToKeep;
        const auto& vertices{graph->GetVertices()};

        for (const auto& edge : graph->GetEdges()) {
            const TString firstVertexValue = TGenericID{vertices[edge.GetVertex1()]}.GetValue();
            const TString secondVertexValue = TGenericID{vertices[edge.GetVertex2()]}.GetValue();
            ui32 firstVertexComponentIndex = vertexValueToComponentIndex[firstVertexValue];
            ui32 secondVertexComponentIndex = vertexValueToComponentIndex[secondVertexValue];
            if ((firstVertexComponentIndex == secondVertexComponentIndex)) {
                TSoupEvent soupEvent;
                soupEvent.SetUnixtime(edge.GetTimeStamp());

                auto mutableEdgeBetween = soupEvent.MutableEdge();
                mutableEdgeBetween->SetSourceType(edge.GetSourceType());
                mutableEdgeBetween->SetLogSource(edge.GetLogSource());
                mutableEdgeBetween->SetSeenCount(edge.GetSeenCount());

                mutableEdgeBetween->MutableVertex1()->CopyFrom(vertices[edge.GetVertex1()]);
                mutableEdgeBetween->MutableVertex2()->CopyFrom(vertices[edge.GetVertex2()]);

                const auto& cid = componentIndexToCid[firstVertexComponentIndex];
                soupEvent.SetCryptaId1(cid);
                soupEvent.SetCryptaId2(cid);
                if (firstVertexComponentIndex != biggestComponentIndex) {
                    if (toRewind.contains(cid)) {
                        toRewind[cid].push_back(soupEvent);
                    } else {
                        toRewind[cid] = TVector<TSoupEvent> {soupEvent};
                    }
                }
                else {
                    edgesToKeep.push_back(soupEvent);
                }
            }
        }
        RebuildState(edgesToKeep);
        return toRewind;
    }

    THashSet<TGenericID> TGraphHandler::LimitEdges(int limit) {
        auto edges = graph->MutableEdges();
        THashSet<TGenericID> droppedVertices;

        if (edges->size() <= limit) {
            return droppedVertices;
        }

        // TODO: Add more sort conditions like seen count, when theyre here
        SortBy(edges->rbegin(), edges->rend(), [](auto& e) { return e.GetTimeStamp(); });
        while (edges->size() > limit) {
            edges->RemoveLast();
        }
        Y_ENSURE(edges->size() == limit);

        auto soup = ToSoup();
        Reverse(soup.begin(), soup.end());
        Y_ENSURE(soup.size() == static_cast<size_t>(limit));

        droppedVertices.reserve(graph->GetVertices().size());
        for (const auto& vertex : graph->GetVertices()) {
            droppedVertices.insert(TGenericID(vertex));
        }

        clear();
        for (const auto& soupEvent : soup) {
            ProcessSoupEvent(soupEvent);
        }
        for (const auto& vertex : graph->GetVertices()) {
            droppedVertices.erase(TGenericID(vertex));
        }
        return droppedVertices;
    }

    // Swappes out the oldest edge to the end of the repeated field and drops it, while updating
    // index map for the swapped in edge. Also if any of the vertices are no longer reachable
    // by an edge they're also dropped
    // TODO: Maybe rewrite to DropNOldestEdges, to only sort once.
    TVector<TGenericID> TGraphHandler::DropOldestEdge() {
        TVector<TGenericID> dropped;
        auto ee = graph->MutableEdges();
        if (ee->size() == 0) {
            return dropped;
        }
        const auto& it = MinElement(ee->begin(), ee->end(), [](const auto& lhs, const auto& rhs) {
            return lhs.GetTimeStamp() < rhs.GetTimeStamp();
        });
        const ui32 pos = it - ee->begin();
        const ui32 lastPos = ee->size() - 1;
        const auto& oldEdgeInfo = edgeDataFromEdge(*it);
        if (pos != lastPos) {
            const auto& lastEdgeInfo = edgeDataFromEdge(ee->Get(lastPos));
            edgeToIndexMap[lastEdgeInfo] = pos;
        }
        edgeToIndexMap.erase(oldEdgeInfo);

        auto v1 = TGenericID(graph->GetVertices()[it->GetVertex1()]);
        auto& v1Info = vertexToIndexCountMap[v1];
        auto v2 = TGenericID(graph->GetVertices()[it->GetVertex2()]);
        auto& v2Info = vertexToIndexCountMap[v2];

        if (--(v1Info.refCount); v1Info.refCount == 0) {
            DropVertex(v1);
            dropped.push_back(std::move(v1));
        }
        if (--(v2Info.refCount); v2Info.refCount == 0) {
            DropVertex(v2);
            dropped.push_back(std::move(v2));
        }

        ee->SwapElements(pos, lastPos);
        ee->RemoveLast();
        return dropped;
    }

    // Swappes out the vertex to the end of repeated field and removes it.
    // Keeps index map updated for the swapped in vertex and updates indices for
    // any edges that references the swapped in vertex.
    void TGraphHandler::DropVertex(const TGenericID& vertex) {
        auto vertices = graph->MutableVertices();
        const ui32 pos = vertexToIndexCountMap[vertex].index;
        const ui32 lastPos = vertices->size() - 1;
        if (pos != lastPos) {
            const auto& lastVertex = TGenericID(vertices->Get(lastPos));
            vertexToIndexCountMap[lastVertex].index = pos;

            for (auto& edge : *graph->MutableEdges()) {
                if (edge.GetVertex1() == lastPos) {
                    edge.SetVertex1(pos);
                }
                if (edge.GetVertex2() == lastPos) {
                    edge.SetVertex2(pos);
                }
            }
        }
        vertexToIndexCountMap.erase(vertex);
        vertices->SwapElements(pos, lastPos);
        vertices->RemoveLast();
    }

    const TVector<TSoupEvent> TGraphHandler::ToSoup() const {
        TVector<TSoupEvent> result;
        for (const auto& edge : graph->GetEdges()) {
            TSoupEvent soupEvent;
            soupEvent.SetUnixtime(edge.GetTimeStamp());
            auto mutableEdgeBetween = soupEvent.MutableEdge();
            mutableEdgeBetween->SetSourceType(edge.GetSourceType());
            mutableEdgeBetween->SetLogSource(edge.GetLogSource());
            mutableEdgeBetween->SetSeenCount(edge.GetSeenCount());

            mutableEdgeBetween->MutableVertex1()->CopyFrom(
                graph->GetVertices().Get(edge.GetVertex1()));
            mutableEdgeBetween->MutableVertex2()->CopyFrom(
                graph->GetVertices().Get(edge.GetVertex2()));

            result.push_back(std::move(soupEvent));
        }
        return result;
    }

    void TGraphHandler::clear() {
        edgeToIndexMap = {};
        vertexToIndexCountMap = {};
        graph->Clear();
    }

    TEdgeDataTuple edgeDataFromEdge(const NCrypta::NGraphEngine::TEdge& edge) {
        return TEdgeDataTuple{edge.GetVertex1(), edge.GetVertex2(), edge.GetSourceType(), edge.GetLogSource()};
    }

    THashMap<TGenericID, TAssociatedUids> TGraphHandler::ConvertToVulture() const {
        THashMap<TGenericID, TAssociatedUids> vultureUids;
        const auto& vertices{graph->GetVertices()};
        const auto& edges{graph->GetEdges()};
        const auto verticesSize{static_cast<size_t>(vertices.size())};

        const auto fullAdjList{BuildAdjList(verticesSize, edges)};
        const auto indeviceAdjList{BuildAdjList(
            verticesSize, edges,
            [](const auto& edge) { return edge.GetIndevice(); })};

        for (size_t vertex{0}; vertex < verticesSize; ++vertex) {
            const TGenericID keyId{vertices[vertex]};
            const auto keyYabsType{CryptaToAds(keyId.GetType())};
            if (!keyYabsType.Defined()) {
                // non vulture type vertex
                continue;
            }
            const auto distances{Bfs(vertex, fullAdjList)};
            const auto indeviceDistances{Bfs(vertex, indeviceAdjList)};
            auto& associated{vultureUids[keyId]};
            {
                auto* id{associated.MutableKeyRecord()};
                id->set_user_id(keyId.GetValue());
                id->set_id_type(*keyYabsType);
                id->set_crypta_graph_distance(1);
                id->set_link_type(kLtMichurin);
                id->add_link_types(kLtMichurin);
                id->set_is_indevice(true);
            }

            for (size_t target{0}; target < distances.size(); ++target) {
                if ((distances[target] == 0) || (distances[target] == kInf)) {
                    // no add themself or not reachable (impossible) vertex
                    continue;
                }
                const TGenericID generic{vertices[target]};

                // see https://st.yandex-team.ru/CRYPTA-16244
                if (!IsProfileType(generic.GetType())) {
                    continue;
                }

                const auto yabsType{CryptaToAds(generic.GetType())};
                if (!yabsType.Defined()) {
                    // non vulture type vertex
                    continue;
                }

                // todo: mskorokhod missing keys
                // set_is_crypta_main_yandexuid ???
                // set_crypta_graph_weight

                auto* id{associated.AddValueRecords()};
                id->set_user_id(generic.GetValue());
                id->set_id_type(*yabsType);
                id->set_crypta_graph_distance(distances[target]);
                id->set_link_type(kLtMichurin);
                id->add_link_types(kLtMichurin);
                id->set_is_indevice(indeviceDistances[target] != kInf);
            }

            // todo: mskorokhod add cryptaid as vertex
        }

        return vultureUids;
    }
}
