#include <infra/netmon/statistics/histograms.h>
#include <infra/netmon/statistics/varint.h>

#include <library/cpp/sse/sse.h>

#include <util/generic/algorithm.h>
#include <util/generic/xrange.h>
#include <util/generic/singleton.h>
#include <util/generic/buffer.h>
#include <util/generic/ymath.h>
#include <util/stream/buffer.h>
#include <util/stream/mem.h>

#include <cmath>

namespace NNetmon {
    namespace {
        using TBucketWeights = std::array<double, RTT_BUCKET_COUNT>;

        constexpr std::size_t AlignSize(std::size_t value) noexcept {
            // simply truncate last two bits
            return value & (~3);
        }

        constexpr std::size_t EvenSize(std::size_t value) noexcept {
            // return as is if even, or add 1 if not
            value += value & 1;
            return value;
        }

        TBucketWeights ComputeBucketWeights() {
            // in milliseconds, biased to inter datacenter latency
            TBucketWeights intervals{{
                0.01,
                0.02,
                0.03,
                0.04,
                0.05,
                0.06,
                0.07,
                0.08,
                0.09,
                0.1,
                0.15,
                0.2,
                0.25,
                0.3,
                0.4,
                0.5,
                0.6,
                0.7,
                0.8,
                0.9,
                1.0,
                1.2,
                1.4,
                1.6,
                1.8,
                2.0,
                2.5,
                3.0,
                3.5,
                4.0,
                4.5,
                5.0,
                5.5,
                6.0,
                6.5,
                7.0,
                7.5,
                8.0,
                8.5,
                9.0,
                9.5,
                10.0,
                15.0,
                20.0,
                25.0,
                30.0,
                35.0,
                40.0,
                45.0,
                50.0,
                60.0,
                70.0,
                80.0,
                90.0,
                100.0,
                200.0,
                300.0,
                400.0,
                500.0,
                600.0,
                700.0,
                800.0,
                900.0,
                1000.0
            }};
            // because of SSE usage in merge operation, size should be "aligned"
            Y_VERIFY(AlignSize(intervals.size()) == intervals.size());
            return intervals;
        }

        const TBucketWeights BucketWeights = ComputeBucketWeights();

        class TDefaultWeights {
        public:
            TDefaultWeights() {
                for (const auto& boundary : BucketWeights) {
                    Values.push_back(boundary);
                }
            }

            TVector<double> Values;
        };
    }

    const TVector<double>& GetBucketWeights() {
        return Singleton<TDefaultWeights>()->Values;
    }

    TSampleHistogram::TSampleHistogram()
        : MinIndex(RTT_BUCKET_COUNT)
        , MaxIndex(0)
    {
        Buckets.fill(0);
    }

    TSampleHistogram::TSampleHistogram(const NCommon::TSampleHistogram& hist)
        : TSampleHistogram()
    {
        FromProto(hist);
    }

    void TSampleHistogram::Append(double value) {
        auto it = UpperBound(BucketWeights.cbegin(), BucketWeights.cend(), value);
        size_t idx = (BucketWeights.cbegin() == it) ? 0 : std::distance(BucketWeights.cbegin(), it) - 1;
        Buckets[idx]++;
        MinIndex = Min(static_cast<ui8>(idx), MinIndex);
        MaxIndex = Max(static_cast<ui8>(idx), MaxIndex);
    }

    void TSampleHistogram::Merge(const TSampleHistogram& hist, ui32 weight) {
        __m128i multiplier = _mm_set1_epi32(weight);
        for (std::size_t idx(0); idx < Buckets.size(); idx += 4) {
            __m128i lhs = _mm_loadu_si128(reinterpret_cast<const __m128i*>(&Buckets[idx]));
            __m128i rhs = _mm_loadu_si128(reinterpret_cast<const __m128i*>(&hist.Buckets[idx]));

            // multiplication
            __m128i low = _mm_mul_epu32(rhs, multiplier); /* mul 2,0 */
            __m128i high = _mm_mul_epu32(_mm_srli_si128(rhs, 4), multiplier); /* mul 3,1 */
            __m128i normalized = _mm_unpacklo_epi32(
                _mm_shuffle_epi32(low, _MM_SHUFFLE (0,0,2,0)),
                _mm_shuffle_epi32(high, _MM_SHUFFLE (0,0,2,0))
            ); /* shuffle results to [63..0] and pack */

            // summation
            __m128i result = _mm_add_epi32(lhs, normalized);

            // copy result
            _mm_storeu_si128(reinterpret_cast<__m128i*>(&Buckets[idx]), result);
        }

        MinIndex = Min(MinIndex, hist.MinIndex);
        MaxIndex = Max(MaxIndex, hist.MaxIndex);
    }

