#include "addr.h"

#include <library/cpp/coroutine/engine/impl.h>
#include <library/cpp/coroutine/engine/network.h>
#include <util/generic/overloaded.h>

#include <util/digest/city.h>
#include <util/generic/scope.h>
#include <util/generic/xrange.h>
#include <util/string/builder.h>
#include <util/string/join.h>
#include <util/string/strip.h>

#include <type_traits>

namespace NSrvKernel {

    namespace {
        TErrorOr<TMaybe<TString>> StripIp6(TStringBuf ip) {
            const auto oldSz = ip.size();
            if (!ip) {
                return Y_MAKE_ERROR(TNetworkResolutionError(EAI_NONAME));
            }
            if (ip.front() == '[') {
                if (ip.size() < 2 || ip.back() != ']') {
                    return Y_MAKE_ERROR(TNetworkResolutionError(EAI_NONAME));
                }
                ip = ip.substr(1, ip.size() - 2);
            }
            if (auto pos = ip.find('%'); pos != TString::npos) {
                ip = ip.substr(0, pos);
            }
            return ip.size() == oldSz ? Nothing() : TMaybe<TString>(TString{ip});
        }

        TErrorOr<TSockAddr> FromAddrInfoSingle(const addrinfo& info) noexcept {
            TSockAddr addr;
            if (!info.ai_addr) {
                return Y_MAKE_ERROR(TNetworkResolutionError(EAI_NONAME));
            }
            Y_PROPAGATE_ERROR(TSockAddr::FromSockAddr(*info.ai_addr).AssignTo(addr));
            return addr;
        }
    }

    TIpAddr::TIpAddr(std::array<ui8, 4> s) noexcept
    {
        in_addr res = {};
        for (auto i : xrange(4)) {
            res.s_addr |= ui32(s[i]) << 8 * (3 - i);
        }
        res.s_addr = htonl(res.s_addr);
        Impl_ = res;
    }

    TIpAddr::TIpAddr(std::array<ui16, 8> s) noexcept
    {
        in6_addr res = {};
        for (auto i : xrange(8)) {
            res.s6_addr[2 * i] = (s[i] & 0xFF00) >> 8;
            res.s6_addr[2 * i + 1] = (s[i] & 0xFF);
        }
        Impl_ = res;
    }

    TIpSubnet TIpAddr::Subnet(ui32 prefixBits) const noexcept {
        prefixBits = std::min(prefixBits, AddrBits());

        if (prefixBits == AddrBits()) {
            return {TIpSubnet::TNoop(), *this, prefixBits};
        }

        return std::visit(TOverloaded{
            [&](const in_addr& a) -> TIpSubnet {
                ui32 addr = ntohl(a.s_addr);
                addr &= prefixBits == 0 ? 0 : (Max<ui32>() << (32 - prefixBits));
                return {TIpSubnet::TNoop(), TIpAddr(in_addr{.s_addr = htonl(addr)}), prefixBits};
            },
            [&](const in6_addr& a) -> TIpSubnet {
                in6_addr res = a;
                ui32 bytes = prefixBits / 8;
                ui32 bits = prefixBits % 8;
                for (auto i : xrange<ui32>(bytes + 1, 16)) {
                    res.s6_addr[i] = 0;
                }
                res.s6_addr[bytes] &= 0xff & (0xff << (8 - bits));
                return {TIpSubnet::TNoop(), TIpAddr(res), prefixBits};
            },
            [](const std::monostate&) {
                return TIpSubnet();
            }
        }, Impl_);
    }

