#include <maps/wikimap/mapspro/services/mrc/libs/track_classifier/include/track_classifier.h>
#include <maps/wikimap/mapspro/services/mrc/libs/graph_matcher_adapter/include/compact_graph_matcher_adapter.h>

#include <maps/libs/chrono/include/time_point.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/geolib/include/distance.h>
#include <maps/libs/log8/include/log8.h>

namespace maps::mrc::track_classifier {

namespace {

const auto BETWEEN_TRACKS_TIME_GAP = std::chrono::seconds{ 60 };

float calculateAverageSpeed(std::vector<float>&& speeds) {
    std::sort(speeds.begin(), speeds.end());
    const size_t outliers = speeds.size() / 10;
    if (0 < outliers) {
        speeds.erase(speeds.begin(), speeds.begin() + outliers);
        speeds.erase(speeds.end() - outliers, speeds.end());
    }
    return std::accumulate(speeds.begin(), speeds.end(), 0.) / speeds.size();
}

std::optional<float> calculateAverageSpeed(const db::TrackPoints& tpts) {
    constexpr size_t MIN_TRACK_LENGTH = 10;
    constexpr double MIN_TIME_SECONDS = 1.;
    constexpr double MIN_DISTANCE_METERS = 1.;
    // движение со скоростью большей 200кмч считаем за ошибку
    constexpr double MAX_SPEED_METERS_PER_SECONDS = 200. * 1000. / 3600.;

    if (tpts.size() <= MIN_TRACK_LENGTH) {
        return std::nullopt;
    }

    maps::chrono::TimePoint lastTime = tpts[0].timestamp();
    maps::geolib3::Point2 lastPt = tpts[0].mercatorPos();
    std::vector<float> speeds;
    for (size_t i = 1; i < tpts.size(); i++) {
        const db::TrackPoint& tpt = tpts[i];

        const maps::chrono::TimePoint time = tpt.timestamp();
        if (std::chrono::abs(time - lastTime) > BETWEEN_TRACKS_TIME_GAP) {// если используем splitTrackPoints, то здесь можно поставить REQUIRE вместо coninue
            lastTime = time;
            lastPt = tpt.mercatorPos();
            continue;
        }
        const double deltaSeconds = std::chrono::duration_cast<std::chrono::milliseconds>(time - lastTime).count() / 1000.0;
        if (deltaSeconds < MIN_TIME_SECONDS) {
            continue;
        }
        const maps::geolib3::Point2 pt = tpt.mercatorPos();
        const float distanceMeters = maps::geolib3::toMeters(maps::geolib3::distance(pt, lastPt), pt);
        if (distanceMeters < MIN_DISTANCE_METERS) {
            continue;
        }
        const double speedMetersPerSecond = distanceMeters / deltaSeconds;
        if (speedMetersPerSecond < MAX_SPEED_METERS_PER_SECONDS) {
            speeds.push_back(speedMetersPerSecond);
        }
        lastTime = time;
        lastPt = tpt.mercatorPos();
    }
    if (MIN_TRACK_LENGTH > speeds.size()) {
        return std::nullopt;
    }
    return calculateAverageSpeed(std::move(speeds));
}

adapters::TrackSegments match(
    const db::TrackPoints& trackpts,
    const std::vector<track_classifier::TrackType>& pttypes,
    const adapters::Matcher* vehicleMatcher,
    const adapters::Matcher* pedestrianMatcher)
{
    REQUIRE(vehicleMatcher, "Vehicle graph matcher is undefined");
    REQUIRE(pedestrianMatcher, "Pedestrian graph matcher is undefined");

    adapters::TrackSegments segments;
    REQUIRE(trackpts.size() == pttypes.size(), "Different amount of track points and types of points");
    db::TrackPoints part;
    part.emplace_back(trackpts[0]);
    track_classifier::TrackType type = pttypes[0];
    for (size_t i = 1; i < trackpts.size(); i++) {
        if (pttypes[i] != type) {
            if (track_classifier::TrackType::Undefined != type) {
                adapters::TrackSegments partSegments;
                if (track_classifier::TrackType::Pedestrian == type) {
                    partSegments = pedestrianMatcher->match(part);
                } else if (track_classifier::TrackType::Vehicle == type) {
                    partSegments = vehicleMatcher->match(part);
                }
                segments.insert(segments.end(), partSegments.begin(), partSegments.end());
            }
            type = pttypes[i];
            part.clear();
        }
        part.emplace_back(trackpts[i]);
    }
    if (!part.empty() && (track_classifier::TrackType::Undefined != type)) {
        adapters::TrackSegments partSegments;
        if (track_classifier::TrackType::Pedestrian == type) {
            partSegments = pedestrianMatcher->match(part);
        } else if (track_classifier::TrackType::Vehicle == type) {
            partSegments = vehicleMatcher->match(part);
        }
        segments.insert(segments.end(), partSegments.begin(), partSegments.end());
    }
    return segments;
}

struct IterationStatistic {
    std::vector<double> dists;
    double distSum = 0.;
    size_t outliersCnt = 0;

