#include "dns.h"
#include "poller.h"

#include <solomon/libs/cpp/event/event.h>

#include <util/generic/hash_set.h>
#include <util/generic/ptr.h>
#include <util/generic/scope.h>
#include <util/string/builder.h>
#include <util/string/split.h>
#include <library/cpp/deprecated/atomic/atomic.h>
#include <util/system/spinlock.h>
#include <util/system/thread.h>
#include <util/thread/lfqueue.h>

#include <library/cpp/containers/absl_flat_hash/flat_hash_map.h>
#include <library/cpp/threading/future/future.h>

#include <contrib/libs/c-ares/include/ares.h>
#include <contrib/libs/c-ares/include/ares_nameser.h>

namespace NSolomon {
namespace {
    using namespace NThreading;

    enum ERecordType {
        Srv = 0,
        A,
        AAAA,
    };

    struct TDnsRequestContext {
        TString Address;
        ERecordType Type;

        TPromise<TVector<TSrvRecord>> SrvPromise;
        TPromise<TIpv6AddressesSet> AddressPromise;
    };

    // TLockFreeQueue does not work with move-only types like THolder
    using TDnsRequestContextPtr = TAutoPtr<TDnsRequestContext>;

    bool IsError(int status) {
        return status != ARES_SUCCESS;
    }

    std::exception_ptr MakeException(int status, const TString& address) {
        TStringBuilder sb;
        sb << TStringBuf("Error while resolving records for ") << address << TStringBuf(": ");

        std::exception_ptr result;
        switch (status) {
            case ARES_ETIMEOUT:
                sb << TStringBuf("timed out");
                result = std::make_exception_ptr(TDnsRequestTimeoutError() << sb);
                break;

            case ARES_ENODATA:
            case ARES_ENOTFOUND:
                sb << ares_strerror(status);
                result = std::make_exception_ptr(TDnsRecordNotFound() << sb);
                break;

            case ARES_ENOTIMP:
            case ARES_EBADNAME:
                sb << ares_strerror(status);
                result = std::make_exception_ptr(TDnsBadRequestError() << sb);
                break;

            case ARES_ENOMEM:
            case ARES_EDESTRUCTION:
            case ARES_ENOTINITIALIZED:
                sb << ares_strerror(status);
                result = std::make_exception_ptr(TDnsClientInternalError() << sb);
                break;

            default:
                sb << ares_strerror(status);
                result = std::make_exception_ptr(yexception() << sb);
                break;
        };

        return result;
    }

    class TChannelPoller {
        struct TAresLib {
            TAresLib() {
                auto status = ares_library_init(ARES_LIB_INIT_ALL);
                Y_ENSURE(status == ARES_SUCCESS, "Could not initialize c-ares: " << ares_strerror(status));
            }

            ~TAresLib() {
                ares_library_cleanup();
            }
        };
    public:
        static int AddSocketShim(ares_socket_t fd, int, void* data) {
            auto* self = static_cast<TChannelPoller*>(data);
            return self->AddSocket(fd);
        }

        static void ModifySocketShim(void* data, ares_socket_t socket, int readable, int writeable) {
            auto* self = static_cast<TChannelPoller*>(data);
            return self->ModifySocket(socket, readable, writeable);
        }

        TChannelPoller() {
            Singleton<TAresLib>();

            Opts_.sock_state_cb = &TChannelPoller::ModifySocketShim;
            Opts_.sock_state_cb_data = this;

            int mask = ARES_OPT_SOCK_STATE_CB;

            Y_VERIFY(ares_init_options(&Channel_, &Opts_, mask) == ARES_SUCCESS);
            ares_set_socket_callback(Channel_, &TChannelPoller::AddSocketShim, this);

            Poller_.AddWait(Wakeup_.Fd(), &Wakeup_);
        }

        ~TChannelPoller() {
            ares_destroy(Channel_);
        }

        static void* ThreadShim(void* raw) {
            auto* self = reinterpret_cast<TChannelPoller*>(raw);
            self->Loop();

            return nullptr;
        }

