#include "simple_edge_fixers.h"

#include "../utils/geom_helpers.h"

#include <maps/libs/geolib/include/distance.h>
#include <maps/libs/common/include/profiletimer.h>
#include <yandex/maps/wiki/threadutils/executor.h>
#include <yandex/maps/wiki/common/batch.h>

#include <iomanip>
#include <queue>
#include <vector>

namespace maps {
namespace wiki {
namespace topology_fixer {

namespace {

constexpr size_t BATCH_COUNT = 1000;

const double MAX_GEOMETRY_LENGTH = 500000; //500km

} // namespace

struct EdgePart
{
    geolib3::Polyline2 geom;

    NodeId fNodeId;
    NodeId tNodeId;

    double length() const { return geolib3::length(geom); };
};

void
HeavyEdgesSplitter::operator () (TopologyData& data) const
{
    auto edgeIds = data.edgeIds();

    INFO() << "Edge count before splitting " << edgeIds.size();

    for (auto edgeId : edgeIds) {
        const Edge& edge = data.edge(edgeId);
        ASSERT(edge.linestring().points().size() >= 2);
        const Node& fnode = data.node(edge.fnodeId());
        const Node& tnode = data.node(edge.tnodeId());
        if (edge.fZlev() != edge.tZlev()) {
            WARN() << "Edge " << edgeId << " zlevels differ, skipping";
            continue;
        }
        ZLevelType zlev = edge.fZlev();
        const auto& points = edge.linestring().points();
        REQUIRE(fnode.point() == points.front() && tnode.point() == points.back(),
            "Edge " << edgeId << " nodes are misplaced");

        if (points.size() <= maxVertices_) {
            continue;
        }

        std::queue<EdgePart> queueParts;
        queueParts.emplace(EdgePart{edge.linestring(), edge.fnodeId(), edge.tnodeId()});

        std::vector<EdgePart> outputParts;

        while (!queueParts.empty()) {
            auto& currentPart = queueParts.front();
            const auto& points = currentPart.geom.points();

            if (points.size() <= maxVertices_
                    && currentPart.length() < MAX_GEOMETRY_LENGTH) {
                outputParts.push_back(std::move(currentPart));
                queueParts.pop();
                continue;
            }

            if (points.size() > 2) {
                size_t midIndex = points.size() / 2;
                NodeId mNodeId = data.addNode(points[midIndex]);

                geolib3::PointsVector first(points.begin(), points.begin() + midIndex + 1);
                queueParts.emplace(EdgePart{geolib3::Polyline2{std::move(first)}, currentPart.fNodeId, mNodeId});

                geolib3::PointsVector second(points.begin() + midIndex, points.end());
                queueParts.emplace(EdgePart{geolib3::Polyline2{std::move(second)}, mNodeId, currentPart.tNodeId});
            } else {
                ASSERT(points.size() == 2);

                geolib3::Point2 midPoint(
                    (points[0].x() + points[1].x()) * 0.5,
                    (points[0].y() + points[1].y()) * 0.5);
                NodeId mNodeId = data.addNode(midPoint);

                queueParts.emplace(EdgePart{geolib3::Polyline2{geolib3::PointsVector{points[0], midPoint}}, currentPart.fNodeId, mNodeId});
                queueParts.emplace(EdgePart{geolib3::Polyline2{geolib3::PointsVector{midPoint, points[1]}}, mNodeId, currentPart.tNodeId});
            }

            queueParts.pop();
        }

        IdSet newEdgeIds;
        for (auto&& part : outputParts) {
            newEdgeIds.insert(data.addEdge(
                std::move(part.geom), part.fNodeId, part.tNodeId, zlev, zlev));
        }

        data.replaceEdge(edgeId, newEdgeIds);

        INFO() << "Splitting heavy edge: [edge: " << edgeId << "]"
            << "[new edges: " << toString(newEdgeIds) << "]";
    }

    INFO() << "Edge count after splitting " << data.edgeIds().size();
}

void
EdgesSimplifier::operator()(TopologyData& data, FaceLocker& locker, ThreadPool& pool) const
{
    Executor executor;
    common::applyBatchOp<EdgeIdsVector>(data.edgeIds(),
        BATCH_COUNT,
        [&, this](const EdgeIdsVector& edgeIds) {
            executor.addTask([&, edgeIds, this](){
                ProfileTimer pt;
                INFO() << "EdgesSimplifier start batch with the edge count " << edgeIds.size();

                for (auto edgeId : edgeIds) {
                    simplifyEdge(edgeId, data, locker);
                }

                INFO() << "EdgesSimplifier batched finished in " << pt.getElapsedTime()
                    << " edge count " << edgeIds.size();
            });
        });

    executor.executeAllInThreads(pool);
}

void EdgesSimplifier::simplifyEdge(EdgeId edgeId, TopologyData& data, FaceLocker& locker) const
{
    const Edge& edge = data.edge(edgeId);
    const auto& points = edge.linestring().points();
    ASSERT(points.size() >= 2);
    geolib3::PointsVector newPoints;
    newPoints.reserve(points.size());
    newPoints.push_back(points.front());
    bool simplified = false;
    for (size_t i = 1; i < points.size() - 1; ++i) {
        if (geolib3::distance(newPoints.back(), points[i]) < minSegmentLength_) {
            simplified = true;
            continue;
        }
        newPoints.push_back(points[i]);
    }
    newPoints.push_back(points.back());
    if (newPoints.size() < 2) {
        WARN() << "Edge " << edgeId << " becomes degenerate after simplification, skipped";
        return;
    }
    size_t i = newPoints.size() - 2;
    while(geolib3::distance(newPoints.back(), newPoints[i]) < minSegmentLength_ && i > 0)
    {
        --i;
        simplified = true;
    }
    if (newPoints.size() - i > 2) {
        newPoints.erase(newPoints.begin() + i + 1, std::prev(newPoints.end()));
    }
    if (!simplified) {
        return;
    }
    if (newPoints.size() < 2) {
        WARN() << "Edge " << edgeId << " becomes degenerate after simplification, skipped";
        return;
    }
    ASSERT(newPoints.front() == data.node(edge.fnodeId()).point() &&
        newPoints.back() == data.node(edge.tnodeId()).point());

    FaceLocker::Locks locks;
    if (data.ftGroup()->topologyType() == TopologyType::Contour) {
        locks = locker.lockFaces(edge.faceIds());
    }

    if (!isEdgeGeometryAllowed(data, edgeId, newPoints)) {
        WARN() << "Edge " << edgeId << " simplification leads to topology invalidation";
        return;
    }
    data.setEdgeGeometry(edgeId, geolib3::Polyline2(std::move(newPoints)));
    INFO() << "Edge " << edgeId << " geometry simplified";
}

namespace {

enum class NodeType { From, To };

void
fixEdgeNode(TopologyData& data, EdgeId edgeId, Node& node, const geolib3::Point2& pos,
    NodeType nodeType)
{
    if (node.edgeIds().size() == 1) {
        ASSERT(*node.edgeIds().begin() == edgeId);
        data.setNodePos(node.id(), pos);
    } else {
        auto newNodeId = data.addNode(pos);
        nodeType == NodeType::From
            ? data.setFNode(edgeId, newNodeId)
            : data.setTNode(edgeId, newNodeId);
        if (node.edgeIds().empty()) {
            data.removeNode(node.id());
        }
    }
}

} // namespace

void
JunctionsMismatchFixer::operator () (TopologyData& data) const
{
    for (auto edgeId : data.edgeIds()) {
        const Edge& edge = data.edge(edgeId);
        const auto& points = edge.linestring().points();
        ASSERT(points.size() >= 2);
        Node& fnode = data.node(edge.fnodeId());
        Node& tnode = data.node(edge.tnodeId());
        if (edge.fnodeId() != edge.tnodeId() &&
            fnode.point() == points.back() && tnode.point() == points.front())
        {
            WARN() << "Reversing direction, edge id " << edgeId;
            auto npoints = points;
            std::reverse(npoints.begin(), npoints.end());
            data.setEdgeGeometry(edgeId, geolib3::Polyline2(std::move(npoints)));
            continue;
        }
        if (points.front() == points.back() && edge.fnodeId() != edge.tnodeId()) {
            WARN() << "Fixing closed edge t_node_id, edge id " << edgeId;
            data.setTNode(edgeId, edge.fnodeId());
            if (tnode.edgeIds().empty()) {
                data.removeNode(tnode.id());
            }
        }
        if (points.front() == points.back() && fnode.point() != points.front()) {
            WARN() << "Fixing closed edge node position, edge id " << edgeId;
            fixEdgeNode(data, edgeId, fnode, points.front(), NodeType::From);
        }
        if (edge.fnodeId() == edge.tnodeId()) {
            continue;
        }
        if (points.front() != fnode.point()) {
            WARN() << "Fixing edge f_node position, edge id " << edgeId;
            fixEdgeNode(data, edgeId, fnode, points.front(), NodeType::From);
        }
        if (points.back() != tnode.point()) {
            WARN() << "Fixing edge t_node position, edge id " << edgeId;
            fixEdgeNode(data, edgeId, tnode, points.back(), NodeType::To);
        }
    }
}

} // namespace topology_fixer
} // namespace wiki
} // namespace maps
