#include <infra/netmon/metrics.h>
#include <infra/netmon/settings.h>
#include <infra/netmon/statistics/histograms.h>

#include <infra/netmon/topology/topology_storage.h>

namespace NNetmon {
    TCrossDcCounters::TCrossDcCounters(const TTopologyStorage& topologyStorage) {
        const auto& dcs = TSettings::Get()->GetProbeScheduleDcs();
        const auto dcsCount = dcs.size();
        PacketCountersList.resize(dcsCount);
        RttCountersList.resize(dcsCount);

        for (size_t i = 0; i < dcsCount; ++i) {
            const auto& sourceDcName = dcs[i];
            auto sourceDc = topologyStorage.FindDatacenter(sourceDcName);
            if (sourceDc) {
                DcIndexMapping[sourceDc->GetReducedId()] = i;
            } else {
                ERROR_LOG << "Cross-dc DcIndexMapping initialization error: unknown dc " << sourceDcName << Endl;
                continue;
            }

            THashSet<TString> targetDcs(begin(dcs), end(dcs));
            for (const auto& item : TSettings::Get()->GetNetmonUrls()) {
                for (const auto& mappedDc : TTopologySettings::Get()->GetMappedDcs(item.first)) {
                    targetDcs.insert(mappedDc);
                }
            }

            for (const auto& dcName : targetDcs) {
                if (dcName == sourceDcName) {
                    continue;
                }

                auto dc = topologyStorage.FindDatacenter(dcName);
                if (!dc) {
                    ERROR_LOG << "Cross-dc counters initialization error: unknown dc " << dcName << Endl;
                    continue;
                }

                // initialize counters using the default constructor
                PacketCountersList[i][dc->GetReducedId()];
                RttCountersList[i][dc->GetReducedId()];
            }
        }
    }

    void TCrossDcCounters::RegisterPackets(const TDatacenter& sourceDc, const TDatacenter& targetDc,
                                           ENetworkType network, ui64 successCount, ui64 failCount,
                                           ui64 changedCount, double rtt) {
        auto sourceIndex = DcIndexMapping[sourceDc.GetReducedId()];
        auto it = PacketCountersList[sourceIndex].find(targetDc.GetReducedId());
        if (!it.IsEnd()) {
            IncrementCounter(it->second, network, successCount, failCount, changedCount);
        }

        auto rttIt = RttCountersList[sourceIndex].find(targetDc.GetReducedId());
        if (!rttIt.IsEnd()) {
            IncrementCounter(rttIt->second, network, rtt);
        }
    }

    void IncrementCounter(TGenericRttCounters<TAtomic>& counters, double rtt) {
        static_assert(TGenericRttCounters<TAtomic>::RTT_BUCKET_COUNT == RTT_BUCKET_COUNT);

        const auto& weights = GetBucketWeights();
        auto it = UpperBound(weights.cbegin(), weights.cend(), rtt);
        size_t idx = (weights.cbegin() == it) ? 0 : std::distance(weights.cbegin(), it) - 1;

        TLightReadGuard guard(counters.Lock);
        AtomicIncrement((*counters.Buckets)[idx]);
    }
}
