import abc
import google.protobuf.message
import inject
import six
from boltons import namedutils
from datetime import datetime
from typing import Type, Set, List, Tuple, final

from awacs.lib.strutils import flatten_full_id, to_full_id
from awacs.lib.vectors import version as ver, vector as vec, vector_mutable as mut_vec
from awacs.lib.vectors.version import choose_newer_version
from awacs.model import zk
from awacs.model.util import find_rev_status_by_revision_id


vectors = namedutils.namedtuple(u'vectors', (u'current', u'valid', u'in_progress', u'active'))
cleanup_result = namedutils.namedtuple(u'cleanup_result', (u'state_updated', u'removed_versions'))

REGISTRY = {}


class StateHandlerMeta(abc.ABCMeta):
    def __new__(mcs, name, bases, attrs):
        if object not in bases:
            message_cls = attrs[u'__protobuf__']
            assert message_cls
            vector_class = attrs[u'__vector_class__']
            assert issubclass(vector_class, vec.Vector)
            mutable_vector_class = attrs[u'__mutable_vector_class__']
            assert issubclass(mutable_vector_class, mut_vec.MutableVector)
            zk_update_method = attrs[u'__zk_update_method__']
            if not callable(zk_update_method):
                assert hasattr(zk.ZkStorage, zk_update_method)
        else:
            return type.__new__(mcs, name, bases, attrs)
        if message_cls in REGISTRY:
            raise RuntimeError(u'StateHandler for {} already exists: {!r}'.format(
                message_cls.__name__, REGISTRY[message_cls.DESCRIPTOR.full_name]))

        slots = set()
        if u'__slots__' in attrs:
            slots.update(attrs[u'__slots__'])  # we also want to preserve slots defined by users
        if bases == (object,):
            slots.add(u'__dict__')  # and we want our instances to be mockable (very useful in tests)
        attrs[u'__slots__'] = sorted(slots)

        klass = REGISTRY[message_cls.DESCRIPTOR.full_name] = abc.ABCMeta.__new__(mcs, name, bases, attrs)

        return klass


class Rev(object):
    __slots__ = ()

    @staticmethod
    @final
    def set_validated(rev_pb, status, message):
        """
        :type rev_pb: google.protobuf.message.Message
        :type status: six.text_type
        :type message: six.text_type
        :rtype: bool
        """
        if rev_pb is None:
            return False
        assert rev_pb.HasField('validated')
        if rev_pb.validated.status == status and rev_pb.validated.message == message:
            return False
        rev_pb.validated.status = status
        rev_pb.validated.message = message
        rev_pb.validated.last_transition_time.GetCurrentTime()
        return True

    @staticmethod
    @final
    def set_active(rev_pb, status):
        """
        :type rev_pb: google.protobuf.message.Message
        :type status: six.text_type
        :rtype: bool
        """
        if rev_pb is None:
            return False
        assert rev_pb.HasField('active')
        if rev_pb.active.status == status:
            return False
        rev_pb.active.status = status
        rev_pb.active.last_transition_time.GetCurrentTime()
        return True

    @staticmethod
    @final
    def clear_in_progress(rev_pb):
        """
        :type rev_pb: google.protobuf.message.Message
        :rtype: bool
        """
        if rev_pb is None:
            return False
        assert rev_pb.HasField('in_progress')
        if rev_pb.in_progress.status == u'False':
            return False
        rev_pb.in_progress.status = u'False'
        rev_pb.in_progress.last_transition_time.GetCurrentTime()
        return True


