#include "yt.h"
#include "metadata_keys.h"

#include <maps/libs/chrono/include/time_point.h>
#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/metadata_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/object_in_photo_gateway.h>
#include <maps/libs/common/include/make_batches.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/pg_locks.h>
#include <maps/libs/log8/include/log8.h>
#include <yandex/maps/pgpool3utils/pg_advisory_mutex.h>
#include <mapreduce/yt/interface/client.h>

namespace db = maps::mrc::db;

namespace {

constexpr std::size_t FEATURES_LIMIT = 40000;
constexpr std::size_t BATCH_SIZE = 1000;

void copyClassifierFields(const db::Feature& from, db::Feature& to)
{
    if (from.hasOrientation())
        to.setOrientation(from.orientation());
    if (from.hasQuality())
        to.setQuality(from.quality());
    if (from.hasRoadProbability())
        to.setRoadProbability(from.roadProbability());
    if (from.hasForbiddenProbability())
        to.setForbiddenProbability(from.forbiddenProbability());
    if (from.hasSize())
        to.setSize(from.size());
}

db::Features loadFeatures(maps::pgpool3::Pool& pool) {
    auto txn = pool.slaveTransaction();

    INFO() << "Load features from db";
    db::FeatureGateway gateway(*txn);

    return gateway.load(db::table::Feature::quality.isNull()
            && db::table::Feature::roadProbability.isNull()
            && db::table::Feature::forbiddenProbability.isNull(),
         maps::sql_chemistry::limit(FEATURES_LIMIT));
}

void save(const maps::mrc::image_analyzer::ProcessedFeatureById& featureById,
          maps::pgpool3::Pool& pool)
{
    INFO() << "Save " << featureById.size() << " features to db";
    for (auto batch : maps::common::makeBatches(featureById, BATCH_SIZE)) {
        auto txn = pool.masterWriteableTransaction();
        db::TIds featureIds;
        db::ObjectsInPhoto objects;

        featureIds.reserve(BATCH_SIZE);
        for (auto it = batch.begin(); it != batch.end(); ++it) {
            featureIds.push_back(it->second.feature.id());
            objects.insert(objects.end(),
                           it->second.privacyObjects.begin(),
                           it->second.privacyObjects.end());
        }

        // To avoid overwriting possible changes made during YT precessing
        // set updated fields directly into the features loaded from DB.
        db::FeatureGateway featureGtw{*txn};

        db::Features dbFeatures = featureGtw.loadByIds(featureIds);
        for (auto& dbFeature : dbFeatures) {
            const auto& ytFeature = featureById.at(dbFeature.id()).feature;
            copyClassifierFields(ytFeature, dbFeature);
        }

        featureGtw.update(dbFeatures, db::UpdateFeatureTxn::No);
        db::ObjectInPhotoGateway{*txn}.insert(objects);
        txn->commit();
    }
}

void updateMetadata(maps::pgpool3::Pool& pool)
{
    INFO() << "Update metadata";
    auto txn = pool.masterWriteableTransaction();

    db::MetadataGateway{*txn}.upsertByKey(
        maps::mrc::image_analyzer::LAST_RUN_TIME,
        maps::chrono::formatSqlDateTime(maps::chrono::TimePoint::clock::now()));

    txn->commit();
}

} // namespace

int main(int argc, const char** argv) try
{
    NYT::Initialize(argc, argv);

    maps::cmdline::Parser parser(
        "This service runs image classifiers/detectors on every feature which"
        " is not yet processed in the database. Classifiers are run on YT."
        " The results are saved back to the database."
    );
    auto syslog = parser.string("syslog-tag")
        .help("redirect log output to syslog with given tag");

    auto configPath = parser.string("config").help("path to configuration");

    auto secretVersion = parser.string("secret-version")
            .help("version for secrets from yav.yandex-team.ru");

    auto useGpu = parser.flag("use-gpu").help("Use GPU in yt operations");

    auto infiniteLoop = parser.flag("infinite-loop")
            .help("the program will work constantly in infinite loop");

    auto dryRun = parser.flag("dry-run")
            .help("do not save changes to database or publish data to social");


    parser.parse(argc, const_cast<char**>(argv));

    if (syslog.defined()) {
        maps::log8::setBackend(maps::log8::toSyslog(syslog));
    }

    const auto mrcConfig =
        maps::mrc::common::templateConfigFromCmdPath(secretVersion, configPath);
    maps::wiki::common::PoolHolder poolHolder(mrcConfig.makePoolHolder());

    constexpr auto RETRY_PERIOD = std::chrono::minutes(5);

    auto runOnce =
        [&]() -> size_t {
            maps::pgp3utils::PgAdvisoryXactMutex mutex(
                poolHolder.pool(),
                static_cast<int64_t>(maps::mrc::common::LockId::ImageAnalyzerYt));
            if (!mutex.try_lock()) {
                INFO() << "Another process is ongoing";
                return 0;
            }

            db::Features features = loadFeatures(poolHolder.pool());
            INFO() << "Loaded " << features.size() << " features";
            if (features.empty()) {
                return 0;
            }

            auto results = maps::mrc::image_analyzer::processOnYT(
                features, mrcConfig,
                useGpu ? maps::mrc::image_analyzer::UseGpu::Yes
                    : maps::mrc::image_analyzer::UseGpu::No);

            if (!dryRun) {
                INFO() << "Saving results to database";
                save(results, poolHolder.pool());
            }

            return features.size();
        };

    do {
        bool shouldWait = false;

        try {
            size_t processedFeatures = runOnce();

            if (!dryRun) {
                updateMetadata(poolHolder.pool());
            }

            if (processedFeatures == 0) {
                shouldWait = true;
            }
        } catch (const maps::Exception& e) {
            ERROR() << "Run cycle failed: " << e;
            shouldWait = true;
        } catch (const yexception& e) {
            ERROR() << "Run cycle failed: " << e.what();
            if (e.BackTrace()) {
                ERROR() << e.BackTrace()->PrintToString();
            }
            shouldWait = true;
        } catch (const std::exception& e) {
            ERROR() << "Run cycle failed: " << e.what();
            shouldWait = true;
        }

        if (infiniteLoop && shouldWait) {
            INFO() << "Sleeping";
            std::this_thread::sleep_for(RETRY_PERIOD);
        }
    } while (infiniteLoop);

    return EXIT_SUCCESS;
} catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
} catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
