from __future__ import unicode_literals

import gevent.event
import kazoo.exceptions
import logging
import six

from infra.swatlib.storage import consts
from infra.swatlib.storage import errors
from infra.swatlib.storage import interfaces
from infra.swatlib.storage.zk import treecache


class ZkWatchableCachingStorage(interfaces.IWatchableCachingStorage):
    """
    Zookeeper implementation of watchable storage.
    """

    def __init__(self, client, path, codec, structure):
        """
        :type client: infra.swatlib.zk.client.ZookeeperClient
        :type path: unicode
        :type codec: infra.swatlib.storage.zk.treecache.Codec
        :type structure: infra.swatlib.storage.zk.treecache.Structure
        """
        self._client = client
        self._path = path.rstrip('/')  # ensure_path with trailing slash throws NoNodeError
        self._codec = codec
        self._cache = treecache.TreeCache(client.client, self._path, structure=structure)
        self._log = logging.getLogger(path)

    def _get_from_cache(self, path):
        node = self._cache.get_data(path)
        if node is None:
            return None
        _, data, _ = node
        return data

    def _get_from_storage(self, path):
        self._client.client.sync(path)
        try:
            data, stat = self._client.read_file(path)
        except kazoo.exceptions.NoNodeError:
            return None
        return self._codec.decode(data, stat)

    def _list(self, list_keys_func, get_func):
        """
        :type list_keys_func: collections.Callable
        :type get_func: collections.Callable
        """
        rv = []
        for child_path in list_keys_func(self._path):
            p = self.prefix_path(child_path)
            v = get_func(p)
            if v is None:
                # Node may be removed after getting children list
                continue
            rv.append(v)
        return rv

    def prefix_path(self, key):
        """
        :type key: unicode
        :rtype: unicode
        """
        return '{}/{}'.format(self._path, key.lstrip('/'))

    def start(self):
        self._client.ensure_path(self._path)
        is_initialized = gevent.event.Event()

        def listener(event):
            if event.event_type == event.INITIALIZED:
                is_initialized.set()

        self.listen(listener)
        self._cache.start()
        self._log.debug('Waiting for cache initialization...')
        is_initialized.wait()
        self._log.debug('Cache initialized')

    def stop(self):
        self._cache.close()

    def watch_list(self):
        cur = set()  # Current list of objects
        w = self._client.children_watcher(self._path)
        try:
            while 1:
                new = set(w.wait())
                # Calculate diff
                added = new - cur
                removed = cur - new
                cur = new
                yield added, removed
        finally:
            w.cancel()

    def watch(self, key):
        path = self.prefix_path(key)
        w = self._client.data_watcher(path)
        try:
            while 1:
                try:
                    data, stat = w.wait()
                except kazoo.exceptions.NoNodeError:
                    raise errors.NodeNotFoundError(key)
                yield self._codec.decode(data, stat)
        finally:
            w.cancel()

    def create(self, key, obj):
        path = self.prefix_path(key)
        data = self._codec.encode(obj)
        try:
            self._client.create_file(path, value=data)
        except kazoo.exceptions.NodeExistsError:
            raise errors.NodeAlreadyExistsError(path)

    def guaranteed_update(self, key, update_func):
        path = self.prefix_path(key)
        while 1:
            try:
                data, stat = self._client.read_file(path)
            except kazoo.exceptions.NoNodeError:
                raise errors.NodeNotFoundError()
            obj = self._codec.decode(data, stat)
            update_func(obj)
            try:
                self.update_conditionally(key, obj, stat.version)
            except errors.ConcurrentModificationError:
                continue
            return obj

    def update_conditionally(self, key, obj, version):
        path = self.prefix_path(key)
        data = self._codec.encode(obj)
        try:
            stat = self._client.client.set(path, value=data, version=int(version))
        except kazoo.exceptions.BadVersionError:
            raise errors.ConcurrentModificationError('Conflict when modifying data for key {} in ZK'.format(key))
        except kazoo.exceptions.NoNodeError:
            raise errors.NodeNotFoundError(key)
        self._codec.set_generation(obj, stat.version)
        return six.text_type(stat.version)

    def remove(self, key):
        path = self.prefix_path(key)
        try:
            self._client.delete_file(path)
        except kazoo.exceptions.NoNodeError:
            pass

    def get(self, key, consistency=consts.CONSISTENCY_STRONG, default=None):
        """
        Retrieves data either from cache, or from synced ZK

        :type key: unicode
        :type consistency: unicode
        :type default: optional[Any]
        """
        path = self.prefix_path(key)
        if consistency == consts.CONSISTENCY_STRONG:
            r = self._get_from_storage(path)
        elif consistency == consts.CONSISTENCY_WEAK:
            r = self._get_from_cache(path)
        else:
            raise ValueError('Unknown consistency level: {}'.format(consistency))
        if r is None:
            return default
        return r

    def list_keys(self, consistency=consts.CONSISTENCY_STRONG):
        if consistency == consts.CONSISTENCY_STRONG:
            return self._client.get_children(self._path)
        elif consistency == consts.CONSISTENCY_WEAK:
            return self._cache.get_children(self._path)
        raise ValueError('Unknown consistency level: {}'.format(consistency))

    def list(self, consistency=consts.CONSISTENCY_STRONG):
        if consistency == consts.CONSISTENCY_STRONG:
            return self._list(self._client.get_children, self._get_from_storage)
        elif consistency == consts.CONSISTENCY_WEAK:
            return self._list(self._cache.get_children, self._get_from_cache)
        raise ValueError('Unknown consistency level: {}'.format(consistency))

    def listen(self, listener):
        self._cache.listen(listener)
