#pragma once

#include <infra/ebpf-agent/lib/utils.h>
#include <infra/ebpf-agent/lib/types.h>
#include <infra/ebpf-agent/lib/fd.h>

#include <infra/ebpf-agent/progs/include/net_stat.h>

#include <util/generic/hash_set.h>
#include <util/generic/vector.h>
#include <util/generic/hash.h>
#include <util/digest/multi.h>
#include <util/string/hex.h>

#include <libbpf.h>

#define MAX_CONG_LEN 16

struct net_stat operator+(const struct net_stat& a, const struct net_stat& b);
struct cgroup_net_stat operator+(const struct cgroup_net_stat& a, const struct cgroup_net_stat& b);
struct cgroup_net_stat_dc operator+(const struct cgroup_net_stat_dc& a, const struct cgroup_net_stat_dc& b);

TString ToString(struct cgroup_net_stat& stat);
TString ToString(struct cgroup_net_stat_dc& stat);

namespace NEbpfAgent {

    class TBpfProgram: public TNonCopyable {
    public:
        ~TBpfProgram();

        inline bool operator==(const struct bpf_prog_info& info) const noexcept {
            return Name == info.name && Type == info.type;
        }

        inline operator const TBpfFd&() const noexcept {
            return Fd;
        }

        bool IsOutdated(const struct bpf_prog_info& prog) const noexcept {
            return memcmp(Info.tag, prog.tag, Y_ARRAY_SIZE(prog.tag));
        }

        const TStringBuf GetName() const noexcept {
            return Name;
        }

        const TStringBuf GetMinKernelVersion() const noexcept {
            return MinKernelVersion;
        }

        bool IsEnabled() const noexcept {
            return Enabled;
        }

        void Load() const noexcept;
        TBpfFd GetMap(const TStringBuf& name) const noexcept;

        static std::size_t ListLoadedProgs(TVector<struct bpf_prog_info>* progs = nullptr);
        static inline std::size_t CountLoadedProgs() {
            return ListLoadedProgs(nullptr);
        }

        static const TVector<const TBpfProgram*>& KnownProgs() noexcept;

        static void EnableYttl() noexcept;

        // These fd are public so that daemon.cpp can access them
        static TBpfFd YttlBlacklistNetsMapFd;
        static TBpfFd TcpBytesCountersMapFd;

    protected:
        TBpfProgram(const char* name, const unsigned char* bytes, unsigned int size,
                    const char* minKernelVersion = "4.19", bool enabled = false) noexcept;

        static const TBpfFd& InitConfigMap() noexcept;
        static const TBpfFd& InitRtoMap() noexcept;
        static const TBpfFd& InitYaNetsMap() noexcept;
        static const TBpfFd& InitProjectIdMap() noexcept;
        static const TBpfFd& InitYttlBlacklistNetsMap() noexcept;
        static const TBpfFd& InitTcpBytesCountersMap() noexcept;

        typedef TVector<TVector<TString>> TNetCongArray;
        static void InitCongMap(TBpfFd&, const TNetCongArray&) noexcept;
        static const TBpfFd& InitBbCongMap() noexcept;
        static const TBpfFd& InitFbCongMap() noexcept;

        typedef const TBpfFd& (*TSharedMapInitFunc)();
        void TryReuseMap(const TStringBuf& name, TSharedMapInitFunc initFunc) const noexcept;

        static bool IsEnabledByConfig(const TStringBuf& name);
        static struct bpf_prog_info GetProgInfo(const TBpfFd& fd, TVector<ui32>* mapIds = nullptr);
        static struct bpf_map_info GetMapInfo(const TBpfFd& fd);

        static TBpfFd ConfigMapFd;
        static TBpfFd RtoMapFd;
        static TBpfFd YaNetsMapFd;
        static TBpfFd ProjectIdMapFd;
        static TBpfFd BbCongMapFd;
        static TBpfFd FbCongMapFd;

