"""
"Model object" is any object that is stored in zk+cache[+mongo].
They have a lot of common behavior (for example, add/remove from zk/cache), and some object-specific
(for example, operation names and associated locks).
The goal of this module is to abstract away as much common mechanisms as possible, while not exposing too much
internal details to the end-user (in this case a controller, or an order processor).

All objects are implemented as instance-less classes, which operate on uids or protobufs, without storing any state.
State is usually contained in protobuf itself, or in a controller, so it doesn't make much sense to wrap protobufs
into other classes. As a plus, we don't create a lot of temporary objects.

Class layout:
- Descriptor is a collection of object constants - paths, names, protobuf class, etc. All other classes use it.
- ModelObject is the main user-facing class, provides high-level abstractions. It's aimed to replace old DAO.
- ModelZkClient is a lower-level zk storage accessor.
- ModelMongoClient is a lower-level mongo storage accessor.
- ModelCache is a replacement for cache.py as a pb storage. in the future cache.py will lose its own caches, and will
only forward protobufs into individual ModelCache's.
"""

import abc
import six
import bson
import uuid
import collections
from datetime import datetime
from boltons import strutils, typeutils
from google.protobuf.message import Message, DecodeError
from typing import Type, final, Tuple, TypeVar

from awacs.lib import zk_storage, zookeeper_client, mongo
from awacs.lib.pagination import SliceResult
from awacs.lib.vectors import version as ver
from awacs.model import errors
from infra.swatlib.zk.treecache import SkippedNode, DataNode


_MODELS_BY_CLASS_NAME = {}
_MODELS_BY_ZK_PREFIX = {}

_DESCRIPTORS = {}
_ZK_CLIENTS = {}
_MONGO_CLIENTS = {}
_CACHES = {}

PMType = TypeVar('PMType', bound=Message)


def get_model_by_zk_prefix(zk_prefix):
    if zk_prefix in _MODELS_BY_ZK_PREFIX:
        return _MODELS_BY_ZK_PREFIX[zk_prefix]
    return None


def get_all_descriptors():
    return Descriptor.__subclasses__()


def raise_on_init(*_, **__):
    raise RuntimeError(u"Don't instantiate me, use classmethods")


class ZkNodeType(object):
    __slots__ = ()
    __init__ = raise_on_init
    SINGLE_NESTED = typeutils.make_sentinel()  # e.g. /namespaces/<namespace_id>/<pb>
    DOUBLE_NESTED = typeutils.make_sentinel()  # e.g. /domains/<namespace_id>/<domain_id>/<pb>


class AwacsDaemon(object):
    __slots__ = ()
    __init__ = raise_on_init
    WORKER = typeutils.make_sentinel()
    STATUS = typeutils.make_sentinel()
    RESOLVER = typeutils.make_sentinel()


class IncludeInDaemons(object):
    __slots__ = ()
    __init__ = raise_on_init
    ALL = (AwacsDaemon.STATUS, AwacsDaemon.WORKER, AwacsDaemon.RESOLVER)
    ALL_EXCEPT_STATUS = (AwacsDaemon.WORKER, AwacsDaemon.RESOLVER)
    ALL_EXCEPT_RESOLVER = (AwacsDaemon.STATUS, AwacsDaemon.WORKER)
    ONLY_WORKER = (AwacsDaemon.WORKER,)
    NOWHERE = ()


class cached_class_property(object):  # noqa
    def __init__(self, func):
        self.__doc__ = getattr(func, u'__doc__')
        self.__isabstractmethod__ = getattr(func, u'__isabstractmethod__', False)
        self.func = func

    def __get__(self, obj, owner):
        value = self.func(owner)
        setattr(owner, self.func.__name__, value)
        return value

    def __repr__(self):
        return u'<%s func=%s>' % (self.__class__.__name__, self.func)