        static void ProcessAddressResponse(void* arg, int status, int, struct hostent* hostent) {
            TDnsRequestContextPtr ctx {static_cast<TDnsRequestContext*>(arg)};
            auto promise = ctx->AddressPromise;

            if (IsError(status)) {
                promise.SetException(MakeException(status, ctx->Address));
                return;
            }

            // this shouldn't really happen, since docs state that it's null only on error
            Y_VERIFY_DEBUG(hostent != nullptr);
            TIpv6AddressesSet result;
            for (auto i = 0; hostent->h_addr_list[i]; ++i) {
                auto* hostentAddr = hostent->h_addr_list[i];

                if (Y_LIKELY(hostent->h_addrtype == AF_INET6)) {
                    result.emplace(*reinterpret_cast<struct in6_addr*>(hostentAddr), 0);
                } else if (hostent->h_addrtype == AF_INET) {
                    result.emplace(*reinterpret_cast<struct in_addr*>(hostentAddr));
                }
            }

            promise.SetValue(result);
        }

        static void ProcessSrvResponse(void *arg, int status, int, unsigned char* abuf, int alen) {
            TDnsRequestContextPtr ctx{static_cast<TDnsRequestContext*>(arg)};
            auto promise = ctx->SrvPromise;

            if (IsError(status)) {
                promise.SetException(MakeException(status, ctx->Address));
                return;
            }

            struct ares_srv_reply *head = nullptr;
            const auto rc = ares_parse_srv_reply(abuf, alen, &head);
            Y_DEFER { ares_free_data(head); };

            if (rc != ARES_SUCCESS) {
                promise.SetException(TStringBuilder() << "Error while parsing response from "
                    << ctx->Address << ": " << ares_strerror(rc));
                return;
            }

            auto* current = head;
            TVector<TSrvRecord> result;
            while (current != nullptr) {
                TSrvRecord r;
                r.Host = current->host;
                r.Port = current->port;
                r.Weight = current->weight;
                r.Priority = current->priority;
                result.push_back(std::move(r));

                current = current->next;
            }

            promise.SetValue(std::move(result));
        }

        void Stop() {
            ShouldStop_ = true;
            Wakeup_.Signal();
        }

        void Schedule(const TDnsRequestContextPtr& ctx) {
            if (ShouldStop_) {
                return;
            }

            Tasks_.Enqueue(ctx);
            Wakeup_.Signal();
        }

        void Loop() {
            TThread::SetCurrentThreadName("AresPoller");
            constexpr auto TIMEOUT = TDuration::Seconds(1);

            while (!ShouldStop_) {
                ProcessTasks();
                ProcessEvents(TIMEOUT);
            }

            // this way we can ensure that corresponding promises will be fulfilled
            while (!Tasks_.IsEmpty()) {
                ProcessTasks();
            }

            ares_cancel(Channel_);
        }

    private:
        int AddSocket(ares_socket_t socket) {
            auto [it, ok] = Events_.emplace(socket, std::make_unique<TEvent>(socket));
            Y_VERIFY_DEBUG(ok);

            Poller_.AddWait(socket, it->second.get());
            return 0;
        }

        void ModifySocket(ares_socket_t socket, int readable, int writeable) {
            auto it = Events_.find(socket);
            Y_VERIFY_DEBUG(it != Events_.end());
            if (it == Events_.end()) {
                return;
            }

            if (!readable && !writeable) {
                Poller_.Unwait(socket);
                const auto cnt = Events_.erase(socket);
                Y_UNUSED(cnt);
                Y_VERIFY_DEBUG(cnt == 1);
            }
        }

        void ProcessTasks() {
            static constexpr auto ITER_LIMIT = 100;
            TDnsRequestContextPtr ctx;

            auto count = 0;
            while (count++ < ITER_LIMIT && Tasks_.Dequeue(&ctx)) {
                switch (ctx->Type) {
                case ERecordType::Srv:
                    ares_query(Channel_,
                        ctx->Address.c_str(),
                        ns_c_in,
                        ns_t_srv,
                        ProcessSrvResponse,
                        ctx.Release());
                    break;
                case ERecordType::A:
                    ares_gethostbyname(Channel_,
                        ctx->Address.c_str(),
                        AF_INET,
                        ProcessAddressResponse,
                        ctx.Release());
                    break;
                case ERecordType::AAAA:
                    ares_gethostbyname(Channel_,
                        ctx->Address.c_str(),
                        AF_INET6,
                        ProcessAddressResponse,
                        ctx.Release());
                    break;
                }
            }
        }

