#include <infra/ebpf-agent/lib/namespace.h>
#include <infra/ebpf-agent/lib/config.h>
#include <infra/ebpf-agent/lib/utils.h>
#include <infra/ebpf-agent/lib/error.h>
#include <infra/ebpf-agent/lib/log.h>
#include <infra/ebpf-agent/lib/nets.h>

#include <util/network/interface.h>
#include <util/stream/file.h>

#include <sys/resource.h>
#include <sys/utsname.h>
#include <string.h>
#include <errno.h>

namespace {

    template <typename T>
    T ReadSysctl(const TString& path) {
        TFileInput sysctl(path);
        T res;
        sysctl >> res;
        return res;
    }

    template <typename T>
    void WriteSysctl(const TString& path, T&& value) {
        TFileOutput sysctl(path);
        sysctl << value;
    }

    THashMap<ino_t, THashMap<TString, unsigned int>> SavedNetSysctls;

    void OverrideNetSysctl(ino_t nsIno, const TString& path, unsigned int value) {
        SavedNetSysctls[nsIno][path] = ReadSysctl<unsigned int>(path);
        WriteSysctl(path, value);
    }

    void RestoreNetSysctls(ino_t nsIno) {
        const auto& sysctls = SavedNetSysctls[nsIno];
        for (const auto& [path, value]: sysctls) {
            WriteSysctl(path, value);
        }
    }

} // namespace

namespace NEbpfAgent {

    const TString& KernelVersion() {
        static TString version;
        if (!version) {
            struct utsname name;
            Y_VERIFY(!uname(&name));
            version = name.release;
        }
        return version;
    }

    int CompareKernelVersion(const TStringBuf& version) {
        return strverscmp(KernelVersion().c_str(), version.data());
    }

    unsigned int GetJiffieMs() {
        static unsigned int jiffieMs = 0;
        if (!jiffieMs) {
            unsigned int configHz = 250;
            if (CompareKernelVersion("5.4") >= 0) {
                configHz = 1000;
            }
            jiffieMs = 1000 / configHz;
        }
        return jiffieMs;
    }

    bool IsExtendedTcpCountersSupported() {
        static bool isSupported = CompareKernelVersion("5.4.180-31") >= 0;
        return isSupported;
    }

    std::tuple<TStringBuf, ESignal> CheckPrerequisites(bool forceRun) {
        // check kernel version
        if (CompareKernelVersion("4.19") < 0) {
            return std::make_tuple("Kernel is too old, need at least 4.19", ESignal::OldKernel);
        }

        if (forceRun) {
            return std::make_tuple("", ESignal::Running);
        }

        // check leaked cgroup
        auto root = FindCgroupRoot(false);
        if (root) {
            auto stat = GetCgroupStat(root);
            if (stat.DyingDescendants > 10000) {
                return std::make_tuple("Too many (> 10000) dying descendants at root cgroup", ESignal::LeakedCgroups);
            }
        }

        return std::make_tuple("", ESignal::Running);
    }

    TString FindCgroupRoot(bool throwOnError) {
        TFileInput input("/proc/self/mountinfo");
        TString line;
        while (input.ReadLine(line)) {
            auto space = line.find(' '); // skip mount id
            space = line.find(' ', space + 1); // skip parent id
            space = line.find(' ', space + 1); // skip major:minor
            space = line.find(' ', space + 1); // skip root

            // mount point
            auto pointStart = space + 1;
            space = line.find(' ', pointStart);
            auto pointLen = space - pointStart;

            // check filesystem type
            auto dash = line.find('-', space + 1);
            space = line.find(' ', dash + 2);
            auto type = line.substr(dash + 2, space - dash - 2);
            if (type == "cgroup2") {
                return line.substr(pointStart, pointLen);
            }
        }

        if (throwOnError) {
            ythrow TCgroupError() << "Failed to find root cgroup";
        }

        return "";
    }

