#include "graph_length.h"
#include "strings.h"
#include "yt_tables.h"

#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/schema.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/serialization.h>

#include <maps/libs/introspection/include/comparison.h>
#include <maps/libs/introspection/include/hashing.h>

#include <maps/libs/common/include/make_batches.h>
#include <maps/libs/geolib/include/spatial_relation.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/parallel_for_each.h>

#include <atomic>
#include <mutex>

namespace maps::mrc::graph_coverage_export {

namespace {

using maps::introspection::operator==;

struct AggregateKey
{
    int32_t geoId;
    uint32_t fc;
    bool isToll;
    bool privateArea;
    db::GraphType graphType;

    template <typename T>
    static auto introspect(const T& key) {
        return std::tie(key.geoId, key.fc, key.isToll, key.privateArea, key.graphType);
    }
};

double getIntersectionLength(const geolib3::BoundingBox& bbox,
                             const road_graph::EdgeData& edgeData)
{
    const auto& geom = edgeData.geometry();
    if (geolib3::spatialRelation(bbox, geom, geolib3::Disjoint)) {
        return .0;
    }
    if (geolib3::spatialRelation(bbox, geom, geolib3::Contains)) {
        return edgeData.length();
    }

    double length = .0;
    for (const auto& polyline : geolib3::intersection(bbox, geom)) {
        length += geolib3::geoLength(polyline);
    }
    return length;
}

using RoadGraphLenghMap = std::unordered_map<AggregateKey, double,
                                             introspection::Hasher>;


RoadGraphLenghMap computeRoadGraphLength(
    const ContextBase& ctx,
    const std::optional<geolib3::BoundingBox>& geoBbox)
{
    constexpr std::size_t BATCH_SIZE = 10000000;
    RoadGraphLenghMap result;

    std::mutex guard;
    std::atomic<int> processed{};

    const auto edges = ctx.graph().edges();
    auto batches = maps::common::makeBatches(edges, BATCH_SIZE);

    common::parallelForEach<THREADS_NUMBER>(batches.begin(), batches.end(),
        [&](auto&, const auto& edgesBatch) {
            RoadGraphLenghMap localResult;
            for (const auto& edge : edgesBatch) {
                ++processed;
                if (!ctx.graph().isBase(edge.id)) {
                    continue;
                }
                auto edgeData = ctx.graph().edgeData(edge.id);
                if (!areCompatible(edgeData.accessIdMask(), ctx.graphType())) {
                    continue;
                }
                double length = geoBbox ? getIntersectionLength(*geoBbox, edgeData)
                                        : edgeData.length();
                if (length < geolib3::EPS)
                    continue;

                auto pos = edgeData.geometry().pointAt(0);
                auto geoIds = ctx.evalGeoIds(pos);
                auto privateArea = ctx.privateAreaIndex().isInPrivateArea(
                    edgeData.geometry(), edgeData.category());
                for (auto geoId : geoIds) {
                    AggregateKey key{geoId,
                                     edgeData.category(),
                                     edgeData.isToll(),
                                     privateArea,
                                     ctx.graphType()};
                    localResult[key] += length;
                }
            }

            std::lock_guard<std::mutex> lock{guard};
            for (const auto& [key, length] : localResult) {
                result[key] += length;
            }
            INFO() << "Processed edges: " << processed;
        });
    return result;
}

} // namespace


void saveGraphLengthToYt(ContextBase& ctx,
                         NYT::TTableWriter<NYT::TNode>& writer,
                         const std::optional<geolib3::BoundingBox>& geoBbox)
{
    auto graphLengthMap = computeRoadGraphLength(ctx, geoBbox);

    for (const auto& item : graphLengthMap) {
        writer.AddRow(
            yt::serialize(
                GraphLength{
                    .geoId = item.first.geoId,
                    .fc = static_cast<uint8_t>(item.first.fc),
                    .isToll = item.first.isToll,
                    .privateArea = item.first.privateArea,
                    .graphType = static_cast<uint8_t>(item.first.graphType),
                    .length = item.second
                }
            )
        );
    }
}


} // namespace maps::mrc::graph_coverage_export