    void TSampleHistogram::FillCumulativeCount(TBuckets& cumulativeCount) const {
        std::partial_sum(Buckets.cbegin(), Buckets.cend(), cumulativeCount.begin());
    }

    TMaybe<double> TSampleHistogram::GetPercentileValue(double percent, const TBuckets& cumulativeCount) const {
        Y_VERIFY(0 <= percent && percent <= 1.0);
        if (!cumulativeCount.back() || !percent) {
            return Nothing();
        }
        double needle = percent * cumulativeCount.back();
        auto it = LowerBound(cumulativeCount.cbegin(), cumulativeCount.cend(), needle);
        auto bucketIndex = std::distance(cumulativeCount.cbegin(), it);
        if (0 < bucketIndex && (size_t)bucketIndex < RTT_BUCKET_COUNT - 1) {
            return BucketWeights[bucketIndex] + (BucketWeights[bucketIndex] - BucketWeights[bucketIndex - 1]) / 2.0;
        } else {
            return BucketWeights[bucketIndex];
        }
    }

    TMaybe<double> TSampleHistogram::GetPercentileValue(double percent) const {
        TBuckets cumulativeCount{};
        FillCumulativeCount(cumulativeCount);
        return GetPercentileValue(percent, cumulativeCount);
    }

    double TSampleHistogram::GetPercentileValueOrNan(double percent, const TBuckets& cumulativeCount) const {
        auto value(GetPercentileValue(percent, cumulativeCount));
        return value.Defined() ? value.GetRef() : NAN;
    }

    double TSampleHistogram::GetPercentileValueOrNan(double percent) const {
        auto value(GetPercentileValue(percent));
        return value.Defined() ? value.GetRef() : NAN;
    }

    void TSampleHistogram::ToJson(NJsonWriter::TBuf& buf) const {
        TBuckets cumulativeCount{};
        FillCumulativeCount(cumulativeCount);
        const auto p50(GetPercentileValue(0.50, cumulativeCount));
        const auto p70(GetPercentileValue(0.70, cumulativeCount));
        const auto p90(GetPercentileValue(0.90, cumulativeCount));
        const auto p95(GetPercentileValue(0.95, cumulativeCount));
        const auto p99(GetPercentileValue(0.99, cumulativeCount));
        auto context(buf.BeginList());
        if (p50.Defined() && p70.Defined() && p90.Defined() && p95.Defined() && p99.Defined()) {
            context
                .WriteDouble(p50.GetRef())
                .WriteDouble(p70.GetRef())
                .WriteDouble(p90.GetRef())
                .WriteDouble(p95.GetRef())
                .WriteDouble(p99.GetRef());
        }
        context.EndList();
    }

    flatbuffers::Offset<NCommon::TSampleHistogram> TSampleHistogram::ToProto(
            flatbuffers::FlatBufferBuilder& builder) const {
        TBuffer buf;
        TBufferStream stream(buf);
        for (const auto idx : xrange(AlignSize(MinIndex) / 2, EvenSize(MaxIndex + 1) / 2)) {
            Y_ASSERT(idx * 2 + 1 < Buckets.size());
            VarintEncode(stream, Buckets[2 * idx], Buckets[2 * idx + 1]);
        }
        stream.Finish();

        return NCommon::CreateTSampleHistogram(
            builder,
            MinIndex,
            builder.CreateVector(
                reinterpret_cast<const i8*>(buf.Data()),
                buf.Size()
            )
        );
    }

    void TSampleHistogram::FromProto(const NCommon::TSampleHistogram& hist) {
        Buckets.fill(0);
        MinIndex = hist.MinIndex();
        TMemoryInput stream(hist.Buckets()->Data(), hist.Buckets()->size());
        for (const auto idx : xrange(AlignSize(MinIndex) / 2, Buckets.size() / 2)) {
            if (VarintDecode(stream, Buckets[2 * idx], Buckets[2 * idx + 1]) == EVarintStatus::END_OF_STREAM) {
                return;
            } else {
                MaxIndex = 2 * idx + 1;
            }
        }
    }