class DescriptorMeta(abc.ABCMeta):
    def __new__(mcs, name, bases, attrs):
        if name in _DESCRIPTORS:
            raise RuntimeError(u'Config class "{}" already exists: {!r}'.format(name, _DESCRIPTORS[name]))

        if object not in bases:
            assert attrs[u'zk_prefix']
            assert attrs[u'zk_node_type'] is not None
            assert attrs[u'include_in_daemons'] is not None
            assert attrs[u'proto_class'] is not None
            assert attrs[u'codec_class'] is not None

            attrs[u'codec'] = type('%s%s' % (attrs[u'proto_class'].__name__, 'Codec'),
                                   (attrs[u'codec_class'],),
                                   {'message_cls': attrs[u'proto_class']})

            attrs[u'canonical_name'] = DescriptorMeta.make_canonical_name(attrs[u'zk_prefix'])
            attrs[u'canonical_id'] = attrs[u'canonical_name'] + u'_id'
            attrs[u'readable_name'] = DescriptorMeta.make_readable_name(attrs[u'canonical_name'])
            attrs[u'slugified_name'] = DescriptorMeta.make_slugified_name(attrs[u'canonical_name'])

        # don't allow creating/modifying class instances
        attrs[u'__slots__'] = ()
        attrs[u'__init__'] = raise_on_init

        klass = _DESCRIPTORS[name] = abc.ABCMeta.__new__(mcs, name, bases, attrs)
        return klass

    @staticmethod
    def make_canonical_name(zk_prefix):
        """
        :return: For example, "namespace_operation", or "l3_balancer_state"
        """
        if zk_prefix.endswith(u's_2'):  # balancers_2 and balancer_states_2 - for historical reasons
            return zk_prefix[:-3]
        elif zk_prefix.endswith(u's'):  # everything else: namespaces, domains, etc.
            return zk_prefix[:-1]
        else:
            return zk_prefix  # just in case

    @staticmethod
    def make_readable_name(canonical_name):
        """
        :return: For example, "Namespace operation", or "L3 balancer state"
        """
        return canonical_name.title().replace(u'_', u' ')

    @staticmethod
    def make_slugified_name(canonical_name):
        """
        :return: For example, "namespace-operation", or "l3-balancer-state"
        """
        return strutils.slugify(canonical_name, delim=u'-')


class Descriptor(six.with_metaclass(DescriptorMeta, object)):
    # required
    zk_prefix = None  # type: six.text_type
    zk_node_type = None  # type: ZkNodeType
    mongo_revs_column = None  # type: six.text_type
    rev_proto_class = None  # type: Type[PMType]
    include_in_daemons = None  # type: Tuple[AwacsDaemon]
    proto_class = None  # type: Type[PMType]
    codec_class = None  # type: Type[zk_storage.Codec]

    # computed by metaclass
    codec = None  # type: zk_storage.Codec
    canonical_name = None  # type: six.text_type
    readable_name = None  # type: six.text_type
    slugified_name = None  # type: six.text_type
    canonical_id = None  # type: six.text_type

    @classmethod
    @final
    def zk_path_to_str_uid(cls, zk_path):
        if cls.zk_node_type is ZkNodeType.DOUBLE_NESTED:
            parts = zk_path.split(u'/')
            assert len(parts) == 2
            return u':'.join(parts)
        elif cls.zk_node_type is ZkNodeType.SINGLE_NESTED:
            assert u'/' not in zk_path
            return zk_path
        raise RuntimeError

    @classmethod
    @final
    def uid_to_zk_path(cls, *full_uid):
        if cls.zk_node_type is ZkNodeType.DOUBLE_NESTED:
            assert len(full_uid) == 2
            return u'/'.join(full_uid)
        elif cls.zk_node_type is ZkNodeType.SINGLE_NESTED:
            assert len(full_uid) == 1
            return full_uid[0]
        raise RuntimeError

    @classmethod
    @final
    def get_cache_structure(cls):
        rv = SkippedNode(DataNode(cls.codec))
        if cls.zk_node_type is ZkNodeType.DOUBLE_NESTED:
            return SkippedNode(rv)
        elif cls.zk_node_type is ZkNodeType.SINGLE_NESTED:
            return rv
        raise RuntimeError


