import logging
import netaddr
import socket
import time

import porto

from pyroute2 import IPRoute

from infra.shawshank.lib import util
from infra.shawshank.lib import netlink

from infra.shawshank.proto import tuns_pb2
from infra.awacs.proto import model_pb2 as awacs_model_pb2
from google.protobuf.json_format import Parse as ProtoParse


CONFIG_PATH = "awacs-balancer-container-spec.pb.json"

DONT_TOUCH_TUNS = ('ip6tnl0', 'ip_ext_tun0')

PROTO_TO_PYROTE_MODES_MAP = {
    0: netlink.PROTO_IPIP6,
    1: netlink.PROTO_IP6IP6,
    2: netlink.PROTO_ANY,
}

PYROUTE_TO_PROTO_MODES_MAP = {
    netlink.PROTO_IPIP6: 0,
    netlink.PROTO_IP6IP6: 1,
    netlink.PROTO_ANY: 2,
}

TUN_KINDS = (
    'ip6tnl',
)


BASE_RULE_PRIORITY = 10000


def read_spec(config_path):
    """ Read config file to protobuf """
    logging.info("Reading configuration")
    msg = None
    awacs_msg = awacs_model_pb2.BalancerContainerSpec()
    try:
        with open(config_path, 'r') as f:
            json_msg = f.read()
            ProtoParse(json_msg, awacs_msg, ignore_unknown_fields=True)
            msg = util.awacs_instance_spec_pb_to_internal_pb(awacs_msg)
    except Exception as e:
        logging.warning("Could not parse config {}:\n{}".format(config_path, e))

    return msg


def validate_v4_vips_and_from_ips(spec):
    out_spec_ips = set()
    for tun in spec.tunnels:
        for r in tun.rules:
            net_ip = netaddr.IPAddress(r.from_ip)
            if net_ip.version == 4:
                out_spec_ips.add(net_ip.value)

    vips_set = {netaddr.IPAddress(vip.ip).value for vip in spec.virtual_ips if netaddr.IPAddress(vip.ip).version == 4}
    if vips_set == out_spec_ips:
        return True
    return False


def validate_porto_ips_and_vips(spec):
    vips_set = {netaddr.IPAddress(vip.ip).value for vip in spec.virtual_ips}

    try:
        porto_ips = get_porto_tun_ips()
    except Exception as e:
        logging.critical('Could not get porto ips: {}'.format(e))
        return False

    diff = porto_ips - vips_set
    if diff == set():
        return True
    diff_str = {netaddr.IPAddress(ip) for ip in diff}
    logging.critical('IP addresses that are not in awacs spec: {}'.format(diff_str))
    return False


def is_valid_spec(spec, noporto):
    logging.info("Validating spec...")
    res = True
    res &= validate_v4_vips_and_from_ips(spec)
    if not noporto:
        res &= validate_porto_ips_and_vips(spec)
    return res


def get_porto_tun_ips():
    porto_tun_ips = set()

    rpc = porto.Connection()
    rpc.connect()

    net = None
    ips = None
    begin_path = '.'
    path = begin_path
    i = 0
    try:
        while True:
            if path == begin_path and i == 0:
                pass
            elif path == begin_path and i == 1:
                path = '..'
            else:
                path = '/'.join((path, '..'))

            net = rpc.GetProperty(path, 'net')
            if net != 'inherited':
                ips = rpc.GetProperty(path, 'ip')
                break
            i += 1

    except porto.exceptions.Permission:
        logging.info("Non-inherited net not found")
        pass

    rpc.disconnect()

    if not ips:
        return porto_tun_ips

    L3_ifaces = set()
    # Example of net:
    # 'L3 veth;ipip6 tun0 2a02:6b8:0:3400::aaaa 2a02:6b8:c0a:100f:0:696:489e:0;MTU tun0 1450;MTU ip6tnl0 1450'
    for n in net.split(';'):
        split = n.split()
        if len(split) == 2 and split[0] == 'L3':
            L3_ifaces.add(split[1])

    # Example of ips:
    # 'veth 2a02:6b8:fc12:100f:0:696:98b5:0;veth 2a02:6b8:c0a:100f:0:696:489e:0;ip6tnl0 2a02:6b8::1:119;tun0 5.45.202.13'
    ips = ips.split(';')
    for i in ips:
        iface_name, iface_ip = i.split()
        if iface_name not in L3_ifaces:
            porto_tun_ips.add(netaddr.IPAddress(iface_ip).value)

    return porto_tun_ips


