#include "simple_client.h"

#include <arpa/inet.h>

#include <contrib/libs/libmnl/include/libmnl/libmnl.h>

#include <linux/if_link.h>
#include <linux/rtnetlink.h>

#include <net/if.h>
#include <netinet/in.h>

#include <util/datetime/base.h>

#include <errno.h>

namespace NInfra::NPodAgent {

namespace {

class TMnlSocket {
public:
    TMnlSocket()
        : MnlSocket_(nullptr)
    {}

    ~TMnlSocket() {
        Close();
    }

    TExpected<void, TIpClientError> Open() {
        if (MnlSocket_ != nullptr) {
            return TIpClientError(
                EIpClientError::SocketOpenError
                , "socket is already open"
            );
        }

        MnlSocket_ = mnl_socket_open(NETLINK_ROUTE);

        if (MnlSocket_ == nullptr) {
            return TIpClientError(
                EIpClientError::SocketOpenError
                , TStringBuilder() << "mnl_socket_open: " << strerror(errno)
            );
        }

        return TExpected<void, TIpClientError>::DefaultSuccess();
    }

    TExpected<void, TIpClientError> Bind() {
        if (MnlSocket_ == nullptr) {
            return TIpClientError(
                EIpClientError::SocketBindError
                , "can't bind unopened socket"
            );
        }

        if (mnl_socket_bind(MnlSocket_, 0, MNL_SOCKET_AUTOPID) < 0) {
            return TIpClientError(
                EIpClientError::SocketBindError
                , TStringBuilder() << "mnl_socket_bind: " << strerror(errno)
            );
        }

        return TExpected<void, TIpClientError>::DefaultSuccess();
    }

    TExpected<ui32, TIpClientError> GetPortId() {
        if (MnlSocket_ == nullptr) {
            return TIpClientError(
                EIpClientError::SocketGetPortIdError
                , "can't get port id from unopened socket"
            );
        }

        // Always success
        return mnl_socket_get_portid(MnlSocket_);
    }

    TExpected<void, TIpClientError> SendTo(
        const void* buf
        , size_t bufSize
    ) {
        if (MnlSocket_ == nullptr) {
            return TIpClientError(
                EIpClientError::SocketSendToError
                , "can't send to unopened socket"
            );
        }

        if (mnl_socket_sendto(MnlSocket_, buf, bufSize) < 0) {
            return TIpClientError(
                EIpClientError::SocketSendToError
                , TStringBuilder() << "mnl_socket_sendto: " << strerror(errno)
            );
        }

        return TExpected<void, TIpClientError>::DefaultSuccess();
    }

    TExpected<i32, TIpClientError> RecvFrom(
        void* buf
        , size_t bufSize
    ) {
        if (MnlSocket_ == nullptr) {
            return TIpClientError(
                EIpClientError::SocketRecvFromError
                , "can't receive from unopened socket"
            );
        }

        i32 ret = mnl_socket_recvfrom(MnlSocket_, buf, bufSize);
        if (ret < 0) {
           return TIpClientError(
                EIpClientError::SocketRecvFromError
                , TStringBuilder() << "mnl_socket_recvfrom: " << strerror(errno)
           );
        }

        return ret;
    }

    TExpected<i32, TIpClientError> CallbackRun(
        const void* buf
        , size_t numBytes
        , ui32 seq
        , ui32 portId
        , mnl_cb_t callback
        , void* data
    ) {
        if (MnlSocket_ == nullptr) {
            return TIpClientError(
                EIpClientError::SocketCallbackRunError
                , "can't do callback run with unopened socket"
            );
        }

        i32 ret = mnl_cb_run(buf, numBytes, seq, portId, callback, data);
        if (ret < 0) {
           return TIpClientError(
                EIpClientError::SocketCallbackRunError
                , TStringBuilder() << "mnl_cb_run: " << strerror(errno)
            );
        }

        return ret;
    }

