#include "connection_controller.h"
#include "session.h"
#include "module_stats.h"
#include "net_server.h"
#include "default_handler.h"
#include "rate_limit.h"
#include <ymod_webserver/server.h>
#include <ymod_webserver/methods/dispatcher.h>
#include <yplatform/net/server.h>
#include <yplatform/net/acceptor.h>
#include <yplatform/net/module.h>
#include <yplatform/util/execution_holder.h>
#include <yplatform/util/weak_bind.h>

namespace ymod_webserver {

class server::impl
    : public yplatform::log::contains_logger
    , public std::enable_shared_from_this<server::impl>
{
public:
    typedef boost::shared_ptr<default_handler> default_handler_ptr;

    impl(yplatform::reactor_ptr reactor, const settings& conf)
        : settings_(std::make_shared<settings>(conf)), reactor_(reactor)
    {
        if (!reactor->plain())
        {
            throw std::runtime_error(
                "webserver requires a single-thread reactor - set pool_count=N and io_threads=1");
        }

        io_ = reactor_->io();
        stats_ = std::make_shared<module_stats>(*settings_);

        io_data_.reserve(reactor_->size());
        for (size_t i = 0; i < reactor_->size(); ++i)
        {
            io_data_.emplace_back(*(*reactor_)[i]->io());
            auto& io = io_data_.back();
            io.setup_ssl(settings_->ssl);
            io.setup_dns();
        }

        // Read certificates and bind to sockets at init because
        // yplatform process may have insufficient privileges at start.
        for (size_t i = 0; i < io_data_.size(); ++i)
        {
            auto& io = io_data_[i];
            for (auto& ep_pair : settings_->endpoints)
            {
                auto& ep = ep_pair.second;
                webservers_.push_back(
                    boost::make_shared<yplatform_net_server>(io, settings_, stats_));
                tcp_servers_.emplace_back(io, ep.addr, ep.port, ep.socket_settings);
            }
        }

        for (auto& limit_settings : settings_->rate_limits)
        {
            rate_limiters_.emplace_back(std::make_shared<rate_limiter>(limit_settings));
        }

        if (settings_->congestion_control)
        {
            connection_controller_ =
                std::make_shared<connection_controller>(reactor_, *settings_->congestion_control);
        }
    }

    void start()
    {
        for (auto& server : webservers_)
        {
            server->logger(logger());
        }
        setup_default_handlers();
        listen();
    }

    void stop()
    {
        for (auto& server : tcp_servers_)
        {
            server.stop();
        }
        //  not_executed_sessions_->clear();
    }

    void subscribe(const string& ep_name, const handler_ptr& h)
    {
        YLOG_LOCAL(info) << "add " << static_cast<void*>(h.get()) << " to " << ep_name
                         << " subscribers";
        endpoint_handlers_[ep_name] = h;
    }

    yplatform::ptree get_stats() const
    {
        auto ret = stats_->to_ptree();
        if (connection_controller_)
        {
            ret.push_back(std::pair("congestion-control", connection_controller_->get_stats()));
        }
        return ret;
    }

    default_handler_ptr get_default_handler(const string& ep_name)
    {
        auto& h = default_handlers_[ep_name];
        if (!h) h = boost::make_shared<default_handler>();
        return h;
    }

    void bind_http(
        const string& ep_name,
        const std::vector<std::string>& paths,
        const http_method& method)
    {
        auto h = get_default_handler(ep_name);
        for (const auto& path : paths)
        {
            std::vector<std::shared_ptr<rate_limiter>> limiters;
            for (auto limiter : rate_limiters_)
            {
                if (limiter->match_by_path(path))
                {
                    limiters.emplace_back(limiter);
                }
            }
            if (limiters.size())
            {
                // synchronize access to rate_limiters with io_
                h->http_dispatcher[path] = io_->wrap(rate_limit_wrapper(limiters, method));
            }
            else
            {
                h->http_dispatcher[path] = method;
            }
            stats_->add_http_path(path);
        }
    }

    void bind_websocket(
        const string& ep_name,
        const std::vector<std::string>& paths,
        const websocket_method& method)
    {
        auto h = get_default_handler(ep_name);
        for (const auto& path : paths)
        {
            h->ws_dispatcher[path] = method;
            stats_->add_websocket_path(path);
        }
    }

    void set_custom_key_extractor(const string& ep_name, http_key_extractor_type&& extractor)
    {
        auto h = get_default_handler(ep_name);
        h->http_dispatcher.set_key_extractor(std::move(extractor));
    }

    void set_custom_websocket_key_extractor(
        const string& ep_name,
        websocket_key_extractor_type&& extractor)
    {
        auto h = get_default_handler(ep_name);
        h->ws_dispatcher.set_key_extractor(std::move(extractor));
    }

    void set_default_method(const string& ep_name, http_method&& method)
    {
        auto h = get_default_handler(ep_name);
        h->http_dispatcher.set_default_method(std::move(method));
    }

    void set_websocket_default_method(const string& ep_name, websocket_method&& method)
    {
        auto h = get_default_handler(ep_name);
        h->ws_dispatcher.set_default_method(std::move(method));
    }

private:
    void setup_default_handlers()
    {
        for (auto& ep_pair : settings_->endpoints)
        {
            auto& ep = ep_pair.second;
            // Use default handler if there was no custom handlers set by user.
            bool use_default = false;
            auto it = endpoint_handlers_.find(ep.name);
            if (it == endpoint_handlers_.end())
            {
                use_default = true;
                it = endpoint_handlers_.emplace(ep.name, default_handlers_[ep.name]).first;
            }
            YLOG_LOCAL(info) << "set " << (use_default ? "default" : "custom") << " handler "
                             << static_cast<void*>(it->second.get()) << " to endpoint \"" << ep.name
                             << "\"";
        }
    }

    void listen()
    {
        size_t i = 0;
        for (auto& io : io_data_)
        {
            for (auto& ep_pair : settings_->endpoints)
            {
                auto& ep = ep_pair.second;
                YLOG_L(info) << "listen " << ep.addr << ":" << ep.port << " secure=" << ep.secure;
                auto& tcp_server = tcp_servers_[i];
                tcp_server.listen(yplatform::weak_bind(
                    &ymod_webserver::server::impl::accept_connection,
                    shared_from_this(),
                    io.get_io(),
                    webservers_[i],
                    std::ref(ep),
                    std::placeholders::_1));
                ++i;
            }
        }
    }

    void accept_connection(
        boost::asio::io_service* io,
        boost::shared_ptr<yplatform_net_server> webserver,
        endpoint& ep,
        yplatform::net::tcp_socket&& socket)
    {
        // TODO do not copy settings
        auto session = boost::make_shared<ymod_webserver::session>(
            std::move(socket), *settings_, ep, endpoint_handlers_[ep.name]);
        YLOG_CTX_LOCAL(session->ctx(), info) << "new connection: "
                                             << "from=\"" << session->ctx()->remote_address << ":"
                                             << session->ctx()->remote_port << "\"";
        stats_->session_stats.total_connected += 1;
        if (ep.secure)
        {
            yplatform::execution_holder_ptr handshake_holder;
            if (connection_controller_)
            {
                if (!connection_controller_->try_run(io))
                {
                    YLOG_CTX_LOCAL(session->ctx(), info) << "tls handshake throttled";
                    stats_->session_stats.throttled_handshakes += 1;
                    return;
                }
                handshake_holder = std::make_shared<yplatform::execution_holder>(
                    std::bind(&connection_controller::finish, connection_controller_, io));
            }
            session->ctx()->profilers.push("ssl");
            session->enable_ssl(yplatform::weak_bind(
                &ymod_webserver::server::impl::handle_ssl,
                shared_from_this(),
                handshake_holder,
                webserver,
                session,
                std::placeholders::_1));
        }
        else
        {
            webserver->register_session(session);
        }
    }

    void handle_ssl(
        yplatform::execution_holder_ptr /*handshake_holder*/,
        boost::shared_ptr<yplatform_net_server> webserver,
        session_ptr session,
        const boost::system::error_code& ec)
    {
        session->ctx()->profilers.pop("ssl");
        if (ec)
        {
            YLOG_CTX_LOCAL(session->ctx(), info) << "tls handshake error: " << ec.message();
        }
        else
        {
            webserver->register_session(session);
        }
    }

    std::shared_ptr<settings> settings_;
    yplatform::reactor_ptr reactor_;
    boost::asio::io_service* io_ = nullptr;
    std::map<string, default_handler_ptr> default_handlers_;
    std::map<string, handler_ptr> endpoint_handlers_;
    std::vector<yplatform::net::io_data> io_data_;
    std::vector<yplatform::net::tcp_server> tcp_servers_;
    std::vector<boost::shared_ptr<yplatform_net_server>> webservers_;
    std::shared_ptr<module_stats> stats_;
    std::vector<std::shared_ptr<rate_limiter>> rate_limiters_;
    std::shared_ptr<connection_controller> connection_controller_;
};

server::server(yplatform::reactor& reactor, const yplatform::ptree& conf)
{
    settings config;
    config.parse_ptree(conf);
    impl_ = std::make_shared<impl>(yplatform::reactor::make_not_owning_copy(reactor), config);
    impl_->logger(*reactor.io());
}

server::server(boost::asio::io_service& io, const settings& conf)
{
    auto reactor =
        boost::make_shared<yplatform::reactor>(std::make_shared<yplatform::io_pool>(io, 1));
    impl_ = std::make_shared<impl>(reactor, conf);
    impl_->logger(io);
}

void server::start()
{
    impl_->start();
}

void server::stop()
{
    impl_->stop();
}

void server::logger(const yplatform::log::source& logger)
{
    contains_logger::logger(logger);
    impl_->logger(logger);
}

void server::subscribe(const string& ep_name, const handler_ptr& h)
{
    impl_->subscribe(ep_name, h);
}

yplatform::ptree server::get_stats() const
{
    return impl_->get_stats();
}

void server::bind_http(
    const string& ep_name,
    const std::vector<std::string>& paths,
    const http_method& method)
{
    impl_->bind_http(ep_name, paths, method);
}

void server::bind_websocket(
    const string& ep_name,
    const std::vector<std::string>& paths,
    const websocket_method& method)
{
    impl_->bind_websocket(ep_name, paths, method);
}

void server::set_custom_key_extractor(const string& ep_name, http_key_extractor_type&& extractor)
{
    impl_->set_custom_key_extractor(ep_name, std::move(extractor));
}

void server::set_custom_websocket_key_extractor(
    const string& ep_name,
    websocket_key_extractor_type&& extractor)
{
    impl_->set_custom_websocket_key_extractor(ep_name, std::move(extractor));
}

void server::set_default_method(const string& ep_name, http_method&& method)
{
    impl_->set_default_method(ep_name, std::move(method));
}

void server::set_websocket_default_method(const string& ep_name, websocket_method&& method)
{
    impl_->set_websocket_default_method(ep_name, std::move(method));
}

using impl = server;

} // namespace ymod_webserver

#include <yplatform/module_registration.h>
REGISTER_MODULE(ymod_webserver::server)
REGISTER_MODULE(ymod_webserver::impl)
