# -*- coding: utf-8 -*-
import weakref
from datetime import datetime, timedelta
from logging import getLogger

from django.core.cache.backends.base import DEFAULT_TIMEOUT
from django_redis.client import DefaultClient
from redis.client import Redis
from redis.exceptions import (ConnectionError, ResponseError, TimeoutError)
from redis.sentinel import Sentinel, SentinelConnectionPool, SentinelManagedConnection, SlaveNotFoundError
from retrying import retry

from travel.library.python.avia_mdb_replica_info.avia_mdb_replica_info.ping import ping_hosts
from travel.library.python.solomon.metrics import ExplicitHistogramRateMetric, GaugeMetric, RateMetric

log = getLogger(__name__)


class Timer(object):
    def __init__(self):
        self.start_time = datetime.now()

    @property
    def elapsed(self):
        return datetime.now() - self.start_time

    def get_elapsed_seconds(self):
        return self.elapsed.total_seconds()

    def seconds_to_delta(self, delta):
        return (delta - self.elapsed).total_seconds()

    def seconds_to_shift(self, shift):
        return self.seconds_to_delta(timedelta(seconds=shift))


class RedisSentinelManagedConnection(SentinelManagedConnection):
    @retry(
        wait_exponential_multiplier=100,
        wait_exponential_max=1000,
        stop_max_attempt_number=5,
    )
    def connect(self):
        if self._sock:
            return

        if self.connection_pool.is_master:
            master = self.connection_pool.get_master_address()
            log.info('Try to connect to master: %s', master)
            self.connect_to(master)
            log.info('Connected to master: %s', master)
            return

        for slave in self.connection_pool.near_hosts():
            try:
                log.info('Try to connect to slave: %s', slave)
                self.connect_to(slave)
                log.info('Connected to slave: %s', slave)
                return
            except ConnectionError:
                continue
        raise SlaveNotFoundError  # Never be here


class RedisSentinelConnectionPool(SentinelConnectionPool):
    def __init__(self, service_name, sentinel_manager, **kwargs):
        kwargs['connection_class'] = kwargs.get('connection_class', RedisSentinelManagedConnection)
        self.is_master = kwargs.pop('is_master', True)
        self.check_connection = kwargs.pop('check_connection', False)
        super(SentinelConnectionPool, self).__init__(**kwargs)
        self.connection_kwargs['connection_pool'] = weakref.proxy(self)
        self.service_name = service_name
        self.sentinel_manager = sentinel_manager

    def near_host(self):
        hosts = self.sentinel_manager.discover_hosts(self.service_name)
        if hosts:
            return hosts[0]
        raise SlaveNotFoundError('No host found for %r' % self.service_name)

    def near_hosts(self):
        return self.sentinel_manager.discover_hosts(self.service_name)


class RedisSentinel(Sentinel):
    def __init__(self, sentinels, min_other_sentinels=0, sentinel_kwargs=None, **connection_kwargs):
        super(RedisSentinel, self).__init__(sentinels, min_other_sentinels=0, sentinel_kwargs=None, **connection_kwargs)
        self.ping_by_host = {}

    def update_ping(self, hosts):
        for h in hosts:
            if not h[0] in self.ping_by_host:
                return True
        return False

    def discover_hosts(self, service_name):
        for sentinel in self.sentinels:
            try:
                hosts = sentinel.sentinel_slaves(service_name)
            except (ConnectionError, ResponseError, TimeoutError):
                continue
            hosts = self.filter_slaves(hosts)
            if hosts:
                hosts.append(self.discover_master(service_name))
                if self.update_ping(hosts):
                    self.ping_by_host = ping_hosts([h[0] for h in hosts])
                    log.info('Redis hosts ping: %s', self.ping_by_host)
                hosts.sort(key=lambda h: self.ping_by_host.get(h[0], float('inf')))
                return hosts
        return [self.discover_master(service_name)]

    def master_for(self, service_name, redis_class=Redis, connection_pool_class=RedisSentinelConnectionPool, **kwargs):
        kwargs['is_master'] = True
        connection_kwargs = dict(self.connection_kwargs)
        connection_kwargs.update(kwargs)
        return redis_class(connection_pool=connection_pool_class(service_name, self, **connection_kwargs))

    def slave_for(self, service_name, redis_class=Redis, connection_pool_class=RedisSentinelConnectionPool, **kwargs):
        kwargs['is_master'] = False
        connection_kwargs = dict(self.connection_kwargs)
        connection_kwargs.update(kwargs)
        return redis_class(connection_pool=connection_pool_class(service_name, self, **connection_kwargs))


