# -*- coding: utf-8 -*-
from copy import deepcopy
import time

from passport.backend.utils.string import smart_bytes
from redis import (
    RedisError,
    ResponseError,
)


class FakeConnectionPool(object):
    def __init__(self, connection_kwargs):
        self.connection_kwargs = connection_kwargs


class FakeRedis(object):
    def __init__(self, host, port, password=None, socket_timeout=None, ssl=None):
        self._redis = {}
        self._key_expiration_map = {}
        self._in_transaction = False
        self.connection_pool = FakeConnectionPool(dict(host=host, port=port))

    def get(self, key):
        key = smart_bytes(key)
        if not self.exists(key):
            return None
        return self._redis[key]

    def hget(self, key, field):
        key = smart_bytes(key)
        field = smart_bytes(field)
        if not self.exists(key):
            return None
        return self._redis[key].get(field)

    def mget(self, keys):
        keys = [smart_bytes(key) for key in keys]
        return [self._redis.get(key) if self.exists(key) else None
                for key in keys]

    def set(self, key, value):
        key = smart_bytes(key)
        value = smart_bytes(value)
        self._redis[key] = smart_bytes(value)
        if key in self._key_expiration_map:
            del self._key_expiration_map[key]
        return True

    def setex(self, key, ttl, value):
        key = smart_bytes(key)
        value = smart_bytes(value)
        self.set(key, value)
        self.expire(key, ttl)
        return True

    def hmset(self, key, data):
        key = smart_bytes(key)
        if key not in self._redis:
            self._redis[key] = {}
        for field, value in data.items():
            field = smart_bytes(field)
            value = smart_bytes(value)
            self._redis[key][field] = value
        return True

    def hgetall(self, key):
        key = smart_bytes(key)
        if not self.exists(key):
            return False
        return self._redis.get(key)

    def exists(self, key):
        key = smart_bytes(key)
        result = self._redis.get(key)
        if not result:
            return False
        expiration_time = self._key_expiration_map.get(key)
        expired = expiration_time and time.time() >= expiration_time
        if expired:
            del self._redis[key]
        return not expired

    def expire(self, key, ttl):
        key = smart_bytes(key)
        if not self.exists(key):
            return False
        self._key_expiration_map[key] = time.time() + ttl
        return True

    def delete(self, *keys):
        keys = [smart_bytes(key) for key in keys]
        for key in keys:
            if not self.exists(key):
                return False
            del self._redis[key]
        return True

    def incr(self, key):
        key = smart_bytes(key)
        if key not in self._redis:
            self._redis[key] = '0'
        try:
            self._redis[key] = str(int(self._redis[key]) + 1)
        except ValueError:
            raise ResponseError('value is not an integer or out of range')
        return int(self._redis[key])

    def hincrby(self, key, field, value):
        key = smart_bytes(key)
        field = smart_bytes(field)
        if key not in self._redis:
            self._redis[key] = {}
        if field not in self._redis[key]:
            self._redis[key][field] = '0'
        try:
            stored_value = int(self._redis[key][field])
        except ValueError:
            raise ResponseError('hash value is not an integer')
        try:
            self._redis[key][field] = str(stored_value + int(value))
        except ValueError:
            raise ResponseError('value is not an integer or out of range')
        return int(self._redis[key][field])

    def rpush(self, key, *values):
        if not values:
            raise ResponseError('wrong number of arguments for "rpush" command')
        key = smart_bytes(key)
        if key not in self._redis:
            self._redis[key] = []
        self._redis[key].extend([smart_bytes(val) for val in values])
        return len(self._redis[key])

    def lpush(self, key, *values):
        if not values:
            raise ResponseError('wrong number of arguments for "rpush" command')
        key = smart_bytes(key)
        if key not in self._redis:
            self._redis[key] = []
        self._redis[key] = [smart_bytes(v) for v in reversed(values)] + self._redis[key]
        return len(self._redis[key])

    def lrange(self, key, start, end):
        key = smart_bytes(key)
        if key not in self._redis:
            return []
        # Шаманство с индексами слайса необходимо, потому что слайс списка в редисе не совсем
        # совпадает со слайсом списка в питоне.
        if start == 0 and end == -1:  # так из редиса забираем список полностью
            return self._redis[key]
        if start == 1 and end == -1:  # так из редиса забираем список без первого элемента
            result = self._redis[key]
            return result[1:]

        if start < 0 or end < 0:
            raise ValueError('Fake Redis does not support negative slices except [0:-1], [0: -2] for lrange')

        return self._redis[key][start:end + 1]

    def llen(self, key):
        key = smart_bytes(key)
        if key not in self._redis:
            return 0
        return len(self._redis[key])

    def ltrim(self, key, start, end):
        key = smart_bytes(key)
        if key not in self._redis:
            return True
        if start < 0 or end < 0:
            raise ValueError('Fake Redis does not support negative slices for ltrim')
        self._redis[key] = self._redis[key][start:end + 1]
        return True

    def ttl(self, key):
        key = smart_bytes(key)
        if key not in self._redis:
            return None
        expiration_time = self._key_expiration_map.get(key)
        if expiration_time is None:
            return None
        ttl = expiration_time - time.time()
        if ttl < 0:
            return None
        return ttl

    def hset(self, key, field, value):
        key = smart_bytes(key)
        field = smart_bytes(field)
        if key not in self._redis:
            self._redis[key] = {}

        flag = 1 if field not in self._redis[key] else 0
        self._redis[key][field] = smart_bytes(value)

        return flag

    def hdel(self, key, *fields):
        key = smart_bytes(key)
        if key not in self._redis:
            return 0

        for field in fields:
            field = smart_bytes(field)
            if field not in self._redis[key]:
                return False
            del self._redis[key][field]
        return True

    def sadd(self, key, *values):
        key = smart_bytes(key)
        values = [smart_bytes(value) for value in values]
        if key not in self._redis:
            self._redis[key] = set()

        self._redis[key].update(values)

    def smembers(self, key):
        key = smart_bytes(key)
        return self._redis.get(key, set())

    def sismember(self, key, value):
        key = smart_bytes(key)
        value = smart_bytes(value)
        return value in self._redis.get(key, set())

    def ping(self):
        pass

    def pipeline(self):
        return FakePipeline(self)

    def multi(self):
        self._in_transaction = True

    def watch(self, *fields):
        pass

    def unwatch(self):
        pass


class FakePipeline(object):
    def __init__(self, redis):
        self._redis = redis
        self._command_queue = []

    def __getattribute__(self, name):
        def add_to_queue(*args, **kwargs):
            self._command_queue.append((name, args, kwargs))

        if name in FakeRedis.__dict__ and name not in ('pipeline', ) and not name.startswith('_'):
            add_to_queue.__name__ = name  # для правильной записи в логи
            return add_to_queue
        else:
            return super(FakePipeline, self).__getattribute__(name)

    def execute(self):
        result = []
        redis_state = deepcopy((self._redis._redis, self._redis._key_expiration_map))
        try:
            for (command, args, kwargs) in self._command_queue:
                result.append(FakeRedis.__dict__[command](self._redis, *args, **kwargs))
            if self._redis._in_transaction:
                result = result[1:]  # отрезаем результат multi
            return result
        except RedisError:
            if self._redis._in_transaction:
                # откатываемся на состояние "до транзакции"
                self._redis._redis, self._redis._key_expiration_map = redis_state
            raise
        finally:
            self._command_queue = []
            self._redis._in_transaction = False

    def discard(self):
        self._command_queue = []
        self._redis._in_transaction = False

    def pipeline(self):
        return self
