#include "traffic_sign_groups.h"
#include "loader.h"
#include "utils.h"

#include <maps/libs/log8/include/log8.h>
#include <maps/libs/sql_chemistry/include/batch_load.h>

#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/collection.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/exif.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/frame_gateway.h>

#include <pqxx/transaction_base>

#include <string>
#include <unordered_map>

namespace maps::mrc::signs_export_gen {

namespace {

struct DetectedSign {
    db::TId detectionId;
    db::TId frameId;
    db::TId objectId;
    traffic_signs::TrafficSign type;
    common::ImageBox box; // transformed by image orientation
};

using DetectedSigns = std::vector<DetectedSign>;

struct RawSignGroup {
    DetectedSigns detectedSigns;
    GroupAlignment alignment;
};


std::unordered_map<db::TId, db::eye::Detections>
makeMapByDetectionGroupId(db::eye::Detections detections)
{
    std::unordered_map<db::TId, db::eye::Detections> result;
    for (auto&& detection : detections) {
        auto detectionGroupId = detection.groupId();
        result[detectionGroupId].push_back(std::move(detection));
    }
    return result;
}

DetectedSign toDetectedSign(
    const db::eye::Detection& detection,
    const db::eye::Frame& frame,
    db::TId objectId)
{
    auto attrs = detection.attrs<db::eye::DetectedSign>();
    return DetectedSign {
        .detectionId = detection.id(),
        .frameId = frame.id(),
        .objectId = objectId,
        .type = attrs.type,
        .box = common::transformByImageOrientation(
            attrs.box, frame.originalSize(), frame.orientation())
    };
}

constexpr double MULTIPLIER_H = 5;
constexpr double MULTIPLIER_V = 2.5;

/**
 * @brief Checks if two boxes are aligned on frame
 * Precondition: ordering must hold:
 *   - for horizontal alignment: lhs.minX() <= rhs.minX()
 *   - for vertical alignment: lhs.minY() <= rhs.minY()
 */
bool areBoxesAligned(
    const common::ImageBox& lhs,
    const common::ImageBox& rhs,
    GroupAlignment alignment)
{
    int lhsCenterX = (lhs.minX() + lhs.maxX()) / 2;
    int rhsCenterX = (rhs.minX() + rhs.maxX()) / 2;
    int lhsCenterY = (lhs.minY() + lhs.maxY()) / 2;
    int rhsCenterY = (rhs.minY() + rhs.maxY()) / 2;

    int minWidth = std::min(lhs.maxX() - lhs.minX(), rhs.maxX() - lhs.minX());
    int minHeight = std::min(lhs.maxY() - lhs.minY(), rhs.maxY() - rhs.minY());

    if (alignment == GroupAlignment::Horizontal)
    {
        return std::abs(rhsCenterY - lhsCenterY) <= minHeight / 2
            && rhsCenterX - lhsCenterX <= MULTIPLIER_H * minWidth;
    }
    else if (alignment == GroupAlignment::Vertical)
    {
        return std::abs(rhsCenterX - lhsCenterX) <= minWidth / 2
            && rhsCenterY - lhsCenterY <= MULTIPLIER_V * minHeight;
    }
    return false;
}

/**
 * Make sign groups from the signs aligned on a single frame
 * Preconditions:
 *   - for horizontal alignment: @param detectedSigns must be sorted by box.minX()
 *   - for vertical alignment:   @param detectedSigns must be sorted by box.minY()
 */
std::vector<RawSignGroup> makeGroups(
    const DetectedSigns& detectedSigns,
    GroupAlignment alignment)
{
    const auto size = detectedSigns.size();

    db::TId nextGroupId = 0;
    std::map<db::TId, db::TId> objectIdToGroupId;
    std::map<db::TId, RawSignGroup> groupById;

    for (size_t i = 0; i < size; ++i) {
        const auto& lhs = detectedSigns[i];
        auto lhsGroupId = findOptional(objectIdToGroupId, lhs.objectId);

        for (size_t j = i + 1; j < size; ++j) {
            const auto& rhs = detectedSigns[j];

            auto rhsGroupId = findOptional(objectIdToGroupId, rhs.objectId);

            // 1. Second sign must not be in any group yet.
            // 2. Signs must have different types.
            // 3. Signs' bboxes must be aligned.
            if (!rhsGroupId && lhs.type != rhs.type
                    && areBoxesAligned(lhs.box, rhs.box, alignment))
            {
                if (lhsGroupId) {
                    // Add sign to existing group
                    groupById[*lhsGroupId].detectedSigns.push_back(rhs);
                    objectIdToGroupId.emplace(rhs.objectId, *lhsGroupId);
                } else {
                    // Make a new group with both signs
                    lhsGroupId = nextGroupId++;
                    auto& group = groupById[*lhsGroupId];
                    group.detectedSigns.push_back(lhs);
                    group.detectedSigns.push_back(rhs);
                    objectIdToGroupId.emplace(lhs.objectId, *lhsGroupId);
                    objectIdToGroupId.emplace(rhs.objectId, *lhsGroupId);
                    group.alignment = alignment;
                }
            }
        }
    }

    std::vector<RawSignGroup> result;
    result.reserve(groupById.size());
    for (auto& [_, group] : groupById) {
        result.push_back(std::move(group));
    }
    return result;
}

/**
 * Make sign groups from the signs detected on a single frame.
 * Each sign can fall into at most one group.
 * First vertical and then horizontal groups are searched for.
 * Horizontal groups may contain only lane direction prescription signs.
 *
 * Example: frame with 2 groups and 1 standalone sign.
 *
 *                  frame
 *  +----------------------------------+
 *  |    GROUP1:                       |
 *  |   +-------+     sign4            |
 *  |   | sign1 |                      |
 *  |   |       |          GROUP2:     |
 *  |   | sign2 |      +-------------+ |
 *  |   |       |      | sign5 sign6 | |
 *  |   | sign3 |      +-------------+ |
 *  |   +-------+                      |
 *  +----------------------------------+
 *
 */
std::vector<RawSignGroup> makeGroups(DetectedSigns detectedSigns)
{
    if (detectedSigns.size() < 2) {
        return {};
    }

    // Find vertical groups
    std::sort(detectedSigns.begin(), detectedSigns.end(),
        [](const auto& l, const auto& r) { return l.box.minY() < r.box.minY(); });

    auto result = makeGroups(detectedSigns, GroupAlignment::Vertical);

    std::set<db::TId> added;
    for (const auto& group : result) {
        for (const auto& sign : group.detectedSigns) {
            added.insert(sign.detectionId);
        }
    }

    // Find horizontal groups among lane direction prescription signs
    std::erase_if(detectedSigns, [&](const DetectedSign& sign) {
        return !traffic_signs::isLaneDirectionPrescription(sign.type)
            || added.count(sign.detectionId);
    });

    std::sort(detectedSigns.begin(), detectedSigns.end(),
        [](const auto& l, const auto& r) { return l.box.minX() < r.box.minX(); });

    auto groups = makeGroups(detectedSigns, GroupAlignment::Horizontal);
    std::move(groups.begin(), groups.end(), std::back_inserter(result));
    return result;
}

void collectGroupsInParallel(
    pgpool3::Pool& pool,
    const db::TIds& frameIdsBatch,
    std::vector<RawSignGroup>& result)
{
    constexpr auto BATCH_SIZE = 50'000;
    constexpr auto NUM_THREADS = 6;

    std::mutex guard;

    parallelForEachBatch<NUM_THREADS, BATCH_SIZE>(
        frameIdsBatch, [&](std::span<const db::TId> batch) {
            auto txn = pool.slaveTransaction();

            auto frameById = loadFrames(*txn, {batch.begin(), batch.end()});
            auto frameIds = keys(frameById);

            auto detGroups = loadDetectionGroups(*txn, frameIds);
            auto detGroupIds = transform(detGroups, &db::eye::DetectionGroup::id);
            auto detGroupIdToFrameId = makeMap(detGroups,
                &db::eye::DetectionGroup::id,
                &db::eye::DetectionGroup::frameId);

            auto detections = loadDetections(*txn, detGroupIds);
            auto detectionIds = transform(detections, &db::eye::Detection::id);
            auto detGroupIdToDetections
                = makeMapByDetectionGroupId(std::move(detections));

            auto detectionIdToObjectId
                = loadDetectionIdToObjectId(*txn, detectionIds);

            std::vector<RawSignGroup> localResult;
            for (const auto& [detGroupId, detections] : detGroupIdToDetections) {
                auto frameId = detGroupIdToFrameId.at(detGroupId);
                if (detections.size() < 2 || !frameById.count(frameId)) {
                    continue;
                }
                const auto& frame = frameById.at(frameId);

                DetectedSigns detectedSigns;
                for (const auto& detection : detections) {
                    auto objectId = findOptional(detectionIdToObjectId, detection.id());
                    if (objectId) {
                        detectedSigns.push_back(
                            toDetectedSign(detection, frame, *objectId));
                    }
                }

                auto groups = makeGroups(std::move(detectedSigns));
                std::move(groups.begin(), groups.end(), std::back_inserter(localResult));
            }
            auto lock = std::lock_guard{guard};
            std::move(localResult.begin(), localResult.end(), std::back_inserter(result));
    });
}

std::pair<SignGroups, ObjectIdToGroupId>
mergeGroups(const std::vector<RawSignGroup>& rawGroups)
{
    SignGroups groups;
    ObjectIdToGroupId objectIdToGroupId;

    INFO() << "Merging " << rawGroups.size() << " groups...";
    for (const auto& rawGroup : rawGroups) {
        std::set<db::TId> existingGroupIds;

        for (const DetectedSign& sign : rawGroup.detectedSigns) {
            auto existingGroupItr = objectIdToGroupId.find(sign.objectId);
            if (existingGroupItr != objectIdToGroupId.end()) {
                existingGroupIds.insert(existingGroupItr->second);
            }
        }

        if (existingGroupIds.empty()) {
            // These objects do not belong to any group yet, create a new one.
            db::TId groupId = groups.size() + 1; // number groups from 1.
            groups.push_back(
                SignGroup{.groupId = groupId, .alignment = rawGroup.alignment});
            for (const DetectedSign& sign : rawGroup.detectedSigns) {
                groups.back().objectIds.push_back(sign.objectId);
                objectIdToGroupId.emplace(sign.objectId, groupId);
            }
        } else if (existingGroupIds.size() == 1) {
            // There is a single existing group, merge objects into it
            auto& group = groups[*existingGroupIds.begin() - 1];
            if (group.alignment != rawGroup.alignment) {
                continue;
            }

            for (const DetectedSign& sign : rawGroup.detectedSigns) {
                if (!contains(group.objectIds, sign.objectId)) {
                    group.objectIds.push_back(sign.objectId);
                    objectIdToGroupId.emplace(sign.objectId, group.groupId);
                }
            }
        } else { // existingGroupIds.size() > 1
            // Objects in this group belong to different groups on other frames.
            // This is a conflict, so just skip this group.
        }
    }
    return {std::move(groups), std::move(objectIdToGroupId)};
}

} // namespace

std::pair<SignGroups, ObjectIdToGroupId>
collectTrafficSignGroups(pgpool3::Pool& pool)
{
    namespace tbl = db::eye::table;
    constexpr auto IDS_BATCH_SIZE = 5'000'000;

    std::vector<RawSignGroup> groups;
    size_t processed = 0;
    db::TId lastFrameId = 0;

    while (true) {
        auto frameIds = db::eye::FrameGateway{*pool.slaveTransaction()}
            .loadIds(
                tbl::Frame::id > lastFrameId,
                orderBy(tbl::Frame::id).asc().limit(IDS_BATCH_SIZE));
        if (frameIds.empty()) {
            break;
        }
        INFO() << "Loaded frame ids batch";

        collectGroupsInParallel(pool, frameIds, groups);

        lastFrameId = frameIds.back();
        processed += frameIds.size();
        INFO() << "Frames: " << processed << ", raw groups: " << groups.size();
    }

    return mergeGroups(groups);
}

}  // namespace maps::mrc::signs_export_gen
