#include "importer.h"
#include "common.h"
#include "track_utils.h"
#include "yt_track_mapper.h"

#include "maps/wikimap/mapspro/services/mrc/long_tasks/import_taxi/lib/video_events_feed.h"
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/parallel_for_each.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/retry.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/mds_path.h>
#include <maps/wikimap/mapspro/services/mrc/libs/config/include/config.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/feature_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/import_config_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/import_taxi_event_config_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/track_point_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/video_gateway.h>

#include <maps/libs/chrono/include/time_point.h>
#include <maps/libs/log8/include/log8.h>

#include <chrono>
#include <exception>
#include <unordered_set>

namespace maps::mrc::import_taxi {

namespace {

class HttpStatus500 : public Exception {
    using Exception::Exception;
};

chrono::TimePoint getActualizationDate(const fb::TEdge& edge)
{
    constexpr float COVERAGE_FRACTION_THRESHOLD{0.95};
    chrono::TimePoint actDate{};

    for (const auto& coverage : edge.coverages) {
        if (coverage.cameraDeviation == db::CameraDeviation::Front
            && coverage.coverageFraction > COVERAGE_FRACTION_THRESHOLD) {
            actDate = std::max(actDate, coverage.actualizationDate);
        }
    }
    return actDate;
}

common::Blob encodeImage(const cv::Mat frame)
{
    common::Bytes buf;
    cv::imencode(".jpg", frame, buf);
    return common::Blob(buf.begin(), buf.end());
}

} // namespace

Importer::Importer(
    pgpool3::Pool& pool,
    mds::Mds mdsClient,
    std::shared_ptr<VideoEventsFeed> videoEventsFeed,
    std::unique_ptr<ITrackProvider> trackProvider,
    const std::string& graphFolder,
    const std::string& geoIdPath,
    bool dryRun
)
    : pool_(pool)
    , mdsClient_(std::move(mdsClient))
    , videoEventsFeed_(videoEventsFeed)
    , trackProvider_(std::move(trackProvider))
    , graphMatcher_(
        graphFolder,
        [](const std::string& path) {
            return std::make_shared<adapters::CompactGraphMatcherAdapter>(path);
        })
    , graphReader_(
        graphFolder,
        [](const std::string& path) {
            return std::make_shared<road_graph::Graph>(path + "/road_graph.fb");
        })
    , graphCoverageReader_(
        graphFolder,
        [](const std::string& path) {
            return std::make_shared<fb::GraphReader>(path + "/graph_coverage.fb");
        })
    , geoIdProvider_(
        geoIdPath,
        [](const std::string& path) {
            return privacy::makeGeoIdProvider(path);
        })
    , dryRun_(dryRun)
{
    ASSERT(trackProvider_);
}

void Importer::import()
{
    REQUIRE(videoEventsFeed_, "Video events feed is not set");
    refreshDatasets();
    refreshThresholds();
    refreshTaxiEventsConfig();

    auto eventsBatch = readEventsBatch();
    INFO() << "Process events batch, size = " << eventsBatch.size();
    trackProvider_->clearTracks();
    trackProvider_->preloadTracks(eventsBatch);

    common::parallelForEach(eventsBatch.begin(),
                            eventsBatch.end(),
                            [&](auto& /*guard*/, auto&& videoEvent) {
                                MAPS_LOG_THREAD_PREFIX_APPEND(std::to_string(videoEvent.id));
                                importEvent(std::move(videoEvent));
                            });

    if (!dryRun_) {
        updateMetadata(pool_, videoEventsFeed_->maxProcessedId());
    }
}

VideoEvents Importer::readEventsBatch()
{
    static constexpr size_t MAX_BATCH_SIZE = 10000;
    static constexpr auto AGE_THRESHOLD = std::chrono::days{1};

    VideoEvents videoEvents;
    std::unordered_set<std::string> addedUrls;

    while (videoEventsFeed_->hasNext() && videoEvents.size() < MAX_BATCH_SIZE) {
        auto videoEvent = videoEventsFeed_->next();
        if (!taxiEventsConfig_.count(videoEvent.eventType)) {
            INFO() << "Skip event " << videoEvent.id << " with type " << videoEvent.eventType;
            continue;
        }

        if (addedUrls.contains(videoEvent.url)) {
            INFO() << "Skip event " << videoEvent.id << " with duplicate url";
            continue;
        }

        if (videoEvent.eventTime + AGE_THRESHOLD > chrono::TimePoint::clock::now()) {
            break;
        }
        addedUrls.insert(videoEvent.url);
        videoEvents.push_back(std::move(videoEvent));
    }
    return videoEvents;
}

void Importer::importEvent(VideoEvent&& videoEvent)
{
    INFO() << "Importing video event " << videoEvent.id;

    videoEvent.track = trackProvider_->getTrack(videoEvent);
    if (videoEvent.track.empty()) {
        INFO() << "Video event " << videoEvent.id << " doesn't have track";
        return;
    }

    auto segments = graphMatcher_.get()->match(videoEvent.track);
    bool shouldImport = std::any_of(segments.begin(), segments.end(),
        [&, this](const auto& segment){
            return segmentNeedsCoverage(segment, videoEvent.eventTime);
        });

    if (!shouldImport) {
        INFO() << "Video event " << videoEvent.id << " doesn't add new coverage";
        return;
    }

    auto txn = pool_.masterWriteableTransaction();

    try {
        auto framesReader = video::VideoFramesReader{};
        framesReader.open(videoEvent.url);
        auto duration = framesReader.duration();

        const auto& eventConfig = taxiEventsConfig_.at(videoEvent.eventType);
        auto videoStartTime = videoEvent.eventTime
            - std::chrono::seconds(eventConfig.secondsBefore());

        auto video = saveVideo(*txn, videoEvent, videoStartTime, duration);
        auto features = savePhotos(*txn, framesReader, videoEvent, videoStartTime, segments);
        saveFrameToVideo(*txn, video, features, videoStartTime);
        saveTrackPoints(*txn, videoEvent.track);
    } catch(const sql_chemistry::UniqueViolationError&) {
        WARN() << "Event " << videoEvent.id << " is already in database";
        return;
    } catch(const std::exception& e) {
        ERROR() << "Failed to import event: " << e.what();
        return;
    }

    if (!dryRun_) {
        txn->commit();
        INFO() << "Video " << videoEvent.id << " imported";
    } else {
        INFO() << "Skip commit in dryRun mode";
    }
}

void Importer::refreshDatasets()
{
    graphMatcher_.refresh();
    graphReader_.refresh();
    graphCoverageReader_.refresh();
    geoIdProvider_.refresh();

    auto graphMatcherVer = graphMatcher_.get()->graphVersion();
    auto graphVer = graphReader_.get()->version();
    auto coverageVer = graphCoverageReader_.get()->version();

    REQUIRE(graphMatcherVer == graphVer && graphVer == coverageVer,
        "Graph versions mismatch"
            << ": graph matcher: " << graphMatcherVer
            << ", graph: " << graphVer
            << ", coverage: " << coverageVer);
}


void Importer::refreshThresholds()
{
    auto txn = pool_.slaveTransaction();
    auto items = db::ImportConfigGateway{*txn}
        .load(db::table::ImportConfig::dataset == db::Dataset::TaxiSignalQ2);

    thresholdsConfig_.clear();
    for (const auto& item : items) {
        thresholdsConfig_.emplace(
            std::make_pair<TGeoId, TFc>(item.geoId(), item.fc()),
            std::chrono::days{item.thresholdDays()});
    }
}

void Importer::refreshTaxiEventsConfig()
{
    auto txn = pool_.slaveTransaction();
    auto items = db::ImportTaxiEventConfigGateway{*txn}.load();

    taxiEventsConfig_.clear();
    for (const auto& item : items) {
        auto eventType = item.eventType();
        taxiEventsConfig_.emplace(std::move(eventType), std::move(item));
    }
}

void Importer::setEventsFeed(std::shared_ptr<VideoEventsFeed> videoEventsFeed)
{
    videoEventsFeed_ = videoEventsFeed;
}

bool Importer::segmentNeedsCoverage(
    const adapters::TrackSegment& segment,
    chrono::TimePoint videoTimestamp) const
{
    if (!segment.edgeId) {
        // Segment was not matched on graph
        return false;
    }

    auto fc = graphReader_.get()->edgeData(*segment.edgeId).category();

    auto actualizationDate = chrono::TimePoint{};
    if (auto edge = graphCoverageReader_.get()->edgeById(segment.edgeId->value())) {
        actualizationDate = getActualizationDate(*edge);
    }

    if (videoTimestamp < actualizationDate) {
        // Video would not add newer coverage
        return false;
    }

    std::optional<std::chrono::days> thresholdDays{};

    // geoIds are in area ascending order
    auto geoIds = geoIdProvider_.get()->load(segment.segment.boundingBox());
    for (auto geoId : geoIds) {
        auto it = thresholdsConfig_.find({geoId, fc});
        if (it != thresholdsConfig_.end()) {
            thresholdDays = it->second;
            break;
        }
    }

    return thresholdDays &&
        *thresholdDays < chrono::TimePoint::clock::now() - actualizationDate;
}


std::string Importer::loadWithRetry(const std::string& url)
{
    auto response = common::retryOnException<HttpStatus500>(
        common::RetryPolicy()
            .setInitialTimeout(std::chrono::seconds(1))
            .setMaxAttempts(3)
            .setTimeoutBackoff(2),
        [&]() {
            http::Request request(httpClient_, http::GET, url);
            auto response = request.perform();
            if (response.status() >= 500) {
                throw HttpStatus500() << url << ", status: " << response.status();
            }
            REQUIRE(response.status() == 200,
                "Unexpected server response [" << url << "]:\n" << response.readBody());
            return response;
        });

    return response.readBody();
}


db::Video Importer::saveVideo(
    pqxx::transaction_base& txn,
    const VideoEvent& videoEvent,
    chrono::TimePoint videoStartTime,
    std::chrono::milliseconds duration)
{
    INFO() << "Saving video, duration = " << duration.count() << "ms";

    auto video = db::Video(
        db::Dataset::TaxiSignalQ2,
        videoEvent.sourceId,
        videoStartTime,
        duration.count() / 1000.,
        mds::Key{"", ""});
    video.setEventId(videoEvent.id);
    video.setEventHash(static_cast<std::int64_t>(
        std::hash<std::string>{}(videoEvent.url)
    ));

    db::VideoGateway{txn}.insert(video);

    auto videoBytes = loadWithRetry(videoEvent.url);
    auto mdsPath = common::makeMdsPath(
        common::MdsObjectSource::Imported,
        common::MdsObjectType::Video,
        video.id());

    if (!dryRun_) {
        auto mdsResponse = mdsClient_.post(mdsPath, videoBytes);
        video.setMdsKey(mdsResponse.key());
    }
    db::VideoGateway{txn}.update(video);
    return video;
}


db::Features Importer::savePhotos(
    pqxx::transaction_base& txn,
    video::VideoFramesReader& framesReader,
    const VideoEvent& videoEvent,
    chrono::TimePoint videoStartTime,
    const adapters::TrackSegments& trackSegments)
{
    constexpr std::chrono::milliseconds MIN_PHOTOS_INTERVAL{1000};
    constexpr double MIN_PHOTOS_DISTANCE_METERS{5.0};

    // Get points at which to take photos
    auto photoPoints = pickTrackPoints(
        trackSegments, MIN_PHOTOS_INTERVAL, MIN_PHOTOS_DISTANCE_METERS);

    db::Features features;

    // Fast-forward to the start of the video
    auto photoPointIt = photoPoints.begin();
    for (; photoPointIt != photoPoints.end() && photoPointIt->timestamp() < videoStartTime;
        ++photoPointIt)
        ;

    // Read frames from desired timestamps
    for (; photoPointIt != photoPoints.end(); ++photoPointIt) {
        std::optional<video::Frame> frame{};

        while ((frame = framesReader.readFrame()).has_value()
            && videoStartTime + frame->timeFromStart < photoPointIt->timestamp())
            ;

        if (!frame) {
            break;
        }

        auto ts = videoStartTime + frame->timeFromStart;
        auto feature = savePhoto(txn, frame->frame, videoEvent.sourceId, ts, *photoPointIt);
        features.push_back(std::move(feature));
    }
    db::FeatureGateway{txn}.update(features, db::UpdateFeatureTxn::No);
    INFO() << "Photos saved: " << features.size();
    return features;
}

db::Feature Importer::savePhoto(
    pqxx::transaction_base& txn,
    const cv::Mat& frame,
    const std::string& sourceId,
    chrono::TimePoint timestamp,
    const db::TrackPoint& location)
{
    INFO() << "Save photo, ts = " << chrono::formatSqlDateTime(timestamp);

    auto feature = sql_chemistry::GatewayAccess<db::Feature>::construct()
        .setTimestamp(timestamp)
        .setSourceId(sourceId)
        .setDataset(db::Dataset::TaxiSignalQ2)
        .setGeodeticPos(location.geodeticPos())
        .setUploadedAt(chrono::TimePoint::clock::now())
        .setSize(frame.cols, frame.rows);

    if (location.heading()) {
        feature.setHeading(*location.heading());
    }

    db::FeatureGateway{txn}.insert(feature);

    auto imageBytes = encodeImage(frame);
    auto mdsPath = common::makeMdsPath(
        common::MdsObjectSource::Imported,
        common::MdsObjectType::Image,
        feature.id());

    if (!dryRun_) {
        auto mdsResponse = mdsClient_.post(mdsPath, imageBytes);
        feature.setMdsKey(mdsResponse.key());
    }
    return feature;
}


void Importer::saveFrameToVideo(
    pqxx::transaction_base& txn,
    const db::Video& video,
    const db::Features& features,
    chrono::TimePoint videoStartTime)
{
    INFO() << "Saving frame to video";

    db::FrameToVideos frameToVideos;
    frameToVideos.reserve(features.size());

    for (const auto& feature : features) {
        double secondsFromStart
            = std::chrono::duration_cast<std::chrono::milliseconds>(
                feature.timestamp() - videoStartTime).count() / 1000.;
        frameToVideos.emplace_back(feature.id(), video.id(), secondsFromStart);
    }
    db::FrameToVideoGateway{txn}.insert(frameToVideos);
}


void Importer::saveTrackPoints(
    pqxx::transaction_base& txn,
    db::TrackPoints& trackPoints)
{
    INFO() << "Saving track points";
    db::TrackPointGateway{txn}.insert(trackPoints);
}

} // namespace maps::mrc::import_taxi
