#include "tnetipv6.h"
#include "tnetipv6_ut_helpers.h"

#include <library/cpp/testing/unittest/registar.h>

namespace detail
{

template<typename Iterator>
void IsIncludeRange(TNetIPv6& netIP, const Iterator& begin, const Iterator& end, const typename Iterator::difference_type& step, const TString& message, bool expected) {
   for (Iterator it = begin; it < end; it += step) {
      if (netIP.IsIncludeNet(*it) != expected) {
         std::cout << it->toStroka() << message << std::endl;
         UNIT_ASSERT(!"wrong detect");
      }
   }
}

template<typename Iterator>
void IsIncludeRange(TNetIPv6& netIP, const Iterator& begin, const Iterator& end, const typename Iterator::difference_type& step, const TString& message) {
   return IsIncludeRange(netIP, begin, end, step, message, true);
}

template<typename Iterator>
void IsNotIncludeRange(TNetIPv6& netIP, const Iterator& begin, const Iterator& end, const typename Iterator::difference_type& step, const TString& message) {
   return IsIncludeRange(netIP, begin, end, step, message, false);
}

template<typename Iterator>
void TestIPByRangeImpl(const TKIPv6& left, const TKIPv6& right, const TString& condition) {
   TNetIPv6 netIP;
   netIP.Init(nullptr, nullptr);
   netIP.ReloadMemList(condition);

   Iterator left_it(left);
   typename Iterator::difference_type distance = left_it - Iterator::begin();
   typename Iterator::difference_type begin = helpers::get_random_start(distance);
   IsNotIncludeRange<Iterator>(netIP, Iterator::begin() + begin, left_it, helpers::get_step(distance), " is not in " + condition);

   Iterator right_it(right);
   distance = right_it - left_it;
   begin = helpers::get_random_start(distance);
   IsIncludeRange<Iterator>(netIP, left_it + begin, ++right_it, helpers::get_step(distance), " must be in " + condition);

   distance = Iterator::end() - right_it;
   begin = helpers::get_random_start(distance);
   IsNotIncludeRange<Iterator>(netIP, right_it + begin, Iterator::end(), helpers::get_step(distance), " is not in " + condition);
}

static const ui64 max_ipv4 = 0xFFFFFFFFL;
void TestIPv4ByMask(const ui8 ip_mask) {
   const ui64 mask = max_ipv4 >> ip_mask;
   ui64 low = helpers::get_random(max_ipv4);

   const TString ipv4_with_mask = TKIPv6(0, helpers::make_ipv4(low)).toStroka() + "/" + UI64ToStroka(ip_mask);

   low &= ~mask;
   const TKIPv6 left(0, helpers::make_ipv4(low));
   const TKIPv6 right(0, helpers::make_ipv4(low | mask));

   TestIPByRangeImpl<helpers::iteratorIPv4>(left, right, ipv4_with_mask);
}

void TestIPv4ByRange(const TKIPv6& left, const TKIPv6& right) {
   const TString ipv4_range = left.toStroka() + "\t" + right.toStroka();
   TestIPByRangeImpl<helpers::iteratorIPv4>(left, right, ipv4_range);
}

const ui64 max64 = std::numeric_limits<ui64>::max();
const TKIPv6 max_ipv6(max64, max64);

void TestIPv6ByMask(const ui8 ip_mask) {
   const helpers::ip_difference mask = helpers::ip_difference(max_ipv6) >> ip_mask;
   helpers::ip_difference ipv6 = helpers::get_random(max_ipv6);

   const TString ipv6_with_mask = ipv6.toIP().toStroka() + "/" + UI64ToStroka(ip_mask);

   ipv6 &= ~mask;
   const TKIPv6 left = ipv6.toIP();
   const TKIPv6 right = (ipv6 | mask).toIP();

   TestIPByRangeImpl<helpers::iteratorIPv6>(left, right, ipv6_with_mask);
}

void TestIPv6ByRange(const TKIPv6& left, const TKIPv6& right) {
   const TString ipv6_range = left.toStroka() + "\t" + right.toStroka();
   TestIPByRangeImpl<helpers::iteratorIPv6>(left, right, ipv6_range);
}

} //detail

