#include "fixtures.h"
#include "mocks.h"

#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/match.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/visibility_predictor.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/location.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/batch.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/impl/metadata.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/object_manager/include/object_manager.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/unit_test/include/frame.h>

#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/frame_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/recognition_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/object_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/verified_detection_missing_on_frame_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/verified_detection_pair_match_gateway.h>

#include <maps/libs/geolib/include/vector.h>
#include <maps/libs/sql_chemistry/include/system_information.h>
#include <maps/libs/introspection/include/stream_output.h>

#include <library/cpp/testing/gtest/gtest.h>

#include <chrono>

using namespace std::literals::chrono_literals;

namespace maps::mrc::db::eye {

using maps::introspection::operator<<;

} //namespace maps::mrc::db::eye


namespace maps::mrc::eye::tests {

namespace {

struct VerifyObjectDetectionsFixture : public StraightForwardMatchFixture
{
    ObjectManagerConfig makeWorkerConfig(db::eye::DetectionTypes detectionTypes) const override
    {
        auto config = StraightForwardMatchFixture::makeWorkerConfig(detectionTypes);

        const db::TId geoId = 225;
        config.verificationRules.push_back(
            VerificationRule{
                .geoId = geoId,
                .objectType = db::eye::ObjectType::Sign,
                .objectPredicate = [](const auto& /* obj */) { return true; },
                .verificationAction = VerificationAction{
                    .verifyObjectDetections = true
                }
            }
        );
        auto geoIdProvider = makeMockGeoIdProvider();
        config.geoIdProvider = geoIdProvider;
        EXPECT_CALL(*geoIdProvider, load(::testing::An<const geolib3::Point2&>()))
        .WillRepeatedly(::testing::Return(db::TIds{geoId}));

        return config;
    }
};

struct VerifyObjectDuplicationFixture : public StraightForwardMatchFixture
{
    ObjectManagerConfig makeWorkerConfig(db::eye::DetectionTypes detectionTypes) const override
    {
        auto config = StraightForwardMatchFixture::makeWorkerConfig(detectionTypes);

        const db::TId geoId = 225;
        config.verificationRules.push_back(
            VerificationRule{
                .geoId = geoId,
                .objectType = db::eye::ObjectType::Sign,
                .objectPredicate = [](const auto& /* obj */) { return true; },
                .verificationAction = VerificationAction{
                    .verifyObjectDuplication = true
                }
            }
        );
        auto geoIdProvider = makeMockGeoIdProvider();
        config.geoIdProvider = geoIdProvider;
        EXPECT_CALL(*geoIdProvider, load(::testing::An<const geolib3::Point2&>()))
        .WillRepeatedly(::testing::Return(db::TIds{geoId}));

        config.detectionMatcher = std::make_shared<MockDetectionMatcherMatchAll>();

        return config;
    }

};


} // namespace


TEST_F(VerifyObjectDetectionsFixture, should_create_verification_request)
{
    auto workerConfig = makeWorkerConfig({db::eye::DetectionType::Sign});
    ObjectManager objectManager(workerConfig);
    objectManager.processBatchInLoopMode(100);

    auto txn = newTxn();
    auto verificationRequests = db::eye::VerifiedDetectionPairMatchGateway(*txn).load();

    // objectManager should have created 2 verification requests for each of two objects
    const auto clusters = loadDetectionClusters(*newTxn());
    EXPECT_EQ(clusters.size(), 2u);
    EXPECT_EQ(verificationRequests.size(), 4u);

}

TEST_F(VerifyObjectDetectionsFixture, should_consider_frames_privacy)
{
    const auto primaryDetectionId = detectionsInPassages[0].detectionsInFrames[0][0].id();

    /// Create object for selected primaryDetection in order to
    /// narrow variants of VerifiedDetectionPairMatches
    {
        auto txn = newTxn();
        auto detectionIt = std::find_if(
            detections.begin(), detections.end(),
            [primaryDetectionId](const auto& detection) { return detection.id() == primaryDetectionId; }
        );
        ASSERT_TRUE(detectionIt != detections.end());
        makeSignObject(*txn, *detectionIt, {10, 0});

        // mark frames from second passage as non-public
        framePrivacies[2].setType(db::FeaturePrivacy::Restricted);
        framePrivacies[3].setType(db::FeaturePrivacy::Secret);
        db::eye::FramePrivacyGateway(*txn).updatex(framePrivacies);

        txn->commit();
    }

    auto workerConfig = makeWorkerConfig({db::eye::DetectionType::Sign});
    ObjectManager objectManager(workerConfig);
    objectManager.processBatchInLoopMode(100);

    auto txn = newTxn();
    auto verificationRequests = db::eye::VerifiedDetectionPairMatchGateway(*txn).load();

    // objectManager should have created 2 verification requests for each of two objects
    const auto clusters = loadDetectionClusters(*newTxn());
    EXPECT_EQ(clusters.size(), 2u);
    EXPECT_EQ(verificationRequests.size(), 4u);

    std::vector<db::eye::VerificationSource> sources;
    for (const auto& request : verificationRequests) {
        sources.push_back(request.source());
    }
    EXPECT_THAT(sources, ::testing::UnorderedElementsAre(
        db::eye::VerificationSource::Toloka,
        db::eye::VerificationSource::Toloka,
        db::eye::VerificationSource::Yang,
        db::eye::VerificationSource::Yang
    ));
}


TEST_F(VerifyObjectDetectionsFixture, wont_create_verification_request_if_it_exists)
{
    const auto primaryDetectionId = detectionsInPassages[0].detectionsInFrames[0][0].id();
    // select detection from another passage
    const db::TIds detectionsToVerify{
        detectionsInPassages[1].detectionsInFrames[0][0].id(),
        detectionsInPassages[1].detectionsInFrames[1][0].id()
    };

    /// Create object for selected primaryDetection in order to
    /// narrow variants of VerifiedDetectionPairMatches
    {
        auto txn = newTxn();
        auto detectionIt = std::find_if(
            detections.begin(), detections.end(),
            [primaryDetectionId](const auto& detection) { return detection.id() == primaryDetectionId; }
        );
        ASSERT_TRUE(detectionIt != detections.end());
        makeSignObject(*txn, *detectionIt, {10, 0});

        db::eye::VerifiedDetectionPairMatches verifiedDetectionPairMatches;
        for (auto detectionId : detectionsToVerify)
        {
            verifiedDetectionPairMatches.push_back(
                db::eye::VerifiedDetectionPairMatch{
                    db::eye::VerificationSource::Toloka,
                    primaryDetectionId,
                    detectionId});
        };
        db::eye::VerifiedDetectionPairMatchGateway{*txn}.insertx(verifiedDetectionPairMatches);

        txn->commit();
    }

    auto workerConfig = makeWorkerConfig({db::eye::DetectionType::Sign});
    ObjectManager objectManager(workerConfig);
    objectManager.processBatchInLoopMode(100);

    auto txn = newTxn();
    auto verificationRequests = db::eye::VerifiedDetectionPairMatchGateway(*txn).load();

    // objectManager should have created 3 verification requests for each of two objects
    const auto clusters = loadDetectionClusters(*newTxn());
    EXPECT_EQ(clusters.size(), 2u);
    EXPECT_EQ(verificationRequests.size(), 5u);
}


TEST_F(VerifyObjectDuplicationFixture, should_create_verification_request)
{
    const auto primaryDetectionId = *secondSignDetections.begin();

    /// Create object for selected primaryDetection in order to
    /// narrow variants of VerifiedDetectionPairMatches
    {
        auto txn = newTxn();
        auto detectionIt = std::find_if(
            detections.begin(), detections.end(),
            [primaryDetectionId](const auto& detection) { return detection.id() == primaryDetectionId; }
        );
        ASSERT_TRUE(detectionIt != detections.end());
        makeSignObject(*txn, *detectionIt, {10, 0});

        txn->commit();
    }

    auto workerConfig = makeWorkerConfig({db::eye::DetectionType::Sign});
    ObjectManager objectManager(workerConfig);
    objectManager.processBatchInLoopMode(100);

    auto txn = newTxn();
    auto verificationRequests = db::eye::VerifiedDetectionPairMatchGateway(*txn).load();

    // objectManager should have created 1 verification requests for between clusters
    const auto clusters = loadDetectionClusters(*newTxn());
    EXPECT_EQ(clusters.size(), 2u);
    EXPECT_EQ(verificationRequests.size(), 1u);
}


class MatchAllFrameMatcher : public FrameMatcher {
public:
    MatchAllFrameMatcher() = default;

