#pragma once

#include <yplatform/util/per_second_accumulator.h>
#include <ymod_httpclient/errors.h>
#include <ymod_httpclient/settings.h>
#include <cmath>
#include <vector>

namespace ymod_httpclient::detail {

struct wrs_stat
{
    int64_t requests = 0;
    int64_t accepts = 0;

    double weight(double factor) const
    {
        return std::max(0.00001, std::pow((accepts + 1.0) / (requests + 1.0), factor));
    }

    wrs_stat& operator+=(const wrs_stat& other)
    {
        requests += other.requests;
        accepts += other.accepts;
        return *this;
    }
};

struct errors_budget_stat
{
    int64_t requests = 0;
    int64_t retries = 0;
    int64_t connect_errors = 0;
    int64_t retries_with_connect_error = 0;

    errors_budget_stat& operator+=(const errors_budget_stat& other)
    {
        requests += other.requests;
        retries += other.retries;
        connect_errors += other.connect_errors;
        retries_with_connect_error += other.retries_with_connect_error;
        return *this;
    }
};

struct weight_stat
{
    weight_stat(double weight) : sum(weight), count(1)
    {
    }

    weight_stat() : sum(0), count(0)
    {
    }

    void operator+=(const weight_stat& other)
    {
        sum += other.sum;
        count += other.count;
    }

    double sum;
    size_t count;
};

struct request_stat
{
    request_stat(unsigned stats_period, bool count_connect_errors, unsigned wrs_nodes_count)
        : stats_period(stats_period)
        , count_connect_errors(count_connect_errors)
        , errors_budget(stats_period)
        , wrs(wrs_nodes_count, stats_period)
        , weights_average(wrs_nodes_count, stats_period)
    {
        if (stats_period <= 1) throw std::runtime_error("stats_period must be greater than 1");
    }

    double calc_retries_ratio()
    {
        auto stats = errors_budget.get_sum(std::time(0) + 1, stats_period);
        int64_t requests = stats.requests;
        int64_t retries = stats.retries;
        if (!count_connect_errors)
        {
            requests -= stats.connect_errors;
            retries -= stats.retries_with_connect_error;
        }
        if (requests == 0) return retries_ratio = 0.0;
        if (retries == requests) return retries_ratio = max_retry_budget_value;
        return retries_ratio = retries / static_cast<double>(requests - retries);
    }

    void count_successfull_request(unsigned attempt, unsigned node_index)
    {
        errors_budget.add({ 1, attempt > 0, 0, 0 });
        if (wrs.size())
        {
            wrs[node_index].add({ 1, 1 });
        }
    }

    void count_failed_request(error_code ec, unsigned attempt, unsigned node_index)
    {
        auto connect_error = ec == http_error::connect_error;
        errors_budget.add({ 1, attempt > 0, connect_error, connect_error && attempt > 0 });
        if (wrs.size())
        {
            wrs[node_index].add({ 1, 0 });
        }
    }

    static constexpr double max_retry_budget_value = 10.0;

    unsigned stats_period = 5;
    double retries_ratio = 0.0;
    bool count_connect_errors;
    yplatform::per_second_accumulator<errors_budget_stat> errors_budget;
    std::vector<yplatform::per_second_accumulator<wrs_stat>> wrs;
    std::vector<yplatform::per_second_accumulator<weight_stat>> weights_average;
};

using request_stat_ptr = shared_ptr<request_stat>;

template <typename StatsAcc, typename WeightAcc>
void calc_weights(
    std::vector<StatsAcc>& stats,
    std::vector<double>& out,
    std::vector<WeightAcc>& weights_average,
    double corrected_factor,
    const balancing_settings::wrs_settings& wrs_settings)
{
    out.resize(stats.size());
    time_t current_time = std::time(0);
    for (size_t i = 0; i < stats.size(); ++i)
    {
        out[i] =
            calc_weight(current_time, stats[i], weights_average[i], corrected_factor, wrs_settings);
        weights_average[i].update(current_time, out[i]);
    }
}

template <typename StatsAcc, typename WeightAcc>
double calc_weight(
    time_t time,
    StatsAcc& stats,
    WeightAcc& weights_average,
    double corrected_factor,
    const balancing_settings::wrs_settings& settings)
{
    auto full_window_stats = stats.get_sum(time + 1, stats.capacity());
    double weight = full_window_stats.weight(corrected_factor);
    auto prev_weights_stats = weights_average.get_sum(time);
    if (prev_weights_stats.count == 0)
    {
        return std::min(settings.max_initial_weight, weight);
    }
    double small_window_weight = stats.get_sum(time + 1, 2).weight(corrected_factor);
    // use weight from small window if it falls too much
    if (small_window_weight * settings.momentary_fall_threshold < weight)
    {
        weight = small_window_weight;
    }
    // cut growth if it is too big
    double avg_weight = prev_weights_stats.sum / prev_weights_stats.count;
    if (weight > avg_weight * settings.max_weight_growth)
    {
        weight = avg_weight * settings.max_weight_growth;
    }
    return weight;
}

}