    void resize(size_t tptsCnt) {
        dists.resize(tptsCnt, DBL_MAX);
    }
    double getAverageError() const {
        constexpr double OUTLIER_COEFFICIENT = 17.;
        if (dists.empty() || distSum == DBL_MAX) {
            return DBL_MAX;
        }
        return distSum / dists.size() + OUTLIER_COEFFICIENT * outliersCnt;
    }
    bool isBadPoint(size_t idxPt) const {
        constexpr double MIN_ERROR_METERS = 5.;
        return (dists[idxPt] > MIN_ERROR_METERS);
    }
    void update(size_t idxPt, const db::TrackPoint& tpt, const adapters::TrackSegment& segment) {
        dists[idxPt] = maps::geolib3::geoDistance(tpt.geodeticPos(), segment.segment);
        distSum += dists[idxPt];
        if (tpt.timestamp() < segment.startTime || tpt.timestamp() > segment.endTime) {
            outliersCnt++;
        }
    }
};

IterationStatistic calcIterationStat(const adapters::TrackSegments& segments, const db::TrackPoints& tpts) {
    const size_t segmsCnt = segments.size();
    const size_t tptsCnt = tpts.size();

    IterationStatistic result;
    result.resize(tptsCnt);
    if (0 == segmsCnt) {
        result.distSum = DBL_MAX;
        result.outliersCnt = tptsCnt;
        return result;
    }
    size_t idxSegm = 0;
    size_t idxPt = 0;
    for (;idxSegm < segmsCnt && idxPt < tptsCnt;) {
        const adapters::TrackSegment& segm = segments[idxSegm];
        const db::TrackPoint& tpt = tpts[idxPt];
        if (tpt.timestamp() < segm.startTime) {
            if ((0 < idxSegm) &&
                (segm.startTime - tpt.timestamp() > tpt.timestamp() - segments[idxSegm - 1].endTime)) {
                result.update(idxPt, tpt, segments[idxSegm - 1]);
            } else {
                result.update(idxPt, tpt, segm);
            }
            idxPt++;
        } else if (tpt.timestamp() <= segm.endTime) {
            result.update(idxPt, tpt, segm);
            idxPt++;
        } else {
            idxSegm++;
        }
    }
    for (;idxPt < tptsCnt; idxPt++) {
        result.update(idxPt, tpts[idxPt], segments.back());
    }
    return result;
}

std::vector<track_classifier::TrackInterval> getIntervals(
    const db::TrackPoints& tpts,
    const std::vector<track_classifier::TrackType>& types)
{
    REQUIRE(tpts.size() == types.size(), "Different amount of track points and types of points");
    std::vector<track_classifier::TrackInterval> result;
    track_classifier::TrackInterval interval;
    interval.begin = tpts[0].timestamp();
    interval.end = tpts[0].timestamp();
    interval.type = types[0];
    for (size_t i = 1; i < types.size(); i++) {
        if (interval.type == types[i]) {
            interval.end = tpts[i].timestamp();
        } else {
            result.push_back(interval);
            interval.begin = tpts[i].timestamp();
            interval.end = tpts[i].timestamp();
            interval.type = types[i];
        }
    }
    result.push_back(interval);
    return result;
}

track_classifier::TrackType overType(track_classifier::TrackType type) {
    if (type == track_classifier::TrackType::Pedestrian) {
        return track_classifier::TrackType::Vehicle;
    }
    return track_classifier::TrackType::Pedestrian;
}

//first - startIdx, second - length
std::vector < std::pair<size_t, size_t> > getBadParts(
    const IterationStatistic& iterStat,
    const std::vector<track_classifier::TrackType>& types)
{
    int start = -1;
    std::vector < std::pair<size_t, size_t> > badParts;
    for (size_t i = 0; i < iterStat.dists.size(); i++) {
        if (-1 == start) {
            if (!iterStat.isBadPoint(i)) {
                continue;
            }
            start = i;
            continue;
        }
        if (iterStat.isBadPoint(i)) {
            if (types[i] != types[start]) {
                const size_t length = i - start;
                badParts.emplace_back(start, length);
                start = i;
            }
        } else {
            const size_t length = i - start;
            badParts.emplace_back(start, length);
            start = -1;
        }
    }
    if (-1 != start) {
        const size_t length = iterStat.dists.size() - start;
        badParts.emplace_back(start, length);
    }

    constexpr size_t MIN_PTS_SWITCH = 10;
    badParts.erase(
        std::remove_if(badParts.begin(), badParts.end(),
            [&](const std::pair<size_t, size_t>& part) {
                return (MIN_PTS_SWITCH >= part.second);
            }
        ),
        badParts.end()
    );
    return badParts;
}

std::vector<TrackInterval> classifyImpl(const db::TrackPoints& tpts, const std::map<db::GraphType, const adapters::Matcher*>& graphTypeToMatcherMap) {
    constexpr float SPEED_PEDESTRIAN_MAX_METERS_PER_SECOND = 20.f * 1000.f / 3600.f; // 20kmh - если больше всегда считаем автомобилем
    constexpr float SPEED_THRESHOLD_METERS_PER_SECOND = 12.f * 1000.f / 3600.f; // 12kmh
    constexpr size_t MAX_ITER = 10;

    REQUIRE(0 < tpts.size(), "Track points vector is empty");

    const std::optional<float> avgSpeed = calculateAverageSpeed(tpts);
    if (avgSpeed == std::nullopt) {
        track_classifier::TrackInterval result;
        result.begin = tpts.front().timestamp();
        result.end = tpts.back().timestamp();
        result.type = track_classifier::TrackType::Undefined;
        return{ result };
    }

    DEBUG() << "Average speed: " << (*avgSpeed) * 3.6 << " kmh";
    if (*avgSpeed > SPEED_PEDESTRIAN_MAX_METERS_PER_SECOND) {
        track_classifier::TrackInterval result;
        result.begin = tpts.front().timestamp();
        result.end = tpts.back().timestamp();
        result.type = track_classifier::TrackType::Vehicle;
        return{ result };
    }

    const adapters::Matcher* vehicleMatcher = graphTypeToMatcherMap.at(db::GraphType::Road);
    const adapters::Matcher* pedestrianMatcher = graphTypeToMatcherMap.at(db::GraphType::Pedestrian);

    REQUIRE((vehicleMatcher && pedestrianMatcher), "Pointer to graph matcher is null");
    std::vector<track_classifier::TrackType> types(tpts.size());
    // инициализируем по средней скорости
    if (*avgSpeed < SPEED_THRESHOLD_METERS_PER_SECOND) {
        std::fill(types.begin(), types.end(), track_classifier::TrackType::Pedestrian);
    } else {
        std::fill(types.begin(), types.end(), track_classifier::TrackType::Vehicle);
    }

    adapters::TrackSegments segments = match(tpts, types, vehicleMatcher, pedestrianMatcher);
    IterationStatistic iterStat = calcIterationStat(segments, tpts);
    for (size_t iter = 0; iter < MAX_ITER; iter++) {
        DEBUG() << iter << " iteration. Average errors: " << iterStat.getAverageError() << ", outliers: " << iterStat.outliersCnt;
        std::vector < std::pair<size_t, size_t> > badParts = getBadParts(iterStat, types);
        if (badParts.empty()) {
            break;
        }
        std::sort(badParts.begin(), badParts.end(),
            [](const std::pair<size_t, size_t>& lhs, const std::pair<size_t, size_t>& rhs) {
                return lhs.second > rhs.second;
            }
        );

        bool change = false;
        for (size_t p = 0; p < badParts.size(); p++) {
            const std::pair<size_t, size_t>& badPart = badParts[p];
            const track_classifier::TrackType oldType = types[badPart.first];
            const track_classifier::TrackType newType = overType(oldType);
            std::fill(types.begin() + badPart.first, types.begin() + badPart.first + badPart.second, newType);
            adapters::TrackSegments segmentsNew = match(tpts, types, vehicleMatcher, pedestrianMatcher);
            IterationStatistic iterStatNew = calcIterationStat(segmentsNew, tpts);
            DEBUG() << iter << " iteration. Average errors: " << iterStatNew.getAverageError() << ", outliers: " << iterStatNew.outliersCnt;
            if (iterStatNew.getAverageError() >= iterStat.getAverageError())
            {
                std::fill(types.begin() + badPart.first, types.begin() + badPart.first + badPart.second, oldType);
            } else {
                change = true;
                segments = segmentsNew;
                iterStat = iterStatNew;
            }
        }
        if (!change) {
            break;
        }
    }

    return getIntervals(tpts, types);
}

/*
    разделяем на наборы
    внутри каждого набора точки упорядочены по временной метке, при этом
    две последовательные точки набора отличаются не более чем на BETWEEN_TRACKS_TIME_GAP
*/
std::vector<db::TrackPoints> splitTrackPoints(const db::TrackPoints& trackpts) {
    REQUIRE(0 < trackpts.size(), "Track points vector is empty");

    db::TrackPoints tpts = trackpts;
    std::sort(tpts.begin(), tpts.end(),
        [](const db::TrackPoint& lhs, const db::TrackPoint& rhs) {
            return lhs.timestamp() < rhs.timestamp();
        }
    );

    size_t start = 0;
    std::vector<db::TrackPoints> result;
    for (size_t i = 1; i < tpts.size(); i++) {
        if (tpts[i].timestamp() - tpts[i - 1].timestamp() > BETWEEN_TRACKS_TIME_GAP) {
            result.emplace_back(tpts.begin() + start, tpts.begin() + i);
            start = i;
        }
    }
    result.emplace_back(tpts.begin() + start, tpts.end());
    return result;
}

} // namespace

std::vector<TrackInterval> classify(const db::TrackPoints& trackpts, const std::map<db::GraphType, const adapters::Matcher*>& graphTypeToMatcherMap) {
    REQUIRE(0 < graphTypeToMatcherMap.count(db::GraphType::Road),
            "graphTypeToMatcherMap doesn't contain road matcher");
    REQUIRE(0 < graphTypeToMatcherMap.count(db::GraphType::Pedestrian),
            "graphTypeToMatcherMap doesn't contain pedestrian matcher.");

    std::vector<db::TrackPoints> tptsVector = splitTrackPoints(trackpts);
    std::vector<TrackInterval> result;
    for (size_t i = 0; i < tptsVector.size(); i++) {
        std::vector<TrackInterval> temp = classifyImpl(tptsVector[i], graphTypeToMatcherMap);
        std::move(temp.begin(), temp.end(), std::back_inserter(result));
    }
    return result;
}


} //maps::mrc::track_classifier

