import copy

import six
import ipaddress  # noqa
from typing import Optional, Set, Tuple

from awacs.model import components
from infra.awacs.proto import model_pb2


L3_DECAPSULATOR_IP = '2a02:6b8:0:3400::aaaa'
L3_DECAPSULATOR_TUNNEL_ID = 'awacs-l3-decap'
DEFAULT_L3_DECAP_TUNNEL = model_pb2.BalancerContainerSpec.OutboundTunnel(
    id=L3_DECAPSULATOR_TUNNEL_ID,
    mode=model_pb2.BalancerContainerSpec.OutboundTunnel.IPIP6,
    remote_ip=L3_DECAPSULATOR_IP
)


DZEN_ENCUP_ID = 'tunvk'
DEFAULT_DZEN_DECAP_TUNNEL = model_pb2.BalancerContainerSpec.OutboundTunnel(
    id=DZEN_ENCUP_ID,
    mode=model_pb2.BalancerContainerSpec.OutboundTunnel.ANY,
    remote_ip='2a02:6b8:0:3400::deca',
    mtu=8952,
    rules=[],  # Will be filled from virtual services,
    encap=model_pb2.BalancerContainerSpec.OutboundTunnel.Encap(
        type=model_pb2.BalancerContainerSpec.OutboundTunnel.Encap.FOU,
        dport=6635
    )
)

DEFAULT_DZEN_DECAP_RULE = model_pb2.BalancerContainerSpec.OutboundTunnel.Rule(
    default_v4_route=model_pb2.BalancerContainerSpec.DefaultRoute(
        iface=DZEN_ENCUP_ID,
        mtu=1350,
        mpls_labels=[1002000]
    )
)

def ensure_shawshank_layer(cache, balancer_spec_pb):
    """
    :type cache: IAwacsCache
    :type balancer_spec_pb: model_pb2.BalancerSpec
    :rtype: bool
    """
    if balancer_spec_pb.components.shawshank_layer.state is balancer_spec_pb.components.shawshank_layer.SET:
        return False  # already set
    if not (balancer_spec_pb.container_spec.inbound_tunnels or balancer_spec_pb.container_spec.virtual_ips):
        return False  # don't need to set
    shawshank_config = components.get_component_config(model_pb2.ComponentMeta.SHAWSHANK_LAYER)
    latest_shawshank_version = shawshank_config.get_latest_published_version(cache)
    if latest_shawshank_version is None:
        raise RuntimeError(u'shawshank component not found')
    balancer_spec_pb.components.shawshank_layer.version = latest_shawshank_version
    balancer_spec_pb.components.shawshank_layer.state = balancer_spec_pb.components.shawshank_layer.SET
    return True


def update_l3_decap_tunnel_if_needed(balancer_spec_pb):
    """
    :type balancer_spec_pb: model_pb2.BalancerSpec
    :rtype: model_pb2.BalancerSpec
    """
    l3_decap_tunnel_pb, non_decap_ip4_addrs = _parse_outbound_tunnels(balancer_spec_pb)
    ipv4_addresses = []
    for v_ip in balancer_spec_pb.container_spec.virtual_ips:
        if v_ip.ip in non_decap_ip4_addrs:
            continue  # don't include addresses from non-decap tunnels
        if isinstance(ipaddress.ip_address(v_ip.ip), ipaddress.IPv4Address):  # only configure rules for ipv4 addresses
            ipv4_addresses.append(v_ip.ip)

    if not l3_decap_tunnel_pb and not ipv4_addresses:  # no decap tunnel, and it's not needed
        return False

    if l3_decap_tunnel_pb:
        if not ipv4_addresses:  # decap tunnel is not needed, should remove it
            balancer_spec_pb.container_spec.outbound_tunnels.remove(l3_decap_tunnel_pb)
            return True
        else:
            l3_decap_tunnel_pb.ClearField('rules')
    else:
        l3_decap_tunnel_pb = balancer_spec_pb.container_spec.outbound_tunnels.add()
        if balancer_spec_pb.custom_service_settings.service == balancer_spec_pb.custom_service_settings.DZEN:
            l3_decap_tunnel_pb.CopyFrom(DEFAULT_DZEN_DECAP_TUNNEL)
            for ipv4 in ipv4_addresses:
                rule_pb = l3_decap_tunnel_pb.rules.add()
                rule_pb.CopyFrom(DEFAULT_DZEN_DECAP_RULE)
                rule_pb.from_ip = ipv4
        else:
            l3_decap_tunnel_pb.CopyFrom(DEFAULT_L3_DECAP_TUNNEL)

    if balancer_spec_pb.custom_service_settings.service != balancer_spec_pb.custom_service_settings.DZEN:
        for ipv4 in ipv4_addresses:
            l3_decap_tunnel_pb.rules.add(from_ip=ipv4)
    return True


def update_inbound_tunnels_if_needed(balancer_spec_pb):
    """
    :type balancer_spec_pb: model_pb2.BalancerSpec
    :rtype: bool
    """
    # set default inbound tunnel
    if not balancer_spec_pb.container_spec.inbound_tunnels:
        balancer_spec_pb.container_spec.inbound_tunnels.add().fallback_ip6.SetInParent()
        return True
    return False


def _parse_outbound_tunnels(balancer_spec_pb):
    """
    :type balancer_spec_pb: model_pb2.BalancerSpec
    :rtype: Tuple[Optional[model_pb2.BalancerSpec.ContainerSpec.OutboundTunnel], Set[six.text_type]]
    """
    decap_tunnel_pb = None
    non_decap_ip4_addrs = set()
    for tunnel_pb in balancer_spec_pb.container_spec.outbound_tunnels:
        if tunnel_pb.id == L3_DECAPSULATOR_TUNNEL_ID:
            decap_tunnel_pb = tunnel_pb
        else:
            for rule in tunnel_pb.rules:
                non_decap_ip4_addrs.add(rule.from_ip)  # collect addresses from other tunnels
    return decap_tunnel_pb, non_decap_ip4_addrs


def configure_tunnels_for_ip_addresses(cache, l7_balancer_pb, description, ip_addresses):
    """
    Warning: this will replace all existing IP rules with rules for provided ip_addresses

    :type cache: IAwacsCache
    :type l7_balancer_pb: model_pb2.Balancer
    :type description: six.text_type
    :type ip_addresses: Set[six.text_type]
    :rtype: bool
    """
    updated = False
    ip_addresses = copy.copy(ip_addresses)
    container_spec_pb = l7_balancer_pb.spec.container_spec

    # collect currently configured tunnels
    # iterate in reverse order, so we can remove obsolete ips from spec
    for idx in reversed(range(len(container_spec_pb.virtual_ips))):
        v_ip = container_spec_pb.virtual_ips[idx].ip
        if v_ip in ip_addresses:
            ip_addresses.remove(v_ip)
        # TODO: we should ask user to remove old IPs manually, so we don't silently break anything
        # else:
        #     updated = True
        #     del container_spec_pb.virtual_ips[idx]

    # add tunnels for ips that were not in spec before
    for ip in sorted(ip_addresses):
        updated = True
        container_spec_pb.virtual_ips.add(ip=ip, description=description)

    updated |= update_inbound_tunnels_if_needed(l7_balancer_pb.spec)
    updated |= update_l3_decap_tunnel_if_needed(l7_balancer_pb.spec)
    updated |= ensure_shawshank_layer(cache, l7_balancer_pb.spec)
    return updated
