#include "context.h"

#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/for_each_batch.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/for_each_passage.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/threadpool_wrapper.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/utility.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/dataset.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/disqualified_source_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/ride_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/takeout_data_erasure_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/track_point_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/ugc/gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/ugc_account_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/graph_matcher_adapter/include/feature_positioner.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sensors_feature_positioner/include/sensors_feature_positioner_pool.h>

#include <maps/libs/log8/include/log8.h>
#include <maps/libs/sql_chemistry/include/batch_load.h>

namespace maps::mrc::feature_publisher {
namespace {

bool isRidesDataset(db::Dataset dataset)
{
    const auto& ridesDataset = db::ridePhotosDatasets();
    return std::find(ridesDataset.begin(), ridesDataset.end(), dataset) !=
           ridesDataset.end();
}

bool areClassifiersOk(const db::Feature& feature)
{
    const auto GOOD_QUALITY_THRESHOLD = .5;
    const auto BAD_QUALITY_THRESHOLD = .05;
    const auto FORBIDDEN_THRESHOLD = .5;
    if (feature.dataset() == db::Dataset::Agents) {
        if (feature.cameraDeviation() != db::CameraDeviation::Front ||
            feature.privacy() != db::FeaturePrivacy::Public) {
            return feature.quality() >= BAD_QUALITY_THRESHOLD;
        }
        return feature.quality() >= GOOD_QUALITY_THRESHOLD;
    }
    if (db::isStandalonePhotosDataset(feature.dataset())) {
        return feature.forbiddenProbability() <= FORBIDDEN_THRESHOLD;
    }
    return feature.forbiddenProbability() <= FORBIDDEN_THRESHOLD &&
           feature.quality() >= BAD_QUALITY_THRESHOLD;
}

bool overlap(const db::Feature& feature,
             const db::DeletedIntervals& deletedIntervals)
{
    for (const auto& deletedInterval : deletedIntervals) {
        if (common::between(feature.timestamp(),
                            deletedInterval.startedAt(),
                            deletedInterval.endedAt())) {
            return true;
        }
    }
    return false;
}

}  // namespace

Context::Context(const common::Config& cfg,
                 const std::string& roadGraphPath,
                 const std::string& pedestrianGraphPath,
                 privacy::RegionPrivacyPtr regionPrivacy,
                 ICameraDeviationClassifierHolder cameraDeviationClassifier)
    : poolHolder_(cfg.makePoolHolder(common::LONG_READ_DB_ID,
                                     common::LONG_READ_POOL_ID))
    , mds_(cfg.makeMdsClient())
    , roadMatcher_(roadGraphPath)
    , pedestrianMatcher_(pedestrianGraphPath)
    , regionPrivacy_(std::move(regionPrivacy))
    , cameraDeviationClassifier_(std::move(cameraDeviationClassifier))
{
}

void Context::setCameraDeviationClassifier(
    ICameraDeviationClassifierHolder cameraDeviationClassifier)
{
    cameraDeviationClassifier_ = std::move(cameraDeviationClassifier);
}

db::TrackPoints Context::loadTrackPoints(const std::string& sourceId,
                                         chrono::TimePoint startTime,
                                         chrono::TimePoint endTime)
{
    return db::TrackPointGateway{*pool().slaveTransaction()}.load(
        db::table::TrackPoint::sourceId.equals(sourceId) &&
        db::table::TrackPoint::timestamp.between(startTime, endTime));
}

db::TrackPoints Context::loadAssignmentTrackPoints(db::TId assignmentId)
{
    return db::TrackPointGateway{*pool().slaveTransaction()}.load(
        db::table::TrackPoint::assignmentId.equals(assignmentId));
}

common::Blob Context::loadImage(const db::Feature& feature)
{
    return mds_.get(feature.mdsKey());
}

std::vector<db::CameraDeviation> Context::loadAllowedCameraDeviations(
    const db::Feature& feature)
{
    if (feature.dataset() == db::Dataset::Agents) {
        auto txn = pool().slaveTransaction();
        auto assignment = db::ugc::AssignmentGateway{*txn}.loadById(
            feature.assignmentId().value());
        auto task = db::ugc::TaskGateway{*txn}.loadById(assignment.taskId());
        return task.cameraDeviations();
    } else if (isRidesDataset(feature.dataset()) &&
            feature.graph() == db::GraphType::Road)
    {
        return {db::CameraDeviation::Front, db::CameraDeviation::Right};
    }
    return {};
}

void Context::evalSize(db::Features::iterator first,
                       db::Features::iterator last)
{
    std::for_each(first, last, [this](auto& feature) {
        if (feature.hasSize()) {
            return;
        }
        try {
            auto blob = loadImage(feature);
            auto mat = common::decodeImage(blob);
            auto size = mat.size();
            feature.setSize(size.width, size.height);
        }
        catch (const cv::Exception& e) {
            WARN() << "evalSize(" << feature.id() << "): " << e.what();
        }
    });
}

void Context::evalUserSetingsForPassage(db::Features::iterator first,
                                        db::Features::iterator last)
{
    forEachEqualRange(
        first,
        last,
        common::equalFn(common::makeTupleFn(&db::Feature::dataset,
                                            &db::Feature::userId,
                                            &db::Feature::clientRideId)),
        [this](auto first, auto last) {
            if (!first->userId()) {
                return;
            }
            auto& userId = first->userId().value();
            auto account =
                db::UgcAccountGateway{*pool().slaveTransaction()}.tryLoadOne(
                    db::table::UgcAccount::userId == userId);
            auto gdpr =
                db::TakeoutDataErasureGateway{*pool().slaveTransaction()}
                    .tryLoadOne(
                        db::table::TakeoutDataErasure::userId == userId,
                        orderBy(db::table::TakeoutDataErasure::requestedAt)
                            .desc()
                            .limit(1));
            auto deletedIntervals = db::DeletedIntervals{};
            if (first->dataset() == db::Dataset::Rides &&
                first->sourceId() != db::feature::NO_SOURCE_ID) {
                auto allOf = sql_chemistry::FiltersCollection{
                    sql_chemistry::op::Logical::And};
                allOf.add(db::table::DeletedInterval::userId == userId);
                allOf.add(db::table::DeletedInterval::sourceId ==
                          first->sourceId());
                if (first->clientRideId()) {
                    allOf.add(db::table::DeletedInterval::clientRideId ==
                              *first->clientRideId());
                }
                else {
                    allOf.add(
                        db::table::DeletedInterval::clientRideId.isNull());
                }
                allOf.add(db::table::DeletedInterval::startedAt <=
                          std::prev(last)->timestamp());
                allOf.add(db::table::DeletedInterval::endedAt >=
                          first->timestamp());
                deletedIntervals =
                    db::DeletedIntervalGateway{*pool().slaveTransaction()}.load(
                        allOf);
            }
            std::for_each(first, last, [&](auto& feature) {
                if (account && account->showAuthorship()) {
                    feature.setShowAuthorship(true);
                }
                if (gdpr && feature.timestamp() <= gdpr->requestedAt()) {
                    feature.setGdprDeleted(true);
                }
                if (overlap(feature, deletedIntervals)) {
                    feature.setDeletedByUser(true);
                }
            });
        });
}

void Context::evalPositionForPassage(db::Features::iterator first,
                                     db::Features::iterator last)
{
    forEachEqualRange(
        first,
        last,
        common::equalFn(common::makeTupleFn(&db::Feature::dataset,
                                            &db::Feature::assignmentId)),
        [this](auto first, auto last) {
            if (first->sourceId() == db::feature::NO_SOURCE_ID) {
                return;
            }
            else if (first->dataset() == db::Dataset::Agents) {
                adapters::FeaturePositioner{
                    {{db::GraphType::Road, &roadMatcher_}},
                    loadTrackPointsFn(),
                    adapters::classifyAs(track_classifier::TrackType::Vehicle)}(
                    first, last);
                auto assignmentId = first->assignmentId().value();
                sensors_feature_positioner::SensorsFeaturePositionerPool{
                    mds_,
                    pool(),
                    roadMatcher_,
                    assignmentId,
                    loadAssignmentTrackPoints(assignmentId)}
                    .applyPositionIfPossible(first, last);
            }
            else if (isRidesDataset(first->dataset())) {
                adapters::FeaturePositioner{
                    {{db::GraphType::Road, &roadMatcher_},
                     {db::GraphType::Pedestrian, &pedestrianMatcher_}},
                    loadTrackPointsFn()}(first, last);
                auto MIN_TRACK = std::chrono::minutes(40);
                auto startTime = first->timestamp();
                auto endTime = std::prev(last)->timestamp();
                auto duration = endTime - startTime;
                auto margin = std::max((MIN_TRACK - duration) / 2, {});
                sensors_feature_positioner::SensorsFeaturePositionerPool{
                    mds_,
                    pool(),
                    roadMatcher_,
                    first->sourceId(),
                    startTime - margin,
                    endTime + margin}
                    .applyPositionIfPossible(first, last);
            }
        });
}

void Context::evalCameraDeviationForPassage(db::Features::iterator first,
                                            db::Features::iterator last)
{
    forEachEqualRange(
        first,
        last,
        common::equalFn(common::makeTupleFn(&db::Feature::dataset,
                                            &db::Feature::assignmentId,
                                            &db::Feature::graph,
                                            &db::Feature::hasSize)),
        [this](auto first, auto last) {
            try {
                if (first->sourceId() == db::feature::NO_SOURCE_ID ||
                    !first->hasSize()) {
                    return;
                }
                auto allowedCameraDeviations =
                    loadAllowedCameraDeviations(*first);
                if (allowedCameraDeviations.empty()) {
                    return;
                }
                auto cameraDeviation =
                    cameraDeviationClassifier_->evalForPassage(
                        first, last, [this](const auto& feature) {
                            ASSERT(feature.hasSize());
                            return loadImage(feature);
                        });
                if (std::find(allowedCameraDeviations.begin(),
                              allowedCameraDeviations.end(),
                              cameraDeviation) ==
                    allowedCameraDeviations.end()) {
                    return;
                }
                std::for_each(first, last, [cameraDeviation](auto& feature) {
                    feature.setCameraDeviation(cameraDeviation);
                });
            }
            catch (const cv::Exception& e) {
                WARN() << "evalCameraDeviationForPassage(" << first->id() << "-"
                       << std::prev(last)->id() << "): " << e.what();
            }
        });
}

db::FeaturePrivacy Context::evalFeatureRegionPrivacy(const db::Feature& feature) const
{
    if (feature.hasPos()) {
        return regionPrivacy_->evalFeaturePrivacy(feature.geodeticPos());
    }
    return db::FeaturePrivacy::Public;
}

db::FeaturePrivacy Context::evalFeaturePrivacy(const db::Feature& feature) const
{
    return db::selectStricterPrivacy(
        evalFeatureRegionPrivacy(feature),
        db::evalPrivacy(feature.dataset())
    );
}

bool isPublishingDisabled(sql_chemistry::Transaction& txn,
                         const std::string& sourceId, chrono::TimePoint now)
{
    auto disqs = db::DisqualifiedSourceGateway{txn}.load(
        db::table::DisqualifiedSource::sourceId == sourceId &&
            db::table::DisqualifiedSource::disqType ==
                db::DisqType::DisablePublishing,
        orderBy(db::table::DisqualifiedSource::endedAt).desc().limit(1));
    return !disqs.empty() && now <= disqs.front().endedAt().value_or(now);
}

void Context::publishPassage(db::Features::iterator first,
                             db::Features::iterator last)
{
    ASSERT(first != last);
    auto now = chrono::TimePoint::clock::now();
    auto disabled =
        first->sourceId() == db::feature::NO_SOURCE_ID
            ? false
            : isPublishingDisabled(
                  *pool().slaveTransaction(), first->sourceId(), now);
    std::for_each(first, last, [this, now, disabled](auto& feature) {
        feature.setPrivacy(evalFeaturePrivacy(feature));
        auto toPublish = !disabled && feature.hasPos() &&
                         feature.hasHeading() && feature.hasSize() &&
                         areClassifiersOk(feature);
        feature.setAutomaticShouldBePublished(toPublish).setProcessedAt(now);
    });
}

void Context::processPhotos(db::Features features)
{
    std::sort(features.begin(),
              features.end(),
              common::lessFn(common::makeTupleFn(&db::Feature::sourceId,
                                                 &db::Feature::timestamp)));

    auto threadPool = common::ThreadpoolWrapper{4u /* threadsNumber */};
    common::forEachPassage(
        features.begin(),
        features.end(),
        [this, &threadPool](auto first, auto last) {
            threadPool->add([this, first, last] {
                if (first->sourceId() != db::feature::NO_SOURCE_ID) {
                    evalSize(first, last);
                    evalUserSetingsForPassage(first, last);
                    evalPositionForPassage(first, last);
                    evalCameraDeviationForPassage(first, last);
                }
                publishPassage(first, last);
                INFO() << "eval passage of " << std::distance(first, last)
                       << " photos";
            });
        });
    threadPool->drain();
    threadPool.checkExceptions();

    common::forEachBatch(features, 2000u, [this](auto first, auto last) {
        auto txn = pool().masterWriteableTransaction();
        db::FeatureGateway{*txn}.update({first, last},
                                        db::UpdateFeatureTxn::Yes);
        txn->commit();
    });
}

size_t Context::processClassifiedPhotos()
{
    auto result = 0u;
    auto batch = sql_chemistry::BatchLoad<db::table::Feature>{
        100000u,
        db::table::Feature::processedAt.isNull() &&
            db::table::Feature::quality.isNotNull() &&
            db::table::Feature::roadProbability.isNotNull() &&
            db::table::Feature::forbiddenProbability.isNotNull()};
    while (batch.next(*pool().slaveTransaction())) {
        auto features = db::Features{batch.begin(), batch.end()};
        result += features.size();
        processPhotos(std::move(features));
        INFO() << "processed " << result << " features";
    }
    return result;
}

}  // namespace maps::mrc::feature_publisher
