#include "parser/uri.h"
#include <ymod_webserver/settings.h>
#include <yplatform/find.h>

namespace ymod_webserver {

void endpoint::parse_ptree(const yplatform::ptree& conf)
{
    auto addr_optional = conf.get_optional<std::string>("<xmlattr>.addr");
    addr = addr_optional ? *addr_optional : conf.get<std::string>("addr", "");
    auto port_optional = conf.get_optional<unsigned short>("<xmlattr>.port");
    port = port_optional ? *port_optional : conf.get<unsigned short>("port", port);
    secure = conf.get<bool>("ssl", secure);
    name = conf.get("endpoint_name", name);
    string default_host_str = conf.get("default_host", "");
    if (!default_host_str.empty())
    {
        parser::parse_host(default_host_str.begin(), default_host_str.end(), default_host);
    }
    socket_settings.parse_ptree(conf);
}

void rate_limit_settings::parse_limiting_attr(const yplatform::ptree& conf)
{
    auto param = conf.get_optional<std::string>("param");
    auto header = conf.get_optional<std::string>("header");

    if (param && header)
    {
        throw std::runtime_error("rate limit can not be set by param and header together");
    }
    if (param)
    {
        limiting_attr = { request_attr::type::url_param, *param };
    }
    else if (header)
    {
        limiting_attr = { request_attr::type::header, *header };
    }
}

void rate_limit_settings::parse_filters(const yplatform::ptree& conf)
{
    for (auto& pair : boost::make_iterator_range(conf.equal_range("if")))
    {
        auto& filter_conf = pair.second;
        auto param = filter_conf.get_optional<std::string>("param");
        auto header = filter_conf.get_optional<std::string>("header");
        if ((!param && !header) || (param && header))
        {
            throw std::runtime_error("rate limit filter must be set either by param or header");
        }

        if (param)
        {
            filters.push_back({ { request_attr::type::url_param, *param },
                                boost::regex(filter_conf.get<std::string>("value")) });
        }
        else
        {
            filters.push_back({ { request_attr::type::header, *header },
                                boost::regex(filter_conf.get<std::string>("value")) });
        }
    }
}

void rate_limit_settings::parse_ptree(const yplatform::ptree& conf)
{
    name = conf.get("name", name);
    limit = conf.get("limit", limit);
    recovery_rate = conf.get("recovery_rate", recovery_rate);
    auto recovery_interval = conf.get<time_traits::duration>("recovery_interval");
    if (recovery_interval < time_traits::milliseconds(1))
        throw std::runtime_error("recovery interval must be greater than 1ms");
    recovery_interval_ms =
        time_traits::duration_cast<time_traits::milliseconds>(recovery_interval).count();
    response_status =
        static_cast<codes::code>(conf.get<int>("status", response_status)); // For compatibility.
    response_status = static_cast<codes::code>(conf.get<int>("response_status", response_status));
    response_body = conf.get("response_body", response_body);
    for (auto& pair : boost::make_iterator_range(conf.equal_range("path")))
    {
        path.emplace_back(pair.second.data());
    }
    parse_limiting_attr(conf);
    parse_filters(conf);
}

void congestion_control_settings::parse_ptree(const yplatform::ptree& conf)
{
    reactor_overload_delay = conf.get<time_traits::duration>("reactor_overload_delay");
}

std::set<header_t> default_log_headers()
{
    static auto headers = { "X-Request-Id"s, "X-Request-Attempt"s, "X-Request-Timeout"s };
    std::set<header_t> res;
    for (const auto& header : headers)
    {
        res.insert({ header, boost::to_lower_copy(header) });
    }
    return res;
}

settings::settings()
    : access_log()
    , read_chunk_size(1024)
    , max_request_line_size(10 * 1024)
    , max_headers_size(10 * 1024)
    , max_post_size(40 * 1024 * 1024)
    , max_websocket_size(10 * 1024 * 1024)
    , enable_gzip(false)
    , enable_deflate(false)
    , enable_compress(false)
    , connect_per_request(false)
    , enable_websocket_masking(false)
    , log_cookies(true)
    , ws_max_length(std::numeric_limits<size_t>::max())
    , ws_max_fragmentation(std::numeric_limits<size_t>::max())
    , ws_max_streams(std::numeric_limits<size_t>::max())
    , http_default_version{ 1, 0 }
    , log_headers(default_log_headers())
    , access_log_tskv_socket_stats(false)
    , access_log_tskv_timings(false)
    , keep_alive_requests(100)
    , min_acceptable_timeout(yplatform::time_traits::duration::zero())
    , augment_unique_id_with_request_id(false)
    , enable_timing_stats(false)
{
}

settings::settings(endpoint endpoint) : settings()
{
    endpoints = { { endpoint.name, endpoint } };
}

void settings::load_policy_data(const string& policy_file_name)
{
    if (policy_file_name.empty()) return;
    std::ifstream is(policy_file_name.c_str(), std::ios::in);
    if (!is)
    {
        throw std::runtime_error("can not open policy file: " + policy_file_name);
    }
    is.exceptions(std::ios::badbit | std::ios::eofbit);
    std::streamsize size = 0;
    if (!is.seekg(0, std::ios::end))
    {
        throw std::runtime_error("can not seek to end of policy file: " + policy_file_name);
    }
    size = is.tellg();
    if (!is.seekg(0, std::ios::beg))
    {
        throw std::runtime_error("can not seek to begin of policy file: " + policy_file_name);
    }
    std::string data;
    data.resize(size);
    is.read(&data[0], size);
    policy_file.swap(data);
}

void settings::parse_ptree(const yplatform::ptree& data)
{
    // Deny deprecated options.
    if (data.count("default"))
        throw std::domain_error("config section 'default' is deprecated, "
                                "please rename it to 'socket'");
    if (data.count("net_settings"))
        throw std::domain_error("config section 'net_settings' is deprecated, "
                                "please rename it to 'socket'");
    if (data.count("ssl_context"))
        throw std::domain_error("config section 'ssl_context' is deprecated, "
                                "please rename it to 'ssl'");
    if (data.count("ssl_verify_hostname"))
        throw std::domain_error("ssl_verify_hostname option is deprecated, "
                                "please use ssl.verify_hostname");

    ssl.parse_ptree(data.get_child("ssl", yplatform::ptree()));

    auto base_socket_conf = data.get_child("socket", yplatform::ptree());
    auto range = data.get_child("endpoints").equal_range("listen");
    for (auto it = range.first; it != range.second; ++it)
    {
        endpoint ep;
        ep.socket_settings.parse_ptree(base_socket_conf);
        ep.parse_ptree(it->second);
        endpoints.insert(std::make_pair(ep.name, ep));
    }

    for (auto& pair : boost::make_iterator_range(data.equal_range("rate_limits")))
    {
        rate_limit_settings st;
        st.parse_ptree(pair.second);
        rate_limits.push_back(std::move(st));
    }

    using boost::adaptors::transformed;

    auto access_log_cfg = data.get_child_optional("access_log");
    if (access_log_cfg && access_log_cfg->size() == 0)
    {
        bool access_log_enabled = data.get<bool>("access_log", false);
        if (access_log_enabled)
        {
            access_log =
                yplatform::find<yplatform::log::source>(data.get("access_log_id", "WebAccess"));
        }
    }
    else if (access_log_cfg)
    {
        yplatform::read_ptree(protected_url_params, *access_log_cfg, "protected_url_params");
        if (data.get_child_optional("access_log.plain"))
        {
            access_log = yplatform::find<yplatform::log::source>(
                data.get("access_log.plain.logger", "WebAccess"));
        }
        if (data.get_child_optional("access_log.typed"))
        {
            if (data.get("access_log.typed.format", "") != "tskv")
            {
                throw std::runtime_error("only tskv typed access log format is supported");
            }
            access_log_tskv = yplatform::find<yplatform::log::typed::logger>(
                data.get("access_log.typed.logger", "WebAccessTskv"));
            access_log_tskv_timings = data.get<bool>("access_log.typed.log_timings", false);
            access_log_tskv_socket_stats =
                data.get("access_log.typed.socket_stats", access_log_tskv_socket_stats);
        }
    }

    read_chunk_size = data.get("read_chunk_size", read_chunk_size);
    max_headers_size = data.get("max_headers_size", max_headers_size);
    // Use specified max_headers_size as default value to keep backward compatibility.
    max_request_line_size = data.get("max_request_line_size", max_headers_size);
    max_post_size = data.get("max_post_size", max_post_size);
    max_websocket_size = data.get("max_websocket_size", max_websocket_size);
    enable_gzip = (data.get("enable_gzip", (enable_gzip ? 1 : 0)) != 0);
    enable_deflate = (data.get("enable_deflate", (enable_deflate ? 1 : 0)) != 0);
    enable_compress = (data.get("enable_compress", (enable_compress ? 1 : 0)) != 0);
    websocket_origin = data.get("websocket_origin", websocket_origin);
    websocket_location = data.get("websocket_location", websocket_location);
    websocket_protocol = data.get("websocket_protocol", websocket_protocol);
    if (connect_per_request) connect_per_request = (data.get("connect_per_request", "on") == "on");
    else
        connect_per_request = (data.get("connect_per_request", "off") == "on");
    if (enable_websocket_masking)
        enable_websocket_masking = (data.get("enable_websocket_masking", "on") == "on");
    else
        enable_websocket_masking = (data.get("enable_websocket_masking", "off") == "on");
    log_cookies = (data.get("log_cookies", "on") == "on");
    ws_max_length = data.get("websocket_max_message_size", ws_max_length);
    ws_max_fragmentation = data.get("websocket_max_fragmentation", ws_max_fragmentation);
    ws_max_streams = data.get("websocket_max_streams", ws_max_streams);
    load_policy_data(data.get("policy_file_name", ""));
    http_default_version.major = data.get("http_default_version.major", http_default_version.major);
    http_default_version.minor = data.get("http_default_version.minor", http_default_version.minor);
    if (data.count("log_headers"))
    {
        const auto headers = data.equal_range("log_headers") |
            transformed([](const std::pair<std::string, boost::property_tree::ptree>& v) {
                                 const auto& header = v.second.get_value<std::string>();
                                 return header_t{ header, boost::to_lower_copy(header) };
                             });
        log_headers.insert(headers.begin(), headers.end());
    }
    keep_alive_requests = data.get("keep_alive_requests", keep_alive_requests);
    access_control_allow_origin = data.get("access_control_allow_origin", "");
    access_control_allow_headers = data.get("access_control_allow_headers", "");
    access_control_expose_headers = data.get("access_control_expose_headers", "");
    access_control_allow_credentials = data.get("access_control_allow_credentials", false);
    min_acceptable_timeout = data.get("min_acceptable_timeout", min_acceptable_timeout);
    augment_unique_id_with_request_id =
        data.get("augment_unique_id_with_request_id", augment_unique_id_with_request_id);
    if (auto congestion_control_cfg = data.get_child_optional("congestion_control"))
    {
        congestion_control = congestion_control_settings();
        congestion_control->parse_ptree(*congestion_control_cfg);
    }
    enable_timing_stats = data.get("enable_timing_stats", enable_timing_stats);
}

}
