#include "quality_classifier.h"
#include "road_classifier.h"
#include "rotation_classifier.h"
#include "yavision_classifier.h"
#include "position_accuracy_classifier.h"

#include <library/cpp/testing/gtest/gtest.h>
#include <maps/wikimap/mapspro/services/mrc/libs/config/include/config.h>
#include <opencv2/imgcodecs/imgcodecs_c.h>
#include <maps/libs/common/include/file_utils.h>
#include <yandex/maps/mrc/unittest/local_server.h>
#include <yandex/maps/mrc/unittest/unittest_config.h>

namespace maps::mrc::classifiers::tests {

namespace {

const std::string QUALITY_CLASSIFIER_PATH = GetWorkPath() + "/quality_classifier.gdef";
const std::string ROAD_CLASSIFIER_PATH = GetWorkPath() + "/road_classifier.gdef";
const std::string ROTATION_CLASSIFIER_PATH = GetWorkPath() + "/rotation_classifier.gdef";
const std::string GOOD_IMAGE_PATH = GetWorkPath() + "/good_photo.jpg";
const std::string ROTATION_0_IMAGE_PATH = GetWorkPath() + "/0.jpg";
const std::string ROTATION_90_IMAGE_PATH = GetWorkPath() + "/90.jpg";
const std::string ROTATION_180_IMAGE_PATH = GetWorkPath() + "/180.jpg";
const std::string ROTATION_270_IMAGE_PATH = GetWorkPath() + "/270.jpg";

cv::Mat loadImage(const std::string& path)
{
    return common::decodeImage(maps::common::readFileToVector(path));
}

cv::Mat getBadPhoto()
{
    // Black 224x224x3 image
    cv::Mat mat(cv::Size(224, 224), CV_8UC3, cv::Scalar(0, 0, 0));
    return mat;
}

common::Blob matToBlob(const cv::Mat& img)
{
    return common::toBlob(common::encodeImage(img));
}

cv::Mat blobToMat(const common::Blob& blob)
{
    return common::decodeImage(blob);
}

common::Blob
RotationClassifierNormalizeOrientation(const RotationClassifier &classifier, const common::Blob& blob)
{
    auto orient = classifier.detectImageOrientation(blob);
    if (orient.isNormal())
        return blob;
    auto result = common::transformByImageOrientation(blob, orient);
    return common::Blob{result.begin(), result.end()};
}

using TestFixture = unittest::WithUnittestConfig<unittest::YavisionStubFixture>;


struct PositionClassifierFixture : testing::Test {
    PositionClassifierFixture() {
        std::vector<float> badLon {39.6942974, 39.6942974, 39.6947809, 39.6948805, 39.6949278};
        std::vector<float> badLat {54.59144, 54.59144, 54.590741, 54.5905765, 54.5904848};
        std::vector<float> badAccuracy {116.1000, 111.3560, 82.8630, 88.9390, 99.1430};
        std::vector<std::string> badTime {
            "2019-08-08 17:46:09",
            "2019-08-08 17:46:12",
            "2019-08-08 17:46:19",
            "2019-08-08 17:46:23",
            "2019-08-08 17:46:28"
        };

        for (size_t i = 0; i != badLon.size(); ++i) {
            badTrackPoints.emplace_back()
                .setTimestamp(chrono::parseSqlDateTime(badTime[i]))
                .setGeodeticPos(geolib3::Point2(badLon[i], badLat[i]))
                .setAccuracyMeters(badAccuracy[i]);
        }

        std::vector<float> goodLon {37.533023, 37.532898, 37.532766, 37.532609, 37.532454, 37.532304};
        std::vector<float> goodLat {56.566329, 56.566427, 56.566534, 56.566640, 56.566751, 56.566865};
        std::vector<std::string> goodTime {
            "2019-08-08 17:46:09",
            "2019-08-08 17:46:10",
            "2019-08-08 17:46:11",
            "2019-08-08 17:46:12",
            "2019-08-08 17:46:13",
            "2019-08-08 17:46:14"
        };

        for (size_t i = 0; i != goodLon.size(); ++i) {
            goodTrackPoints.emplace_back()
                .setTimestamp(chrono::parseSqlDateTime(goodTime[i]))
                .setGeodeticPos(geolib3::Point2(goodLon[i], goodLat[i]))
                .setAccuracyMeters(3.9);
        }

        // Same data, but half of accuracy values missing
        for (size_t i = 0; i != goodLon.size(); ++i) {
            missingAccuracyTrackPoints.emplace_back()
                .setTimestamp(chrono::parseSqlDateTime(goodTime[i]))
                .setGeodeticPos(geolib3::Point2(goodLon[i], goodLat[i]));
            if (i < goodLon.size() / 2 - 1) {
                missingAccuracyTrackPoints[i].setAccuracyMeters(3.9);
            }
        }
    }

