#include <maps/wikimap/mapspro/services/mrc/eye/lib/common/include/id.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/collision.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/location/include/rotation.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/unit_test/include/frame.h>

#include "mocks.h"
#include "fixtures.h"

#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/frame_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/recognition_gateway.h>

#include <library/cpp/testing/gtest/gtest.h>
#include <library/cpp/testing/unittest/registar.h>
#include <library/cpp/testing/gmock_in_unittest/gmock.h>

namespace maps::mrc::eye::tests {


TEST(collistions_tests, find_collisions)
{
    const std::vector<db::TIdSet> objectsDetectionIds{
        db::TIdSet{1, 3},
        db::TIdSet{1, 2, 3},
        db::TIdSet{2, 3}
    };

    const db::IdTo<std::vector<size_t>> result
        = findCollisions(objectsDetectionIds);

    const db::IdTo<std::vector<size_t>> expected{
        {1, {0, 1}}, {2, {1, 2}}, {3, {0, 1, 2}}
    };

    EXPECT_EQ(result.size(), expected.size());

    for (const auto& [detectionId, objectIndxs] : result) {
        ASSERT_TRUE(expected.count(detectionId) != 0);

        const auto& expectedObjectIndxs = expected.at(detectionId);

        EXPECT_THAT(objectIndxs, ::testing::UnorderedElementsAreArray(expectedObjectIndxs));
    }
}

TEST(collistions_tests, apply_collision_solutions)
{
    const std::vector<db::TIdSet> objectsDetectionIds{
        db::TIdSet{1, 3},
        db::TIdSet{1, 2, 3},
        db::TIdSet{2, 3}
    };
    const db::IdTo<std::vector<size_t>> collisions{
        {1, {0, 1}}, {2, {1, 2}}, {3, {0, 1, 2}}
    };
    const db::IdTo<size_t> collisionSolutions{
        {1, 0}, {2, 1}, {3, 1}
    };

    const std::vector<db::TIdSet> result
        = applyCollisionSolutions(
            objectsDetectionIds,
            collisions, collisionSolutions
        );

    const std::vector<db::TIdSet> expected{
        db::TIdSet{1},
        db::TIdSet{2, 3}
    };


    EXPECT_THAT(result, ::testing::UnorderedElementsAreArray(expected));
}

struct ResolveCollisionsFixture: public BaseFixture {
    ResolveCollisionsFixture() {
        auto txn = newTxn();

        devices = {
            {db::eye::MrcDeviceAttrs{"M1"}},
        };
        db::eye::DeviceGateway(*txn).insertx(devices);

        frames = {
            {devices[0].id(), identical, makeUrlContext(1, "1"), {1200, 800}, time()},
            {devices[0].id(), identical, makeUrlContext(2, "2"), {1200, 800}, time()},
            {devices[0].id(), identical, makeUrlContext(3, "3"), {1200, 800}, time()},
            {devices[0].id(), identical, makeUrlContext(3, "3"), {1200, 800}, time()},
            {devices[0].id(), identical, makeUrlContext(3, "3"), {1200, 800}, time()},
            {devices[0].id(), identical, makeUrlContext(3, "3"), {1200, 800}, time()},
            {devices[0].id(), identical, makeUrlContext(3, "3"), {1200, 800}, time()},
        };
        db::eye::FrameGateway(*txn).insertx(frames);

        frameLocations = {
            {frames[0].id(), geolib3::Point2{0, 0}, toRotation(geolib3::Heading(90), identical)},
            {frames[1].id(), geolib3::Point2{1, 0}, toRotation(geolib3::Heading(90), identical)},
            {frames[2].id(), geolib3::Point2{0, 2}, toRotation(geolib3::Heading(90), identical)},
            {frames[3].id(), geolib3::Point2{0, 2}, toRotation(geolib3::Heading(90), identical)},
            {frames[4].id(), geolib3::Point2{0, 2}, toRotation(geolib3::Heading(90), identical)},
            {frames[5].id(), geolib3::Point2{0, 2}, toRotation(geolib3::Heading(90), identical)},
            {frames[6].id(), geolib3::Point2{0, 2}, toRotation(geolib3::Heading(90), identical)},
        };
        db::eye::FrameLocationGateway(*txn).insertx(frameLocations);

        groups = {
            {frames[0].id(), db::eye::DetectionType::HouseNumber},
            {frames[1].id(), db::eye::DetectionType::HouseNumber},
            {frames[2].id(), db::eye::DetectionType::HouseNumber},
            {frames[3].id(), db::eye::DetectionType::HouseNumber},
            {frames[4].id(), db::eye::DetectionType::HouseNumber},
            {frames[5].id(), db::eye::DetectionType::HouseNumber},
            {frames[6].id(), db::eye::DetectionType::HouseNumber},
        };
        db::eye::DetectionGroupGateway(*txn).insertx(groups);

        detections = {
            {groups[0].id(), db::eye::DetectedHouseNumber{{0, 0, 10, 10}, 1.0, "12"}},
            {groups[1].id(), db::eye::DetectedHouseNumber{{0, 0, 20, 20}, 1.0, "12"}},
            {groups[2].id(), db::eye::DetectedHouseNumber{{0, 0, 10, 10}, 1.0, "12"}},
            {groups[3].id(), db::eye::DetectedHouseNumber{{0, 0, 70, 70}, 1.0, "12"}},
            {groups[4].id(), db::eye::DetectedHouseNumber{{0, 0, 70, 70}, 1.0, "12"}},
            {groups[5].id(), db::eye::DetectedHouseNumber{{0, 0, 70, 70}, 1.0, "12"}},
            {groups[6].id(), db::eye::DetectedHouseNumber{{0, 0, 70, 70}, 1.0, "12"}},
        };
        db::eye::DetectionGateway(*txn).insertx(detections);

        txn->commit();
    }
};

class PredefinedDetectionMatcher : public DetectionMatcher {
public:
    PredefinedDetectionMatcher(const MatchedFrameDetections& matches)
    {
        fillMatches(matches);
    }

