# encoding: utf-8
from __future__ import unicode_literals

import json
import logging
import operator
import time
from urllib import quote

import tornadis
from tornado import gen

from intranet.webauth.lib.utils import get_redis_pool

REMOVAL_BATCH_SIZE = 10
REMOVAL_SLEEP_TIME = 0.0002  # 200 μs
GET_ROLE_LOG_TIME = 0.05  # 50 ms

logger = logging.getLogger(__name__)


def fix_role_path(path):
    if not path:
        return path
    # making sure that role path is always encircled by brackets
    return '/{}/'.format(path.strip('/'))


def get_version_set_name(version):
    return 'roles/{}'.format(version)


def get_role_cache_key(login, idm_role, fields_data):
    fields_data = fields_data or {}

    key_pattern = '{role}/{login}/{fields_data}'
    key = key_pattern.format(
        role=quote(idm_role, safe=''),
        login=login,
        fields_data=quote(json.dumps(normalize_fields_data(fields_data)), safe=''),
    )
    return key


@gen.coroutine
def _redis_call(*args):
    with (yield get_redis_pool().connected_client()) as client:
        if isinstance(client, tornadis.TornadisException):
            raise client
        raw_result = yield client.call(*args)
    if isinstance(raw_result, tornadis.TornadisException):
        raise raw_result
    raise gen.Return(raw_result)


@gen.coroutine
def _get(key):
    result = yield _redis_call('GET', key)
    raise gen.Return(result)


@gen.coroutine
def _set(key, value):
    yield _redis_call('SET', key, value)


@gen.coroutine
def _del(*keys):
    removed_items = yield _redis_call('DEL', *keys)
    raise gen.Return(removed_items)


@gen.coroutine
def _sadd(set_name, *keys):
    added_items = yield _redis_call('SADD', set_name, *keys)
    raise gen.Return(added_items)


@gen.coroutine
def _sismember(set_name, key):
    result = yield _redis_call('SISMEMBER', set_name, key)
    raise gen.Return(result)  # 0 or 1


@gen.coroutine
def _srandmember(set_name, count=None):
    if count is None:
        result = yield _redis_call('SRANDMEMBER', set_name)
    else:
        result = yield _redis_call('SRANDMEMBER', set_name, count)
    raise gen.Return(result)


@gen.coroutine
def _sscan(set_name, cursor='0', count=None, match=None):
    if count is not None and match is not None:
        # this is possibly not true, but we will play safe here
        raise ValueError('Could not simultaneously specify COUNT and MATCH parameters')
    command = ['SSCAN', set_name, cursor]
    if count:
        command += ['COUNT', count]
    elif match:
        command += ['MATCH', match]

    result = yield _redis_call(*command)
    raise gen.Return(result)


@gen.coroutine
def _srem(set_name, keys):
    if not isinstance(keys, (tuple, list)):
        keys = [keys]
    result = yield _redis_call('SREM', set_name, *keys)
    raise gen.Return(result)


@gen.coroutine
def get_cache_version():
    key = 'role_cache_version'
    value = yield _get(key)
    raise gen.Return(value)


@gen.coroutine
def set_cache_version(value):
    key = 'role_cache_version'
    yield _set(key, str(value))


def normalize_fields_data(fields_data):
    return sorted(fields_data.items(), key=operator.itemgetter(0))


@gen.coroutine
def check_role_cache(login, idm_role, fields_data):
    start = time.time()

    idm_role = fix_role_path(idm_role)

    cache_version = yield get_cache_version()
    key = get_role_cache_key(login, idm_role, fields_data)
    set_name = get_version_set_name(cache_version)
    value = yield _sismember(set_name, key)

    duration = time.time() - start
    if duration > GET_ROLE_LOG_TIME:
        logger.warn('check_role_cache worked too slowly: %s ms, result=%s', round(duration * 1000, 2), bool(value))
    raise gen.Return(bool(value))


@gen.coroutine
def write_role_cache(version, roles):
    set_name = get_version_set_name(version)
    keys = [
        get_role_cache_key(login, fix_role_path(idm_role), fields_data)
        for (login, idm_role, fields_data) in roles
    ]
    yield _sadd(set_name, *keys)


@gen.coroutine
def write_version_timestamp(version):
    key = 'roles_version_timestamp/{}'.format(version)
    timestamp = time.time()
    yield _set(key, str(timestamp))


@gen.coroutine
def get_version_timestamp(version):
    key = 'roles_version_timestamp/{}'.format(version)
    timestamp = yield _get(key)
    if timestamp is not None:
        timestamp = float(timestamp)
    raise gen.Return(timestamp)


@gen.coroutine
def is_version_nonempty(version):
    set_name = get_version_set_name(version)
    randmember = yield _srandmember(set_name)
    raise gen.Return(randmember is not None)


@gen.coroutine
def remove_version(version):
    # Cannot just use "DEL <set_key>" here
    # Removing large sets in Redis is very slow: https://www.redisgreen.net/blog/deleting-large-objects/
    set_name = get_version_set_name(version)
    cursor, elements = yield _sscan(set_name, '0', count=REMOVAL_BATCH_SIZE)
    while True:
        if elements:
            yield _srem(set_name, elements)
            yield gen.sleep(REMOVAL_SLEEP_TIME)
        if cursor == '0':
            break
        cursor, elements = yield _sscan(set_name, cursor, count=REMOVAL_BATCH_SIZE)
    yield _del(set_name)
