#include "ipaddr.h"

#include "string/coder.h"
#include "string/string_utils.h"

#include <util/draft/ip.h>

#include <array>

namespace NPassport::NUtils {
    const ui64 MaskOf64Bits = (~((ui64)0));

    TIpAddr::TIpAddr(const TString& strAddr) {
        Y_ENSURE(Parse(strAddr));
    }

    constexpr TIpAddr::TIpAddr(ui8 a, ui8 b, ui8 c, ui8 d) noexcept
        : Low_(ui64(0xffff) << 32 | (ui32(a) << 24) | (ui32(b) << 16) | (ui32(c) << 8) | ui32(d))
    {
    }

    constexpr TIpAddr::TIpAddr(ui16 a, ui16 b, ui16 c, ui16 d, ui16 e, ui16 f, ui16 g, ui16 h) noexcept
        : Low_((ui64(e) << 48) | (ui64(f) << 32) | (ui64(g) << 16) | ui64(h))
        , High_((ui64(a) << 48) | (ui64(b) << 32) | (ui64(c) << 16) | ui64(d))
    {
    }

    TString TIpAddr::Normalize(const TString& strIp) {
        TIpAddr ip;
        if (ip.Parse(strIp)) {
            return ip.ToString();
        }

        return TString();
    }

    bool TIpAddr::IsIpv4() const {
        return High_ == 0 && ((Low_ >> 32) == 0xffff);
    }

    bool TIpAddr::IsIpv6() const {
        return !IsIpv4();
    }

    ui32 TIpAddr::ProjectId() const {
        if (IsIpv4()) {
            return 0;
        }

        return (Low_ >> 32) & 0xffffffff;
    }

    bool TIpAddr::Parse(const TString& strAddr) {
        High_ = 0;
        Low_ = 0;

        struct in_addr baseaddr {};
        if (inet_pton(AF_INET, strAddr.c_str(), &baseaddr)) {
            // Ipv4 address
            Low_ = ((ui64)0xffff) << 32;
            // NOLINTNEXTLINE(readability-isolate-declaration)
            Low_ += ntohl(baseaddr.s_addr);
            return true;
        }

        struct in6_addr baseaddrv6 {};

        if (inet_pton(AF_INET6, strAddr.c_str(), &baseaddrv6)) {
            // IPv6 address is just array of bytes
            for (size_t idx = 0; idx < 8; ++idx) {
                High_ <<= 8;
                High_ |= ((ui64)baseaddrv6.s6_addr[idx]) & 0xFF;
            }
            for (size_t idx = 8; idx < 16; ++idx) {
                Low_ <<= 8;
                Low_ |= ((ui64)baseaddrv6.s6_addr[idx]) & 0xFF;
            }

            return true;
        }

        return false;
    }

    bool TIpAddr::ParseBase64(const TStringBuf strAddr) {
        const int len = strAddr.size();

        if (len != 6 && len != 22) { // not ipv6, not ipv4, seems broken
            return false;
        }

        TString binaddr = Base64url2bin(strAddr);

        if (binaddr.size() == 4) {
            High_ = Low_ = 0;
            for (int idx = 3; idx >= 0; --idx) {
                Low_ <<= 8;
                Low_ |= (ui8)binaddr[idx];
            }
            Low_ |= ((ui64)0xffff) << 32;
            return true;
        }

        if (binaddr.size() == 16) {
            High_ = Low_ = 0;
            for (int idx = 7; idx >= 0; --idx) {
                Low_ <<= 8;
                Low_ |= (ui8)binaddr[idx];
            }
            for (int idx = 15; idx >= 8; --idx) {
                High_ <<= 8;
                High_ |= (ui8)binaddr[idx];
            }
            return true;
        }

        return false;
    }

    static std::array<ui8, 4> ToBytesIpv4(ui64 low) {
        return {
            (ui8)((low >> 24) & 0xFF),
            (ui8)((low >> 16) & 0xFF),
            (ui8)((low >> 8) & 0xFF),
            (ui8)(low & 0xFF),
        };
    }

    static std::array<ui8, 16> ToBytesIpv6(ui64 low, ui64 high) {
        std::array<ui8, 16> res{};

        for (size_t idx = 0; idx < 8; ++idx) {
            res[idx] = (high >> (8 * (7 - idx))) & 0xFF;
        }
        for (size_t idx = 0; idx < 8; ++idx) {
            res[8 + idx] = (low >> (8 * (7 - idx))) & 0xFF;
        }

        return res;
    }