    ui64 GetCgroupId(const TString& path) {
        char handleBuf[sizeof(struct file_handle) + sizeof(ui64)] = { 0 };
        struct file_handle* handle = reinterpret_cast<struct file_handle*>(handleBuf);
        int mountId;

        handle->handle_bytes = sizeof(ui64);
        if (name_to_handle_at(AT_FDCWD, path.c_str(), handle, &mountId, 0) < 0) {
            /* try at cgroup2 mount */
            auto rootFd = OpenCgroupRoot();

            handle->handle_bytes = sizeof(ui64);
            if (name_to_handle_at(rootFd, path.c_str(), handle, &mountId, 0) < 0) {
                ythrow TCgroupError() << "Failed to get cgroup id";
            }
        }
        if (handle->handle_bytes != sizeof(ui64)) {
            ythrow TCgroupError() << "Invalid cgroup id size";
        }

        union {
            ui64 id;
            unsigned char bytes[sizeof(ui64)];
        } cgroupId;
        memcpy(cgroupId.bytes, handle->f_handle, sizeof(ui64));
        return cgroupId.id;
    }

    TCgroupStat GetCgroupStat(const TString& path) {
        TCgroupStat stat;
        TFileInput input(path + "/cgroup.stat");
        TString line;
        while (input.ReadLine(line)) {
            if (!strncmp(line.c_str(), "nr_descendants", sizeof("nr_descendants") - 1)) {
                stat.Descendants = FromString<ui64>(&line[sizeof("nr_descendants")]);
            } else if (!strncmp(line.c_str(), "nr_dying_descendants", sizeof("nr_dying_descendants") - 1)) {
                stat.DyingDescendants = FromString<ui64>(&line[sizeof("nr_dying_descendants")]);
            }
        }
        return stat;
    }

    EDatacenter GetMyDatacenter() {
        static int myDc = -1;

        if (myDc < 0) {
            myDc = DC_UNKNOWN;

            const auto interfaces = NAddr::GetNetworkInterfaces();
            for (const auto& interface: interfaces) {
                const struct sockaddr* sa = interface.Address->Addr();
                if (sa->sa_family == AF_INET6) {
                    const struct sockaddr_in6* sa6 = reinterpret_cast<const struct sockaddr_in6*>(sa);
                    const __u32* addr = sa6->sin6_addr.s6_addr32;
                    if (addr[0] == YA_NETS) {
                        INFO_LOG << "Match self datacenter by address " << PrintHost(*interface.Address) << Endl;
                        myDc = GetDatacenter(addr);
                        break;
                    }
                }
            }
        }

        return static_cast<EDatacenter>(myDc);
    }

    TString DatacenterToString(EDatacenter dc) {
        switch (dc) {
        case DC_SAS1:
        case DC_SAS2:
            return "sas";
        case DC_VLA:
            return "vla";
        case DC_VLX:
            return "vlx";
        case DC_MAN:
            return "man";
        case DC_IVA:
            return "iva";
        case DC_MYT:
            return "myt";
        default:
            return "";
        }
    }

    bool SetMlockLimit(size_t bytes) {
        const struct rlimit rlim = { bytes, bytes };
        return (setrlimit(RLIMIT_MEMLOCK, &rlim) == 0);
    }

    void CleanupNetNamespaces() {
        TNetNamespace::ForEach([](ino_t nsIno) {
            WriteSysctl("/proc/sys/net/ipv6/auto_flowlabels", 0);

            RestoreNetSysctls(nsIno);
        });
    }

    void SetupNetNamespaces() {
        TNetNamespace::ForEach([](ino_t nsIno) {
            WriteSysctl("/proc/sys/net/ipv6/auto_flowlabels", 3);

            OverrideNetSysctl(nsIno, "/proc/sys/net/ipv4/tcp_syn_retries", Config().GetTcp().GetSynRetries());
            OverrideNetSysctl(nsIno, "/proc/sys/net/ipv4/tcp_synack_retries", Config().GetTcp().GetSynAckRetries());
            OverrideNetSysctl(nsIno, "/proc/sys/net/ipv4/tcp_orphan_retries", Config().GetTcp().GetOrphanRetries());
            OverrideNetSysctl(nsIno, "/proc/sys/net/ipv4/tcp_retries1", Config().GetTcp().GetRetries1());
            OverrideNetSysctl(nsIno, "/proc/sys/net/ipv4/tcp_retries2", Config().GetTcp().GetRetries2());

            if (TCgroupBpfProgram::TcpTos().IsEnabled()) {
                OverrideNetSysctl(nsIno, "/proc/sys/net/ipv4/tcp_ecn_fallback", 0);
            };
        });
    }

} // namespace NEbpfAgent
