from collections import defaultdict, namedtuple

import six
import enum
from typing import Generator, List
from datetime import datetime
from six.moves import map

from awacs.lib.vectors.version import WeightSectionVersion
from awacs.lib.models.classes import (
    Descriptor,
    ModelObject,
    ModelZkClient,
    ModelCache,
    ModelMongoClient,
    ZkNodeType,
    IncludeInDaemons,
)
from awacs.model import codecs, util, errors
from infra.awacs.proto import model_pb2


class WeightSectionDescriptor(Descriptor):
    zk_prefix = u'weight_sections'
    zk_node_type = ZkNodeType.DOUBLE_NESTED

    mongo_revs_column = 'weight_section_revisions'
    rev_proto_class = model_pb2.WeightSectionRevision

    proto_class = model_pb2.WeightSection
    codec_class = codecs.GenerationCodec

    include_in_daemons = IncludeInDaemons.ALL


class WeightSectionZkClient(ModelZkClient):
    desc = WeightSectionDescriptor

    @classmethod
    def create(cls, pb):
        """
        :type pb: model_pb2.WeightSection
        """
        return cls._create(cls.desc.uid_to_zk_path(pb.meta.namespace_id, pb.meta.id), pb)

    @classmethod
    def get(cls, namespace_id, ws_id, sync=False):
        """
        :type namespace_id: six.text_type
        :type ws_id: six.text_type
        :type sync: bool
        :rtype: model_pb2.WeightSection | None
        """
        return cls._get(cls.desc.uid_to_zk_path(namespace_id, ws_id), sync=sync)

    @classmethod
    def must_get(cls, namespace_id, ws_id, sync=False):
        """
        :type namespace_id: six.text_type
        :type ws_id: six.text_type
        :type sync: bool
        :rtype: model_pb2.WeightSection
        :raises: errors.NotFoundError
        """
        return cls._must_get(cls.desc.uid_to_zk_path(namespace_id, ws_id))

    @classmethod
    def update(cls, namespace_id, ws_id, pb=None):
        """
        :type namespace_id: six.text_type
        :type ws_id: six.text_type
        :type pb: model_pb2.WeightSection | None
        :rtype: Generator[model_pb2.WeightSection, None, None]
        :raises: errors.NotFoundError
        """
        if pb is not None:
            assert isinstance(pb, model_pb2.WeightSection)
            assert pb.meta.namespace_id == namespace_id
            assert pb.meta.id == ws_id
        for pb in cls._update(cls.desc.uid_to_zk_path(namespace_id, ws_id), pb):
            yield pb

    @classmethod
    def remove(cls, namespace_id, ws_id):
        """
        :type namespace_id: six.text_type
        :type ws_id: six.text_type
        """
        return cls._remove(cls.desc.uid_to_zk_path(namespace_id, ws_id))

    @classmethod
    def remove_all(cls, namespace_id):
        """
        :type namespace_id: six.text_type
        """
        return cls._remove_recursive(namespace_id)

    @classmethod
    def cancel_order(cls, namespace_id, ws_id, author, comment=u''):
        """
        :type namespace_id: six.text_type
        :type ws_id: six.text_type
        :type author: six.text_type
        :type comment: six.text_type
        """
        return cls._cancel_order(cls.desc.uid_to_zk_path(namespace_id, ws_id), author, comment)


