#pragma once

#include <mail/so/libs/talkative_config/config.h>
#include <library/cpp/hnsw/index/index_base.h>
#include <library/cpp/hnsw/index/index_item_storage_base.h>
#include <library/cpp/hnsw/index_builder/index_builder.h>
#include <library/cpp/hnsw/index_builder/index_writer.h>
#include <library/cpp/threading/atomic_shared_ptr/atomic_shared_ptr.h>
#include <library/cpp/threading/skip_list/skiplist.h>
#include <util/random/fast.h>
#include <util/random/shuffle.h>
#include <util/random/easy.h>
#include <util/stream/buffer.h>
#include <util/system/rwlock.h>
#include <util/system/mutex.h>
#include <util/system/condvar.h>
#include <util/thread/pool.h>
#include <util/generic/scope.h>
#include <util/generic/deque.h>

namespace NBatchHnsw {

    struct TConfig {
        static TConfig FromConfig(const NConfig::TDict &config) {
            TConfig knnConfig;

            if (auto it = MapFindPtr(config, "ttl")) {
                knnConfig.Ttl = NTalkativeConfig::As<TDuration>(*it);
            }

            if (auto it = MapFindPtr(config, "buffer_max_size")) {
                knnConfig.BuffersMaxSize = NTalkativeConfig::As<size_t>(*it);
            }

            if (auto it = MapFindPtr(config, "max_neighbors")) {
                knnConfig.MaxNeighbors = NTalkativeConfig::As<size_t>(*it);
            }

            if (auto it = MapFindPtr(config, "dump")) {
                knnConfig.DumpingFile = NTalkativeConfig::Get<TString>(*it);
            }

            if (auto it = MapFindPtr(config, "index_threads")) {
                knnConfig.IndexThreads = NTalkativeConfig::As<size_t>(*it);
            }

            if (auto it = MapFindPtr(config, "max_indexes")) {
                knnConfig.MaxIndexes = NTalkativeConfig::As<size_t>(*it);
            }

            if (auto it = MapFindPtr(config, "search_neighborhood_size")) {
                knnConfig.SearchNeighborhoodSize = NTalkativeConfig::As<size_t>(*it);
            }

            if (auto it = MapFindPtr(config, "num_exact_candidates")) {
                knnConfig.MaxIndexes = NTalkativeConfig::As<size_t>(*it);
            }

            return knnConfig;
        }

        NHnsw::THnswBuildOptions MakeOpts(bool firstLaunch) const {
            NHnsw::THnswBuildOptions opts;
            opts.MaxNeighbors = MaxNeighbors;
            opts.SearchNeighborhoodSize = SearchNeighborhoodSize;
            opts.NumExactCandidates = NumExactCandidates;
            opts.Verbose = false;
            opts.ReportProgress = true;
            opts.NumThreads = firstLaunch ? NHnsw::THnswBuildOptions::AutoSelect : IndexThreads;
            return opts;
        }

        TConfig() = default;

        TDuration Ttl = TDuration::Hours(6);
        size_t BuffersMaxSize = 5000;
        size_t MaxNeighbors = 100;
        size_t MaxIndexes = 20;
        size_t SearchNeighborhoodSize = 300;
        size_t NumExactCandidates = 100;
        TMaybe<TFsPath> DumpingFile;
        size_t IndexThreads = NHnsw::THnswBuildOptions::AutoSelect;
    };

    static const int SERIAL_VER = 4;

    template<class T>
    class TSerializablePtr: public TAtomicSharedPtr<T> {
    private:
        using TBase = TAtomicSharedPtr<T>;
    public:
        using TBase::TBase;
        TSerializablePtr(TBase base)
            : TBase(std::move(base))
        {}

        bool operator< (const TSerializablePtr& toCompare) const {
            return TBase::Get() < toCompare.Get();
        }

        inline void Save(IOutputStream *rh) const {
            ::Save(rh, *TBase::Get());
        }

        inline void Load(IInputStream *rh) {
            T val;
            ::Load(rh, val);
            *this = MakeAtomicShared<T>(std::move(val));
        }

