#include "histograms.h"

#include "bit_stream.h"

#include <util/generic/algorithm.h>
#include <util/string/cast.h>

namespace NYasmServer {
    const ui8 SIMPLE_EMPTY_CODE = 0b000;
    const ui8 SIMPLE_SINGLE_VALUE_CODE = 0b001;
    const ui8 SIMPLE_ZERO_COUNT_CODE = 0b010;
    const ui8 SIMPLE_NORMAL_CODE = 0b011;
    const ui8 LOGARITHMIC_CODE = 0b100;
    const ui8 UGRAM_CODE = 0b101;

    const size_t TYPE_CODE_SIZE_BITS = 3;

    const THistogram& THistogramDecoder::Read(const TString& stream, size_t& bitPosition) {
        bool valueChanged = ReadFromBitStream(stream, bitPosition, 1);
        if (!valueChanged) {
            Y_ASSERT(!LastValue.IsNull());
            return LastValue;
        }
        bool typeChanged = ReadFromBitStream(stream, bitPosition, 1);

        if (typeChanged) {
            ui8 typeCode = ReadFromBitStream(stream, bitPosition, TYPE_CODE_SIZE_BITS);
            switch (typeCode) {
                case SIMPLE_EMPTY_CODE:
                    LastValue = THistogram(TSimpleHistogram());
                    break;
                case SIMPLE_SINGLE_VALUE_CODE:
                    ReadFullSimpleSingleValue(stream, bitPosition);
                    break;
                case SIMPLE_ZERO_COUNT_CODE:
                    ReadFullSimpleZeroCount(stream, bitPosition);
                    break;
                case SIMPLE_NORMAL_CODE:
                    ReadFullSimpleNormal(stream, bitPosition);
                    break;
                case LOGARITHMIC_CODE:
                    ReadFullLog(stream, bitPosition);
                    break;
                case UGRAM_CODE:
                    ReadFullUser(stream, bitPosition);
                    break;
            }
        } else {
            Y_ASSERT(!LastValue.IsNull());
            switch (LastValue.GetKind()) {
                case EHistogramKind::Simple:
                    switch (LastValue.AsSimpleHistogram().GetSimpleKind()) {
                        case TSimpleHistogram::EKind::Empty:
                            Y_FAIL("This is impossible");
                            break;
                        case TSimpleHistogram::EKind::SingleValue:
                            ReadPartialSimpleSingleValue(stream, bitPosition);
                            break;
                        case TSimpleHistogram::EKind::ZeroCount:
                            ReadPartialSimpleZeroCount(stream, bitPosition);
                            break;
                        case TSimpleHistogram::EKind::Normal:
                            ReadPartialSimpleNormal(stream, bitPosition);
                            break;
                    }
                    break;
                case EHistogramKind::User:
                    ReadPartialUser(stream, bitPosition);
                    break;
                case EHistogramKind::Log:
                    ReadPartialLog(stream, bitPosition);
                    break;
            }
        }
        return LastValue;
    }

    void THistogramDecoder::ReadFullSimpleSingleValue(const TString& stream, size_t& bitPosition) {
        TSimpleHistogram hist;
        hist.MutableValues() = {ReadFromBitStreamUncompressed<double>(stream, bitPosition)};
        LastValue = THistogram(std::move(hist));
    }

    void THistogramDecoder::ReadFullSimpleZeroCount(const TString& stream, size_t& bitPosition) {
        TSimpleHistogram hist;
        hist.SetZeroCount(ReadCompressedUint64(stream, bitPosition));
        LastValue = THistogram(std::move(hist));
    }

