#include "../include/catboost_visibility_predictor.h"

#include <maps/libs/geolib/include/contains.h>
#include <maps/libs/geolib/include/direction.h>
#include <maps/libs/geolib/include/distance.h>
#include <maps/libs/geolib/include/heading.h>

#include <maps/libs/log8/include/log8.h>

#include <library/cpp/resource/resource.h>
#include <library/cpp/iterator/zip.h>

#include <tuple>

namespace maps::mrc::eye {

namespace {

struct PointWithConfidence {
    geolib3::Point2 point;
    float confidence;
};

geolib3::Line2 toGeolibLine(const cv::Point3f& line)
{
    auto point = std::abs(line.x) > geolib3::EPS
    ? geolib3::Point2(-line.z / line.x, 0)
    : geolib3::Point2(0, -line.z/ line.y);

    return geolib3::Line2(point, geolib3::Vector2(line.y, line.x));
}


geolib3::Line2
evalEpiline(FrameMatchPlace matchPlace,
            const cv::Mat& fundMatrix,
            const cv::Point2f& ptf)
{
    // calc the index of the image with detection
    int whichImage = matchPlace == FrameMatchPlace::First ? 2 : 1;
    std::vector<cv::Point3f> lines;
    cv::computeCorrespondEpilines(std::vector<cv::Point2f>({ptf}), whichImage, fundMatrix, lines);
    ASSERT(lines.size() == 1);
    return toGeolibLine(lines[0]);
}

float calcConfidence(const FramesMatchData& match)
{
    int totalPtsCount = match.ptsCnt0 + match.ptsCnt1;
    ASSERT(totalPtsCount > 0);
    return 2.0 * ((float)match.goodPtsCnt / totalPtsCount);
}

template<typename Container,
         typename ValueGetter = std::identity,
         typename ValueType =  std::decay_t<std::invoke_result_t<ValueGetter, typename Container::value_type>>>
ValueType evalMedian(Container& c, ValueGetter func = {})
{
    if (c.empty()) {
        return {};
    }
    const size_t middleIndex = c.size() / 2;
    std::nth_element(c.begin(), c.begin() + middleIndex, c.end(),
        [=](const auto& one, const auto& other) {
            return func(one) < func(other);
        });
    return func(c[middleIndex]);
}


template<typename Container,
         typename ValueGetter = std::identity,
         typename ValueType =  std::decay_t<std::invoke_result_t<ValueGetter, typename Container::value_type>>>
ValueType evalAvg(Container& c, ValueGetter func = {})
{
    if (c.empty()) {
        return {};
    }
    return std::accumulate(c.begin(), c.end(), ValueType{},
        [=](ValueType one, auto other) {
            return one + func(other);
        }) / c.size();
}


template<typename Container,
         typename ValueGetter = std::identity,
         typename ValueType =  std::decay_t<std::invoke_result_t<ValueGetter, typename Container::value_type>>>
ValueType evalMin(Container& c, ValueGetter func = {})
{
    if (c.empty()) {
        return {};
    }
    return func(*std::min_element(c.begin(), c.end(),
        [=](const auto& one, const auto& other) {
            return func(one) < func(other);
        }));
}

template<typename Container,
         typename ValueGetter = std::identity,
         typename ValueType =  std::decay_t<std::invoke_result_t<ValueGetter, typename Container::value_type>>>
ValueType evalMax(Container& c, ValueGetter func = {})
{
    if (c.empty()) {
        return {};
    }
    return func(*std::max_element(c.begin(), c.end(),
        [=](const auto& one, const auto& other) {
            return func(one) < func(other);
        }));
}

TFullModel readCatboostModelFromResource(const std::string& resourceName)
{
    TString binModel = NResource::Find(resourceName);
    return ReadModel(binModel.data(), binModel.size());
}

} // namespace

MatchEpilines calcEpilines(
    FrameMatchPlace matchPlace,
    const cv::Mat& fundMatrix,
    const common::ImageBox& detectionBbox)
{
    cv::Rect rect = detectionBbox;
    auto centerEpiline = evalEpiline(matchPlace, fundMatrix, (rect.tl() + rect.br()) / 2);

    return MatchEpilines{
        .centerEpiline = std::move(centerEpiline),
        .topLeftEpiline = evalEpiline(matchPlace, fundMatrix, rect.tl()),
        .bottomRightEpiline = evalEpiline(matchPlace, fundMatrix, rect.br())};
}

std::vector<float> calcObjectVisibilityFactors(
    const common::Size& frameSize,
    const geolib3::Point2& frameMercPosition,
    geolib3::Heading frameHeading,
    const std::vector<DetectionMatchData>& matches,
    const std::vector<MatchEpilines>& epilinesVec)
{
    ASSERT(matches.size() == epilinesVec.size());

    const geolib3::BoundingBox imageBbox(
        geolib3::Point2{},
        geolib3::Point2{
            static_cast<float>(frameSize.width),
            static_cast<float>(frameSize.height)});

    int maxBboxSize = 0;
    float distanceToMaxBbox = 0.;
    float headingDiffWithMaxBbox = 0.;
    int maxGoodPointsCount = 0;
    int maxMatchBboxSize = 0;
    float innerPointsMaxIntersectionAngle = 0.;
    float outlierPointsMaxIntersectionAngle = 0.;
    std::vector<PointWithConfidence> innerPoints;
    std::vector<PointWithConfidence> outlierPoints;
    std::vector<float> headingDiffs;
    std::vector<float> predictedBboxSizes;
    std::vector<int> goodPointsCountVec;

    for (const auto& [match, epilines]: Zip(matches, epilinesVec)) {
        const float distanceMeters = geolib3::fastGeoDistance(
            frameMercPosition, match.detectionFrameMercPosition);
        const float headingDiff = std::abs(
            geolib3::angleBetween(frameHeading, match.detectionFrameHeading)
                .value());
        headingDiffs.push_back(headingDiff);
        const int bboxMaxSize =
            std::max(match.detectionBbox.width(), match.detectionBbox.height());
        if (bboxMaxSize > maxBboxSize) {
            maxBboxSize = bboxMaxSize;
            distanceToMaxBbox = distanceMeters;
            headingDiffWithMaxBbox = headingDiff;
        }

        if (maxGoodPointsCount < match.match.goodPtsCnt) {
            maxGoodPointsCount = match.match.goodPtsCnt;
            maxMatchBboxSize = bboxMaxSize;
        }
        goodPointsCountVec.push_back(match.match.goodPtsCnt);
    }

    for (size_t i = 0; i < matches.size(); ++i) {
        for (size_t j = i + 1; j < matches.size(); j++) {
            auto lineIntersectionPoints = geolib3::intersection(
                epilinesVec[i].centerEpiline, epilinesVec[j].centerEpiline);
            auto tlLineIntersectionPoints = geolib3::intersection(
                epilinesVec[i].topLeftEpiline, epilinesVec[j].topLeftEpiline);
            auto brLineIntersectionPoints = geolib3::intersection(
                epilinesVec[i].bottomRightEpiline, epilinesVec[j].bottomRightEpiline);

            if (lineIntersectionPoints.size() != 1 ||
                tlLineIntersectionPoints.size() != 1 ||
                brLineIntersectionPoints.size() != 1) {
                continue;
            }

            float lineAngle = geolib3::angleBetween(
                                   epilinesVec[i].centerEpiline.direction(),
                                   epilinesVec[j].centerEpiline.direction())
                                   .value();
            float confidence = std::min(
                calcConfidence(matches[i].match),
                calcConfidence(matches[j].match));
            predictedBboxSizes.push_back(
                geolib3::BoundingBox(
                    tlLineIntersectionPoints[0], brLineIntersectionPoints[0])
                    .diagonalLength());

            auto point = PointWithConfidence{
                .point = lineIntersectionPoints[0], .confidence = confidence};

            const bool isInternalPoint =
                geolib3::contains(imageBbox, point.point);
            if (isInternalPoint) {
                innerPointsMaxIntersectionAngle =
                    std::max(innerPointsMaxIntersectionAngle, lineAngle);
                innerPoints.push_back(std::move(point));
            } else {
                outlierPointsMaxIntersectionAngle =
                    std::max(outlierPointsMaxIntersectionAngle, lineAngle);
                outlierPoints.push_back(std::move(point));
            }
        }
    }

    const size_t totalPoints = innerPoints.size() + outlierPoints.size();
    const float innerPointsRatio =
        totalPoints > 0 ? ((float)innerPoints.size()) / totalPoints : 0.;

    return std::vector<float>({
        evalMedian(headingDiffs),
        (float)innerPoints.size(),
        evalAvg(
            innerPoints,
            [](const PointWithConfidence& p) { return p.confidence; }),
        evalMax(
            innerPoints,
            [](const PointWithConfidence& p) { return p.confidence; }),
        innerPointsRatio,
        innerPointsMaxIntersectionAngle,
        (float) evalMedian(goodPointsCountVec),
        (float) maxMatchBboxSize,
        (float) maxBboxSize,
        evalMedian(predictedBboxSizes),
        distanceToMaxBbox,
        headingDiffWithMaxBbox
        });
}

CatboostClusterVisibilityPredictor::CatboostClusterVisibilityPredictor()
    : model_(readCatboostModelFromResource("/maps/mrc/predict_visibility/model.bin"))
{}

bool CatboostClusterVisibilityPredictor::isVisible(
    const common::Size& frameSize,
        const geolib3::Point2& frameMercPosition,
        geolib3::Heading frameHeading,
        const std::vector<DetectionMatchData>& matches) const
{
    std::vector<MatchEpilines> epilines;
    epilines.reserve(matches.size());

    for (const auto& match: matches) {
        epilines.push_back(calcEpilines(match.matchPlace, match.match.fundMatrix, match.detectionBbox));
    }

    const auto factors =
        calcObjectVisibilityFactors(frameSize, frameMercPosition, frameHeading, matches, epilines);

    double result = 0.;
    model_.Calc(factors, {}, MakeArrayRef(&result, 1));
    return result > 0.5;
}

} // namespace maps::mrc::eye
