#pragma once

// Finds a function f(x) using some not very precise values of f and
// precise derivative of f (but not perfectly precise).
//
// Lets assume that there is an unknown array of x, where x[i] - x[i-1] ≈ x[i+1] - x[i]
// We known h(x) for each x[i] where
// h(x[i]) ≈ f(x[i])
// We known g(x) for each x[i] where
// g(x[i]) != f(x[i]) but g(x[i]) - g(x[i-1]) == f(x[i]) - f(x[i-1])

// Algorithm:
// f(x[i]) = g(x[i]) + (average(h(x[j])-g(x[j]))) for j in range [i-K;i+K])
// To calculate the average(h(x[j])-g(x[j])) we need to define:
// 1) the size of the range K
// 2) the averaging algorithm
// 3) the weight of (h(x[j]) - g(x[j])) for each j.

// What is weight:
// If we know that the accuracy of h(x[j1]) is bad and the accuracy of
// h(x[j2]) is good, we should take h(x[j2]) into consideration more than h(x[j1]).
// So the weight can be (1 / accuracy).
// Also the precise of g(x) is not perfect, so we can't take very big K.
// We can take it into consideration using the weight:
// weight = 1 / accuracy / (1 + abs(i-j))
// This formula is just a recommendation

// Averaging algorithm
// We use a kind of median value. We have an array of values and the
// weight for each value. We sort all the items by value and take the item in the
// center of mass in the sorted array.

#include <maps/wikimap/mapspro/services/mrc/libs/position_improvment/include/events.h>

#include <maps/libs/common/include/exception.h>

#include <vector>
#include <algorithm>

namespace maps::mrc::pos_improvment {

// represents (h(x) - g(x))
template<typename DecimalType>
struct Shift {
    DecimalType shift; // h(x) - g(x) value
    double weight;
    bool operator <(const Shift& other) const {
        return shift < other.shift;
    }
};

template<typename DecimalType>
DecimalType getMedian(std::vector<Shift<DecimalType>>& shifts)
{
    double weightSum = 0;
    for(size_t i = 0; i < shifts.size(); i++) {
        weightSum += shifts[i].weight;
    }

    std::sort(shifts.begin(), shifts.end());

    double curWeightSum = 0;
    for(size_t i = 0; i < shifts.size(); i++) {
        curWeightSum += shifts[i].weight;
        if (curWeightSum >= weightSum / 2) {
            return shifts[i].shift;
        }
    }
    throw Exception("Can't find median");
}

using BaseIndex = int64_t;
using CurIndex = int64_t;

// Should return weight for each value.
// @param BaseIndex - index i, it means that we are calculating f(i) now.
// @param CurIndex - the index j in range [i-K; i+K]. The function should
// return the weight of (h(x[j]) - g(x[j]).
using ValueWeightFunction = std::function<double(BaseIndex, CurIndex)>;

// Returns f(x[i]) for each i;
// @param frameWidth = the size K of the range near x[i] (number of
// neighboring points)
// @param values - values of h(x[i]) for each i where
// h(x[i]) ≈ f(x[i]);
// @param valuesWithDerivative - values of g(x[i]) for each x[i] where
// g(x[i]) != f(x[i]) but g(x[i]) - g(x[i-1]) == f(x[i]) - f(x[i-1]);
template<typename DecimalType>
std::vector<DecimalType> matchFunction(
    std::vector<DecimalType> values,
    std::vector<DecimalType> valuesWithDerivative,
    size_t frameWidth,
    ValueWeightFunction getValueWeight = [](BaseIndex, CurIndex) { return 1; })
{
    REQUIRE(values.size() == valuesWithDerivative.size(),
            "input arrays should have equal sizes");

    std::vector<DecimalType> actualValues; // f(x)
    actualValues.reserve(values.size());
    std::vector<Shift<DecimalType>> shifts; // (h(x)-g(x))
    shifts.reserve(frameWidth);
    size_t baseIndex;

    auto addShift = [&](int curIndex) {
        if (curIndex >= 0 && curIndex < (int)values.size()) {
            double weight = getValueWeight(baseIndex, curIndex);
            if (weight > 0) {
                shifts.push_back(
                    {values[curIndex] - valuesWithDerivative[curIndex],
                     weight});
            }
        }
    };

    for (int i = 0; i < (int)values.size(); i++) {
        shifts.resize(0);
        baseIndex = i;

        addShift(i);
        int curFrameWidth = 1;
        while (shifts.size() < frameWidth) {
            if (i - curFrameWidth < 0 && i + curFrameWidth >= (int)values.size()) {
                break;
            }
            addShift(i + curFrameWidth);
            addShift(i - curFrameWidth);
            curFrameWidth++;
        }

        actualValues.push_back(
            valuesWithDerivative[i] + getMedian(shifts));
    }

    return actualValues;
}

} // namespace maps::mrc::pos_improvment
