#include <util/system/compiler.h>
#include <util/system/yassert.h>

#include <cstdio>
#include <memory>

#include <dlfcn.h>

namespace {

    thread_local bool initializing = false;

    struct RealAlloc {
        using MallocFn = void* (*)(size_t);
        using FreeFn = void (*)(void*);
        using CallocFn = void* (*)(size_t, size_t);
        using ReallocFn = void* (*)(void*, size_t);

        using PoxisMemalignFn = int (*)(void**, size_t, size_t);
        using AlignedAllocFn = void* (*)(size_t, size_t);
        using VallocFn = void* (*)(size_t);
        using MemalignFn = void* (*)(size_t, size_t);
        using PvallocFn = void* (*)(size_t);

        RealAlloc() {
            initializing = true;

            malloc = loadSymbol<MallocFn>("malloc");
            free = loadSymbol<FreeFn>("free");
            calloc = loadSymbol<CallocFn>("calloc");
            realloc = loadSymbol<ReallocFn>("realloc");

            posix_memalign = loadSymbol<PoxisMemalignFn>("posix_memalign");
            aligned_alloc = loadSymbol<AlignedAllocFn>("aligned_alloc");
            valloc = loadSymbol<VallocFn>("valloc");
            memalign = loadSymbol<MemalignFn>("memalign");
            pvalloc = loadSymbol<PvallocFn>("pvalloc");

            initializing = false;
        }

        template <class Fn>
        Fn loadSymbol(const char* name) {
            void* sym = dlsym(RTLD_NEXT, name);
            if (!sym) {
                fprintf(stderr, "panic: dlsym() failed\n");
                std::abort();
            }

            Dl_info info;
            if (!dladdr(sym, &info)) {
                fprintf(stderr, "panic: dladdr() failed\n");
                std::abort();
            }

            if (strstr(info.dli_fname, "/ld-linux")) {
                fprintf(stderr, "panic: dlsym() returned symbol from linker\n");
                std::abort();
            }

            return reinterpret_cast<Fn>(sym);
        }

        MallocFn malloc;
        FreeFn free;
        CallocFn calloc;
        ReallocFn realloc;

        PoxisMemalignFn posix_memalign;
        AlignedAllocFn aligned_alloc;
        VallocFn valloc;
        MemalignFn memalign;
        PvallocFn pvalloc;

        static RealAlloc& instance() {
            static RealAlloc instance;
            return instance;
        }
    };

    class StaticAlloc {
    public:
        static constexpr size_t N = 65536;

        StaticAlloc()
            : Ptr_(Buffer_)
            , Size_(N)
        {
        }

        void* alloc(size_t size) {
            if (!size) {
                return nullptr;
            }

            void* p = std::align(alignof(std::max_align_t), size, Ptr_, Size_);
            if (!p) {
                return nullptr;
            }

            Ptr_ = static_cast<char*>(Ptr_) + size;
            Size_ -= size;

            return p;
        }

        inline bool contains(void* ptr) {
            return Buffer_ <= ptr && ptr < Ptr_;
        }

        static StaticAlloc& instance() {
            static StaticAlloc instance;
            return instance;
        }

    private:
        char Buffer_[N];
        void* Ptr_;
        size_t Size_;
    };

} // namespace

#define ENSURE_INITIALIZED()                                                            \
    if (Y_UNLIKELY(initializing)) {                                                     \
        fprintf(stderr, "panic: %s() called while initializing hooks\n", __FUNCTION__); \
        std::abort();                                                                   \
    }

extern "C" {

    void* malloc(size_t size) {
        if (Y_UNLIKELY(initializing)) {
            return StaticAlloc::instance().alloc(size);
        }
        return RealAlloc::instance().malloc(size);
    }

    void free(void* ptr) {
        if (!ptr) {
            return;
        }

        if (Y_UNLIKELY(StaticAlloc::instance().contains(ptr))) {
            return;
        }
        return RealAlloc::instance().free(ptr);
    }

    void* calloc(size_t nmemb, size_t size) {
        if (Y_UNLIKELY(initializing)) {
            void* p = StaticAlloc::instance().alloc(nmemb * size);
            if (!p) {
                return nullptr;
            }
            memset(p, 0, nmemb * size);
            return p;
        }
        return RealAlloc::instance().calloc(nmemb, size);
    }

    void* realloc(void* ptr, size_t size) {
        ENSURE_INITIALIZED()
        return RealAlloc::instance().realloc(ptr, size);
    }

    int posix_memalign(void** memptr, size_t alignment, size_t size) {
        ENSURE_INITIALIZED()
        return RealAlloc::instance().posix_memalign(memptr, alignment, size);
    }

    void* aligned_alloc(size_t alignment, size_t size) {
        ENSURE_INITIALIZED()
        return RealAlloc::instance().aligned_alloc(alignment, size);
    }

    void* valloc(size_t size) {
        ENSURE_INITIALIZED()
        return RealAlloc::instance().valloc(size);
    }

    void* memalign(size_t alignment, size_t size) {
        ENSURE_INITIALIZED()
        return RealAlloc::instance().memalign(alignment, size);
    }

    void* pvalloc(size_t size) {
        ENSURE_INITIALIZED()
        return RealAlloc::instance().pvalloc(size);
    }
}
