#include "geo_kdtree.h"

#include <library/cpp/geo/util.h>
#include <util/generic/set.h>
#include <util/generic/utility.h>
#include <util/generic/xrange.h>

namespace NRTYGeo {
    using namespace NGeo;

    namespace {
        static const TGeoPoint globalLower(MIN_LONGITUDE, MIN_LATITUDE); // lon, lat
        static const TGeoPoint globalUpper(MAX_LONGITUDE, MAX_LATITUDE); // lon, lat
        static const TGeoBoundingBox globalLimits(globalLower, globalUpper);

        static constexpr double EXTENDED_MIN_LONGITUDE = MIN_LONGITUDE - WORLD_WIDTH;
        static constexpr double EXTENDED_MAX_LONGITUDE = MAX_LONGITUDE + WORLD_WIDTH;

        using TRecord = TKdTree::TRecord;

        size_t NextDimIndex(size_t dimIndex) {
            return (dimIndex + 1) % TKdTree::Dimension;
        }

        template <class T>
        static T* Mid(T* begin, T* end) {
            return begin + (end - begin) / 2;
        }

        void MakeKdTree(TRecord* begin, TRecord* end, size_t dimIndex) {
            if (begin == end) {
                return;
            }

            TRecord* mid = Mid(begin, end);
            std::nth_element(begin, mid, end, [=](const TRecord& a, const TRecord& b) {
                return a.Coords[dimIndex] < b.Coords[dimIndex];
            });

            if (mid > begin) {
                MakeKdTree(begin, mid, NextDimIndex(dimIndex));
            }
            mid++;
            if (end > mid) {
                MakeKdTree(mid, end, NextDimIndex(dimIndex));
            }
        }

        template <typename TOnResult>
        void FindLe(
                const TRecord* begin,
                const TRecord* end,
                TKdTree::TFlags fl,
                const TKdTree::TKey& key,
                size_t dimIndex,
                TOnResult&& onResult) {
            if (begin == end) {
                return;
            }
            if (std::all_of(fl.begin(), fl.end(), [](bool b) { return b; })) {
                for (; begin != end; ++begin) {
                    if (!onResult(*begin)) {
                        return;
                    }
                }
                return;
            }

            const TRecord* mid = Mid(begin, end);
            if (mid->Coords[dimIndex] <= key[dimIndex]) {
                if (std::equal(mid->Coords.begin(), mid->Coords.end(), key.begin(), std::less_equal<>{})) {
                    if (!onResult(*mid)) {
                        return;
                    }
                }

                // left branch
                fl[dimIndex] = true;
                FindLe(begin, mid, fl, key, NextDimIndex(dimIndex), onResult);
                // right branch
                fl[dimIndex] = false;
                FindLe(mid + 1, end, fl, key, NextDimIndex(dimIndex), onResult);
            } else {
                // left branch
                fl[dimIndex] = false;
                FindLe(begin, mid, fl, key, NextDimIndex(dimIndex), onResult);
            }
        }

        // @returns "end" if not sorted, "mid" otherwise
        const TRecord* IsSorted(const TRecord* begin, const TRecord* end, size_t dimIndex) {
            Y_ASSERT(begin != end);
            auto root = Mid(begin, end);
            auto mid = root;

            if (begin != mid) {
                auto left = IsSorted(begin, mid, NextDimIndex(dimIndex));
                const bool hasError = left == mid;
                if (hasError || root->Coords[dimIndex] < left->Coords[dimIndex]) {
                    return end;
                }
            }
            mid++;
            if (mid != end) {
                auto right = IsSorted(mid, end, NextDimIndex(dimIndex));
                const bool hasError = right == end;
                if (hasError || right->Coords[dimIndex] < root->Coords[dimIndex]) {
                    return end;
                }
            }
            return root;
        }

        bool IsSorted(const TRecord* begin, const TRecord* end) {
            return begin == end || IsSorted(begin, end, 0) != end;
        }

        // a simplified "pair of doubles" struct that is only for the normalization routines
        struct TPoint {
            double Lon;
            double Lat;

            TPoint(const TGeoPoint& p)
                : Lon(p.Lon())
                , Lat(p.Lat())
            {
            }

