#include "options.h"
#include "file_converter.h"
#include "log.h"

#include <boost/tokenizer.hpp>

#include <functional>
#include <algorithm>
#include <exception>

namespace NNwSmtp {

const char *temp_error = "451 4.7.1 Service unavailable - try again later";
const char *perm_user_error = "550 Requested action not taken: mailbox unavailable";

const std::size_t MB_51 = 53477376;

template <typename Opts>
inline void read_http_info(const ptree& pt, Opts& opts) {
    opts.connectTimeout = pt.get<unsigned int>("connect_timeout", 1);
    opts.timeout = pt.get<unsigned int>("timeout", 1);
    opts.attempts = pt.get<unsigned int>("try", 2);

    auto addrStr = pt.get<std::string>("addr");
    opts.addr = remote_point::from_string(addrStr);
}

inline SmtpTimeouts read_smtp_timeouts(const ptree& pt, SmtpTimeouts timeouts = SmtpTimeouts()) {
    timeouts.connect = pt.get("connect", timeouts.connect);
    timeouts.command = pt.get("command", timeouts.command);
    timeouts.data = pt.get("data", timeouts.data);
    return timeouts;
}

inline Options::Targeting read_targeting_options(const ptree& pt) {
    Options::Targeting targeting;
    targeting.init(pt);
    return targeting;
}

void Options::SmtpConnectionOpts::init(const ptree& pt) {
    startTls = pt.get<bool>("starttls", startTls);

    maxMessagesPerAuth = pt.get("max_messages_per_auth", maxMessagesPerAuth);

    banner = pt.get<std::string>("banner", "");

    auto constraintsNode = pt.get_child("proto_constraints");
    constraints.rcptCountLimit = constraintsNode.get<unsigned int>("recipient_limit", 35);
    constraints.checkSenderSyntax = constraintsNode.get<bool>("check_sender_syntax", true);
    constraints.strictAsciiRecipient = constraintsNode.get<bool>("strict_ascii_recipient", true);
    constraints.allowPercentHack = constraintsNode.get<bool>("allow_percent_hack", true);
    constraints.messageSizeLimit = constraintsNode.get<std::size_t>("message_size_limit", MB_51);

    auto sessionTimeLimitNode = pt.get_child_optional("session_time_limit");
    if (sessionTimeLimitNode) {
        sessionTimeLimit.use = sessionTimeLimitNode->get("use", false);
        sessionTimeLimit.timeout = sessionTimeLimitNode->get("timeout", sessionTimeLimit.timeout);
    }

    unrecognizedCommandsMaxCount = pt.get<std::size_t>(
        "unrecognized_commands_max_count", unrecognizedCommandsMaxCount);

    useRateLimiter = pt.get<bool>("use_rate_limiter", useRateLimiter);
}

void Options::MessageProcessingOpts::init(const ptree& pt) {
    boost::range::for_each(pt.equal_range("remove_headers_list"), [&](const auto& header) {
        removeHeadersSet.insert(boost::to_lower_copy(header.second.data()));;
    });

    headersCountLimit = pt.get<unsigned int>("headers_count_limit", 250);

    hideSourceInReceived = pt.get<bool>("hide_source_in_received", false);
    rewriteSenderFromHeader = pt.get<bool>("rewrite_sender_from_header", false);
    addReturnPath = pt.get<bool>("add_return_path", false);
    markMLMessage = pt.get<bool>("mark_message_ml", false);

    auto range = pt.equal_range("trust_headers");
    for (auto header = range.first; header != range.second; ++header) {
        trustHeaders.insert(boost::to_lower_copy(header->second.data()));
    }

    ControlFromOpts.Init(pt.get_child("control_from"));
}

void Options::DeliveryToSenderControl::init(const ptree& pt) {
    use = pt.get<bool>("use", false);
    checkSenderInBigMlSubscribers = pt.get<bool>("check_sender_in_big_ml_subscribers", checkSenderInBigMlSubscribers);
}

void Options::AuthOpts::init(const ptree& pt) {
    use = pt.get<bool>("use", false);
    sslOnly = pt.get<bool>("ssl_only", false);
    sslMessage = pt.get<std::string>("ssl_message", "Message on AUTH without STARTTLS or ssl");
}

void Options::YarmOpts::init(const ptree& pt) {
    useRpopAuth = pt.get<bool>("use_rpop_auth", false);
    if (useRpopAuth) {
        secret = pt.get<std::string>("secret");
    }
}

void Options::RBLOpts::init(const ptree& pt) {
    use = pt.get<bool>("use", false);
    if (use) {
        auto hostsRange = pt.equal_range("hosts");
        for (auto it = hostsRange.first; it != hostsRange.second; ++it) {
            hosts.push_back(it->second.data());
        }
    }
}

void Options::SPFOpts::init(const ptree& pt) {
    use = pt.get<bool>("use", false);
    timeout = pt.get<unsigned int>("timeout", 1);
}

void Options::DMARCOpts::init(const ptree& pt) {
    use = pt.get<bool>("use", false);
    timeout = pt.get<unsigned int>("timeout", 1);
}

void Options::LLOpts::init(const ptree& pt) {
    use = pt.get<bool>("use", false);
    if (use) {
        auto addrStr = pt.get<std::string>("addr");
        addr = remote_point::from_string(addrStr);
        timeout = pt.get("timeout", timeout);
        connectTimeout = pt.get("connect_timeout", connectTimeout);
        attempts = pt.get("try", attempts);
        ignoreErrors = pt.get<bool>("ignore_errors", false);
    }
}

void Options::BigMlOpts::init(const ptree& pt) {
    use = pt.get<bool>("use", false);
    if (use) {
        add_recipients = pt.get("add_recipients", add_recipients);
        path = pt.get<std::string>("path");
        timeout = pt.get("timeout", timeout);
        connectTimeout = pt.get("connect_timeout", connectTimeout);
    }
}

void Options::AuthSettingsOpts::init(const ptree& pt) {
    use = pt.get("use", use);
}

void Options::BlackBoxOpts::init(const ptree& pt) {
    auto scopesRange = pt.equal_range("oauth_scopes");
    for (auto it = scopesRange.first; it != scopesRange.second; ++it) {
        oauthScopes.insert(it->second.data());
    }

    auto links = pt.equal_range("access_restricted_links_list");
    boost::range::for_each(links, [&](const auto& value) {
        auto key = value.second.template get<std::string>("key");
        auto link = value.second.template get<std::string>("link");
        accessRestrictedLinksMap.emplace(std::make_pair(key, link));
    });

    allowUnknownRcpt = pt.get("allow_unknown_rcpt", allowUnknownRcpt);
    allowReceiveWithBadKarma =
        pt.get("client_rcpt_with_bad_karma_receives_mails", allowReceiveWithBadKarma);
    checkRcpt = pt.get("check_rcpt", checkRcpt);
    checkSender = pt.get("check_sender", checkSender);
    tempUserErrorReplyText = pt.get<std::string>("temp_user_error_reply_text");
    denyAuthForAssessors = pt.get("deny_auth_for_assessors", denyAuthForAssessors);
    allowUnknownSender = pt.get("allow_unknown_sender", allowUnknownSender);
}

void Options::Targeting::init(const ptree& pt) {
    use = pt.get("use", use);
    fallback = pt.get("fallback", fallback);
    if (pt.count("timeouts")) {
        timeouts = read_smtp_timeouts(pt.get_child("timeouts"), timeouts);
    }
    if (auto port = pt.get_optional<std::uint16_t>("bypass_port")) {
        bypassPort = *port;
    }
}

void Options::Mds::init(const ptree& pt) {
    auto range = pt.equal_range("remove_headers");
    for (auto header = range.first; header != range.second; ++header) {
        RemoveHeaders.insert(boost::to_lower_copy(header->second.data()));
    }
}

void Options::DeliveryOpts::init(const ptree& pt) {
    auto routingPt = pt.get_child_optional("routing");
    if (routingPt) {
        routing.init(routingPt.get());
    }

    const auto & relaysPt = pt.get_child("relays");
    // local
    const auto & localPt = relaysPt.get_child("local");
    local.addr = SmtpPoint::fromString(localPt.get<std::string>("addr"));
    if (localPt.count("timeouts")) {
        local.timeouts = read_smtp_timeouts(localPt.get_child("timeouts"));
    }
    if (localPt.count("targeting")) {
        localTargeting = read_targeting_options(localPt.get_child("targeting"));
    }
    // fallback
    const auto & fallbackPt = relaysPt.get_child("fallback");
    fallback.addr = SmtpPoint::fromString(fallbackPt.get<std::string>("addr"));
    if (fallbackPt.count("timeouts")) {
        fallback.timeouts = read_smtp_timeouts(fallbackPt.get_child("timeouts"));
    }
    if (fallbackPt.count("targeting")) {
        fallbackTargeting = read_targeting_options(fallbackPt.get_child("targeting"));
    }
    // external
    auto externalPt = relaysPt.get_child_optional("external");
    if (externalPt) {
        if (auto addr = externalPt->get<std::string>("addr"); !addr.empty()) {
            external.addr = SmtpPoint::fromString(addr);
        }
        if (externalPt.get().count("timeouts")) {
            external.timeouts = read_smtp_timeouts(externalPt->get_child("timeouts"));
        }
    } else if (routing.primary != RoutingSettings::Type::NONE) {
        throw std::runtime_error("external relay not specified");
    }

    // sender_dependent
    if (auto opt = relaysPt.get_child_optional("sender_dependent")) {
        const auto& senderDepPt = opt.get();
        if (senderDepPt.get<bool>("use", false)) {
            boost::for_each(senderDepPt.equal_range("relays_map"), [&](const auto& value){
                auto key = value.second.template get<std::string>("key");
                auto addr = value.second.template get<std::string>("addr");
                senderDependent.relaysMap.insert({key, SmtpPoint::fromString(addr)});
            });
            if (senderDepPt.count("timeouts")) {
                senderDependent.timeouts = read_smtp_timeouts(senderDepPt.get_child("timeouts"));
            }
        }
    }

    auto mxcodes = pt.equal_range("mxcode_map");
    for (auto it = mxcodes.first; it != mxcodes.second; ++it) {
        auto key = it->second.get<std::string>("key");
        auto addr = it->second.get<std::string>("addr");
        mxCodeMap.insert({key, remote_point::from_string(addr)});
    }

    boost::range::for_each(pt.equal_range("remove_headers"), [&](const auto& header) {
        removeHeaders.insert(boost::to_lower_copy(header.second.data()));;
    });
}

void Options::DecyclerOpts::init(const ptree& pt) {
    updateHeader = pt.get<bool>("update_header", false);
    reject = pt.get<bool>("reject", false);
    ttl = pt.get<std::size_t>("ttl", std::numeric_limits<std::size_t>::max());
}

void Options::DsnOptions::init(const ptree& pt) {
    mode = dsn::parse_mode(pt.get<std::string>("mode", "none"));
    if (mode == dsn::Options::NEVER)
        return;

    relay = SmtpPoint::fromString(pt.get<std::string>("addr"));
    sign = pt.get<bool>("sign", false);
    composer.init(pt);
}

Options::Options(const ptree& pt) {
    read(pt);
}

void Options::read(const ptree& pt) {
    mydomain = pt.get<std::string>("mydomain");
    clusterName = pt.get<std::string>("cluster_name", "");
    hostName = boost::asio::ip::host_name();
    publicSuffixListFile = pt.get<std::string>("public_suffix_list", "");

    smtpOpts.init(pt.get_child("smtp_connection"));
    msgOpts.init(pt.get_child("message_processing"));
    auth.init(pt.get_child("auth"));
    yarm.init(pt.get_child("yarm"));
    rbl.init(pt.get_child("rbl"));
    spf.init(pt.get_child("spf"));
    dsn.init(pt.get_child("dsn"));
    dmarc.init(pt.get_child("dmarc"));
    corpList.init(pt.get_child("corp_maillist"));
    bigMlOpts.init(pt.get_child("big_ml"));
    authSettingsOpts.init(pt.get_child("auth_settings"));
    blackbox.init(pt.get_child("blackbox"));
    soOpts.Read(pt.get_child("so"));
    avir.read(pt.get_child("av"));
    delivery.init(pt.get_child("delivery"));
    deliveryToSenderControl.init(pt.get_child("delivery_to_sender_control"));
    mds.init(pt.get_child("mds"));
    decyclerOpts.init(pt.get_child("decycler"));
}

void Options::reload(const ptree& pt) {
    auto routingPt = pt.get_child_optional("delivery.routing");
    if (routingPt) {
        delivery.routing.reload(routingPt.get());
    }
}

std::shared_ptr<Options> gconfig;

}   // namespace NNwSmtp