class ModelObjectMeta(abc.ABCMeta):
    def __new__(mcs, name, bases, attrs):
        if name in _MODELS_BY_CLASS_NAME:
            raise RuntimeError(u'Model object "{}" already exists: {!r}'.format(name, _MODELS_BY_CLASS_NAME[name]))

        if object not in bases:
            assert attrs[u'desc'] is not None
            assert attrs[u'cache'] is not None
            assert attrs[u'zk'] is not None

        # don't allow creating/modifying class instances
        attrs[u'__slots__'] = ()
        attrs[u'__init__'] = raise_on_init

        klass = _MODELS_BY_CLASS_NAME[name] = abc.ABCMeta.__new__(mcs, name, bases, attrs)
        if attrs[u'desc']:
            _MODELS_BY_ZK_PREFIX[attrs[u'desc'].zk_prefix] = klass

        return klass


class ModelObject(six.with_metaclass(ModelObjectMeta, object)):
    # required
    desc = None  # type: Descriptor
    cache = None  # type: ModelCache
    zk = None  # type: ModelZkClient
    mongo = None  # type: ModelMongoClient
    version = None  # type: ver.Version
    state = None  # type: ModelObject

    @classmethod
    @abc.abstractmethod
    def create(cls, *_, **__):
        raise NotImplementedError

    @classmethod
    @abc.abstractmethod
    def update(cls, *_, **__):
        raise NotImplementedError

    @classmethod
    @abc.abstractmethod
    def remove(cls, *_, **__):
        raise NotImplementedError

    @classmethod
    def find_rev_spec(cls, version):
        if isinstance(version, ver.Version):
            namespace_id, object_id = version.id
        else:
            namespace_id, object_id = getattr(version, cls.desc.canonical_id)
        pb = cls.cache.get(namespace_id, object_id)
        if pb is not None and pb.meta.version == version.version:
            return pb.spec
        else:
            return cls.mongo.must_get_rev(version.version).spec

    @staticmethod
    def _gen_version_id():
        return six.text_type(uuid.uuid4())

    @classmethod
    def _gen_new_rev(cls, namespace_id, object_id, version, comment, author, spec_pb, utcnow=None):
        if utcnow is None:
            utcnow = datetime.utcnow()
        rev_pb = cls.desc.rev_proto_class(spec=spec_pb)
        rev_pb.meta.id = version
        rev_pb.meta.namespace_id = namespace_id
        setattr(rev_pb.meta, cls.desc.canonical_id, object_id)
        rev_pb.meta.author = author
        rev_pb.meta.comment = comment
        rev_pb.meta.ctime.FromDatetime(utcnow)
        return rev_pb

    @classmethod
    def _create_with_rev(cls, meta_pb, login, spec_pb=None, order_content_pb=None, utcnow=None):
        pb = cls.desc.proto_class(meta=meta_pb)

        if utcnow is None:
            utcnow = datetime.utcnow()
        pb.meta.CopyFrom(meta_pb)
        pb.meta.ctime.FromDatetime(utcnow)
        pb.meta.mtime.FromDatetime(utcnow)
        pb.meta.author = login
        pb.meta.version = cls._gen_version_id()
        if spec_pb is not None:
            pb.spec.CopyFrom(spec_pb)
        if order_content_pb is not None:
            pb.order.content.CopyFrom(order_content_pb)

        rev_pb = cls._gen_new_rev(pb.meta.namespace_id, pb.meta.id, pb.meta.version, pb.meta.comment,
                                  pb.meta.author, pb.spec, utcnow=utcnow)
        cls.mongo.save_rev(rev_pb)

        try:
            cls.zk.create(pb)
        except errors.ConflictError:
            cls.mongo.remove_rev(pb.meta.version)
            raise errors.ConflictError('{} "{}" already exists in namespace "{}".'
                                       .format(cls.desc.readable_name, meta_pb.id, meta_pb.namespace_id))
        return pb


