#include <util/system/yassert.h>
#include <util/generic/yexception.h>
#include <util/string/cast.h>

#include <iostream>
#include <stdlib.h>
#include <errno.h>
#include <stdio.h>

#include "ipv6.h"
#include "tkipv6.h"

#ifndef IN6_ARE_ADDR_EQUAL
#define IN6_ARE_ADDR_EQUAL(a, b) \
    (memcmp((const void*)(a), (const void*)(b), sizeof(struct in6_addr)) == 0)
#endif

static bool operator==(const in_addr& a1, const in_addr& a2) {
    return a1.s_addr == a2.s_addr;
}

static bool operator==(const in6_addr& a1, const in6_addr& a2) {
    return IN6_ARE_ADDR_EQUAL(&a1, &a2);
}

static bool belongs_to_net_with_mask(const in_addr& addr, const in_addr& net, size_t maskBits) {
    return ((addr.s_addr ^ net.s_addr) >> (32 - maskBits)) == 0;
}

static bool belongs_to_net_with_mask(const in6_addr& addr, const in6_addr& net, size_t maskBits) {
    const ui64* l = reinterpret_cast<const ui64*>(&addr.s6_addr);
    const ui64* r = reinterpret_cast<const ui64*>(&net.s6_addr);

    if (maskBits > 64)
        return ((l[0] ^ r[0]) == 0) && (((l[1] ^ r[1]) >> (128 - maskBits)) == 0);
    else
        return ((l[1] ^ r[1]) >> (64 - maskBits)) == 0;
}

static bool operator==(const sockaddr_storage& s1, const sockaddr_storage& s2) {
    if (s1.ss_family != s2.ss_family)
        return false;
    if (s1.ss_family == AF_INET6) {
        return (((sockaddr_in6*)&s1)->sin6_addr == ((sockaddr_in6*)&s2)->sin6_addr);
    } else {
        return (((sockaddr_in*)&s1)->sin_addr == ((sockaddr_in*)&s2)->sin_addr);
    }
}

bool TIpAddr::BelongsToNetWithMask(const TIpAddr& addr, const TIpAddr& net, size_t maskSize) {
    if (!addr.IsValid())
        throw yexception() << "addr " << addr.ToString() << " isn't valid";
    if (!net.IsValid())
        throw yexception() << "net " << net.ToString() << " isn't valid";

    if (addr.IsIpv6() != net.IsIpv6())
        throw yexception() << "net " << net.ToString() << " ipv6:" << int(net.IsIpv6())
                           << " but addr " << addr.ToString() << " ipv6:" << int(addr.IsIpv6());

    if (addr.IsIpv6()) {
        return belongs_to_net_with_mask(((sockaddr_in6*)&net.addr_storage)->sin6_addr, ((sockaddr_in6*)&addr.addr_storage)->sin6_addr, maskSize);
    } else {
        return belongs_to_net_with_mask(((sockaddr_in*)&net.addr_storage)->sin_addr, ((sockaddr_in*)&addr.addr_storage)->sin_addr, maskSize);
    }
}

bool TIpAddr::operator==(const TIpAddr& addr) const {
    return (is_valid && (addr.is_valid == is_valid) && (addr.is_ipv6 == is_ipv6) && (this->addr_storage == addr.addr_storage));
}

bool TIpAddr::operator!=(const TIpAddr& addr) const {
    return !(*this == addr);
}

TIpAddr& TIpAddr::operator|=(const TIpAddr& addr) {
    if (!addr.is_valid || !is_valid || (addr.is_ipv6 != is_ipv6)) {
        is_valid = false;
        return *this;
    }
    if (is_ipv6) {
        in6_addr& ipv6_addr = ((sockaddr_in6*)&addr_storage)->sin6_addr;
        in6_addr& ipv6_addr2 = ((sockaddr_in6*)&addr.addr_storage)->sin6_addr;
        for (int i = 0; i < 16; ++i)
            ipv6_addr.s6_addr[i] |= ipv6_addr2.s6_addr[i];
    } else {
        in_addr& ipv4_addr = ((sockaddr_in*)&addr_storage)->sin_addr;
        in_addr& ipv4_addr2 = ((sockaddr_in*)&addr.addr_storage)->sin_addr;
        ipv4_addr.s_addr |= ipv4_addr2.s_addr;
    }
    return *this;
}

