#pragma once

#include <ymod_ratecontroller/errors.h>

#include <ymod_ratecontroller/rate_controller.h>

#include <yplatform/module.h>
#include <yplatform/find.h>
#include <yplatform/util/split.h>

#include <boost/asio/io_service.hpp>
#include <boost/algorithm/string.hpp>
#include <boost/property_tree/ptree.hpp>

#include <deque>
#include <map>
#include <mutex>
#include <atomic>

namespace ymod_ratecontroller {

namespace ph = std::placeholders;

class rate_controller_impl
    : public rate_controller
    , public std::enable_shared_from_this<rate_controller_impl>
{
public:
    rate_controller_impl(
        boost::asio::io_service& io,
        std::size_t max_concurrency,
        std::size_t max_queue_size)
        : io_(io), max_concurrency_(max_concurrency), max_queue_size_(max_queue_size)
    {
    }

    void post(
        const task_type& task,
        const std::string& task_id,
        const time_traits::time_point& deadline) override
    {
        io_.post([this, self = shared_from_this(), task_id, task, deadline] {
            if (tasks_.size() >= max_queue_size_)
            {
                task(error::capacity_exceeded, [] {});
                return;
            }
            std::size_t task_num = tasks_counter_++;
            try
            {
                tasks_.emplace_back(io_, task_id, task, task_num);
            }
            catch (const std::exception&)
            {
                task(error::add_to_rc_queue_exception, [] {});
                return;
            }
            queue_size_.fetch_add(1, std::memory_order_relaxed);
            if (deadline != time_traits::time_point::max())
            {
                auto& timer = tasks_.back().deadline_timer;
                timer.expires_at(deadline);
                timer.async_wait(
                    std::bind(&rate_controller_impl::on_deadline, self, ph::_1, task_num));
            }
        });
        io_.post(std::bind(&rate_controller_impl::keep_running, shared_from_this()));
    }

    void cancel(const std::string& task_id) override
    {
        io_.post([this, self = shared_from_this(), task_id] {
            for (auto& task_data : tasks_)
            {
                if (task_data.task_id == task_id)
                {
                    abort(task_data);
                }
            }
        });
    }

    std::size_t max_queue_size() const
    {
        return max_queue_size_;
    }

    std::size_t queue_size() const override
    {
        return queue_size_.load(std::memory_order_relaxed);
    }

    std::size_t running_tasks_count() const override
    {
        return running_tasks_.load(std::memory_order_relaxed);
    }

    std::size_t max_concurrency() const override
    {
        return max_concurrency_;
    }

private:
    enum class task_state
    {
        pending,
        aborted,
        running
    };

    struct queued_task
    {
        queued_task(
            boost::asio::io_service& io,
            const std::string& task_id,
            const task_type& task,
            std::size_t task_num)
            : task_id(task_id), task(task), deadline_timer(io), task_num(task_num)
        {
        }

        std::string task_id;
        task_type task;
        time_traits::timer deadline_timer;
        task_state state = task_state::pending;
        std::size_t task_num;
    };

    void on_complete()
    {
        assert(running_tasks_count() > 0);
        running_tasks_.fetch_sub(1, std::memory_order_relaxed);
        keep_running();
    }

    void on_deadline(error_code err, std::size_t task_num)
    {
        io_.post(std::bind(&rate_controller_impl::keep_running, shared_from_this()));
        if (err || tasks_.empty() || tasks_.front().task_num > task_num)
        {
            return;
        }
        auto& task_data = tasks_.at(task_num - tasks_.front().task_num);
        abort(task_data);
    }

    void abort(queued_task& task_data)
    {
        if (task_data.state == task_state::pending)
        {
            task_data.state =
                task_state::aborted; // don't delete task from queue, just mark aborted
            io_.post(std::bind(task_data.task, error::task_aborted, [] {}));
        }
    }

    void keep_running()
    {
        completion_handler on_complete_cb =
            io_.wrap(std::bind(&rate_controller_impl::on_complete, shared_from_this()));
        while (tasks_.size() && running_tasks_count() < max_concurrency_)
        {
            auto& data = tasks_.front();
            if (data.state == task_state::pending)
            {
                data.state = task_state::running;
                running_tasks_.fetch_add(1, std::memory_order_relaxed);
                io_.post(std::bind(data.task, error_code(), on_complete_cb));
                data.deadline_timer.cancel();
            }
            tasks_.pop_front();
            queue_size_.fetch_sub(1, std::memory_order_relaxed);
        }
    }

    boost::asio::io_service& io_;
    std::deque<queued_task> tasks_;
    const std::size_t max_concurrency_;
    const std::size_t max_queue_size_;
    std::size_t tasks_counter_ = 0;
    std::atomic<std::size_t> running_tasks_{ 0 };
    std::atomic<std::size_t> queue_size_{ 0 };
};

class rate_controller_module_impl
    : public rate_controller_module
    , public yplatform::module
{
public:
    rate_controller_ptr get_controller(const std::string& resource_path) override
    {
        std::lock_guard<std::mutex> lock(mutex_);
        auto it = controllers_.find(resource_path);
        if (it != controllers_.end())
        {
            return it->second;
        }
        auto cur_settings = get_settings(resource_path);
        auto rc = std::make_shared<rate_controller_impl>(
            *reactor_->io(), cur_settings.max_concurrency, cur_settings.max_queue_size);
        controllers_[resource_path] = rc;
        return rc;
    }

    void init(const yplatform::ptree& conf)
    {
        reactor_ = yplatform::find_reactor(conf.get<std::string>("reactor"));
        for (std::size_t i = 0; i < reactor_->size(); ++i)
        {
            if ((*reactor_)[i]->size() != 1)
            {
                throw std::runtime_error("rate controller is optimized for single-thread reactors "
                                         "- set pool_count=N and io_threads=1");
            }
        }

        settings_ = build_settings(conf);
    }

    yplatform::ptree get_stats() const override
    {
        std::lock_guard<std::mutex> lock(mutex_);

        yplatform::ptree ret;
        for (auto& pair : controllers_)
        {
            auto& resource_path = pair.first;
            auto& rc = pair.second;
            yplatform::ptree controller;
            controller.put("queue_size", rc->queue_size());
            controller.put("max_queue_size", rc->max_queue_size());
            controller.put("max_concurrency", rc->max_concurrency());
            controller.put("running_tasks_count", rc->running_tasks_count());
            ret.push_back(std::make_pair(resource_path, controller));
        }

        return ret;
    }

private:
    struct settings
    {
        settings(){};
        settings(const yplatform::ptree& conf)
            : max_concurrency(conf.get<std::size_t>("max_concurrency"))
            , max_queue_size(conf.get<std::size_t>("max_queue_size"))
        {
        }

        std::size_t max_concurrency = 0;
        std::size_t max_queue_size = 0;
    };

    using settings_tree = boost::property_tree::basic_ptree<std::string, settings>;

    settings_tree build_settings(const yplatform::ptree& conf, const settings& def = settings())
    {
        settings_tree node;
        settings node_settings = def;

        auto optional_settings = conf.get_child_optional("settings");
        if (optional_settings)
        {
            node_settings = settings(*optional_settings);
        }
        node.put_value(node_settings);

        auto values = conf.equal_range("children");
        for (auto& pair : boost::make_iterator_range(values))
        {
            auto& child_node = pair.second;
            auto child_name = child_node.get<std::string>("name");
            node.put_child(child_name, build_settings(child_node, node_settings));
        }

        return node;
    }

    settings get_settings(const std::string& resource_path)
    {
        auto path_elements = yplatform::util::split(resource_path, ".");

        auto settings_node = &settings_;
        auto current_path = path_elements.begin();
        while (settings_node->count(*current_path) && current_path != path_elements.end())
        {
            settings_node = &settings_node->get_child(*current_path++);
        }

        return settings_node->get_value<settings>();
    }

    mutable std::mutex mutex_;
    yplatform::reactor_ptr reactor_;
    std::map<std::string, std::shared_ptr<rate_controller_impl>> controllers_;
    settings_tree settings_;
};

}
