#include "connect.h"

#include <library/cpp/coroutine/engine/network.h>

namespace NModProxy {

    TAddrVec FilterAddr(const TSockAddrInfo& addr, int family) noexcept {
        TAddrVec res;
        res.reserve(addr.Addresses.size());
        for (const auto& socketAddr : addr.Addresses) {
            const auto socketFamily = socketAddr.AddrFamily();
            if (socketFamily == AF_INET && family == AF_INET6) {
                continue;
            }
            if (socketFamily == AF_INET6 && family == AF_INET) {
                continue;
            }
            if (socketFamily != AF_INET6 && socketFamily != AF_INET) {
                Y_ASSERT(false);
                continue; // TODO(velavokr): We cannot handle anything but ip anyway. Maybe just crash here?
            }
            res.push_back(&socketAddr);
        }

        return res;
    }

    void ShuffleAddr(TAddrVec& addr) noexcept {
        auto s = addr.begin();

        for (auto i = addr.begin(); i != addr.end(); ++i) {
            if ((*i)->AddrFamily() != (*s)->AddrFamily()) {
                Shuffle(s, i);
                s = i;
            }
        }

        Shuffle(s, addr.end());
    }


    TErrorOr<const TSockAddr*> Connect(TContExecutor* const exec, TSocketHolder& s, const TAddrVec& addrs, int type,
                        int protocol, const TInstant deadLine, TConnStats& stats) noexcept {
        int ret = EHOSTUNREACH;

        if (addrs.empty()) {
            return Y_MAKE_ERROR(TSystemError(ret) << TErrno(ret) << " ");
        }

        TAddrVec shuffled = addrs;
        ShuffleAddr(shuffled);

        const TSockAddr* last = nullptr;
        for (const auto& addr : shuffled) {
            last = addr;
            ret = ConnectD(exec->Running(), s, *addr, type, protocol, deadLine);

            if (ret == EIO) {
                ret = ECONNREFUSED; // TODO(velavokr): fix the poller wrapper
            }

            if (ret == 0) {
                return addr;
            }
            if (ret == ECANCELED) {
                break;
            }
            if (ret == ETIMEDOUT) {
                ++stats.ConnTimeout;
                break;
            } else if (ErrorIsConnRefused(ret)) {
                ++stats.ConnRefused;
            } else {
                ++stats.ConnOtherError;
                break;
            }
        }

        return Y_MAKE_ERROR(TSystemError(ret) << TErrno(ret) << " connect to " << last->ToString() << " ");
    }
}
