#pragma once

#include <util/generic/bitops.h>
#include <util/stream/input.h>
#include <util/stream/output.h>

#include <tuple>

namespace NNetmon {
    namespace {
        template <class T>
        constexpr inline ui8 LeadingZeros(T value) noexcept {
            static_assert(sizeof(T) <= std::numeric_limits<ui8>::max(), "too big value type");
            constexpr std::size_t length(sizeof(T) * CHAR_BIT);
            return value ? (length - GetValueBitCount(value)) : length;
        }

        template <class T>
        constexpr inline ui8 TrailingZeros(T value) noexcept {
            static_assert(sizeof(T) <= std::numeric_limits<ui8>::max(), "too big value type");
            constexpr std::size_t length(sizeof(T) * CHAR_BIT);
            return value ? CountTrailingZeroBits(value) : length;
        }

        template <class T>
        inline ui8 EncodeCount(ui8 zeros) noexcept {
            const ui8 length(sizeof(T) - zeros / CHAR_BIT);
            return length ? length - 1 : length;
        }

        template <class T>
        inline ui8 ShiftCount(ui8 flags) noexcept {
            return (sizeof(T) - (flags & (CHAR_BIT - 1)) - 1) * (flags & CHAR_BIT);
        }

        inline ui8 FlagsToBytes(ui8 flags) noexcept {
            return (flags & (CHAR_BIT - 1)) + 1;
        }

        template <class T>
        inline std::pair<T, ui8> EncodeOneValue(T value) noexcept {
            const ui8 leadingZeros = LeadingZeros(value);
            const ui8 trailingZeros = TrailingZeros(value);

            ui8 flags, bytesNumber;
            if (trailingZeros > leadingZeros && trailingZeros - leadingZeros >= CHAR_BIT) {
                // trailing zeros can be truncated
                bytesNumber = EncodeCount<T>(trailingZeros);
                flags = CHAR_BIT | (bytesNumber & (CHAR_BIT - 1));
            } else {
                // leading zeros can be truncated
                bytesNumber = EncodeCount<T>(leadingZeros);
                flags = bytesNumber & (CHAR_BIT - 1);
            }


            return {value >> ShiftCount<T>(flags), flags};
        }
    }

    enum class EVarintStatus {
        END_OF_STREAM,
        ONE_VALUE,
        TWO_VALUES
    };

    template <class T>
    inline void VarintEncode(IOutputStream& stream, T first, T second) {
        T lowPayload, highPayload;
        ui8 lowFlags, highFlags;

        std::tie(lowPayload, lowFlags) = EncodeOneValue(first);
        std::tie(highPayload, highFlags) = EncodeOneValue(second);

        const ui8 flags(lowFlags | (highFlags << 4));
        stream.Write(reinterpret_cast<const char*>(&flags), sizeof(ui8));
        stream.Write(reinterpret_cast<const char*>(&lowPayload), FlagsToBytes(lowFlags));
        stream.Write(reinterpret_cast<const char*>(&highPayload), FlagsToBytes(highFlags));
    }

    template <class T>
    inline void VarintEncode(IOutputStream& stream, T first) {
        T lowPayload;
        ui8 lowFlags;

        std::tie(lowPayload, lowFlags) = EncodeOneValue(first);

        const ui8 flags(lowFlags | (0xF << 4));
        stream.Write(reinterpret_cast<const char*>(&flags), sizeof(ui8));
        stream.Write(reinterpret_cast<const char*>(&lowPayload), FlagsToBytes(lowFlags));
    }

    template <>
    inline void VarintEncode<double>(IOutputStream& stream, double first, double second) {
        static_assert(sizeof(double) == sizeof(ui64), "too long value type");
        VarintEncode<ui64>(stream, reinterpret_cast<ui64&>(first), reinterpret_cast<ui64&>(second));
    }

    template <>
    inline void VarintEncode<double>(IOutputStream& stream, double first) {
        static_assert(sizeof(double) == sizeof(ui64), "too long value type");
        VarintEncode<ui64>(stream, reinterpret_cast<ui64&>(first));
    }

    template <class T>
    inline EVarintStatus VarintDecode(IInputStream& stream, T& first, T& second) {
        static_assert(sizeof(T) <= sizeof(ui64), "too long value type");

        ui8 flags;
        if (stream.Read(reinterpret_cast<char*>(&flags), sizeof(ui8)) != sizeof(ui8)) {
            return EVarintStatus::END_OF_STREAM;
        }

        first = 0;
        ui8 firstFlags(flags & 0x0F);
        ui8 firstLength = (firstFlags & (CHAR_BIT - 1)) + 1;
        if (firstLength > sizeof(T)) {
            return EVarintStatus::END_OF_STREAM;
        } else {
            if (stream.Read(reinterpret_cast<char*>(&first), firstLength) != firstLength) {
                return EVarintStatus::END_OF_STREAM;
            }
            first <<= ShiftCount<T>(firstFlags);
        }

        second = 0;
        ui8 secondFlags((flags & 0xF0) >> 4);
        if (secondFlags == 0xF) {
            return EVarintStatus::ONE_VALUE;
        }

        ui8 secondLength = (secondFlags & (CHAR_BIT - 1)) + 1;
        if (secondLength > sizeof(T)) {
            return EVarintStatus::END_OF_STREAM;
        } else {
            if (stream.Read(reinterpret_cast<char*>(&second), secondLength) != secondLength) {
                return EVarintStatus::END_OF_STREAM;
            }
            second <<= ShiftCount<T>(secondFlags);
        }

        return EVarintStatus::TWO_VALUES;
    }

    template <>
    inline EVarintStatus VarintDecode<double>(IInputStream& stream, double& first, double& second) {
        static_assert(sizeof(double) == sizeof(ui64), "too long value type");
        return VarintDecode<ui64>(stream, reinterpret_cast<ui64&>(first), reinterpret_cast<ui64&>(second));
    }
}