    TIpAddr TIpAddr::Broadcast(ui32 prefixLength) const noexcept {
        prefixLength = std::min(prefixLength, AddrBits());

        if (prefixLength == AddrBits()) {
            return *this;
        }

        return std::visit(TOverloaded{
            [&](const in_addr& a) {
                ui32 addr = ntohl(a.s_addr);
                addr |= (Max<ui32>() >> prefixLength);
                return TIpAddr(in_addr{.s_addr = htonl(addr)});
            },
            [&](const in6_addr& a) {
                in6_addr res = a;
                ui32 prefixBytes = prefixLength / 8;
                ui32 prefixBits = prefixLength % 8;
                for (auto i : xrange<ui32>(prefixBytes + 1, 16)) {
                    res.s6_addr[i] = 0xff;
                }
                res.s6_addr[prefixBytes] |= (0xff >> prefixBits);
                return TIpAddr(res);
            },
            [](const std::monostate&) {
                return TIpAddr();
            }
        }, Impl_);
    }

    bool TIpAddr::Loopback() const noexcept {
        return std::visit(TOverloaded{
            [](const in_addr& a) {
                const auto addr = ntohl(a.s_addr);
                // 127.0.0.0/8
                return (addr & 0xff00'0000) == 0x7f00'0000;
            },
            [](const in6_addr& a) {
                return IN6_IS_ADDR_LOOPBACK(&a);
            },
            [](const std::monostate&) {
                return false;
            }
        }, Impl_);
    }

    bool TIpAddr::LinkLocal() const noexcept {
        return std::visit(TOverloaded{
            [](const in_addr& a) {
                const auto addr = ntohl(a.s_addr);
                // 169.254.0.0/16
                return ((addr & 0xffff'0000) == 0xa9fe'0000);
            },
            [](const in6_addr& a) {
                return IN6_IS_ADDR_LINKLOCAL(&a);
            },
            [](const std::monostate&) {
                return false;
            }
        }, Impl_);
    }

    // unicast global
    bool TIpAddr::Global() const noexcept {
        return !Empty() && !Multicast() && !LinkLocal() && !Loopback();
    }

    bool TIpAddr::Multicast() const noexcept {
        return std::visit(TOverloaded{
            [](const in_addr& a) {
                const auto addr = ntohl(a.s_addr);
                return (addr & 0xf000'0000) == 0xe000'0000;
            },
            [&](const in6_addr& a) {
                return IN6_IS_ADDR_MULTICAST(&a);
            },
            [](const std::monostate&) {
                return false;
            }
        }, Impl_);
    }

    // multicast node-local
    bool TIpAddr::McNodeLocal() const noexcept {
        return (std::visit(TOverloaded{
            [](const in6_addr& a) -> bool {
                return IN6_IS_ADDR_MC_NODELOCAL(&a);
            },
            [](auto&) {
                return false;
            }
        }, Impl_));
    }

    // multicast link-local
    bool TIpAddr::McLinkLocal() const noexcept {
        return std::visit(TOverloaded{
            [](const in_addr& a) {
                const auto addr = ntohl(a.s_addr);
                // 224.0.0.0/24
                return (( addr & 0xffff'ff00) == 0xe000'0000);
            },
            [](const in6_addr& a) {
                return IN6_IS_ADDR_MC_LINKLOCAL(&a);
            },
            [](const std::monostate&) {
                return false;
            }
        }, Impl_);
    }

    // multicast global
    bool TIpAddr::McGlobal() const noexcept {
        return Multicast() && !McNodeLocal() && !McLinkLocal();
    }

    TString TIpAddr::ToString() const noexcept {
        return std::visit(TOverloaded{
            [](const in_addr& a) -> TString {
                char buf[INET_ADDRSTRLEN];
                return TString(inet_ntop(AF_INET, &a, buf, sizeof(buf)));
            },
            [](const in6_addr& a) -> TString {
                char buf[INET6_ADDRSTRLEN] = {};
                return TString(inet_ntop(AF_INET6, &a, buf, sizeof(buf)));
            },
            [](const std::monostate&) -> TString {
                return TString();
            },
        }, Impl_);
    }

    TErrorOr<TIpAddr> TIpAddr::FromIp(const TString& ip) noexcept {
        if (ip.size() < 1) {
            return Y_MAKE_ERROR(TNetworkResolutionError(EAI_NONAME));
        }

        if (const auto pos = ip.find(':'); pos != TString::npos) {
            TMaybe<TString> stripped;
            Y_PROPAGATE_ERROR(StripIp6(ip).AssignTo(stripped));
            in6_addr res;
            const char* ipPtr = (stripped ? stripped->c_str() : ip.c_str());
            if (inet_pton(AF_INET6, ipPtr, &res) != 1) {
                return Y_MAKE_ERROR(TNetworkResolutionError(EAI_NONAME));
            }
            return res;
        } else {
            in_addr res;
            if (inet_pton(AF_INET, ip.c_str(), &res) != 1) {
                return Y_MAKE_ERROR(TNetworkResolutionError(EAI_NONAME));
            }
            return res;
        }
    }

    const TIpAddr& Loopback4() noexcept {
        static const TIpAddr a(TIp4Raw{127, 0, 0, 1});
        return a;
    }

    const TIpAddr& Loopback6() noexcept {
        static const TIpAddr a(in6addr_loopback);
        return a;
    }

    const TIpAddr& Blackhole6() noexcept {
        static const TIpAddr a(TIp6Raw{0x100, 0, 0, 0, 0, 0, 0, 1}); // 100::1 in 100::/64
        return a;
    }


    TIpSubnet::TIpSubnet(TIpAddr prefix, ui32 prefixBits) noexcept
    {
        *this = prefix.Subnet(prefixBits);
    }

    TIpSubnet::TIpSubnet(TIpSubnet::TNoop, TIpAddr prefix, ui32 prefixBits) noexcept
        : Addr_(prefix)
        , PrefixBits_(prefixBits)
    {}

    TString TIpSubnet::ToString() const noexcept {
        if (Empty()) {
            return {};
        }
        return TStringBuilder() << Addr_ << '/' << PrefixBits_;
    }

    TErrorOr<TIpSubnet> TIpSubnet::FromSubnet(const TString& subnet) noexcept {
        TIpAddr addr;
        ui32 prefix = 0;
        if (const auto pos = subnet.find('/'); pos != TString::npos) {
            Y_PROPAGATE_ERROR(TIpAddr::FromIp(subnet.substr(0, pos)).AssignTo(addr));
            if (!TryFromString(subnet.substr(pos + 1), prefix)) {
                return Y_MAKE_ERROR(TNetworkResolutionError(EAI_NONAME));
            }
            if (prefix > addr.AddrBits()) {
                return Y_MAKE_ERROR(TNetworkResolutionError(EAI_NONAME));
            }
            return addr.Subnet(prefix);
        } else {
            Y_PROPAGATE_ERROR(TIpAddr::FromIp(subnet).AssignTo(addr));
            return TIpSubnet{TNoop(), addr, addr.AddrBits()};
        }
    }

    bool TIpSubnet::Contains(const TIpSubnet& snet) const noexcept {
        if (PrefixBits_ == snet.PrefixBits_) {
            return Addr_ == snet.Addr_;
        }
        return PrefixBits_ < snet.PrefixBits_ && Contains(snet.Addr_);
    }

    bool TIpSubnet::Contains(const TIpAddr& addr) const noexcept {
        if (Addr_.Impl_.index() != addr.Impl_.index()) {
            return false;
        }
        return Addr_ == addr.Subnet(PrefixBits_).Addr_;
    }


    TSockAddr::TSockAddr(const TIpAddr& addr, ui16 port, const TIp6AuxFields& aux) noexcept
    {
        std::visit(TOverloaded{
            [&](in_addr a) {
                Impl_.emplace<sockaddr_in>(sockaddr_in{
                    .sin_family = AF_INET,
                    .sin_port = htons(port),
                    .sin_addr = a,
                    .sin_zero = {}
                });
            },
            [&](in6_addr a) {
                Impl_.emplace<sockaddr_in6>(sockaddr_in6{
                    .sin6_family = AF_INET6,
                    .sin6_port = htons(port),
                    .sin6_flowinfo = htonl(aux.FlowInfo),
                    .sin6_addr = a,
                    .sin6_scope_id = aux.ScopeId,
                });
            },
            [&](std::monostate) {}
        }, addr.Impl_);
    }

    const sockaddr* TSockAddr::Addr() const noexcept {
        return (const sockaddr*)GetRawBytes().data();
    }

    TIpAddr TSockAddr::Ip() const noexcept {
        return std::visit(TOverloaded{
            [](const sockaddr_in& a) { return TIpAddr(a.sin_addr); },
            [](const sockaddr_in6& a) { return TIpAddr(a.sin6_addr); },
            [](const std::monostate&) { return TIpAddr(); },
        }, Impl_);
    }

    ui16 TSockAddr::Port() const noexcept {
        return ntohs(std::visit(TOverloaded{
            [](const sockaddr_in& a) { return (ui16) a.sin_port; },
            [](const sockaddr_in6& a) { return (ui16) a.sin6_port; },
            [](const std::monostate&) { return (ui16) 0; },
        }, Impl_));
    }

    void TSockAddr::SetPort(ui16 port) noexcept {
        port = htons(port);
        std::visit(TOverloaded{
            [&](sockaddr_in& a) { a.sin_port = port; },
            [&](sockaddr_in6& a) { a.sin6_port = port; },
            [](std::monostate&) { },
        }, Impl_);
    }

    TIp6AuxFields TSockAddr::Ip6AuxFields() const noexcept {
        return std::visit(TOverloaded{
            [](const sockaddr_in6& a) {
                return TIp6AuxFields{
                    .FlowInfo = ntohl(a.sin6_flowinfo),
                    .ScopeId = a.sin6_scope_id,
                };
            },
            [](auto&) { return TIp6AuxFields(); },
        }, Impl_);
    }

    void TSockAddr::SetIp6AuxFields(const TIp6AuxFields& aux) noexcept {
        std::visit(TOverloaded{
            [&](sockaddr_in6& a) {
                a.sin6_flowinfo = ntohl(aux.FlowInfo);
                a.sin6_scope_id = aux.ScopeId;
            },
            [](auto&) { },
        }, Impl_);
    }

    TErrorOr<TSockAddr> TSockAddr::FromIpPort(const TString& ipPort) noexcept {
        if (ipPort.size() < 3) {
            return Y_MAKE_ERROR(TNetworkResolutionError(EAI_NONAME));
        }
        TIpAddr ip;
        ui16 port = 0;
        auto sep0 = ipPort.find(':');
        auto sep = ipPort.rfind(':');
        if (sep == TString::npos || sep != sep0 && ipPort[sep - 1] != ']') {
            Y_PROPAGATE_ERROR(TIpAddr::FromIp(ipPort).AssignTo(ip));
        } else {
            Y_PROPAGATE_ERROR(TIpAddr::FromIp(ipPort.substr(0, sep)).AssignTo(ip));
            if (!TryFromString<ui16>(ipPort.substr(sep + 1), port)) {
                return Y_MAKE_ERROR(TNetworkResolutionError(EAI_NONAME));
            }
        }
        return {ip, port};
    }

    TString TSockAddr::ToString() const noexcept {
        return std::visit(TOverloaded{
            [](const sockaddr_in& a) -> TString {
                if (a.sin_port) {
                    TStringBuilder b;
                    b.reserve(INET_ADDRSTRLEN + 1 + 5);
                    return b << TIpAddr(a.sin_addr) << ':' << ntohs(a.sin_port);
                } else {
                    return TIpAddr(a.sin_addr).ToString();
                }
            },
            [](const sockaddr_in6& a) -> TString {
                // TODO(velavokr): handle the nonzero scope id case
                if (a.sin6_port) {
                    TStringBuilder b;
                    b.reserve(INET6_ADDRSTRLEN + 1 + 2 + 5);
                    return b << '[' << TIpAddr(a.sin6_addr) << "]:" << ntohs(a.sin6_port);
                } else {
                    return TIpAddr(a.sin6_addr).ToString();
                }
            },
            [](const std::monostate&) -> TString {
                return {};
            },
        }, Impl_);
    }

    TErrorOr<TSockAddr> TSockAddr::FromIpPort(const TString& ip, ui16 port) noexcept {
        TIpAddr addr;
        Y_PROPAGATE_ERROR(TIpAddr::FromIp(ip).AssignTo(addr));
        return {addr, port};
    }

    TErrorOr<TSockAddr> TSockAddr::FromSockAddr(const sockaddr& addr, socklen_t len) noexcept {
        switch (addr.sa_family) {
        case AF_INET:
            if (len < sizeof(sockaddr_in)) {
                return Y_MAKE_ERROR(TSystemError(ENOMEM));
            }
            return (const sockaddr_in&) addr;
        case AF_INET6:
            if (len < sizeof(sockaddr_in6)) {
                return Y_MAKE_ERROR(TSystemError(ENOMEM));
            }
            return (const sockaddr_in6&) addr;
        default:
            return Y_MAKE_ERROR(TSystemError(EAFNOSUPPORT));
        }
    }

    TErrorOr<TSockAddr> TSockAddr::FromSockAddr(const sockaddr_storage& addr) noexcept {
        return FromSockAddr(
            reinterpret_cast<const sockaddr&>(addr),
            sizeof(addr)
        );
    }

    TErrorOr<TSockAddr> TSockAddr::FromRemoteAddr(const NAddr::IRemoteAddr& addr) noexcept {
        if (addr.Addr()) {
            return FromSockAddr(*addr.Addr());
        }
        return TSockAddr();
    }

    TErrorOr<std::vector<TSockAddr>> TSockAddr::FromAddrInfo(const addrinfo& info) noexcept {
        std::vector<TSockAddr> res;
        res.reserve(1);
        for (auto it = &info; it; it = it->ai_next) {
            TSockAddr addr;
            Y_PROPAGATE_ERROR(FromAddrInfoSingle(*it).AssignTo(addr));
            res.emplace_back(addr);
        }
        res.erase(std::unique(res.begin(), res.end()), res.end());
        return res;
    }

    TErrorOr<std::vector<TSockAddr>> TSockAddr::FromAddrInfo(const TNetworkAddress& info) noexcept {
        std::vector<TSockAddr> res;
        res.reserve(1);
        for (auto it = info.Begin(); it != info.End(); ++it) {
            TSockAddr addr;
            Y_PROPAGATE_ERROR(FromAddrInfoSingle(*it).AssignTo(addr));
            res.emplace_back(addr);
        }
        res.erase(std::unique(res.begin(), res.end()), res.end());
        return res;
    }

    int ConnectD(TCont* cont, TSocketHolder& s, const TSockAddr& addr, int type, int protocol, TInstant deadline) noexcept {
        TSocketHolder res(NCoro::Socket(addr.AddrFamily(), type, protocol));

        if (res.Closed()) {
            return LastSystemError();
        }

        const int ret = NCoro::ConnectD(cont, res, addr.Addr(), addr.Len(), deadline);

        if (!ret) {
            s.Swap(res);
        }

        return ret;
    }
}

template <>
void Out<NSrvKernel::TIpAddr>(IOutputStream& out, const NSrvKernel::TIpAddr& obj) {
    out << obj.ToString();
}

template <>
void Out<NSrvKernel::TIpSubnet>(IOutputStream& out, const NSrvKernel::TIpSubnet& obj) {
    out << obj.ToString();
}

template <>
void Out<NSrvKernel::TSockAddr>(IOutputStream& out, const NSrvKernel::TSockAddr& obj) {
    out << obj.ToString();
}

template <>
void Out<NSrvKernel::TSockAddrInfo>(IOutputStream& out, const NSrvKernel::TSockAddrInfo& obj) {
    out << JoinSeq(",", obj.Addresses);
}
