#include "classifier.h"
#include "common.h"

#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/serialization.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sideview_classifier/include/sideview.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/exif.h>

#include <mapreduce/yt/interface/node.h>

#include <opencv2/opencv.hpp>
#include <optional>

namespace maps::mrc::sideview_classifier {

namespace {

cv::Mat loadImage(
    http::Client& httpClient,
    const std::string& url,
    common::ImageOrientation orientation)
{
    static constexpr int HTTP_STATUS_OK = 200;

    auto bytes = retry(
        [&]() -> common::Bytes {
            http::Request request(httpClient, http::GET, http::URL(url));
            auto response = request.perform();

            if (response.status() != HTTP_STATUS_OK) {
                throw RuntimeError() << "Http status " << response.status();
            }

            return response.readBodyToVector();
        },
        ::maps::common::RetryPolicy()
            .setTryNumber(5)
            .setInitialCooldown(std::chrono::seconds(1))
            .setCooldownBackoff(2)
    );

    auto mat = common::decodeImage(bytes);
    return common::transformByImageOrientation(mat, orientation);
}

} // namespace

void SideviewClassifier::Do(TReader* reader, TWriter* writer)
{
    http::Client httpClient;
    httpClient.setTimeout(std::chrono::seconds(10));

    sideview::SideViewClassifier classifier;

    for (; reader->IsValid(); reader->Next()) {
        const auto& row = reader->GetRow();

        TString url1 = row[COL_URL_1].AsString();
        TString url2 = row[COL_URL_2].AsString();
        auto orient1 = common::ImageOrientation::fromExif(row[COL_ORIENTATION_1].AsInt64());
        auto orient2 = common::ImageOrientation::fromExif(row[COL_ORIENTATION_2].AsInt64());

        auto image1 = loadImage(httpClient, url1, orient1);
        auto image2 = loadImage(httpClient, url2, orient2);

        auto [type, confidence] = classifier.inference(image1, image2);

        writer->AddRow(
            NYT::TNode()
            (COL_KEY_SOURCE_ID, row[COL_KEY_SOURCE_ID].AsString())
            (COL_KEY_MIN_FEATURE_ID, row[COL_KEY_MIN_FEATURE_ID].AsInt64())
            (COL_FEATURE_ID_1, row[COL_FEATURE_ID_1].AsInt64())
            (COL_FEATURE_ID_2, row[COL_FEATURE_ID_2].AsInt64())
            (COL_TYPE, static_cast<int64_t>(type))
            (COL_CONFIDENCE, confidence)
        );
    }
}

REGISTER_MAPPER(SideviewClassifier);

} // namespace maps::mrc::sideview_classifier
