#include "async_coro_requests.h"

#include <balancer/kernel/dns/common.h>

#include <library/cpp/coroutine/engine/impl.h>

namespace NSrvKernel::NDns {

class TRequest final : public IRequest {
public:
    TRequest(const char* host, int family, TContExecutor& e, ICounters& counters, IEntry& result);
    TRequest(const TRequest&) = delete;
    TRequest(TRequest&&) = delete;
    TRequest& operator=(const TRequest&) = delete;
    TRequest& operator=(TRequest&&) = delete;
    ~TRequest() override;

    TString Host() const noexcept override { return Host_; }
    int Family() const noexcept override { return Family_; }

    void OnComplete(int status, const hostent* result) noexcept override;
    int WaitD(TInstant deadline) noexcept override;

private:
    void SetResult(const hostent* result) noexcept;

private:
    IEntry& Result_;
    TString Host_;
    int  Family_ = AF_UNSPEC;
    TContSimpleEvent Event_;
    bool WaitCalled_ = false;
    bool CbCalled_ = false;
    ICounters& Counters_;
};


TRequest::TRequest(const char* host, int family, TContExecutor& e, ICounters& counters, IEntry& result)
    : Result_(result)
    , Host_(host)
    , Family_(family)
    , Event_(&e)
    , Counters_(counters)
{
    Result_.AddRequestInProgress();
}

TRequest::~TRequest()
{
    Result_.SetFailedIfDone();
}

void TRequest::OnComplete(int status, const hostent* he) noexcept
{
    Y_DEFER { Event_.Signal(); if (WaitCalled_) { delete this; } };

    if (CbCalled_) {
        Y_ASSERT(false);
        return;
    }

    CbCalled_ = true;

    if (status != ARES_SUCCESS) {
        switch (status) {
            case ARES_ENODATA:
                Counters_.AddNoData();
                return;
            case ARES_ENOTFOUND:
                Counters_.AddNxdomain();
                return;
            default:
                Counters_.AddResolveError();
                return;
        }
    }
    Counters_.AddResolve();

    SetResult(he);
}

int TRequest::WaitD(TInstant deadline) noexcept
{
    Y_DEFER { if (CbCalled_) { delete this; } };

    if (CbCalled_) {
        return EWAKEDUP; // report that wait wasn't timeouted for this call
    }
    const auto result = Event_.WaitD(deadline);
    WaitCalled_ = true;

    return result;
}

void TRequest::SetResult(const hostent* he) noexcept
{
    if (!he) {
        return;
    }

    char** addrList = he->h_addr_list;
    while (*addrList) {
        switch (he->h_addrtype) {
            case AF_INET6: {
                sockaddr_in6 sockAddr;
                Zero(sockAddr);
                sockAddr.sin6_family = AF_INET6;
                memcpy(&sockAddr.sin6_addr.s6_addr, *addrList, sizeof(sockAddr.sin6_addr.s6_addr));
                Result_.AddResult({sockAddr});
                break;
            }
            case AF_INET: {
                sockaddr_in sockAddr;
                Zero(sockAddr);
                sockAddr.sin_family = AF_INET;
                memcpy(&sockAddr.sin_addr, *addrList, sizeof(sockAddr.sin_addr));
                Result_.AddResult({sockAddr});
                break;
            }
            default:
                Y_ASSERT(false);
        }
        ++addrList;
    }
}

// ----------------------------------------------------------------------------
// TCoroRequests
//
TCoroRequests::TCoroRequests(TContExecutor& executor)
    : Executor_(executor)
{
}

bool TCoroRequests::ExecuteRequest(const TString& host, int family, TInstant deadline,
                                   ScheduleFunc scheduler, ICounters& counters, IEntry& result) noexcept
{
    if (family == AF_UNSPEC) {
        // c-ares checks for AF_INET only if no AF_INET6 result found, but we need all
        // possible addresses, both for IPv4 and IPv6. IPv6 has priority.
        auto request1 = new(std::nothrow) TRequest(host.c_str(), AF_INET6, Executor_, counters, result);
        auto request2 = new(std::nothrow) TRequest(host.c_str(), AF_INET, Executor_, counters, result);
        if (!request1 || !request2) {
            if (!request1 && !request2) {
                result.SetFailedIfDone(); // set error status manually
            }
            delete request1;
            delete request2;
            return false;
        }

        Requests_.PushBack(request1);
        Requests_.PushBack(request2);

        scheduler(*request1);
        scheduler(*request2);

        request1->WaitD(deadline);
        request2->WaitD(deadline);
    } else {
        Y_ASSERT(family == AF_INET6 || family == AF_INET);
        auto request1 = new(std::nothrow) TRequest(host.c_str(), family, Executor_, counters, result);
        if (!request1) {
            result.SetFailedIfDone(); // set error status manually
            return false;
        }

        Requests_.PushBack(request1);

        scheduler(*request1);
        request1->WaitD(deadline);
    }

    return true;
}

}