    void THistogramDecoder::ReadFullSimpleNormal(const TString& stream, size_t& bitPosition) {
        TSimpleHistogram hist;

        bitPosition = AlignToByteBoundary(bitPosition);
        // actual data is aligned
        hist.SetZeroCount(AlignedReadCompressedUint64(stream, bitPosition));
        ui8 itemsCount = AlignedReadFromBitStreamUncompressed<ui8>(stream, bitPosition);
        hist.MutableValues().resize(itemsCount);
        for (size_t i = 0; i < itemsCount; i++) {
            hist.MutableValues()[i] = AlignedReadFromBitStreamUncompressed<double>(stream, bitPosition);
        }
        LastValue = THistogram(std::move(hist));
    }

    void THistogramDecoder::ReadFullLog(const TString& stream, size_t& bitPosition) {
        TLogHistogram hist;
        bitPosition = AlignToByteBoundary(bitPosition);
        hist.SetZeroCount(AlignedReadCompressedUint64(stream, bitPosition));
        hist.SetStartPower(AlignedReadFromBitStreamUncompressed<i16>(stream, bitPosition));

        ui8 itemsCount = AlignedReadFromBitStreamUncompressed<ui8>(stream, bitPosition);
        hist.MutableWeights().resize(itemsCount);
        for (size_t i = 0; i < itemsCount; i++) {
            hist.MutableWeights()[i] = AlignedReadCompressedUint64(stream, bitPosition);
        }
        LastValue = THistogram(std::move(hist));
    }

    void THistogramDecoder::ReadFullUser(const TString& stream, size_t& bitPosition) {
        TUserHistogram hist;
        // actual data is aligned
        bitPosition = AlignToByteBoundary(bitPosition);
        ui8 itemsCount = AlignedReadFromBitStreamUncompressed<ui8>(stream, bitPosition);
        hist.MutableBuckets().resize(itemsCount);
        for (size_t i = 0; i < itemsCount; i++) {
            hist.MutableBuckets()[i].LowerBound = AlignedReadFromBitStreamUncompressed<double>(stream, bitPosition);
        }
        for (size_t i = 0; i < itemsCount; i++) {
            hist.MutableBuckets()[i].Weight = AlignedReadCompressedUint64(stream, bitPosition);
        }
        LastValue = THistogram(std::move(hist));
    }

    void THistogramDecoder::ReadPartialSimpleSingleValue(const TString& stream, size_t& bitPosition) {
        LastValue.MutableSimpleHistogram().MutableValues()[0] = ReadFromBitStreamUncompressed<double>(stream, bitPosition);
    }

    void THistogramDecoder::ReadPartialSimpleZeroCount(const TString& stream, size_t& bitPosition) {
        LastValue.MutableSimpleHistogram().SetZeroCount(ReadCompressedUint64(stream, bitPosition));
    }

    void THistogramDecoder::ReadPartialSimpleNormal(const TString& stream, size_t& bitPosition) {
        auto& last = LastValue.MutableSimpleHistogram();

        bool zeroCountChanged = ReadFromBitStream(stream, bitPosition, 1);
        if (zeroCountChanged) {
            last.SetZeroCount(ReadCompressedUint64(stream, bitPosition));
        }

        bool valuesChanged = ReadFromBitStream(stream, bitPosition, 1);
        if (valuesChanged) {
            // actual data is aligned
            bitPosition = AlignToByteBoundary(bitPosition);
            ui8 itemsCount = AlignedReadFromBitStreamUncompressed<ui8>(stream, bitPosition);
            last.MutableValues().resize(itemsCount);
            for (size_t i = 0; i < itemsCount; i++) {
                last.MutableValues()[i] = AlignedReadFromBitStreamUncompressed<double>(stream, bitPosition);
            }
        }
    }