    MatchedFramesPairs makeMatches(
        const DetectionStore& /* store */,
        const std::vector<std::pair<db::TId, db::TId>>& frameIdPairs) const override
    {
        MatchedFramesPairs result;
        for (auto [id0, id1] : frameIdPairs) {
            result.push_back(
                MatchedFramesPair{
                    .id0 = id0,
                    .id1 = id1
                }
            );
        }
        return result;
    }
};

/// Detections from the first passage do not match to other detections
/// but expected to be visible on their frames
struct VerifyMissingDetectionsFixture : public ObjectClusterizationFixture
{
    std::vector<db::TIdSet> detectionClusters;

    VerifyMissingDetectionsFixture()
    {
        addPassageWithMissingDetection();
        evaluateClusters();
    }

    ObjectManagerConfig makeWorkerConfig(db::eye::DetectionTypes detectionTypes) const override
    {
        auto config = ObjectClusterizationFixture::makeWorkerConfig(detectionTypes);

        config.detectionMatcher = std::make_shared<MockDetectionMatcher>(
            detectionClusters
        );

        config.visibilityPredictor = std::make_shared<DummyVisibilityPredictor>(true);
        config.frameMatcher = std::make_shared<MatchAllFrameMatcher>();

        const db::TId geoId = 225;
        config.verificationRules.push_back(
            VerificationRule{
                .geoId = geoId,
                .objectType = db::eye::ObjectType::Sign,
                .objectPredicate = [](const auto& /* obj */) { return true; },
                .verificationAction = VerificationAction{
                    .verifyObjectMissingness = true
                }
            }
        );
        auto geoIdProvider = makeMockGeoIdProvider();
        config.geoIdProvider = geoIdProvider;
        EXPECT_CALL(*geoIdProvider, load(::testing::An<const geolib3::Point2&>()))
        .WillRepeatedly(::testing::Return(db::TIds{geoId}));
        return config;
    }

