#include "gateway_monitor.h"

#include <yandex_io/libs/base/utils.h>
#include <yandex_io/libs/errno/errno_exception.h>
#include <yandex_io/libs/logging/logging.h>

#include <linux/rtnetlink.h>

#include <netinet/in_systm.h>

#include <util/generic/scope.h>

#include <arpa/inet.h>
#include <net/if.h>

namespace {

    struct route_info {
        struct in_addr dst_addr;
        struct in_addr src_addr;
        struct in_addr gw_addr;
        char if_name[IF_NAMESIZE];
    };

    ssize_t read_netlink_socket(int sock_fd, char* buffer, size_t buffer_sz, uint32_t seq, uint32_t pid) {
        const struct nlmsghdr* nl_hdr = nullptr;
        size_t msg_len = 0;

        do {
            if (msg_len > buffer_sz) {
                YIO_LOG_DEBUG("No size in buffer: " << quasar::strError(errno))
                return -1;
            }

            /* Receive response from the kernel */
            const ssize_t sread_len = recv(sock_fd, buffer, buffer_sz - msg_len, 0);
            if (sread_len < 0) {
                YIO_LOG_DEBUG("SOCK READ error: " << quasar::strError(errno))
                return -1;
            }

            const size_t read_len = static_cast<size_t>(sread_len);
            nl_hdr = reinterpret_cast<const struct nlmsghdr*>(buffer);

            /* Check if the header is valid */
            if (!NLMSG_OK(nl_hdr, read_len) || (nl_hdr->nlmsg_type == NLMSG_ERROR)) {
                YIO_LOG_DEBUG("Error in received packet: " << quasar::strError(errno))
                return -1;
            }

            /* Check if the its the last message */
            if (nl_hdr->nlmsg_type == NLMSG_DONE) {
                break;
            } else {
                /* Else move the pointer to buffer appropriately */
                buffer += read_len;
                msg_len += read_len;
            }

            /* Check if its a multi part message */
            if (!(nl_hdr->nlmsg_flags & NLM_F_MULTI)) {
                /* return if its not */
                break;
            }

        } while ((nl_hdr->nlmsg_seq != seq) || (nl_hdr->nlmsg_pid != pid));

        return static_cast<ssize_t>(msg_len);
    }

    /* parse the route info returned */
    int parse_routes(struct nlmsghdr* nl_hdr, struct route_info* rt_info) {
        struct rtmsg* rt_msg;
        struct rtattr* rt_attr;
        int rt_len;

        rt_msg = (struct rtmsg*)NLMSG_DATA(nl_hdr);

        if (rt_msg->rtm_family != AF_INET) {
            return -1;
        }

        /* get the rtattr field */
        rt_attr = (struct rtattr*)RTM_RTA(rt_msg);
        rt_len = RTM_PAYLOAD(nl_hdr);

        for (; RTA_OK(rt_attr, rt_len); rt_attr = RTA_NEXT(rt_attr, rt_len)) {
            switch (rt_attr->rta_type) {
                case RTA_OIF:
                    if_indextoname(*(int*)RTA_DATA(rt_attr), rt_info->if_name);
                    break;
                case RTA_GATEWAY:
                    memcpy(&rt_info->gw_addr, RTA_DATA(rt_attr),
                           sizeof(rt_info->gw_addr));
                    break;
                case RTA_PREFSRC:
                    memcpy(&rt_info->src_addr, RTA_DATA(rt_attr),
                           sizeof(rt_info->src_addr));
                    break;
                case RTA_DST:
                    memcpy(&rt_info->dst_addr, RTA_DATA(rt_attr),
                           sizeof(rt_info->dst_addr));
                    break;
            }
        }

        return 0;
    }

