import abc
import inject
import kazoo.exceptions
import os
import six
from google.protobuf.message import Message
from six.moves import zip
from typing import TypeVar, Generic

from awacs.lib import zookeeper_client, strutils
import logging

log = logging.getLogger(u'zk-transaction-client')


PB_CLS = TypeVar('PB_CLS', bound=Message)


class Codec(six.with_metaclass(abc.ABCMeta, Generic[PB_CLS])):
    message_cls = None  # type: PB_CLS

    def __init__(self):
        raise RuntimeError(u"Don't instantiate me, use classmethods")

    @staticmethod
    @abc.abstractmethod
    def set_generation(obj, generation):
        """
        Sets object's generation.

        :type obj: PB_CLS
        :type generation: six.text_type
        """
        raise NotImplementedError

    @staticmethod
    @abc.abstractmethod
    def get_generation(obj):
        """
        Returns object's generation.
        :type obj: PB_CLS
        :rtype: six.text_type
        """
        raise NotImplementedError

    @staticmethod
    def encode(obj):
        """
        Encodes provided object into bytes.

        :type obj: PB_CLS
        :rtype: six.binary_type
        """
        return obj.SerializeToString()

    @classmethod
    def decode(cls, buf, stat=None):
        """
        Decodes bytes into object.

        :type buf: six.binary_type
        :type stat: kazoo.protocol.states.ZnodeStat
        :rtype: PB_CLS
        """
        obj = cls.message_cls()  # type: PB_CLS  # noqa
        obj.MergeFromString(buf)
        if stat is not None:
            cls.set_generation(obj, stat.version)
        return obj


class NodeAlreadyExistsError(Exception):
    pass


class KazooTransactionException(kazoo.exceptions.KazooException):
    """Exception raised when a checked commit fails."""

    def __init__(self, message, failures):
        super(KazooTransactionException, self).__init__(message)
        self._failures = tuple(failures)

    @property
    def failures(self):
        return self._failures


def path_with_prefix(prefix, key):
    if not prefix:
        return key
    prefix = prefix.rstrip(u'/')
    key = key.lstrip(u'/')
    return u'%s/%s' % (prefix, key)


class ZkStorageClient(object):
    """
    :type _client: awacs.lib.zookeeper_client.ZookeeperClient
    :type _prefix: six.text_type
    :type _codec: Codec
    """
    __slots__ = (u'_client', u'_prefix', u'_codec')

    def __init__(self, client, prefix, codec):
        assert codec
        self._client = client
        self._prefix = prefix
        self._codec = codec

    def _prefix_path(self, path):
        return path_with_prefix(self._prefix, path)

    def sync(self, key=u''):
        """
        Syncs specified key.
        """
        path = self._prefix_path(key)
        self._client.client.sync(path)

    def guaranteed_update(self, key, obj=None):
        """
        Should be used in a "for" loop:
        for obj in storage.guaranteed_update(key):
            obj.value = "new_value"

        1. Fetches object specified by key from storage.
        2. Yields it back to the caller so it can be modified
        3. Stores object conditionally. To abort the save use "break" inside the calling loop
        4. If object saving fails with conflict error, go to (1)
        :param key: object key
        :param obj: object value
        """
        while 1:
            if obj is None:
                obj = self.get(key)
            if obj is None:
                yield None
                break
            generation = self._codec.get_generation(obj)
            yield obj
            try:
                self.put(key, obj, generation=generation)
            except kazoo.exceptions.BadVersionError:
                obj = None
            else:
                break

    def list_keys(self, key=''):
        """
        Returns keys of all stored objects.

        :rtype: list[six.text_type]
        """
        path = self._prefix_path(key)
        return self._client.get_children(path)

    def exists(self, key):
        path = self._prefix_path(key)
        return bool(self._client.client.exists(path))

    def get(self, key):
        """
        Returns content of specified key.

        :type key: six.string_types
        """
        path = self._prefix_path(key)
        try:
            data, stat = self._client.read_file(path)
        except kazoo.exceptions.NoNodeError:
            return None
        return self._codec.decode(data, stat=stat)

    def create(self, key, obj):
        """
        Creates new key in storage if not exists, otherwise raises error.
        """
        path = self._prefix_path(key)
        data = self._codec.encode(obj)
        try:
            self._client.create_file(path, value=data)
        except kazoo.exceptions.NodeExistsError:
            raise NodeAlreadyExistsError(key)

    def put(self, key, obj, generation=None):
        """
        Creates or updates key in storage.
        """
        data = self._codec.encode(obj)
        path = self._prefix_path(key)
        if generation is None:  # Unconditional update
            stat = self._client.write_file(path, data)
        else:
            stat = self._client.client.set(path, data, version=generation)
        self._codec.set_generation(obj, stat.version)

    def remove(self, key, recursive=False):
        path = self._prefix_path(key)
        try:
            self._client.delete_file(path, recursive=recursive)
        except kazoo.exceptions.NoNodeError:
            pass


class ZkTransactionClient(object):
    __slots__ = ()
    _client = inject.attr(zookeeper_client.IZookeeperClient)

    @classmethod
    def batch_create(cls, paths, value):
        txn = cls._client.client.transaction()
        for path in paths:
            cls._ensure_parent_path(path)
            if isinstance(value, six.text_type):
                value = value.encode('utf-8')
            txn.create(path, value=value)
        results = txn.commit()
        failures = cls._get_txn_failures(txn, results)
        if failures:
            raise KazooTransactionException(
                u"Transaction with %s operations failed: %s"
                % (len(txn.operations), failures), failures)
        return results

    @classmethod
    def batch_remove(cls, paths):
        txn = cls._client.client.transaction()
        for path in paths:
            txn.delete(path)
        results = txn.commit()
        for op, exc in cls._get_txn_failures(txn, results):
            if isinstance(exc, kazoo.exceptions.NoNodeError):
                continue
            elif isinstance(exc, Exception):
                raise exc
            else:
                log.error(exc)

    @classmethod
    def _ensure_parent_path(cls, path):
        parent, _ = os.path.split(strutils.removesuffix(path, u'/'))
        cls._client.client.ensure_path(parent)

    @classmethod
    def _get_txn_failures(cls, txn, results):
        failures = []
        for op, result in zip(txn.operations, results):
            if isinstance(result, kazoo.exceptions.KazooException):
                failures.append((op, result))
        if len(results) < len(txn.operations):
            raise KazooTransactionException(
                u"Transaction returned %s results, this is less than the number of expected operations %s"
                % (len(results), len(txn.operations)), failures)
        if len(results) > len(txn.operations):
            raise KazooTransactionException(
                u"Transaction returned %s results, this is greater than the number of expected operations %s"
                % (len(results), len(txn.operations)), failures)
        return failures
