#pragma once

#include <ymod_webserver/codes.h>
#include <ymod_webserver/settings.h>
#include <ymod_webserver/websocket_codes.h>
#include <yplatform/util/histogram.h>
#include <yplatform/util/per_second_accumulator.h>
#include <yplatform/ptree.h>
#include <boost/lexical_cast.hpp>
#include <boost/algorithm/string/case_conv.hpp>
#include <boost/algorithm/string/trim.hpp>
#include <mutex>
#include <cctype>
#include <tuple>
#include <map>
#include <optional>

namespace ymod_webserver {

namespace detail {

inline string format_path(const string& path)
{
    auto formatted_path = path;
    boost::algorithm::to_lower(formatted_path);

    std::replace_if(
        formatted_path.begin(), formatted_path.end(), [](char c) { return c == '/'; }, '_');

    formatted_path.erase(
        std::remove_if(
            formatted_path.begin(),
            formatted_path.end(),
            [](char c) { return !(isalnum(c) || c == '_'); }),
        formatted_path.end());

    boost::algorithm::trim_if(formatted_path, [](char c) { return c == '_'; });

    return formatted_path;
}

}

class code_stats
{
public:
    using code_counters = std::map<uint32_t, std::atomic<uint32_t>>;

    explicit code_stats(code_counters counters) : code_counters_(std::move(counters))
    {
    }

    void add_path(const string& path)
    {
        auto formatted_path = detail::format_path(path);
        if (formatted_path.empty()) return;
        auto [path_codes_it, inserted] =
            code_counters_by_path_.emplace(formatted_path, code_counters{});
        if (!inserted) return;
        for (auto&& [code, _] : code_counters_)
        {
            path_codes_it->second[code] = {};
        }
    }

    void increment_code_counters(const string& path, const uint32_t code)
    {
        auto count_it = code_counters_.find(code);
        if (count_it != code_counters_.end()) count_it->second++;

        auto formatted_path = detail::format_path(path);
        auto path_codes_it = code_counters_by_path_.find(formatted_path);
        if (path_codes_it == code_counters_by_path_.end()) return;
        count_it = path_codes_it->second.find(code);
        if (count_it != path_codes_it->second.end()) count_it->second++;
    }

    yplatform::ptree to_ptree() const
    {
        yplatform::ptree codes;
        for (auto&& [path, code_counters] : code_counters_by_path_)
        {
            for (auto&& [code, count] : code_counters)
            {
                if (!count) continue;
                codes.put(
                    "codes_" + path + "_" + std::to_string(code) + "_cumulative",
                    std::to_string(count.load()));
            }
        }
        for (auto&& [code, count] : code_counters_)
        {
            if (!count) continue;
            codes.put(
                "codes_code_" + std::to_string(code) + "_cumulative", std::to_string(count.load()));
        }
        return codes;
    }

private:
    std::map<string, code_counters> code_counters_by_path_;
    code_counters code_counters_;
};

class timing_stats
{
public:
    using histogram_axis =
        yplatform::hgram::axis::regular<double, yplatform::hgram::axis::transform::log>;
    using histogram = yplatform::hgram::histogram<std::tuple<histogram_axis>>;

    explicit timing_stats(histogram_axis axis)
        : timing_hgram_axis_(std::move(axis))
        , total_timing_hgram_(std::tuple<histogram_axis>(timing_hgram_axis_))
    {
    }

    void add_path(const string& path)
    {
        auto formatted_path = detail::format_path(path);
        if (formatted_path.empty()) return;
        timing_hgrams_by_path_.emplace(
            formatted_path, yplatform::hgram::make_histogram(timing_hgram_axis_));
    }

    void add_timing(const string& path, double seconds)
    {
        lock_t lock(mutex_);
        total_timing_hgram_(seconds);
        auto formatted_path = detail::format_path(path);
        auto timing_hgram_by_path_it = timing_hgrams_by_path_.find(formatted_path);
        if (timing_hgram_by_path_it != timing_hgrams_by_path_.end())
        {
            timing_hgram_by_path_it->second(seconds);
        }
    }

    std::map<string, histogram> timings_by_path() const
    {
        lock_t lock(mutex_);
        return timing_hgrams_by_path_;
    }

    histogram total_timings() const
    {
        lock_t lock(mutex_);
        return total_timing_hgram_;
    }

    yplatform::ptree to_ptree() const
    {
        lock_t lock(mutex_);
        yplatform::ptree timings;
        timings.put_child(
            "timings_cumulative_hgram", yplatform::hgram::to_ptree(total_timing_hgram_));
        for (auto&& [path, timing] : timing_hgrams_by_path_)
        {
            timings.put_child(
                "timings_" + path + "_cumulative_hgram", yplatform::hgram::to_ptree(timing));
        }
        return timings;
    }

private:
    using lock_t = std::lock_guard<std::mutex>;