Y_UNIT_TEST_SUITE(TestNetIPv6) {
   Y_UNIT_TEST(Simple) {
      const TString data =
         "# - комментарий\n"
         "3a02:2f0c:3000::/20\n"
         "2a02:2f0c:3000::1 2a02:2f0c:3000::2\n"
         "128.0.0.1/24\n"
         "127.0.0.1 127.0.0.2\n"
         "yandex.ru\n";

      TNetIPv6 NetIP;
      NetIP.Init(nullptr, nullptr);
      NetIP.ReloadMemList(data);

      UNIT_ASSERT_EQUAL(NetIP.IsIncludeNet("yandex.ru"), true);
      UNIT_ASSERT_EQUAL(NetIP.IsIncludeNet("yandex.com"), false);

      UNIT_ASSERT_EQUAL(NetIP.IsIncludeNet(TKIPv6("127.0.0.1")), true);
      UNIT_ASSERT_EQUAL(NetIP.IsIncludeNet(TKIPv6("127.0.0.3")), false);

      UNIT_ASSERT_EQUAL(NetIP.IsIncludeNet(TKIPv6("128.0.0.86")), true);
      UNIT_ASSERT_EQUAL(NetIP.IsIncludeNet(TKIPv6("128.0.1.86")), false);

      UNIT_ASSERT_EQUAL(NetIP.IsIncludeNet(TKIPv6("2a02:2f0c:3000::1")), true);
      UNIT_ASSERT_EQUAL(NetIP.IsIncludeNet(TKIPv6("2a02:2f0c:3000::3")), false);

      UNIT_ASSERT_EQUAL(NetIP.IsIncludeNet(TKIPv6("3a02:2eee:eeee::0")), true);
      UNIT_ASSERT_EQUAL(NetIP.IsIncludeNet(TKIPv6("3a02:3f00:0000::0")), false);
   }

   Y_UNIT_TEST(IPv4Mask) {
      for (ui8 ip_mask = 128; ip_mask > 0; --ip_mask)
         detail::TestIPv4ByMask((ip_mask) % 32 + 1);
   }

   Y_UNIT_TEST(IPv4Range) {
      for (ui32 count = 128; count > 0; --count)
      {
         ui64 left = helpers::get_random(detail::max_ipv4);
         ui64 right = helpers::get_random(detail::max_ipv4);
         if (left > right)
            std::swap(left, right);

         detail::TestIPv4ByRange(TKIPv6(0, helpers::make_ipv4(left)), TKIPv6(0, helpers::make_ipv4(right)));
      }
   }

   Y_UNIT_TEST(IPv6Mask) {
      for (ui8 ip_mask = 128; ip_mask > 0; --ip_mask)
         detail::TestIPv6ByMask(ip_mask);
   }

   Y_UNIT_TEST(IPv6Range) {
      for (ui32 count = 128; count > 0; --count) {
         TKIPv6 left  = helpers::get_random(detail::max_ipv6);
         TKIPv6 right = helpers::get_random(detail::max_ipv6);
         if (left > right)
            std::swap(left, right);

         detail::TestIPv6ByRange(left, right);
      }
   }

   Y_UNIT_TEST(ZeroMask) {
      const TString ipv6 = "471e:35a5:da3d:11d5:92db:a01d:55ce:e097";

      for (ui32 mask = 127; mask > 0; --mask) {
         helpers::ip_difference tmp = TKIPv6(ipv6.c_str());
         tmp &= helpers::ip_difference(detail::max_ipv6) >> mask;
         const TKIPv6 ipv6_with_zero = tmp.toIP();

         TNetIPv6 netIP;
         netIP.Init(nullptr, nullptr);
         netIP.ReloadMemList((ipv6_with_zero.toStroka() + "/" + UI64ToStroka(mask)).c_str());

         UNIT_ASSERT_EQUAL(netIP.IsIncludeNet(ipv6_with_zero), true);
      }
   }

   Y_UNIT_TEST(RealFail) {
      TNetIPv6 netIP;
      netIP.Init(nullptr, nullptr);
      netIP.ReloadMemList("471e:35a5:da3d:11d5:92db:a01d:55ce:e096/1");

      UNIT_ASSERT_EQUAL(netIP.IsIncludeNet(TKIPv6("0000:196b:b313:0c53:7060:ad30:da41:3174")), true);
   }

   Y_UNIT_TEST(RealFailIPv4) {
      TNetIPv6 netIP;
      netIP.Init(nullptr, nullptr);
      netIP.ReloadMemList("::ffff:31.137.213.173/1");

      UNIT_ASSERT_EQUAL(netIP.IsIncludeNet(TKIPv6("::ffff:0.0.26.115")), true);
   }

   Y_UNIT_TEST(GetAddress) {
       TNetIPv6 netIP;
       netIP.Init(nullptr, nullptr);

       TWIPv6B value;
       UNIT_ASSERT_EQUAL(netIP.GetAddresses("127.0.1.0", value), TNetIPv6::TIPADDRESS);
       UNIT_ASSERT_EQUAL(value.ip1, value.ip2);
       UNIT_ASSERT_EQUAL(value.ip1, TKIPv6("127.0.1.0"));

       UNIT_ASSERT_EQUAL(netIP.GetAddresses("127.0.1.0 127.1.0.0", value), TNetIPv6::TIPRANGE);
       UNIT_ASSERT_EQUAL(value.ip1, TKIPv6("127.0.1.0"));
       UNIT_ASSERT_EQUAL(value.ip2, TKIPv6("127.1.0.0"));

       UNIT_ASSERT_EQUAL(netIP.GetAddresses("127.0.1.0/32", value), TNetIPv6::TIPNET);
       UNIT_ASSERT_EQUAL(value.ip1, value.ip2);
       UNIT_ASSERT_EQUAL(value.ip1, TKIPv6("127.0.1.0"));
   }
}
