#pragma once

#include <infra/netmon/library/boxes.h>
#include <infra/netmon/library/event_hub.h>
#include <infra/netmon/metrics.h>
#include <infra/netmon/topology/topology.h>

#include <util/generic/algorithm.h>

namespace NNetmon {
    class TTopologyStorage;

    namespace NInternal {
        // unified code for both rtt and packet counters
        template <class TBaseCounters>
        struct TOrientedCounters {
            using TValue = typename TBaseCounters::TValue;
            template <class T> using TGenericType =
                TOrientedCounters<typename TBaseCounters::template TGenericType<T>>;
            using TAtomicBaseCounters = typename TBaseCounters::template TGenericType<TAtomic>;

            static TGenericType<ui64> ExtractMetrics(TGenericType<TAtomic>& counters) {
                return TGenericType<ui64>{
                    TBaseCounters::ExtractMetrics(counters.Rx),
                    TBaseCounters::ExtractMetrics(counters.Tx)
                };
            }

            TBaseCounters Rx, Tx;
        };

        template <class TBaseCounters,
                  class = typename std::enable_if_t<std::is_same_v<typename TBaseCounters::TValue, ui64>>>
        inline TOrientedCounters<TBaseCounters>& operator+=(TOrientedCounters<TBaseCounters>& lhs, const TOrientedCounters<TBaseCounters>& rhs) {
            lhs.Rx += rhs.Rx;
            lhs.Tx += rhs.Tx;
            return lhs;
        }

        template <class TBaseCounters,
                  class = typename std::enable_if_t<std::is_same_v<typename TBaseCounters::TValue, ui64>>>
        inline TOrientedCounters<TBaseCounters>& operator-=(TOrientedCounters<TBaseCounters>& lhs, const TOrientedCounters<TBaseCounters>& rhs) {
            lhs.Rx -= rhs.Rx;
            lhs.Tx -= rhs.Tx;
            return lhs;
        }

        // incoming param defines direction: true is Rx, false is Tx
        template<class T, class... Args>
        inline void IncrementCounter(TOrientedCounters<T>& counters,
                                     bool incoming,
                                     Args&&... args) {
            static_assert(std::is_same_v<typename T::TValue, TAtomic>);
            if (incoming) {
                IncrementCounter(counters.Rx, std::forward<Args>(args)...);
            } else {
                IncrementCounter(counters.Tx, std::forward<Args>(args)...);
            }
        }
    }

    template <class TValue>
    using TOrientedPacketSlaCounters = NInternal::TOrientedCounters<TGenericPacketSlaCounters<TValue>>;
    template <class TValue>
    using TOrientedRttSlaCounters = NInternal::TOrientedCounters<TGenericRttSlaCounters<TValue>>;

    class TSwitchSlaCounters {
    public:
        template <class TCounters>
        using TCounterMap = THashMap<ui64, TCounters>;

        using TInterSwitchCounters = TOrientedPacketSlaCounters<TAtomic>;
        using TInterSwitchCounterMap = TCounterMap<TInterSwitchCounters>;
        using TInterSwitchCounterMapBox = TRWLockedBox<TInterSwitchCounterMap>::TReadOwnedBox;

        using TInterSwitchRttCounters = TOrientedRttSlaCounters<TAtomic>;
        using TInterSwitchRttCounterMap = TCounterMap<TInterSwitchRttCounters>;
        using TInterSwitchRttCounterMapBox = TRWLockedBox<TInterSwitchRttCounterMap>::TReadOwnedBox;

        using TLinkPollerCounters = TGenericPacketSlaCounters<TAtomic>;
        using TLinkPollerCounterMap = TCounterMap<TLinkPollerCounters>;
        using TLinkPollerCounterMapBox = TRWLockedBox<TLinkPollerCounterMap>::TReadOwnedBox;

        explicit TSwitchSlaCounters(const TTopologyStorage& topologyStorage);
        ~TSwitchSlaCounters();

        void RegisterInterSwitchPackets(const TSwitch& switchRef,
                                        bool incoming,
                                        ENetworkType network,
                                        ui64 successCount,
                                        ui64 failCount,
                                        ui64 changedCount,
                                        double rtt);
        void RegisterLinkPollerPackets(const TSwitch& switchRef,
                                       ENetworkType network,
                                       ui64 successCount,
                                       ui64 failCount);
        TInterSwitchCounterMapBox GetInterSwitchCounterMapBox();
        TInterSwitchRttCounterMapBox GetInterSwitchRttCounterMapBox();
        TLinkPollerCounterMapBox GetLinkPollerCounterMapBox();

        const TVoidEventHub& OnMapChanged() const noexcept;

    private:
        class TImpl;
        THolder<TImpl> Impl;
    };
}
