#include "passage_loader.h"
#include "yt.h"

#include <maps/libs/chrono/include/time_point.h>
#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/chrono/include/time_point.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/metadata_gateway.h>
#include <maps/libs/common/include/make_batches.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/pg_locks.h>
#include <yandex/maps/wiki/common/string_utils.h>
#include <maps/libs/log8/include/log8.h>
#include <mapreduce/yt/interface/client.h>

#include <maps/libs/sql_chemistry/include/order.h>

#include <random>

namespace db = maps::mrc::db;
using namespace maps::mrc::sideview_classifier;

namespace {

const std::string APP_NAME = "sideview_classifier.tool";
const std::string LAST_FEATURE_ID = APP_NAME + ".feature_id";

constexpr std::size_t FEATURES_BATCH_SIZE = 2000000;
constexpr db::TId START_FEATURE_ID = 34854949;
constexpr int SAMPLES_PER_PASSAGE = 7;

db::TIds loadFeatureIds(maps::pgpool3::Pool& pool, db::TId initialId)
{
    auto txn = pool.slaveTransaction();
    db::FeatureGateway gtw(*txn);

    return gtw.loadIds(
        db::table::Feature::id > initialId
            && db::table::Feature::dataset != db::Dataset::Walks,
        maps::sql_chemistry::orderBy(db::table::Feature::id).asc());
}

void saveFeatures(maps::pgpool3::Pool& pool,
                  db::Features& features,
                  bool skipCommit)
{
    INFO() << "Save " << features.size() << " features to db...";

    db::TIds ids;
    std::unordered_map<db::TId, db::Feature> featuresById;
    ids.reserve(features.size());
    featuresById.reserve(features.size());

    for (auto& feature : features) {
        auto id = feature.id();
        featuresById.emplace(id, std::move(feature));
        ids.push_back(id);
    }

    // To avoid overwriting possible changes made during YT precessing
    // set updated fields directly into the features loaded from DB.
    auto txn = pool.masterWriteableTransaction();
    db::FeatureGateway featureGtw{*txn};

    db::Features dbFeatures = featureGtw.loadByIds(ids);
    for (auto& dbFeature : dbFeatures) {
        const auto& classifiedFeature = featuresById.at(dbFeature.id());
        dbFeature.setCameraDeviation(classifiedFeature.cameraDeviation());
    }

    featureGtw.update(dbFeatures, db::UpdateFeatureTxn::Yes);
    if (skipCommit) {
        INFO() << "Dry run => skip commit";
        return;
    }

    txn->commit();
    INFO() << "Done saving";
}

template <class FeatureIt>
std::vector<FeatureIt> sample(FeatureIt first, FeatureIt last)
{
    static std::mt19937 engine{std::random_device{}()};
    std::vector<FeatureIt> all;
    for (auto it = first; it != last; ++it) {
        all.push_back(it);
    }
    std::vector<FeatureIt> result;
    std::sample(all.begin(), all.end(),
                std::back_inserter(result), SAMPLES_PER_PASSAGE, engine);
    return result;
}

maps::mrc::sideview_classifier::Inputs
prepareInputs(const PassageKeyToFeaturesMap& passagesByKey)
{
    Inputs inputs;
    inputs.reserve(passagesByKey.size() * SAMPLES_PER_PASSAGE);

    for (const auto& [key, features] : passagesByKey) {
        for (auto it : sample(features.begin(), features.end())) {
            if (it != features.begin()) {
                auto second = it;
                auto first = std::prev(it);
                inputs.push_back({key, *first, *second});
            }
        }
    }
    return inputs;
}


db::CameraDeviation estimateCameraDeviation(const Outputs& outputs)
{
    float forwardConfidence = 0.;
    float sideConfidence = 0.;
    for (const auto& output : outputs) {
        if (output.type == maps::mrc::sideview::SideViewType::SideView) {
            sideConfidence += output.confidence;
        } else {
            forwardConfidence += output.confidence;
        }
    }

    bool isSidePassage = sideConfidence > forwardConfidence;

    return isSidePassage ? db::CameraDeviation::Right
                         : db::CameraDeviation::Front;
}

void saveOutputs(maps::pgpool3::Pool& pool,
                 PassageKeyToFeaturesMap& passagesByKey,
                 const maps::mrc::sideview_classifier::Outputs& outputs,
                 bool skipCommit)
{
    size_t featuresCounter = 0;
    size_t passagesCounter = 0;
    OutputsToFeaturesMap outputsByKey;

    for (auto&& output : outputs) {
        outputsByKey[output.key].push_back(std::move(output));
    }

    for (const auto& [key, outputs] : outputsByKey) {
        auto cameraDeviation = estimateCameraDeviation(outputs);

        auto& features = passagesByKey[key];
        for (auto& feature : features) {
            feature.setCameraDeviation(cameraDeviation);
        }

        if (cameraDeviation == db::CameraDeviation::Right) {
            INFO() << "Set camera deviation RIGHT for " << features.size()
                    << " features with sourceID = " << key.sourceId
                    << ", featureID from " << features.front().id()
                    << " to " << features.back().id();
            saveFeatures(pool, features, skipCommit);
            featuresCounter += features.size();
            ++passagesCounter;
        }
    }

    INFO() << "SIDE passages found: " << passagesCounter;
    INFO() << "SIDE features found: " << featuresCounter;
}

void updateMetadata(maps::pgpool3::Pool& pool,
                  db::TId featureId,
                  bool skipCommit)
{
    if (!skipCommit) {
        INFO() << "Update metadata";
        auto txn = pool.masterWriteableTransaction();
        db::MetadataGateway metadataGtw{*txn};
        metadataGtw.upsertByKey(LAST_FEATURE_ID, std::to_string(featureId));
        txn->commit();
    }
}


} // namespace