    mutable std::mutex mutex_;
    histogram_axis timing_hgram_axis_;
    std::map<string, histogram> timing_hgrams_by_path_;
    histogram total_timing_hgram_;
};

namespace detail {

const uint32_t MIN_HTTP_STATUS_CODE = 100;
const uint32_t MAX_HTTP_STATUS_CODE = 600;
const double TIMING_MIN_BOUND = 0.001;
const double HTTP_TIMING_MAX_BOUND = 600;
const double WEBSOCKET_TIMING_MAX_BOUND = 86400;
const std::size_t TIMING_HGRAM_BUCKET_NUM = 50;

inline code_stats::code_counters make_http_code_counters()
{
    code_stats::code_counters counters;
    for (auto code = MIN_HTTP_STATUS_CODE; code < MAX_HTTP_STATUS_CODE; ++code)
    {
        if (codes::reason::get(code) != codes::reason::get_unknown_reason())
        {
            counters[code] = {};
        }
    }
    return counters;
}

inline code_stats::code_counters make_websocket_code_counters()
{
    code_stats::code_counters counters;
    for (auto code : websocket::codes::stat_codes())
    {
        counters[code] = {};
    }
    return counters;
}

inline timing_stats::histogram_axis make_hgram_axis(
    std::size_t bucket_num,
    double min_bound,
    double max_bound)
{
    if (bucket_num <= 2)
    {
        throw std::domain_error("invalid number of histogram buckets");
    }
    return timing_stats::histogram_axis(
        bucket_num - 2, // Under- and overflow buckets are added by default.
        min_bound,
        max_bound);
}

inline timing_stats::histogram_axis make_http_timings_axis()
{
    return make_hgram_axis(TIMING_HGRAM_BUCKET_NUM, TIMING_MIN_BOUND, HTTP_TIMING_MAX_BOUND);
}

inline timing_stats::histogram_axis make_websocket_timings_axis()
{
    return make_hgram_axis(TIMING_HGRAM_BUCKET_NUM, TIMING_MIN_BOUND, WEBSOCKET_TIMING_MAX_BOUND);
}

}

class module_stats
{
public:
    module_stats(const settings& settings = {})
        : rps_acc_(61)
        , http_code_stats_(detail::make_http_code_counters())
        , websocket_code_stats_(detail::make_websocket_code_counters())
        , websocket_streams_max_(settings.ws_max_streams)
    {
        if (settings.enable_timing_stats)
        {
            http_timing_stats_.emplace(detail::make_http_timings_axis());
            websocket_timing_stats_.emplace(detail::make_websocket_timings_axis());
        }
    }

    void add_http_path(const string& path)
    {
        http_code_stats_.add_path(path);
        if (http_timing_stats_) http_timing_stats_->add_path(path);
    }

    void add_websocket_path(const string& path)
    {
        websocket_code_stats_.add_path(path);
        if (websocket_timing_stats_) websocket_timing_stats_->add_path(path);
    }

    void increment_http_code_counters(const string& path, const uint32_t code)
    {
        http_code_stats_.increment_code_counters(path, code);
    }

    void increment_websocket_code_counters(const string& path, const uint32_t code)
    {
        websocket_code_stats_.increment_code_counters(path, code);
    }

    void add_http_timing(const string& path, double seconds)
    {
        if (http_timing_stats_) http_timing_stats_->add_timing(path, seconds);
    }

    void add_websocket_timing(const string& path, double seconds)
    {
        if (websocket_timing_stats_) websocket_timing_stats_->add_timing(path, seconds);
    }

    void new_request()
    {
        lock_t lock(mutex_);
        rps_acc_.add(1);
    }

    yplatform::ptree to_ptree() const
    {
        yplatform::ptree result;
        {
            lock_t lock(mutex_);
            result.put("rps", boost::lexical_cast<string>(rps_acc_.get_last()));
            result.put("rps_avg_min", boost::lexical_cast<string>(rps_acc_.get_avg(60)));
        }
        result.push_back(std::pair("http", http_code_stats_.to_ptree()));
        if (http_timing_stats_) result.push_back(std::pair("http", http_timing_stats_->to_ptree()));
        result.push_back(std::pair("websocket", websocket_code_stats_.to_ptree()));
        if (websocket_timing_stats_) result.push_back(std::pair("websocket", websocket_timing_stats_->to_ptree()));
        result.push_back(std::pair("sessions", session_stats_to_ptree()));
        result.put(
            "websocket_streams_count", boost::lexical_cast<string>(websocket_streams_count.load()));
        if (websocket_streams_max_ != std::numeric_limits<size_t>::max())
        {
            result.put(
                "websocket_streams_max", boost::lexical_cast<string>(websocket_streams_max_));
        }
        return result;
    }

    std::atomic<size_t> websocket_streams_count = 0;

    struct
    {
        std::atomic<size_t> total_connected = 0;
        std::atomic<size_t> handshake_errors = 0;
        std::atomic<size_t> throttled_handshakes = 0;
        std::atomic<size_t> read_errors = 0;
        std::atomic<size_t> write_errors = 0;
        std::atomic<size_t> ssl_errors = 0;
    } session_stats;

private:
    yplatform::ptree session_stats_to_ptree() const
    {
        yplatform::ptree sessions;
        sessions.put("connected_cumulative", session_stats.total_connected.load());
        sessions.put("handshake_errors_cumulative", session_stats.handshake_errors.load());
        sessions.put("throttled_handshakes_cumulative", session_stats.throttled_handshakes.load());
        sessions.put("read_errors_cumulative", session_stats.read_errors.load());
        sessions.put("write_errors_cumulative", session_stats.write_errors.load());
        sessions.put("ssl_errors_cumulative", session_stats.ssl_errors.load());
        return sessions;
    }

    using lock_t = std::lock_guard<std::mutex>;

    mutable std::mutex mutex_;
    mutable yplatform::per_second_accumulator<int> rps_acc_;
    code_stats http_code_stats_;
    code_stats websocket_code_stats_;
    std::optional<timing_stats> http_timing_stats_;
    std::optional<timing_stats> websocket_timing_stats_;
    size_t websocket_streams_max_;
};

}