import os
import uuid
import math
import time
import logging
import random
import functools

from redis import StrictRedis
from redis import WatchError, ConnectionError, ReadOnlyError
from redis.sentinel import Sentinel

from lacmus2 import hostlist


def make_client(redis_cfg):
    logging.getLogger('lacmus2.redis').info('redis config: %s', redis_cfg)
    conn_kwargs = {
        'socket_timeout': redis_cfg.get('SOCKET_TIMEOUT', 10),
        'retry_on_timeout': redis_cfg.get('RETRY_ON_TIMEOUT', True),
        'decode_responses': True,
        'db': redis_cfg['DBNUM'],
    }
    if redis_cfg['USE_SENTINEL']:
        sentinels = redis_cfg['SENTINELS'][:]
        random.shuffle(sentinels)
        sentinel = Sentinel(sentinels, **conn_kwargs)
        return sentinel.master_for(redis_cfg['SENTINEL_SERVICE_NAME'])

    return StrictRedis(redis_cfg['HOST'], redis_cfg['PORT'], **conn_kwargs)


def retry(method):
    RETRIES_WITH_SAME_CLIENT = 4
    RETRIES_RECONNECT = 4
    DELAY = 2

    @functools.wraps(method)
    def wrapped(self, *args, **kwargs):
        for rattempt in range(RETRIES_RECONNECT):
            for sattempt in range(RETRIES_WITH_SAME_CLIENT):
                self.logger.info('calling %s', method.__name__)
                try:
                    return method(self, *args, **kwargs)
                except (ConnectionError, ReadOnlyError) as exc:
                    if sattempt == RETRIES_WITH_SAME_CLIENT - 1:
                        if rattempt == RETRIES_RECONNECT - 1:
                            self.logger.exception(
                                '%s fails after %d attempts to reconnect',
                                method.__name__, RETRIES_RECONNECT
                            )
                            raise
                        else:
                            self.logger.exception('%s failed too many times, '
                                                  'giving up with this client',
                                                  method.__name__)
                    else:
                        self.logger.error(
                            '%s failed with %s, sleeping for %.1fs',
                            method.__name__, type(exc).__name__, DELAY
                        )
                self.logger.info(
                    'connections in the pool: %s',
                    self.redis.connection_pool._available_connections
                )
                time.sleep(DELAY)
            self.connect()
    return wrapped


