#pragma once

#include <ymod_webserver/settings.h>
#include <yplatform/algorithm/leaky_map.h>
#include <boost/asio/io_service.hpp>
#include <type_traits>
#include <memory>
#include <numeric>
#include <algorithm>

namespace ymod_webserver {

class rate_limiter
{
public:
    rate_limiter(const rate_limit_settings& settings)
        : counters_(settings.recovery_rate, settings.recovery_interval_ms), settings_(settings)
    {
    }

    template <typename Stream>
    bool limit_exceeded(Stream stream)
    {
        auto attr_value = get_attr_value(stream);
        auto counter = counters_.get(attr_value ? *attr_value : "");
        return counter >= settings_.limit;
    }

    template <typename Stream>
    void increase_counter(Stream stream)
    {
        auto attr_value = get_attr_value(stream);
        counters_.add(attr_value ? *attr_value : "", 1);
    }

    codes::code limit_exceeded_status() const
    {
        return settings_.response_status;
    }

    const std::string& limit_exceeded_body() const
    {
        return settings_.response_body;
    }

    const string& name() const
    {
        return settings_.name;
    }

    bool match_by_path(const string& path) const
    {
        return std::any_of(settings_.path.begin(), settings_.path.end(), [&path](auto& pattern) {
            return boost::regex_match(path, pattern);
        });
    }

    template <typename Stream>
    bool match_by_stream(Stream stream) const
    {
        auto& params = stream->request()->url.params;
        auto& headers = stream->request()->headers;
        return std::all_of(
            settings_.filters.begin(),
            settings_.filters.end(),
            [this, &params, &headers](auto& filter) {
                return filter.attr.is_url_param() ? this->match_filter(filter, params) :
                                                    this->match_filter(filter, headers);
            });
    }

private:
    template <typename Stream>
    const string* get_attr_value(Stream stream) const
    {
        if (!settings_.limiting_attr)
        {
            return nullptr;
        }

        auto& attr = *settings_.limiting_attr;
        if (attr.is_url_param())
        {
            auto& params = stream->request()->url.params;
            auto it = params.find(attr.key);
            return it == params.end() ? nullptr : &it->second;
        }
        else
        {
            auto& headers = stream->request()->headers;
            auto it = headers.find(attr.key);
            return it == headers.end() ? nullptr : &it->second;
        }
    }

    template <typename AttrMap>
    bool match_filter(const rate_limit_settings::filter& filter, const AttrMap& attr_map) const
    {
        auto it = attr_map.find(filter.attr.key);
        return boost::regex_match(it != attr_map.end() ? it->second : "", filter.value);
    }

    yplatform::leaky_map<string, uint64_t> counters_;
    rate_limit_settings settings_;
};

template <typename Handler>
struct rate_limit_wrapper
{
    using handler_type = typename std::decay_t<Handler>;

    rate_limit_wrapper(
        const std::vector<std::shared_ptr<rate_limiter>>& limiters,
        const Handler& handler)
        : limiters(limiters), handler(handler)
    {
        matched_limiters.reserve(limiters.size());
    }

    template <typename Stream, typename... Args>
    void operator()(Stream stream, Args&&... args)
    {
        matched_limiters.clear();
        for (auto& limiter : limiters)
        {
            if (limiter->match_by_stream(stream))
            {
                if (limiter->limit_exceeded(stream))
                {
                    stream->request()->context->custom_log_data["rate_limit"] = limiter->name();
                    stream->set_code(limiter->limit_exceeded_status());
                    stream->result_body(limiter->limit_exceeded_body());
                    return;
                }
                matched_limiters.push_back(limiter.get());
            }
        }
        for (auto& limiter : matched_limiters)
        {
            limiter->increase_counter(stream);
        }
        stream->get_io_service().post(std::bind(handler, stream, std::forward<Args>(args)...));
    }

    std::vector<std::shared_ptr<rate_limiter>> limiters;
    std::vector<rate_limiter*> matched_limiters;
    handler_type handler;
};

}