            TPoint(double lon, double lat)
                : Lon(lon)
                , Lat(lat)
            {
            }
        };

        // a simplified "pair of points" struct that is only for the normalization routines
        struct TWindow {
            TPoint Lower;
            TPoint Upper;

            TWindow(const TPoint& lower, const TPoint& upper)
                : Lower(lower)
                , Upper(upper)
            {
            }
        };

        bool IsNormalOrder(const TWindow& box) {
            return !(box.Upper.Lon < box.Lower.Lon) && !(box.Upper.Lat < box.Lower.Lat);
        }

        TGeoWindow AsShiftedWindow(const TWindow& box, int shift) {
            return TGeoWindow(
                TGeoPoint(box.Lower.Lon - shift * WORLD_WIDTH, box.Lower.Lat),
                TGeoPoint(box.Upper.Lon - shift * WORLD_WIDTH, box.Upper.Lat));
        }
    }

    template <class TOnResult>
    void TKdTree::FindLe(const TKey& key, TOnResult&& onResult) const {
        return ::NRTYGeo::FindLe(GetData().begin(), GetData().end(), {}, key, 0, onResult);
    }

    size_t TKdTree::CountLe(const TKey& key) const {
        size_t count = 0;
        FindLe(key, [&](const TRecord&) { ++count; return true; });
        return count;
    }

    void TKdTree::CountUniqueLe(absl::flat_hash_set<ui32>& output, ui32 limit, const TKey& key) const {
        FindLe(key, [&](const TRecord& record) {
            if (output.size() >= limit) {
                return false;
            }
            output.insert(record.DocId);
            return output.size() < limit;
        });
    }

    void TGeoIndexTree::FindIntersections(absl::flat_hash_set<ui32>& result, const NGeo::TGeoWindow& w) const {
        Y_ENSURE(IsNormalized(w));
        TKey key = QueryToKey(w);
        FindLe(key, [&](const TRecord& record) { result.insert(record.DocId); return true; });
    }

    TVector<ui32> TGeoIndexTree::FindIntersections(const NGeo::TGeoWindow& w0, ui32 prunLimit) const {
        // w is permitter to be non-normalized
        TVector<NGeo::TGeoWindow> request;
        if (Y_LIKELY(IsNormalized(w0))) {
            request = {w0};
        } else {
            request = Normalize(w0);
        }

        absl::flat_hash_set<ui32> resultSet;
        for (const NGeo::TGeoWindow& w : request) {
            FindIntersections(resultSet, w);
        }
        TVector<ui32> result(Reserve(resultSet.size()));
        result.insert(result.end(), resultSet.begin(), resultSet.end());
        std::sort(result.begin(), result.end());
        result.crop(prunLimit);
        return result;
    }

    size_t TGeoIndexTree::CountHits(const TVector<NGeo::TGeoWindow>& request) const {
        size_t numHits = 0;
        for (const NGeo::TGeoWindow& w: request) {
            Y_ASSERT(IsNormalized(w));
            numHits += CountHits(w.GetLowerLeftCorner(), w.GetUpperRightCorner());
        }
        return numHits;
    }

    TVector<ui32> TGeoIndexTree::FindIntersectionsWithDynamicPruning(const NGeo::TGeoWindow& w0, ui32 nDoc, bool mayCollapse, bool mayExpand, const NGeo::TSize* maxSize) const {
        NGeo::TGeoWindow w = w0;
        bool expanded = false;
        static const NGeo::TSize noSizeLimit{NGeo::WORLD_WIDTH, NGeo::WORLD_HEIGHT};
        if (mayExpand) {
            static constexpr float minExpansionSize = 1e-3;
            std::tie(w, expanded) = ExpandToDocCount(w, minExpansionSize, maxSize ? *maxSize : noSizeLimit, nDoc);
        }
        if (mayCollapse && (!expanded || maxSize)) {
            return FindNClosest(w, nDoc);
        }
        return FindIntersections(w, Max<ui32>());
    }

