#include "net_utils.h"

#include <yandex_io/libs/base/utils.h>
#include <yandex_io/libs/errno/errno_exception.h>
#include <yandex_io/libs/logging/logging.h>

#include <util/generic/scope.h>
#include <util/system/byteorder.h>

#include <algorithm>
#include <iostream>
#include <memory>

#include <arpa/inet.h>
#include <net/if.h>
#include <sys/ioctl.h>
#include <sys/types.h>

#include <ifaddrs.h>

#include <linux/rtnetlink.h>
#include <netinet/in_systm.h>

YIO_DEFINE_LOG_MODULE("net_lib");

namespace quasar {

    namespace {

        struct RouteInfo {
            struct in_addr dst_addr;
            struct in_addr src_addr;
            struct in_addr gw_addr;
            char if_name[IF_NAMESIZE];
        };

        ssize_t readNetlinkSocket(int sock_fd, char* buffer, size_t buffer_sz, uint32_t seq, uint32_t pid) {
            const struct nlmsghdr* nl_hdr = nullptr;
            size_t msg_len = 0;

            do {
                if (msg_len > buffer_sz) {
                    YIO_LOG_ERROR_EVENT("ReadNetLinkSocket.NotEnoughBufferSize", "No size in buffer: " << quasar::strError(errno))
                    return -1;
                }

                /* Receive response from the kernel */
                const ssize_t sread_len = recv(sock_fd, buffer, buffer_sz - msg_len, 0);
                if (sread_len < 0) {
                    YIO_LOG_ERROR_EVENT("ReadNetLinkSocket.SocketReadFail", "SOCK READ error: " << quasar::strError(errno))
                    return -1;
                }

                const size_t read_len = static_cast<size_t>(sread_len);
                nl_hdr = reinterpret_cast<const struct nlmsghdr*>(buffer);

                /* Check if the header is valid */
                if (!NLMSG_OK(nl_hdr, read_len) || (nl_hdr->nlmsg_type == NLMSG_ERROR)) {
                    YIO_LOG_ERROR_EVENT("ReadNetLinkSocket.InvalidPacketHeader", "Error in received packet: " << quasar::strError(errno))
                    return -1;
                }

                /* Check if the its the last message */
                if (nl_hdr->nlmsg_type == NLMSG_DONE) {
                    break;
                } else {
                    /* Else move the pointer to buffer appropriately */
                    buffer += read_len;
                    msg_len += read_len;
                }

                /* Check if its a multi part message */
                if (!(nl_hdr->nlmsg_flags & NLM_F_MULTI)) {
                    /* return if its not */
                    break;
                }

            } while ((nl_hdr->nlmsg_seq != seq) || (nl_hdr->nlmsg_pid != pid));

            return static_cast<ssize_t>(msg_len);
        }

        /* parse the route info returned */
        int parseRoutes(struct nlmsghdr* nl_hdr, RouteInfo* rt_info) {
            struct rtmsg* rt_msg;
            struct rtattr* rt_attr;
            int rt_len;

            rt_msg = (struct rtmsg*)NLMSG_DATA(nl_hdr);

            if (rt_msg->rtm_family != AF_INET) {
                return -1;
            }

            /* get the rtattr field */
            rt_attr = (struct rtattr*)RTM_RTA(rt_msg);
            rt_len = RTM_PAYLOAD(nl_hdr);

            for (; RTA_OK(rt_attr, rt_len); rt_attr = RTA_NEXT(rt_attr, rt_len)) {
                switch (rt_attr->rta_type) {
                    case RTA_OIF:
                        if_indextoname(*(int*)RTA_DATA(rt_attr), rt_info->if_name);
                        break;
                    case RTA_GATEWAY:
                        memcpy(&rt_info->gw_addr, RTA_DATA(rt_attr),
                               sizeof(rt_info->gw_addr));
                        break;
                    case RTA_PREFSRC:
                        memcpy(&rt_info->src_addr, RTA_DATA(rt_attr),
                               sizeof(rt_info->src_addr));
                        break;
                    case RTA_DST:
                        memcpy(&rt_info->dst_addr, RTA_DATA(rt_attr),
                               sizeof(rt_info->dst_addr));
                        break;
                }
            }

            return 0;
        }