    int get_gatewayip(char* gatewayip, socklen_t size) {
        int found_gatewayip = 0;

        struct nlmsghdr* nl_msg;
        struct route_info route_info;
        char msg_buffer[8192]; // pretty large buffer
        size_t len = 0;
        const int myPid = getpid();

        {
            /* Create Socket */
            int sock = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE);
            if (sock < 0) {
                YIO_LOG_DEBUG("Socket Creation error: " << quasar::strError(errno))
                return -1;
            }

            Y_DEFER {
                close(sock);
            };

            /* Initialize the buffer */
            int msg_seq = 0;
            memset(msg_buffer, 0, sizeof(msg_buffer));

            /* point the header and the msg structure pointers into the buffer */
            nl_msg = (struct nlmsghdr*)msg_buffer;

            /* Fill in the nlmsg header*/
            nl_msg->nlmsg_len = NLMSG_LENGTH(sizeof(struct rtmsg)); // Length of message.
            nl_msg->nlmsg_type = RTM_GETROUTE;                      // Get the routes from kernel routing table .
            nl_msg->nlmsg_flags = NLM_F_DUMP | NLM_F_REQUEST;       // The message is a request for dump.
            nl_msg->nlmsg_seq = msg_seq++;                          // Sequence of the message packet.
            nl_msg->nlmsg_pid = myPid;                              // PID of process sending the request.

            /* Send the request */
            if (send(sock, nl_msg, nl_msg->nlmsg_len, 0) < 0) {
                YIO_LOG_DEBUG("Write To Socket Failed: " << quasar::strError(errno))
                return -1;
            }

            /* Read the response */
            const ssize_t slen = read_netlink_socket(sock, msg_buffer, sizeof(msg_buffer), msg_seq, myPid);
            if (slen < 0) {
                YIO_LOG_DEBUG("Read From Socket Failed")
                return -1;
            }
            len = static_cast<size_t>(slen);
        }

        /* Parse and print the response */
        for (; NLMSG_OK(nl_msg, len); nl_msg = NLMSG_NEXT(nl_msg, len)) {
            memset(&route_info, 0, sizeof(route_info));
            if (parse_routes(nl_msg, &route_info) < 0) {
                continue; // don't check route_info if it has not been set up
            }

            // Check if default gateway
            if (strstr((char*)inet_ntoa(route_info.dst_addr), "0.0.0.0")) {
                // copy it over
                if (inet_ntop(AF_INET, &route_info.gw_addr, gatewayip, size)) {
                    found_gatewayip = 1;
                    YIO_LOG_DEBUG("ifname: " << route_info.if_name);
                    break;
                }
            }
        }

        return found_gatewayip;
    }

    std::string getGatewayIp() {
        char gateway_ip[128]{'\0'};

        if (get_gatewayip(gateway_ip, sizeof(gateway_ip)) != 1) {
            YIO_LOG_DEBUG("Can't determine gateway");
            return "";
        }

        return std::string(gateway_ip);
    }

} // namespace

namespace quasar {

    GatewayMonitor::GatewayMonitor(std::chrono::seconds reloadInterval, Listener* listener, ICallbackQueue* worker)
        : reloadInterval_(reloadInterval)
        , listener_(listener)
        , worker_(worker)
    {
    }

    std::shared_ptr<GatewayMonitor> GatewayMonitor::create(std::chrono::seconds reloadInterval, Listener* listener, ICallbackQueue* worker) {
        auto ref = std::make_shared<GatewayMonitor>(reloadInterval, listener, worker);

        ref->worker_->add([ref]() {
            ref->reload();
        });

        return ref;
    }

    std::weak_ptr<GatewayMonitor> GatewayMonitor::weakFromThis() {
        return shared_from_this();
    }

    void GatewayMonitor::reloadNow() {
        const auto weakRef = weakFromThis();

        worker_->add([this, weakRef]() {
            if (const auto ref = weakRef.lock()) {
                reloadOnce();
            }
        });
    }

    void GatewayMonitor::reload() {
        reloadOnce();
        scheduleNextReload();
    }

    void GatewayMonitor::reloadOnce() {
        const std::string gatewayIp = getGatewayIp();

        if (gatewayIp != gatewayIp_) {
            listener_->onGatewayChanged(gatewayIp_, gatewayIp);
            gatewayIp_ = gatewayIp;
        }
    }

    void GatewayMonitor::scheduleNextReload() {
        const auto weakRef = weakFromThis();

        worker_->addDelayed(
            [this, weakRef]() {
                if (const auto ref = weakRef.lock()) {
                    reload();
                }
            },
            reloadInterval_);
    }

    std::string GatewayMonitor::gatewayIp()
    {
        return getGatewayIp();
    }

} // namespace quasar