    bool TSampleHistogram::operator==(const TSampleHistogram& rhs) const noexcept {
        return std::memcmp(
            Buckets.data(),
            rhs.Buckets.data(),
            sizeof(decltype(Buckets)::value_type) * Buckets.size()
        ) == 0;
    }

    void TSampleHistogram::InitMinMaxIndexFromBuckets() {
        auto firstNonzeroIt = FindIf(Buckets, [](ui32 x) { return x != 0; });
        MinIndex = std::distance(begin(Buckets), firstNonzeroIt);
        auto lastNonzeroIt = FindIf(rbegin(Buckets), rend(Buckets),
                                    [](ui32 x) { return x != 0; });
        if (lastNonzeroIt == rend(Buckets)) {
            MaxIndex = 0;
        } else {
            MaxIndex = std::distance(lastNonzeroIt, rend(Buckets)) - 1;
        }
    }

    TConnectivityHistogram::TConnectivityHistogram()
        : WeightAccumulator(0.0)
    {
        Buckets.fill(0.0);
    }

    TConnectivityHistogram::TConnectivityHistogram(const NCommon::TConnectivityHistogram& hist)
        : TConnectivityHistogram()
    {
        FromProto(hist);
    }

    void TConnectivityHistogram::Append(double value, double weight) {
#ifdef NOC_SLA_BUILD
        (void)weight;
        for (size_t i = 0; i < CONNECTIVITY_BUCKET_COUNT; ++i) {
            Buckets[i] += value;
        }
        WeightAccumulator += 1;
#else
        // 3 means [1, 1]
        if (value == 1.0) {
            Buckets[3] += weight;
        }
        // 2 means [0.9, 1]
        if (value >= 0.9) {
            Buckets[2] += weight;
        }
        // 1 means [0.7, 1]
        if (value >= 0.7) {
            Buckets[1] += weight;
        }
        // 0 means [0.5, 1]
        if (value >= 0.5) {
            Buckets[0] += weight;
        }
        WeightAccumulator += weight;
#endif
    }

    void TConnectivityHistogram::Merge(const TConnectivityHistogram& hist) {
        WeightAccumulator += hist.WeightAccumulator;
        for (const auto idx : xrange(Buckets.size())) {
            Buckets[idx] += hist.Buckets[idx];
        }
    }

    TMaybe<TConnectivityHistogram::TNormalizedBuckets> TConnectivityHistogram::GetValues() const {
        TNormalizedBuckets normalizedBuckets{};
        normalizedBuckets.fill(0.0);
        if (WeightAccumulator) {
            size_t idx = 0;
            for (const auto& val : Buckets) {
                normalizedBuckets[idx++] = val / WeightAccumulator;
            }
            return normalizedBuckets;
        } else {
            return Nothing();
        }
    }

    void TConnectivityHistogram::ToJson(NJsonWriter::TBuf& buf) const {
        auto context = buf.BeginList();
        auto values(GetValues());
        if (values.Defined()) {
            for (double value : values.GetRef()) {
                context.WriteDouble(value);
            }
        }
        context.EndList();
    }

    flatbuffers::Offset<NCommon::TConnectivityHistogram> TConnectivityHistogram::ToProto(
            flatbuffers::FlatBufferBuilder& builder) const {
        return NCommon::CreateTConnectivityHistogram(
                builder, WeightAccumulator, builder.CreateVector(Buckets.data(), Buckets.size()));
    }

    void TConnectivityHistogram::FromProto(const NCommon::TConnectivityHistogram& hist) {
        const auto& restoredBuckets = *hist.Buckets();
        Y_VERIFY(restoredBuckets.size() == Buckets.size());
        WeightAccumulator = hist.WeightAccumulator();
        for (const auto idx : xrange(restoredBuckets.size())) {
            Buckets[idx] = restoredBuckets.Get(idx);
        }
    }

    bool TConnectivityHistogram::operator==(const TConnectivityHistogram& rhs) const noexcept {
        return std::memcmp(
            Buckets.data(),
            rhs.Buckets.data(),
            sizeof(decltype(Buckets)::value_type) * Buckets.size()
        ) == 0 && WeightAccumulator == rhs.WeightAccumulator;
    }

