#include "netlink_monitor.h"

#include <sys/types.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <netdb.h>
#include <string.h>
#include <arpa/inet.h>
#include <linux/rtnetlink.h>

#include <array>
#include <iostream>
#include <string.h>

#include <yandex_io/libs/logging/logging.h>
#include <yandex_io/libs/net/netlink_base.h>

#include <ifaddrs.h>

namespace {
    /* netlink implementation */

    std::string addrToString(int family, void* data, size_t sz) {
        char buf[1024];
        switch (family) {
            case AF_INET6: {
                in6_addr addr;
                if (sz == sizeof(addr)) {
                    memcpy(&addr, data, sz);
                    return inet_ntop(family, &addr, buf, sizeof(buf));
                }
                break;
            }
            case AF_INET: {
                in_addr addr;
                if (sz == sizeof(addr.s_addr)) {
                    memcpy(&addr.s_addr, data, sz);
                    return inet_ntop(family, &addr, buf, sizeof(buf));
                }
            }
        }
        return {};
    }

    class NetlinkMonitorImpl: public quasar::net::NetlinkMonitor, quasar::net::NetlinkBase {
        using Handler = quasar::net::NetlinkMonitor::Handler;
        using HandledRequest = NetlinkBase::HandledRequest;

        std::chrono::steady_clock::time_point statsRequested = std::chrono::steady_clock::now();
        Handler& hndl;

        static IfFlags fromIfiFlags(uint32_t ifi_flags) {
            return IfFlags{
                .up = (bool)(ifi_flags & IFF_UP),
                .running = (bool)(ifi_flags & IFF_RUNNING),
                .loopback = (bool)(ifi_flags & IFF_LOOPBACK),
                .multicast = (bool)(ifi_flags & IFF_MULTICAST),
                .p2p = (bool)(ifi_flags & IFF_POINTOPOINT),
                .dynamic = (bool)(ifi_flags & IFF_DYNAMIC)};
        }

        template <typename Callback_>
        void iterateAddresses(ifaddrmsg* msg, size_t len, Callback_ cb) {
            if (msg->ifa_family == AF_INET || msg->ifa_family == AF_INET6) {
                struct rtattr* a = IFA_RTA(msg);
                while (RTA_OK(a, len)) {
                    bool local = false;
                    switch (a->rta_type) {
                        // case IFA_FLAGS: // IFF_NOARP flag is here
                        case IFA_LOCAL:
                            local = true;
                        case IFA_ADDRESS:
                            cb(msg->ifa_index, addrToString(msg->ifa_family, RTA_DATA(a), RTA_PAYLOAD(a)), msg->ifa_family, Scope(msg->ifa_scope), local);
                            break;
                    }
                    a = RTA_NEXT(a, len);
                }
            }
        }

        void iterateLinks(ifinfomsg* ifi_ptr, size_t attr_len) {
            IfFlags flags = fromIfiFlags(ifi_ptr->ifi_flags);
            struct rtattr* attr_ptr = IFLA_RTA(ifi_ptr);
            while (RTA_OK(attr_ptr, attr_len)) {
                const int rtaPayloadSize = RTA_PAYLOAD(attr_ptr);
                switch (attr_ptr->rta_type) {
                    case IFLA_IFNAME: {
                        // -1 to remove trailing zero
                        hndl.onLink(ifi_ptr->ifi_index, std::string((char*)RTA_DATA(attr_ptr), rtaPayloadSize - 1), flags);
                        break;
                    }
                    case IFLA_ADDRESS: {
                        std::string mac;
                        unsigned char* ptr = (unsigned char*)RTA_DATA(attr_ptr);
                        int size = rtaPayloadSize;
                        while (size--) {
                            mac.append(quasar::net::hexchar(*ptr).data(), 2);
                            if (size) {
                                mac += ':';
                            }
                            ++ptr;
                        }
                        hndl.onMac(ifi_ptr->ifi_index, mac);
                        break;
                    }
                    case IFLA_STATS64: {
                        rtnl_link_stats64 st;
                        memcpy(&st, RTA_DATA(attr_ptr), sizeof(st));
                        hndl.onStats(ifi_ptr->ifi_index, st);
                        break;
                    }
                };
                attr_ptr = RTA_NEXT(attr_ptr, attr_len);
            }
        }

        void iterateStats(int idx, rtattr* attr_ptr, size_t attr_len) {
            while (RTA_OK(attr_ptr, attr_len)) {
                const int rtaPayloadSize = RTA_PAYLOAD(attr_ptr);
                if (attr_ptr->rta_type == IFLA_STATS_LINK_64 && rtaPayloadSize == sizeof(rtnl_link_stats64)) {
                    rtnl_link_stats64 st;
                    memcpy(&st, RTA_DATA(attr_ptr), sizeof(st));
                    hndl.onStats(idx, st);
                }
                attr_ptr = RTA_NEXT(attr_ptr, attr_len);
            }
        }