    MatchedFrameDetections makeMatches(
        const DetectionStore& /*store*/,
        const DetectionIdPairSet& detectionPairs,
        const FrameMatcher* = nullptr) const override
    {
        MatchedFrameDetections matches;
        for (const DetectionIdPair& detectionPair : detectionPairs) {
            const auto it = matches_.find(detectionPair);
            if (it != matches_.end()) {
                matches.push_back(it->second);
            } else {
                DetectionIdPair revDetectionPair;
                revDetectionPair.first = detectionPair.second;
                revDetectionPair.second = detectionPair.first;

                const auto revIt = matches_.find(revDetectionPair);
                if (revIt != matches_.end()) {
                    matches.push_back(reverseMatch(revIt->second));
                }
            }
        }

        return matches;
    }

private:
    std::map<std::pair<db::TId, db::TId>, MatchedFrameDetection> matches_;
    void fillMatches(const MatchedFrameDetections& matches) {
        for (size_t i = 0; i < matches.size(); i++) {
            const MatchedFrameDetection& match = matches[i];
            matches_.emplace(std::make_pair(match.id0().detectionId, match.id1().detectionId), match);
        }
    }
};

TEST_F(ResolveCollisionsFixture, find_collision_solutions)
{
    DetectionStore detectionStore;
    detectionStore.extendByDetections(*newTxn(), detections);

    const std::vector<db::TIdSet> objectsDetectionIds{
        db::TIdSet{
            detections[0].id(),
            detections[1].id(),
            detections[2].id(),
            detections[3].id()
        },
        db::TIdSet{
            detections[1].id(),
            detections[3].id(),
            detections[4].id(),
            detections[5].id(),
            detections[6].id()
        },
    };
    const db::IdTo<std::vector<size_t>> collisions{
        {detections[1].id(), {0, 1}},
        {detections[3].id(), {0, 1}}
    };

    const MockFrameMatcher frameMatcher;
    const PredefinedDetectionMatcher detectionMatcher({
        {
            {frames[1].id(), detections[1].id()},
            {frames[0].id(), detections[0].id()},
            1.
        },
        {
            {frames[1].id(), detections[1].id()},
            {frames[2].id(), detections[2].id()},
            1.
        },
        {
            {frames[1].id(), detections[1].id()},
            {frames[5].id(), detections[5].id()},
            0.9
        },
    });

    const db::IdTo<size_t> result
        = findCollisionSolutions(
            detectionStore,
            frameMatcher,
            detectionMatcher,
            objectsDetectionIds, collisions
        );

    const db::IdTo<size_t> expected{
        {detections[1].id(), 0},
        {detections[3].id(), 1},
    };

    EXPECT_EQ(result, expected);
}

TEST_F(ResolveCollisionsFixture, resolve_collisions)
{
    DetectionStore detectionStore;
    detectionStore.extendByDetections(*newTxn(), detections);

    const std::vector<db::TIdSet> objectsDetectionIds{
        db::TIdSet{
            detections[0].id(),
            detections[1].id(),
            detections[2].id(),
            detections[3].id()
        },
        db::TIdSet{
            detections[1].id(),
            detections[3].id(),
            detections[4].id(),
            detections[5].id(),
            detections[6].id()
        },
    };
    const MockFrameMatcher frameMatcher;
    const PredefinedDetectionMatcher detectionMatcher({
        {
            {frames[1].id(), detections[1].id()},
            {frames[0].id(), detections[0].id()},
            1.
        },
        {
            {frames[1].id(), detections[1].id()},
            {frames[2].id(), detections[2].id()},
            1.
        },
        {
            {frames[1].id(), detections[1].id()},
            {frames[5].id(), detections[5].id()},
            0.9
        },
    });

    const std::vector<db::TIdSet> result
        = resolveCollisions(
            detectionStore,
            frameMatcher,
            detectionMatcher,
            objectsDetectionIds
        );

    const std::vector<db::TIdSet> expected{
        db::TIdSet{
            detections[0].id(),
            detections[1].id(),
            detections[2].id()
        },
        db::TIdSet{
            detections[3].id(),
            detections[4].id(),
            detections[5].id(),
            detections[6].id()
        },
    };

    EXPECT_EQ(result, expected);
}

} // namespace maps::mrc::eye::tests