def check_vip_addrs_existence(spec):
    with IPRoute() as ipr:
        addresses_info = ipr.get_addr()
        all_ip_addrs = {netaddr.IPAddress(i.get_attrs('IFA_ADDRESS')[0]).value for i in addresses_info}

    all_spec_addrs = set()
    for vip in spec.virtual_ips:
        all_spec_addrs.add(netaddr.IPAddress(vip.ip).value)

    # RTCNETWORK-529 - new version of shawshank with old spec can create more than one IP.
    # So, we should check not only VIPs, but those ips in rules too.
    for tun in spec.tunnels:
        for r in tun.rules:
            all_spec_addrs.add(netaddr.IPAddress(r.from_ip).value)

    missing = all_spec_addrs - all_ip_addrs
    return missing


def check_and_add_addr(spec):
    baddrs = netlink.get_addr(netlink.FALLBACK_IFACE)
    logging.debug("addrs on {} before flushing: {}".format(netlink.FALLBACK_IFACE, baddrs))

    logging.debug("flushing {}".format(netlink.FALLBACK_IFACE))
    netlink.flush_addr(netlink.FALLBACK_IFACE)  # not to calc diff

    logging.debug("sleep 3 seconds after flush")
    time.sleep(3)

    baddrs = netlink.get_addr(netlink.FALLBACK_IFACE)
    logging.debug("addrs on {} after flushing: {}".format(netlink.FALLBACK_IFACE, baddrs))

    missing = check_vip_addrs_existence(spec)
    if missing:
        for m in missing:
            na = netaddr.IPAddress(m)
            logging.info("Trying to add missing ip: {}".format(na))
            ret = add_addr_and_check(netlink.FALLBACK_IFACE, str(na))
            assert ret is True

    addrs = netlink.get_all_addresses()
    logging.debug("get_all_addresses(): {}".format(addrs))


def check_default_routes(spec, table=None):
    changed = False

    with IPRoute() as ipr:
        if spec.default_v4_route.iface:
            iface_id = netlink.get_iface_id(spec.default_v4_route.iface, ipr=ipr)
            if iface_id:
                changed |= netlink.check_default_route(socket.AF_INET, iface_id, spec.default_v4_route, table)

        if spec.default_v6_route.iface:
            iface_id = netlink.get_iface_id(spec.default_v6_route.iface, ipr=ipr)
            if iface_id:
                changed |= netlink.check_default_route(socket.AF_INET6, iface_id, spec.default_v6_route, table)

    return changed


def extend_with_default_route(family, route_msg, table=None):
    default_route = netlink.get_default_route_params(family, table, need_ifname=True)
    if default_route:
        route_msg.iface = default_route['iface']
        if 'mtu' in default_route:
            route_msg.mtu = default_route['mtu']
        if 'advmss' in default_route:
            route_msg.advmss = default_route['advmss']
        if 'mpls_labels' in default_route:
            route_msg.mpls_labels.extend(default_route['mpls_labels'])


def extend_with_rules_and_routes(msg_w_tuns):
    """ Add info about route/rules to tuns_pb2.Tunnel messsage. """

    logging.info("Extending formed proto with rules and routes")
    with IPRoute() as ipr:
        for tun in msg_w_tuns.tunnels:
            table_id = None
            link_id = ipr.link_lookup(ifname=tun.name)[0]
            fam = get_rules_family(tun.mode)
            routes = ipr.get_routes(family=fam, oif=link_id)
            rule_msg = tuns_pb2.Rule()
            for r in routes:
                tmp_table_id = r.get_attrs('RTA_TABLE')[0]
                if tmp_table_id not in netlink.DEFAULT_TABLES:
                    table_id = tmp_table_id
                    break

            rules = []
            if table_id:
                v4_rules = ipr.get_rules(family=socket.AF_INET, table=table_id)
                v6_rules = ipr.get_rules(family=socket.AF_INET6, table=table_id)
                rules = v4_rules + v6_rules
            for rule in rules:
                rule_msg = tuns_pb2.Rule()
                rule_msg.table_id = table_id
                rule_msg.from_ip = rule.get_attrs('FRA_SRC')[0]
                rule_msg.priority = rule.get_attrs('FRA_PRIORITY')[0]
                extend_with_default_route(socket.AF_INET, rule_msg.default_v4_route, table_id)
                extend_with_default_route(socket.AF_INET6, rule_msg.default_v6_route, table_id)

                tun.rules.extend([rule_msg])

    extend_with_default_route(socket.AF_INET, msg_w_tuns.default_v4_route)
    extend_with_default_route(socket.AF_INET6, msg_w_tuns.default_v6_route)

    return msg_w_tuns


