#include <cstring>
#include <linux/if_packet.h>
#include <linux/if_ether.h>
#include <sys/ioctl.h>
#include <util/generic/yexception.h>
#include "socket.h"

namespace {
    int getInterfaceIndex(int socket, const TString& interfaceName) {
        Y_VERIFY(interfaceName.size() <= IFNAMSIZ);

        ifreq ifr;
        memset(&ifr, 0, sizeof(ifr));
        memcpy(ifr.ifr_name, interfaceName.c_str(), interfaceName.size());

        Y_ENSURE(ioctl(socket, SIOCGIFINDEX, &ifr) != -1, "getInterfaceIndex for '" << interfaceName << "' failed");

        return ifr.ifr_ifindex;
    }
}

TTcpUdpSocket::TTcpUdpSocket(int family, int protocol)
    : AddressFamily(family)
    {
    switch(protocol) {
        case (SOCK_DGRAM):
            SocketFd = socket(AddressFamily, SOCK_DGRAM, 0);
            break;
        case (SOCK_STREAM):
            SocketFd = socket(AddressFamily, SOCK_STREAM, 0);
            break;
    }

    Y_ENSURE(SocketFd > 0, "TSocket::TSocket SocketFd = " << SocketFd);
}

TTcpUdpSocket::~TTcpUdpSocket() {
    close(SocketFd);
}

TRawSocket::TRawSocket() {
    SocketFd = socket(PF_PACKET, SOCK_RAW, htons(ETH_P_ALL));
    Y_ENSURE(SocketFd > 0, "TRawSocket::TRawSocket socket SOCK_RAW = " << SocketFd);
}

TRawSocket::~TRawSocket() {
    close(SocketFd);
}

int TRawSocket::fd() const {
    return SocketFd;
}

int TTcpUdpSocket::fd() const {
    return SocketFd;
}

TInterfaceBindedSocket::TInterfaceBindedSocket(const TString& interfaceName)
    : Socket(TRawSocket())
    , InterfaceName(interfaceName)
    , InterfaceIndex(getInterfaceIndex(Socket.fd(), interfaceName))
{
    sockaddr_ll bindAddress;
    memset(&bindAddress, 0, sizeof(bindAddress));

    bindAddress.sll_family = AF_PACKET;
    bindAddress.sll_protocol = htons(ETH_P_ALL);
    bindAddress.sll_ifindex = InterfaceIndex;

    auto result = bind(fd(), (sockaddr*)&bindAddress, sizeof(bindAddress));
    Y_ENSURE(result != -1, "TInterfaceBindedSocket::TInterfaceBindedSocket bind for interface index " << InterfaceIndex << " failed");
}

int TInterfaceBindedSocket::fd() const {
    return Socket.fd();
}