    TString TIpAddr::ToString() const {
        if (IsIpv4()) {
            std::array<ui8, 4> bytes = ToBytesIpv4(Low_);
            return NUtils::CreateStr(
                (unsigned)bytes[0],
                ".",
                (unsigned)bytes[1],
                ".",
                (unsigned)bytes[2],
                ".",
                (unsigned)bytes[3]);
        }

        char buf[INET6_ADDRSTRLEN];
        struct in6_addr addr {};

        // IPv6 address is just array of bytes
        std::array<ui8, 16> bytes = ToBytesIpv6(Low_, High_);
        memcpy(addr.s6_addr, bytes.data(), bytes.size());

        inet_ntop(AF_INET6, &addr, buf, INET6_ADDRSTRLEN);

        return buf;
    }

    TIpAddr TIpAddr::GetRangeStart(int mask_width) const {
        if (IsIpv4()) {
            mask_width += 96;
        }

        const int numbits = (mask_width > 64) ? 128 - mask_width : 64 - mask_width;
        const ui64 mask = MaskOf64Bits << numbits;

        TIpAddr result;

        if (mask_width > 64) {
            result.High_ = High_;
            result.Low_ = Low_ & mask;
        } else {
            result.High_ = High_ & mask;
            result.Low_ = 0;
        }

        return result;
    }

    TIpAddr TIpAddr::GetRangeEnd(int mask_width) const {
        if (IsIpv4()) {
            mask_width += 96;
        }

        const int numbits = (mask_width > 64) ? 128 - mask_width : 64 - mask_width;
        const ui64 mask = MaskOf64Bits << numbits;

        TIpAddr result;

        if (mask_width > 64) {
            result.High_ = High_;
            result.Low_ = Low_ | ~mask;
        } else {
            result.High_ = High_ | ~mask;
            result.Low_ = MaskOf64Bits;
        }

        return result;
    }

    TIpAddr TIpAddr::Next() const {
        TIpAddr result;
        result.Low_ = Low_ + 1;
        result.High_ = result.Low_ ? High_ : (High_ + 1);
        return result;
    }

    TIpAddr TIpAddr::Prev() const {
        TIpAddr result;
        result.Low_ = Low_ - 1;
        result.High_ = Low_ ? High_ : (High_ - 1);
        ;
        return result;
    }

    TIpAddr& TIpAddr::operator+=(const TIpAddr& other) {
        ui64 tmp = Low_;

        Low_ += other.Low_;
        High_ += other.High_;

        if (Low_ < tmp) { // shift happened
            ++High_;
        }

        return *this;
    }

    TIpAddr& TIpAddr::operator+=(ui64 n) {
        ui64 tmp = Low_;

        Low_ += n;

        if (Low_ < tmp) { // shift happened
            ++High_;
        }

        return *this;
    }

    TIpAddr& TIpAddr::operator-=(const TIpAddr& other) {
        ui64 tmp = Low_;

        Low_ -= other.Low_;
        High_ -= other.High_;

        if (Low_ > tmp) { // shift happened
            --High_;
        }

        return *this;
    }

    TIpAddr& TIpAddr::operator-=(ui64 n) {
        ui64 tmp = Low_;

        Low_ -= n;

        if (Low_ > tmp) { // shift happened
            --High_;
        }

        return *this;
    }

    TIpAddr& TIpAddr::operator<<=(unsigned n) {
        if (n >= 128) {
            High_ = Low_ = 0;
            return *this;
        }

        if (n >= 64) {
            High_ = Low_ << (n - 64);
            Low_ = 0;

            return *this;
        }

        if (n > 0) {
            High_ <<= n;

            const ui64 mask = ~(MaskOf64Bits >> n);

            High_ |= (Low_ & mask) >> (64 - n);

            Low_ <<= n;
        }

        return *this;
    }

    TIpAddr& TIpAddr::operator>>=(unsigned n) {
        if (n >= 128) {
            High_ = Low_ = 0;
            return *this;
        }

        if (n >= 64) {
            Low_ = High_ >> (n - 64);
            High_ = 0;

            return *this;
        }

        if (n > 0) {
            Low_ >>= n;

            const ui64 mask = ~(MaskOf64Bits << n);

            Low_ |= (High_ & mask) << (64 - n);

            High_ >>= n;
        }

        return *this;
    }

