#pragma once

#include "request.h"
#include "session_pool.h"
#include "session.h"
#include "log_util.h"
#include "simple_handler.h"
#include "prepare_headers_to_log.h"
#include <ymod_httpclient/detail/eval_timeouts.h>
#include <ymod_httpclient/util/url_parser.h>
#include <ymod_httpclient/call.h>
#include <yplatform/algorithm/congestion_controller.h>
#include <yplatform/net/io_service.h>
#include <yplatform/find.h>
#include <yplatform/module.h>
#include <yplatform/log/typed.h>
#include <yplatform/reactor/tracer.h>
#include <yplatform/util/visit.h>
#include <yplatform/util/sstream.h>
#include <yplatform/util.h>
#include <ticket_provider/for_host.h>
#include <boost/format.hpp>

namespace ymod_httpclient {

namespace detail {

struct remote_point_info_less
{
    bool operator()(const remote_point_info& a, const remote_point_info& b) const
    {
        return std::tie(a.proto, a.host, a.port) < std::tie(b.proto, b.host, b.port);
    }
};

inline void adapt_async(
    future_void_t future,
    boost::shared_ptr<simple_handler> handler,
    call::callback_type callback,
    yplatform::log::source logger,
    task_context_ptr ctx)
{
    future.add_callback(
        [cb = call::callback_type(std::move(callback)), future, handler, logger, ctx]() {
            auto err = errc();
            // TODO handle all exceptions or refactor
            try
            {
                future.get();
            }
            catch (const connect_error& e)
            {
                err = errc::connect_error;
            }
            catch (const connection_timeout& e)
            {
                err = errc::connection_timeout;
            }
            catch (const ssl_error& e)
            {
                err = errc::ssl_error;
            }
            catch (const read_error& e)
            {
                err = errc::read_error;
            }
            catch (const write_error& e)
            {
                err = errc::write_error;
            }
            catch (const request_timeout_error& e)
            {
                err = errc::request_timeout;
            }
            catch (const server_response_error& e)
            {
                err = errc::server_response_error;
            }
            catch (const server_status_error& e)
            {
                err = errc::server_status_error;
            }
            catch (const server_header_error& e)
            {
                err = errc::server_header_error;
            }
            catch (const task_canceled_error& e)
            {
                err = errc::task_canceled;
            }
            catch (const parse_response_error& e)
            {
                err = errc::parse_response_error;
            }
            catch (const eof_error& e)
            {
                err = errc::eof_error;
            }
            catch (const request_uri_too_long_error& e)
            {
                err = errc::request_uri_too_long;
            }
            catch (const no_service_ticket_error& e)
            {
                err = errc::no_service_ticket;
            }
            catch (const task_throttled_error& e)
            {
                err = errc::task_throttled;
            }
            catch (const std::exception& e)
            {
                YLOG_CTX(logger, ctx, error) << " unclassified error: " << e.what();
                err = errc::unknown_error;
            }
            cb(err, err ? response() : std::move(handler->result));
        });
}

inline void set_exception(
    yplatform::future::promise<void>& prom,
    http_error::code errc,
    const string& reason)
{
    switch (errc)
    {
    case http_error::code::connect_error:
        prom.set_exception(connect_error(reason));
        break;
    case http_error::code::ssl_error:
        prom.set_exception(ssl_error(reason));
        break;
    case http_error::code::read_error:
        prom.set_exception(read_error(reason));
        break;
    case http_error::code::write_error:
        prom.set_exception(write_error(reason));
        break;
    case http_error::code::server_response_error:
        prom.set_exception(server_response_error(reason));
        break;
    case http_error::code::server_status_error:
        prom.set_exception(server_status_error(reason));
        break;
    case http_error::code::server_header_error:
        prom.set_exception(server_header_error(reason));
        break;
    case http_error::code::request_timeout:
        prom.set_exception(request_timeout_error(reason));
        break;
    case http_error::code::connection_timeout:
        prom.set_exception(connection_timeout(reason));
        break;
    case http_error::parse_response_error:
        prom.set_exception(parse_response_error(reason));
        break;
    case http_error::code::task_canceled:
        prom.set_exception(task_canceled_error(reason));
        break;
    case http_error::code::eof_error:
        prom.set_exception(eof_error(reason));
        break;
    case http_error::code::request_uri_too_long:
        prom.set_exception(request_uri_too_long_error(reason));
        break;
    case http_error::code::task_throttled:
        prom.set_exception(task_throttled_error(reason));
        break;
    default:
        prom.set_exception(unknown_error(reason));
        break;
    }
}

inline void adapt_async(shared_ptr<promise_void_t> prom, request_data& req)
{
    req.handler = [prom](auto&& err, auto&& reason) {
        if (!err)
        {
            prom->set();
        }
        else
        {
            set_exception(*prom, err, reason);
        }
    };
}

inline string to_string(const remote_point_info& remote)
{
    const string port_str = (0 == remote.port) ? string() : ":" + std::to_string(remote.port);
    return remote.proto + "://" + remote.host + port_str;
}

inline void assign(header_container& container, request::headers_type headers)
{
    yplatform::visit(
        std::move(headers),
        [&container](string headers) { container = header_container(std::move(headers)); },
        [&container](header_dict headers) { container = header_container(std::move(headers)); });
}

inline const std::string& request_id_or_unique_id(task_context_ptr context)
{
    if (!context->request_id().empty())
        return context->request_id();
    else
        return context->uniq_id();
}

namespace ph = std::placeholders;

template <typename Session>
class call_impl
    : public call
    , public yplatform::module
{
    typedef Session session;
    typedef std::mutex mutex_t;
    typedef std::unique_lock<mutex_t> lock_t;
    typedef shared_ptr<session> session_ptr;
    typedef session_pool<session, yplatform::congestion_controller> connections;
    typedef shared_ptr<connections> connections_ptr;
    typedef task_context_ptr context_ptr;
    typedef std::map<remote_point_info, connections_ptr, remote_point_info_less> active_connections;

public:
    void init(const yplatform::ptree& conf)
    {
        auto reactor = yplatform::global_reactor_set->get(conf.get("reactor", "global"));
        init(reactor, conf);
    }

    void init(yplatform::reactor_ptr reactor, const yplatform::ptree& conf)
    {
        settings_.parse_ptree(conf);
        init(reactor, settings_);
    }

    void init(yplatform::reactor_ptr reactor, const settings& settings)
    {
        if (!reactor->plain())
        {
            throw std::runtime_error("httpclient is optimized for single-thread reactors - set "
                                     "pool_count=N and io_threads=1");
        }

        reactor_ = reactor;
        settings_ = settings;

        io_data_.reserve(reactor->size());
        for (size_t i = 0; i < reactor->size(); ++i)
        {
            auto io = (*reactor_)[i]->io();
            yplatform::net::io_data io_data(*io);
            io_data.setup_dns(settings_.dns);
            io_data.setup_ssl(settings_.ssl);
            io_data_[io] = std::move(io_data);
        }

        if (settings_.reactor_overload_delay != time_traits::duration::max())
        {
            for (size_t i = 0; i < reactor->size(); ++i)
            {
                auto& io = *(*reactor_)[i]->io();
                auto tracer =
                    make_shared<yplatform::reactor_tracer>(io);
                store_tracer(io, tracer);
                tracer->start();
            }
        }

        if (settings_.service_ticket_provider.module.size())
        {
            service_ticket_provider_ = yplatform::find<ticket_provider::for_host, shared_ptr>(
                *reactor->io(), settings_.service_ticket_provider.module);
        }
    }

    yplatform::ptree get_stats() const
    {
        active_connections connections_copy;
        {
            lock_t guard(mux_);
            connections_copy = active_conns_;
        }
        yplatform::ptree stats;
        stats.put("running-requests", std::to_string(running_requests_.load()));
        for (auto& pool : connections_copy)
        {
            auto pool_stats = pool.second->stats();
            const auto address = to_string(pool.first);
            pool_stats.put("address", address);
            stats.push_back({ address, std::move(pool_stats) });
        }
        return stats;
    }

    void request_i_cb(
        request_data_ptr req,
        session_ptr session,
        http_error::code errc,
        const string& reason)
    {
        const auto session_id = session ? session->id() : nullptr;
        {
            auto total_time = time_traits::clock::now() - req->started_at;
            report_request_finished(*req, session, total_time, errc);
            update_stats(*req, errc);
        }

        if (session) return_session(*req, session);

        try
        {
            if (!errc)
            {
                req->handler({}, "");
            }
            else if (req->context->is_cancelled())
            {
                static const string canceled_txt = "context canceled";
                req->handler(http_error::task_canceled, canceled_txt);
            }
            else if (errc == http_error::code::session_closed_error)
            {
                YLOG_CTX_LOCAL(req->context, info)
                    << "request wll be retried: conn=" << session_id << " req=" << req->number
                    << " reason=\"" << reason << "\"";
                request_impl(req);
                return;
            }
            else
            {
                req->handler(errc, reason);
            }
        }
        catch (const std::exception& ex)
        {
            YLOG_CTX_LOCAL(req->context, error)
                << "uncaugth exception in request handler: conn=" << session_id
                << " req=" << req->number << " what=\"" << ex.what() << "\"";
        }
        catch (...)
        {
            YLOG_CTX_LOCAL(req->context, error)
                << "uncaugth exception in request handler: conn=" << session_id
                << " req=" << req->number << " what=\"unknown\"";
        }

        running_requests_--;
    }

    response run(task_context_ptr ctx, request req)
    {
        return run(ctx, std::move(req), options());
    }

    response run(task_context_ptr ctx, request req, const options& options)
    {
        auto remote_point = make_rm_info(req.url);
        auto handler = boost::make_shared<simple_handler>();
        auto res = make_shared<request_data>();
        res->context = ctx;
        res->response_handler = handler;
        res->remote_point = remote_point;
        assign(res->headers, std::move(req.headers));
        res->post = std::move(req.body);
        res->mpost = std::move(req.multipart_body);
        res->method = req.method;
        res->attempt = req.attempt;
        do_request(res, options).get();
        return handler->result;
    }

    void async_run(task_context_ptr ctx, request req, callback_type cb)
    {
        async_run(ctx, std::move(req), options(), std::move(cb));
    }

    void async_run(task_context_ptr ctx, request req, const options& options, callback_type cb)
    {
        auto err = errc();
        auto remote_point = make_rm_info(req.url, timeouts(), false, err);
        if (err)
        {
            cb(err, response());
            return;
        }
        auto handler = boost::make_shared<simple_handler>();
        auto res = make_shared<request_data>();
        res->context = ctx;
        res->response_handler = handler;
        res->remote_point = remote_point;
        assign(res->headers, std::move(req.headers));
        res->post = std::move(req.body);
        res->mpost = std::move(req.multipart_body);
        res->method = req.method;
        res->attempt = req.attempt;

        auto future = do_request(res, options);
        adapt_async(future, handler, std::move(cb), logger(), ctx);
    }

    future_void_t get_url(
        context_ptr context,
        response_handler_ptr handler,
        const remote_point_info_ptr host,
        const string& req,
        const string& headers)
    {
        auto res = make_shared<request_data>();
        res->context = context;
        res->response_handler = handler;
        res->remote_point = host;
        res->uri = req;
        assign(res->headers, headers);
        res->method = request::method_t::GET;

        options opt;
        opt.timeouts = host->operation_timeouts;
        opt.reuse_connection = host->reuse_connection;
        return do_request(res, opt);
    }

    future_void_t head_url(
        context_ptr context,
        response_handler_ptr handler,
        const remote_point_info_ptr host,
        const string& req,
        const string& headers)
    {
        auto res = make_shared<request_data>();
        res->context = context;
        res->response_handler = handler;
        res->remote_point = host;
        res->uri = req;
        assign(res->headers, headers);
        res->method = request::method_t::HEAD;

        options opt;
        opt.timeouts = host->operation_timeouts;
        opt.reuse_connection = host->reuse_connection;
        return do_request(res, opt);
    }

    future_void_t post_url(
        context_ptr context,
        response_handler_ptr handler,
        const remote_point_info_ptr host,
        const string& req,
        const string_ptr& post,
        const string& headers,
        bool log_post_args)
    {
        auto res = make_shared<request_data>();
        res->context = context;
        res->response_handler = handler;
        res->remote_point = host;
        res->uri = req;
        res->post = post;
        assign(res->headers, headers);
        res->method = request::method_t::POST;

        options opt;
        opt.timeouts = host->operation_timeouts;
        opt.log_post_body = log_post_args;
        opt.reuse_connection = host->reuse_connection;
        return do_request(res, opt);
    }

    future_void_t mpost_url(
        context_ptr context,
        response_handler_ptr handler,
        const remote_point_info_ptr host,
        const string& req,
        post_chunks&& post,
        const string& headers,
        bool log_post_args)
    {
        auto res = make_shared<request_data>();
        res->context = context;
        res->response_handler = handler;
        res->remote_point = host;
        res->uri = req;
        res->mpost = boost::make_shared<multipart>(body_form_data(std::move(post)));
        assign(res->headers, headers);
        res->method = request::method_t::POST;

        options opt;
        opt.timeouts = host->operation_timeouts;
        opt.log_post_body = log_post_args;
        opt.reuse_connection = host->reuse_connection;
        return do_request(res, opt);
    }

    remote_point_info_ptr make_rm_info(
        const string& host,
        const timeouts& timeouts,
        bool reuse_connection,
        errc& err)
    {
        remote_point_info_ptr res(new remote_point_info);
        err = parse_url(host, res->proto, res->host, res->port, res->uri_prefix);
        res->reuse_connection = reuse_connection;
        res->operation_timeouts = timeouts;
        return res;
    }

    remote_point_info_ptr make_rm_info(const string& host, const timeouts& timeouts)
    {
        return make_rm_info(host, timeouts, settings_.reuse_connection);
    }

    remote_point_info_ptr make_rm_info(
        const string& host,
        const timeouts& timeouts,
        bool reuse_connection)
    {
        auto err = errc();
        remote_point_info_ptr res = make_rm_info(host, timeouts, reuse_connection, err);
        if (err)
        {
            throw bad_url_error();
        }
        return res;
    }

    remote_point_info_ptr make_rm_info(const string& host)
    {
        return make_rm_info(host, settings_.reuse_connection);
    }

    remote_point_info_ptr make_rm_info(const string& host, bool reuse_connection)
    {
        return make_rm_info(host, timeouts(), reuse_connection);
    }

private:
    inline future_void_t do_request(request_data_ptr req, const options& options)
    {
        auto promise_ptr = make_shared<promise_void_t>();
        try
        {
            adapt_async(promise_ptr, *req);
            check_request(*req);
            set_options(*req, options);
            append_headers(*req);
            add_service_ticket(*req);
            request_impl(req);
        }
        catch (const error& ex)
        {
            promise_ptr->set_exception(ex);
        }
        return *promise_ptr;
    }

    void check_request(request_data& req)
    {
        if (req.remote_point->uri_prefix.length() + req.uri.length() >
            settings_.max_request_line_size)
        {
            throw request_uri_too_long_error();
        }
    }

    void add_service_ticket(request_data& req)
    {
        static const string TICKET_HEADER = "X-Ya-Service-Ticket";
        if (req.headers.contained_headers().x_ya_service_ticket) return;
        if (!service_ticket_provider_) return;

        if (auto [err, ticket] =
                service_ticket_provider_->get_service_ticket_for_host(req.remote_point->host);
            !err && ticket.size())
        {
            req.headers.add(TICKET_HEADER, ticket);
        }
        else if (err && !settings_.service_ticket_provider.ignore_errors)
        {
            throw no_service_ticket_error(
                "failed to get ticket for host " + req.remote_point->host);
        }
    }

    session_ptr create_session(boost::asio::io_service& io, request_data_ptr req)
    {
        auto it = io_data_.find(&io);
        if (it == io_data_.end())
        {
            throw std::runtime_error("create session error: io_data not found");
        }
        session_ptr session(
            make_shared<session>(it->second, *req->remote_point, logger(), settings_));
        return session;
    }

    void return_session(request_data& req, session_ptr session)
    {
        auto pool = find_pool(req);
        pool->put(session);
    }

    connections_ptr find_pool(request_data& req)
    {
        lock_t guard(mux_);
        auto ipool = active_conns_.find(*req.remote_point);
        if (ipool == active_conns_.end())
        {
            ipool = active_conns_.insert(
                ipool,
                std::make_pair(
                    *req.remote_point,
                    std::make_shared<connections>(
                        reactor_,
                        settings_.preffered_pool_size,
                        settings_.reactor_overload_delay)));
        }
        return ipool->second;
    }

    void set_options(request_data& req, const options& options)
    {
        req.log_post_body =
            options.log_post_body ? *options.log_post_body : settings_.enable_post_args_logging;
        req.log_headers =
            options.log_headers ? *options.log_headers : settings_.enable_headers_logging;
        req.reuse_connection =
            options.reuse_connection ? *options.reuse_connection : settings_.reuse_connection;

        auto timeouts = eval_timeouts(req.context, options, settings_);
        req.connect_timeout = timeouts.connect;
        req.total_timeout = timeouts.total;
    }

    void append_headers(request_data& req)
    {
        auto timeout =
            time_traits::duration_cast<time_traits::milliseconds>(req.total_timeout).count();
        if (!req.headers.contained_headers().x_request_id)
            req.headers.add("X-Request-Id", request_id_or_unique_id(req.context));
        if (!req.headers.contained_headers().x_request_attempt)
            req.headers.add("X-Request-Attempt", std::to_string(req.attempt));
        if (!req.headers.contained_headers().x_request_timeout)
            req.headers.add("X-Request-Timeout", std::to_string(timeout));
        if (!settings_.user_agent.empty()) req.headers.add("User-Agent", settings_.user_agent);
    }

    void request_impl(request_data_ptr req)
    {
        if (req->context->is_cancelled() ||
            time_traits::clock::now() + settings_.min_request_duration >= req->context->deadline())
        {
            static const string canceled_txt = "context canceled";
            req->handler(http_error::task_canceled, canceled_txt);
            return;
        }

        running_requests_++;
        const string host_suffix =
            0 == req->remote_point->port ? string() : ":" + std::to_string(req->remote_point->port);
        req->full_uri =
            req->remote_point->host + host_suffix + req->remote_point->uri_prefix + req->uri;

        try
        {
            auto pool = find_pool(*req);
            if (req->reuse_connection)
            {
                if (auto session = pool->get())
                {
                    run_request_in_existing_session(req, session);
                }
                else
                {
                    run_request_in_new_session(req, pool);
                }
            }
            else
            {
                run_request_in_new_session(req, pool);
            }
        }
        catch (const std::exception& e)
        {
            YLOG_CTX_LOCAL(req->context, error)
                << "unexpected error in request_impl: exception=\"" << e.what() << "\"";
            request_i_cb(req, nullptr, http_error::unknown_error, e.what());
        }
        catch (...)
        {
            YLOG_CTX_LOCAL(req->context, error)
                << "unexpected error in request_impl: exception=\"unknown\"";
            request_i_cb(req, nullptr, http_error::unknown_error, "");
        }
    }

    void run_request_in_new_session(request_data_ptr req, connections_ptr pool)
    {
        report_request_started(*req, nullptr);
        pool->create(
            req->context,
            std::bind(&call_impl::create_session, yplatform::shared_from(this), ph::_1, req),
            req->connect_timeout,
            [this, self = yplatform::shared_from(this), req](
                auto&& err, auto&& reason, auto&& session) {
                if (err)
                {
                    return request_i_cb(req, session, err, reason);
                }
                req->number = session->get_next_request_number();
                session->run(
                    req, std::bind(&call_impl::request_i_cb, self, req, session, ph::_1, ph::_2));
            });
    }

    void run_request_in_existing_session(request_data_ptr req, session_ptr session)
    {
        req->number = session->get_next_request_number();
        report_request_started(*req, session->id());
        session->run(req, std::bind(&call_impl::request_i_cb, this, req, session, ph::_1, ph::_2));
    }

    void report_request_started(request_data& req, void const* session_id)
    {
        if (!settings_.enable_logging || settings_.tskv_logger) return;
        std::stringstream post_args_stream;
        if (settings_.enable_post_args_logging && req.post && req.log_post_body)
        {
            yplatform::util::log_cmd(
                post_args_stream << " post_args=\"",
                boost::make_iterator_range(*req.post),
                settings_.post_args_log_entry_max_size);
            post_args_stream << "\"";
        }
        YLOG_CTX_LOCAL(req.context, info)
            << "conn=" << session_id << " req=" << req.number << " event=start uri=\""
            << req.full_uri << "\"" << post_args_stream.str();
    }

    void report_request_finished(
        request_data& req,
        session_ptr session,
        time_traits::duration total_time,
        http_error::code errc)
    {
        if (!settings_.enable_logging) return;
        using time_traits::to_string;
        time_traits::duration resolve_time = {}, connect_time = {}, tls_time = {};
        if (session && session->stats().requests_processed <= 1)
        {
            auto& stats = session->stats();
            resolve_time = stats.resolve_time;
            connect_time = stats.connect_time;
            tls_time = stats.tls_time;
        }
        int port = req.remote_point->port;
        if (!port)
        {
            port = req.remote_point->proto == "https" ? 443 : 80;
        }
        const auto session_id = session ? session->id() : nullptr;
        const auto status = req.response_handler->code();
        const auto bytes_in = req.bytes_in;
        const auto bytes_out = req.bytes_out;
        if (settings_.tskv_logger)
        {
            using namespace yplatform::log::typed;

            attributes_map extra_attrs;
            if (req.log_post_body && req.post)
            {
                extra_attrs << make_attr(
                    "post_args",
                    replace_tskv_specials(
                        req.post->substr(0, settings_.post_args_log_entry_max_size)));
            }

            if (req.log_headers && req.headers.data().size())
            {
                extra_attrs << make_attr(
                    "headers",
                    prepare_headers_to_log(req.headers.data(), settings_.protected_log_headers));
            }
            if (!settings_.name_of_request_id_field.empty())
            {
                extra_attrs << make_attr(
                    settings_.name_of_request_id_field, req.context->request_id());
            }

            YLOG((*settings_.tskv_logger), info)
                << make_attr("session_id", session_id)
                << make_attr(settings_.name_of_uniq_id_field, req.context->uniq_id())
                << make_attr("host", req.remote_point->host) << make_attr("port", port)
                << make_attr("uri", req.full_uri) << make_attr("status", status)
                << make_attr("resolve_time", to_string(resolve_time))
                << make_attr("connect_time", to_string(connect_time))
                << make_attr("tls_time", to_string(tls_time))
                << make_attr("total_time", to_string(total_time))
                << make_attr("error", http_error::message(errc))
                << make_attr("bytes_out", bytes_out) << make_attr("bytes_in", bytes_in)
                << make_attr("attempt", req.attempt) << extra_attrs;
        }
        else
        {
            YLOG_CTX_LOCAL(req.context, info)
                << "conn=" << session_id << " req=" << req.number << " event=fin tm={"
                << to_string(resolve_time) << ", " << to_string(connect_time) << ", "
                << to_string(tls_time) << ", " << to_string(total_time) << "}"
                << " status=" << status << " attempt=" << req.attempt << " bytes_out=" << bytes_out
                << " bytes_in=" << bytes_in;
        }
    }

    void update_stats(request_data& req, http_error::code errc)
    {
        assert(req.remote_point);
        auto it = active_conns_.find(*req.remote_point);
        if (it == active_conns_.cend()) return;
        assert(it->second);
        assert(req.response_handler);
        const int status_code = req.response_handler->code();
        it->second->increment_request_processing_stats(status_code, req.attempt, errc);
    }

    active_connections active_conns_;
    mutable mutex_t mux_;
    std::atomic_int running_requests_ = { 0 };

    yplatform::reactor_ptr reactor_;
    std::unordered_map<boost::asio::io_service*, yplatform::net::io_data> io_data_;
    settings settings_;
    yplatform::net::ssl_settings ssl_settings_;
    shared_ptr<ticket_provider::for_host> service_ticket_provider_;
};

}

using call_impl = detail::call_impl<session>;

}