    void THistogramDecoder::ReadPartialLog(const TString& stream, size_t& bitPosition) {
        auto& last = LastValue.MutableLogHistogram();

        bool zeroCountChanged = ReadFromBitStream(stream, bitPosition, 1);
        if (zeroCountChanged) {
            last.SetZeroCount(ReadCompressedUint64(stream, bitPosition));
        }

        bool startPowerChanged = ReadFromBitStream(stream, bitPosition, 1);
        if (startPowerChanged) {
            last.SetStartPower(ReadFromBitStreamUncompressed<i16>(stream, bitPosition));
        }

        bool weightsChanged = ReadFromBitStream(stream, bitPosition, 1);
        if (weightsChanged) {
            bitPosition = AlignToByteBoundary(bitPosition);
            ui8 itemsCount = AlignedReadFromBitStreamUncompressed<ui8>(stream, bitPosition);
            last.MutableWeights().resize(itemsCount);
            for (size_t i = 0; i < itemsCount; i++) {
                last.MutableWeights()[i] = AlignedReadCompressedUint64(stream, bitPosition);
            }
        }
    }

    void THistogramDecoder::ReadPartialUser(const TString& stream, size_t& bitPosition) {
        auto& last = LastValue.MutableUserHistogram();

        bool bordersChanged = ReadFromBitStream(stream, bitPosition, 1);
        // rest of the data is aligned
        bitPosition = AlignToByteBoundary(bitPosition);
        if (bordersChanged) {
            ui8 itemsCount = AlignedReadFromBitStreamUncompressed<ui8>(stream, bitPosition);
            last.MutableBuckets().resize(itemsCount);
            for (size_t i = 0; i < last.GetBuckets().size(); i++) {
                last.MutableBuckets()[i].LowerBound = AlignedReadFromBitStreamUncompressed<double>(stream, bitPosition);
            }
        }

        for (size_t i = 0; i < last.GetBuckets().size(); i++) {
            last.MutableBuckets()[i].Weight = AlignedReadCompressedUint64(stream, bitPosition);
        }
    }

    inline void THistogramEncoder::WriteVector(const TVector<ui64>& values, bool writeSize, TString& stream, ui32& bitPosition) {
        if (writeSize) {
            if (values.size() > 255) {
                throw yexception() << "More than 255 values in histogram";
            }
            // 8 bits for count
            AlignedAddToBitStreamUncompressed<ui8>(values.size(), stream, bitPosition);
        }
        for (ui64 value : values) {
            AlignedWriteCompressedUint64(value, stream, bitPosition);
        }
    }

    inline void THistogramEncoder::WriteVector(const TVector<double>& values, bool writeSize, TString& stream, ui32& bitPosition) {
        if (writeSize) {
            if (values.size() > 255) {
                throw yexception() << "More than 255 values in histogram";
            }
            // 8 bits for count
            AlignedAddToBitStreamUncompressed<ui8>(values.size(), stream, bitPosition);
        }
        for (double value : values) {
            // 64 bits for each double value
            AlignedAddToBitStreamUncompressed(value, stream, bitPosition);
        }
    }

