import gevent
import inject
import six
from six.moves import http_client as httplib

from awacs.lib import l3mgrclient
from awacs.model import zk, cache
from awacs.model.l3_balancer import cacheutil
from awacs.model.l3_balancer.errors import L3ConfigValidationError
from awacs.model.l3_balancer.state_handler import L3BalancerStateHandler
from awacs.model.l3_balancer.vector import L3BalancerVersion, EndpointSetVersion, BackendVersion
from infra.swatlib import metrics


L3_CTL_REGISTRY = metrics.ROOT_REGISTRY.path(u'l3-ctl')


def join_sorted(iterable):
    return ', '.join(sorted(iterable))


class Validator(object):
    _zk = inject.attr(zk.IZkStorage)  # type: zk.ZkStorage
    _l3mgr_client = inject.attr(l3mgrclient.IL3MgrClient)  # type: l3mgrclient.L3MgrClient
    _cache = inject.attr(cache.IAwacsCache)  # type: cache.AwacsCache

    _l3mgr_error_counter = L3_CTL_REGISTRY.get_counter(u'l3mgr-error')

    def __init__(self, namespace_id, l3_balancer_id):
        self._namespace_id = namespace_id
        self._l3_balancer_id = l3_balancer_id

    def _validate_l3mgr_service_exists(self, ctx, l3mgr_service_id):
        try:
            self._l3mgr_client.get_service(l3mgr_service_id)
        except gevent.Timeout as e:
            self._l3mgr_error_counter.inc()
            raise L3ConfigValidationError(six.text_type(e))
        except l3mgrclient.L3MgrException as e:
            if e.resp is not None and e.resp.status_code == httplib.NOT_FOUND:
                ctx.log.warn(u'%s is missing from l3mgr for some reason', l3mgr_service_id)
            else:
                ctx.log.exception(u'failed to call get_service')
            self._l3mgr_error_counter.inc()
            raise L3ConfigValidationError(six.text_type(e))

    def _validate_vector(self, ctx, l3_vector):
        """
        :param L3Vector l3_vector: vector to validate
        :type ctx: context.OpCtx
        :raises: ConfigValidationError
        """
        l3_balancer_spec_pb = cacheutil.find_l3_balancer_revision_spec_and_use_cache(
            self._namespace_id, self._l3_balancer_id, l3_vector.balancer_version.version)

        try:
            self._validate_l3mgr_service_exists(ctx, l3_balancer_spec_pb.l3mgr_service_id)
        except L3ConfigValidationError as e:
            e.cause = l3_vector.balancer_version
            raise

        endpoint_set_rev_pbs = {}
        for endpoint_set_id, endpoint_set_version in six.iteritems(l3_vector.endpoint_set_versions):
            if endpoint_set_version.deleted:
                continue
            backend_version = l3_vector.backend_versions.get(endpoint_set_id)
            if backend_version is None or backend_version.deleted:
                continue
            endpoint_set_rev_pb = cacheutil.find_endpoint_set_revision_and_use_cache(
                self._namespace_id, endpoint_set_id, endpoint_set_version.version)
            endpoint_set_rev_pbs[endpoint_set_version] = endpoint_set_rev_pb

        deleted_backend_ids = set()
        for backend_id, backend_version in six.iteritems(l3_vector.backend_versions):
            if backend_version.deleted:
                deleted_backend_ids.add(backend_id)
                continue
            if backend_id not in l3_vector.endpoint_set_versions:
                raise L3ConfigValidationError(
                    u'{}: backend "{}" is not resolved yet'.format(ctx.id(), backend_id),
                    cause=backend_version)
            endpoint_set_version = l3_vector.endpoint_set_versions[backend_id]
            endpoint_set_rev_pb = endpoint_set_rev_pbs[endpoint_set_version]
            if not backend_version.deleted and backend_version.version not in endpoint_set_rev_pb.meta.backend_versions:
                raise L3ConfigValidationError(
                    u'{}: endpoint set "{}" is not compatible with its backend'.format(ctx.id(), backend_id),
                    cause=max(backend_version, endpoint_set_version))

        included_backend_ids = set()
        for backend_id in cacheutil.get_included_backend_ids(l3_balancer_spec_pb):
            included_backend_ids.add(backend_id)
        missing_backend_ids = included_backend_ids - set(l3_vector.backend_versions)
        if missing_backend_ids:
            raise L3ConfigValidationError(
                u'{}: some included backends are missing or not resolved yet: {}'.format(
                    ctx.id(), join_sorted(missing_backend_ids)))
        included_deleted_backends = included_backend_ids & deleted_backend_ids
        if included_deleted_backends:
            raise L3ConfigValidationError(
                u'{}: the following backends are deleted: {}'.format(
                    ctx.id(), join_sorted(included_deleted_backends)))

    @staticmethod
    def _get_version_to_rollback(ctx, error, vector, valid_vector):
        if error.cause:
            ctx.log.debug(u'error caused by %s', error.cause)
            return error.cause
        not_validated_versions = []
        if vector.balancer_version != valid_vector.balancer_version:
            not_validated_versions.append(vector.balancer_version)
        for backend_id, backend_version in six.iteritems(vector.backend_versions):
            if not valid_vector or backend_version != valid_vector.backend_versions.get(backend_id):
                not_validated_versions.append(backend_version)
        for endpoint_set_id, endpoint_set_version in six.iteritems(vector.endpoint_set_versions):
            if not valid_vector or endpoint_set_version != valid_vector.endpoint_set_versions.get(endpoint_set_id):
                not_validated_versions.append(endpoint_set_version)

        if not_validated_versions:
            rv = not_validated_versions[0]
            if valid_vector is not None:
                if isinstance(rv, L3BalancerVersion) and (valid_vector.balancer_version is None or
                                                          valid_vector.balancer_version == rv):
                    not_validated_versions.remove(rv)
                    rv = not_validated_versions[0] if not_validated_versions else None
                if isinstance(rv, EndpointSetVersion) and (
                        rv.endpoint_set_id not in valid_vector.endpoint_set_versions or
                        valid_vector.endpoint_set_versions[rv.endpoint_set_id] == rv.version):
                    not_validated_versions.remove(rv)
                    rv = not_validated_versions[0] if not_validated_versions else None
            return rv
        else:
            return None

    @staticmethod
    def _rollback_version(vector, valid_vector, version_to_rollback):
        if isinstance(version_to_rollback, L3BalancerVersion):
            if valid_vector:
                return vector.replace_balancer_version(valid_vector.balancer_version)
            else:
                return vector.remove_balancer_version()
        elif isinstance(version_to_rollback, BackendVersion):
            backend_id = version_to_rollback.backend_id
            valid_backend = valid_vector.backend_versions.get(backend_id) if valid_vector else None
            if valid_backend:
                return vector.replace_backend_version(backend_id, valid_backend)
            else:
                rv = vector.remove_backend_version(backend_id)
                if backend_id in rv.endpoint_set_versions:
                    return rv.remove_endpoint_set_version(backend_id)
                else:
                    return rv
        elif isinstance(version_to_rollback, EndpointSetVersion):
            endpoint_set_id = version_to_rollback.endpoint_set_id
            valid_es = valid_vector.endpoint_set_versions.get(endpoint_set_id) if valid_vector else None
            if valid_es:
                return vector.replace_endpoint_set_version(endpoint_set_id, valid_es)
            else:
                rv = vector.remove_endpoint_set_version(endpoint_set_id)
                if endpoint_set_id in rv.backend_versions:
                    return rv.remove_backend_version(endpoint_set_id)
                else:
                    return rv
        else:
            raise AssertionError(u'Unsupported version type: {}'.format(type(version_to_rollback)))

    def _mark_version_as(self, status, version, message, l3_balancer_state_pb):
        """
        :param status:
        :param version:
        :param message:
        :type l3_balancer_state_pb: infra.awacs.proto.model_pb2.L3BalancerState
        :rtype: bool
        """
        updated = False
        for l3_balancer_state_pb in self._zk.update_l3_balancer_state(self._namespace_id, self._l3_balancer_id,
                                                                      l3_balancer_state_pb=l3_balancer_state_pb):
            h = L3BalancerStateHandler(l3_balancer_state_pb)
            updated = h.select_rev(version).set_validated(status=status, message=message)
            if not updated:
                break
        return updated

    def _mark_vector_as_valid(self, vector, message, l3_balancer_state_pb):
        updated = False
        for l3_balancer_state_pb in self._zk.update_l3_balancer_state(self._namespace_id, self._l3_balancer_id,
                                                                      l3_balancer_state_pb=l3_balancer_state_pb):
            l3s = L3BalancerStateHandler(l3_balancer_state_pb)
            updated = False
            for version in vector:
                updated |= l3s.select_rev(version).set_validated(status='True', message=message)
            if not updated:
                break
        return updated

    def validate(self, ctx, l3_balancer_state_pb):
        """
        :type l3_balancer_state_pb: infra.awacs.proto.model_pb2.L3BalancerState
        :type ctx: context.OpCtx
        :rtype: bool
        """
        ctx = ctx.with_op(op_id='validator')
        ctx.log.debug(u'validate(), l3_balancer_state_pb.generation is %s', l3_balancer_state_pb.generation)

        updated = False
        curr_vector, valid_vector, _, _ = cacheutil.l3_balancer_state_to_vectors(l3_balancer_state_pb)

        if valid_vector != curr_vector:
            ctx.log.debug(u'valid_vector: %s', valid_vector)
            ctx.log.debug(u'curr_vector: %s', curr_vector)

        while 1:
            if curr_vector == valid_vector:
                ctx.log.debug(u'curr_vector == valid_vector, nothing to validate')
                return False
            if not curr_vector.balancer_version:
                ctx.log.debug(u'no curr_vector.balancer_version, nothing to validate')
                break
            if curr_vector.has_no_backends():
                ctx.log.debug(u'curr_vector has no backends')
                break

            curr_vector.omit_orphan_endpoint_sets()
            try:
                ctx.log.debug(u'validating curr_vector')
                self._validate_vector(ctx, curr_vector)
                break
            except L3ConfigValidationError as e:
                ctx.log.debug(u'curr_vector is not valid: %s', e)
                # if current vector is not valid, we must find a change that caused an error
                tryout_curr_vector = curr_vector
                # remember the first error
                error = e
                # and start looking
                while 1:
                    ctx.log.debug(u'start looking for a change that caused tryout_curr_vector to be invalid')
                    # undo the latest change
                    version_to_rollback = self._get_version_to_rollback(ctx, error, tryout_curr_vector, valid_vector)
                    if not version_to_rollback:
                        curr_vector = tryout_curr_vector
                        break  # return to the outer loop, so we can try to validate again

                    prev, tryout_curr_vector = (tryout_curr_vector,
                                                self._rollback_version(tryout_curr_vector,
                                                                       valid_vector,
                                                                       version_to_rollback))
                    if prev == tryout_curr_vector:
                        raise RuntimeError(
                            u'Infinite loop detected: rolling back %s from %s:%s results in the same vector',
                            version_to_rollback, self._namespace_id, self._l3_balancer_id)

                    ctx.log.debug(u'rolled back %s, validating tryout_curr_vector %s',
                                  version_to_rollback, tryout_curr_vector)

                    tryout_curr_vector.omit_orphan_endpoint_sets()
                    if tryout_curr_vector.has_no_backends():
                        # it means that we just rolled back last non-deleted invalid backend version,
                        # let's consider tryout_curr_vector valid to stop here
                        # and rollback this version from the `curr_vector`
                        pass
                    else:
                        try:
                            self._validate_vector(ctx, tryout_curr_vector)
                        except L3ConfigValidationError as e:
                            error = e
                            ctx.log.debug(u'tryout_curr_vector is not valid: %s', e)
                            continue  # if the vector is still invalid, remember the error and continue undoing

                    ctx.log.debug(u'tryout_curr_vector is valid, mark %s as invalid and '
                                  u'roll it back from curr_vector', version_to_rollback)
                    # the vector is valid, therefore we can tell that `version_to_rollback` caused the error
                    # we mark change as invalid
                    updated |= self._mark_version_as(u'False', version_to_rollback, six.text_type(error),
                                                     l3_balancer_state_pb=l3_balancer_state_pb)
                    # and remove it from the current vector
                    curr_vector = self._rollback_version(curr_vector, valid_vector, version_to_rollback)
                    ctx.log.debug(u'new curr_vector: %s', curr_vector)
                    break  # return to the outer loop

        if curr_vector.has_no_backends():
            return updated
        ctx.log.debug(u'new valid vector: %s', curr_vector)
        updated |= self._mark_vector_as_valid(curr_vector, message=u'', l3_balancer_state_pb=l3_balancer_state_pb)
        return updated