        void processRtmMessage(const struct nlmsghdr* nlmsg_ptr) {
            switch (nlmsg_ptr->nlmsg_type) {
                case RTM_NEWADDR:
                    iterateAddresses((ifaddrmsg*)NLMSG_DATA(nlmsg_ptr), NLMSG_PAYLOAD(nlmsg_ptr, sizeof(struct ifaddrmsg)),
                                     [this](int idx, std::string addr, int family, Scope scope, bool local) {
                                         hndl.onAddress(idx, std::move(addr), family, scope, local);
                                     });
                    break;
                case RTM_NEWLINK:
                    iterateLinks((ifinfomsg*)NLMSG_DATA(nlmsg_ptr), nlmsg_ptr->nlmsg_len - NLMSG_LENGTH(sizeof(struct ifinfomsg)));
                    break;
                case RTM_DELADDR:
                    iterateAddresses((ifaddrmsg*)NLMSG_DATA(nlmsg_ptr), NLMSG_PAYLOAD(nlmsg_ptr, sizeof(struct ifaddrmsg)),
                                     [this](int idx, std::string addr, int family, Scope scope, bool local) {
                                         hndl.onAddressRemove(idx, std::move(addr), family, scope, local);
                                     });
                    break;
                case RTM_DELLINK: {
                    const ifinfomsg* ifmsg = (ifinfomsg*)NLMSG_DATA(nlmsg_ptr);
                    if (ifmsg) {
                        hndl.onLinkRemoved(ifmsg->ifi_index);
                    }
                    break;
                }
                case RTM_NEWSTATS: {
                    const if_stats_msg* ifsm = (if_stats_msg*)NLMSG_DATA(nlmsg_ptr);
                    iterateStats(ifsm->ifindex, (rtattr*)((char*)ifsm + NLMSG_ALIGN(sizeof(*ifsm))), nlmsg_ptr->nlmsg_len - NLMSG_LENGTH(sizeof(*ifsm)));
                    break;
                }
            };
        }

    public:
        NetlinkMonitorImpl(Handler& h)
            : NetlinkBase(NETLINK_ROUTE, RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR)
            , hndl(h)
        {
        }

        void stop() override {
            NetlinkBase::stop();
        }

        void monitor(bool indefinetelly, bool pollStats) override {
            auto msgHndl = [this](const struct nlmsghdr* msg) { processRtmMessage(msg); };
            std::vector<HandledRequest> requests = {
                {[this]() { request(RTM_GETADDR); }, msgHndl},
                {[this]() { request(RTM_GETLINK); }, msgHndl},
            };

            if (pollStats) {
                requests.insert(requests.begin(), {[this]() { statsRequest(RTM_GETSTATS); }, msgHndl});
                monitorImpl(indefinetelly,
                            std::move(requests),
                            [this]() {
                                if (std::chrono::steady_clock::now() - statsRequested >= std::chrono::seconds(60)) {
                                    statsRequest(RTM_GETSTATS);
                                }
                                return false;
                            });
            } else {
                monitorImpl(indefinetelly,
                            std::move(requests),
                            nullptr);
            }
        }

        void request(int rtmRequest, int reqFamily = AF_UNSPEC) {
            struct NlRequest {
                struct nlmsghdr hdr;
                struct rtgenmsg gen;
            } req;

            setupHeader(req, rtmRequest);
            req.gen.rtgen_family = reqFamily;
            req.hdr.nlmsg_flags |= NLM_F_DUMP;

            sendRequest(req);
        };

        void statsRequest(int rtmRequest, int reqFamily = AF_UNSPEC) {
            struct NlStatsRequest {
                struct nlmsghdr hdr;
                struct if_stats_msg gen;
            } req;

            setupHeader(req, rtmRequest);
            req.hdr.nlmsg_flags |= NLM_F_DUMP;
            req.gen.family = reqFamily;
            req.gen.filter_mask = 0xffffffff;

            sendRequest(req);
            statsRequested = std::chrono::steady_clock::now();
        };
    };
} // namespace

namespace quasar::net {
    std::string NetlinkMonitor::IfFlags::toString() const {
        std::string rval;
        if (up) {
            rval += 'U';
        }
        if (running) {
            rval += 'R';
        }
        if (loopback) {
            rval += 'L';
        }
        if (multicast) {
            rval += 'M';
        }
        if (p2p) {
            rval += 'P';
        }
        if (dynamic) {
            rval += 'D';
        }
        return rval;
    }

    std::unique_ptr<NetlinkMonitor> makeNetlinkMonitor([[maybe_unused]] NetlinkMonitor::Handler& handler) {
        return std::make_unique<NetlinkMonitorImpl>(handler);
    }
} // namespace quasar::net