        const TStringBuf Name;
        const TStringBuf MinKernelVersion;
        bool Enabled;
        struct bpf_object* Obj;
        struct bpf_program* Prog;
        const TStringBuf Title;
        enum bpf_prog_type Type;
        enum bpf_attach_type ExpectedAttachType;
        enum bpf_attach_type AttachType;
        mutable TBpfFd Fd;
        mutable struct bpf_prog_info Info;
    };

    class TCgroupBpfProgram: public TBpfProgram {
    public:
        TCgroupBpfProgram(const char* name, const unsigned char* bytes, unsigned int size,
                          const char* minKernelVersion = "4.19", bool enabled = false,
                          int attachFlags = 0)
            : TBpfProgram(name, bytes, size, minKernelVersion, enabled)
            , AttachFlags(attachFlags)
        {
        }

        TCgroupBpfProgram(const char* name, const unsigned char* bytes, unsigned int size,
                          int attachFlags)
            : TBpfProgram(name, bytes, size, "4.19", false)
            , AttachFlags(attachFlags)
        {
        }

        ~TCgroupBpfProgram();

        void Attach(const TCgroupFd& cgFd) const;
        inline void Attach(const TString& path) const {
            Attach(OpenCgroup(path));
        }

        inline void Detach() const {
            for (const auto& child: Children) {
                child.Detach();
            }
            Children.clear();
        }

        virtual void AddChild(const struct bpf_prog_info& prog, const TString& attachPath) const {
            TAttachedProgram attachedProg(*this, prog.id, attachPath);
            if (!Children.contains(attachedProg)) {
                Children.emplace(std::move(attachedProg));
            }
        }

        static void Attach(const TBpfFd& progFd, const TCgroupFd& cgFd, enum bpf_attach_type type, unsigned int flags = 0);
        static inline void Attach(const TBpfFd& progFd, const TString& path, enum bpf_attach_type type, unsigned int flags = 0) {
            Attach(progFd, OpenCgroup(path), type, flags);
        }

        static void Detach(const TBpfFd& progFd, const TCgroupFd& cgFd, enum bpf_attach_type type);
        static inline void Detach(const TBpfFd& progFd, const TString& path, enum bpf_attach_type type) {
            Detach(progFd, OpenCgroup(path), type);
        }

        static TVector<struct bpf_prog_info> ListAttachedProgs(const TCgroupFd& cgFd, enum bpf_attach_type type);
        static inline TVector<struct bpf_prog_info> ListAttachedProgs(const TString& path, enum bpf_attach_type type) {
            return ListAttachedProgs(OpenCgroup(path), type);
        }

        static THashMap<enum bpf_attach_type, TVector<struct bpf_prog_info>> ListAttachedProgs(const TCgroupFd& cgFd);
        static inline THashMap<enum bpf_attach_type, TVector<struct bpf_prog_info>> ListAttachedProgs(const TString& path) {
            return ListAttachedProgs(OpenCgroup(path));
        }

        static void Check(const TVector<const TCgroupBpfProgram*>& knownProgs, const TCgroupFd& cgFd, bool repair);
        static void Check(const TVector<const TCgroupBpfProgram*>& knownProgs, const TString& path, bool repair) {
            Check(knownProgs, OpenCgroup(path), repair);
        }

        static const TCgroupBpfProgram& TcpTos();
        static const TCgroupBpfProgram& TclassLock();

    protected:
        class TAttachedProgram {
        public:
            TAttachedProgram(const TCgroupBpfProgram& parent, ui32 id, const TString& attachPath)
                : Parent(parent)
                , Id(id)
                , AttachPath(attachPath)
            {
                Y_VERIFY(Id > 0);
                Y_VERIFY(AttachPath);
            }

            void Detach() const;

            inline operator size_t() const {
                return MultiHash(Id, AttachPath);
            }

        private:
            const TCgroupBpfProgram& Parent;
            mutable ui32 Id;
            mutable TString AttachPath;
        };

        int AttachFlags;
        mutable THashSet<TAttachedProgram> Children;
    };

    class TTcpRtoBpfProgram: public TCgroupBpfProgram {
    public:
        static const TTcpRtoBpfProgram& TcpRto();

        static bool IsAnyTcpCounterEnabled() noexcept;