TIpAddr TIpAddr::operator|(const TIpAddr& addr) const {
    TIpAddr res(*this);
    res |= addr;
    return res;
}

TIpAddr& TIpAddr::operator&=(const TIpAddr& addr) {
    if (!addr.is_valid || !is_valid || (addr.is_ipv6 != is_ipv6)) {
        is_valid = false;
        return *this;
    }
    if (is_ipv6) {
        in6_addr& ipv6_addr = ((sockaddr_in6*)&addr_storage)->sin6_addr;
        in6_addr& ipv6_addr2 = ((sockaddr_in6*)&addr.addr_storage)->sin6_addr;
        for (int i = 0; i < 16; ++i)
            ipv6_addr.s6_addr[i] &= ipv6_addr2.s6_addr[i];
    } else {
        in_addr& ipv4_addr = ((sockaddr_in*)&addr_storage)->sin_addr;
        in_addr& ipv4_addr2 = ((sockaddr_in*)&addr.addr_storage)->sin_addr;
        ipv4_addr.s_addr &= ipv4_addr2.s_addr;
    }
    return *this;
}

TIpAddr TIpAddr::operator&(const TIpAddr& addr) const {
    TIpAddr res(*this);
    res &= addr;
    return res;
}

void TIpAddr::MakeMask(int len, bool is_ipv6) {
    this->is_ipv6 = is_ipv6;
    memset(&addr_storage, 0, sizeof(addr_storage));
    if (is_ipv6) {
        addr_storage.ss_family = AF_INET6;
#ifdef __FreeBSD__
        addr_storage.ss_len = 28;
#endif
        in6_addr& ipv6_addr = ((sockaddr_in6*)&addr_storage)->sin6_addr;
        int n_ffs = len / 8;
        for (int i = 0; i < n_ffs; ++i)
            ipv6_addr.s6_addr[i] |= 0xff;

        if(n_ffs < 16) {
            ipv6_addr.s6_addr[n_ffs] = 0xff << (8 - (len % 8));
        }

        for (int i = n_ffs + 1; i < 16; ++i)
            ipv6_addr.s6_addr[i] |= 0;
    } else {
        addr_storage.ss_family = AF_INET;
#ifdef __FreeBSD__
        addr_storage.ss_len = 16;
#endif
        in_addr& ipv4_addr = ((sockaddr_in*)&addr_storage)->sin_addr;
        if (len == 0)
            ipv4_addr.s_addr = 0;
        else if (len >= 32)
            ipv4_addr.s_addr = 0xffffffff;
        else
            ipv4_addr.s_addr = htonl(0xffffffff << (32 - len));
    }
    sSourceString = ::ToString(len);
    is_valid = true;
}

TString TIpAddr::ToDNSBLString() const {
    if (!is_valid)
        return TString("");
    char domain[512];
    if (is_ipv6) {
        char hex_chr[] = "0123456789abcdef";
        const ui8* ipaddr = (const ui8*)&GetIn6Addr()->s6_addr;
        int idx = 0;
        for (int i = 15; i >= 0; --i) {
            domain[idx++] = hex_chr[ipaddr[i] & 0x0F];
            domain[idx++] = '.';
            domain[idx++] = hex_chr[(ipaddr[i] >> 4) & 0x0F];
            domain[idx++] = '.';
        }
        domain[idx - 1] = '\0';
    } else {
        const in_addr* addr = GetInAddr();
        const ui8* ipaddr = (const ui8*)&addr->s_addr;
        snprintf(domain, sizeof(domain), "%u.%u.%u.%u",
                 (unsigned)ipaddr[3], (unsigned)ipaddr[2],
                 (unsigned)ipaddr[1], (unsigned)ipaddr[0]);
    }
    return domain;
}

