#pragma once

#include <util/datetime/base.h>
#include <util/generic/guid.h>
#include <util/generic/hash_set.h>
#include <util/generic/variant.h>
#include <util/generic/vector.h>
#include <util/generic/yexception.h>
#include <util/system/guard.h>
#include <util/system/mutex.h>

#include <type_traits>

namespace NTravel {
    template <class T, typename Enable = void>
    struct TTotalByteSize {
    };

    template <class T>
    struct TTotalByteSize<T, typename std::enable_if_t<std::is_scalar_v<T>>> {
        inline size_t operator()(const T&) const {
            return sizeof(T);
        }
    };

    template <class T>
    struct TTotalByteSize<T, typename std::enable_if_t<std::is_same_v<decltype(std::declval<T>().CalcTotalByteSize()), size_t>>> {
        inline size_t operator()(const T& item) const {
            return item.CalcTotalByteSize();
        }
    };

    template <>
    struct TTotalByteSize<TString, void> {
        inline size_t operator()(const TString& item) const {
            return 32 + item.capacity();
        }
    };

    template <>
    struct TTotalByteSize<std::monostate, void> {
        inline size_t operator()(const std::monostate&) const {
            return sizeof(std::monostate);
        }
    };

    template <class T1, class T2>
    struct TTotalByteSize<std::variant<T1, T2>, void> {
        inline size_t operator()(const std::variant<T1, T2>& item) const {
            if (std::holds_alternative<T1>(item)) {
                return sizeof(item) + TTotalByteSize<T1>(std::get<T1>(item));
            } else {
                return sizeof(item) + TTotalByteSize<T2>(std::get<T2>(item));
            }
        }
    };

    template <class T1, class T2, class T3>
    struct TTotalByteSize<std::variant<T1, T2, T3>, void> {
        inline size_t operator()(const std::variant<T1, T2, T3>& item) const {
            if (std::holds_alternative<T1>(item)) {
                return sizeof(item) + TTotalByteSize<T1>()(std::get<T1>(item)) - sizeof(T1);
            } else if (std::holds_alternative<T2>(item)) {
                return sizeof(item) + TTotalByteSize<T2>()(std::get<T2>(item)) - sizeof(T2);
            } else {
                return sizeof(item) + TTotalByteSize<T3>()(std::get<T3>(item)) - sizeof(T3);
            }
        }
    };

    template <>
    struct TTotalByteSize<TGUID, void> {
        inline size_t operator()(const TGUID&) const {
            return sizeof(TGUID);
        }
    };

    template <class TFirst, class TSecond>
    struct TTotalByteSize<std::pair<TFirst, TSecond>, void> {
        inline size_t operator()(const std::pair<TFirst, TSecond>& item) const {
            return TTotalByteSize<TFirst>()(item.first) + TTotalByteSize<TSecond>()(item.second);
        }
    };

    template <class T>
    struct TTotalByteSize<TMaybe<T>, void> {
        inline size_t operator()(const TMaybe<T>& item) const {
            return sizeof(item) + (item.Defined() ? TTotalByteSize<T>()(item.GetRef()) - sizeof(T) : 0);
        }
    };

    template <typename T>
    size_t GetTotalByteSize(const T& item) noexcept {
        return TTotalByteSize<T>()(item);
    }

    template <typename T>
    size_t GetByteSizeWithoutSizeof(const T& item) noexcept {
        return TTotalByteSize<T>()(item) - sizeof(item);
    }

    template <typename T>
    size_t GetVectorByteSizeWithoutElementAllocations(const TVector<T>& vector) noexcept {
        return sizeof(vector) + sizeof(T) * vector.capacity();
    }

    template <typename K, typename V>
    size_t GetHashMapByteSizeWithoutElementAllocations(const THashMap<K, V>& hashMap) noexcept {
        // Inspired by https://a.yandex-team.ru/arc/trunk/arcadia/kernel/gazetteer/common/bytesize.h?rev=r7864195#L256

        typedef typename THashMap<K, V>::value_type TValue; // value_type is pair<K, V>
        typedef __yhashtable_node<TValue> TNode; // TNode is ptr to next node + TValue

        const size_t bucketsOverhead = hashMap.bucket_count() * sizeof(TNode*); // hashMap.bucket_count() is size of inner array of TNode*
        const size_t valuesSize = hashMap.size() * sizeof(TNode); // each item creates TNode

        return sizeof(THashMap<K, V>) + valuesSize + bucketsOverhead;
    }

    template <typename TOne>
    size_t GetTotalByteSizeMulti(const TOne& one) noexcept {
        return TTotalByteSize<TOne>()(one);
    }

    template <typename THead, typename... TTail>
    size_t GetTotalByteSizeMulti(const THead& head, TTail&&... tail) noexcept {
        return GetTotalByteSizeMulti(tail...) + TTotalByteSize<THead>()(head);
    }
}
