import collections

import abc
import inject
import six

from awacs.lib.strutils import flatten_full_id, join_full_uids
from awacs.lib.vectors import cacheutil
from awacs.lib.vectors.vector import Vector, VectorMeta, ValidationError
from awacs.model import cache


root_cause_result = collections.namedtuple('root_cause_result', ('version', 'error', 'rollback_counter'))


def copy_dict_of_dicts(d):
    return {key: dict(values) for key, values in six.iteritems(d)}


class MutableVector(six.with_metaclass(VectorMeta, Vector)):
    __slots__ = ('validated_pbs',)

    __main_version_class__ = None
    __version_classes__ = None

    def __init__(self, validated_pbs, *args, **kwargs):
        self.validated_pbs = validated_pbs
        super(MutableVector, self).__init__(*args, **kwargs)

    @abc.abstractmethod
    def validate(self, ctx):
        """
        :type ctx: context.OpCtx
        :raises: ConfigValidationError
        """
        raise NotImplementedError

    @abc.abstractmethod
    def has_anything_to_validate(self):
        """
        :rtype: bool
        """
        raise NotImplementedError

    @abc.abstractmethod
    def remove_orphan_versions(self):
        raise NotImplementedError

    def clone(self):
        kwargs = {}
        for version_field_name in self.__version_field_names__:
            kwargs[version_field_name] = dict(self.must_get_version_dict(version_field_name))
        return type(self)(self.main_version,
                          validated_pbs=copy_dict_of_dicts(self.validated_pbs),
                          **kwargs)

    def get_version_to_rollback(self, ctx, valid_vector):
        """
        "self" is tryout vector

        Priority when looking for a version to rollback:
        1) Version most recently marked as invalid
        2) Version most recently created that doesn't match valid version

        :type ctx: context.OpCtx
        :type valid_vector: Vector
        :rtype Optional[Version]
        """
        if not isinstance(valid_vector, Vector):
            raise NotImplementedError

        first_diff_version = None
        first_invalid_version = None
        first_invalid_version_mtime = None
        for version in self:
            validated_pb = self._get_validated_pb(version)
            if validated_pb.status == u'False':
                version_mtime = validated_pb.last_transition_time.ToMicroseconds()
                if first_invalid_version is None or first_invalid_version_mtime < version_mtime:
                    first_invalid_version = version
                    first_invalid_version_mtime = version_mtime
            else:
                valid_version = valid_vector.get_version_item(version.vector_field_name, version.id)
                if version != valid_version:
                    if first_diff_version is None or first_diff_version.ctime < version.ctime:
                        first_diff_version = version

        if first_invalid_version:
            ctx.log.debug('found an invalid version to rollback: %s', first_invalid_version)
            return first_invalid_version
        elif first_diff_version:
            ctx.log.debug("invalid version to rollback not found, "
                          "instead found a version that doesn't match valid version: %s",
                          first_diff_version)
            return first_diff_version
        ctx.log.debug('version to rollback not found')
        return None

    def rollback_version(self, valid_vector, version_to_rollback):
        """
        :type valid_vector: Vector
        :type version_to_rollback: Version
        :rtype: MutableVector
        """
        if not isinstance(valid_vector, Vector):
            raise NotImplementedError

        version_class = type(version_to_rollback)
        if version_class != self.__main_version_class__ and version_class not in self.__version_classes__:
            raise RuntimeError(u'Unsupported version type: {}'.format(version_class))

        valid_version = valid_vector.get_version_item_by_version(version_to_rollback)
        self._replace_version_with(version_to_rollback, valid_version)
        self.remove_orphan_versions()

    def _replace_version_with(self, version, valid_version):
        if version == self.main_version:
            self.main_version = valid_version
        else:
            version_dict = self.must_get_version_dict_by_version(version)
            if valid_version is not None:
                version_dict[version.id] = valid_version
            else:
                version_dict.pop(version.id, None)
        self.validated_pbs.get(version.__class__.__name__, {}).pop(version.id, None)

    def _get_validated_pb(self, version):
        return self.validated_pbs.get(version.__class__.__name__, {}).get(version.id)

    def transform_until_valid(self, ctx, valid_vector, state_handler):
        """
        While this vector is invalid:
          1) Clone this vector.
          2) Continuously rollback versions from the clone until it becomes valid.
          3) Consider the last rolled back version as the root cause of the invalidness and mark it as invalid in zk.
          4) Rollback root cause version from the original vector and re-validate it.
          4a) If it's valid, mark all versions remaining in this vector as valid.
          4b) If it's invalid, repeat the process starting from (1).

        Rollback means that we replace a version in this vector with version from the known valid vector,
        or remove version completely if valid vector doesn't have a matching valid version.

        :type ctx: context.OpCtx
        :type valid_vector: Vector
        :type state_handler: StateHandler
        :rtype: int
        """
        rollback_counter = 0
        if self != valid_vector:
            ctx.log.debug(u'valid_vector: %s', valid_vector)
            ctx.log.debug(u'tryout_vector: %s', self)
        else:
            ctx.log.debug(u'tryout_vector == valid_vector, nothing to validate')
            return rollback_counter
        try:
            root_cause = None  # type: root_cause_result  # noqa
            while 1:
                # this loop is guaranteed to eventually exit, see self._get_root_cause_version()

                if not self.has_anything_to_validate():
                    if self.main_version:
                        if root_cause is not None:
                            msg = six.text_type(root_cause.error)
                        else:
                            msg = u'All included objects are invalid'
                        state_handler.mark_version_as_invalid(self.main_version, msg)
                    return rollback_counter

                try:
                    ctx.log.debug(u'validating tryout_vector')
                    self.validate(ctx)
                    ctx.log.debug(u'new valid vector: %s', self)
                    state_handler.mark_versions_as_valid(list(self))
                    ctx.log.debug(u'marked vector as valid')
                    return rollback_counter
                except ValidationError as error:
                    ctx.log.debug(u'tryout_vector is not valid: %s, trying to find the root cause', error)
                    root_cause = self._get_root_cause_version(ctx, valid_vector, error)
                    ctx.log.debug(u'consider version %s as root cause and roll it back from tryout_vector',
                                  root_cause.version)
                    self.rollback_version(valid_vector, root_cause.version)
                    state_handler.mark_version_as_invalid(root_cause.version,
                                                          six.text_type(root_cause.error))
                    ctx.log.debug(u'new tryout_vector: %s', self)
                    rollback_counter += root_cause.rollback_counter
        finally:
            state_handler.flush()  # if we marked some versions as invalid, save them to ZK all at once

    def _get_root_cause_version(self, ctx, valid_vector, initial_error):
        """
        1) Clone this vector.
        2) Continuously rollback versions from the clone until it becomes valid.
        3) Consider the last rolled back version as the root cause.

        This always returns a version, or throws an exception

        :type ctx: context.OpCtx
        :type valid_vector: Vector
        :type error: ValidationError
        :rtype: root_cause_result
        :raises RuntimeError
        """
        rollback_counter = 0
        candidate_vector = self.clone()
        error = initial_error
        while 1:
            version_to_rollback = None

            if error.cause:
                valid_version = valid_vector.get_version_item_by_version(error.cause)
                if valid_version == error.cause:
                    ctx.log.warning('Rollback hint matches valid version, ignoring it: %s', error.cause)
                else:
                    ctx.log.debug('Error caused by %s', error.cause)
                    version_to_rollback = error.cause

            if not version_to_rollback:
                ctx.log.debug('Choosing rollback version by heuristic')
                version_to_rollback = candidate_vector.get_version_to_rollback(ctx, valid_vector)

            if not version_to_rollback:
                raise RuntimeError("Didn't find version to rollback, validation failed")

            prev = candidate_vector.clone()
            candidate_vector.rollback_version(valid_vector, version_to_rollback)
            candidate_vector.remove_orphan_versions()
            if prev == candidate_vector:
                raise RuntimeError(u'Infinite loop detected: rolling back %s results in the same vector',
                                   version_to_rollback)

            rollback_counter += 1
            ctx.log.debug(u'rolled back %s, validating candidate_vector %s', version_to_rollback, candidate_vector)
            if not candidate_vector.has_anything_to_validate():
                # there's nothing left to roll back, so consider the previous rolled back version the root cause
                return root_cause_result(version_to_rollback, error, rollback_counter)

            try:
                candidate_vector.validate(ctx)
                # the last rollback led to successful validation, so consider it a root cause
                return root_cause_result(version_to_rollback, error, rollback_counter)
            except ValidationError as e:
                ctx.log.debug(u'candidate_vector is not valid: %s, continuing to rollback versions', e)
                error = e


