import gevent
import inject
import six
from six.moves import urllib_parse, http_client as httplib
from typing import List

from awacs.lib import l3mgrclient
from awacs.lib.strutils import to_full_id
from awacs.lib.vectors import cacheutil
from awacs.lib.vectors.state_handler import StateHandler, Rev
from awacs.lib.vectors.vector import Vector, ValidationError
from awacs.lib.vectors.vector_discovered import DiscoveredVectorWithBackends
from awacs.lib.vectors.vector_mutable import MutableVectorWithBackends
from awacs.lib.vectors.version import L3BalancerVersion, BackendVersion, EndpointSetVersion
from awacs.model.l3_balancer import events
from infra.awacs.proto import model_pb2


class L3BalancerVector(Vector):
    __main_version_class__ = L3BalancerVersion
    __version_classes__ = (BackendVersion, EndpointSetVersion)

    def __init__(self, l3_balancer_version, backend_versions, endpoint_set_versions):
        self.l3_balancer_version = l3_balancer_version
        self.backend_versions = backend_versions
        self.endpoint_set_versions = endpoint_set_versions
        super(L3BalancerVector, self).__init__()


class MutableL3BalancerVector(MutableVectorWithBackends):
    __main_version_class__ = L3BalancerVersion
    __version_classes__ = (BackendVersion, EndpointSetVersion)

    _l3mgr_client = inject.attr(l3mgrclient.IL3MgrClient)  # type: l3mgrclient.L3MgrClient
    _l3mgr_error_counter = events.L3_CTL_REGISTRY.get_counter(u'l3mgr-error')

    def __init__(self, l3_balancer_version, backend_versions, endpoint_set_versions, validated_pbs):
        self.l3_balancer_version = l3_balancer_version
        super(MutableL3BalancerVector, self).__init__(
            backend_versions=backend_versions,
            endpoint_set_versions=endpoint_set_versions,
            validated_pbs=validated_pbs)

    def validate(self, ctx):
        """
        :type ctx: context.OpCtx
        :raises: ValidationError
        """
        assert self.l3_balancer_version
        namespace_id, l3_balancer_id = self.l3_balancer_version.id
        l3_balancer_spec_pb = cacheutil.must_get_l3_balancer_revision_spec_with_cache(
            self._cache, namespace_id, l3_balancer_id, self.l3_balancer_version.version)

        self._validate_l3mgr_service_existence(ctx, l3_balancer_spec_pb.l3mgr_service_id)

        included_backend_ids = get_included_backend_ids(namespace_id, l3_balancer_spec_pb)
        self.validate_backends_and_endpoint_sets(ctx, namespace_id, included_backend_ids)

        if l3_balancer_spec_pb.ctl_version >= 2 and not l3_balancer_spec_pb.virtual_servers:
            raise ValidationError(
                u'L3Balancer with ctl_version 2 or greater must have virtual_servers',
                cause=self.l3_balancer_version)

    def _validate_l3mgr_service_existence(self, ctx, l3mgr_service_id):
        try:
            return self._l3mgr_client.get_service(l3mgr_service_id)
        except gevent.Timeout as e:
            self._l3mgr_error_counter.inc()
            raise ValidationError(six.text_type(e), cause=self.l3_balancer_version)
        except l3mgrclient.L3MgrException as e:
            if e.resp is not None and e.resp.status_code == httplib.NOT_FOUND:
                ctx.log.warn(u'%s is missing from l3mgr for some reason', l3mgr_service_id)
            else:
                ctx.log.exception(u'failed to call get_service')
            self._l3mgr_error_counter.inc()
            raise ValidationError(six.text_type(e), cause=self.l3_balancer_version)


class DiscoveredL3BalancerVector(DiscoveredVectorWithBackends):
    __main_version_class__ = L3BalancerVersion
    __version_classes__ = (BackendVersion, EndpointSetVersion)

    def __init__(self, l3_balancer_version, backend_versions, endpoint_set_versions):
        self.l3_balancer_version = l3_balancer_version
        super(DiscoveredL3BalancerVector, self).__init__(
            backend_versions=backend_versions,
            endpoint_set_versions=endpoint_set_versions,
        )

    @classmethod
    def from_cache(cls, namespace_id, main_id):
        l3_balancer_version = cacheutil.must_get_l3_balancer_version(cls._cache, namespace_id, main_id)
        backend_versions = cacheutil.get_backend_versions(cls._cache, namespace_id)
        endpoint_set_versions = cacheutil.get_endpoint_set_versions(cls._cache, namespace_id)
        return cls(l3_balancer_version, backend_versions, endpoint_set_versions)

    def get_included_backends(self, version):
        ns_id, l3_balancer_id = self.l3_balancer_version.id
        l3_balancer_spec_pb = cacheutil.must_get_l3_balancer_revision_spec_with_cache(
            self._cache, ns_id, l3_balancer_id, version.version)
        return get_included_backend_ids(ns_id, l3_balancer_spec_pb)