        void ProcessEvents(TDuration timeout) {
            static constexpr auto MAX_POLL = 10;
            void* events[MAX_POLL];
            int r = Poller_.WaitT(events, Y_ARRAY_SIZE(events), timeout);

            bool hasMore = false;
            if (r != 0) {
                for (int i = 0; i < r; ++i) {
                    if (&Wakeup_ == events[i]) {
                        hasMore = true;
                    } else {
                        auto& event = *static_cast<TEvent*>(events[i]);
                        ares_process_fd(Channel_, event.Socket, event.Socket);
                    }
                }
            } else {
                ares_process_fd(Channel_, ARES_SOCKET_BAD, ARES_SOCKET_BAD);
            }

            if (hasMore) {
                Wakeup_.Reset();
            }
        }

    private:
        absl::flat_hash_map<ares_socket_t, std::unique_ptr<TEvent>> Events_;
        TAresPoller Poller_;
        std::atomic<bool> ShouldStop_{false};

        TPollableEvent Wakeup_;
        ares_channel Channel_;
        ares_options Opts_{};
        TLockFreeQueue<TDnsRequestContextPtr> Tasks_;
    };

    class TDnsClient: public IDnsClient {
    public:
        TDnsClient()
            : Poller_{}
            , Thread_{TChannelPoller::ThreadShim, &Poller_}
        {
            Start();
        }

        ~TDnsClient() override {
            Stop();
        }

        TFuture<TVector<TSrvRecord>> GetSrvRecords(const TString& address) override {
            TDnsRequestContextPtr ctx{new TDnsRequestContext{
                .Address = address,
                .Type = ERecordType::Srv,
                .SrvPromise = NewPromise<TVector<TSrvRecord>>(),
            }};

            auto promise = ctx->SrvPromise;
            Poller_.Schedule(ctx);

            return promise.GetFuture();
        }

        TFuture<TIpv6AddressesSet> GetAddresses(const TString& address, bool ipv6Only) override {
            TDnsRequestContextPtr v6Ctx{new TDnsRequestContext{
                .Address = address,
                .Type = ERecordType::AAAA,
                .AddressPromise = NewPromise<TIpv6AddressesSet>(),
            }};

            auto v6Promise = v6Ctx->AddressPromise;
            Poller_.Schedule(v6Ctx);

            if (ipv6Only) {
                return v6Promise.GetFuture();
            }

            TDnsRequestContextPtr v4Ctx{new TDnsRequestContext{
                .Address = address,
                .Type = ERecordType::A,
                .AddressPromise = NewPromise<TIpv6AddressesSet>(),
            }};


            auto v4Promise = v4Ctx->AddressPromise;
            Poller_.Schedule(v4Ctx);

            struct TWaitContext: TAtomicRefCount<TWaitContext> {
                ~TWaitContext() {
                    if (Result.empty()) {
                        Promise.SetException(MakeErrorPretty());
                        return;
                    }

                    Promise.SetValue(std::move(Result));
                }

                void OnDone(TIpv6AddressesSet&& addrs) {
                    auto g = Guard(Lock);
                    Copy(addrs.begin(), addrs.end(), std::inserter(Result, Result.begin()));
                }

                void OnError(TString exc) {
                    auto g = Guard(Lock);
                    Errors.push_back(std::move(exc));
                }

                TString MakeErrorPretty() const {
                    TStringBuilder sb;
                    for (auto&& exc: Errors) {
                        sb << exc << "\n";
                    }

                    return sb;
                }

                TFuture<TIpv6AddressesSet> GetFuture() const {
                    return Promise.GetFuture();
                }

                TAdaptiveLock Lock;
                TIpv6AddressesSet Result;

                TVector<TString> Errors;
                TPromise<TIpv6AddressesSet> Promise = NewPromise<TIpv6AddressesSet>();
            };

            TIntrusivePtr<TWaitContext> ctx{new TWaitContext};

            auto callback = [ctx] (auto f) {
                try {
                    auto&& val = f.ExtractValue();
                    ctx->OnDone(std::move(val));
                } catch (...) {
                    ctx->OnError(CurrentExceptionMessage());
                }
            };

            v4Promise.GetFuture().Subscribe(callback);
            v6Promise.GetFuture().Subscribe(callback);

            return ctx->GetFuture();
        }

    private:
        void Start() {
            Thread_.Start();
        }

        void Stop() {
            Poller_.Stop();
            Thread_.Join();
        }

    private:
        TChannelPoller Poller_;
        TThread Thread_;
    };

} // namespace

    IDnsClientPtr CreateDnsClient() {
        return MakeAtomicShared<TDnsClient>();
    }
} // namespace NSolomon