    TMaybe<double> TAverage::GetValue() const {
        if (WeightAccumulator) {
            return ValueAccumulator / WeightAccumulator;
        } else {
            return Nothing();
        }
    }

    void TAverage::Append(double value, ui32 weight) {
        ValueAccumulator += value * weight;
        WeightAccumulator += weight;
    }

    void TAverage::Merge(const TAverage& other) {
        ValueAccumulator += other.ValueAccumulator;
        WeightAccumulator += other.WeightAccumulator;
    }

    NCommon::TAverage TAverage::ToProto() const {
        return NCommon::TAverage(ValueAccumulator, WeightAccumulator);
    }

    void TAverage::FromProto(const NCommon::TAverage& element) {
        ValueAccumulator = element.ValueAccumulator();
        WeightAccumulator = element.WeightAccumulator();
    }

    bool TAverage::operator==(const TAverage& rhs) const noexcept {
        return ValueAccumulator == rhs.ValueAccumulator && WeightAccumulator == rhs.WeightAccumulator;
    }

    TAverageHistogram::TAverageHistogram(const NCommon::TAverageHistogram& hist) {
        FromProto(hist);
    }

    void TAverageHistogram::Merge(const TConnectivityHistogram& hist, ui32 weight) {
        if (hist.Empty()) {
            return;
        }
        auto buckets = hist.GetValues();
        if (buckets.Defined()) {
            for (const auto idx : xrange(Elements.size())) {
                Elements[idx].Append(buckets->at(idx), weight);
            }
        }
    }

    void TAverageHistogram::Merge(const TAverageHistogram& hist) {
        for (const auto idx : xrange(Elements.size())) {
            Elements[idx].Merge(hist.Elements[idx]);
        }
    }

    TMaybe<TConnectivityHistogram::TNormalizedBuckets> TAverageHistogram::GetValues() const {
        TConnectivityHistogram::TNormalizedBuckets buckets;
        for (const auto idx : xrange(Elements.size())) {
            auto value(Elements[idx].GetValue());
            if (value.Defined()) {
                buckets[idx] = value.GetRef();
            } else {
                return Nothing();
            }
        }
        return buckets;
    }

    void TAverageHistogram::ToJson(NJsonWriter::TBuf& buf) const {
        auto context = buf.BeginList();
        auto values(GetValues());
        if (values.Defined()) {
            for (double value : values.GetRef()) {
                context.WriteDouble(value);
            }
        }
        context.EndList();
    }

    flatbuffers::Offset<NCommon::TAverageHistogram> TAverageHistogram::ToProto(
            flatbuffers::FlatBufferBuilder& builder) const {

        std::vector<NCommon::TAverage> elements;
        elements.reserve(elements.size());
        for (const auto& el : Elements) {
            elements.emplace_back(el.ToProto());
        }

        return NCommon::CreateTAverageHistogram(builder, builder.CreateVectorOfStructs(elements));
    }

    void TAverageHistogram::FromProto(const NCommon::TAverageHistogram& hist) {
        const auto& restoredElements = *hist.Elements();
        Y_VERIFY(restoredElements.size() == Elements.size());
        for (const auto idx : xrange(Elements.size())) {
            Elements[idx].FromProto(*restoredElements.Get(idx));
        }
    }

    bool TAverageHistogram::operator==(const TAverageHistogram& rhs) const noexcept {
        for (const auto idx : xrange(Elements.size())) {
            if (Elements[idx] != rhs.Elements[idx]) {
                return false;
            }
        }
        return true;
    }
}

template <>
void Out<NNetmon::TSampleHistogram>(IOutputStream& stream,
                                    TTypeTraits<NNetmon::TSampleHistogram>::TFuncParam hist) {
    NJsonWriter::TBuf buf;
    hist.ToJson(buf);
    buf.FlushTo(&stream);
}

template <>
void Out<NNetmon::TConnectivityHistogram>(IOutputStream& stream,
                                          TTypeTraits<NNetmon::TConnectivityHistogram>::TFuncParam hist) {
    NJsonWriter::TBuf buf;
    hist.ToJson(buf);
    buf.FlushTo(&stream);
}

template <>
void Out<NNetmon::TAverageHistogram>(IOutputStream& stream,
                                     TTypeTraits<NNetmon::TAverageHistogram>::TFuncParam hist) {
    NJsonWriter::TBuf buf;
    hist.ToJson(buf);
    buf.FlushTo(&stream);
}