    TString TIpAddr::ToBase64String() const {
        ui64 binaddr[2];
        binaddr[0] = Low_;
        binaddr[1] = High_;

        return Bin2base64url(TStringBuf((const char*)binaddr, IsIpv4() ? 4 : 16));
    }

    TString TIpAddr::ToPackedStringV6() const {
        std::array<ui8, 16> bytes = ToBytesIpv6(Low_, High_);
        return TString((const char*)bytes.data(), bytes.size());
    }

    TString TIpAddr::ToPackedStringShortest() const {
        if (IsIpv6()) {
            return ToPackedStringV6();
        }

        std::array<ui8, 4> bytes = ToBytesIpv4(Low_);
        return TString((const char*)bytes.data(), bytes.size());
    }

    TString TIpAddr::ToStringHalfV6() const {
        if (IsIpv4()) {
            return ToString();
        }

        TIpAddr ip(*this);
        ip.Low_ = 0;
        return ip.ToString();
    }

    bool TIpAddr::IsLoopback() const {
        if (IsIpv6()) {
            // ::1/128
            return High_ == 0 && Low_ == 1;
        }

        // 127.0.0.0/8 (RFC 990)
        return ToBytesIpv4(Low_)[0] == 127;
    }

    bool TIpAddr::IsPrivate() const {
        if (IsIpv6()) {
            // https://a.yandex-team.ru/arc/trunk/arcadia/contrib/python/netaddr/netaddr/ip/__init__.py?rev=5424435#L1893
            // IPNetwork('fc00::/7'),  #   Unique Local Addresses (ULA)
            // IPNetwork('fec0::/10'), #   Site Local Addresses (deprecated - RFC 3879)
            const std::array<ui8, 16> bytes = ToBytesIpv6(Low_, High_);
            if ((bytes[0] & 0xfe) == 0xfc || (bytes[0] == 0xfe && (bytes[1] & 0xc0) == 0xc0)) {
                return true;
            }

            // is link local: RFCs 3927 and 4291
            // fe80::/10
            return bytes[0] == 0xfe && (bytes[1] & 0xc0) == 0x80;
        }

        /*
         * https://a.yandex-team.ru/arc/trunk/arcadia/contrib/python/netaddr/netaddr/ip/__init__.py?rev=6686413#L1859
         * IPNetwork('10.0.0.0/8'),        #   Class A private network local communication (RFC 1918)
         * IPNetwork('100.64.0.0/10'),     #   Carrier grade NAT (RFC 6598)
         * IPNetwork('172.16.0.0/12'),     #   Private network - local communication (RFC 1918)
         * IPNetwork('192.0.0.0/24'),      #   IANA IPv4 Special Purpose Address Registry (RFC 5736)
         * IPNetwork('192.168.0.0/16'),    #   Class B private network local communication (RFC 1918)
         * IPNetwork('198.18.0.0/15'),     #  Testing of inter-network communications between subnets (RFC 2544)
         * IPRange('239.0.0.0', '239.255.255.255'),    #   Administrative Multicast
         */

        const std::array<ui8, 4> bytes = ToBytesIpv4(Low_);

        if ((bytes[0] == 10) ||
            (bytes[0] == 100 && (bytes[1] & 0xc0) == 64) ||
            (bytes[0] == 172 && (bytes[1] & 0xf0) == 16) ||
            (bytes[0] == 192 && bytes[1] == 0 && bytes[2] == 0) ||
            (bytes[0] == 192 && bytes[1] == 168) ||
            (bytes[0] == 198 && (bytes[1] & 0xfe) == 18) ||
            (bytes[0] == 239))
        {
            return true;
        }

        // is link local
        // 169.254.0.0/16
        return bytes[0] == 169 && bytes[1] == 254;
    }

    bool TIpAddr::IsMulticast() const {
        if (IsIpv6()) {
            // ff00::/8
            return (ToBytesIpv6(Low_, High_)[0] & 0xff) == 0xff;
        }

        // 224.0.0.0/4
        return (ToBytesIpv4(Low_)[0] & 0xf0) == 224;
    }
}
