#include <maps/wikimap/mapspro/services/mrc/libs/common/include/exif.h>
#include <maps/wikimap/mapspro/services/mrc/libs/config/include/config.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sideview_classifier/include/sideview.h>

#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/geolib/include/point.h>
#include <maps/libs/http/include/http.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/pgpool/include/pgpool3.h>

#include <mapreduce/yt/interface/client.h>
#include <mapreduce/yt/util/temp_table.h>

#include <opencv2/opencv.hpp>

#include <vector>
#include <random>

namespace {

static const TString COLUMN_NAME_SOURCE_ID       = "source_id";
static const TString COLUMN_NAME_CLUSTER_IND     = "cluster_ind";
static const TString COLUMN_NAME_URL1            = "url1";
static const TString COLUMN_NAME_URL2            = "url2";
static const TString COLUMN_NAME_ORIENTATION1    = "orientation1";
static const TString COLUMN_NAME_ORIENTATION2    = "orientation2";
static const TString COLUMN_NAME_SIDEVIEW        = "sideview";
static const TString COLUMN_NAME_CONFIDENCE      = "confidence";

struct FeaturesCluster {
    std::string sourceId;
    maps::mrc::db::Features features;
    int frontViewCount = 0;
    float frontViewConfidenceSum = 0.f;
    int sideViewCount = 0;
    float sideViewConfidenceSum = 0.f;
};

maps::mrc::db::Features loadAssignmentPhotos(maps::pgpool3::Pool& pool, maps::mrc::db::TId assignmentId)
{
    auto txn = pool.slaveTransaction();
    return maps::mrc::db::FeatureGateway{*txn}.load(
        maps::mrc::db::table::Feature::assignmentId.equals(assignmentId)
        );
}

bool needSplit(const maps::mrc::db::Feature& lhs, const maps::mrc::db::Feature& rhs, int timeGapSeconds) {
    const auto TIME_GAP = std::chrono::seconds{timeGapSeconds};
    return lhs.sourceId() != rhs.sourceId() || std::chrono::abs(lhs.timestamp() - rhs.timestamp()) > TIME_GAP;
}

std::vector<FeaturesCluster> extractFeatureClusters(maps::pgpool3::Pool& pool, maps::mrc::db::TId assignmentId, int timeGapSeconds)
{
    maps::mrc::db::Features features = loadAssignmentPhotos(pool, assignmentId);
    if (features.empty()) {
        return {};
    }
    std::sort(features.begin(), features.end(),
              [](const auto& lhs, const auto& rhs) {
                  return std::make_tuple(lhs.sourceId(), lhs.timestamp())
                         < std::make_tuple(rhs.sourceId(), rhs.timestamp());
              });

    std::vector<FeaturesCluster> featureClusters;
    auto first = features.begin();
    auto itprev = first;
    for (auto it = first + 1; it != features.end(); it++) {
        if (needSplit(*itprev, *it, timeGapSeconds)) {
            FeaturesCluster cluster;
            cluster.sourceId = first->sourceId();
            cluster.features = {first, it};
            featureClusters.emplace_back(std::move(cluster));
            first = it;
        }
        itprev = it;
    }
    if (first != features.end()) {
            FeaturesCluster cluster;
            cluster.sourceId = first->sourceId();
            cluster.features = {first, features.end()};
            featureClusters.emplace_back(std::move(cluster));
    }
    return featureClusters;
}

TString getFeatureUrl(const maps::mrc::db::Feature& f) {
    return ("http://storage-int.mds.yandex.net/get-maps_mrc/" + f.mdsGroupId() + "/" + f.mdsPath()).c_str();
}

int uploadPairs(NYT::IClientPtr& client, const TString& tableName, const std::vector<FeaturesCluster>& featureClusters, int pairsCount)
{
    static std::default_random_engine engine;

    NYT::TTableWriterPtr<NYT::TNode> writer = client->CreateTableWriter<NYT::TNode>(tableName);

    int uploaded = 0;
    for (size_t i = 0; i < featureClusters.size(); i++) {
        const FeaturesCluster& cluster = featureClusters[i];
        const int clusterSize = (int)cluster.features.size();
        const int pairs = std::min((pairsCount > 0) ? pairsCount : clusterSize / (-pairsCount), clusterSize / 2);
        std::uniform_int_distribution<int> uniform_dist(1, clusterSize - 1);
        uploaded += pairs;
        for (int p = 0; p < pairs; p++) {
            const int idx = uniform_dist(engine);
            const maps::mrc::db::Feature& firstFeature = cluster.features[idx - 1];
            const maps::mrc::db::Feature& secondFeature = cluster.features[idx];
            writer->AddRow(
                NYT::TNode()
                    (COLUMN_NAME_SOURCE_ID,    NYT::TNode(cluster.sourceId.c_str()))
                    (COLUMN_NAME_CLUSTER_IND,  NYT::TNode(i))
                    (COLUMN_NAME_URL1,         NYT::TNode(getFeatureUrl(firstFeature)))
                    (COLUMN_NAME_ORIENTATION1, NYT::TNode(firstFeature.hasOrientation() ? (int)firstFeature.orientation() : (int)1))
                    (COLUMN_NAME_URL2,         NYT::TNode(getFeatureUrl(secondFeature)))
                    (COLUMN_NAME_ORIENTATION2, NYT::TNode(secondFeature.hasOrientation() ? (int)secondFeature.orientation() : (int)1))
            );
        }
    }
    return uploaded;
}

std::vector<uint8_t> downloadImage(maps::http::Client& client, const std::string& url)
{
    maps::common::RetryPolicy retryPolicy;
    retryPolicy.setTryNumber(10)
        .setInitialCooldown(std::chrono::seconds(1))
        .setCooldownBackoff(2);

    auto validateResponse = [](const auto& maybeResponse) {
        return maybeResponse.valid() && maybeResponse.get().responseClass() != maps::http::ResponseClass::ServerError;
    };
    auto resp = maps::common::retry(
                [&]() {
                    return maps::http::Request(client, maps::http::GET, maps::http::URL(url)).perform();
                },
                retryPolicy,
                validateResponse
            );
    REQUIRE(resp.responseClass() == maps::http::ResponseClass::Success,
        "Unexpected response status " << resp.status() << " for url "
        << url);
    return resp.readBodyToVector();
}

cv::Mat loadImage(maps::http::Client& client, const std::string& url, int orientation)
{
    std::vector<uint8_t> data = downloadImage(client, url);
    cv::Mat image = cv::imdecode(data, cv::IMREAD_COLOR | cv::IMREAD_IGNORE_ORIENTATION);
    return maps::mrc::common::transformByImageOrientation(image, maps::mrc::common::ImageOrientation::fromExif(orientation));
}

// YT Mappers
class TClassifierMapper
    : public NYT::IMapper<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<NYT::TNode>>  {
public:
    void Do(NYT::TTableReader<NYT::TNode>* reader, NYT::TTableWriter<NYT::TNode>* writer) override {
        INFO() << "Start classification ... ";
        maps::http::Client client;

        maps::mrc::sideview::SideViewClassifier classifier;
        for (; reader->IsValid(); reader->Next()) {
            const NYT::TNode &inpRow = reader->GetRow();

            std::pair<maps::mrc::sideview::SideViewType, float> classifierResult = classifier.inference(
                loadImage(client, inpRow[COLUMN_NAME_URL1].AsString(), inpRow[COLUMN_NAME_ORIENTATION1].AsInt64()),
                loadImage(client, inpRow[COLUMN_NAME_URL2].AsString(), inpRow[COLUMN_NAME_ORIENTATION2].AsInt64())
            );
            NYT::TNode outRow = inpRow;
            outRow[COLUMN_NAME_SIDEVIEW]   = (classifierResult.first == maps::mrc::sideview::SideViewType::ForwardView) ? 0 : 1;
            outRow[COLUMN_NAME_CONFIDENCE] = classifierResult.second;
            writer->AddRow(outRow);
        }
    }
};
REGISTER_MAPPER(TClassifierMapper);

void loadClassificationResults(
    NYT::IClientPtr& client,
    const TString& tableName,
    std::vector<FeaturesCluster>& featureClusters)
{
    NYT::TTableReaderPtr<NYT::TNode> reader = client->CreateTableReader<NYT::TNode>(tableName);

    for (;reader->IsValid(); reader->Next()) {
        const NYT::TNode& inpRow = reader->GetRow();
        const size_t clusterInd = inpRow[COLUMN_NAME_CLUSTER_IND].AsUint64();
        FeaturesCluster& cluster = featureClusters[clusterInd];
        if (0 == inpRow[COLUMN_NAME_SIDEVIEW].AsInt64()) {
            cluster.frontViewCount++;
            cluster.frontViewConfidenceSum += inpRow[COLUMN_NAME_CONFIDENCE].AsDouble();
        } else {
            cluster.sideViewCount++;
            cluster.sideViewConfidenceSum += inpRow[COLUMN_NAME_CONFIDENCE].AsDouble();
        }
    }
}

void updateFeaturesCameraDeviation(
    maps::pgpool3::Pool& pool,
    const maps::mrc::db::Features& features,
    maps::mrc::db::CameraDeviation cameraDeviation,
    bool dryRun)
{
    maps::mrc::db::Features filtered;
    for (size_t i = 0; i < features.size(); i++) {
        const maps::mrc::db::Feature feature = features[i];
        if (!feature.hasCameraDeviation() || feature.cameraDeviation() != cameraDeviation)
        {
            maps::mrc::db::Feature updated = feature;
            updated.setCameraDeviation(cameraDeviation);
            filtered.push_back(updated);
        }
    }
    if (dryRun) {
        INFO() << "Need update: " << filtered.size() << " features from " << features.size()
               << " to camera deviation " << ((maps::mrc::db::CameraDeviation::Front == cameraDeviation) ? "0" : "90");
        for (size_t i = 0; i < filtered.size(); i++) {
            INFO() << "    fid = " << filtered[i].id();
        }
    } else if (!filtered.empty()) {
        auto txn = pool.masterWriteableTransaction();
        maps::mrc::db::FeatureGateway{*txn}.update(filtered, maps::mrc::db::UpdateFeatureTxn::Yes);
        txn->commit();

        INFO() << "Updated: " << filtered.size() << " features from " << features.size()
               << " to camera deviation " << ((maps::mrc::db::CameraDeviation::Front == cameraDeviation) ? "0" : "90");
        for (size_t i = 0; i < filtered.size(); i++) {
            INFO() << "    fid = " << filtered[i].id();
        }
    }
}

void updateFeaturesCameraDeviationsByCount(
    const std::vector<FeaturesCluster>& featureClusters,
    maps::pgpool3::Pool& pool,
    double minVotesFraction,
    bool dryRun)
{
    for (size_t i = 0; i < featureClusters.size(); i++) {
        const FeaturesCluster& cluster = featureClusters[i];
        const int votes = cluster.frontViewCount + cluster.sideViewCount;
        if (cluster.frontViewCount > minVotesFraction * votes) {
            updateFeaturesCameraDeviation(pool, cluster.features, maps::mrc::db::CameraDeviation::Front, dryRun);
        } else if (cluster.sideViewCount > minVotesFraction * votes) {
            updateFeaturesCameraDeviation(pool, cluster.features, maps::mrc::db::CameraDeviation::Right, dryRun);
        } else if (2 <= cluster.features.size()) {
            INFO()  << "Cluster " << i << " features: " << cluster.features[0].id() << ", " << cluster.features[1].id() << "..."
                    << " has " << cluster.frontViewCount << " front view pairs and " << cluster.sideViewCount << " side view pairs."
                    << " Camera deviation didn't change";
        }
    }
}

void updateFeaturesCameraDeviationsBySumConfidence(
    const std::vector<FeaturesCluster>& featureClusters,
    maps::pgpool3::Pool& pool,
    double minVotesFraction,
    bool dryRun)
{
    for (size_t i = 0; i < featureClusters.size(); i++) {
        const FeaturesCluster& cluster = featureClusters[i];
        const float votesSum = cluster.frontViewConfidenceSum + cluster.sideViewConfidenceSum;
        if (cluster.frontViewConfidenceSum > minVotesFraction * votesSum) {
            updateFeaturesCameraDeviation(pool, cluster.features, maps::mrc::db::CameraDeviation::Front, dryRun);
        } else if (cluster.sideViewCount > minVotesFraction * votesSum) {
            updateFeaturesCameraDeviation(pool, cluster.features, maps::mrc::db::CameraDeviation::Right, dryRun);
        } else if (2 <= cluster.features.size()) {
            INFO()  << "Cluster " << i << " features: " << cluster.features[0].id() << ", " << cluster.features[1].id() << "..."
                    << " has " << cluster.frontViewConfidenceSum << " front view confidence and " << cluster.sideViewConfidenceSum << " side view confidence."
                    << " Camera deviation didn't change";
        }
    }
}

} //namespace