class L3Rev(Rev):
    __slots__ = ()

    @staticmethod
    def set_in_progress(rev_pb, l3mgr_config_pb):
        if rev_pb is None:
            return False
        meta_pb = rev_pb.in_progress.meta
        if len(meta_pb.l3mgr.configs) > 1:
            raise RuntimeError(u'More than one L3Mgr config is in progress at the same time: {}'.format(
                meta_pb.l3mgr.configs))
        elif meta_pb.l3mgr.configs:
            config_pb = meta_pb.l3mgr.configs[0]
            if config_pb.service_id == l3mgr_config_pb.service_id and config_pb.config_id == l3mgr_config_pb.config_id:
                if rev_pb.in_progress.status == u'True':
                    return False
                else:
                    rev_pb.in_progress.status = u'True'
                    rev_pb.in_progress.last_transition_time.GetCurrentTime()
                    return True
            else:
                # L3mgr config changed, clean it up and set correct info
                del rev_pb.in_progress.meta.l3mgr.configs[:]
        rev_pb.in_progress.meta.l3mgr.configs.add().CopyFrom(l3mgr_config_pb)
        rev_pb.in_progress.status = u'True'
        rev_pb.in_progress.last_transition_time.GetCurrentTime()
        return True


class L3BalancerStateHandler(StateHandler):
    __protobuf__ = model_pb2.L3BalancerState
    __vector_class__ = L3BalancerVector
    __mutable_vector_class__ = MutableL3BalancerVector
    __zk_update_method__ = u'update_l3_balancer_state'

    u"""
    Invariant: only one L3mgr config can be in progress at the same time
    """

    @property
    def full_id(self):
        return to_full_id(self._pb.namespace_id, self._pb.l3_balancer_id)

    @property
    def ignore_existing_l3mgr_config(self):
        return self._pb.ignore_existing_l3mgr_config.value

    def mark_versions_as_in_progress(self, versions, l3mgr_config_pb):
        """
        :type versions: list[Version]
        :type l3mgr_config_pb: model_pb2.L3mgrConfig
        """
        updated = False
        for state_pb in self.update_zk():
            self._pb = state_pb
            updated = self._set_ignore_existing_l3mgr_config(False, author=u'awacs', comment=u'')
            for version in versions:
                rev_pb = self._get_rev_pb(version)
                updated |= L3Rev.set_in_progress(rev_pb, l3mgr_config_pb)
            if not updated:
                break
        return updated

    def reset_in_progress_vector(self, vector, author=u'awacs', comment=u'', ignore_l3mgr_config=True):
        """
        :type vector: L3BalancerVectorVector
        :type author: six.text_type
        :type comment: six.text_type
        :type ignore_l3mgr_config: bool
        """
        updated = False
        for state_pb in self.update_zk():
            updated = False
            self._pb = state_pb
            if ignore_l3mgr_config:
                set_ignore_updated = self._set_ignore_existing_l3mgr_config(True, author, comment)
                if set_ignore_updated:
                    updated = True
                    self._increment_skip_count(vector.get_weak_hash_str())
            for version in vector:
                rev_pb = self._get_rev_pb(version)
                updated |= L3Rev.clear_in_progress(rev_pb)
                if rev_pb.in_progress.meta.l3mgr.configs:
                    del rev_pb.in_progress.meta.l3mgr.configs[:]
                    updated |= True
            if not updated:
                break
        return updated

    def handle_l3mgr_config_activation(self, versions):
        """
        :type versions: List[ver.Version]
        """
        updated = False
        for state_pb in self.update_zk():
            updated = False
            self._pb = state_pb
            if len(self._pb.skip_counts) > 0:
                updated = True
                self._pb.ClearField('skip_counts')
            for version in versions:
                rev_pb = self._get_rev_pb(version)
                updated |= L3Rev.clear_in_progress(rev_pb)
                updated |= L3Rev.set_active(rev_pb, status=u'True')
            if not updated:
                break
        return updated

    def get_in_progress_l3mgr_config_pb(self):
        """
        :rtype: model_pb2.L3mgrConfig
        """
        l3mgr_config_pb = None
        for version, rev_pb in self.iter_in_progress_versions_and_rev_pbs():
            meta_pb = rev_pb.in_progress.meta  # type: model_pb2.L3ConfigTransportMeta
            if not meta_pb.l3mgr.configs:
                raise RuntimeError(u'No L3mgr configs found for in-progress version {}'.format(
                    version, meta_pb.l3mgr.configs))
            elif len(meta_pb.l3mgr.configs) > 1:
                raise RuntimeError(u'More than one L3Mgr config is in progress for version {}: {}'.format(
                    version, meta_pb.l3mgr.configs))
            else:
                config_pb = meta_pb.l3mgr.configs[0]
            if l3mgr_config_pb is None or config_pb.ctime.ToMicroseconds() > l3mgr_config_pb.ctime.ToMicroseconds():
                l3mgr_config_pb = config_pb
            full_config_id = (l3mgr_config_pb.service_id, l3mgr_config_pb.config_id)
            if full_config_id != (config_pb.service_id, config_pb.config_id):
                raise RuntimeError(u'Multiple L3mgr config ids are present in different versions: {} and {}'.format(
                    (config_pb.service_id, config_pb.config_id), full_config_id))
        return l3mgr_config_pb

    def get_skip_stuck_count(self, vector_hash):
        return self._pb.skip_counts.get(vector_hash, 0)

    def iter_in_progress_versions_and_rev_pbs(self):
        for rev_pb in self._get_main_rev_statuses_pb():
            if rev_pb.in_progress.status != u'True':
                continue
            yield self._main_version_class.from_rev_status_pb(self.full_id, rev_pb), rev_pb
        for version_class in self._version_classes:
            for flat_id, rev_statuses_pb in self._get_pb_field(version_class.pb_field_name).items():
                full_id = to_full_id(self._pb.namespace_id, flat_id)
                for rev_pb in self._get_rev_statuses_pb(rev_statuses_pb):
                    if rev_pb.in_progress.status != u'True':
                        continue
                    yield version_class.from_rev_status_pb(full_id, rev_pb), rev_pb

    def _set_ignore_existing_l3mgr_config(self, value, author, comment):
        if self._pb.ignore_existing_l3mgr_config.value == value:
            return False
        self._pb.ignore_existing_l3mgr_config.value = value
        self._pb.ignore_existing_l3mgr_config.mtime.GetCurrentTime()
        self._pb.ignore_existing_l3mgr_config.author = author
        self._pb.ignore_existing_l3mgr_config.comment = comment
        return True

    def _increment_skip_count(self, vector_hash):
        self._pb.skip_counts[vector_hash] += 1


