import socket
import ctypes
import struct
import time
from fcntl import ioctl

from infra.netconfig.lib.exceptions import NetconfigError

SO_ATTACH_FILTER = 26  # from uapi/asm-generic/socket.h
ETH_P_ALL = 0x03  # Catch all packets
LLDP_PACKET_RECV_TIMEOUT = 120  # Wait for lldp packet
ETH_HEADER_LEN = 14  # Ethernet frame header length
LLDP_MULTICAST = b'\x01\x80\xc2\x00\x00\x0e'  # LLDP dst multicast
LLDP_ETH_PROTO = 0x88cc  # value of ethernet packet 'type' field for LLDP

LLDP_TLV_HEADER_LEN = 2  # 7 bits of 'TLV Type' field + 9 bits for 'TLV Payload Length' field == 16 bits == 2 bytes
LLDP_TLV_TYPE_MASK = 0xfe00  # 0b1111111000000000 first 7 bits of TLV 'header'
LLDP_TLV_PAYLOAD_LEN_MASK = 0xffff - LLDP_TLV_TYPE_MASK  # 0b0000000111111111 last 9 bits of TLV 'header'

LLDP_PORT_ID_TLV = 0x2  # TLV Type for Port ID
LLDP_SYS_NAME_TLV = 0x5  # TLV Type for System Name


# Got this from:
#     tcpdump -dd "ether[0] & 1 = 1 and ether proto 0x88cc and ether dst 01:80:c2:00:00:0e"
#
# For optimization purpose, we first check if the first bit of the
# first byte is 1. If not, this can only be an EDP packet.
# Got this trick from lldpd sources (daemon/lldpd.h)
# This is LLDP-only filter!
LLDP_FILTER = [
    (0x30, 0, 0, 0x00000000),
    (0x54, 0, 0, 0x00000001),
    (0x15, 0, 7, 0x00000001),
    (0x28, 0, 0, 0x0000000c),
    (0x15, 0, 5, 0x000088cc),
    (0x20, 0, 0, 0x00000002),
    (0x15, 0, 3, 0xc200000e),
    (0x28, 0, 0, 0x00000000),
    (0x15, 0, 1, 0x00000180),
    (0x6, 0, 0, 0x00040000),
    (0x6, 0, 0, 0x00000000),
]


class BpfProgram(ctypes.Structure):
    _fields_ = [
        ('bf_len', ctypes.c_int),
        ('bf_filter', ctypes.c_void_p)
    ]


class BpfInstruction(ctypes.Structure):
    _fields_ = [
        ('code', ctypes.c_uint16),
        ('jt', ctypes.c_uint8),
        ('jf', ctypes.c_uint8),
        ('k', ctypes.c_uint32),
    ]


def attach_lldp_filter(sock):
    filter = (BpfInstruction * len(LLDP_FILTER))()  # array of BpfInstruction, with length == len(LLDP_FILTER)
    for i, (code, jt, jf, k) in enumerate(LLDP_FILTER):
        filter[i].code = code
        filter[i].jt = jt
        filter[i].jf = jf
        filter[i].k = k

    prog = BpfProgram()
    prog.bf_len = len(LLDP_FILTER)
    prog.bf_filter = ctypes.addressof(filter)

    sock.setsockopt(socket.SOL_SOCKET, SO_ATTACH_FILTER, buffer(prog))


def form_mreq(iface_index):
    """
    struct packet_mreq {
        int mr_ifindex;
        unsigned short int mr_type;
        unsigned short int mr_alen;
        unsigned char mr_address[8];
    };
    """
    PACKET_MR_MULTICAST = 0
    return struct.pack('IHH8s', iface_index, PACKET_MR_MULTICAST, len(LLDP_MULTICAST), LLDP_MULTICAST)