        inline void Load(IInputStream *rh, TMemoryPool &pool) {
            T val(pool);
            ::Load(rh, val);
            *this = MakeAtomicShared<T>(std::move(val));
        }
    };

    template<class TPoint, class TDistance>
    struct TypedNs {
        class TUpdatablePoint {
        public:
            TUpdatablePoint() = default;

            TUpdatablePoint(TUpdatablePoint &&point) noexcept
                    : LastUpdate(point.LastUpdate.load()), Point(std::move(point.Point)) {

            };

            template<class... TPointArgs>
            explicit TUpdatablePoint(TInstant now, TPointArgs &&... args)
                    : LastUpdate(now.MicroSeconds()), Point(std::forward<TPointArgs...>(args)...) {
            }

            const TPoint &GetPoint() const {
                return Point;
            }

            TPoint &GetMutablePoint() {
                return Point;
            }

            bool FreshAndExists(TInstant now) {
                return LastUpdate.load() > now.MicroSeconds();
            }

            void Touch(TInstant now) {
                while (true) {
                    ui64 t = LastUpdate.load();
                    if (t == 0 ||
                        LastUpdate.compare_exchange_weak(t, now.MicroSeconds()))
                    {
                        break;
                    }
                }
            }

            void MarkDeleted() {
                LastUpdate.store(0);
            }

            inline void Save(IOutputStream *s) const {
                ::Save(s, LastUpdate.load());
                Point.Save(s);
            }

            inline void Load(IInputStream *s) {
                {
                    ui64 t;
                    ::Load(s, t);
                    LastUpdate.store(t);
                }
                ::Load(s, Point);
            }

        private:
            std::atomic<ui64> LastUpdate = ::MicroSeconds();
            TPoint Point;
        };

        using TUpdatablePointPtr = TSerializablePtr<TUpdatablePoint>;


        class TPtrDistance {
        public:
            using TResult = float;
            using TLess = std::less<float>;

            TPtrDistance() = default;

            explicit TPtrDistance(TDistance distance) noexcept
                    : Distance(std::move(distance)) {
            }

            TResult operator()(const TUpdatablePointPtr item1, const TUpdatablePointPtr item2) const noexcept {
                return Distance(item1->GetPoint(), item2->GetPoint());
            }

        private:
            TDistance Distance;
        };

        class TItemsStorage {
        private:
            TDeque<TUpdatablePointPtr> Items;

        public:
            using TItem = TUpdatablePointPtr;

            TItemsStorage() = default;

            TItemsStorage(TItemsStorage &&) noexcept = default;

            explicit TItemsStorage(TDeque<TUpdatablePointPtr> items) noexcept
                    : Items(std::move(items)) {
            }

            Y_SAVELOAD_DEFINE(Items);

            [[nodiscard]] size_t GetNumItems() const noexcept {
                return Items.size();
            }

            const TUpdatablePointPtr &GetItem(size_t I) const {
                return Items.at(I);
            }

            const auto &GetItems() const {
                return Items;
            }

            auto &GetItemsMutable() {
                return Items;
            }

            TBlob MakeBlob(const NHnsw::THnswBuildOptions &opts, const TPtrDistance &distance) const {
                TBufferOutput out;
                {
                    NHnsw::THnswIndexData indexData = NHnsw::BuildIndex(opts, *this, distance);
                    NHnsw::WriteIndex(indexData, out);
                }

                return TBlob::FromBuffer(out.Buffer());
            }
        };

        struct TUpdatablePointPtrWithDistance {
            TUpdatablePointPtrWithDistance() = default;

            TUpdatablePointPtrWithDistance(TUpdatablePointPtr item, typename TPtrDistance::TResult distance) noexcept
                    : Item(std::move(item)), Distance(distance) {}

            bool operator<(const TUpdatablePointPtrWithDistance& another) const noexcept {
                return Distance < another.Distance;
            }

            TUpdatablePointPtr Item;
            typename TPtrDistance::TResult Distance{};
        };

        class TIndex : public TItemsStorage, public NHnsw::THnswIndexBase {
        public:
            using TBase = NHnsw::THnswIndexBase;

        public:

            TIndex(const NHnsw::THnswBuildOptions &opts, TPtrDistance distance)
                    : TBase(TItemsStorage::MakeBlob(opts, distance)), Opts(opts), Distance(distance) {
            }

            TIndex(TItemsStorage storage, const NHnsw::THnswBuildOptions &opts, TPtrDistance distance)
                    : TItemsStorage(std::move(storage)), TBase(TItemsStorage::MakeBlob(opts, distance)), Opts(opts),
                      Distance(std::move(distance)) {
            }

            TVector<TUpdatablePointPtrWithDistance> GetNearestNeighbors(const TUpdatablePointPtr &query) const {
                const auto neighbors = TBase::GetNearestNeighbors(query, Opts.MaxNeighbors, Opts.SearchNeighborhoodSize,
                                                                  Max<size_t>(), *this, Distance);

                TVector<TUpdatablePointPtrWithDistance> neighborsWithDistances(Reserve(neighbors.size()));
                for (const auto &neighbor : neighbors) {
                    neighborsWithDistances.emplace_back(
                            TUpdatablePointPtrWithDistance{this->GetItem(neighbor.Id), neighbor.Dist});
                }

                return neighborsWithDistances;
            }

        private:
            NHnsw::THnswBuildOptions Opts;
            TPtrDistance Distance;
        };

        struct TSafeItems {
            explicit TSafeItems(TMemoryPool &memoryPool) : MemoryPool(memoryPool), Items(memoryPool) {}

            TSafeItems(TSafeItems &&another) : MemoryPool(another.MemoryPool), Items(MemoryPool) {
                for (auto it = another.Items.SeekToFirst(); it.IsValid(); it.Next()) {
                    Items.Insert(std::move(it.GetValue()));
                }
            }

            inline void Save(IOutputStream *s) const {
                TDeque<TUpdatablePointPtr> items;
                for (auto it = Items.SeekToFirst(); it.IsValid(); it.Next())
                    items.emplace_back(it.GetValue());
                ::Save(s, items);
            }

            inline void Load(IInputStream *s) {
                TDeque<TUpdatablePointPtr> items;
                ::Load(s, items);
                for (auto &item: items) {
                    Items.Insert(std::move(item));
                }
            }

            TMemoryPool &MemoryPool;
            NThreading::TSkipList<TUpdatablePointPtr> Items;
        };

        class TIndexManager {
        private:
            const TConfig Config;

            TMemoryPool MemoryPool;
            TSerializablePtr<TSafeItems> Buf1, Buf2;

            IThreadPool *AddingIndexThreadPool, *CollapsingThreadPool;
            THolder<TThreadPool> DefaultAddingIndexThreadPool, DefaultCollapsingThreadPool;

            const TPtrDistance Distance;

            TAtomicSharedPtr<NThreading::TSkipList<TSerializablePtr<TIndex>>> Indexes1, Indexes2;

            TAtomic AddingIndex{}, CollapsingIndexes{};

            mutable TMutex DumpingLock;
            mutable TCondVar CondDumping;

        public:
            explicit TIndexManager(TConfig config = {}, TDistance distance = {},
                                   IThreadPool *addingIndexThreadPool = {}, IThreadPool *collapsingThreadPool = {})
                    : Config(std::move(config)),
                      MemoryPool(Config.BuffersMaxSize * 2, TMemoryPool::TExpGrow::Instance()),
                      Buf1(MakeAtomicShared<TSafeItems>(MemoryPool)),
                      Buf2(MakeAtomicShared<TSafeItems>(MemoryPool)), AddingIndexThreadPool(addingIndexThreadPool),
                      CollapsingThreadPool(collapsingThreadPool), Distance(std::move(distance)),
                      Indexes1(MakeAtomicShared<NThreading::TSkipList<TSerializablePtr<TIndex>>>(MemoryPool)),
                      Indexes2(MakeAtomicShared<NThreading::TSkipList<TSerializablePtr<TIndex>>>(MemoryPool)) {
                if (!AddingIndexThreadPool) {
                    DefaultAddingIndexThreadPool = MakeHolder<TThreadPool>();
                    AddingIndexThreadPool = DefaultAddingIndexThreadPool.Get();
                }
                if (!CollapsingThreadPool) {
                    DefaultCollapsingThreadPool = MakeHolder<TThreadPool>();
                    CollapsingThreadPool = DefaultCollapsingThreadPool.Get();
                }

                if (Config.DumpingFile && NFs::Exists(*Config.DumpingFile)) {
                    TIFStream s(*Config.DumpingFile);
                    Load(&s);
                }

                AddingIndexThreadPool->Start(1, 0);
                CollapsingThreadPool->Start(1, 0);
            }

