#pragma once

#include <library/cpp/containers/absl_flat_hash/flat_hash_set.h>
#include <library/cpp/geo/geo.h>
#include <util/generic/vector.h>
#include <util/generic/hash_set.h>
#include <array>

namespace NRTYGeo {
    class TKdTree {
    public:
        static constexpr size_t Dimension = 4;
        using TKey = std::array<float, Dimension>;
        using TFlags = std::array<bool, Dimension>;

        struct TRecord {
            ui32 DocId;
            TKey Coords;
        };

    public:
        TVector<TRecord> Data;
        bool ReadOnly = false;

        const TVector<TRecord>& GetData() const {
            Y_ENSURE(ReadOnly); // the 4D tree must be finalized before serialization
            return Data;
        }

    protected:
        template <class TOnResult>
        void FindLe(const TKey& key, TOnResult&& onResult) const;

        size_t CountLe(const TKey& key) const;

        void CountUniqueLe(absl::flat_hash_set<ui32>& output, ui32 limit, const TKey& key) const;
    };

    class TGeoIndexTree: public TKdTree {
    public:
        static TKey QueryToKey(const NGeo::TGeoPoint& lower, const NGeo::TGeoPoint& upper) {
            return {(float)upper.Lon(), (float)upper.Lat(), (float)-lower.Lon(), (float)-lower.Lat()};
        }

        static TKey QueryToKey(const NGeo::TGeoWindow& w) {
            return QueryToKey(w.GetLowerLeftCorner(), w.GetUpperRightCorner());
        }

        static TKey ItemToKey(const NGeo::TGeoWindow& w) {
            auto lower = w.GetLowerLeftCorner();
            auto upper = w.GetUpperRightCorner();
            return {(float)lower.Lon(), (float)lower.Lat(), (float)-upper.Lon(), (float)-upper.Lat()};
        }

        static bool IsItem(const TKey& key) {
            return -key[2] >= key[0] && -key[3] >= key[1];
        }

        static std::pair<NGeo::TGeoPoint, NGeo::TGeoPoint> KeyToCorners(const TKey& key) {
            Y_VERIFY_DEBUG(IsItem(key), "incorrect item");
            NGeo::TGeoPoint lower(key[0], key[1]);
            NGeo::TGeoPoint upper(-key[2], -key[3]);
            return {lower, upper};
        }

        static NGeo::TGeoWindow KeyToItem(const TKey& key) {
            Y_ENSURE(IsItem(key), "incorrect item");
            NGeo::TGeoPoint lower, upper;
            std::tie(lower, upper) = KeyToCorners(key);
            return NGeo::TGeoWindow(lower, upper);
        }

        static bool IsNormalized(const NGeo::TGeoWindow& w);
        static TVector<NGeo::TGeoWindow> Normalize(const NGeo::TGeoWindow& w);
    private:
        bool IsEnoughDocsPredicate(const TVector<NGeo::TGeoWindow>& request, ui32 requiredCount) const;

    public:
        void FindIntersections(absl::flat_hash_set<ui32>& result, const NGeo::TGeoWindow& w) const;

        TVector<ui32> FindIntersections(const NGeo::TGeoWindow& w0, ui32 prunLimit) const;

        size_t CountHits(const TKey& key) const {
            return CountLe(key);
        }

        size_t CountHits(const NGeo::TGeoPoint& lower, const NGeo::TGeoPoint& upper) const {
            return CountHits(QueryToKey(lower, upper));
        }

        size_t CountHits(const TVector<NGeo::TGeoWindow>& request) const;

        std::pair<NGeo::TGeoWindow, bool> ExpandToDocCount(const NGeo::TGeoWindow& w0, const double epsilonSize, const NGeo::TSize& maxSize, ui32 requiredCount) const;

        TVector<ui32> FindNClosest(const NGeo::TGeoWindow& w0, ui32 n) const;

        TVector<ui32> FindIntersectionsWithDynamicPruning(const NGeo::TGeoWindow& w0, ui32 nDoc, bool mayCollapse, bool mayExpand, const NGeo::TSize* maxSize) const;
    };

    class TGeoIndexTreeBuilder: public TGeoIndexTree {
    public:
        void AddHit(ui32 docId, const NGeo::TGeoWindow& w);

        void AddHit(ui32 docId, const TGeoIndexTree::TKey& key);

        void Remap(const TVector<ui32>& remapTable);

        void Finalize();

    public:
        [[nodiscard]]
        bool Assign(TVector<TRecord>&& data);
        void AssignTo(TGeoIndexTree& tree);
    };

}
