#include "compact_bbox.h"
#include "packer.h"

#include <contrib/libs/flatbuffers64/include/flatbuffers/flatbuffers.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/succinct_buffers/include/writers.h>

#include <util/generic/array_ref.h>
#include <util/generic/buffer.h>
#include <util/generic/xrange.h>

#include <fstream>
#include <variant>

namespace maps::mrc::fb_rtree {
namespace impl {

using geolib3::BoundingBox;
struct InnerNode;
using InnerNodesRef = TArrayRef<const InnerNode>;
using LeafNodesRef = TArrayRef<const LeafNode>;

struct InnerNode {
    std::variant<InnerNodesRef, LeafNodesRef> children;
    bool hasUnfilledChildren;
    BoundingBox bbox;
};

template <class Node>
BoundingBox commonBbox(TArrayRef<const Node> nodes)
{
    auto result = std::optional<BoundingBox>{};
    for (const auto& node : nodes) {
        if (!result) {
            result = node.bbox;
        }
        else {
            result = expand(*result, node.bbox);
        }
    }
    return *result;
}

void rearrangeSlices(TArrayRef<LeafNode>, size_t) {}

// Rearrange array slices so that the node having fewer than branching children
// goes into the last slice
void rearrangeSlices(TArrayRef<InnerNode> nodes, size_t sliceSize)
{
    auto maybeUnfilledIndex = std::optional<size_t>{};
    for (const size_t i : xrange(nodes.size())) {
        if (nodes[i].hasUnfilledChildren) {
            ASSERT(!maybeUnfilledIndex);
            maybeUnfilledIndex = i;
        }
    }

    if (!maybeUnfilledIndex) {
        // All nodes are filled
        return;
    }

    size_t unfilledIndex = *maybeUnfilledIndex;

    const size_t totalSlices = (nodes.size() + sliceSize - 1) / sliceSize;
    const size_t unfilledSlice = unfilledIndex / sliceSize;
    const size_t lastSliceSize = nodes.size() - (totalSlices - 1) * sliceSize;

    if (unfilledIndex - unfilledSlice * sliceSize < lastSliceSize) {
        // Optimal solution is possible
        const size_t targetUnfilledIndex =
            unfilledSlice * sliceSize + lastSliceSize - 1;
        std::swap(nodes[unfilledIndex], nodes[targetUnfilledIndex]);
        unfilledIndex = targetUnfilledIndex;
    }

    std::rotate(nodes.begin() + unfilledIndex - lastSliceSize + 1,
                nodes.begin() + unfilledIndex + 1,
                nodes.end());
    ASSERT(nodes.back().hasUnfilledChildren);
}

bool hasUnfilledChildren(TArrayRef<const LeafNode> children, size_t branching)
{
    return children.size() != branching;
}

bool hasUnfilledChildren(TArrayRef<const InnerNode> children, size_t branching)
{
    if (children.size() != branching) {
        return true;
    }
    for (const auto& child : children) {
        if (child.hasUnfilledChildren) {
            return true;
        }
    }
    return false;
}

// Groups nodes using STR (Sort Tiles Recursive) algorithm
template <class Node>
std::vector<InnerNode> groupNodes(std::vector<Node>* nodes, size_t branching)
{
    // Round up
    size_t parentNodes = (nodes->size() + branching - 1) / branching;
    size_t sliceSize = branching * std::ceil(std::sqrt(parentNodes));

    auto result = std::vector<InnerNode>{};
    result.reserve(parentNodes);
    if (!sliceSize) {
        result.emplace_back();
        return result;
    }

    std::stable_sort(
        nodes->begin(), nodes->end(), [](const auto& lhs, const auto& rhs) {
            return lhs.bbox.minX() < rhs.bbox.minX();
        });
    rearrangeSlices(*nodes, sliceSize);

    for (size_t sliceBegin : xrange(size_t(0), nodes->size(), sliceSize)) {
        size_t sliceEnd = std::min(sliceBegin + sliceSize, nodes->size());
        std::stable_sort(nodes->begin() + sliceBegin,
                         nodes->begin() + sliceEnd,
                         [](const auto& lhs, const auto& rhs) {
                             return lhs.bbox.minY() < rhs.bbox.minY();
                         });
        rearrangeSlices({nodes->data() + sliceBegin, sliceEnd - sliceBegin},
                        branching);
        for (size_t groupBegin : xrange(sliceBegin, sliceEnd, branching)) {
            size_t groupSize =
                std::min(groupBegin + branching, sliceEnd) - groupBegin;
            auto children =
                TArrayRef<const Node>{nodes->data() + groupBegin, groupSize};
            result.push_back({children,
                              hasUnfilledChildren(children, branching),
                              commonBbox(children)});
        }
    }

    return result;
}

bool fillNextLayer(const InnerNode& child,
                   std::vector<InnerNode>& nextLayer,
                   Ids&,
                   const fb::CompactBoundingBox& compactBbox,
                   const BoundingBox& parentBbox)
{
    nextLayer.push_back({child.children,
                         child.hasUnfilledChildren,
                         uncompact(compactBbox, parentBbox)});
    return false;
}

bool fillNextLayer(const LeafNode& child,
                   std::vector<InnerNode>& nextLayer,
                   Ids& ids,
                   const fb::CompactBoundingBox& compactBbox,
                   const BoundingBox& parentBbox)
{
    nextLayer.push_back({InnerNodesRef{nullptr, nullptr},
                         false,
                         uncompact(compactBbox, parentBbox)});
    ids.push_back(child.id);
    return true;
}

flatbuffers64::FlatBufferBuilder buildRtree(const std::string& version,
                                            const InnerNode& root,
                                            const LeafNodes& leaves,
                                            size_t branching)
{
    auto builder = flatbuffers64::FlatBufferBuilder{};

    auto nodes = std::vector<fb::CompactBoundingBox>{};
    auto levelBegin = std::vector<uint64_t>{};

    auto layer = std::vector<InnerNode>{};
    auto nextLayer = std::vector<InnerNode>{};
    auto ids = Ids{};

    ids.reserve(leaves.size());
    layer.push_back(root);

    for (bool leafLevel = false; !leafLevel;) {
        for (const auto& node : layer) {
            leafLevel = std::visit(
                [&](const auto& children) {
                    bool result = true;
                    for (const auto& child : children) {
                        const auto& compactBbox =
                            compact(child.bbox, node.bbox);
                        nodes.push_back(compactBbox);
                        result = fillNextLayer(
                            child, nextLayer, ids, compactBbox, node.bbox);
                    }
                    return result;
                },
                node.children);
        }
        layer.swap(nextLayer);
        nextLayer.clear();
        levelBegin.push_back(nodes.size());
    }

    auto leafRoughBboxes = std::vector<BoundingBox>{};
    leafRoughBboxes.reserve(layer.size());
    for (const auto& node : layer) {
        leafRoughBboxes.push_back(node.bbox);
    }

    const auto nodesOffset = builder.CreateVectorOfStructs(nodes);
    const auto levelBeginOffset = writeVarWidthVector(builder, levelBegin);
    const auto idsOffset = writeFixedWidthVector(builder, ids);
    const auto bbox = fb::BoundingBox{
        root.bbox.minX(), root.bbox.maxX(), root.bbox.minY(), root.bbox.maxY()};
    const auto rtreeOffset = fb::CreateRtree(builder,
                                             builder.CreateString(version),
                                             &bbox,
                                             branching,
                                             nodesOffset,
                                             levelBeginOffset,
                                             idsOffset);

    builder.Finish(rtreeOffset);
    return builder;
}

flatbuffers64::FlatBufferBuilder buildRtree(const std::string& version,
                                            LeafNodes&& leaves,
                                            size_t branching)
{
    INFO() << "grouping " << leaves.size() << " leaves";
    auto levels = std::vector<std::vector<InnerNode>>{};
    levels.emplace_back(groupNodes(&leaves, branching));
    while (levels.back().size() > 1) {
        INFO() << "grouping " << levels.back().size() << " inner nodes";
        levels.emplace_back(groupNodes(&levels.back(), branching));
    }
    INFO() << "writing rtree...";
    return buildRtree(version, levels.back().front(), leaves, branching);
}

}  // namespace impl

void packRtree(const std::string& version,
               LeafNodes leaves,
               const std::string& outputFilename,
               size_t branching)
{
    auto builder = impl::buildRtree(version, std::move(leaves), branching);
    writeFlatBuffersToFile(builder, outputFilename);
}

std::unique_ptr<Rtree> buildRtree(const std::string& version,
                                  LeafNodes leaves,
                                  size_t branching)
{
    return std::make_unique<Rtree>(
        buildRtreeToBuffer(version, std::move(leaves), branching));
}

TBuffer buildRtreeToBuffer(const std::string& version,
                           LeafNodes leaves,
                           size_t branching)
{
    auto builder = impl::buildRtree(version, std::move(leaves), branching);
    return TBuffer(reinterpret_cast<const char*>(builder.GetBufferPointer()),
                   builder.GetSize());
}

}  // namespace maps::mrc::fb_rtree