            ~TIndexManager() {
                Dump();
            }

            void Dump() const {
                if (!Config.DumpingFile)
                    return;

                if (DumpingLock.TryAcquire()) {
                    Y_DEFER {
                                DumpingLock.Release();
                                CondDumping.BroadCast();
                            };

                    const TString tempName = Config.DumpingFile->GetPath() + "_" + CreateGuidAsString() + ".tmp";
                    {
                        TOFStream s(tempName);
                        Save(&s);
                    }
                    NFs::Rename(tempName, *Config.DumpingFile);
                }
            }

            [[nodiscard]] size_t BuffersSize() const {
                size_t size{};
                for (const auto b : {Buf1, Buf2}) {
                    size += b->Items.GetSize();
                }
                return size;
            }

            [[nodiscard]] size_t IndexSize() const {
                size_t size{};

                for (const auto indexes : {Indexes1, Indexes2}) {
                    for (auto it = indexes->SeekToFirst(); it.IsValid(); it.Next()) {
                        auto indexPtr = it.GetValue();
                        size += indexPtr->GetNumItems();
                    }
                }

                return size;
            }

            [[nodiscard]] size_t IndexesCount() const {
                size_t size{};
                for (const auto cp : {Indexes1, Indexes2}) {
                    size += cp->GetSize();
                }
                return size;
            }

            [[nodiscard]] size_t Size() const {
                return BuffersSize() + IndexSize();
            }

            TDeque<TUpdatablePointPtrWithDistance> GetNearestNeighbors(const TUpdatablePointPtr &newItem) const {
                const auto now = Now();
                const auto freshDeadline = now - Config.Ttl;

                TDeque<TUpdatablePointPtrWithDistance> bestNeighbors{};

                for (const auto buf : {Buf1, Buf2}) {
                    for (auto it = buf->Items.SeekToFirst(); it.IsValid(); it.Next()) {
                        const auto item = it.GetValue();
                        const auto d = Distance(item, newItem);
                        if (item->FreshAndExists(freshDeadline))
                            bestNeighbors.emplace_back(item, d);
                    }
                }

                for (const auto cp : {Indexes1, Indexes2}) {
                    for (auto it = cp->SeekToFirst(); it.IsValid(); it.Next()) {
                        auto indexPtr = it.GetValue();

                        TVector<TUpdatablePointPtrWithDistance> neighbors = indexPtr->GetNearestNeighbors(newItem);
                        for (auto &neighbor : neighbors) {
                            if (neighbor.Item->FreshAndExists(freshDeadline))
                                bestNeighbors.emplace_back(std::move(neighbor.Item), neighbor.Distance);
                        }
                    }
                }

                if (!bestNeighbors)
                    return {};

                if (bestNeighbors.size() > Config.MaxNeighbors) {
                    auto it = std::next(bestNeighbors.begin(), Config.MaxNeighbors);
                    PartialSort(bestNeighbors.begin(), it, bestNeighbors.end());
                    bestNeighbors.resize(std::distance(bestNeighbors.begin(), it));
                } else {
                    Sort(bestNeighbors.begin(), bestNeighbors.end());
                }

                for (auto &n : bestNeighbors)
                    n.Item->Touch(now);

                return bestNeighbors;
            }

            void AddItem(TUpdatablePointPtr newItem) {
                size_t size{};
                {
                    auto cp = Buf1;
                    cp->Items.Insert(std::move(newItem));
                    size = cp->Items.GetSize();
                }

                if (size >= Config.BuffersMaxSize && AtomicCas(&AddingIndex, 1, 0)) {
                    if (!AddingIndexThreadPool->AddFunc([this]() {
                        Y_DEFER { AtomicSet(AddingIndex, 0); };
                        AddIndex();
                    })) {
                        AtomicSet(AddingIndex, 0);
                    }
                }
            }

