#include "netlink_base.h"

#include <yandex_io/libs/errno/errno_exception.h>
#include <yandex_io/libs/logging/logging.h>

#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <unistd.h>
#include <poll.h>

#ifndef SOL_NETLINK
    #define SOL_NETLINK 270
#endif

namespace {
    char nible(unsigned char c) {
        return c < 10 ? '0' + c : 'a' + (c - 10);
    }
} // namespace

namespace quasar::net {
    std::array<char, 2> hexchar(unsigned char c) {
        std::array<char, 2> rval = {};
        rval[0] = nible((c >> 4) & 0xf);
        rval[1] = nible(c & 0xf);
        return rval;
    }

    std::string macToStr(const MacAddress& src) {
        std::string result;
        for (const auto b : src) {
            if (!result.empty()) {
                result += ':';
            }
            const auto c = hexchar(b);
            result += c[0];
            result += c[1];
        }
        return result;
    }

    std::string rawDataToString(const void* dataIn, unsigned len) {
        const char* data = (const char*)dataIn;
        std::string result;
        for (unsigned i = 0; i < len; ++i, ++data) {
            result += std::to_string(i);
            result += ':';
            auto c = hexchar(*data);
            result += c[0];
            result += c[1];
            if (*data >= ' ' && *data < 127) {
                result += '\'';
                result += *data;
                result += '\'';
            }
            result += ' ';
        }
        return result;
    }

} // namespace quasar::net

using namespace quasar::net;

FdHolder::FdHolder(int f)
    : fd(f)
{
}

FdHolder::~FdHolder() {
    if (fd != -1) {
        close(fd);
    }
}

NetlinkBase::ReadMsgStatus NetlinkBase::readMsg(bool inRequest) {
    char buf[BUF_SIZE];
    struct iovec iov;
    struct msghdr msg;
    memset(&msg, 0, sizeof(msg));
    memset(&iov, 0, sizeof(iov));
    iov.iov_base = buf;
    iov.iov_len = BUF_SIZE;

    char cred_msg[CMSG_SPACE(sizeof(struct ucred))];
    msg.msg_control = cred_msg;
    msg.msg_controllen = sizeof(cred_msg);
    msg.msg_iov = &iov;
    msg.msg_iovlen = 1;
    msg.msg_name = &nlAddr;
    msg.msg_namelen = sizeof(nlAddr);
    int msg_len = recvmsg(sock.fd, &msg, MSG_DONTWAIT);
    if (msg_len <= 0) {
        return {.done = false, .read = false};
    }

    struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
    struct ucred* cred = (struct ucred*)CMSG_DATA(cmsg);
    if (inRequest && cred->pid != 0) { // not my request
        YIO_LOG_DEBUG("Not my request " << cred->pid << " vs " << nlPidId << " vs " << procPidId);
        return {.done = false, .read = true};
    }
    bool done = false;
    struct nlmsghdr* nlmsg_ptr = (struct nlmsghdr*)buf;
    YIO_LOG_DEBUG("nlmsg_seq = " << nlmsg_ptr->nlmsg_seq << " vs " << seqNumber << " (level = " << cmsg->cmsg_level << ", type = " << cmsg->cmsg_type << ')');
    while (msg_len > 0 && NLMSG_OK(nlmsg_ptr, (size_t)msg_len)) {
        switch (nlmsg_ptr->nlmsg_type) {
            case NLMSG_DONE:
                YIO_LOG_DEBUG("NLMSG_DONE");
                done = true;
                break;
            case NLMSG_ERROR: {
                YIO_LOG_DEBUG("NLMSG_ERROR");
                if (nlmsg_ptr->nlmsg_len >= sizeof(nlmsgerr)) {
                    const struct nlmsgerr* err = (nlmsgerr*)NLMSG_DATA(nlmsg_ptr);
                    if (err->error) {
                        YIO_LOG_WARN("nlmsgerror " << err->error << " nlmsg_len = " << err->msg.nlmsg_len << " payload length = " << (err->msg.nlmsg_len - sizeof(nlmsgerr)));
                        YIO_LOG_WARN(quasar::strError(abs(err->error)));
                        YIO_LOG_DEBUG(rawDataToString(NLMSG_DATA(&(err->msg)), err->msg.nlmsg_len));
                    }
                };
                done = true;
                break;
            }
            default:
                YIO_LOG_DEBUG("nlmsg_type = " << nlmsg_ptr->nlmsg_type << ",  nlmsg_len = " << nlmsg_ptr->nlmsg_len);
                msgHandler(nlmsg_ptr);
                break;
        }
        nlmsg_ptr = NLMSG_NEXT(nlmsg_ptr, msg_len);
    }
    return {.done = done, .read = true};
}

