#pragma once

#include <infra/yasm/stockpile_client/common/base_types.h>
#include "cluster.h"

namespace NHistDb::NStockpile {
    template <class TShard>
    class TShardCache {
    public:
        static constexpr size_t FAIL_THRESHOLD = 6;

        using THostKey = std::pair<TString, TStockpilePort>; // host, port

        struct TShardStatus {
            TShard Key;
            bool ReadyWrite = false;
            bool ReadyRead = false;

            bool operator==(const TShardStatus& other) const noexcept {
                return Key == other.Key && ReadyWrite == other.ReadyWrite && ReadyRead == other.ReadyRead;
            }

            bool operator!=(const TShardStatus& other) const noexcept {
                return !operator==(other);
            }
        };

        void UpdateShardsForHost(const TGrpcRemoteHost& host, const TVector<TShardStatus>& shards) {
            UpdateShardsForHost(MakeHostKey(host), shards);
        }

        void UpdateShardsForHost(const THostKey& host, const TVector<TShardStatus>& shards) {
            SeenHosts.emplace(host);

            typename THashMap<THostKey, TCacheEntry>::insert_ctx ctx;
            auto it(Cache.find(host, ctx));
            if (it == Cache.end()) {
                it = Cache.emplace_direct(ctx, host, TCacheEntry{});
            }

            it->second.Shards = shards;
            it->second.FailCounter = 0;
        }

        void MarkHostAsFailed(const TGrpcRemoteHost& host) {
            MarkHostAsFailed(MakeHostKey(host));
        }

        void MarkHostAsFailed(const THostKey& host) {
            SeenHosts.emplace(host);

            const auto it(Cache.find(host));
            if (it == Cache.end()) {
                return;
            }

            it->second.FailCounter++;
            if (it->second.FailCounter >= FAIL_THRESHOLD) {
                Cache.erase(it);
            }
        }

        void Cleanup() {
            for (auto it(Cache.begin()); it != Cache.end();) {
                auto prevIt = it;
                ++it;
                if (!SeenHosts.contains(prevIt->first)) {
                    Cache.erase(prevIt);
                }
            }
            SeenHosts.clear();
        }

        TVector<TShardStatus> GetShardsForHost(const TGrpcRemoteHost& host) const {
            return GetShardsForHost(MakeHostKey(host));
        }

        TVector<TShardStatus> GetShardsForHost(const THostKey& host) const {
            const auto it(Cache.find(host));
            if (it != Cache.end()) {
                return it->second.Shards;
            } else {
                return {};
            }
        }

    private:
        struct TCacheEntry {
            TVector<TShardStatus> Shards;
            size_t FailCounter = 0;
        };

        THostKey MakeHostKey(const TGrpcRemoteHost& host) const {
            return std::make_pair(host.GetHost(), host.GetPort());
        }

        THashMap<THostKey, TCacheEntry> Cache;
        THashSet<THostKey> SeenHosts;
    };

    using TStockpileShardCache = TShardCache<TStockpileShardId>;
}