        std::optional<RouteInfo> getGatewayIp()
        {
            std::optional<RouteInfo> found_gatewayip;

            struct nlmsghdr* nl_msg;
            RouteInfo route_info;
            char msg_buffer[8192]; // pretty large buffer
            size_t len = 0;
            const int myPid = getpid();

            {
                /* Create Socket */
                int sock = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE);
                if (sock < 0) {
                    YIO_LOG_ERROR_EVENT("GetGatewayIp.CreateSocketFail", "Socket Creation error: " << quasar::strError(errno))
                    return std::nullopt;
                }

                Y_DEFER {
                    close(sock);
                };

                /* Initialize the buffer */
                int msg_seq = 0;
                memset(msg_buffer, 0, sizeof(msg_buffer));

                /* point the header and the msg structure pointers into the buffer */
                nl_msg = (struct nlmsghdr*)msg_buffer;

                /* Fill in the nlmsg header*/
                static_assert(sizeof(msg_buffer) >= sizeof(struct nlmsghdr));
                nl_msg->nlmsg_len = NLMSG_LENGTH(sizeof(struct rtmsg)); // Length of message.
                nl_msg->nlmsg_type = RTM_GETROUTE;                      // Get the routes from kernel routing table .
                nl_msg->nlmsg_flags = NLM_F_DUMP | NLM_F_REQUEST;       // The message is a request for dump.
                nl_msg->nlmsg_seq = msg_seq++;                          // Sequence of the message packet.
                nl_msg->nlmsg_pid = myPid;                              // PID of process sending the request.

                /* Send the request */
                if (send(sock, nl_msg, nl_msg->nlmsg_len, 0) < 0) {
                    YIO_LOG_ERROR_EVENT("GetGatewayIp.SocketSendFail", "Write To Socket Failed: " << quasar::strError(errno))
                    return std::nullopt;
                }

                /* Read the response */
                const ssize_t slen = readNetlinkSocket(sock, msg_buffer, sizeof(msg_buffer), msg_seq, myPid);
                if (slen < 0) {
                    YIO_LOG_ERROR_EVENT("GetGatewayIp.ReadNetLinkSocketFail", "Read From Socket Failed")
                    return std::nullopt;
                }
                len = static_cast<size_t>(slen);
            }

            /* Parse and print the response */
            for (; NLMSG_OK(nl_msg, len); nl_msg = NLMSG_NEXT(nl_msg, len)) {
                memset(&route_info, 0, sizeof(route_info));
                if (parseRoutes(nl_msg, &route_info) < 0) {
                    continue; // don't check route_info if it has not been set up
                }

                // Check if default gateway
                if (route_info.dst_addr.s_addr == INADDR_ANY) {
                    found_gatewayip = route_info;
                    break;
                }
            }

