#include "ares.h"
#include "algorithm.h"
#include <array>
#include <util/string/cast.h>
#include <util/generic/xrange.h>
#include <contrib/libs/c-ares/include/ares.h>
#include <library/cpp/coroutine/engine/impl.h>
#include <mail/so/libs/curl/coroutine_helper.h>

namespace NAres {
    struct TAresGlobalInit {
        TAresGlobalInit() {
            ares_library_init(ARES_LIB_INIT_ALL);
        }

        ~TAresGlobalInit() {
            ares_library_cleanup();
        }
    };

    TAresInit::TAresInit() {
        static TAresGlobalInit init;
    }

    void TAres::TDestroyer::Destroy(ares_channel c) {
        ares_destroy(c);
    }

    TAres::TAres()
        : TAres(TDuration::Zero(), 0)
    { }

    TAres::TAres(const TDuration& timeout, size_t tries)
        : TAres(timeout, tries, nullptr, 0)
    { }

    TAres::TAres(const TDuration& timeout, size_t tries, const TString& server, ui16 port) {
        int flags = 0;
        ares_options options;

        if (timeout != TDuration::Zero()) {
            flags |= ARES_OPT_TIMEOUTMS;
            options.timeout = static_cast<int>(std::min<ui64>(timeout.MilliSeconds(), std::numeric_limits<int>::max()));
        }

        if (tries != 0) {
            flags |= ARES_OPT_TRIES;
            options.tries = tries;
        }

        ares_channel channel = nullptr;
        ares_init_options(&channel, &options, flags);

        if (!server.empty()) {
            TString s = server;
            if(port)
                s += TString(':') + ToString(port);
            ares_set_servers_ports_csv(channel, s.c_str());
        }

        context.Reset(channel);
    }

    static void Callback(void *arg, int status, int, struct hostent *hostent) {
        if (status == ARES_SUCCESS) {
            std::array<char, INET6_ADDRSTRLEN> buffer;
            buffer.fill(0);

            TVector<TString>& data = *reinterpret_cast<TVector<TString>*>(arg);
            for (auto it = hostent->h_addr_list; *it != nullptr; ++it) {
                ares_inet_ntop(hostent->h_addrtype, *it, buffer.data(), buffer.size());
                data.push_back(buffer.data());
            }
        }
    }

    static bool GetFds(ares_channel channel, fd_set& readers, fd_set& writers, int& nfds) {
        FD_ZERO(&readers); FD_ZERO(&writers);
        nfds = ares_fds(channel, &readers, &writers);
        return nfds != 0;
    }

    static bool SimpleWait(ares_channel channel, fd_set& readers, fd_set& writers, int nfds, const TInstant& deadline) {
        if (deadline < Now())
            ythrow TSystemError(ETIMEDOUT);

        timeval tv, maxtv = deadline.TimeVal();
        auto tvp = ares_timeout(channel, &maxtv, &tv);

        select(nfds, &readers, &writers, NULL, tvp);
        return true;
    }

    static bool Wait(ares_channel channel, fd_set& readers, fd_set& writers, int nfds, size_t count, const TInstant& deadline, TCont* cont) {
        return (cont == nullptr) ?
            SimpleWait(channel, readers, writers, nfds, deadline):
            CoroutineWait(readers, writers, nfds, count, deadline, cont);
    }

    static void Process(ares_channel channel, size_t count, const TInstant& deadline, TCont* cont)
    try{
        int nfds = 0;
        fd_set readers = {}, writers = {};
        while (GetFds(channel, readers, writers, nfds) && Wait(channel, readers, writers, nfds, count, deadline, cont)) {
            ares_process(channel, &readers, &writers);
        }
    } catch (...) {
        ares_cancel(channel);
        throw;
    }

    TVector<TVector<TString>> TAres::Resolve(const TVector<TString>& hosts, const TInstant& deadline, TCont* cont) {
         TVector<TVector<TString>> result(hosts.size());

         simultaneous_for_each_container([this](const auto& host, auto& arg) {
             ares_gethostbyname(context.Get(), host.c_str(), AF_INET, Callback, &arg);
             ares_gethostbyname(context.Get(), host.c_str(), AF_INET6, Callback, &arg);
         }, hosts, result);

         Process(context.Get(), hosts.size() * 2, deadline, cont);

         return result;
    }

    TVector<TString> TAres::Resolve(const TString& host, const TInstant& deadline, TCont* cont) {
        return std::move(Resolve(TVector<TString>(1, host), deadline, cont).front());
    }

    THolder<TAres> TPoolTraits::create() const {
        return MakeHolder<TAres>(timeout, tries, server, port);
    }

    TPoolTraits::TPoolTraits()
        : TPoolTraits(TDuration::Zero(), 0)
    { }

    TPoolTraits::TPoolTraits(const TDuration& timeout, size_t tries)
        : TPoolTraits(timeout, tries, nullptr, 0)
    { }

    TPoolTraits::TPoolTraits(const TDuration& timeout, size_t tries, const TString& server, ui16 port)
        : timeout(timeout), tries(tries), server(server), port(port)
    {
        auto l = CombineHashes<size_t>(timeout.MilliSeconds(), tries);
        auto r = CombineHashes<size_t>(THash<TString>()(server), port);
        h = CombineHashes(l, r);
    }
    TPoolTraits::TPoolTraits(const TDuration& timeout, size_t tries, const TStringBuf& serverPort) :
            TPoolTraits(timeout, tries, TString{serverPort.Before(':')}, FromString(serverPort.After(':'))) {}
}