    void Close() {
        if (MnlSocket_ != nullptr) {
            mnl_socket_close(MnlSocket_);
            MnlSocket_ = nullptr;
        }
    }

private:
    mnl_socket* MnlSocket_;
};

struct ListAddressData {
    TVector<TIpDescription> DumpedIps;
    ui32 iface;
};

TExpected<ui32, TIpClientError> NameToIndex(const TString& device) {
    ui32 iface = if_nametoindex(device.data());
    if (iface == 0) {
        return TIpClientError(
            EIpClientError::DeviceError
            , TStringBuilder() << "if_nametoindex: " << strerror(errno)
        );
    }

    return iface;
}

TExpected<in6_addr, TIpClientError> ParseIp6Address(const TString& ip) {
    in6_addr ip6Addr;
    if (!inet_pton(AF_INET6, ip.data(), &ip6Addr)) {
        return TIpClientError(
            EIpClientError::Ip6ParseError
            , TStringBuilder() << "inet_pton: " << strerror(errno)
        );
    }

    return ip6Addr;
}

nlmsghdr* GetHeaderForAddressAction(
    char buf[MNL_SOCKET_BUFFER_SIZE]
    , ui16 nlmsgType
    , ui16 nlmsgFlags
    , ui32 seq
    , ui32 iface
    , const in6_addr& ip6Addr
    , ui32 subnet
) {
    nlmsghdr* nlh;

    nlh = mnl_nlmsg_put_header(buf);

    nlh->nlmsg_type = nlmsgType;
    nlh->nlmsg_flags = nlmsgFlags;
    nlh->nlmsg_seq = seq;

    {
        ifaddrmsg* ifm;
        ifm = (ifaddrmsg*)mnl_nlmsg_put_extra_header(nlh, sizeof(ifaddrmsg));

        ifm->ifa_family = AF_INET6;
        ifm->ifa_prefixlen = subnet;
        ifm->ifa_flags = IFA_F_PERMANENT;

        ifm->ifa_scope = RT_SCOPE_UNIVERSE;
        ifm->ifa_index = iface;
    }

    {
        mnl_attr_put(nlh, IFA_ADDRESS, sizeof(in6_addr), &ip6Addr);
    }

    {
        // DEPLOY-3121
        // Magic that lowers box ip priority
        ifa_cacheinfo cacheInfo;

        cacheInfo.ifa_prefered = 0;
        cacheInfo.ifa_valid = -1;
        cacheInfo.cstamp = 0;
        cacheInfo.tstamp = 0;

        mnl_attr_put(nlh, IFA_CACHEINFO, sizeof(cacheInfo), &cacheInfo);
    }

    return nlh;
}

TExpected<void, TIpClientError> AddressAction(
    const TString& device
    , const TIpDescription& ip
    , ui16 nlmsgType
    , ui16 nlmsgFlags
) {
    char buf[MNL_SOCKET_BUFFER_SIZE];

    ui32 iface = OUTCOME_TRYX(NameToIndex(device));
    in6_addr ip6Addr = OUTCOME_TRYX(ParseIp6Address(ip.Ip6));
    ui32 seq = TInstant::Now().Seconds();

    nlmsghdr* nlh = GetHeaderForAddressAction(
        buf
        , nlmsgType
        , nlmsgFlags
        , seq
        , iface
        , ip6Addr
        , ip.Subnet
    );

    TMnlSocket socket;

    OUTCOME_TRYV(socket.Open());
    OUTCOME_TRYV(socket.Bind());

    ui32 portId = OUTCOME_TRYX(socket.GetPortId());

    OUTCOME_TRYV(socket.SendTo(nlh, nlh->nlmsg_len));
    i32 ret = OUTCOME_TRYX(socket.RecvFrom(buf, sizeof(buf)));
    OUTCOME_TRYV(socket.CallbackRun(buf, ret, seq, portId, nullptr, nullptr));

    return TExpected<void, TIpClientError>::DefaultSuccess();
}

i32 ParseAttributeForListAddressWithFilterCallback(
    const nlattr* attr
    , void* data
) {
    const nlattr** tb = (const nlattr**)data;
    i32 type = mnl_attr_get_type(attr);

    // skip attribute with wrong type
    if (type != IFA_ADDRESS) {
        return MNL_CB_OK;
    }

    // skip unsupported attribute in user-space
    if (mnl_attr_type_valid(attr, IFA_MAX) < 0) {
        return MNL_CB_OK;
    }

    if (mnl_attr_validate(attr, MNL_TYPE_BINARY) < 0) {
        return MNL_CB_ERROR;
    }
    tb[type] = attr;

    return MNL_CB_OK;
}

i32 ListAddressWithFilterCallback(
    const nlmsghdr* nlh
    , void* data
) {
    nlattr* tb[IFA_MAX + 1] = {};
    char parsedIpAddress[INET6_ADDRSTRLEN];

    ListAddressData* listData = (ListAddressData*)data;
    ifaddrmsg* ifa = (ifaddrmsg*)mnl_nlmsg_get_payload(nlh);

    // skip by filter with device and ip type
    if (ifa->ifa_index != listData->iface || ifa->ifa_family != AF_INET6) {
        return MNL_CB_OK;
    }

    mnl_attr_parse(nlh, sizeof(*ifa), ParseAttributeForListAddressWithFilterCallback, tb);
    if (tb[IFA_ADDRESS]) {
        void* addr = mnl_attr_get_payload(tb[IFA_ADDRESS]);
        if (inet_ntop(ifa->ifa_family, addr, parsedIpAddress, sizeof(parsedIpAddress))) {
            listData->DumpedIps.push_back(
                TIpDescription(
                    TString(parsedIpAddress)
                    , ifa->ifa_prefixlen
               )
            );
        }
    }

    return MNL_CB_OK;
}

nlmsghdr* GetHeaderForListAddress(
    char buf[MNL_SOCKET_BUFFER_SIZE]
    , ui32 seq
) {
    nlmsghdr* nlh;

    nlh = mnl_nlmsg_put_header(buf);

    nlh->nlmsg_type = RTM_GETADDR;
    nlh->nlmsg_flags = (NLM_F_REQUEST | NLM_F_DUMP);
    nlh->nlmsg_seq = seq;

    {
        rtgenmsg* rt;
        rt = (rtgenmsg*)mnl_nlmsg_put_extra_header(nlh, sizeof(rtgenmsg));

        rt->rtgen_family = AF_INET6;
    }

    return nlh;
}

TExpected<TVector<TIpDescription>, TIpClientError> ListAddressWithFilter(
    const TString& device
) {
    char buf[MNL_SOCKET_BUFFER_SIZE];
    ListAddressData listData;

    listData.iface = OUTCOME_TRYX(NameToIndex(device));
    ui32 seq = TInstant::Now().Seconds();

    nlmsghdr* nlh = GetHeaderForListAddress(
        buf
        , seq
    );

    TMnlSocket socket;

    OUTCOME_TRYV(socket.Open());
    OUTCOME_TRYV(socket.Bind());

    ui32 portId = OUTCOME_TRYX(socket.GetPortId());

    OUTCOME_TRYV(socket.SendTo(nlh, nlh->nlmsg_len));

    i32 ret = OUTCOME_TRYX(socket.RecvFrom(buf, sizeof(buf)));
    while (ret > 0) {
        ret = OUTCOME_TRYX(socket.CallbackRun(buf, ret, seq, portId, ListAddressWithFilterCallback, &listData));
        if (ret <= MNL_CB_STOP) {
            break;
        }
        ret = OUTCOME_TRYX(socket.RecvFrom(buf, sizeof(buf)));
    }

    return listData.DumpedIps;
}

}

TExpected<void, TIpClientError> TSimpleIpClient::AddAddress(const TString& device, const TIpDescription& ip) {
    return AddressAction(
        device
        , ip
        , RTM_NEWADDR
        , (NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE)
    );
}

TExpected<void, TIpClientError> TSimpleIpClient::RemoveAddress(const TString& device, const TIpDescription& ip) {
   return AddressAction(
        device
        , ip
        , RTM_DELADDR
        , (NLM_F_REQUEST | NLM_F_ACK)
    );
}

TExpected<TVector<TIpDescription>, TIpClientError> TSimpleIpClient::ListAddress(const TString& device) {
    return ListAddressWithFilter(
        device
    );
}

} // namespace NInfra::NPodAgent