            return found_gatewayip;
        }

    } // namespace

    std::string getMyIp() noexcept {
        if (auto routeInterface = interfaceOfDefaultRoute()) {
            if (routeInterface->iface && !routeInterface->iface->ip.empty()) {
                return routeInterface->iface->ip;
            } else {
                return routeInterface->sourceIp;
            }
        }
        return "";
    }

    std::string getDefaultGatewayIp() noexcept {
        if (auto routeInfo = getGatewayIp()) {
            char ipv4[INET_ADDRSTRLEN];
            inet_ntop(AF_INET, &routeInfo->gw_addr, ipv4, INET_ADDRSTRLEN);
            return ipv4;
        }
        return "";
    }

    std::string toIpAddress(const struct sockaddr_in& addr) noexcept {
        return toIpAddress(*(const struct sockaddr*)&addr);
    }

    std::string toIpAddress(const struct sockaddr& addr) noexcept {
        std::string ipAddress;
        if (addr.sa_family == AF_INET) {
            char ipv4[INET_ADDRSTRLEN];
            struct sockaddr_in* addr4;
            addr4 = (struct sockaddr_in*)&addr;
            inet_ntop(AF_INET, &addr4->sin_addr, ipv4, INET_ADDRSTRLEN);
            ipAddress = ipv4;
        } else if (addr.sa_family == AF_INET6) {
            char ipv6[INET6_ADDRSTRLEN];
            struct sockaddr_in6* addr6;
            addr6 = (struct sockaddr_in6*)&addr;
            inet_ntop(AF_INET6, &addr6->sin6_addr, ipv6, INET6_ADDRSTRLEN);
            ipAddress = ipv6;
        }
        return ipAddress;
    }

    bool InterfaceConfig::operator<(const InterfaceConfig& other) const noexcept {
        if (name != other.name) {
            return name < other.name;
        } else if (ip != other.ip) {
            return ip < other.ip;
        } else if (netmask != other.netmask) {
            return netmask < other.netmask;
        } else if (broadcast != other.broadcast) {
            return broadcast < other.broadcast;
        }
        return ptp < other.ptp;
    }

    std::vector<InterfaceConfig> interfaceConfigList() noexcept {
        return findInterface(std::string{});
    }

    std::vector<InterfaceConfig> findInterface(const std::string& anyInterfaceIdentifier) noexcept {
        std::vector<InterfaceConfig> result;

        struct ifconf ifc;
        struct ifreq ifr[10];
        int sd;
        int ifc_num;
        int i;

        /* Create a socket so we can use ioctl on the file
         * descriptor to retrieve the interface info.
         */

        sd = socket(PF_INET, SOCK_DGRAM, 0);
        if (sd > 0)
        {
            ifc.ifc_len = sizeof(ifr);
            ifc.ifc_ifcu.ifcu_buf = (caddr_t)ifr;

            if (ioctl(sd, SIOCGIFCONF, &ifc) == 0)
            {
                ifc_num = ifc.ifc_len / sizeof(struct ifreq);
                result.reserve(ifc_num);

                for (i = 0; i < ifc_num; ++i)
                {
                    if (ifr[i].ifr_addr.sa_family != AF_INET)
                    {
                        continue;
                    }

                    InterfaceConfig interfaceConfig;
                    /* display the interface name */
                    interfaceConfig.name = ifr[i].ifr_name;

                    /* Retrieve the IP address, broadcast address, and subnet mask. */
                    if (ioctl(sd, SIOCGIFADDR, &ifr[i]) == 0)
                    {
                        interfaceConfig.ip = toIpAddress(ifr[i].ifr_addr);
                    }
                    if (ioctl(sd, SIOCGIFBRDADDR, &ifr[i]) == 0)
                    {
                        interfaceConfig.broadcast = toIpAddress(ifr[i].ifr_broadaddr);
                    }
                    if (ioctl(sd, SIOCGIFNETMASK, &ifr[i]) == 0)
                    {
                        interfaceConfig.netmask = toIpAddress(ifr[i].ifr_netmask);
                    }

                    if (anyInterfaceIdentifier.empty() ||
                        anyInterfaceIdentifier == interfaceConfig.name ||
                        anyInterfaceIdentifier == interfaceConfig.ip ||
                        anyInterfaceIdentifier == interfaceConfig.netmask ||
                        anyInterfaceIdentifier == interfaceConfig.broadcast ||
                        anyInterfaceIdentifier == interfaceConfig.ptp)
                    {
                        result.emplace_back(std::move(interfaceConfig));
                    }
                }
            }
            close(sd);
        }

        std::sort(result.begin(), result.end());
        return result;
    }

    std::vector<InterfaceConfig> broadcastAddresses(std::string ipAddress) noexcept {
        std::vector<InterfaceConfig> result;

        if (ipAddress.empty()) {
            ipAddress = "255.255.255.255";
        }

        struct sockaddr_in addr {};
        addr.sin_family = AF_INET;
        if (!inet_aton(ipAddress.c_str(), &addr.sin_addr)) {
            return result;
        }

        ipAddress = toIpAddress(addr);

        auto ifcs = findInterface("");
        if ((addr.sin_addr.s_addr == INADDR_ANY) || (addr.sin_addr.s_addr == INADDR_BROADCAST)) {
            result = std::move(ifcs);
        } else {
            result.reserve(ifcs.size());
            for (auto& ifc : ifcs) {
                if (!ifc.broadcast.empty()) {
                    if (ifc.ip == ipAddress || ifc.broadcast == ipAddress) {
                        result.emplace_back(std::move(ifc));
                    }
                }
            }
        }
        std::sort(result.begin(), result.end(),
                  [](const InterfaceConfig& if1, const InterfaceConfig& if2) {
                      return if1.broadcast < if2.broadcast;
                  });
        result.erase(std::unique(result.begin(), result.end(),
                                 [](const InterfaceConfig& if1, const InterfaceConfig& if2) {
                                     return if1.broadcast == if2.broadcast;
                                 }), result.end());

        return result;
    }

    std::optional<RouteInterface> interfaceOfDefaultRoute() noexcept {
        std::optional<RouteInterface> routeInterface;
        if (auto routeInfo = getGatewayIp()) {
            char ipv4[INET_ADDRSTRLEN];
            inet_ntop(AF_INET, &routeInfo->gw_addr, ipv4, INET_ADDRSTRLEN);
            std::string gatewayIp = ipv4;

            inet_ntop(AF_INET, &routeInfo->src_addr, ipv4, INET_ADDRSTRLEN);
            std::string sourceIp = ipv4;

            inet_ntop(AF_INET, &routeInfo->dst_addr, ipv4, INET_ADDRSTRLEN);
            std::string destinationIp = ipv4;

            auto ifaces = findInterface(routeInfo->if_name);
            const InterfaceConfig* routeIface = nullptr;
            if (ifaces.size() == 1) {
                routeIface = &ifaces.front();
            } else if (ifaces.size() > 1) {
                for (const auto& iface : ifaces) {
                    if (iface.ip == sourceIp) {
                        routeIface = &iface;
                    }
                }
                if (!routeIface) {
                    routeIface = &ifaces.front();
                }
            }

            routeInterface.emplace();
            routeInterface->destinationIp = std::move(destinationIp);
            routeInterface->sourceIp = std::move(sourceIp);
            routeInterface->gatewayIp = std::move(gatewayIp);
            if (routeIface) {
                routeInterface->iface = *routeIface;
            }
        }
        return routeInterface;
    }

    IpRange::IpRange(const std::string& ip, const std::string& netmask)
    {
        struct sockaddr_in addrIp {};
        addrIp.sin_family = AF_INET;
        if (!inet_aton(ip.c_str(), &addrIp.sin_addr)) {
            return;
        }

        struct sockaddr_in addrNetmask {};
        addrNetmask.sin_family = AF_INET;
        if (!inet_aton(netmask.c_str(), &addrNetmask.sin_addr)) {
            return;
        }

        uint32_t leIp = InetToHost<uint32_t>(addrIp.sin_addr.s_addr);
        uint32_t leNetmask = InetToHost<uint32_t>(addrNetmask.sin_addr.s_addr);

        if (leNetmask == 0xFFFFFFFF) {
            leBegin_ = leIp;
            leEnd_ = leIp;
            privateIp_ = true;
        } else {
            leBegin_ = (leIp & leNetmask);
            leEnd_ = leBegin_ | (~leNetmask);
        }
    }

    std::string IpRange::network() const noexcept {
        return leToString(leBegin_);
    }

    std::string IpRange::broadcast() const noexcept {
        return leToString(leEnd_);
    }

    std::string IpRange::ip(size_t index) const noexcept {
        if (index >= hostCapacity()) {
            return std::string{};
        }
        return privateIp_ ? leToString(leBegin_) : leToString(leBegin_ + index + 1);
    }

    std::optional<size_t> IpRange::indexOf(const std::string& ip) const noexcept {
        struct sockaddr_in addrIp {};
        addrIp.sin_family = AF_INET;
        if (!inet_aton(ip.c_str(), &addrIp.sin_addr)) {
            return std::nullopt;
        }

        uint32_t leIp = InetToHost<uint32_t>(addrIp.sin_addr.s_addr);
        if (privateIp_ && leIp == leBegin_) {
            return 0;
        }

        if (leIp <= leBegin_ || leIp >= leEnd_) {
            return std::nullopt;
        }
        return (leIp - leBegin_ - 1);
    }

    size_t IpRange::hostCapacity() const noexcept {
        return privateIp_ ? 1 : (leBegin_ >= leEnd_ ? 0 : (leEnd_ - leBegin_ > 2 ? leEnd_ - leBegin_ - 1 : 0));
    }

    struct sockaddr_in IpRange::sockaddr(size_t index) const noexcept {
        struct sockaddr_in addr {};
        if (index < hostCapacity()) {
            uint32_t leIp = privateIp_ ? leBegin_ : leBegin_ + index + 1;
            addr.sin_family = AF_INET;
            addr.sin_addr.s_addr = HostToInet<uint32_t>(leIp);
        }
        return addr;
    }

    std::string IpRange::leToString(uint32_t leIp) noexcept {
        struct sockaddr_in addr {};
        addr.sin_family = AF_INET;
        addr.sin_addr.s_addr = HostToInet<uint32_t>(leIp);
        return toIpAddress(addr);
    }

} // namespace quasar