class StateHandler(six.with_metaclass(StateHandlerMeta, object)):
    """
    Main abstraction for working with state_pb.
    Callers deal with Versions, and StateHandler translates them into state_pb revisions.

    Important: state_pb MUST NOT live longer than one iteration of discovery OR validation OR transport,
    because it retains a copy of protobuf which is not synced with cache
    """
    __slots__ = (u'_pb', u'_was_updated', u'_has_unsaved_changes')

    # inheritors MUST define these fields
    __protobuf__ = None  # type: Type[google.protobuf.message.Message]
    __vector_class__ = None  # type: Type[vec.Vector]
    __mutable_vector_class__ = None  # type: Type[mut_vec.MutableVector]
    __zk_update_method__ = None  # type: six.text_type

    _zk = inject.attr(zk.IZkStorage)  # type: zk.ZkStorage

    def __init__(self, pb):
        if pb.DESCRIPTOR.full_name != self.__protobuf__.DESCRIPTOR.full_name:
            raise AssertionError(u'{} is not {}'.format(type(pb), self.__protobuf__))
        self._pb = pb
        self._was_updated = False
        self._has_unsaved_changes = False

    @property
    @abc.abstractmethod
    def full_id(self):
        """
        :rtype: Tuple[six.text_type, six.text_type]
        """
        raise NotImplementedError

    def update_zk(self):
        if self._was_updated:
            raise RuntimeError(u'Cannot reuse state handler after state was saved to zk')
        if callable(self.__zk_update_method__):
            zk_update_method = self.__zk_update_method__
        else:
            zk_update_method = getattr(self._zk, self.__zk_update_method__)
        for state_pb in zk_update_method(self.full_id[0], self.full_id[1], self._pb):
            # user of this method should do this assignment for readability,
            # but repeat it here just to make sure that handler always has the latest state_pb
            self._pb = state_pb

            yield state_pb
        else:
            # if we didn't "break" out of this loop, then zk update completed successfully
            self._was_updated = True

    def flush(self):
        if self._has_unsaved_changes and not self._was_updated:
            for state_pb in self.update_zk():
                state_pb.CopyFrom(self._pb)
            self._has_unsaved_changes = False

    @property
    def was_updated(self):
        return self._was_updated

    def mark_version_as_invalid(self, version, message):
        """
        We don't want to save to ZK here - versions are marked as invalid multiple times,
        and we need to update zk only once

        :type version: Version
        :type message: six.text_type
        """
        rev_pb = self._get_rev_pb(version)
        updated = Rev.set_validated(rev_pb, status=u'False', message=message)
        self._has_unsaved_changes |= updated
        return updated

    def mark_versions_as_valid(self, versions):
        """
        :type versions: List[Version]
        """
        for state_pb in self.update_zk():
            updated = False
            self._pb = state_pb
            for version in versions:
                rev_pb = self._get_rev_pb(version)
                updated |= Rev.set_validated(rev_pb, status=u'True', message=u'')
            if not updated:
                break

    def update_versions(self, versions_to_add, versions_to_delete):
        """
        :type versions_to_add: Set[Version]
        :type versions_to_delete: Set[Version]
        """
        for state_pb in self.update_zk():
            updated = False
            self._pb = state_pb
            for version in versions_to_add:
                updated |= self._add_version_if_missing(version)
            for version in versions_to_delete:
                updated |= self._delete_version(version)
            if not updated:
                break

    def remove_obsolete_versions(self, vector):
        """
        Clean up the state to discard revisions that are no longer relevant.

        This should be given the latest stable vector - it's usually Active vector;
        unless there's no Active state, (e.g. DNSRecord), then it could be Valid or something else.

        :type vector: vec.Vector
        :rtype Set[ver.Version]
        """
        removed_versions = set()
        for state_pb in self.update_zk():
            self._pb = state_pb
            removed_versions = self._remove_obsolete_revs(vector)
            updated = len(removed_versions) > 0
            updated |= self._remove_fields_without_revs()
            if not updated:
                break
        return removed_versions

    def generate_vectors(self):
        """
        note: we use hasattr instead of HasField because we must support objects that are entirely missing some fields,
        and in that case HasField throws an error

        :rtype vectors[mut_vec.MutableVector, vec.Vector, vec.Vector, vec.Vector]
        """
        current_vers = {}
        valid_vers = {}
        in_progress_vers = {}
        active_vers = {}
        validated_pbs = {}
        for version_class in self._version_classes:
            validated_pbs[version_class.__name__] = {}
            for vers in current_vers, valid_vers, in_progress_vers, active_vers:
                vers[version_class.vector_field_name] = {}
        main_field_name = self._main_version_class.vector_field_name
        validated_pbs[self._main_version_class.__name__] = {}
        for vers in current_vers, valid_vers, in_progress_vers, active_vers:
            vers[main_field_name] = None

        for rev_pb in self._get_main_rev_statuses_pb():
            v = self._main_version_class.from_rev_status_pb(self.full_id, rev_pb)
            current_vers[main_field_name] = choose_newer_version(v, current_vers.get(main_field_name))
            if v == current_vers[main_field_name]:
                validated_pbs[self._main_version_class.__name__][v.id] = rev_pb.validated
            if hasattr(rev_pb, 'validated') and rev_pb.validated.status == u'True':
                valid_vers[main_field_name] = choose_newer_version(v, valid_vers.get(main_field_name))
            if hasattr(rev_pb, 'in_progress') and rev_pb.in_progress.status == u'True':
                in_progress_vers[main_field_name] = choose_newer_version(v, in_progress_vers.get(main_field_name))
            if hasattr(rev_pb, 'active') and rev_pb.active.status == u'True':
                active_vers[main_field_name] = choose_newer_version(v, active_vers.get(main_field_name))

        for version_class in self._version_classes:
            field_name = version_class.vector_field_name
            for flat_id, revs_pb in six.iteritems(self._get_pb_field(version_class.pb_field_name)):
                full_id = to_full_id(self._pb.namespace_id, flat_id)
                for rev_pb in self._get_rev_statuses_pb(revs_pb):
                    v = version_class.from_rev_status_pb(full_id, rev_pb)
                    newest_v = choose_newer_version(v, current_vers[field_name].get(full_id))
                    current_vers[field_name][full_id] = newest_v
                    if v == newest_v:
                        validated_pbs[version_class.__name__][v.id] = rev_pb.validated
                    if hasattr(rev_pb, 'validated') and rev_pb.validated.status == u'True':
                        valid_vers[field_name][full_id] = choose_newer_version(
                            v, valid_vers[field_name].get(full_id))
                    if hasattr(rev_pb, 'in_progress') and rev_pb.in_progress.status == u'True':
                        in_progress_vers[field_name][full_id] = choose_newer_version(
                            v, in_progress_vers[field_name].get(full_id))
                    if hasattr(rev_pb, 'active') and rev_pb.active.status == u'True':
                        active_vers[field_name][full_id] = choose_newer_version(
                            v, active_vers[field_name].get(full_id))

        current_vector = self.__mutable_vector_class__(validated_pbs=validated_pbs, **current_vers)
        valid_vector = self.__vector_class__(**valid_vers)
        in_progress_vector = self.__vector_class__(**in_progress_vers)
        active_vector = self.__vector_class__(**active_vers)
        return vectors(current_vector, valid_vector, in_progress_vector, active_vector)

    @property
    def _main_version_class(self):
        """
        :rtype Type[ver.Version]
        """
        return self.__vector_class__.__main_version_class__

    @property
    def _version_classes(self):
        """
        :rtype Iterable[Type[ver.Version]]
        """
        return self.__vector_class__.__version_classes__

    def _get_main_rev_statuses_pb(self):
        revs_pb = getattr(self._pb, self._main_version_class.pb_field_name)
        return self._get_rev_statuses_pb(revs_pb)

    def _get_pb_field(self, pb_field_name):
        """
        :type: pb_field_name: six.text_type
        """
        return getattr(self._pb, pb_field_name)

    def _get_revs_pb(self, version, create=False):
        """
        :type version: ver.Version
        :type create: bool
        :rtype: Revisions | None
        """
        if isinstance(version, self._main_version_class):
            return self._get_pb_field(self._main_version_class.pb_field_name)
        assert type(version) in self._version_classes, u'{} not in {}'.format(type(version), self._version_classes)
        version_revs_pb = self._get_pb_field(version.pb_field_name)
        flat_id = flatten_full_id(self._pb.namespace_id, version.id)
        if flat_id in version_revs_pb or create:
            return version_revs_pb[flat_id]
        return None

    @staticmethod
    def _get_rev_statuses_pb(revs_pb):
        if hasattr(revs_pb, 'statuses'):
            return revs_pb.statuses
        elif hasattr(revs_pb, 'l3_statuses'):  # >:-(
            return revs_pb.l3_statuses
        raise RuntimeError

    def _get_rev_pb(self, version):
        revs_pb = self._get_revs_pb(version)
        if revs_pb is None:
            return None
        statuses_pb = self._get_rev_statuses_pb(revs_pb)
        return find_rev_status_by_revision_id(statuses_pb, version.version)

    def _remove_obsolete_revs(self, vector):
        """
        Walk through all the revisions in state, and discard revs that are not present or removed in the given vector,
        or that are older than the given vector

        :rtype: set[Version]
        """
        removed_versions = set()
        rev_statuses_pb = self._get_main_rev_statuses_pb()
        updated = False
        new_revs_pb = []
        for rev_pb in rev_statuses_pb:
            v = self._main_version_class.from_rev_status_pb(self.full_id, rev_pb)
            given_v = vector.get_version_item_by_version(v)
            if (
                    given_v is None  # this object is completely missing from the vector
                    or v.deleted and v.version == given_v.version  # this rev is marked as deleted
                    or v.ctime < given_v.ctime  # this object is older than the given version
            ):
                updated = True
                removed_versions.add(v)
            else:
                new_revs_pb.append(rev_pb)  # preserve this rev
        if updated:
            del rev_statuses_pb[:]
            rev_statuses_pb.extend(new_revs_pb)
        for version_class in self._version_classes:
            for flat_id, revs_pb in six.iteritems(self._get_pb_field(version_class.pb_field_name)):
                full_id = to_full_id(self._pb.namespace_id, flat_id)
                new_revs_pb = []
                rev_statuses_pb = self._get_rev_statuses_pb(revs_pb)
                updated = False
                for rev_pb in rev_statuses_pb:
                    v = version_class.from_rev_status_pb(full_id, rev_pb)
                    given_v = vector.get_version_item_by_version(v)
                    if (
                            given_v is None  # this object is completely missing from the vector
                            or v.deleted and v.version == given_v.version  # this rev is marked as deleted
                            or v.ctime < given_v.ctime  # this object is older than the given version
                    ):
                        updated = True
                        removed_versions.add(v)
                    else:
                        new_revs_pb.append(rev_pb)  # preserve this rev
                if updated:
                    del rev_statuses_pb[:]
                    rev_statuses_pb.extend(new_revs_pb)
        return removed_versions

    def _remove_fields_without_revs(self):
        """
        :rtype: bool
        """
        updated = False
        for version_class in self._version_classes:
            version_revs_pb = self._get_pb_field(version_class.pb_field_name)
            for flat_id, revs_pb in list(version_revs_pb.items()):  # not iteritems because we modify this dict
                if len(self._get_rev_statuses_pb(revs_pb)) == 0:
                    del version_revs_pb[flat_id]
                    updated = True
        return updated

    def _add_version_if_missing(self, version):
        """
        :type version: ver.Version
        :rtype: bool
        """
        revs_pb = self._get_revs_pb(version, create=True)
        statuses_pb = self._get_rev_statuses_pb(revs_pb)
        if find_rev_status_by_revision_id(statuses_pb, version.version) is not None:
            return False
        status_pb = statuses_pb.add(revision_id=version.version,
                                    deleted=version.deleted,
                                    incomplete=version.incomplete)
        status_pb.ctime.FromMicroseconds(version.ctime)
        utcnow = datetime.utcnow()
        if hasattr(status_pb, 'validated'):
            status_pb.validated.status = 'Unknown'
            status_pb.validated.last_transition_time.FromDatetime(utcnow)
        if hasattr(status_pb, 'in_progress'):
            status_pb.in_progress.status = 'False'
            status_pb.in_progress.last_transition_time.FromDatetime(utcnow)
        if hasattr(status_pb, 'active'):
            status_pb.active.status = 'False'
            status_pb.active.last_transition_time.FromDatetime(utcnow)
        return True

    def _delete_version(self, version):
        """
        :type version: Version
        :rtype: bool
        """
        if isinstance(version, self._main_version_class):
            raise RuntimeError('Cannot remove main version from state')
        revs_pb = self._get_pb_field(version.pb_field_name)
        flat_id = flatten_full_id(self._pb.namespace_id, version.id)
        if flat_id in revs_pb:
            del revs_pb[flat_id]
            return True
        return False