class ModelZkClientMeta(abc.ABCMeta):
    def __new__(mcs, name, bases, attrs):
        if name in _ZK_CLIENTS:
            raise RuntimeError(u'ZkClient "{}" already exists: {!r}'.format(name, _ZK_CLIENTS[name]))

        if object not in bases:
            assert attrs[u'desc'] is not None
            attrs[u'_lock_path'] = ModelZkClientMeta.make_lock_path(attrs[u'desc'].canonical_name)

        # don't allow creating/modifying class instances
        attrs[u'__slots__'] = ()
        attrs[u'__init__'] = raise_on_init

        klass = _ZK_CLIENTS[name] = abc.ABCMeta.__new__(mcs, name, bases, attrs)
        return klass

    @staticmethod
    def make_lock_path(canonical_name):
        """
        For example, "namespace_operation_locks"
        """
        return u'%s_locks' % canonical_name


class ModelZkClient(six.with_metaclass(ModelZkClientMeta, object)):
    # required
    desc = None  # type: Descriptor

    # computed
    _lock_path = None  # type: six.text_type
    _zk_client = None  # type: zk_storage.ZkStorageClient

    # internal
    _awtest_prefix = None

    @classmethod
    @abc.abstractmethod
    def create(cls, *_, **__):
        raise NotImplementedError

    @classmethod
    @abc.abstractmethod
    def update(cls, *_, **__):
        raise NotImplementedError

    @classmethod
    @abc.abstractmethod
    def remove(cls, *_, **__):
        raise NotImplementedError

    @classmethod
    @abc.abstractmethod
    def get(cls, *_, **__):
        raise NotImplementedError

    @classmethod
    @abc.abstractmethod
    def must_get(cls, *_, **__):
        raise NotImplementedError

    @classmethod
    @final
    def lock(cls, keys, namespace_id, op_id):
        full_zk_paths = {u'%s/%s/%s' % (cls._lock_path, key, namespace_id) for key in keys}
        try:
            zk_storage.ZkTransactionClient.batch_create(full_zk_paths, op_id)
        except zk_storage.KazooTransactionException:
            raise errors.ConflictError(u'Another namespace operation is already processing same objects, '
                                       u'please wait until it finishes')

    @classmethod
    @final
    def unlock(cls, keys, namespace_id):
        full_zk_paths = {u'%s/%s/%s' % (cls._lock_path, key, namespace_id) for key in keys}
        zk_storage.ZkTransactionClient.batch_remove(full_zk_paths)

    @cached_class_property
    @final
    def _zk_client(cls):  # noqa
        """
        :rtype: zk_storage.ZkStorageClient
        """
        prefix = u'%s/%s' % (cls._awtest_prefix, cls.desc.zk_prefix) if cls._awtest_prefix else cls.desc.zk_prefix
        return zk_storage.ZkStorageClient(
            client=zookeeper_client.IZookeeperClient.instance(),
            prefix=prefix,
            codec=cls.desc.codec,
        )

    @classmethod
    @final
    def awtest_set_zk_prefix(cls, prefix):
        cls._awtest_prefix = prefix

    @classmethod
    @final
    def _create(cls, zk_path, pb):
        try:
            return cls._zk_client.create(zk_path, pb)
        except zk_storage.NodeAlreadyExistsError:
            raise errors.ConflictError(u'{} "{}" already exists'.format(cls.desc.readable_name,
                                                                        cls.desc.zk_path_to_str_uid(zk_path)))

    @classmethod
    @final
    def _get(cls, zk_path, sync=False):
        if sync:
            cls._zk_client.sync(zk_path)
        return cls._zk_client.get(zk_path)

    @classmethod
    @final
    def _must_get(cls, zk_path, sync=False):
        pb = cls._get(zk_path, sync=sync)
        if pb is None:
            raise errors.NotFoundError(u'{} "{}" does not exist'.format(cls.desc.readable_name,
                                                                        cls.desc.zk_path_to_str_uid(zk_path)))
        return pb

    @classmethod
    @final
    def _update(cls, zk_path, pb=None):
        for m_pb in cls._zk_client.guaranteed_update(zk_path, obj=pb):
            if m_pb is None:
                raise errors.NotFoundError(u'{} "{}" does not exist'.format(cls.desc.readable_name,
                                                                            cls.desc.zk_path_to_str_uid(zk_path)))
            yield m_pb

    @classmethod
    @final
    def _remove(cls, zk_path):
        cls._zk_client.remove(zk_path)

    @classmethod
    @final
    def _remove_recursive(cls, zk_path):
        cls._zk_client.remove(zk_path, recursive=True)

    @classmethod
    @final
    def _cancel_order(cls, zk_path, author, comment, forced=False):
        for m_pb in cls._zk_client.guaranteed_update(zk_path):
            if m_pb is None:
                raise errors.NotFoundError(u'{} "{}" does not exist'.format(cls.desc.readable_name,
                                                                            cls.desc.zk_path_to_str_uid(zk_path)))
            m_pb.order.cancelled.value = True
            m_pb.order.cancelled.author = author
            m_pb.order.cancelled.comment = comment
            m_pb.order.cancelled.mtime.GetCurrentTime()
            m_pb.order.cancelled.forced = forced