    bool TGeoIndexTree::IsEnoughDocsPredicate(const TVector<NGeo::TGeoWindow>& request, ui32 requiredCount) const {
        if (CountHits(request) < requiredCount) {
            return false;
        }

        absl::flat_hash_set<ui32> output;
        for (const NGeo::TGeoWindow& w: request) {
            Y_ASSERT(IsNormalized(w));
            TKey key = QueryToKey(w.GetLowerLeftCorner(), w.GetUpperRightCorner());
            CountUniqueLe(output, requiredCount, key);
            if (output.size() >= requiredCount)
                break;
        }
        return output.size() >= requiredCount;
    }

    std::pair<NGeo::TGeoWindow, bool> TGeoIndexTree::ExpandToDocCount(const NGeo::TGeoWindow& w0, const double epsilonSize, const NGeo::TSize& maxSize, ui32 requiredCount) const {
        // w0 is permitted to not be normalized
        Y_ENSURE(w0.IsValid() && maxSize.IsValid());
        TVector<NGeo::TGeoWindow> request = Normalize(w0);
        if (IsEnoughDocsPredicate(request, requiredCount))
            return {w0, false}; // no changes in window

        NGeo::TGeoWindow w = w0;
        if (NGeo::SquaredDiagonal(w.GetSize(), w.GetCenter().Lat()) < Sqr(epsilonSize)) {
            // to double the size of window iteratively, the initial size must be non-zero
            const double minWidth = Min(epsilonSize / Max(GetLatCos(w.GetCenter().Lat()), 1e-6), NGeo::WORLD_WIDTH);
            const double minHeight = epsilonSize;
            NGeo::TSize newSize{Max(w.GetSize().GetWidth(), minWidth), Max(w.GetSize().GetHeight(), minHeight)};
            w = NGeo::TGeoWindow(w.GetCenter(), newSize);
        }

        double prevMult = 1.0f;
        double mult = 1.0f;
        NGeo::TGeoPoint prevLower, prevUpper;
        const double height0 = w.GetSize().GetHeight() / 2;
        const double width0 = w.GetSize().GetWidth() / 2;
        const NGeo::TGeoPoint p = w.GetCenter();
        const NGeo::TGeoWindow limits(w.GetCenter(), maxSize);

        ui32 retryLimit = 20;

        auto getCornersForMult = [=](double m) {
            double height = height0 * m;
            double width = width0 * m;
            NGeo::TGeoPoint lower(
                    Max(p.Lon() - width, limits.GetLowerLeftCorner().Lon()),
                    Max(p.Lat() - height, limits.GetLowerLeftCorner().Lat()));
            NGeo::TGeoPoint upper(
                    Min(p.Lon() + width, limits.GetUpperRightCorner().Lon()),
                    Min(p.Lat() + height, limits.GetUpperRightCorner().Lat()));
            return std::make_pair(lower, upper);
        };

        NGeo::TGeoPoint lower = w.GetLowerLeftCorner();
        NGeo::TGeoPoint upper = w.GetUpperRightCorner();
        for (; retryLimit; --retryLimit) {
            prevMult = mult;
            mult *= 2;
            std::tie(prevLower, prevUpper) = std::tie(lower, upper);
            std::tie(lower, upper) = getCornersForMult(mult);
            if (lower == prevLower && upper == prevUpper) {
                // limit is reached, not enough docs
                retryLimit = 0;
                break;
            }

            w = NGeo::TGeoWindow(lower, upper);
            request = Normalize(w);
            if (IsEnoughDocsPredicate(request, requiredCount)) {
                break;
            }
        }

        if (retryLimit) {
            // we have "true" for mult and "false" for prevMult
            // let us make a few binary search iterations to choose a smaller value of mult
            double begin = prevMult;
            double end = mult;
            double mid = mult;
            NGeo::TGeoPoint lower, upper;
            for (retryLimit = Min(5u, retryLimit); retryLimit; --retryLimit) {
                mid = (begin + end) / 2;
                std::tie(lower, upper) = getCornersForMult(mid);
                w = NGeo::TGeoWindow(lower, upper);
                request = Normalize(w);
                if (IsEnoughDocsPredicate(request, requiredCount)) {
                    end = mid;
                } else {
                    begin = mid;
                }
            }
            if (end != mid) {
                std::tie(lower, upper) = getCornersForMult(end);
                w = NGeo::TGeoWindow(lower, upper);
            }
            Y_ASSERT(w.IsValid());
        }

        return {w, true};
    }