        static void UpdateMetrics();
    private:
        TTcpRtoBpfProgram(const char* name, const unsigned char* bytes, unsigned int size,
                          int attachFlags)
            : TCgroupBpfProgram(name, bytes, size, attachFlags)
        {
        }
    };

    class TNetStatBpfProgram: public TCgroupBpfProgram {
    public:
        static const TNetStatBpfProgram& Rx(bool forceEnable = false, bool allowMulti = false);
        static const TNetStatBpfProgram& Tx(bool forceEnable = false, bool allowMulti = false);

    private:
        TNetStatBpfProgram(const char* name, const unsigned char* bytes, unsigned int size,
                           bool enabled = false, int attachFlags = 0)
            : TCgroupBpfProgram(name, bytes, size, "4.19.91", enabled, attachFlags)
        {
        }
    };

    class TNetStatDcBpfProgram: public TCgroupBpfProgram {
    public:
        static const TNetStatDcBpfProgram& Rx(bool forceEnable = false, bool allowMulti = false);
        static const TNetStatDcBpfProgram& Tx(bool forceEnable = false, bool allowMulti = false);

        struct cgroup_net_stat_dc GetStats() const;

        void AddChild(const struct bpf_prog_info& prog, const TString& attachPath) const override {
            TAttachedProgram attachedProg(*this, prog.id, attachPath);
            if (!Children.contains(attachedProg)) {
                Children.emplace(std::move(attachedProg));
                // reset StatMapFd to already used one
                if (!IsOutdated(prog) && prog.nr_map_ids) {
                    TVector<ui32> mapIds(prog.nr_map_ids);
                    (void)GetProgInfo(GetProgFdById(prog.id), &mapIds);
                    for (auto mapId: mapIds) {
                        auto mapFd = GetMapFdById(mapId);
                        auto mapInfo = GetMapInfo(mapFd);
                        if (!strncmp(mapInfo.name, "net_stat_dc_", 12)) {
                            StatMapFd = std::move(mapFd);
                            break;
                        }
                    }
                }
            }
        }

        static bool IsAnyEnabled() noexcept {
            return (Rx().Enabled || Tx().Enabled);
        }
        static void UpdateMetrics();

    private:
        TNetStatDcBpfProgram(const char* name, const unsigned char* bytes, unsigned int size,
                             bool enabled = false, int attachFlags = 0, bool rx = true)
            : TCgroupBpfProgram(name, bytes, size, "4.19.91", enabled, attachFlags)
            , StatMapFd(GetMap(rx ? "net_stat_dc_rx_map" : "net_stat_dc_tx_map"))
        {
            if (Enabled) {
                Y_VERIFY(StatMapFd >= 0);
            }
        }

        mutable TBpfFd StatMapFd;
    };

    class TCgroupRootBpfProgram: public TCgroupBpfProgram {
    public:
        static inline TVector<struct bpf_prog_info> ListAttachedProgs(enum bpf_attach_type type) {
            return TCgroupBpfProgram::ListAttachedProgs(FindCgroupRoot(), type);
        }

        static inline THashMap<enum bpf_attach_type, TVector<struct bpf_prog_info>> ListAttachedProgs() {
            return TCgroupBpfProgram::ListAttachedProgs(FindCgroupRoot());
        }

        static inline void DetachKnownProgs() {
            for (const auto* knownProg: KnownProgs()) {
                knownProg->Detach();
            }
        }

        static void Check(bool repair) {
            TCgroupBpfProgram::Check(KnownProgs(), FindCgroupRoot(), repair);
        }
        static TString JugglerCheck();

        static const TVector<const TCgroupBpfProgram*>& KnownProgs() noexcept;
    };

} // namespace NEbpfAgent

template <>
inline void Out<struct bpf_prog_info>(IOutputStream& stream, const struct bpf_prog_info& info) {
    stream << "id " << info.id
           << "\ntype " << static_cast<NEbpfAgent::EBpfProgType>(info.type)
           << "\nname " << info.name
           << "\ntag " << HexEncode(info.tag, sizeof(info.tag));
}
