#include <boost/range/iterator_range.hpp>
#include <boost/algorithm/string/predicate.hpp>

#include "options.h"
#include "log.h"
#include "public_suffix_list.h"
#include "dmarc.h"

#include <boost/format.hpp>

namespace NNwSmtp {
namespace dmarc {

    static const std::string DMARC{"DMARC"};

    template <typename Iterator>
    record::policy parse_policy(Iterator begin, Iterator end)
    {
        boost::iterator_range<Iterator> r(begin, end);
        if (boost::equal(r, boost::as_literal("none")))
            return record::NONE;
        if (boost::equal(r, boost::as_literal("quarantine")))
            return record::QUARANTINE;
        if (boost::equal(r, boost::as_literal("reject")))
            return record::REJECT;
        return record::UNKNOWN;
    }

    template <typename Iterator>
    void parse_record(Iterator begin, Iterator end, record &result)
    {
        while (begin != end)
        {
            char c = *begin;
            if (!isspace(c))
            {
                // extract tag name
                Iterator name_begin = begin;
                for (; begin != end && *begin != '=' && !isspace(*begin); ++begin);
                Iterator name_end = begin;
                for (; begin != end && isspace(*begin); ++begin);

                if (begin == end || *begin != '=')
                    break;

                // skip spaces between '=' and value
                for (++begin; begin != end && (*begin == ' ' || *begin == '\t'); ++begin);

                // extract tag value
                Iterator value_begin = begin, value_end = begin;
                while (begin != end)
                {
                    if ((c = *begin) == ';')
                    {
                        value_end = begin++;
                        break;
                    }

                    value_end = ++begin;
                }

                // handle entry
                std::string name(name_begin, name_end);
                if (boost::iequals(name, "v")) {
                    result.m_version.assign(value_begin, value_end);
                } else if (boost::iequals(name, "adkim")) {
                    result.m_strict_dkim = boost::equal(
                        boost::make_iterator_range(value_begin, value_end), boost::as_literal("s"));
                } else if (boost::iequals(name, "aspf")) {
                    result.m_strict_spf = boost::equal(
                        boost::make_iterator_range(value_begin, value_end), boost::as_literal("s"));
                } else if (boost::iequals(name, "p")) {
                    if ((result.m_policy = parse_policy(value_begin, value_end)) == record::UNKNOWN) {
                        return;
                    }
                } else if (boost::iequals(name, "sp")) {
                    result.m_subdomain_policy = parse_policy(value_begin, value_end);
                } else if (boost::iequals(name, "pct")) {
                    std::string value(value_begin, value_end);
                    try { result.percent = boost::lexical_cast<unsigned int>(value); } catch (...) {}
                } else if (boost::iequals(name, "rua")) {
                    result.rua.assign(value_begin, value_end);
                } else if (boost::iequals(name, "fo")) {
                    result.failureOpts.assign(value_begin, value_end);
                }
            }
            else
                for (; begin != end && isspace(*begin); ++begin);
        }
    }

    checker::checker(boost::asio::io_service &io_service, const resolver_options& resolver_options)
        : m_io_service(io_service)
        , m_strand(io_service)
        , m_timer(io_service)
        , m_resolver(io_service, resolver_options)
        , m_tld_fallback(false)
    {
    }

    void checker::start(const std::string &domain, const result_callback &handler, TContextPtr context)
    {
        m_handler = handler;
        Context = std::move(context);

        m_timer.expires_from_now(boost::posix_time::seconds(gconfig->dmarc.timeout));
        m_timer.async_wait(m_strand.wrap(boost::bind(&checker::handle_timeout, shared_from_this(), boost::asio::placeholders::error)));

        resolve(domain);
    }

    void checker::stop()
    {
        m_io_service.post(m_strand.wrap(boost::bind(&checker::cancel, shared_from_this())));
    }

    const std::string& checker::get_txt_record() const {
        return m_txt_record;
    }

    void checker::handle_resolve(const boost::system::error_code &error, resolver_t::iterator_txt result)
    {
        if (error == boost::asio::error::operation_aborted)
            return;

        if (!error)
        {
            for (resolver_t::iterator_txt i = result; i != resolver_t::iterator_txt(); ++i)
            {
                std::string r = *i;
                parse_record(r.begin(), r.end(), m_record);

                if (m_record.m_version != "DMARC1" || m_record.m_policy == record::UNKNOWN)
                    continue;

                m_txt_record = r;
                return report();
            }
        }

        if (!m_tld_fallback)
        {
            m_tld_fallback = true;

            // try the organizational domain
            std::string domain = g_public_suffixes.find_tld(m_record.m_domain);
            if (domain != m_record.m_domain)
            {
                if (error && error != boost::asio::error::host_not_found) {
                    NWLOG_L(notice, DMARC, "domain='" + m_record.m_domain + "', stat=error, report='" +
                        error.message() + "'");
                }

                return resolve(domain);
            }
        }

        if (error && error != boost::asio::error::host_not_found)
            return handle_error(error.message());
        report(false);
    }

    void checker::handle_timeout(const boost::system::error_code &error)
    {
        if (error == boost::asio::error::operation_aborted)
            return;

        handle_error("Timed out");
    }

    void checker::handle_error(const std::string &message)
    {
        NWLOG_L(notice, DMARC, "domain='" + m_record.m_domain + "', stat=error, report='" + message + "'");
        report(false);
    }

    void checker::resolve(const std::string &domain)
    {
        m_record.m_domain = domain;
        m_resolver.async_resolve_txt(
            "_dmarc." + domain,
            m_strand.wrap(
                boost::bind(
                    &checker::handle_resolve,
                    shared_from_this(),
                    boost::asio::placeholders::error,
                    boost::asio::placeholders::iterator)));
    }

    void checker::cancel()
    {
        try
        {
            m_timer.cancel();
            m_resolver.cancel();
        }
        catch (...)
        {
        }
    }

    void checker::report(bool success)
    {
        stop();

        if (m_handler)
            m_handler(success ? m_record : std::optional<record>());
    }

    std::string convert_policy(record::policy value)
    {
        const char *name[] = {"none", "quarantine", "reject"};
        return value != record::UNKNOWN ? name[value] : "unknown";
    }

}   // namespace dmarc
}   // namespace NNwSmtp