    void THistogramEncoder::Write(const TSimpleHistogram& hist, TString& stream, ui32& bitPosition) {
        if (!LastValue.IsNull() &&
            LastValue.GetKind() == EHistogramKind::Simple &&
            LastValue.AsSimpleHistogram().GetSimpleKind() == hist.GetSimpleKind()) {
            // '0' - type hasn't changed, partial encoding
            AddToBitStream(0, 1, stream, bitPosition);
            // this should be impossible - empty histogram cannot change its value without changing its type
            Y_VERIFY(!hist.IsEmpty());

            const auto& lastSimple = LastValue.AsSimpleHistogram();
            if (hist.IsZeroCount()) {
                // 64 bits for zero count
                WriteCompressedUint64(hist.GetZeroCount(), stream, bitPosition);
            } else if (hist.IsSingleValue()) {
                // 64 bits for single value
                AddToBitStreamUncompressed<double>(hist.GetValues()[0], stream, bitPosition);
            } else {
                if (hist.GetZeroCount() == lastSimple.GetZeroCount()) {
                    // '0' - zero count hasn't changed
                    AddToBitStream(0, 1, stream, bitPosition);
                } else {
                    // '1' - zero count has changed
                    AddToBitStream(1, 1, stream, bitPosition);
                    WriteCompressedUint64(hist.GetZeroCount(), stream, bitPosition);
                }

                if (hist.GetValues() == lastSimple.GetValues()) {
                    // '0' - values haven't changed
                    AddToBitStream(0, 1, stream, bitPosition);
                } else {
                    // '1' - values changed
                    AddToBitStream(1, 1, stream, bitPosition);

                    bitPosition = AlignToByteBoundary(bitPosition);
                    WriteVector(hist.GetValues(), true, stream, bitPosition);
                }
            }
        } else {
            // '1' - type has changed, full encoding
            AddToBitStream(1, 1, stream, bitPosition);

            if (hist.IsEmpty()) {
                // 3 bits for typecode, no further info needed
                AddToBitStream(SIMPLE_EMPTY_CODE, TYPE_CODE_SIZE_BITS, stream, bitPosition);
            } else if (hist.IsZeroCount()) {
                // 3 bits for typecode
                AddToBitStream(SIMPLE_ZERO_COUNT_CODE, TYPE_CODE_SIZE_BITS, stream, bitPosition);
                WriteCompressedUint64(hist.GetZeroCount(), stream, bitPosition);
            } else if (hist.IsSingleValue()) {
                // 3 bits for typecode
                AddToBitStream(SIMPLE_SINGLE_VALUE_CODE, TYPE_CODE_SIZE_BITS, stream, bitPosition);
                // 64 bits for single value
                AddToBitStreamUncompressed<double>(hist.GetValues()[0], stream, bitPosition);
            } else {
                // 3 bits for typecode
                AddToBitStream(SIMPLE_NORMAL_CODE, TYPE_CODE_SIZE_BITS, stream, bitPosition);

                bitPosition = AlignToByteBoundary(bitPosition);
                AlignedWriteCompressedUint64(hist.GetZeroCount(), stream, bitPosition);
                WriteVector(hist.GetValues(), true, stream, bitPosition);
            }
        }
    }

    void THistogramEncoder::Write(const TUserHistogram& hist, TString& stream, ui32& bitPosition) {
        if (!LastValue.IsNull() && LastValue.GetKind() == EHistogramKind::User) {
            // '0' - type hasn't changed, partial encoding
            AddToBitStream(0, 1, stream, bitPosition);

            auto& last = LastValue.AsUserHistogram();

            bool bordersEqual = hist.GetBuckets().size() == last.GetBuckets().size() &&
                                Equal(hist.GetBuckets().begin(), hist.GetBuckets().end(), last.GetBuckets().begin(),
                                      [](auto& left, auto& right) { return left.LowerBound == right.LowerBound; });

            if (bordersEqual) {
                // '0' - bucket bounds haven't changed
                AddToBitStream(0, 1, stream, bitPosition);

                bitPosition = AlignToByteBoundary(bitPosition);
            } else {
                // '1' - bucket bounds changed
                AddToBitStream(1, 1, stream, bitPosition);

                bitPosition = AlignToByteBoundary(bitPosition);
                if (hist.GetBuckets().size() > 255) {
                    ythrow yexception() << "Ugram has more than 255 buckets [" << ToString(hist.GetBuckets().size()) << "]";
                }
                // 8 bits for buckets count
                AlignedAddToBitStreamUncompressed<ui8>(hist.GetBuckets().size(), stream, bitPosition);
                for (const auto& bucket : hist.GetBuckets()) {
                    // 64 bits for each lower bound
                    AlignedAddToBitStreamUncompressed<double>(bucket.LowerBound, stream, bitPosition);
                }
            }
            for (const auto& bucket : hist.GetBuckets()) {
                AlignedWriteCompressedUint64(bucket.Weight, stream, bitPosition);
            }
        } else {
            // '1' - type has changed, full encoding
            AddToBitStream(1, 1, stream, bitPosition);
            AddToBitStream(UGRAM_CODE, TYPE_CODE_SIZE_BITS, stream, bitPosition);

            // align on byte boundary
            bitPosition = AlignToByteBoundary(bitPosition);
            if (hist.GetBuckets().size() > 255) {
                ythrow yexception() << "Ugram has more than 255 buckets [" << ToString(hist.GetBuckets().size()) << "]";
            }
            // 8 bits for buckets count
            AlignedAddToBitStreamUncompressed<ui8>(hist.GetBuckets().size(), stream, bitPosition);
            for (const auto& bucket : hist.GetBuckets()) {
                // 64 bits for each lower bound
                AlignedAddToBitStreamUncompressed<double>(bucket.LowerBound, stream, bitPosition);
            }
            for (const auto& bucket : hist.GetBuckets()) {
                AlignedWriteCompressedUint64(bucket.Weight, stream, bitPosition);
            }
        }
    }