class Lacmus2RedisStorage(object):
    NUM_HOSTS_ON_PAGE = 100

    def __init__(self, redis_cfg):
        self.redis_cfg = redis_cfg
        self.logger = logging.getLogger('lacmus2.redis.storage')
        self.connect()

    def connect(self):
        self.redis = make_client(self.redis_cfg)
        self.logger.info('connected')

    @classmethod
    def _key(self, *args):
        return '\x00'.join(str(x) for x in args)

    @retry
    def process_hostreport(self, host, timestamp, key, value):
        if '\x00' in host or '\x00' in key or '\x00' in value:
            raise ValueError()
        hk2v_key = self._key('hk2v', host, key)
        hk2m_key = self._key('hk2m', host, key)
        for attempt in range(100):
            try:
                with self.redis.pipeline() as pipeline:
                    pipeline.watch(hk2m_key, hk2v_key)
                    last_changed, old_value = pipeline.mget(hk2m_key, hk2v_key)
                    if last_changed:
                        if float(last_changed) > timestamp:
                            break
                    pipeline.multi()
                    pipeline.set(hk2m_key, str(timestamp))
                    pipeline.zadd(self._key('expires'),
                                  **{'%s\x00%s' % (host, key): timestamp})
                    if value != old_value:
                        if old_value is not None:
                            pipeline.hincrby(self._key('k2vc', key),
                                             old_value, -1)
                            pipeline.srem(self._key('kv2hh', key, old_value),
                                          host)
                        pipeline.hincrby(self._key('k2vc', key), value, +1)
                        pipeline.sadd(self._key('kv2hh', key, value), host)
                        pipeline.set(hk2v_key, value)
                        pipeline.sadd(self._key('all_keys'), key)
                    pipeline.execute()
            except WatchError:
                continue
            break
        else:
            raise RuntimeError("exceeded 100 attempts to process")

    @retry
    def cleanup(self, max_age, limit):
        expired = time.time() - max_age
        pairs = self.redis.zrangebyscore(
            'expires', min=0, max=expired,
            start=0, num=limit, withscores=True
        )
        removed = 0
        for hk, expires in pairs:
            host, key = hk.split('\x00')
            try:
                with self.redis.pipeline() as pipeline:
                    hk2m_key = self._key('hk2m', host, key)
                    hk2v_key = self._key('hk2v', host, key)
                    pipeline.watch(hk2m_key, hk2v_key)
                    mtime, value = self.redis.mget(hk2m_key, hk2v_key)
                    if mtime is None or float(mtime) > expired:
                        continue
                    pipeline.multi()
                    pipeline.delete(hk2m_key)
                    pipeline.delete(hk2v_key)
                    pipeline.hincrby(self._key('k2vc', key), value, -1)
                    pipeline.srem(self._key('kv2hh', key, value), host)
                    pipeline.zrem('expires', hk)
                    pipeline.execute()
            except WatchError:
                continue
            else:
                removed += 1
        return removed

    @retry
    def get_signals(self):
        result = {}

        for key in self.redis.smembers('all_keys'):
            value_count = {
                v: int(vc)
                for v, vc in self.redis.hgetall(self._key('k2vc', key)).items()
                if int(vc) != 0
            }
            if value_count:
                result[key] = value_count

        return result

    @retry
    def list_hosts(self, selector_vtype, selector_key,
                   key, value, filters, page, hosts_on_page,
                   compact=False):
        to_intersect = []
        for fkey, fvalue in filters:
            to_intersect.append(self._key('kv2hh', fkey, fvalue))
        if selector_vtype and selector_key:
            to_intersect.append(self._key('s2hh', selector_vtype,
                                          selector_key))
        with _RedisTmpStorageManager(self.redis) as tmpstor:
            if value is not None:
                to_intersect.append(self._key('kv2hh', key, value))
                setname = tmpstor.make_name()
                numhosts = self.redis.zinterstore(setname, to_intersect)
            else:
                if not to_intersect:
                    return [], 0, 0
                k2vc_key = self._key('k2vc', key)
                allvalues = [v for v, n in self.redis.hgetall(k2vc_key).items()
                             if int(n) != 0]
                tmpset = tmpstor.make_name()
                self.redis.sunionstore(tmpset, [self._key('kv2hh', key, v)
                                                for v in allvalues])
                setname = tmpstor.make_name()
                self.redis.sinterstore(setname, to_intersect)
                self.redis.sdiffstore(setname, setname, tmpset)
                numhosts = self.redis.zunionstore(setname, [setname])
            numpages = int(math.ceil(numhosts / hosts_on_page))
            if page >= numpages:
                page = numpages - 1
                if page < 0:
                    page = 0
            firstidx = max(0, page * hosts_on_page)
            lastidx = min(firstidx + hosts_on_page - 1, numhosts - 1)
            hosts = self.redis.zrange(setname, firstidx, lastidx)
            page = firstidx // hosts_on_page
        if compact:
            hosts = hostlist.compactify_hostlist(hosts)
        return hosts, page, numpages

    @retry
    def get_chart(self, selector_vtype, selector_key, key, filters):
        value_counts = {
            k: int(v)
            for k, v in self.redis.hgetall(self._key('k2vc', key)).items()
            if int(v) != 0
        }
        result = {}
        to_intersect = []
        for fkey, fvalue in filters:
            to_intersect.append(self._key('kv2hh', fkey, fvalue))
        if selector_vtype and selector_key:
            s2hh_key = self._key('s2hh', selector_vtype, selector_key)
            to_intersect.append(s2hh_key)

        if not to_intersect:
            return dict(value_counts, **{'': 0})

        with _RedisTmpStorageManager(self.redis) as tmpstor:
            if len(to_intersect) == 1:
                [base_set] = to_intersect
                base_count = self.redis.scard(base_set)
            else:
                base_set = tmpstor.make_name()
                base_count = self.redis.sinterstore(base_set, to_intersect)

            if base_count == 0:
                return {'': 0}

            for value in value_counts:
                with _RedisTmpStorageManager(self.redis) as tmpstor2:
                    setname = tmpstor2.make_name()
                    vcount = self.redis.sinterstore(
                        setname, base_set, self._key('kv2hh', key, value)
                    )
                if vcount:
                    result[value] = vcount

        result[''] = base_count - sum(result.values())
        return result

    @retry
    def mark_chart_as_viewed(self, chartkey):
        self.redis.set(self._key('cv', chartkey), 1, ex=5)

    @retry
    def is_chart_viewed(self, chartkey):
        return bool(self.redis.get(self._key('cv', chartkey)))

    @retry
    def save_selector_hosts(self, selector_vtype, selector_key, hosts):
        key = self._key('s2hh', selector_vtype, selector_key)
        with self.redis.pipeline() as pipeline:
            pipeline.multi()
            pipeline.delete(key)
            if hosts:
                pipeline.sadd(key, *hosts)
            pipeline.execute()

    @retry
    def get_secret_key(self):
        while True:
            key = self.redis.get("config.secret_key")
            if key is not None:
                return key
            key = ''.join('%02x' % x for x in os.urandom(48))
            if self.redis.setnx("config.secret_key", key):
                return key


class _RedisTmpStorageManager(object):
    def __init__(self, client):
        self.client = client
        self.tmp_storages = []

    def __enter__(self):
        return self

    def make_name(self):
        name = 'tmp\0%s' % (uuid.uuid4().hex, )
        self.tmp_storages.append(name)
        return name

    def __exit__(self, exc_type, exc_value, tb):
        for name in self.tmp_storages:
            self.client.delete(name)
