# -*- coding: utf-8 -*-
from collections import defaultdict
from functools import wraps

from passport.backend.core.tracks.serializer import Serializer
import six


try:
    from contextlib import GeneratorContextManager
except:
    from contextlib import _GeneratorContextManager as GeneratorContextManager


_TRANSACTION_LEVEL = defaultdict(int)


class NotNestedGeneratorContextManager(GeneratorContextManager):
    def __init__(self, generator, args, kwargs, key, allow_nested):
        self.key = key
        self.allow_nested = allow_nested
        if six.PY2:
            super(NotNestedGeneratorContextManager, self).__init__(generator(*args, **kwargs))
        else:
            super(NotNestedGeneratorContextManager, self).__init__(generator, args, kwargs)

    def _get_level(self):
        global _TRANSACTION_LEVEL
        return _TRANSACTION_LEVEL.get(self.key)

    def _increase_level(self):
        global _TRANSACTION_LEVEL
        _TRANSACTION_LEVEL[self.key] += 1

    def _decrease_level(self):
        global _TRANSACTION_LEVEL
        _TRANSACTION_LEVEL[self.key] -= 1
        if not _TRANSACTION_LEVEL[self.key]:
            del _TRANSACTION_LEVEL[self.key]

    def __enter__(self):
        if self._get_level() and not self.allow_nested:
            raise RuntimeError('Nested transactions are prohibited')
        try:
            self._increase_level()
            return super(NotNestedGeneratorContextManager, self).__enter__()
        except Exception:
            self._decrease_level()
            raise

    def __exit__(self, type, value, traceback):
        try:
            return super(NotNestedGeneratorContextManager, self).__exit__(type, value, traceback)
        finally:
            self._decrease_level()


def not_nested_contextmanager(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        transaction = args[0]
        return NotNestedGeneratorContextManager(
            generator=func,
            args=args,
            kwargs=kwargs,
            key=transaction.track.track_id,
            allow_nested=transaction.allow_nested,
        )
    return wrapper


class TrackTransaction(object):
    def __init__(self, track, manager, allow_nested=False):
        self.track = track
        self.manager = manager
        self.redis_node = self.manager._get_redis_node(self.track.track_id)
        self.allow_nested = allow_nested

    @property
    def is_in_nested_transaction(self):
        global _TRANSACTION_LEVEL
        return _TRANSACTION_LEVEL[self.track.track_id] > 1

    @not_nested_contextmanager
    def rollback_on_error(self):
        snapshot = self.track.snapshot()
        try:
            yield self.track
            Serializer().execute(
                old_track=snapshot,
                new_track=self.track,
                redis_node=self.redis_node,
                allow_incremental_updates=not self.is_in_nested_transaction,
            )
        except:
            self.track._data = snapshot._data
            self.track.track_version = snapshot.track_version
            raise

    @not_nested_contextmanager
    def commit_on_error(self):
        snapshot = self.track.snapshot()
        try:
            yield self.track
        finally:
            Serializer().execute(
                old_track=snapshot,
                new_track=self.track,
                redis_node=self.redis_node,
                allow_incremental_updates=not self.is_in_nested_transaction,
            )

    @not_nested_contextmanager
    def delete(self):
        yield self.track
        Serializer().execute(
            old_track=self.track,
            new_track=None,
            redis_node=self.redis_node,
            allow_incremental_updates=True,  # удалять можно всё
        )
