#include "geo_kdtree.h"
#include "mingeo_index.h"

#include <library/cpp/geo/util.h>
#include <library/cpp/testing/unittest/registar.h>
#include <util/stream/buffer.h>
#include <util/generic/bitops.h>
#include <util/generic/buffer.h>
#include <util/string/hex.h>

#include <random>

namespace NRTYGeo {
    using NGeo::TGeoPoint;
    using NGeo::TGeoWindow;
    using NGeo::TSize;

    class TGeoModelTest: public NUnitTest::TTestBase {
        UNIT_TEST_SUITE(TGeoModelTest)
        UNIT_TEST(TestNormalization);
        UNIT_TEST(TestSearchWithRollovers);
        UNIT_TEST_SUITE_END();

    private:
        static void PrintRect(IOutputStream& s, const TGeoWindow& rect) {
            // we dont want standard TGeoWindow.TString() here
            s << "(" << rect.GetLowerLeftCorner() << " " << rect.GetUpperRightCorner() << ")";
        }

        static void PrintRect(IOutputStream& s, const TVector<TGeoWindow>& normalization) {
            bool first = true;
            for (const TGeoWindow& rect : normalization) {
                if (first)
                    first = false;
                else
                    s << " ";
                PrintRect(s, rect);
            }
        }

        static TString FormatRect(const TGeoWindow& rect) {
            TStringStream s;
            PrintRect(s, rect);
            return s.Str();
        }

        static TString FormatRect(const TVector<TGeoWindow>& normalization) {
            TStringStream s;
            PrintRect(s, normalization);
            return s.Str();
        }

        static TGeoWindow MakeMetricWindow(const TGeoPoint& geodeiticCoords, NGeo::TSize metricSize) {
            return MercatorToLL(NGeo::TMercatorWindow(NGeo::LLToMercator(geodeiticCoords), metricSize));
        }

        static TGeoWindow MakeLonShift(const TGeoWindow& rect, double shift) {
            return TGeoWindow(
                TGeoPoint(rect.GetLowerLeftCorner().Lon() + shift, rect.GetLowerLeftCorner().Lat()),
                TGeoPoint(rect.GetUpperRightCorner().Lon() + shift, rect.GetUpperRightCorner().Lat()));
        }

        static TMaybe<NGeo::TGeoWindow> GetIntersectionEx(const TGeoWindow& obj, const TGeoWindow& rect) {
            TMaybe<NGeo::TGeoWindow> result;
            result = NGeo::Intersection(obj, rect);
            if (result.Defined())
                return result;

            std::array<double, 2> shifts = {-NGeo::WORLD_WIDTH, NGeo::WORLD_WIDTH};
            for (double shift : shifts) {
                TGeoWindow s = MakeLonShift(rect, shift);
                result = NGeo::Intersection(obj, s);
                if (result.Defined())
                    break;
            }
            return result;
        }

        static void ValidateNormalization(const TGeoWindow& obj, const TVector<TGeoWindow>& rects) {
            double coveredArea = 0.0;
            double objArea = obj.Area();
            double rectsArea = 0.0;

            UNIT_ASSERT(rects.size() == 1 || !TGeoIndexTree::IsNormalized(obj));

            for (const NGeo::TGeoWindow& rect : rects) {
                UNIT_ASSERT(rects.size() == 1 || rect.Area() > 0);

                rectsArea += rect.Area();
                TMaybe<NGeo::TGeoWindow> intersection = GetIntersectionEx(obj, rect);
                if (intersection.Defined()) {
                    coveredArea += intersection->Area();
                }
                UNIT_ASSERT(intersection.Defined());
            }

            // Check that objArea == coveredArea, rectsArea == objArea, and the tiles does not intersect
            UNIT_ASSERT_LE_C(objArea - coveredArea, objArea * 1e-6, "the object is not fully covered by the tiles: " << objArea << " " << coveredArea);

            for (size_t i = 0; i < rects.size(); ++i) {
                for (size_t j = i + 1; j < rects.size(); ++j) {
                    TMaybe<NGeo::TGeoWindow> intersection = NGeo::Intersection(rects[i], rects[j]);
                    UNIT_ASSERT_C(!intersection.Defined() || intersection->Area() == 0.0, "some tiles intersect");
                }
            }
            UNIT_ASSERT_LE_C(coveredArea - objArea, objArea * 1e-6, "some tiles intersect");
            UNIT_ASSERT_LE_C(rectsArea - objArea, rectsArea * 1e-6, "the tiles is bigger than the object: " << rectsArea << " " << objArea);
        }