TString TIpAddr::ToString() const {
    if (!is_valid)
        return TString();
    char host[512];
    size_t ss_len;
#ifdef __FreeBSD__
    ss_len = addr_storage.ss_len;
#else
    ss_len = sizeof(addr_storage);
#endif
    int res = getnameinfo((sockaddr*)&addr_storage, ss_len, host, sizeof(host), nullptr, 0, NI_NUMERICHOST);
    if (res != 0) {
        //cerr << "getnameinfo failed " << gai_strerror(res) << endl;
        return TString();
    }
    return host;
}

void TIpAddr::FromString(const TString& str) {
    is_valid = false;
    struct addrinfo* ai = nullptr;
    addrinfo hint;
    sSourceString = str;

    TString sIP(str);
#ifdef _MSC_VER
    if (_strnicmp(sIP.c_str(), "ipv6:", 5) == 0)
#else
    if (strncasecmp(sIP.c_str(), "ipv6:", 5) == 0)
#endif
        sIP.erase(0, 5);

#ifdef _MSC_VER // convert embedded ipv4 to common ipv4
    if (_strnicmp(sIP.c_str(), "::ffff:", 7) == 0)
#else
    if (strncasecmp(sIP.c_str(), "::ffff:", 7) == 0)
#endif
        if (sIP.find(".") != TString::npos)
            sIP.erase(0, 7);

    size_t percentPos = sIP.find("%");
    if (percentPos != TString::npos)
        sIP.erase(percentPos, sIP.length() - percentPos + 1);

    memset(&hint, 0, sizeof(hint));
    hint.ai_flags = AI_NUMERICHOST;
    int res = getaddrinfo(sIP.c_str(), nullptr, &hint, &ai);
    if (res != 0) {
        //cerr << "getaddrinfo fail " << gai_strerror(res) << " " << str << endl;
        if (ai)
            freeaddrinfo(ai);
        return;
    }
    is_valid = true;
    if (ai != nullptr) {
        memcpy(&addr_storage, ai->ai_addr, ai->ai_addrlen);
        socklen = ai->ai_addrlen;
        is_ipv6 = (ai->ai_family == AF_INET6);
        freeaddrinfo(ai);
    }
}

TString TIpAddr::GetSourceString() const {
    return sSourceString;
}

bool TIpAddr::IsIpv6() const {
    return is_ipv6;
}

bool TIpAddr::IsValid() const {
    return is_valid;
}

const sockaddr* TIpAddr::GetAddr() const {
    return reinterpret_cast<const sockaddr*>(&addr_storage);
}

size_t TIpAddr::GetAddrSize() const {
    if (is_ipv6)
        return sizeof(sockaddr_in6);
    else
        return sizeof(sockaddr_in);
}

const in6_addr* TIpAddr::GetIn6Addr() const {
    if (!is_ipv6)
        return nullptr;
    return &reinterpret_cast<const sockaddr_in6*>(&addr_storage)->sin6_addr;
}

const in_addr* TIpAddr::GetInAddr() const {
    if (is_ipv6)
        return nullptr;
    return &reinterpret_cast<const sockaddr_in*>(&addr_storage)->sin_addr;
}

TKIPv6 TIpAddr::ToTKIPv6() const {
    ui64 highpart = 0;
    ui64 lowpart = 0;
    if (is_valid) {
        if (is_ipv6) {
            const in6_addr* addr = GetIn6Addr();
            const ui32* addr32 = reinterpret_cast<const ui32*>(addr->s6_addr);
            highpart = ui64(htonl(addr32[0])) << 32 | htonl(addr32[1]);
            lowpart = ui64(htonl(addr32[2])) << 32 | htonl(addr32[3]);
        } else {
            const in_addr* addr = GetInAddr();
            highpart = 0;
            lowpart = 0x0000FFFF00000000 | htonl(addr->s_addr); //ipv4 mapped into ipv6 address space
        }
    }
    return TKIPv6(highpart, lowpart);
}

TIpAddr::TIpAddr(const TString& str)
    : is_valid(false)
{
    FromString(str);
}

TIpAddr::TIpAddr()
    : is_valid(false)
    , is_ipv6()
    , socklen()
{
}

TIpAddr::~TIpAddr() {
}