    void THistogramEncoder::Write(const TLogHistogram& hist, TString& stream, ui32& bitPosition) {
        if (!LastValue.IsNull() && LastValue.GetKind() == EHistogramKind::Log) {
            // '0' - type hasn't changed, partial encoding
            AddToBitStream(0, 1, stream, bitPosition);
            const auto& last = LastValue.AsLogHistogram();

            if (hist.GetZeroCount() == last.GetZeroCount()) {
                // '0' - zero count hasn't changed
                AddToBitStream(0, 1, stream, bitPosition);
            } else {
                // '1' - zero count has changed
                AddToBitStream(1, 1, stream, bitPosition);
                WriteCompressedUint64(hist.GetZeroCount(), stream, bitPosition);
            }

            if (hist.GetStartPower() == last.GetStartPower()) {
                // '0' - start power hasn't changed
                AddToBitStream(0, 1, stream, bitPosition);
            } else {
                // '1' - start power has changed
                AddToBitStream(1, 1, stream, bitPosition);
                AddToBitStreamUncompressed<i16>(hist.GetStartPower(), stream, bitPosition);
            }

            if (hist.GetWeights().size() > 255) {
                ythrow yexception() << "Log hist has more than 255 weights [" << ToString(hist.GetWeights().size()) << "]";
            }
            if (hist.GetWeights() == last.GetWeights()) {
                // '0' - weights haven't changed
                AddToBitStream(0, 1, stream, bitPosition);
            } else {
                // '1' - weights changed
                AddToBitStream(1, 1, stream, bitPosition);

                bitPosition = AlignToByteBoundary(bitPosition);
                WriteVector(hist.GetWeights(), true, stream, bitPosition);
            }

        } else {
            // '1' - type has changed, full encoding
            AddToBitStream(1, 1, stream, bitPosition);
            AddToBitStream(LOGARITHMIC_CODE, TYPE_CODE_SIZE_BITS, stream, bitPosition);

            bitPosition = AlignToByteBoundary(bitPosition);
            AlignedWriteCompressedUint64(hist.GetZeroCount(), stream, bitPosition);
            AlignedAddToBitStreamUncompressed<i16>(hist.GetStartPower(), stream, bitPosition);

            WriteVector(hist.GetWeights(), true, stream, bitPosition);
        }
    }

    void THistogramEncoder::Write(THistogram value, TString& stream, ui32& bitPosition) {
        if (LastValue == value) {
            // '0' - value hasn't changed
            AddToBitStream(0, 1, stream, bitPosition);
            return;
        }
        // '1' - value has changed in some way
        AddToBitStream(1, 1, stream, bitPosition);

        switch (value.GetKind()) {
            case EHistogramKind::Simple:
                Write(value.AsSimpleHistogram(), stream, bitPosition);
                break;
            case EHistogramKind::User:
                Write(value.AsUserHistogram(), stream, bitPosition);
                break;
            case EHistogramKind::Log:
                Write(value.AsLogHistogram(), stream, bitPosition);
                break;
        }

        LastValue = std::move(value);
    }
} // namespace NYasmServer