    void addPassageWithMissingDetection()
    {
        auto txn = newTxn();
        const auto& device = devices.at(0);
        PassageDetections passage;

        passage.detectionsInFrames.push_back(
            addFrameWithDetections(*txn, device.id(), 24h, {0, 0}, geolib3::Heading{90},
                {traffic_signs::TrafficSign::ProhibitoryMaxSpeed50})
        );

        detectionsInPassages.push_back(std::move(passage));

        txn->commit();
    }

    void evaluateClusters()
    {
        size_t passageIdx = 0;
        db::TIdSet firstSignDetections;
        db::TIdSet secondSignDetections;
        for (const auto& passageDetections : detectionsInPassages) {
            const bool isLastPassage = passageIdx == detectionsInPassages.size() - 1;
            for (const auto& detections : passageDetections.detectionsInFrames) {
                firstSignDetections.insert(detections.at(0).id());
                if (!isLastPassage) {
                    secondSignDetections.insert(detections.at(1).id());
                }
            }
            ++passageIdx;
        }
        detectionClusters.push_back(std::move(firstSignDetections));
        detectionClusters.push_back(std::move(secondSignDetections));
    }
};


TEST_F(VerifyMissingDetectionsFixture, should_create_detection_missingness_verification_request)
{
    db::TIds detectionGroupIdsBeforeMissing;

    for (size_t i = 0; i < groups.size() - 1; ++i) {
        detectionGroupIdsBeforeMissing.push_back(groups.at(i).id());
    }

    auto workerConfig = makeWorkerConfig({db::eye::DetectionType::Sign});
    ObjectManager objectManager(workerConfig);
    /// First process detections where both objects are visible
    objectManager.processBatch(detectionGroupIdsBeforeMissing);
    /// Then process next detections
    objectManager.processBatch({groups.back().id()});

    auto txn = newTxn();
    auto verificationRequests = db::eye::VerifiedDetectionMissingOnFrameGateway(*txn).load();

    ASSERT_EQ(verificationRequests.size(), 1u);
    const auto& request = verificationRequests.at(0);
    EXPECT_EQ(request.source(), db::eye::VerificationSource::Toloka);
    EXPECT_EQ(request.frameId(), frames.back().id());
}


} // namespace maps::mrc::eye::tests
