#include <maps/wikimap/mapspro/services/mrc/eye/lib/detect_road_marking/tests/fixture.h>

#include <maps/wikimap/mapspro/services/mrc/eye/lib/detect_road_marking/include/detect_road_marking.h>

#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/frame_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/recognition_gateway.h>

#include <maps/wikimap/mapspro/services/mrc/eye/lib/location/include/rotation.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/unit_test/include/frame.h>

#include <library/cpp/testing/unittest/registar.h>
#include <mapreduce/yt/tests/yt_unittest_lib/yt_unittest_lib.h>

namespace maps::mrc::eye::tests {

struct DetectRoadMarkingFixture : public Fixture {
    DetectRoadMarkingFixture() {
        auto txn = newTxn();

        devices = {
            {db::eye::MrcDeviceAttrs{"M1"}},
        };
        db::eye::DeviceGateway(*txn).insertx(devices);

        txn->commit();
    }
};

Y_UNIT_TEST_SUITE_F(detect_road_marking_tests, DetectRoadMarkingFixture)
{

Y_UNIT_TEST(process_batch)
{
    const auto loader = FrameLoader::fromConfig(config());
    const auto client = NYT::NTesting::CreateTestClient();

    DetectRoadMarkingConfig workerConfig;
    workerConfig.yt.client = client.Get();
    workerConfig.yt.concurrency = 1;
    workerConfig.yt.frameLoader = &loader;
    workerConfig.yt.partitionSize = 10;
    workerConfig.yt.rootPath = "//tmp/detect_road_marking/run";
    workerConfig.yt.commit = false;
    workerConfig.mrc.pool = &pool();
    workerConfig.mrc.commit = true;

    DetectRoadMarking worker(workerConfig);

    frames.push_back(
        {devices[0].id(), identical, makeUrlContext(1, "1s"), {1920, 1080}, time()}
    );
    {
        auto txn = newTxn();
        db::eye::FrameGateway(*txn).insertx(frames.back());
        txn->commit();
    }

    recognitions.push_back(
        {
            frames.back().id(),
            frames.back().orientation(),
            db::eye::RecognitionType::DetectPanel,
            db::eye::RecognitionSource::Model,
            0, // version
            db::eye::DetectedPanel{
                common::ImageBox(0, 0, 1920, 950)
            }
        }
    );
    {
        auto txn = newTxn();
        db::eye::RecognitionGateway(*txn).insertx(recognitions.back());
        txn->commit();
    }

    db::TIds frameIds{frames[0].id()};

    worker.processBatch(frameIds);

    auto txn = newTxn();
    auto recognitions = db::eye::RecognitionGateway(*txn).load(
        db::eye::table::Recognition::frameId == frames[0].id() &&
        db::eye::table::Recognition::type == db::eye::RecognitionType::DetectRoadMarking
    );

    const db::eye::Recognition& recognition = recognitions[0];
    UNIT_ASSERT_EQUAL(recognition.frameId(), frames[0].id());
    UNIT_ASSERT_EQUAL(recognition.orientation(), frames[0].orientation());
    UNIT_ASSERT_EQUAL(recognition.type(), db::eye::RecognitionType::DetectRoadMarking);
    UNIT_ASSERT_EQUAL(recognition.version(), 0);

    const auto roadMarkings = recognition.value<db::eye::DetectedRoadMarkings>();
    UNIT_ASSERT_EQUAL(roadMarkings.size(), 2u);

    auto it1 = std::find_if(roadMarkings.begin(), roadMarkings.end(),
        [](const db::eye::DetectedRoadMarking& item) {
            return item.type == traffic_signs::TrafficSign::RoadMarkingLaneDirectionL;
        }
    );
    UNIT_ASSERT_EQUAL(it1 == roadMarkings.end(), false);
    UNIT_ASSERT_EQUAL(it1->box, common::ImageBox(564, 780, 919, 895));

    auto it2 = std::find_if(roadMarkings.begin(), roadMarkings.end(),
        [](const db::eye::DetectedRoadMarking& item) {
            return item.type == traffic_signs::TrafficSign::RoadMarkingLaneDirectionF;
        }
    );
    UNIT_ASSERT_EQUAL(it2 == roadMarkings.end(), false);
    UNIT_ASSERT_EQUAL(it2->box, common::ImageBox(1197, 777, 1747, 901));
}

Y_UNIT_TEST(process_batch_in_loop_mode)
{
    const auto loader = FrameLoader::fromConfig(config());
    const auto client = NYT::NTesting::CreateTestClient();

    DetectRoadMarkingConfig workerConfig;
    workerConfig.yt.client = client.Get();
    workerConfig.yt.concurrency = 1;
    workerConfig.yt.frameLoader = &loader;
    workerConfig.yt.partitionSize = 10;
    workerConfig.yt.rootPath = "//tmp/detect_road_marking/run";
    workerConfig.yt.commit = false;
    workerConfig.mrc.pool = &pool();
    workerConfig.mrc.commit = true;

    const size_t batchSize = 2u;
    DetectRoadMarking worker(workerConfig);

    UNIT_ASSERT_EQUAL(worker.processBatchInLoopMode(batchSize), false);

    frames.push_back(
        {devices[0].id(), identical, makeUrlContext(1, "1s"), {1920, 1080}, time()}
    );
    {
        auto txn = newTxn();
        db::eye::FrameGateway(*txn).insertx(frames.back());
        txn->commit();
    }

    UNIT_ASSERT_EQUAL(worker.processBatchInLoopMode(batchSize), true);

    db::eye::Recognitions dbRecognitions;
    {
        auto txn = newTxn();
        dbRecognitions = db::eye::RecognitionGateway(*txn).load(
            db::eye::table::Recognition::frameId == frames[0].id() &&
            db::eye::table::Recognition::type == db::eye::RecognitionType::DetectRoadMarking
        );
    }
    UNIT_ASSERT_EQUAL(dbRecognitions.size(), 0u);

    recognitions.push_back(
        {
            frames.back().id(),
            frames.back().orientation(),
            db::eye::RecognitionType::DetectPanel,
            db::eye::RecognitionSource::Model,
            0, // version
            db::eye::DetectedPanel{
                common::ImageBox(0, 0, 1920, 950)
            }
        }
    );
    {
        auto txn = newTxn();
        db::eye::RecognitionGateway(*txn).insertx(recognitions.back());
        txn->commit();
    }

    UNIT_ASSERT_EQUAL(worker.processBatchInLoopMode(batchSize), true);

    {
        auto txn = newTxn();
        dbRecognitions = db::eye::RecognitionGateway(*txn).load(
            db::eye::table::Recognition::frameId == frames[0].id() &&
            db::eye::table::Recognition::type == db::eye::RecognitionType::DetectRoadMarking
        );
    }

    UNIT_ASSERT_EQUAL(dbRecognitions.size(), 1u);
    db::eye::Recognition recognition = dbRecognitions[0];
    UNIT_ASSERT_EQUAL(recognition.frameId(), frames[0].id());
    UNIT_ASSERT_EQUAL(recognition.orientation(), frames[0].orientation());
    UNIT_ASSERT_EQUAL(recognition.type(), db::eye::RecognitionType::DetectRoadMarking);
    UNIT_ASSERT_EQUAL(recognition.version(), 0);

    auto roadMarkings = recognition.value<db::eye::DetectedRoadMarkings>();
    UNIT_ASSERT_EQUAL(roadMarkings.size(), 2u);

    auto it1 = std::find_if(roadMarkings.begin(), roadMarkings.end(),
        [](const db::eye::DetectedRoadMarking& item) {
            return item.type == traffic_signs::TrafficSign::RoadMarkingLaneDirectionL;
        }
    );
    UNIT_ASSERT_EQUAL(it1 == roadMarkings.end(), false);
    UNIT_ASSERT_EQUAL(it1->box, common::ImageBox(564, 780, 919, 895));

    auto it2 = std::find_if(roadMarkings.begin(), roadMarkings.end(),
        [](const db::eye::DetectedRoadMarking& item) {
            return item.type == traffic_signs::TrafficSign::RoadMarkingLaneDirectionF;
        }
    );
    UNIT_ASSERT_EQUAL(it2 == roadMarkings.end(), false);
    UNIT_ASSERT_EQUAL(it2->box, common::ImageBox(1197, 777, 1747, 901));

    // нужно обработать recognition дорожной разметки с прошлого шага
    UNIT_ASSERT_EQUAL(worker.processBatchInLoopMode(batchSize), true);
    UNIT_ASSERT_EQUAL(worker.processBatchInLoopMode(batchSize), false);

    frames.push_back(
        {
            devices[0].id(),
            common::ImageOrientation(common::Rotation::CW_180),
            makeUrlContext(1, "2s"),
            {1920, 1080}, time()
        }
    );
    {
        auto txn = newTxn();
        db::eye::FrameGateway(*txn).insertx(frames.back());
        txn->commit();
    }

    recognitions.push_back(
        {
            frames.back().id(),
            frames.back().orientation(),
            db::eye::RecognitionType::DetectPanel,
            db::eye::RecognitionSource::Model,
            0, // version
            db::eye::DetectedPanel{
                common::ImageBox(0, 85, 1920, 1080)
            }
        }
    );
    {
        auto txn = newTxn();
        db::eye::RecognitionGateway(*txn).insertx(recognitions.back());
        txn->commit();
    }

    UNIT_ASSERT_EQUAL(worker.processBatchInLoopMode(batchSize), true);

    {
        auto txn = newTxn();
        dbRecognitions = db::eye::RecognitionGateway(*txn).load(
            db::eye::table::Recognition::frameId == frames[1].id() &&
            db::eye::table::Recognition::type == db::eye::RecognitionType::DetectRoadMarking
        );
    }

    UNIT_ASSERT_EQUAL(dbRecognitions.size(), 1u);

    recognition = dbRecognitions[0];
    UNIT_ASSERT_EQUAL(recognition.frameId(), frames[1].id());
    UNIT_ASSERT_EQUAL(recognition.orientation(), frames[1].orientation());
    UNIT_ASSERT_EQUAL(recognition.type(), db::eye::RecognitionType::DetectRoadMarking);
    UNIT_ASSERT_EQUAL(recognition.version(), 0);

    roadMarkings = recognition.value<db::eye::DetectedRoadMarkings>();
    UNIT_ASSERT_EQUAL(roadMarkings.size(), 2u);

    it1 = std::find_if(roadMarkings.begin(), roadMarkings.end(),
        [](const db::eye::DetectedRoadMarking& item) {
            return item.type == traffic_signs::TrafficSign::RoadMarkingLaneDirectionFR;
        }
    );
    UNIT_ASSERT_EQUAL(it1 == roadMarkings.end(), false);
    UNIT_ASSERT_EQUAL(it1->box, common::ImageBox(844, 188, 1111, 268));

    it2 = std::find_if(roadMarkings.begin(), roadMarkings.end(),
        [](const db::eye::DetectedRoadMarking& item) {
            return item.type == traffic_signs::TrafficSign::RoadMarkingLaneDirectionF;
        }
    );
    UNIT_ASSERT_EQUAL(it2 == roadMarkings.end(), false);
    UNIT_ASSERT_EQUAL(it2->box, common::ImageBox(1364, 176, 1746, 260));

    // нужно обработать recognition дорожной разметки с прошлого шага
    UNIT_ASSERT_EQUAL(worker.processBatchInLoopMode(batchSize), true);
    UNIT_ASSERT_EQUAL(worker.processBatchInLoopMode(batchSize), false);
}

} // Y_UNIT_TEST_SUITE

} // namespace maps::mrc::eye::tests
