#include "mapper.h"

#include <maps/wikimap/mapspro/services/mrc/libs/classifiers/include/road_classifier.h>
#include <maps/wikimap/mapspro/services/mrc/libs/classifiers/include/rotation_classifier.h>
#include <maps/wikimap/mapspro/services/mrc/libs/classifiers/include/quality_classifier.h>
#include <maps/wikimap/mapspro/services/mrc/libs/classifiers/include/yavision_classifier.h>
#include <maps/wikimap/mapspro/services/mrc/libs/classifiers/include/forbidden_classifier.h>
#include <maps/wikimap/mapspro/services/mrc/libs/config/include/config.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/object_in_photo_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/privacy_detector/include/privacy_detector_faster_rcnn.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/serialization.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/exif.h>

#include <opencv2/opencv.hpp>

namespace maps::mrc::image_analyzer {

namespace {

class ImageClassifier {
public:
    ImageClassifier() = default;

    ImageClassifier(bool enableForbiddenClassifier)
        : forbiddenClassifier_(
            enableForbiddenClassifier
            ? new classifiers::ForbiddenClassifier()
            : nullptr)
    {
    }
    double estimateRoadProbability(const cv::Mat& image) const
    {
        return roadClassifier_.estimateRoadProbability(image);
    }

    double estimateImageQuality(const cv::Mat& image) const
    {
        return qualityClassifier_.estimateImageQuality(image);
    }

    double estimateForbiddenProbability(const cv::Mat& image) const
    {
        REQUIRE(forbiddenClassifier_, "Forbidden classifier was not enabled");
        return forbiddenClassifier_->estimateForbiddenProbability(image);
    }

    common::ImageOrientation detectImageOrientation(const cv::Mat& image) const
    {
        return rotationClassifier_.detectImageOrientation(image);
    }

private:
    const classifiers::RoadClassifier roadClassifier_;
    const classifiers::RotationClassifier rotationClassifier_;
    const classifiers::QualityClassifier qualityClassifier_;
    std::unique_ptr<classifiers::ForbiddenClassifier> forbiddenClassifier_;
};

constexpr size_t FEATURES_TABLE_INDEX = 0;
constexpr size_t OBJECTS_TABLE_INDEX = 1;

} // namespace

FeatureImageMapper::FeatureImageMapper(const TString& config, bool enableForbiddenClassifier)
    : config_(config)
    , enableForbiddenClassifier_(enableForbiddenClassifier)
{}

void FeatureImageMapper::Do(TReader* reader, TWriter* writer)
{
    auto cfg = common::Config::fromString(config_.c_str());
    auto mds = cfg.makeMdsClient();
    const ImageClassifier classifier(enableForbiddenClassifier_);
    privacy_detector::FasterRCNNDetector privacyDetector;

    for (; reader->IsValid(); reader->Next()) {
        const auto& row = reader->GetRow();

        auto feature = yt::deserialize<db::Feature>(row);
        auto imageBytes = mds.get(feature.mdsKey());
        auto image = common::decodeImage(imageBytes);
        if (image.empty()) {
            WARN() << "feature " << feature.id() << ": empty image";
            feature.setQuality(0.).setRoadProbability(0.).setForbiddenProbability(0.);
            writer->AddRow(yt::serialize(feature), FEATURES_TABLE_INDEX);
            continue;
        }

        std::optional<common::ImageOrientation> orientation;
        if (db::isStandalonePhotosDataset(feature.dataset())) {
            // Prefer EXIF orientation for walks photos because
            // it showed better quality
            orientation = common::parseImageOrientationFromExif(imageBytes);
            if (!orientation.has_value()) {
                orientation = classifier.detectImageOrientation(image);
            }
            feature.setSize(image.cols, image.rows);
        } else {
            orientation = classifier.detectImageOrientation(image);
        }

        if (orientation) {
            feature.setOrientation(*orientation);
        }

        const auto normalizedImage
            = transformByImageOrientation(image, feature.orientation());

        feature.setQuality(classifier.estimateImageQuality(normalizedImage))
            .setRoadProbability(classifier.estimateRoadProbability(normalizedImage));
        if (enableForbiddenClassifier_) {
            feature.setForbiddenProbability(
                classifier.estimateForbiddenProbability(normalizedImage));
        }
        writer->AddRow(yt::serialize(feature), FEATURES_TABLE_INDEX);

        privacy_detector::PrivacyImageBoxes privacyObjects = privacyDetector.detect(normalizedImage);
        cv::Rect normalizedImageRect(0, 0, normalizedImage.cols, normalizedImage.rows);
        for (size_t i = 0; i < privacyObjects.size(); i++) {
            const privacy_detector::PrivacyImageBox &privacyObject = privacyObjects[i];
            cv::Rect boxInsideImage = ((cv::Rect)privacyObject.box) & normalizedImageRect;
            if (boxInsideImage.width <= 0 || boxInsideImage.height <= 0)
                continue;
            common::ImageBox box = common::revertByImageOrientation(
                                      common::ImageBox(boxInsideImage),
                                      {(size_t)image.cols, (size_t)image.rows},
                                      feature.orientation());
            db::ObjectInPhoto dbObject(feature.id(), privacyObject.type, box, privacyObject.confidence);
            writer->AddRow(yt::serialize(dbObject), OBJECTS_TABLE_INDEX);
        }
    }
}

REGISTER_MAPPER(FeatureImageMapper);


YavisionMapper::YavisionMapper(const TString& config) : config_(config) {}

void YavisionMapper::Do(TReader* reader, TWriter* writer)
{
    auto cfg = common::Config::fromString(config_.c_str());
    auto mds = cfg.makeMdsClient();
    const classifiers::YavisionClassifier classifier(
        cfg.externals().yavisionUrl());

    for (; reader->IsValid(); reader->Next()) {
        const auto& row = reader->GetRow();

        auto feature = yt::deserialize<db::Feature>(row);
        auto imageBytes = mds.get(feature.mdsKey());
        auto image = common::decodeImage(imageBytes);
        if (image.empty()) {
            writer->AddRow(yt::serialize(feature), FEATURES_TABLE_INDEX);
            continue;
        }

        const auto normalizedImage
            = transformByImageOrientation(image, feature.orientation());

        try {
            feature.setForbiddenProbability(
                classifier.estimateForbiddenProbability(normalizedImage));
            writer->AddRow(yt::serialize(feature), FEATURES_TABLE_INDEX);
        } catch (const maps::Exception& ex) {
            ERROR() << "Failed to estimate forbidden probability for feature "
                << feature.id();
        }
    }
}

REGISTER_MAPPER(YavisionMapper);

} // namespace maps::mrc::image_analyzer
