#include "yt_types.h"
#include "yt_track_mapper.h"

#include <maps/libs/log8/include/log8.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/io.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/operation.h>
#include <maps/wikimap/mapspro/services/mrc/libs/yt/include/serialization.h>

#include <mapreduce/yt/interface/operation.h>
#include <util/ysaveload.h>

#include <chrono>
#include <string>

// Required for Y_SAVELOAD_JOB
template <>
class TSerializer<maps::mrc::import_taxi::TimeInterval> {
public:
    static void Save(IOutputStream* s, const maps::mrc::import_taxi::TimeInterval& v)
    {
        using Duration = std::chrono::microseconds;
        ::Save(s, maps::chrono::sinceEpoch<Duration>(v.begin));
        ::Save(s, maps::chrono::sinceEpoch<Duration>(v.end));
    }

    static void Load(IInputStream* s, maps::mrc::import_taxi::TimeInterval& v)
    {
        int64_t begin, end;
        ::Load(s, begin);
        ::Load(s, end);

        using Duration = std::chrono::microseconds;
        v.begin = maps::chrono::sinceEpochToTimePoint<Duration>(begin);
        v.end = maps::chrono::sinceEpochToTimePoint<Duration>(end);
    }
};

// Required for Y_SAVELOAD_JOB
template <>
class TSerializer<maps::mrc::import_taxi::DeviceIdToTimeIntervals> {
public:
    static void Save(IOutputStream* s, const maps::mrc::import_taxi::DeviceIdToTimeIntervals& map)
    {
        ::Save(s, map.size());
        for (const auto& [deviceId, intervals] : map) {
            ::Save(s, TString(deviceId));
            ::Save(s, intervals);
        }
    }

    static void Load(IInputStream* s, maps::mrc::import_taxi::DeviceIdToTimeIntervals& map)
    {
        size_t size;
        ::Load(s, size);
        for (size_t i = 0; i < size; ++i) {
            TString deviceId;
            maps::mrc::import_taxi::TimeIntervals intervals;
            ::Load(s, deviceId);
            ::Load(s, intervals);
            map.emplace(std::string(deviceId), std::move(intervals));
        }
    }
};

namespace maps::mrc::import_taxi {

namespace {

class FilterDevicesMapper : public yt::Mapper
{
public:
    FilterDevicesMapper() = default;

    FilterDevicesMapper(DeviceIdToTimeIntervals deviceIdToTimeIntervals)
        : deviceIdToTimeIntervals_(std::move(deviceIdToTimeIntervals))
    {}

    void Do(yt::Reader* reader, yt::Writer* writer) override
    {
        for (; reader->IsValid(); reader->Next()) {
            auto record = yt::deserialize<DeviceTrackRecord>(reader->GetRow());
            if (!isRecordUseful(record)) {
                continue;
            }
            writer->AddRow(yt::serialize(record));
        }
    }

    Y_SAVELOAD_JOB(deviceIdToTimeIntervals_);

private:
    bool isRecordUseful(const DeviceTrackRecord& record)
    {
        auto itr = deviceIdToTimeIntervals_.find(record.deviceId);
        if (itr == deviceIdToTimeIntervals_.end()) {
            return false;
        }
        const auto& intervals = itr->second;

        auto intervalItr = std::upper_bound(
            intervals.begin(), intervals.end(), record.timePoint(),
            [](chrono::TimePoint timePoint, const TimeInterval& interval) {
                return timePoint < interval.end;
            });

        return intervalItr != intervals.end()
            && record.timePoint() >= intervalItr->begin;
    }

    DeviceIdToTimeIntervals deviceIdToTimeIntervals_;
};

} // namespace

REGISTER_MAPPER(FilterDevicesMapper);

void loadSelectedTracks(
    NYT::IClientBase& client,
    const std::vector<TString>& inputTablePaths,
    const TString& outputTablePath,
    const DeviceIdToTimeIntervals& deviceIdToTimeIntervals)
{
    INFO() << "Loading selected track points...";
    auto mapSpec = NYT::TMapOperationSpec();
    for (const auto& path : inputTablePaths) {
        if (client.Exists(path)) {
            mapSpec.AddInput<NYT::TNode>(path);
        }
    }
    mapSpec.AddOutput<NYT::TNode>(outputTablePath);

    auto operation = client.Map(
        mapSpec,
        new FilterDevicesMapper(deviceIdToTimeIntervals));
}

TimeIntervals mergeTimeIntervals(TimeIntervals timeIntervals)
{
    TimeIntervals result;
    if (timeIntervals.empty()) {
        return result;
    }

    std::sort(timeIntervals.begin(), timeIntervals.end(),
        [](const auto& lhs, const auto& rhs) {
            return lhs.begin < rhs.begin;
        });

    TimeInterval current = *timeIntervals.begin();

    for (auto itr = timeIntervals.begin() + 1; itr < timeIntervals.end(); ++itr) {
        if (current.end >= itr->begin) {
            current.end = std::max(current.end, itr->end);
        } else {
            result.push_back(current);
            current = *itr;
        }
    }
    result.push_back(current);
    return result;
}

} // namespace maps::mrc::import_taxi