int main(int argc, const char** argv) try
{
    NYT::Initialize(argc, argv);

    maps::cmdline::Parser parser(
        "This tool classifies camera direction on photos. "
        "Classifiers are run on YT. "
        "The results are saved back to the database."
    );
    auto syslog = parser.string("syslog-tag")
        .help("redirect log output to syslog with given tag");
    auto configPath = parser.string("config").help("path to configuration");
    auto startId = parser.num("start-id")
        .help("start feature id. If not specified, taken from db metadata");
    auto batchSize = parser.num("batch-size").help("features batch size")
        .defaultValue(FEATURES_BATCH_SIZE);
    auto dryRun = parser.flag("dry-run").help("Don't commit to database");

    parser.parse(argc, const_cast<char**>(argv));

    if (syslog.defined()) {
        maps::log8::setBackend(maps::log8::toSyslog(syslog));
    }

    auto mrcConfig = maps::mrc::common::templateConfigFromCmdPath(configPath);
    maps::wiki::common::PoolHolder poolHolder(mrcConfig.makePoolHolder());

    db::TId startFeatureId = START_FEATURE_ID;
    if (startId.defined()) {
        startFeatureId = startId;
    } else {
        auto txn = poolHolder.pool().slaveTransaction();
        auto savedId = db::MetadataGateway{*txn}.tryLoadByKey(LAST_FEATURE_ID);
        startFeatureId = savedId ? std::stoul(*savedId) : START_FEATURE_ID;
    }

    bool skipCommit = false;
    if (dryRun.defined()) {
        skipCommit = dryRun;
    }

    INFO() << "Start feature ID = " << startFeatureId;

    auto ids = loadFeatureIds(poolHolder.pool(), startFeatureId);

    for (const auto& batch : maps::common::makeBatches(ids, batchSize)) {
        db::TIds idsBatch{batch.begin(), batch.end()};

        auto passagesByKey = maps::mrc::sideview_classifier::loadPassages(
            poolHolder.pool(), idsBatch);
        INFO() << "Passages loaded: " << passagesByKey.size();

        auto inputs = prepareInputs(passagesByKey);

        auto outputs = maps::mrc::sideview_classifier::classifyOnYT(inputs, mrcConfig);

        saveOutputs(poolHolder.pool(), passagesByKey, outputs, skipCommit);

        updateMetadata(poolHolder.pool(), idsBatch.back(), skipCommit);
    }

    INFO() << "Done";
    return EXIT_SUCCESS;
} catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
} catch (const yexception& e) {
    FATAL() << "Worker failed: " << e.what();
    if (e.BackTrace()) {
        FATAL() << e.BackTrace()->PrintToString();
    }
    return EXIT_FAILURE;
} catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