class ModelCacheMeta(abc.ABCMeta):
    def __new__(mcs, name, bases, attrs):
        if name in _CACHES:
            raise RuntimeError(u'Cache "{}" already exists: {!r}'.format(name, _CACHES[name]))

        if object not in bases:
            assert attrs[u'desc'] is not None
            assert attrs[u'_cache'] is not None

        # don't allow creating/modifying class instances
        attrs[u'__slots__'] = ()
        attrs[u'__init__'] = raise_on_init

        klass = _CACHES[name] = abc.ABCMeta.__new__(mcs, name, bases, attrs)

        return klass


class ModelCache(six.with_metaclass(ModelCacheMeta, object)):
    # required
    desc = None  # type: Descriptor
    _cache = None  # type: dict

    legacy_update_event = None  # type: collections.namedtuple
    legacy_remove_event = None  # type: collections.namedtuple

    @abc.abstractmethod
    def add(self, *_, **__):
        raise NotImplementedError

    @abc.abstractmethod
    def discard(self, *_, **__):
        raise NotImplementedError

    @abc.abstractmethod
    def get(self, *_, **__):
        raise NotImplementedError

    @abc.abstractmethod
    def must_get(self, *_, **__):
        raise NotImplementedError

    @abc.abstractmethod
    def list(self, *_, **__):
        raise NotImplementedError

    @classmethod
    @final
    def _add(cls, zk_path, pb):
        existing_pb = cls._cache.get(zk_path)
        if existing_pb and cls.desc.codec.get_generation(existing_pb) >= cls.desc.codec.get_generation(pb):  # Do we need process_ctime here? Like in _must_be_ignored cache method
            return False
        cls._cache[zk_path] = pb
        return True

    @classmethod
    @final
    def _discard(cls, zk_path):
        return cls._cache.pop(zk_path, None)

    @classmethod
    @final
    def _get(cls, zk_path):
        return cls._cache.get(zk_path)

    @classmethod
    @final
    def _must_get(cls, zk_path):
        pb = cls._cache.get(zk_path)
        if pb is None:
            raise errors.NotFoundError(u'{} "{}" does not exist'.format(cls.desc.readable_name,
                                                                        cls.desc.zk_path_to_str_uid(zk_path)))
        return pb