class RedisClient(DefaultClient):
    def __init__(self, server, params, backend):
        super(RedisClient, self).__init__(server, params, backend)
        self._hosts = self._options['HOSTS']

        self.service_name = self._options['SENTINEL_SERVICE_NAME']
        self.master_only = False
        if 'MASTER_ONLY' in self._options:
            self.master_only = self._options['MASTER_ONLY']
        self.sentinel = RedisSentinel(self._hosts, socket_timeout=self._options['SOCKET_TIMEOUT'])
        options = {'password': self._options['PASSWORD']} if self._options['PASSWORD'] else {}
        options['retry_on_timeout'] = True
        options['socket_timeout'] = self._options['SOCKET_TIMEOUT']
        self.master = self.sentinel.master_for(self.service_name, **options)
        self.slave = self.sentinel.slave_for(self.service_name, **options)
        log.info('Redis sentinel init hosts %s with service name %s', self._hosts, self.service_name)

        self.monitoring_metric_queue = self._options['monitoring_metric_queue']
        self.sensor_requests = '{}.requests'.format(self._options['monitoring_sensor_prefix'])
        self.sensor_timings = '{}.timings'.format(self._options['monitoring_sensor_prefix'])
        self.sensor_timings_histogram = '{}.timings-histogram'.format(self._options['monitoring_sensor_prefix'])

    def _send_metrics(self, method_name, elapsed_seconds):
        if not self.monitoring_metric_queue:
            return

        seconds = int(elapsed_seconds * 1000)
        labels = {
            'cache_method': method_name,
            'cache_backend': 'redis',
        }
        self.monitoring_metric_queue.put_nowait(RateMetric(self.sensor_requests, labels))
        self.monitoring_metric_queue.put_nowait(GaugeMetric(self.sensor_timings, labels, seconds))
        self.monitoring_metric_queue.put_nowait(
            ExplicitHistogramRateMetric(self.sensor_timings_histogram, labels, seconds),
        )

    def call_method(self, method_name, *args, **kwargs):
        key = args[0] if len(args) > 0 and method_name not in {'set_many'} else ''
        timer = Timer()
        val = getattr(super(RedisClient, self), method_name)(*args, **kwargs)
        elapsed_seconds = timer.elapsed.total_seconds()
        log.info('Redis %s %s %s', method_name, elapsed_seconds, key)

        try:
            self._send_metrics(method_name, elapsed_seconds)
        except Exception:
            log.exception('Error sending metrics')

        return val

    def get_client(self, write=True, tried=(), show_index=False):
        client = self.master if write or self.master_only else self.slave
        if show_index:
            return client, 0
        else:
            return client

    def get(self, key, default=None, version=None, client=None):
        return self.call_method('get', key, default=default, version=version, client=client)

    def get_many(self, keys, version=None, client=None):
        return self.call_method('get_many', keys, version=version, client=client)

    def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None, nx=False, xx=False):
        self.call_method('set', key, value, timeout=timeout, version=version, client=client, nx=nx, xx=xx)

    def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None, client=None):
        self.call_method('set_many', data, timeout=timeout, version=version, client=client)

    def delete(self, key, version=None, prefix=None, client=None):
        self.call_method('delete', key, version=version, prefix=prefix, client=client)

    def delete_many(self, keys, version=None, client=None):
        self.call_method('delete_many', keys, version=version, client=client)

    def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None):
        return self.call_method('set', key, value, timeout=timeout, version=version, client=client, nx=True)