    TVector<ui32> TGeoIndexTree::FindNClosest(const NGeo::TGeoWindow& w0, ui32 n) const {
        // This works like FindIntersections with prunLimit = Max<ui32>(), but then we rank the matches by approximated geo_dist, and choose n closest hits
        // w0 is permitted to not be normalized
        Y_ENSURE(w0.IsValid());
        TVector<NGeo::TGeoWindow> request = Normalize(w0);

        TVector<std::pair<ui32, float>> result;

        const NGeo::TGeoPoint center = w0.GetCenter();
        const TKey centerKey = QueryToKey(center, center);
        const float latCoeff = NGeo::GetLatCos(center.Lat());

        for (const NGeo::TGeoWindow& w : request) {
            TKey key = QueryToKey(w);
            auto dist = [&](const TKey& item) -> float {
                const bool centerIsInsideItem = std::equal(item.begin(), item.end(), centerKey.begin(), [](float a, float b) { return a <= b; });
                if (centerIsInsideItem)
                    return 0.f;
                NGeo::TGeoPoint lower, upper;
                std::tie(lower, upper) = KeyToCorners(item);
                float dLat = (fabs(lower.Lat() - center.Lat()) + fabs(upper.Lat() - center.Lat()) - fabs(lower.Lat() - upper.Lat())) / 2;
                float dLon = (fabs(lower.Lon() - center.Lon()) + fabs(upper.Lon() - center.Lon()) - fabs(lower.Lon() - upper.Lon())) / 2;
                while (dLat > NGeo::WORLD_HEIGHT / 2)
                    dLat -= NGeo::WORLD_HEIGHT;
                while (dLon > NGeo::WORLD_WIDTH / 2)
                    dLon -= NGeo::WORLD_WIDTH;

                dLat = fabs(dLat);
                dLon = fabs(dLon) * latCoeff;

                // now compute the octagonal approximation of sqrt(dx^2+dy^2) (6% MoE)
                if (dLat < dLon)
                    std::swap(dLat, dLon);
                const float octagonalDistance = 0.41f * dLon + 0.941246f * dLat;
                return octagonalDistance;
            };
            FindLe(key, [&](const TRecord& record) {
                result.emplace_back(record.DocId, dist(record.Coords));
                return true;
            });
        }

        TVector<ui32> docIds;
        if (result.size() > n) {
            TSet<ui32> uniqueDocs;
            SortBy(result, [&](auto docidDistance) {
                const auto [docId, dist] = docidDistance;
                return std::pair(dist, docId);
            });

            for (const auto& [docId, _]: result) {
                uniqueDocs.insert(docId);
                if (uniqueDocs.size() == n) {
                    break;
                }
            }
            docIds.reserve(uniqueDocs.size());
            docIds.assign(uniqueDocs.begin(), uniqueDocs.end());
        } else {
            docIds.reserve(result.size());
            for (const auto& [docId, _]: result) {
                docIds.push_back(docId);
            }
            SortUnique(docIds);
        }
        Y_ASSERT(std::is_sorted(docIds.begin(), docIds.end()));
        return docIds;
    }

    bool TGeoIndexTree::IsNormalized(const TGeoWindow& w) {
        auto lower = w.GetLowerLeftCorner();
        auto upper = w.GetUpperRightCorner();
        return (lower.Lon() <= upper.Lon() && lower.Lat() <= upper.Lat() && lower.Lon() >= globalLower.Lon() && lower.Lat() >= globalLower.Lat() && upper.Lon() <= globalUpper.Lon() && upper.Lat() <= globalUpper.Lat());
    }