    public:
        void TestNormalization() {
            // Check some samples
            TGeoPoint yandexHQ(/*lon=*/37.6155, /*lat=*/55.7522);

            constexpr int M = 1000000;
            constexpr int K = 1000;

            for (int size : {0, 5, 20, 2 * K, 20 * K, 5 * M, 30 * M}) {
                TGeoWindow w = MakeMetricWindow(yandexHQ, TSize(size, size));
                auto vect = TGeoIndexTree::Normalize(w);
                UNIT_ASSERT(vect.size() == 1);

                const TGeoWindow& expected = w;
                const TGeoWindow& actual = vect[0];

                ValidateNormalization(w, vect);
                UNIT_ASSERT_VALUES_EQUAL(FormatRect(expected), FormatRect(actual));
            }

            // check a case when the window crosses lon=+-180 line
            {
                TGeoWindow rolloverWindow(TGeoPoint(179.99, 33.3), TSize(0.1, 0.1));
                TVector<TGeoWindow> actual = TGeoIndexTree::Normalize(rolloverWindow);
                ValidateNormalization(rolloverWindow, actual);
                UNIT_ASSERT_VALUES_EQUAL("([179.94, 33.24998553] [180, 33.34998553]) ([-180, 33.24998553] [-179.96, 33.34998553])", FormatRect(actual));

                //...and the negative one
                TGeoWindow rolloverWindow2(TGeoPoint(-179.99, 1.3), TSize(0.1, 0.1));
                TVector<TGeoWindow> actual2 = TGeoIndexTree::Normalize(rolloverWindow2);
                ValidateNormalization(rolloverWindow2, actual2);
                UNIT_ASSERT_VALUES_EQUAL("([179.96, 1.249999498] [180, 1.349999498]) ([-180, 1.249999498] [-179.94, 1.349999498])", FormatRect(actual2));
            }

            // check the "whole world" case
            {
                TGeoWindow wholeWorld(TGeoPoint(-180, -90), TGeoPoint(180, 90));
                auto vect = TGeoIndexTree::Normalize(wholeWorld);
                UNIT_ASSERT(vect.size() == 1);
                ValidateNormalization(wholeWorld, vect);
                const TGeoWindow& actual = vect[0];
                UNIT_ASSERT_VALUES_EQUAL(FormatRect(wholeWorld), FormatRect(actual));
            }

            // check 15 "almost whole world" cases
            {
                for (ui32 flags = 1; flags < (1 << 4); ++flags) {
                    const double eps = 1e-3;
                    std::array<double, 4> deltas = {
                        ((flags & 1) ? eps : 0),
                        ((flags & 2) ? eps : 0),
                        ((flags & 4) ? -eps : 0),
                        ((flags & 8) ? -eps : 0)};
                    TGeoWindow almostWholeWorld(TGeoPoint(-180 + deltas[0], -90 + deltas[1]), TGeoPoint(180 + deltas[2], 90 + deltas[3]));
                    auto vect = TGeoIndexTree::Normalize(almostWholeWorld);
                    ValidateNormalization(almostWholeWorld, vect);
                    UNIT_ASSERT(vect.size() == 1);
                    const TGeoWindow& actual = vect[0];
                    UNIT_ASSERT_VALUES_EQUAL(FormatRect(almostWholeWorld), FormatRect(actual));
                }
            }

            //TODO(yrum): add more zero-sized cases, corner cases, rollover test cases
        }

