import logging
import socket
import ipaddr
import netaddr
from pyroute2 import IPRoute

MTN_BACKBONE_PREFIX = ipaddr.IPNetwork("2a02:6b8:c00::/40")

DEFAULT_TABLES = (253, 254, 255, 0)  # default, main, local, unspec

V6_DEF_GW = "::/0"
V4_DEF_GW = "0.0.0.0/0"

V6_DEF_HOPLIMIT = 64
V6_DEF_ENCAPLIMIT = 4

# IFLA_IP6TNL_PROTO = 4 for ipip6
# IFLA_IP6TNL_PROTO = 41 for ip6ip6
PROTO_ANY = 0
PROTO_IPIP6 = 4
PROTO_IP6IP6 = 41

IP6_TNL_F_IGN_ENCAP_LIMIT = 1

LWTUNNEL_ENCAP_MPLS = 1

FAMILY_MASK_MAP = {
    4: 32,
    6: 128,
}

FALLBACK_IFACE = 'ip6tnl0'


def create_ip6_tun(name, proto, local_ip, remote_ip, mtu, encap=None):
    tun_kwargs = {
        'ifname': name,
        'kind': 'ip6tnl',
        'ip6tnl_proto': proto,
        'ip6tnl_local': local_ip,
        'ip6tnl_remote': remote_ip,
        'ip6tnl_ttl': V6_DEF_HOPLIMIT,
        'mtu': mtu,
    }
    if encap:
        tun_kwargs['ip6tnl_encap_type'] = encap['type']
        if encap['sport']:
            tun_kwargs['ip6tnl_encap_sport'] = encap['sport']
        if encap['dport']:
            tun_kwargs['ip6tnl_encap_dport'] = encap['dport']
        if encap['limit']:
            tun_kwargs['ip6tnl_encap_limit'] = encap['limit']
        else:
            tun_kwargs['ip6tnl_flags'] = IP6_TNL_F_IGN_ENCAP_LIMIT
    else:
        tun_kwargs['ip6tnl_encap_limit'] = V6_DEF_ENCAPLIMIT

    with IPRoute() as ipr:
        logging.debug("add link kwargs: {}".format(tun_kwargs))
        ipr.link('add', **tun_kwargs)
        link_id = ipr.link_lookup(ifname=name)[0]
        # had problems with creating ifaces with specified mtu
        ipr.link('set', index=link_id, mtu=mtu)
        ipr.link('set', index=link_id, state='up')
    return link_id


def get_iface_id(name, ipr=None):
    def _get_iface_id(name, ipr):
        try:
            link = ipr.get_links(ifname=name)[0]
            return link.get('index')
        except IndexError:
            logging.warning("Could not find {} iface".format(name))
        return None

    if ipr:
        return _get_iface_id(name, ipr)

    with IPRoute() as ipr:
        return _get_iface_id(name, ipr)


def get_iface_name(iface_id, ipr=None):
    def _get_iface_name(iface_id, ipr):
        try:
            link = ipr.get_links(iface_id)[0]
            return link.get_attr('IFLA_IFNAME')
        except IndexError:
            logging.warning("Could not find {} iface".format(iface_id))
        return None

    if ipr:
        return _get_iface_name(iface_id, ipr)

    with IPRoute() as ipr:
        return _get_iface_name(iface_id, ipr)


def get_mpls_labels_from_encap(encap):
    labels = []
    if encap:
        for e in encap['attrs'][0][1]:
            labels.append(e['label'])
    return labels


def get_default_route_params(family, table, need_ifname=False):
    params = {}
    with IPRoute() as ipr:
        if table:
            routes = ipr.get_default_routes(family=family, table=table)
        else:
            routes = ipr.get_default_routes(family=family)
        if not routes:
            return None
        route = routes[0]
        oif = route.get_attr('RTA_OIF')
        params['iface_id'] = oif
        if need_ifname:
            params['iface'] = get_iface_name(oif, ipr)

        metrics = route.get_attr('RTA_METRICS')
        if metrics:
            mtu = metrics.get_attr('RTAX_MTU')
            if mtu:
                params['mtu'] = mtu
            advmss = metrics.get_attr('RTAX_ADVMSS')
            if advmss:
                params['advmss'] = advmss

        encap_type = route.get_attr('RTA_ENCAP_TYPE')
        if encap_type == LWTUNNEL_ENCAP_MPLS:
            params['mpls_labels'] = get_mpls_labels_from_encap(route.get_attr('RTA_ENCAP'))
    return params