def get_interface_index(sock, iface_name):
    """
    IFNAMSIZ = 16

    struct ifreq {
        union {
            char ifrn_name[IFNAMSIZ];   /* Interface name, e.g. "eth0".  */
        } ifr_ifrn;

        union {
            struct sockaddr ifru_addr;
            struct sockaddr ifru_dstaddr;
            struct sockaddr ifru_broadaddr;
            struct sockaddr ifru_netmask;
            struct sockaddr ifru_hwaddr;
            short int ifru_flags;
            int ifru_ivalue;
            int ifru_mtu;
            struct ifmap ifru_map;
            char ifru_slave[IFNAMSIZ];  /* Just fits the size */
            char ifru_newname[IFNAMSIZ];
            __caddr_t ifru_data;
        } ifr_ifru;
    };
    """
    SIOCGIFINDEX = 0x8933
    iface_name, iface_index = struct.unpack('16sI', ioctl(sock, SIOCGIFINDEX, struct.pack('16sI', iface_name, 0)))
    return iface_index


def hardcore_get_mac_address(sock, iface_name):
    SIOCGIFHWADDR = 0x8927
    iface_name, packet_type, mac = struct.unpack('16s2s6s', ioctl(sock, SIOCGIFHWADDR, struct.pack('24s', iface_name)))
    return mac


def run_lldp_sniff(interface):
    # interface: name of interface (eth0)
    SOL_PACKET = 263
    PACKET_ADD_MEMBERSHIP = 1

    s = None
    deadline = time.time() + LLDP_PACKET_RECV_TIMEOUT
    try:
        s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(ETH_P_ALL))
        s.settimeout(LLDP_PACKET_RECV_TIMEOUT)
        s.bind((interface, 0))
        attach_lldp_filter(s)

        interface_index = get_interface_index(s, interface)
        interface_mac = hardcore_get_mac_address(s, interface)
        mreq = form_mreq(interface_index)
        s.setsockopt(SOL_PACKET, PACKET_ADD_MEMBERSHIP, mreq)

        while time.time() < deadline:
            packet = s.recvfrom(65565)[0]  # returns tuple, so get first element with raw data
            ethernet_dst, ethernet_src, ethernet_type = struct.unpack('>6s6sH', packet[:ETH_HEADER_LEN])
            if ethernet_dst == LLDP_MULTICAST and ethernet_type == LLDP_ETH_PROTO and ethernet_src != interface_mac:
                lldp_data = packet[ETH_HEADER_LEN:]
                return lldp_data

        raise NetconfigError('Could not get LLDP frame until deadline')

    except socket.timeout:
        raise NetconfigError('LLDP socket timeout')
    except socket.error as e:
        raise NetconfigError('Socket error on LLDP sniff: {}'.format(e))
    finally:
        if s:
            s.close()


def parse_lldp_packet(interface):
    ''' Get data from lldp packet '''
    parsed_data = {}
    lldp_data = run_lldp_sniff(interface)

    while lldp_data:
        # Unpack TLV 'header' with '!H', because we got it from network, and it is 2 bytes long.
        tlv_header = struct.unpack('!H', lldp_data[:LLDP_TLV_HEADER_LEN])[0]

        # Shift out tlv_payload_length from header to get tlv_type
        tlv_type = tlv_header >> 9

        # 'TLV Type' of 0 is 'End of LLDPDU', so break
        if tlv_type == 0:
            break

        # Get length of TLV payload
        tlv_payload_length = tlv_header & LLDP_TLV_PAYLOAD_LEN_MASK
        # Get payload of TLV frame
        tlv_payload = lldp_data[LLDP_TLV_HEADER_LEN:tlv_payload_length + LLDP_TLV_HEADER_LEN]  # [2:tlv_payload_length + 2]

        # Get port name
        if tlv_type == LLDP_PORT_ID_TLV:
            # skip first byte in tlv_payload, because it is 'Port ID subtype' field
            parsed_data['port_name'] = tlv_payload[1:tlv_payload_length]

        # Get switch name
        if tlv_type == LLDP_SYS_NAME_TLV:
            parsed_data['switch_name'] = tlv_payload[:tlv_payload_length]

        # reduce lldp_data till next TLV header
        lldp_data = lldp_data[LLDP_TLV_HEADER_LEN + tlv_payload_length:]

    if not parsed_data:
        raise NetconfigError('Could not parse lldp packet')

    return parsed_data