    TVector<TGeoWindow> TGeoIndexTree::Normalize(const TGeoWindow& w) {
        Y_ENSURE(w.IsValid()); // Incorrect argument: a non-empty range was expected
        if (Y_LIKELY(IsNormalized(w)))
            return {w};

        TVector<TGeoWindow> result;
        TWindow cur(w.GetLowerLeftCorner(), w.GetUpperRightCorner());
        if (Y_UNLIKELY(!IsNormalOrder(cur))) {
            Y_ASSERT(0); // Incorrect argument: TGeoWindow internal data is broken
            return result;
        }

        // First, clamp the Window into {{-540,-90},{540,90}} range
        // Next then, split the range into one or more rectangles that are inside the "standard" globalLimits range, removing any rollovers
        cur.Lower.Lat = Max(MIN_LATITUDE, cur.Lower.Lat);
        cur.Lower.Lon = Max(EXTENDED_MIN_LONGITUDE, cur.Lower.Lon);
        cur.Upper.Lat = Min(MAX_LATITUDE, cur.Upper.Lat);
        cur.Upper.Lon = Min(EXTENDED_MAX_LONGITUDE, cur.Upper.Lon);
        if (cur.Upper.Lon < cur.Lower.Lon || cur.Upper.Lat < cur.Lower.Lat) {
            // Nothing left after clamping
            return result;
        }
        if (cur.Upper.Lon - cur.Lower.Lon >= WORLD_WIDTH) {
            // the window is bigger than the world and intersects with itself - let us replace it with a sensible range
            cur.Lower.Lon = MIN_LONGITUDE;
            cur.Upper.Lon = MAX_LONGITUDE;
        }

        constexpr auto Worlds = xrange<ui32>(3);
        constexpr std::array<double, Worlds.size()> WorldEnds{MIN_LONGITUDE, MAX_LONGITUDE, EXTENDED_MAX_LONGITUDE};

        for (ui32 worldNo : Worlds) {
            Y_ASSERT(cur.Lower.Lon <= cur.Upper.Lon);

            double split = WorldEnds[worldNo];
            if (cur.Lower.Lon < split) {
                TWindow w1(
                    TPoint(cur.Lower.Lon, cur.Lower.Lat),
                    TPoint(Min(cur.Upper.Lon, split), cur.Upper.Lat));
                result.push_back(AsShiftedWindow(w1, (int)worldNo - 1));
                if (cur.Upper.Lon <= split)
                    break; // all done

                // Now exclude w1 from cur and handle the rest
                cur.Lower.Lon = w1.Upper.Lon;
                Y_ASSERT(cur.Lower.Lon == split);
            }
        }
        return result;
    }

    void TGeoIndexTreeBuilder::AddHit(ui32 docId, const NGeo::TGeoWindow& w) {
        Y_ENSURE(!ReadOnly);
        Y_ASSERT(IsNormalized(w));
        Data.emplace_back(TRecord{docId, ItemToKey(w)});
    }

    void TGeoIndexTreeBuilder::AddHit(ui32 docId, const TGeoIndexTree::TKey& key) {
        Y_ASSERT(!ReadOnly);
        Y_ASSERT(IsItem(key)); // please pre-check
        Data.emplace_back(TRecord{docId, key});
    }

    void TGeoIndexTreeBuilder::Remap(const TVector<ui32>& remapTable) {
        constexpr ui32 Removed = (ui32)-1; // used in remapTable

        Y_ENSURE(!ReadOnly);
        auto pOut = Data.begin();
        for (auto pRecord = Data.begin(); pRecord != Data.end(); ++pRecord) {
            const ui32 tmpDocId = pRecord->DocId;
            const ui32 newDocId = tmpDocId < remapTable.size() ? remapTable[tmpDocId] : Removed;
            if (newDocId != Removed) {
                if (newDocId == tmpDocId && pOut == pRecord) {
                    pOut++;
                } else {
                    *pOut++ = TRecord{newDocId, pRecord->Coords};
                }
            }
        }
        if (pOut != Data.end())
            Data.resize(pOut - Data.begin());
    }

    void TGeoIndexTreeBuilder::Finalize() {
        Y_ENSURE(!ReadOnly);
        ReadOnly = true;
        MakeKdTree(Data.begin(), Data.end(), 0);
    }

    bool TGeoIndexTreeBuilder::Assign(TVector<TRecord>&& data) {
        if (!IsSorted(data.begin(), data.end()))
            return false;

        Data = std::move(data);
        ReadOnly = true;
        return true;
    }

    void TGeoIndexTreeBuilder::AssignTo(TGeoIndexTree& tree) {
        Y_ENSURE(ReadOnly);

        tree.Data.clear();
        tree.Data.swap(Data);
        tree.ReadOnly = true;
    }
}