def form_current_tunnels_proto():
    """ Generate tuns_pb2.ContainerSpec() with currently configured tunnels, rules and routes """

    logging.info("Form proto with current tunnels")
    msg = tuns_pb2.ContainerSpec()

    links = netlink.get_links()

    for link in links:
        tun = tuns_pb2.Tunnel()

        link_name = link.get_attr('IFLA_IFNAME')
        if link_name in DONT_TOUCH_TUNS:
            continue

        remote_ip, remote_proto = netlink.get_remote_ip_and_proto(link)
        if remote_ip is not None and remote_proto is not None:
            tun.name = link_name
            tun.mtu = netlink.get_link_mtu(tun.name)
            tun.remote_ip = remote_ip
            tun.mode = PYROUTE_TO_PROTO_MODES_MAP.get(remote_proto)
            msg.tunnels.extend([tun])

    extend_with_rules_and_routes(msg)
    logging.debug("==== FORMED MESSAGE:\n{}\n".format(msg))
    return msg


def create_tunnels(msg, iface_name, check_if_in_mtn=True):
    """ Input: 'msg' is tuns_pb2.ContainerSpec() """
    try:
        local_ip_and_prefix = netlink.get_local_bb_ip(iface_name, check_if_in_mtn)
    except IndexError:
        raise Exception("Could not get bb ip from {} device, check mtn is {}".format(iface_name, check_if_in_mtn))
    local_ip = str(local_ip_and_prefix.ip)
    for tun in msg.tunnels:
        logging.info("Create tunnel {}".format(tun.name))
        encap = {}
        if tun.encap.type != tun.encap.NONE:
            encap["type"] = tun.encap.type
            encap["sport"] = tun.encap.sport
            encap["dport"] = tun.encap.dport
            encap["limit"] = tun.encap.limit
        link_id = netlink.create_ip6_tun(tun.name, PROTO_TO_PYROTE_MODES_MAP.get(tun.mode), local_ip, tun.remote_ip, tun.mtu, encap)
        for r in tun.rules:
            ret = add_addr_and_check(tun.name, r.from_ip)
            assert ret is True
        setup_rules_and_routes(tun, link_id, msg.default_v4_route, msg.default_v6_route)


def setup_rules_and_routes(tun_msg, link_id, global_default_v4_route, global_default_v6_route):
    """ Configure proper rules and routes.
        Input: tuns_pb2.Tunnel() message, not whole Spec.
    """

    logging.info("Creating rules and routes for {}".format(tun_msg.name))
    for rule in tun_msg.rules:
        priority = rule.priority or BASE_RULE_PRIORITY
        fam = get_rules_family(tun_msg.mode)
        global_default_route = global_default_v4_route if fam == socket.AF_INET else global_default_v6_route
        if tun_msg.name == global_default_route.iface:
            continue
        with IPRoute() as ipr:
            kwargs = {
                'family': fam,
                'priority': priority,
                'table': rule.table_id,
                'src': rule.from_ip,
            }
            if rule.to_ip:
                kwargs['dst'] = rule.to_ip

            logging.debug('add rule kwargs: {}'.format(kwargs))
            ipr.rule("add", **kwargs)

        default_route = rule.default_v4_route if fam == socket.AF_INET else rule.default_v6_route
        netlink.check_default_route(fam, link_id, default_route, rule.table_id)


def change_tunnels(new_container_msg, old_container_msg, iface_name, check_if_in_mtn):
    delete_tunnels(old_container_msg)
    create_tunnels(new_container_msg, iface_name=iface_name, check_if_in_mtn=check_if_in_mtn)


def delete_tunnels(container_msg):
    for tun in container_msg.tunnels:
        logging.info("Delete tunnel: {}".format(tun.name))
        fam = get_rules_family(tun.mode)
        with IPRoute() as ipr:
            link_id = ipr.link_lookup(ifname=tun.name)[0]
            ipr.link('set', index=link_id, state='down')
            ipr.link('del', index=link_id)
            for rule in tun.rules:
                logging.info("Delete rule: {}".format(rule))
                ipr.rule('del', table=rule.table_id, src=rule.from_ip, family=fam, priority=rule.priority)


def get_rules_family(tun_mode):
    """ Dumb way to determine rule family """

    if tun_mode == PYROUTE_TO_PROTO_MODES_MAP.get(41):
        return socket.AF_INET6
    else:
        return socket.AF_INET