NetlinkBase::NetlinkBase(int netlinkScope, int groups)
    : sock(socket(PF_NETLINK, SOCK_DGRAM, netlinkScope))
    , seqNumber(1)
    , procPidId(getpid())
    , nlPidId(pthread_self() << 16 | procPidId)
{
    if (sock.fd == -1) {
        throw quasar::ErrnoException(errno, "NetlinkMonitor: Failed to open netlink socket");
    }
    memset(&nlAddr, 0, sizeof(nlAddr));
    nlAddr.nl_family = AF_NETLINK;
    nlAddr.nl_pid = nlPidId; // unique id of source
    nlAddr.nl_groups = groups;
    if (bind(sock.fd, (struct sockaddr*)&nlAddr, sizeof(nlAddr)) == -1) {
        throw quasar::ErrnoException(errno, "NetlinkMonitor: Failed to bind netlink groups");
    }
    nlAddr.nl_groups = 0; // interesting in all groups
    nlAddr.nl_pid = 0;    // sending to kernel
    // next stuff is optional for safety so didnt check result
    int on = 1;
    setsockopt(sock.fd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));
    setsockopt(sock.fd, SOL_NETLINK, NETLINK_EXT_ACK, &on, sizeof(on));
}

void NetlinkBase::stop() {
    quit = true;
}

void NetlinkBase::monitorImpl(bool indefinitelly,
                              std::vector<HandledRequest> initialRequests,
                              std::function<bool()> periodic) {
    struct pollfd pfds {
        .fd = sock.fd,
        .events = POLLIN,
        .revents = 0
    };

    auto popRequest = [this, &initialRequests]() {
        if (!initialRequests.empty()) {
            auto& [reqFn, msgHndlFn] = initialRequests.back();
            msgHandler = std::move(msgHndlFn);
            reqFn();
            initialRequests.pop_back();
            return true;
        }
        return false;
    };

    bool curRequest = popRequest();
    quit = false;

    while (!quit) {
        poll(&pfds, 1, 1000);
        if (pfds.revents & POLLIN) {
            while (true) {
                auto [done, msgRead] = readMsg(curRequest);
                if (done) {
                    curRequest = popRequest();
                    if (!curRequest && !indefinitelly) {
                        quit = true;
                    }
                }
                if (!msgRead) {
                    break;
                };
            };
        }
        if (!curRequest) {
            if (periodic) {
                curRequest = periodic();
            }
        }
    }
}

void NetlinkBase::sendRequestImpl(void* reqPtr, int reqLen) {
    struct msghdr msg;
    memset(&msg, 0, sizeof(msg));
    struct iovec iov;
    iov.iov_base = reqPtr;
    iov.iov_len = reqLen;

    msg.msg_iov = &iov;
    msg.msg_iovlen = 1;

    msg.msg_name = &nlAddr;
    msg.msg_namelen = sizeof(nlAddr);

    auto res = sendmsg(sock.fd, &msg, 0);
    YIO_LOG_DEBUG("sendmsg returned " << res);
    if (res == -1) {
        YIO_LOG_DEBUG(quasar::strError(errno));
    }
}

void NetlinkBase::setupHeaderImpl(struct nlmsghdr& hdr, unsigned genLen, int rtmRequest) {
    hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
    hdr.nlmsg_len = NLMSG_LENGTH(genLen);
    hdr.nlmsg_pid = nlPidId; // source of request
    hdr.nlmsg_seq = seqNumber++;
    hdr.nlmsg_type = rtmRequest;
}
