#include <maps/wikimap/mapspro/services/mrc/libs/carsegm/include/carsegm.h>
#include <maps/wikimap/mapspro/services/mrc/libs/config/include/config.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/hypothesis_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/sign_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/common.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/io.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/operation.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/serialization.h>

#include <maps/libs/chrono/include/time_point.h>
#include <maps/libs/cmdline/include/cmdline.h>

#include <yandex/maps/wiki/common/extended_xml_doc.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/common/include/retry.h>
#include <maps/libs/log8/include/log8.h>

#include <mapreduce/yt/interface/client.h>
#include <util/generic/size_literals.h>

#include <sstream>
#include <utility>

namespace maps::mrc::hide {

namespace {

using namespace std::literals::chrono_literals;

using MrcConfig = maps::mrc::common::Config;
using WikiConfig = wiki::common::ExtendedXmlDoc;
using PoolHolder = wiki::common::PoolHolder;

db::SignFeatures loadSignFeatures(pgpool3::Pool& mrcPool)
{
    auto txn = mrcPool.slaveTransaction();
    return db::SignFeatureGateway(*txn).load();
}

db::TIds collectFeatureIds(const db::SignFeatures& signFeatures)
{
    db::TIdSet featureIds;

    for (const auto& signFeature: signFeatures) {
        featureIds.insert(signFeature.featureId());
    }

    return {featureIds.begin(), featureIds.end()};
}

db::Features loadFeatures(pgpool3::Pool& mrcPool, db::TIds featureIds)
{
    auto txn = mrcPool.slaveTransaction();

    return db::FeatureGateway(*txn).load(
        db::table::Feature::id.in(std::move(featureIds))
    );
}

const TString PATH = "//home/maps/core/mrc/hide";
const TString SIGN_FEATURE_TABLE = PATH + "/sign_feature";
const TString FEATURE_TABLE = PATH + "/feature";
const TString FEATURE_WITH_BOX_TABLE = PATH + "/feature_with_box";
const TString SIGN_ON_CAR_TABLE = PATH + "/sign_on_car";
const TString DROPPED_SIGN_TABLE = PATH + "/dropped_sign";

using Reader = NYT::TTableReader<NYT::TNode>;
using Writer = NYT::TTableWriter<NYT::TNode>;

// Join two tables eature and signs_feature by feature_id as struct with 2 fields.
class Join: public NYT::IReducer<Reader, Writer>  {
public:
    void Do(Reader* reader, Writer* writer) override;
};

void Join::Do(Reader* reader, Writer* writer)
{
    for (NYT::TNode feature; reader->IsValid(); reader->Next()) {
        if (reader->GetTableIndex() == 0) {
            feature = reader->GetRow();
        } else if (reader->GetTableIndex() == 1){
            writer->AddRow(
                NYT::TNode::CreateMap()
                    ("feature", feature)
                    ("signFeature", reader->GetRow())
            );
        } else {
            throw maps::RuntimeError() << "Unknown table index " << reader->GetTableIndex();
        }
    }
}

REGISTER_REDUCER(Join);

void join(
        NYT::IOperationClient& client,
        const TString& prime,
        const TString& foreign,
        const TString& result)
{
    client.JoinReduce(
        NYT::TJoinReduceOperationSpec()
            .JoinBy({"feature_id"})
            .AddInput<NYT::TNode>(NYT::TRichYPath(foreign).Foreign(true))
            .AddInput<NYT::TNode>(prime)
            .AddOutput<NYT::TNode>(result),
        new Join()
    );
}

// For each pair feature and sign_feature write row with feature_id,
// if the corresponding box has a significant intersection with the car mask.
class SignOnCarDetector: public NYT::IMapper<Reader, Writer>  {

public:
    void Do(Reader* reader, Writer* writer) override;

private:
    bool isSignOnCar(
        const db::SignFeature& signFeature,
        const cv::Mat& carMask,
        double intersectionRatioSignWithCar=0.5) const;