class ModelMongoClientMeta(abc.ABCMeta):
    def __new__(mcs, name, bases, attrs):
        if name in _MONGO_CLIENTS:
            raise RuntimeError(u'MongoClient "{}" already exists: {!r}'.format(name, _MONGO_CLIENTS[name]))

        if object not in bases:
            assert attrs[u'desc'] is not None
            desc = attrs[u'desc']  # type: Descriptor
            assert desc.mongo_revs_column is not None
            assert desc.rev_proto_class is not None

        # don't allow creating/modifying class instances
        attrs[u'__slots__'] = ()
        attrs[u'__init__'] = raise_on_init

        klass = _MONGO_CLIENTS[name] = abc.ABCMeta.__new__(mcs, name, bases, attrs)
        return klass


class ModelMongoClient(six.with_metaclass(ModelMongoClientMeta, object)):
    # required
    desc = None  # type: Descriptor

    DEFAULT_LIMIT = 30

    @classmethod
    def _revs(cls):
        return getattr(mongo.get_db(), cls.desc.mongo_revs_column)

    @classmethod
    def ensure_index(cls, field_name):
        cls._revs().ensure_index(field_name, name=field_name+'_index')

    @classmethod
    def save_rev(cls, rev_pb):
        content = rev_pb.SerializeToString()
        cls._revs().insert({
            '_id': rev_pb.meta.id,
            'namespace': rev_pb.meta.namespace_id,
            cls.desc.canonical_id: getattr(rev_pb.meta, cls.desc.canonical_id),
            'ctime': rev_pb.meta.ctime.ToMilliseconds(),
            'author': rev_pb.meta.author,
            'comment': rev_pb.meta.comment,
            'content': bson.Binary(content),
        })

    @classmethod
    def get_rev(cls, rev_id):
        data = cls._revs().find_one(rev_id, projection=('content',))
        if not data:
            return None
        rev_pb = cls.desc.rev_proto_class()
        rev_pb.MergeFromString(data['content'])
        return rev_pb

    @classmethod
    def must_get_rev(cls, rev_id):
        rev_pb = cls.get_rev(rev_id)
        if not rev_pb:
            raise errors.NotFoundError('{} revision "{}" does not exist'.format(cls.desc.readable_name, rev_id))
        return rev_pb

    @classmethod
    def remove_rev(cls, rev_id):
        r = cls._revs().remove({'_id': rev_id}, multi=False)
        return bool(r.get('n'))

    @classmethod
    def _remove_revs(cls, spec):
        r = cls._revs().remove(spec, multi=True)
        return r.get('n', 0)

    @classmethod
    def remove_revs_by_namespace_id(cls, namespace_id):
        return cls._remove_revs({'namespace': namespace_id,})

    @classmethod
    def remove_revs_by_full_id(cls, namespace_id, object_id):
        return cls._remove_revs({'namespace': namespace_id, cls.desc.canonical_id: object_id})

    @classmethod
    def list_revs(cls, namespace_id, object_id, skip=None, limit=None):
        spec = {
            'namespace': namespace_id,
            cls.desc.canonical_id: object_id,
        }
        total = cls._revs().find(spec).count()
        if total == 0:
            return SliceResult([], 0)
        cur = cls._revs().find(spec, skip=skip or 0, limit=limit or cls.DEFAULT_LIMIT, sort=[('ctime', -1)])
        rev_pbs = []
        for item in cur:
            rev_pb = cls.desc.rev_proto_class()
            try:
                rev_pb.MergeFromString(item['content'])
            except DecodeError:
                continue
            rev_pbs.append(rev_pb)
        return SliceResult(rev_pbs, total)

    @classmethod
    def raw_list_revs(self, spec=None, fields=None, skip=0, limit=0, sort=None):
        return self._revs().find(spec, projection=fields, skip=skip, limit=limit, sort=sort)