    db::TrackPoints badTrackPoints;
    db::TrackPoints goodTrackPoints;
    db::TrackPoints missingAccuracyTrackPoints;
};

} // anonymous namespace

TEST(test_classifiers, test_road_classifier)
{
    const double CONFIDENCE = 1.0f;
    const double EPSILON = 0.1f;
    const double ROAD = 1.0f;
    const double NOT_ROAD = 0.0f;

    RoadClassifier roadClassifier;

    auto img = loadImage(GOOD_IMAGE_PATH);
    auto jpg = matToBlob(img);
    auto result = roadClassifier.callClassifier(img);
    auto roadProbability = roadClassifier.estimateRoadProbability(jpg);
    EXPECT_STREQ(result.first.c_str(), "road");
    EXPECT_NEAR(result.second, CONFIDENCE, EPSILON);
    EXPECT_NEAR(roadProbability, ROAD, EPSILON);

    img = getBadPhoto();
    jpg = matToBlob(img);
    result = roadClassifier.callClassifier(img);
    roadProbability = roadClassifier.estimateRoadProbability(jpg);
    EXPECT_STREQ(result.first.c_str(), "not_road");
    EXPECT_NEAR(result.second, CONFIDENCE, EPSILON);
    EXPECT_NEAR(roadProbability, NOT_ROAD, EPSILON);
}

TEST(test_classifiers, test_rotation_classifier)
{
    // Classifier gives degree in counter clockwise rotation (CCW).
    // Given images named for clockwise rotation (CW).
    RotationClassifier classifier;


    const auto img0 = loadImage(ROTATION_0_IMAGE_PATH);
    const common::ImageOrientation orientation0(false, common::Rotation::CCW_0);

    EXPECT_STREQ(classifier.callClassifier(img0).first.c_str(), "0_degrees");
    EXPECT_EQ(classifier.detectImageOrientation(img0), orientation0);

    const auto img90 = loadImage(ROTATION_90_IMAGE_PATH); // CW
    const common::ImageOrientation orientation90(false, common::Rotation::CCW_90);

    EXPECT_STREQ(classifier.callClassifier(img90).first.c_str(), "270_degrees"); // CCW
    EXPECT_EQ(classifier.detectImageOrientation(img90), orientation90);

    const auto img180 = loadImage(ROTATION_180_IMAGE_PATH); // CW
    const common::ImageOrientation orientation180(false, common::Rotation::CCW_180);

    EXPECT_STREQ(classifier.callClassifier(img180).first.c_str(), "180_degrees"); // CCW
    EXPECT_EQ(classifier.detectImageOrientation(img180), orientation180);

    auto img270 = loadImage(ROTATION_270_IMAGE_PATH); // CW
    const common::ImageOrientation orientation270(false, common::Rotation::CCW_270);

    EXPECT_STREQ(classifier.callClassifier(img270).first.c_str(), "90_degrees"); // CCW
    EXPECT_EQ(classifier.detectImageOrientation(img270), orientation270);
}

TEST(test_classifiers, test_rotation_idempotency)
{
    const std::string ZERO_DEGREES = "0_degrees";

    RotationClassifier rotationClassifier;

    auto img0 = matToBlob(loadImage(ROTATION_0_IMAGE_PATH));
    auto img0_normalized
        = blobToMat(RotationClassifierNormalizeOrientation(rotationClassifier, img0));
    EXPECT_STREQ(
        rotationClassifier.callClassifier(img0_normalized).first.c_str(),
        ZERO_DEGREES.c_str());

    auto img270
        = matToBlob(loadImage(ROTATION_90_IMAGE_PATH)); // CW -> CCW
    auto img270_normalized
        = blobToMat(RotationClassifierNormalizeOrientation(rotationClassifier, img270));

    EXPECT_STREQ(
        rotationClassifier.callClassifier(img270_normalized).first.c_str(),
        ZERO_DEGREES.c_str());

    auto img180 = matToBlob(loadImage(ROTATION_180_IMAGE_PATH));
    auto img180_normalized
        = blobToMat(RotationClassifierNormalizeOrientation(rotationClassifier, img180));
    EXPECT_STREQ(rotationClassifier.callClassifier(img180_normalized)
                     .first.c_str(),
                 ZERO_DEGREES.c_str());

    auto img90
        = matToBlob(loadImage(ROTATION_270_IMAGE_PATH)); // CW -> CCW
    auto img90_normalized
        = blobToMat(RotationClassifierNormalizeOrientation(rotationClassifier, img90));
    EXPECT_STREQ(
        rotationClassifier.callClassifier(img90_normalized).first.c_str(),
        ZERO_DEGREES.c_str());
}

TEST(test_classifiers, test_quality_classifier)
{
    const double CONFIDENCE = 1.0f;
    const double EPSILON = 0.1f;
    const double GOOD_QUALITY = 1.0f;
    const double BAD_QUALITY = 0.0f;

    QualityClassifier qualityClassifier;

    auto img = loadImage(GOOD_IMAGE_PATH);
    auto jpg = matToBlob(img);
    auto result = qualityClassifier.callClassifier(img);
    auto quality = qualityClassifier.estimateImageQuality(jpg);
    bool homogeneous = qualityClassifier.isImageHomogeneous(img);
    EXPECT_STREQ(result.first.c_str(), "good_quality");
    EXPECT_NEAR(result.second, CONFIDENCE, EPSILON);
    EXPECT_NEAR(quality, GOOD_QUALITY, EPSILON);
    EXPECT_FALSE(homogeneous);

    img = getBadPhoto();
    jpg = matToBlob(img);
    result = qualityClassifier.callClassifier(img);
    quality = qualityClassifier.estimateImageQuality(jpg);
    homogeneous = qualityClassifier.isImageHomogeneous(img);
    EXPECT_STREQ(result.first.c_str(), "bad_quality");
    EXPECT_NEAR(result.second, CONFIDENCE, EPSILON);
    EXPECT_NEAR(quality, BAD_QUALITY, EPSILON);
    EXPECT_TRUE(homogeneous);
}

TEST(test_classifiers, test_yavision_classifier)
{
    const double NOT_FORBIDDEN = 0.0;
    const double FORBIDDEN = 1.0;
    const double EPSILON = 0.1;

    TestFixture testFixture;

    auto yavisionUrl = testFixture.config().externals().yavisionUrl();
    INFO() << yavisionUrl;
    YavisionClassifier yavisionClassifier(yavisionUrl);

    {
        const auto img = loadImage(GOOD_IMAGE_PATH);
        const auto jpg = matToBlob(img);
        const auto forbiddenProbability = yavisionClassifier.estimateForbiddenProbability(jpg);
        EXPECT_FALSE(yavisionClassifier.hasForbiddenContent(jpg));
        EXPECT_NEAR(forbiddenProbability, NOT_FORBIDDEN, EPSILON);
    }

    {
        const auto img = getBadPhoto();
        const auto jpg = matToBlob(img);
        const auto forbiddenProbability = yavisionClassifier.estimateForbiddenProbability(jpg);
        EXPECT_TRUE(yavisionClassifier.hasForbiddenContent(jpg));
        EXPECT_NEAR(forbiddenProbability, FORBIDDEN, EPSILON);
    }
}

TEST_F(PositionClassifierFixture, test_position_classifier_features)
{
    // No neighbors
    auto features = calculatePositionFeatures(db::TrackPoints());
    EXPECT_THAT(features, testing::ElementsAreArray({0.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}));

    // Only one neighbor
    features = calculatePositionFeatures(db::TrackPoints(badTrackPoints.begin(), badTrackPoints.begin() + 1));
    EXPECT_THAT(features, testing::ElementsAreArray({1.0f, 116.1f, 116.1f, 116.1f, -1.0f, -1.0f, -1.0f, -1.0f}));

    // All neighbors features calculation
    features = calculatePositionFeatures(badTrackPoints);
    EXPECT_EQ(features[0], badTrackPoints.size());
    EXPECT_NEAR(features[1], 99.14299, 0.0001);
    EXPECT_NEAR(features[2], 116.0999, 0.0001);
    EXPECT_NEAR(features[3], 82.863, 0.0001);
    EXPECT_NEAR(features[4], 14.9794, 0.0001);
    EXPECT_NEAR(features[5], 83.7581, 0.0001);
    EXPECT_NEAR(features[6], 3.47959, 0.0001);
    EXPECT_NEAR(features[7], 11.9654, 0.0001);

    // Test even number of track points accuracy median
    features = calculatePositionFeatures(db::TrackPoints(badTrackPoints.begin(), badTrackPoints.begin() + 4));
    EXPECT_NEAR(features[1], 100.1475, 0.0001);

    // Test median with missing accuracy data
    features = calculatePositionFeatures(missingAccuracyTrackPoints);
    EXPECT_NEAR(features[1], -1.0, 0.0001);

    // Test median if not enough missing accuracy data (<= half track points without accuracy)
    features = calculatePositionFeatures(
        db::TrackPoints(missingAccuracyTrackPoints.begin(), missingAccuracyTrackPoints.begin() + 4)
    );
    EXPECT_NEAR(features[1], 3.9, 0.0001);
}

TEST_F(PositionClassifierFixture, test_position_classifier_calcer)
{
    EXPECT_TRUE(isPositionInaccurate(badTrackPoints));
    EXPECT_FALSE(isPositionInaccurate(goodTrackPoints));
}

} // namespace maps::mrc::classifiers::tests