int main(int argc, const char** argv) try {
    static const TString YT_PROXY = "hahn";

    NYT::Initialize(argc, argv);

    maps::cmdline::Parser parser("Update camera deviation");

    maps::cmdline::Option<int> assignmentId = parser.num("assignment-id")
        .required()
        .help("Assigment ID");

    maps::cmdline::Option<std::string> mrcConfigPath = parser.string("mrc-config")
        .help("Path to mrc config");

    maps::cmdline::Option<std::string> secretVersion = parser.string("secret-version")
        .help("version for secrets from yav.yandex-team.ru");

    maps::cmdline::Option<int> timeGapSeconds = parser.num("time-gap")
        .defaultValue(300)
        .help("Time gap in seconds for split features to different cluster (default: 300)");

    maps::cmdline::Option<int> samplesCount = parser.num("samples-cnt")
        .defaultValue(-20)
        .help("Samples for one cluster, if negative value use as fraction of all samples in cluster (default: -20)");

    maps::cmdline::Option<double> minVotesFraction = parser.real("min-votes-frac")
        .defaultValue(0.7)
        .help("Minimal fraction of votes to apply calculated camera deviation (default: 0.7)");

    maps::cmdline::Option<bool> useConfidenceSum = parser.flag("use-confidence")
        .help("Use sum off confidence from classifier instead of count");

    maps::cmdline::Option<bool> dryRun = parser.flag("dry-run")
        .help("Don't apply changes to DB");

    parser.parse(argc, const_cast<char**>(argv));

    REQUIRE(0.5 < minVotesFraction && minVotesFraction < 1.0,  "min-votes-frac parameters should be between 0.5 and 1.0");

    const maps::mrc::common::Config mrcConfig =
        maps::mrc::common::templateConfigFromCmdPath(secretVersion, mrcConfigPath);

    maps::wiki::common::PoolHolder mrc(mrcConfig.makePoolHolder());

    std::vector<FeaturesCluster> featureClusters = extractFeatureClusters(mrc.pool(), assignmentId, timeGapSeconds);
    INFO() << "Clusters count: " << featureClusters.size();
    if (featureClusters.empty()) {
        return EXIT_SUCCESS;
    }

    INFO() << "Connecting to yt::" << YT_PROXY;
    NYT::IClientPtr client = NYT::CreateClient(YT_PROXY);
    const NYT::TTempTable inputTable(client);
    int uploaded = uploadPairs(client, inputTable.Name(), featureClusters, samplesCount);
    INFO() << uploaded << " pairs uploaded for classifier to " << inputTable.Name();

    const NYT::TTempTable outputTable(client);
    INFO() << "Start classifier to " << outputTable.Name();
    client->Map(
        NYT::TMapOperationSpec()
            .AddInput<NYT::TNode>(inputTable.Name())
            .AddOutput<NYT::TNode>(outputTable.Name())
            .JobCount(std::max(1, uploaded / 10)),
        new TClassifierMapper()
    );
    INFO() << "Classifier finished";

    loadClassificationResults(client, outputTable.Name(), featureClusters);
    if (useConfidenceSum) {
        updateFeaturesCameraDeviationsBySumConfidence(featureClusters, mrc.pool(), minVotesFraction, dryRun);
    } else {
        updateFeaturesCameraDeviationsByCount(featureClusters, mrc.pool(), minVotesFraction, dryRun);
    }

    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
