#include "detection_relations.h"
#include "eye/recognition.h"

#include <maps/wikimap/mapspro/services/mrc/eye/lib/common/include/id.h>

#include <maps/wikimap/mapspro/services/mrc/eye/lib/sync_detection/include/metadata.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/sync_detection/impl/sync.h>

#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/frame_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/recognition_gateway.h>

#include <maps/wikimap/mapspro/services/mrc/eye/lib/common/include/id.h>

#include <maps/libs/log8/include/log8.h>
#include <maps/libs/sql_chemistry/include/order.h>

#include <library/cpp/iterator/zip.h>

#include <algorithm>
#include <functional>
#include <map>

namespace maps::mrc::eye {

namespace {

inline bool equal(
        const db::eye::DetectedTrafficLight&,
        const db::eye::DetectedTrafficLight&)
{
    return true;
}

inline bool equal(
        const db::eye::DetectedHouseNumber& lhs,
        const db::eye::DetectedHouseNumber& rhs)
{
    return lhs.number == rhs.number;
}

inline bool equal(
        const db::eye::DetectedSign& lhs,
        const db::eye::DetectedSign& rhs)
{
    return lhs.type == rhs.type
        && lhs.temporary == rhs.temporary;
}

inline bool equal(
        const db::eye::DetectedRoadMarking& lhs,
        const db::eye::DetectedRoadMarking& rhs)
{
    return lhs.type == rhs.type;
}

template<class DetectionAttrs>
const db::eye::Detection* findClosest(
        const db::eye::Detections& detections,
        const DetectionAttrs& detectionAttrs,
        const db::TIdSet& skipIds,
        double iouThreshold)
{
    const db::eye::Detection* result = nullptr;
    double bestScore = iouThreshold;
    for (const auto& other: detections) {
        if (skipIds.count(other.id())) {
            continue;
        }

        const auto attrs = other.attrs<DetectionAttrs>();
        const double score = common::getIoU(attrs.box, detectionAttrs.box);

        if (score >= bestScore and equal(attrs, detectionAttrs)) {
            result = &other;
            bestScore = score;
        }
    }

    return result;
}

template<class DetectionAttrs>
std::pair<std::vector<DetectionAttrs>, bool> merge(db::eye::Recognitions recognitions)
{
    ASSERT(not recognitions.empty());

    std::sort(
        recognitions.begin(), recognitions.end(),
        [](const auto& lhs, const auto& rhs) {
            return lhs.orderKey() > rhs.orderKey();
        }
    );

    using Wrap = db::eye::DetectedObjects<DetectionAttrs>;
    return {
        recognitions.front().value<Wrap>(),
        recognitions.front().source() >= db::eye::RecognitionSource::Toloka
    };
}

template<class DetectionAttrs>
UpdatedDetections makeUpdatedDetectionsImpl(
        const db::eye::DetectionGroup& group,
        const db::eye::Detections& detections,
        const db::eye::Recognitions& recognitions)
{
    UpdatedDetections updatedDetections;

    db::TIdSet usedIds;
    auto [detectionAttrsVec, approved] = merge<DetectionAttrs>(recognitions);
    for (const auto detectionAttrs: detectionAttrsVec) {
        const db::eye::Detection* closest =
            findClosest(detections, detectionAttrs, usedIds, 0.8);

        if (closest) {
            usedIds.insert(closest->id());
            updatedDetections.unchanged.push_back(*closest);
        } else {
            updatedDetections.updated.emplace_back(group.id(), detectionAttrs);
        }
    }

    for (auto detection: detections) {
        if (not usedIds.count(detection.id())) {
            updatedDetections.updated.push_back(detection.setDeleted());
        }
    }

    updatedDetections.approved = approved;

    return updatedDetections;
}

BatchItems createDetectionGroups(pqxx::transaction_base& txn, BatchItems&& items) {
    db::eye::DetectionGroups groups;
    std::vector<size_t> itemIndices;

    for (size_t itemIndx = 0; itemIndx < items.size(); itemIndx++) {
        if (items[itemIndx].group.has_value()) {
            continue;
        }

        auto type = toDetectionType(items[itemIndx].recognitions.at(0).type());
        REQUIRE(type.has_value(), "Invalid detection type");

        itemIndices.push_back(itemIndx);
        groups.push_back({items[itemIndx].frame.id(), type.value()});
    }

    INFO() << "Created " << groups.size() << " groups";
    db::eye::DetectionGroupGateway(txn).insertx(groups);

    for (size_t groupIndx = 0; groupIndx < groups.size(); groupIndx++) {
        items[itemIndices[groupIndx]].group = groups[groupIndx];
    }

    return items;
}

BatchItems updateDetections(pqxx::transaction_base& txn, BatchItems&& items) {
    db::eye::DetectionGroups updatedGroups;
    db::eye::Detections updatedDetections;

    std::vector<size_t> groupIndxToItemIndx;
    std::vector<size_t> detectionIndxToItemIndx;
    for (size_t i = 0; i < items.size(); i++) {
        if (items[i].frame.deleted()) {
            if (items[i].group.has_value() &&
                (!items[i].detections.empty() || !items[i].group.value().approved()))
            {
                groupIndxToItemIndx.push_back(i);

                auto group = items[i].group.value();
                group.setApproved(true);

                updatedGroups.push_back(group);
            }

            for (auto&& detection : items[i].detections) {
                detectionIndxToItemIndx.push_back(i);

                detection.setDeleted();
                updatedDetections.push_back(std::move(detection));
            }
            items[i].detections.clear();
        } else {
            ASSERT(items[i].group.has_value());

            auto group = items[i].group.value();

            auto [unchanged, updated, approved]
                = makeUpdatedDetections(group, items[i].detections, items[i].recognitions);

            items[i].detections = std::move(unchanged);

            for (auto&& detection : std::move(updated)) {
                detectionIndxToItemIndx.push_back(i);

                updatedDetections.push_back(std::move(detection));
            }

            if (!updated.empty() || group.approved() != approved) {
                groupIndxToItemIndx.push_back(i);

                group.setApproved(approved);
                updatedGroups.push_back(std::move(group));
            }
        }
    }

    INFO() << "Updating " << updatedDetections.size() << " detections";
    db::eye::DetectionGateway(txn).upsertx(updatedDetections);
    INFO() << "Updating " << updatedGroups.size() << " groups";
    db::eye::DetectionGroupGateway(txn).upsertx(updatedGroups);

    for (size_t i = 0; i < updatedGroups.size(); i++) {
        size_t itemIndx = groupIndxToItemIndx[i];
        items[itemIndx].group = std::move(updatedGroups[i]);
    }
    for (size_t i = 0; i < updatedDetections.size(); i++) {
        if (updatedDetections[i].deleted()) {
            continue;
        }

        size_t itemIndx = detectionIndxToItemIndx[i];
        items[itemIndx].detections.push_back(std::move(updatedDetections[i]));
    }

    return items;
}

bool sameRelations(
    const db::eye::DetectionRelation& a,
    const db::eye::DetectionRelation& b)
{
    return (a.masterDetectionId() == b.masterDetectionId()) &&
           (a.slaveDetectionId() == b.slaveDetectionId());
}

std::optional<size_t> findRelationIdx(
    const db::eye::DetectionRelations &relations,
    const db::eye::DetectionRelation& a)
{
    for (size_t i = 0; i < relations.size(); i++) {
        if (sameRelations(a, relations[i])) {
            return std::optional<size_t>{i};
        }
    }
    return std::nullopt;
}

UpdatedRelations makeUpdatedRelations(
    const db::eye::Frame& frame,
    const db::eye::Detections& detections,
    const db::eye::DetectionRelations& relations)
{
    UpdatedRelations updatedRelations;

    auto newRelations = findSignDetectionRelations(frame, detections);

    for (auto relation : relations) {
        if (auto idx = findRelationIdx(newRelations, relation)) {
            if (relation.deleted()) {
                relation.setDeleted(false);
                updatedRelations.changed.push_back(std::move(relation));
            } else {
                updatedRelations.unchanged.push_back(std::move(relation));
            }

            newRelations.erase(newRelations.begin() + idx.value());
        } else {
            relation.setDeleted(true);
            updatedRelations.changed.push_back(std::move(relation));
        }
    }

    for (auto& relation : newRelations) {
        updatedRelations.changed.push_back(std::move(relation));
    }

    return updatedRelations;
}

BatchItems updateRelations(pqxx::transaction_base& txn, BatchItems&& items) {
    db::eye::DetectionGroups updatedGroups;
    db::eye::DetectionRelations updatedRelations;

    std::vector<size_t> updatedGroupItemIndices;
    std::vector<size_t> updatedRelationItemIndices;
    for (size_t i = 0; i < items.size(); i++) {
        auto& item = items[i];

        if (!item.frame.deleted() &&
            item.group.value().type() == db::eye::DetectionType::Sign)
        {
            auto [unchanged, changed]
                = makeUpdatedRelations(item.frame, item.detections, item.relations);

            item.relations = std::move(unchanged);

            if (!changed.empty()) {
                updatedGroupItemIndices.push_back(i);
                updatedGroups.push_back(item.group.value());

                for (auto&& relation : std::move(changed)) {
                    updatedRelationItemIndices.push_back(i);
                    updatedRelations.push_back(std::move(relation));
                }
            }
        } else if (item.frame.deleted() && !item.relations.empty()) {
            updatedGroupItemIndices.push_back(i);
            updatedGroups.push_back(item.group.value());

            for (auto&& relation : item.relations) {
                updatedRelationItemIndices.push_back(i);

                relation.setDeleted(true);
                updatedRelations.push_back(std::move(relation));
            }

            item.relations.clear();
        }
    }

    INFO() << "Updating " << updatedGroups.size() << " groups";
    db::eye::DetectionGroupGateway(txn).upsertx(updatedGroups);
    INFO() << "Updating " << updatedRelations.size() << " relations";
    db::eye::DetectionRelationGateway(txn).upsertx(updatedRelations);

    for (size_t i = 0; i < updatedGroups.size(); i++) {
        size_t itemIndx = updatedGroupItemIndices[i];
        items[itemIndx].group = std::move(updatedGroups[i]);
    }
    for (size_t i = 0; i < updatedRelations.size(); i++) {
        if (updatedRelations[i].deleted()) {
            continue;
        }

        size_t itemIndx = updatedRelationItemIndices[i];
        items[itemIndx].relations.push_back(std::move(updatedRelations[i]));
    }

    return items;
}

BatchItems removeEmptyRecognitions(pqxx::transaction_base& txn, BatchItems&& items) {
    db::TIds emptyRecognitionIds;

    for (auto& item : items) {
        std::erase_if(item.recognitions,
            [&](const auto& recognition) {
                if (recognition.empty() && mayBeRemoved(recognition.source())) {
                    emptyRecognitionIds.push_back(recognition.id());
                    return true;
                }
                return false;
            }
        );
    }

    INFO() << "Removing " << emptyRecognitionIds.size() << " recognitions";
    db::eye::RecognitionGateway(txn).removeByIds(emptyRecognitionIds);

    return items;
}

} // namespace

bool mayBeRemoved(db::eye::RecognitionSource source) {
    return source == db::eye::RecognitionSource::Import
        || source == db::eye::RecognitionSource::Model;
}

UpdatedDetections makeUpdatedDetections(
        const db::eye::DetectionGroup& group,
        const db::eye::Detections& detections,
        const db::eye::Recognitions& recognitions)
{
    switch(group.type()) {
        case db::eye::DetectionType::Sign:
            return makeUpdatedDetectionsImpl<db::eye::DetectedSign>(group, detections, recognitions);

        case db::eye::DetectionType::TrafficLight:
            return makeUpdatedDetectionsImpl<db::eye::DetectedTrafficLight>(group, detections, recognitions);

        case db::eye::DetectionType::HouseNumber:
            return makeUpdatedDetectionsImpl<db::eye::DetectedHouseNumber>(group, detections, recognitions);

        case db::eye::DetectionType::RoadMarking:
            return makeUpdatedDetectionsImpl<db::eye::DetectedRoadMarking>(group, detections, recognitions);
    }
}

void sync(pqxx::transaction_base& txn, BatchItems items)
{
    INFO() << "Create new detection groups";
    items = createDetectionGroups(txn, std::move(items));
    INFO() << "Update detections";
    items = updateDetections(txn, std::move(items));
    INFO() << "Update relations";
    items = updateRelations(txn, std::move(items));
    INFO() << "Remove empty recognitions";
    removeEmptyRecognitions(txn, std::move(items));
}

} // namespace maps::mrc::eye