def set_default_route(mode, family, oif_id, spec, table):
    v4 = family == socket.AF_INET
    logging.info("Setting {} default route {}to iface_id={} mtu={} advmss={} mpls_labels={}".format(
        "v4" if v4 else "v6", "in table {} ".format(table) if table else "", oif_id, spec.mtu, spec.advmss, spec.mpls_labels))
    kwargs = {
        'oif': oif_id,
        'dst': V4_DEF_GW if v4 else V6_DEF_GW,
        'scope': 253,
        'family': family,
    }
    if table:
        kwargs['table'] = table
    if spec.mtu or spec.advmss:
        kwargs['metrics'] = {}
        if spec.mtu:
            kwargs['metrics']['mtu'] = spec.mtu
        if spec.advmss:
            kwargs['metrics']['advmss'] = spec.advmss
    if spec.mpls_labels:
        kwargs['encap'] = {'type': 'mpls', 'labels': [label for label in spec.mpls_labels]}
    with IPRoute() as ipr:
        ipr.route(mode, **kwargs)


def check_default_route(family, oif_id, spec=None, table=None):
    v4 = family == socket.AF_INET
    current_params = get_default_route_params(family, table)
    if current_params is None:
        logging.info("Default {} route does not exists".format("v4" if v4 else "v6"))
        set_default_route('add', family, oif_id, spec, table)
        return True

    oif_changed = current_params['iface_id'] != oif_id

    metrics_changed = False
    encap_changed = False

    if spec:
        if not metrics_changed and spec.mtu and 'mtu' in current_params:
            metrics_changed = current_params['mtu'] != spec.mtu
        if not metrics_changed and spec.advmss and 'advmss' in current_params:
            metrics_changed = current_params['advmss'] != spec.advmss

        if not encap_changed and spec.mpls_labels and 'mpls_labels' in current_params:
            encap_changed = current_params['mpls_labels'] != [label for label in spec.mpls_labels]

    if not current_params or oif_changed or metrics_changed or encap_changed:
        logging.info("Default {} route is not configured properly".format("v4" if v4 else "v6"))
        set_default_route('replace', family, oif_id, spec, table)
        return True

    return False


def get_local_bb_ip(iface_name, check_if_in_mtn):
    """ Get backbone ip address from specified iface.
        It's possible to check if this ip in MTN or not.
    """

    found_addr = None
    with IPRoute() as ipr:
        iface_id = ipr.link_lookup(ifname=iface_name)[0]
        addrs = ipr.get_addr(index=iface_id, scope=0, family=socket.AF_INET6)
        for a in addrs:
            addr = a.get_attrs('IFA_ADDRESS')[0]
            prefixlen = a['prefixlen']
            formed_addr = ipaddr.IPNetwork('/'.join((addr, str(prefixlen))))
            if check_if_in_mtn:
                if formed_addr in MTN_BACKBONE_PREFIX:
                    found_addr = formed_addr
                    break
            else:
                found_addr = formed_addr
                break

    return found_addr


def get_links():
    with IPRoute() as ipr:
        return ipr.get_links()


def get_link_mtu(iface_name):
    """ Get link MTU """
    with IPRoute() as ipr:
        link_id = ipr.link_lookup(ifname=iface_name)
        info = ipr.get_links(link_id)[0]
        mtu = info.get_attr('IFLA_MTU')
        return mtu


def get_remote_ip_and_proto(linkobj):
    remote_ip = None
    remote_proto = None
    link_info = linkobj.get_attr('IFLA_LINKINFO')
    if link_info:
        kind = link_info.get_attr('IFLA_INFO_KIND')
        if kind == 'ip6tnl':
            link_info_data = link_info.get_attr('IFLA_INFO_DATA')
            try:
                remote_ip = link_info_data.get_attr('IFLA_IP6TNL_REMOTE')
                remote_proto = link_info_data.get_attr('IFLA_IP6TNL_PROTO')
            except:
                logging.warning("Could not get remote_ip or remote_proto:\n{}".format(link_info))

    return (remote_ip, remote_proto)


def guess_mask(addr):
    family = netaddr.IPAddress(addr).version
    return FAMILY_MASK_MAP[family]


def add_addr(iface_name, addr, mask=None):
    if mask is None:
        mask = guess_mask(addr)

    ret = None
    with IPRoute() as ipr:
        link_id = ipr.link_lookup(ifname=iface_name)[0]
        ret = ipr.addr(
            "add",
            address=addr,
            index=link_id,
            mask=mask,
        )
    return ret


def flush_addr(iface_name):
    ret = None
    with IPRoute() as ipr:
        link_id = ipr.link_lookup(ifname=iface_name)[0]
        ret = ipr.flush_addr(index=link_id)
    return ret


def get_all_addresses():
    with IPRoute() as ipr:
        addresses_info = ipr.get_addr()
        return {i.get_attr('IFA_ADDRESS') for i in addresses_info}


def addr_exists(addr):
    set_of_addrs = get_all_addresses()
    set_of_addrs = {netaddr.IPAddress(a).value for a in set_of_addrs}
    return netaddr.IPAddress(addr).value in set_of_addrs


def get_addr(iface_name):
    with IPRoute() as ipr:
        link_id = ipr.link_lookup(ifname=iface_name)[0]
        addresses_info = ipr.get_addr(index=link_id)
        return {i.get_attr('IFA_ADDRESS') for i in addresses_info}