    carsegm::CarSegmentator segmentator_;
};

REGISTER_MAPPER(SignOnCarDetector);

bool SignOnCarDetector::isSignOnCar(
        const db::SignFeature& signFeature,
        const cv::Mat& carMask,
        double intersectionRatioSignWithCar) const
{
    const size_t width = signFeature.maxX() - signFeature.minX();
    const size_t height = signFeature.maxY() - signFeature.minY();

    REQUIRE(height > 0 && width > 0, "Required nondegenerate box");

    const size_t signBoxArea = height * width;

    const size_t intersectionArea = cv::countNonZero(
        carMask(
            cv::Rect(signFeature.minX(), signFeature.minY(), width, height)
        )
    );

    return static_cast<double>(intersectionArea) / signBoxArea >= intersectionRatioSignWithCar;
}

void SignOnCarDetector::Do(Reader* reader, Writer* writer)
{
    using FeatureWithImage = yt::FeatureWithImage<db::Feature>;

    db::TId featureId = 0;

    for (cv::Mat segmentation; reader->IsValid(); reader->Next()) {
        const NYT::TNode row = reader->GetRow();

        const auto featureWithImage = yt::deserialize<FeatureWithImage>(row["feature"]);
        const auto signFeature = yt::deserialize<db::SignFeature>(row["signFeature"]);

        if (featureId != signFeature.featureId()) {
            if (featureWithImage.image.empty()) {
                continue;
            }

            const auto size = featureWithImage.image.size();
            const bool isValidBox =
                    signFeature.minX() >= 0 &&
                    signFeature.minY() >= 0 &&
                    signFeature.maxX() < size.width &&
                    signFeature.maxY() < size.height;

            if (!isValidBox) {
                ERROR() << "Invalid box of feature " << signFeature.featureId();
                continue;
            }

            segmentation = segmentator_.segment(featureWithImage.image);
            featureId = signFeature.featureId();
        }

        if (isSignOnCar(signFeature, segmentation)) {
            writer->AddRow(row["signFeature"]);
        }
    }
}

void sortByFeatureId(NYT::IOperationClient& client, const TString& table)
{
    client.Sort(
        NYT::TSortOperationSpec()
            .AddInput(table)
            .Output(table)
            .SortBy({"feature_id"})
    );
}

void detectSignsOnCar(
        NYT::IOperationClient& client,
        const TString& featureWithBox,
        const TString& result,
        size_t jobCount)
{
    const auto JOB_SPEC = NYT::TNode::CreateMap()
        ("memory_limit", 16_GB)
        ("memory_reserve_factor", 0.6);

    const auto OPERATION_SPEC = NYT::TNode::CreateMap()
        ("title", "Signs on car detector")
        ("mapper", JOB_SPEC);

    client.Map(
        NYT::TMapOperationSpec()
            .AddInput<NYT::TNode>(featureWithBox)
            .AddOutput<NYT::TNode>(result)
            .JobCount(jobCount),
        new SignOnCarDetector(),
        NYT::TOperationOptions().Spec(OPERATION_SPEC)
    );
}

db::TIds findSignsOnCarOnly(const db::SignFeatures& all, const db::SignFeatures& onCar)
{
    auto countFeatureForSign = [](const db::SignFeatures& signFeatures) {
        std::unordered_map<db::TId, size_t> signIdToFeatureCount;

        for (const auto& signFeature: signFeatures) {
            signIdToFeatureCount[signFeature.signId()]++;
        }

        return signIdToFeatureCount;
    };

    const auto signIdToFeatureCount = countFeatureForSign(all);

    db::TIds signIds;
    for (const auto& [signId, count]: countFeatureForSign(onCar)) {
        if (count == signIdToFeatureCount.at(signId)) {
            signIds.push_back(signId);
        }
    }

    return signIds;
}

size_t removeSigns(pgpool3::Pool& mrcPool, const db::TIds& signIds)
{
    auto txn = mrcPool.masterWriteableTransaction();
    const auto result = db::SignGateway(*txn).removeByIds(signIds);
    txn->commit();
    return result;
}

db::TIds collectFeedbackTaskIds(pgpool3::Pool& mrcPool, const db::TIds& signIds)
{
    auto txn = mrcPool.slaveTransaction();

    const auto filter = db::table::Hypothesis::signId.in(signIds);

    db::TIds feedbackTaskIds;
    for (const auto& hypothesis:  db::HypothesisGateway(*txn).load(filter)) {
        if (hypothesis.feedbackTaskId()) {
            feedbackTaskIds.push_back(hypothesis.feedbackTaskId());
        }
    }

    return feedbackTaskIds;
}

void closeFeedbackTask(const MrcConfig& config, db::TId feedbackTaskId)
{
    std::stringstream url;

    url << config.externals().socialBackofficeUrl()
        << "/feedback/tasks/"
        << feedbackTaskId
        << "/resolve";

    http::Client сlient;
    сlient.setTimeout(10s);

    retry(
        [&]() {
            auto request = http::Request(сlient, http::POST, url.str());
            return request.perform();
        },
        maps::common::RetryPolicy().setTryNumber(3)
    );
}

db::Signs loadSigns(pgpool3::Pool& mrcPool, const db::TIds& signIds)
{
    auto txn = mrcPool.slaveTransaction();
    return db::SignGateway(*txn).loadByIds(signIds);
}

void run(const MrcConfig& mrcConfig)
{
    INFO() << "Starting...";

    PoolHolder mrc(mrcConfig.makePoolHolder());

    INFO() << "Load all sign features...";
    const auto signFeatures = loadSignFeatures(mrc.pool());
    INFO() << "Number of sign features " << signFeatures.size();

    NYT::IClientPtr client = mrcConfig.externals().yt().makeClient();
    if (!client->Exists(PATH)) {
        INFO() << "Create node " << PATH;
        client->Create(PATH, NYT::NT_MAP, NYT::TCreateOptions().Recursive(true));
    }

    INFO() << "Load features...";
    const db::Features features = loadFeatures(mrc.pool(), collectFeatureIds(signFeatures));
    INFO() << "Number of features " << features.size();

    INFO() << "Store sign features to " << SIGN_FEATURE_TABLE;
    yt::saveToTable(*client, SIGN_FEATURE_TABLE, signFeatures);

    INFO() << "Sort table " << SIGN_FEATURE_TABLE << " by feature_id";
    sortByFeatureId(*client, SIGN_FEATURE_TABLE);

    INFO() << "Store features to " << FEATURE_TABLE;
    yt::saveToTable(*client, FEATURE_TABLE, features);

    INFO() << "Upload images " << FEATURE_TABLE;
    yt::uploadFeatureImage(
        mrcConfig,
        *client,
        FEATURE_TABLE,
        FEATURE_TABLE
    );

    INFO() << "Sort table " << FEATURE_TABLE << " by feature_id";
    sortByFeatureId(*client, FEATURE_TABLE);

    INFO() << "Join table " << FEATURE_TABLE << " with " << SIGN_FEATURE_TABLE;
    join(*client, SIGN_FEATURE_TABLE, FEATURE_TABLE, FEATURE_WITH_BOX_TABLE);

    INFO() << "Detect signs on cars...";

    constexpr size_t JOB_FEAUTER_N = 30;
    detectSignsOnCar(
        *client,
        FEATURE_WITH_BOX_TABLE,
        SIGN_ON_CAR_TABLE,
        std::max(features.size() / JOB_FEAUTER_N, 1ul)
    );

    const auto signIds = findSignsOnCarOnly(
        signFeatures,
        yt::loadFromTable<db::SignFeatures>(*client, SIGN_ON_CAR_TABLE)
    );
    INFO() << "Found " << signIds.size() << " signs on car only";

    INFO() << "Save dropped signs to " << DROPPED_SIGN_TABLE;
    yt::saveToTable(
        *client,
        DROPPED_SIGN_TABLE,
        loadSigns(mrc.pool(), signIds)
    );

    const db::TIds feedbackTaskIds = collectFeedbackTaskIds(mrc.pool(), signIds);
    INFO() << "Found " << feedbackTaskIds.size() << " feedbacks";
    for (auto feedbackTaskId: feedbackTaskIds) try {
        INFO() << "Close " << feedbackTaskId;
        closeFeedbackTask(mrcConfig, feedbackTaskId);
    } catch (const maps::Exception&) {
        ERROR() << "Impossible closeFeedbackTask feedback tasks" << feedbackTaskId;
    }

    INFO() << "Remove " << removeSigns(mrc.pool(), signIds) << " signs";
    INFO() << "Finish!";
}

} // namepspace
} // namespace maps::mrc::hide

using namespace maps::mrc::hide;

int main(int argc, const char** argv) try {
    NYT::Initialize(argc, argv);

    maps::cmdline::Parser parser("Hide signs on cars");
    auto mrcConfigPath = parser.string("mrc-config").help("path to mrc config");
    parser.parse(argc, const_cast<char**>(argv));

    run(maps::mrc::common::templateConfigFromCmdPath(mrcConfigPath));

    return EXIT_SUCCESS;
} catch (const maps::Exception& ex) {
    FATAL() << "Failed: " << ex;
    return EXIT_FAILURE;
} catch (const yexception& ex) {
    FATAL() << "Failed: " << ex.what();
    if (ex.BackTrace()) {
        FATAL() << ex.BackTrace()->PrintToString();
    }
    return EXIT_FAILURE;
} catch (const std::exception& ex) {
    FATAL() << "Failed: " << ex.what();
    return EXIT_FAILURE;
}