def get_l3_balancer_versions(cache, namespace_id):
    """
    :type cache: AwacsCache
    :param six.text_type namespace_id:
    :rtype: dict[(six.text_type, six.text_type), L3BalancerVersion]
    """
    return {(l3_balancer_pb.meta.namespace_id, l3_balancer_pb.meta.id): L3BalancerVersion.from_pb(l3_balancer_pb)
            for l3_balancer_pb in cache.list_all_l3_balancers(namespace_id)}


def get_included_backend_ids(namespace_id, l3_balancer_spec_pb):
    """
    :type namespace_id: six.text_type
    :type l3_balancer_spec_pb: model_pb2.L3BalancerSpec
    :rtype: set[six.text_type]
    """
    selector_type = l3_balancer_spec_pb.real_servers.type
    if selector_type == model_pb2.L3BalancerRealServersSelector.BACKENDS:
        return [(namespace_id, backend_pb.id) for backend_pb in l3_balancer_spec_pb.real_servers.backends]
    elif selector_type == model_pb2.L3BalancerRealServersSelector.BALANCERS:
        # system backend's id always matches balancer id
        return [(namespace_id, balancer_pb.id) for balancer_pb in l3_balancer_spec_pb.real_servers.balancers]
    else:
        raise AssertionError(u'Unknown L3 selector type {}'.format(selector_type))


def get_instance_weight(l3_balancer_spec_pb, instance_pb):
    if not l3_balancer_spec_pb.use_endpoint_weights:
        return None
    if 0 < instance_pb.weight < 1:
        return 1  # clamp to the minimum supported l3mgr weight
    return max(int(instance_pb.weight), 0)


def make_vs_spec_pbs_from_vs_order_pb(vs_order_pb, ip):
    """
    :type vs_order_pb: model_pb2.L3BalancerOperationOrder.Content.AddVirtualServerOnNewIP
    :type ip: six.text_type
    :rtype: List[model_pb2.L3BalancerSpec.VirtualServer]
    """
    vs_spec_pbs = []
    for port in vs_order_pb.ports:
        vs_pb = model_pb2.L3BalancerSpec.VirtualServer(ip=ip, port=port, traffic_type=vs_order_pb.traffic_type)
        vs_pb.health_check_settings.url = vs_order_pb.health_check_url
        hc_pb = model_pb2.L3BalancerSpec.VirtualServer.HealthCheckSettings
        if port == 443:
            # keepalived uses SSL_GET without SNI and without checking cert contents, so this should be safe
            vs_pb.health_check_settings.check_type = hc_pb.CT_SSL_GET
        else:
            vs_pb.health_check_settings.check_type = hc_pb.CT_HTTP_GET
        vs_spec_pbs.append(vs_pb)
    return vs_spec_pbs


def make_l3mgr_balancer_meta(namespace_id, l3_balancer_id, include_meta_prefix=False, locked=False):
    prefix = u'meta-' if include_meta_prefix else u''
    data = {
        u'{}OWNER'.format(prefix): u'awacs',
        u'{}LINK'.format(prefix):
            u'https://nanny.yandex-team.ru/ui/#/awacs/namespaces/list/{}/l3-balancers/list/{}/show/'.format(
                urllib_parse.quote(namespace_id), urllib_parse.quote(l3_balancer_id)
            )
    }
    if locked:
        data[u'{}LOCKED'.format(prefix)] = True
    return data