def filter_tunnels_by_name(set_of_tun_names, container_msg):
    """ Return tuns_pb2.ContainerSpec() with tunnels which names are in set_of_tun_names. """

    filtered_container = tuns_pb2.ContainerSpec()
    for t in container_msg.tunnels:
        if t.name in set_of_tun_names:
            filtered_container.tunnels.extend([t])
    return filtered_container


def rules_changed(cur_rules, spec_rules):
    """ Determine if there are changes in rules.
        Input: tuns_pb2.Tunnel().rules
    """

    if len(cur_rules) != len(spec_rules):
        return True

    for cur_r in cur_rules:
        found_full_match = False
        for spec_r in spec_rules:
            found_full_match |= (
                netaddr.IPAddress(spec_r.from_ip).value == netaddr.IPAddress(cur_r.from_ip).value and
                spec_r.to_ip == cur_r.to_ip and
                spec_r.table_id == cur_r.table_id and
                spec_r.priority == cur_r.priority and
                not check_default_routes(spec_r, spec_r.table_id)
            )

        if not found_full_match:
            return True
    return False


def encap_changed(cur_encap, spec_encap):
    """ Determine if there are changes in encap.
        Input: tuns_pb2.Tunnel().encap
    """

    return (
        spec_encap.type != cur_encap.type or
        spec_encap.sport != cur_encap.sport or
        spec_encap.dport != cur_encap.dport or
        spec_encap.limit != cur_encap.limit
        )


def compare_tuns_matched_by_name(cur, spec):
    """ Find diffs between tunnels with matched names.
        Input: tuns_pb2.ContainerSpec()
    """
    new = tuns_pb2.ContainerSpec()
    old = tuns_pb2.ContainerSpec()
    for cur_t in cur.tunnels:
        for spec_t in spec.tunnels:
            if spec_t.name == cur_t.name:
                if (
                        netaddr.IPAddress(spec_t.remote_ip).value != netaddr.IPAddress(cur_t.remote_ip).value or
                        spec_t.mode != cur_t.mode or
                        spec_t.mtu != cur_t.mtu or
                        spec_t.advmss != cur_t.advmss or
                        rules_changed(cur_t.rules, spec_t.rules) or
                        encap_changed(cur_t.encap, spec_t.encap)
                ):
                    new.tunnels.extend([spec_t])
                    old.tunnels.extend([cur_t])
                    break
    return new, old


def calc_diff(current_tunnels, spec_tunnels):
    """ Find diffs in configurations.
        Input: tuns_pb2.ContainerSpec()
    """

    logging.info("Calculating diff")
    diff = {
        'new': None,
        'changed': None,
        'deleted': None,
    }

    set_of_current_tun_names = {tun.name for tun in current_tunnels.tunnels}
    set_of_spec_tun_names = {tun.name for tun in spec_tunnels.tunnels}

    new_tunnels_names = set_of_spec_tun_names - set_of_current_tun_names
    old_tunnels_names = set_of_current_tun_names - set_of_spec_tun_names
    matched_tunnels_names = set_of_current_tun_names & set_of_spec_tun_names  # to check if there are some changes

    new_tunnels = filter_tunnels_by_name(new_tunnels_names, spec_tunnels)
    old_tunnels = filter_tunnels_by_name(old_tunnels_names, current_tunnels)
    matched_tunnels_spec = filter_tunnels_by_name(matched_tunnels_names, spec_tunnels)
    matched_tunnels_cur = filter_tunnels_by_name(matched_tunnels_names, current_tunnels)

    changed_tunnels_new_state, changed_tunnels_old_state = compare_tuns_matched_by_name(
        matched_tunnels_cur, matched_tunnels_spec
    )

    if len(new_tunnels.tunnels) > 0:
        diff['new'] = new_tunnels
    if len(old_tunnels.tunnels) > 0:
        diff['old'] = old_tunnels
    if len(changed_tunnels_new_state.tunnels) > 0:
        diff['changed'] = (changed_tunnels_new_state, changed_tunnels_old_state)

    logging.debug("========= NEW TUNNELS:\n{}\n".format(new_tunnels))
    logging.debug("========= OLD TUNNELS:\n{}\n".format(old_tunnels))
    logging.debug("========= CHANGED TUNNELS:\n--- NEW:\n{}\n--- OLD:\n{}".format(
        changed_tunnels_new_state,
        changed_tunnels_old_state
    ))

    return diff


def add_addr_and_check(iface_name, addr, mask=None):
    na = netaddr.IPAddress(addr)  # checks, that address is valid
    ret = netlink.add_addr(iface_name, str(na), mask=mask)
    logging.info("result for {}: {}".format(str(na), ret))
    return netlink.addr_exists(str(na))
