#include <maps/wikimap/mapspro/services/mrc/eye/lib/detect_object/tests/common.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detect_object/tests/fixture.h>

#include <maps/wikimap/mapspro/services/mrc/eye/lib/detect_object/include/detect_object.h>

#include <library/cpp/testing/unittest/registar.h>
#include <mapreduce/yt/tests/yt_unittest_lib/yt_unittest_lib.h>

namespace maps::mrc::eye::tests {

Y_UNIT_TEST_SUITE_F(simple_detect_object, Fixture)
{

Y_UNIT_TEST(run)
{
    const auto loader = FrameLoader::fromConfig(config());
    const auto yt = NYT::NTesting::CreateTestClient();

    SimpleDetectObjectConfig workerConfig;
    workerConfig.yt.client = yt.Get();
    workerConfig.yt.frameLoader = &loader;
    workerConfig.yt.rootPath = "//tmp/detect_object/run";
    workerConfig.yt.partitionSize = 10;
    workerConfig.yt.concurrency = 1;
    workerConfig.yt.useGpu = false;
    workerConfig.mrc.pool = &pool();
    workerConfig.mrc.commit = true;
    workerConfig.rework = true;

    ExampleDetectObject detect(workerConfig);
    detect.processBatch(frameIdsAt({1, 2, 4}));

    {
        const auto& frame = frames.at(1);
        const auto recognition = recognitionFor(frame, version);

        UNIT_ASSERT_EQUAL(recognition.type(), db::eye::RecognitionType::DetectTrafficLight);
        UNIT_ASSERT_EQUAL(recognition.source(), db::eye::RecognitionSource::Model);
        UNIT_ASSERT_EQUAL(recognition.version(), version);
        UNIT_ASSERT(recognition.txnId());

        const auto trafficLights = recognition.value<db::eye::DetectedTrafficLights>();

        UNIT_ASSERT_EQUAL(trafficLights.size(), 1u);
    }

    {
        const auto& frame = frames.at(2);
        const auto recognition = recognitionFor(frame, version);

        UNIT_ASSERT_EQUAL(recognition.type(), db::eye::RecognitionType::DetectTrafficLight);
        UNIT_ASSERT_EQUAL(recognition.source(), db::eye::RecognitionSource::Model);
        UNIT_ASSERT_EQUAL(recognition.version(), version);
        UNIT_ASSERT(recognition.txnId());

        const auto trafficLights = recognition.value<db::eye::DetectedTrafficLights>();

        UNIT_ASSERT_EQUAL(trafficLights.size(), 0u);
    }

    {
        const auto& frame = frames.at(4);
        const auto recognition = recognitionFor(frame, version);

        UNIT_ASSERT_EQUAL(recognition.type(), db::eye::RecognitionType::DetectTrafficLight);
        UNIT_ASSERT_EQUAL(recognition.source(), db::eye::RecognitionSource::Model);
        UNIT_ASSERT_EQUAL(recognition.version(), version);
        UNIT_ASSERT(recognition.txnId());

        const auto trafficLights = recognition.value<db::eye::DetectedTrafficLights>();

        UNIT_ASSERT_EQUAL(trafficLights.size(), 2u);
    }
}

Y_UNIT_TEST(batch)
{
    const auto loader = FrameLoader::fromConfig(config());
    const auto yt = NYT::NTesting::CreateTestClient();

    SimpleDetectObjectConfig workerConfig;
    workerConfig.yt.client = yt.Get();
    workerConfig.yt.frameLoader = &loader;
    workerConfig.yt.rootPath = "//tmp/batch/batch";
    workerConfig.yt.partitionSize = 10;
    workerConfig.yt.concurrency = 1;
    workerConfig.yt.useGpu = false;
    workerConfig.mrc.pool = &pool();
    workerConfig.mrc.commit = true;

    ExampleDetectObject detect(workerConfig);
    UNIT_ASSERT_EQUAL(detect.processBatchInLoopMode(2), true);
    UNIT_ASSERT_EQUAL(detect.processBatchInLoopMode(2), true);
}

Y_UNIT_TEST(rework_false)
{
    const auto loader = FrameLoader::fromConfig(config());
    const auto yt = NYT::NTesting::CreateTestClient();

    SimpleDetectObjectConfig workerConfig;
    workerConfig.yt.client = yt.Get();
    workerConfig.yt.frameLoader = &loader;
    workerConfig.yt.rootPath = "//tmp/detect_object/run";
    workerConfig.yt.partitionSize = 10;
    workerConfig.yt.concurrency = 1;
    workerConfig.yt.useGpu = false;
    workerConfig.mrc.pool = &pool();
    workerConfig.mrc.commit = true;
    workerConfig.rework = false;

    ExampleDetectObject detect(workerConfig);
    detect.processBatch(frameIdsAt({0}));

    const auto& frame = frames.at(0);
    const auto recognition = recognitionFor(frame, version);

    UNIT_ASSERT_EQUAL(recognition.type(), db::eye::RecognitionType::DetectTrafficLight);
    UNIT_ASSERT_EQUAL(recognition.source(), db::eye::RecognitionSource::Model);
    UNIT_ASSERT_EQUAL(recognition.version(), version);
    UNIT_ASSERT(recognition.txnId());

    const auto trafficLights = recognition.value<db::eye::DetectedTrafficLights>();

    UNIT_ASSERT_EQUAL(trafficLights.size(), 2u);
}

Y_UNIT_TEST(rework_true)
{
    const auto loader = FrameLoader::fromConfig(config());
    const auto yt = NYT::NTesting::CreateTestClient();

    SimpleDetectObjectConfig workerConfig;
    workerConfig.yt.client = yt.Get();
    workerConfig.yt.frameLoader = &loader;
    workerConfig.yt.rootPath = "//tmp/detect_object/run";
    workerConfig.yt.partitionSize = 10;
    workerConfig.yt.concurrency = 1;
    workerConfig.yt.useGpu = false;
    workerConfig.mrc.pool = &pool();
    workerConfig.mrc.commit = true;
    workerConfig.rework = true;

    ExampleDetectObject detect(workerConfig);
    detect.processBatch(frameIdsAt({0}));

    const auto& frame = frames.at(0);
    const auto recognition = recognitionFor(frame, version);

    UNIT_ASSERT_EQUAL(recognition.type(), db::eye::RecognitionType::DetectTrafficLight);
    UNIT_ASSERT_EQUAL(recognition.source(), db::eye::RecognitionSource::Model);
    UNIT_ASSERT_EQUAL(recognition.version(), version);
    UNIT_ASSERT(recognition.txnId());

    const auto trafficLights = recognition.value<db::eye::DetectedTrafficLights>();

    UNIT_ASSERT_EQUAL(trafficLights.size(), 1u);
}

} // Y_UNIT_TEST_SUITE

} // namespace maps::mrc::eye::tests