    private:
        static TVector<ui32> MockRequest(const TGeoIndexTree& kdTree, const NGeo::TGeoWindow& reqRange) {
            TVector<ui32> result;
            if (TGeoIndexTree::IsNormalized(reqRange)) {
                result = kdTree.FindIntersections(reqRange, Max<ui32>());
            } else {
                TVector<NGeo::TGeoWindow> parts = TGeoIndexTree::Normalize(reqRange);
                for (const auto& part : parts) {
                    TVector<ui32> partResult = kdTree.FindIntersections(part, Max<ui32>());
                    result.insert(result.end(), partResult.begin(), partResult.end());
                }
            }
            Sort(result);
            auto pEnd = std::unique(result.begin(), result.end());
            result.resize(pEnd - result.begin());
            return result;
        }

    public:
        void TestSearchWithRollovers() {
            // Run Kd test with a small randomized sample, with many rollovers (also works as a smoke test)
            CheckKdTreeRng(10000, 100, 0xDEADDADD);
        }

        void CheckKdTreeRng(ui32 docCount, ui32 reqCount, std::random_device::result_type seed) {
            TVector<NGeo::TGeoWindow> documents;
            std::uniform_real_distribution<double> disLon(NGeo::MIN_LONGITUDE - NGeo::WORLD_WIDTH, NGeo::MAX_LONGITUDE + NGeo::WORLD_WIDTH);
            std::uniform_real_distribution<double> disLat(NGeo::MIN_LATITUDE, NGeo::MAX_LATITUDE);
            std::uniform_real_distribution<double> disRange(0, NGeo::MAX_LATITUDE / 2);
            std::minstd_rand rng(seed);

            auto randomWindow = [&]() {
                TGeoPoint a(disLon(rng), disLat(rng));
                TGeoPoint b(disLon(rng), disLat(rng));
                return TGeoWindow(a, b); // will rearrange coords as requested
            };

            for (ui32 docId = 0; docId < docCount; ++docId) {
                documents.push_back(randomWindow());
            }

            auto getExpectedIntersections = [&](NGeo::TGeoWindow& request) {
                TVector<ui32> result;

                TVector<NGeo::TGeoWindow> requests;
                for (int shift = -2; shift <= 2; ++shift) {
                    auto req = MakeLonShift(request, shift * NGeo::WORLD_WIDTH);
                    if (req.GetUpperRightCorner().Lon() < disLon.a() || req.GetLowerLeftCorner().Lon() > disLon.b())
                        continue;
                    requests.push_back(req);
                }

                for (ui32 docId = 0; docId < documents.size(); ++docId) {
                    for (const auto& request : requests) {
                        if (NGeo::Intersection(documents[docId], request).Defined()) {
                            result.push_back(docId);
                            break;
                        }
                    }
                }
                return result;
            };

            constexpr ui8 streamId = 127;
            constexpr ui64 kps = 255*256;
            TGeoIndex kdIndex;
            TGeoIndexBuilder bld;

            for (ui32 docId = 0; docId < documents.size(); ++docId) {
                bld.AddDoc(streamId, kps, docId, documents[docId]);
            }
            bld.Finalize(kdIndex);

            TGeoIndexTree& kdTree = *kdIndex.Trees[TPartKey{streamId, kps}];

            for (ui32 reqNo = 0; reqNo < reqCount; ++reqNo) {
                TGeoPoint reqCenter(disLon(rng), disLat(rng));
                TSize reqSize(disRange(rng), disRange(rng));
                TGeoWindow req(reqCenter, reqSize);

                TVector<ui32> expected = getExpectedIntersections(req);
                TVector<ui32> actual = MockRequest(kdTree, req);
                bool differ = !std::equal(expected.begin(), expected.end(), actual.begin(), actual.end());
                UNIT_ASSERT_C(!differ, "reqNo: " << reqNo);
            }
        }
    };
}

UNIT_TEST_SUITE_REGISTRATION(NRTYGeo::TGeoModelTest);
