#include <crypta/lib/native/dns/addr_hash.h>

#include <library/cpp/testing/unittest/registar.h>
#include <util/generic/hash.h>

using TAddrOnlyMap = THashMap<NAddr::TOpaqueAddr, TString, NCrypta::TAddrOnlyHash, NCrypta::TAddrOnlyEqual>;

sockaddr_in6 GetEmptyIpv6Addr() {
    sockaddr_in6 in6;
    ::memset(&in6, 0, sizeof(sockaddr_in6));
    in6.sin6_family = AF_INET6;
    return in6;
}

sockaddr_in GetEmptyIpv4Addr() {
    sockaddr_in in;
    ::memset(&in, 0, sizeof(sockaddr_in));
    in.sin_family = AF_INET;
    return in;
}

NAddr::TOpaqueAddr GetIpv6Loopback(in_port_t port = 0) {
    auto in6 = GetEmptyIpv6Addr();
    in6.sin6_addr = IN6ADDR_LOOPBACK_INIT;
    in6.sin6_port = port;
    return NAddr::TOpaqueAddr((sockaddr*)&in6);
}

NAddr::TOpaqueAddr GetIpv4Loopback(in_port_t port = 0) {
    auto in = GetEmptyIpv4Addr();
    in.sin_addr.s_addr = INADDR_LOOPBACK;
    in.sin_port = port;
    return NAddr::TOpaqueAddr((sockaddr*)&in);
}

NAddr::TOpaqueAddr GetIpv6Addr(uint8_t fill, in_port_t port = 0) {
    auto in6 = GetEmptyIpv6Addr();
    ::memset(&in6.sin6_addr, fill, sizeof(in6_addr));
    in6.sin6_port = port;
    return NAddr::TOpaqueAddr((sockaddr*)&in6);
}

NAddr::TOpaqueAddr GetIpv4Addr(uint8_t fill, in_port_t port = 0) {
    auto in = GetEmptyIpv4Addr();
    ::memset(&in.sin_addr, fill, sizeof(in_addr));
    in.sin_port = port;
    return NAddr::TOpaqueAddr((sockaddr*)&in);
}

Y_UNIT_TEST_SUITE(AddHashTest) {
    Y_UNIT_TEST(OneIpv6) {
        TAddrOnlyMap map;
        auto addr = GetIpv6Loopback();

        map[addr] = "loopback-ip6";
        UNIT_ASSERT(map.contains(addr));

        map.erase(addr);
        UNIT_ASSERT(map.empty());
    }

    Y_UNIT_TEST(OneIpv4) {
        TAddrOnlyMap map;
        auto addr = GetIpv4Loopback();

        map[addr] = "loopback-ip";
        UNIT_ASSERT(map.contains(addr));

        map.erase(addr);
        UNIT_ASSERT(map.empty());
    }

    template<typename TGetLoopback, typename TGetAddr>
    void RunMultiple(TGetLoopback GetLoopback, TGetAddr GetAddr) {
        auto loopback1 = GetLoopback(1);
        auto loopback2 = GetLoopback(2);

        TAddrOnlyMap map;
        map[loopback1] = "port does not count";
        UNIT_ASSERT(map.contains(loopback1));
        UNIT_ASSERT(map.contains(loopback2));

        auto addrFf = GetAddr(0xff, 0);
        map[addrFf] = "but sin6_addr (sin_addr) does";

        UNIT_ASSERT(map.size() == 2);
        UNIT_ASSERT(map.contains(addrFf));
        UNIT_ASSERT(map.contains(loopback1));
        UNIT_ASSERT(map[loopback1] == map[loopback2]);
        UNIT_ASSERT(map[loopback1] != map[addrFf]);

        map.erase(loopback2);
        UNIT_ASSERT(!map.contains(loopback1));
    }

    Y_UNIT_TEST(MultipleAddresses) {
        RunMultiple<>(GetIpv6Loopback, GetIpv6Addr);
        RunMultiple<>(GetIpv4Loopback, GetIpv4Addr);
    }

    Y_UNIT_TEST(MixedIpVersions) {
        auto loopback6 = GetIpv6Loopback();
        auto loopback4 = GetIpv4Loopback();

        TAddrOnlyMap map;
        map[loopback6] = "ip6-loopback";
        map[loopback4] = "loopback";

        UNIT_ASSERT(map.size() == 2);
        UNIT_ASSERT(map[loopback6] != map[loopback4]);

        map.erase(loopback4);
        UNIT_ASSERT(!map.contains(loopback4));
    }
}