            void Save(IOutputStream *s) const {
                ::Save(s, SERIAL_VER);
                {
                    TDeque<TSerializablePtr<TSafeItems>> deq{Buf1, Buf2}; // just legacy
                    ::Save(s, deq);
                }
                {
                    const TItemsStorage storage = CollectAllItems();
                    ::Save(s, storage);
                }
            }

            void Load(IInputStream *s) {
                {
                    int ver = 0;
                    ::Load(s, ver);
                    if (ver != SERIAL_VER)
                        return;
                }

                {
                    TDeque<TSerializablePtr<TSafeItems>> deq; // just legacy
                    ::Load(s, deq, MemoryPool);
                }

                TItemsStorage itemsStorage;
                ::Load(s, itemsStorage);

                auto cp = Indexes1;
                cp->Insert(MakeAtomicShared<TIndex>(std::move(itemsStorage), Config.MakeOpts(true), Distance));
            }

            void AddIndex() {
                TItemsStorage storage;
                auto &mutableItems = storage.GetItemsMutable();
                {
                    auto cp = Buf2;
                    for (auto it = cp->Items.SeekToFirst(); it.IsValid(); it.Next()) {
                        mutableItems.emplace_back(it.GetValue());
                    }
                }

                size_t size{};
                {
                    auto cp = Indexes1;
                    cp->Insert(MakeAtomicShared<TIndex>(std::move(storage), Config.MakeOpts(false), Distance));
                    size = cp->GetSize();
                }

                Buf2 = Buf1;
                Buf1 = MakeAtomicShared<TSafeItems>(MemoryPool);

                if (size >= Config.MaxIndexes && AtomicCas(&CollapsingIndexes, 1, 0)) {
                    if (!CollapsingThreadPool->AddFunc([this]() {
                        Y_DEFER { AtomicSet(CollapsingIndexes, 0); };
                        CollapseIndexes();
                    })) {
                        AtomicSet(CollapsingIndexes, 0);
                    }
                }
            }

            void CollapseIndexes() {
                TItemsStorage storage;
                auto &mutableItems = storage.GetItemsMutable();

                {
                    const auto now = Now();
                    const auto freshDeadline = now - Config.Ttl;

                    const auto cp = Indexes2;
                    for (auto it = cp->SeekToFirst(); it.IsValid(); it.Next()) {
                        const auto indexPtr = it.GetValue();

                        for (auto item : indexPtr->GetItems())
                            if (item->FreshAndExists(freshDeadline))
                                mutableItems.emplace_back(item);
                    }
                }

                Indexes2 = Indexes1;
                Indexes2->Insert(MakeAtomicShared<TIndex>(std::move(storage), Config.MakeOpts(false), Distance));
                Indexes1 = MakeAtomicShared<NThreading::TSkipList<TSerializablePtr<TIndex>>>(MemoryPool);
            }

        private:
            TItemsStorage CollectAllItems() const {
                TItemsStorage storage;
                {
                    const auto now = Now();
                    const auto freshDeadline = now - Config.Ttl;

                    auto &mutableItems = storage.GetItemsMutable();

                    for (auto buf : {Buf1, Buf2}) {
                        for (auto it = buf->Items.SeekToFirst(); it.IsValid(); it.Next()) {
                            const auto item = it.GetValue();
                            if (item->FreshAndExists(freshDeadline))
                                mutableItems.emplace_back(item);
                        }
                    }

                    for (const auto cp : {Indexes1, Indexes2}) {
                        for (auto it = cp->SeekToFirst(); it.IsValid(); it.Next()) {
                            auto indexPtr = it.GetValue();

                            for (const auto item : indexPtr->GetItems())
                                if (item->FreshAndExists(freshDeadline))
                                    mutableItems.emplace_back(item);
                        }
                    }
                }
                return storage;
            }
        };
    };
} // namespace NBatchHnsw