class WeightSectionCache(ModelCache):
    desc = WeightSectionDescriptor

    legacy_update_event = namedtuple('WeightSectionUpdate', ['path', 'pb'])
    legacy_remove_event = namedtuple('WeightSectionRemove', ['path'])

    _cache = {}
    _cache_by_namespace_id = defaultdict(set)
    _cache_mtimes = {}
    _cache_deleted_paths = set()

    class QueryTarget(enum.Enum):
        ID_REGEXP = 1
        VALIDATED_STATUS_IN = 2
        IN_PROGRESS_STATUS_IN = 3
        ACTIVE_STATUS_IN = 4
        DELETED = 5

    class SortTarget(enum.Enum):
        ID = 1
        MTIME = 2

    @classmethod
    def add(cls, zk_path, pb):
        """
        :type zk_path: six.text_type
        :type pb: model_pb2.WeightSection
        """
        added = cls._add(zk_path, pb)
        if not added:
            return

        cls._cache_by_namespace_id[pb.meta.namespace_id].add(zk_path)
        cls._cache_mtimes[pb.meta.namespace_id] = pb.meta.mtime.ToDatetime()

        if pb.spec.deleted:
            cls._cache_deleted_paths.add(zk_path)
        else:
            cls._cache_deleted_paths.discard(zk_path)

    @classmethod
    def discard(cls, zk_path):
        """
        :type zk_path: six.text_type
        """
        ws_pb = cls._discard(zk_path)
        if not ws_pb:
            return

        cls._cache_by_namespace_id[ws_pb.meta.namespace_id].discard(zk_path)
        cls._cache_mtimes.pop(zk_path, None)
        cls._cache_deleted_paths.discard(zk_path)

    @classmethod
    def get(cls, namespace_id, ws_id):
        """
        :type namespace_id: six.text_type
        :type ws_id: six.text_type
        :rtype: model_pb2.WeightSection | None
        """
        return cls._cache.get(cls.desc.uid_to_zk_path(namespace_id, ws_id))

    @classmethod
    def must_get(cls, namespace_id, ws_id):
        """
        :type namespace_id: six.text_type
        :type ws_id: six.text_type
        :rtype: model_pb2.WeightSection
        :raises: errors.NotFoundError
        """
        return cls._must_get(cls.desc.uid_to_zk_path(namespace_id, ws_id))

    @classmethod
    def list(cls, namespace_id=None, query=None, sort=(SortTarget.ID, 1)):
        """
        :type namespace_id: six.text_type | None
        :type query: dict | None
        :type sort: Iterable
        :rtype: List[model_pb2.WeightSection]
        """
        query = query or {}

        if namespace_id:
            paths = cls._cache_by_namespace_id.get(namespace_id, set())
        elif cls._cache_by_namespace_id:
            paths = set.union(*six.itervalues(cls._cache_by_namespace_id))
        else:
            paths = set()

        id_regexp = query.get(cls.QueryTarget.ID_REGEXP)
        if id_regexp:
            paths = {p for p in paths if id_regexp.search(cls._cache[p].meta.id)}

        validated_status_in = query.get(cls.QueryTarget.VALIDATED_STATUS_IN, [])
        if validated_status_in:
            paths = {p for p in paths if cls._cache[p].status.validated.status in validated_status_in}

        in_progress_status_in = query.get(cls.QueryTarget.IN_PROGRESS_STATUS_IN, [])
        if in_progress_status_in:
            paths = {p for p in paths if cls._cache[p].status.in_progress.status in in_progress_status_in}

        active_status_in = query.get(cls.QueryTarget.ACTIVE_STATUS_IN, [])
        if active_status_in:
            paths = {p for p in paths if cls._cache[p].status.active.status in active_status_in}

        deleted = query.get(cls.QueryTarget.DELETED, False)
        if deleted:
            paths = paths & cls._cache_deleted_paths

        if sort[0] == cls.SortTarget.MTIME:
            key = cls._cache_mtimes.get
        else:
            key = lambda x: x
        return list(map(cls._cache.get, sorted(paths, key=key, reverse=sort[-1] < 0)))

    @classmethod
    def count(cls, namespace_id=None):
        """
        :type namespace_id: six.text_type | None
        :rtype: int
        """
        if namespace_id:
            return len(cls._cache_by_namespace_id.get(namespace_id) or frozenset())
        else:
            return len(cls._cache)


class WeightSectionMongoClient(ModelMongoClient):
    desc = WeightSectionDescriptor


class WeightSection(ModelObject):
    desc = WeightSectionDescriptor
    zk = WeightSectionZkClient
    cache = WeightSectionCache
    mongo = WeightSectionMongoClient
    version = WeightSectionVersion

    @classmethod
    def create(cls, meta_pb, spec_pb, login):
        """
        :type meta_pb: model_pb2.WeightSectionMeta
        :type spec_pb:model_pb2.WeightSectionSpec
        :type login: six.text_type
        :rtype: model_pb2.WeightSection
        """
        return cls._create_with_rev(meta_pb, login, spec_pb=spec_pb)

    @classmethod
    def remove(cls, namespace_id, ws_id):
        """
        :type namespace_id: six.text_type
        :type ws_id: six.text_type
        """
        cls.zk.remove(namespace_id, ws_id)
        cls.mongo.remove_revs_by_full_id(namespace_id, ws_id)

    @classmethod
    def update(cls, namespace_id, ws_id, version, comment, login, updated_spec_pb):
        """
        :type namespace_id: six.text_type
        :type ws_id: six.text_type
        :type version: six.text_type
        :type comment: six.text_type
        :type login: six.text_type
        :type updated_spec_pb: model_pb2.WeightSectionSpec
        :rtype: model_pb2.WeightSection
        """
        ws_pb = None
        new_version = cls._gen_version_id()
        utcnow = datetime.utcnow()
        rev_pb = cls._gen_new_rev(namespace_id, ws_id, new_version, comment, login, updated_spec_pb, utcnow=utcnow)
        cls.mongo.save_rev(rev_pb)

        for ws_pb in cls.zk.update(namespace_id, ws_id):
            if ws_pb.meta.version != version:
                cls.mongo.remove_rev(new_version)
                raise errors.ConflictError(
                    u'Weight Section modification conflict: assumed version="{}", current="{}"'.format(
                        version, ws_pb.meta.version))

            if ws_pb.spec.deleted:
                raise errors.ConflictError('WeightSection with "spec.deleted" flag can not be updated')
            if not util.update_spec_in_pb(ws_pb, updated_spec_pb, new_version, comment, login, utcnow=utcnow):
                break

        return ws_pb