class MutableVectorWithBackends(six.with_metaclass(VectorMeta, MutableVector)):
    _cache = inject.attr(cache.IAwacsCache)  # type: cache.AwacsCache

    __main_version_class__ = None
    __version_classes__ = None

    def __init__(self, backend_versions, endpoint_set_versions, validated_pbs, *args, **kwargs):
        """
        :type backend_versions: dict[(six.text_type, six.text_type), BackendVersion]
        :type endpoint_set_versions: dict[(six.text_type, six.text_type), EndpointSetVersion]
        :type validated_pbs: dict[six.text_type, dict[(six.text_type, six.text_type), EndpointSetVersion]]
        """
        self.backend_versions = backend_versions
        self.endpoint_set_versions = endpoint_set_versions

        # super() resolves to Vector instead of MutableVector here, this is a workaround
        MutableVector.__init__(self, validated_pbs=validated_pbs, *args, **kwargs)

    @abc.abstractmethod
    def validate(self, ctx):
        raise NotImplementedError

    def has_anything_to_validate(self):
        if self.main_version is None:
            return False
        for backend_version in six.itervalues(self.backend_versions):
            if not backend_version.deleted:
                return True
        return False

    def remove_orphan_versions(self):
        for endpoint_set_id in list(self.endpoint_set_versions):
            if endpoint_set_id not in self.backend_versions:
                del self.endpoint_set_versions[endpoint_set_id]

    def validate_backends_and_endpoint_sets(self, ctx, namespace_id, included_backend_ids):
        """
        :type ctx: context.OpCtx
        :type namespace_id: six.text_type
        :type included_backend_ids: Iterable[(six.text_type, six.text_type)]
        :raises: ValidationError
        """
        endpoint_set_rev_pbs = {}
        for endpoint_set_id, endpoint_set_version in six.iteritems(self.endpoint_set_versions):
            if endpoint_set_version.deleted or endpoint_set_version.incomplete:
                continue
            backend_version = self.backend_versions.get(endpoint_set_id)
            if backend_version is None or backend_version.deleted or backend_version.incomplete:
                continue
            endpoint_set_rev_pb = cacheutil.must_get_endpoint_set_revision_with_cache(
                cache=self._cache,
                namespace_id=namespace_id,
                endpoint_set_id=flatten_full_id(namespace_id, endpoint_set_id),
                version=endpoint_set_version.version)
            endpoint_set_rev_pbs[endpoint_set_version] = endpoint_set_rev_pb

        deleted_backends = {}
        for backend_id, backend_version in six.iteritems(self.backend_versions):
            if backend_version.deleted or backend_version.incomplete:
                deleted_backends[backend_id] = backend_version
                continue
            if backend_id not in self.endpoint_set_versions:
                raise ValidationError(
                    u'backend "{}" is not resolved yet'.format(backend_id),
                    cause=backend_version)
            endpoint_set_version = self.endpoint_set_versions[backend_id]
            endpoint_set_rev_pb = endpoint_set_rev_pbs[endpoint_set_version]
            if backend_version.version not in endpoint_set_rev_pb.meta.backend_versions:
                raise ValidationError(u'endpoint set "{}" is not compatible with its backend'.format(backend_id),
                                      cause=backend_version)

        included_backend_ids = set(included_backend_ids)
        missing_backend_ids = included_backend_ids - set(self.backend_versions)
        if missing_backend_ids:
            raise ValidationError(
                u'some included backends are missing or not resolved yet: {}'.format(
                    join_full_uids(missing_backend_ids)),
                cause=max(self.backend_versions.get(missing_id) for missing_id in missing_backend_ids))
        included_deleted_backends = included_backend_ids & set(deleted_backends)
        if included_deleted_backends:
            raise ValidationError(
                u'the following backends are deleted: {}'.format(join_full_uids(included_deleted_backends)),
                cause=max(self.backend_versions.get(deleted_id) for deleted_id in included_deleted_backends))
