# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import random
import typing
from functools import partial
from typing import Callable, Any, Union, List

import six

from travel.rasp.library.python.db.replica_health import ReplicaHealth
from travel.rasp.library.python.db.utils import TimeInterval, get_dc_priority

log = logging.getLogger(__name__)


LOW_INSTANCE_PRIORITY = 9999


class ClusterException(Exception):
    pass


class ConnectionFailure(Exception):
    pass


class DbInstance(object):
    def __init__(self, host, is_master=False, dc=None, priority=LOW_INSTANCE_PRIORITY, health=ReplicaHealth.ALIVE):
        # type: (str, bool, typing.Optional[str], int, str) -> None
        self.host = host
        self.is_master = is_master
        self.dc = dc
        self.priority = priority
        self.health = health

    @property
    def is_replica(self):
        return not self.is_master

    def __eq__(self, instance_or_host):
        if isinstance(instance_or_host, DbInstance):
            return self.host == instance_or_host.host
        elif isinstance(instance_or_host, six.string_types):
            return self.host == instance_or_host
        else:
            raise Exception("Can't compare {} with {} ({})".format(type(self), type(instance_or_host), instance_or_host))

    def __hash__(self):
        return hash(self.host)

    def __repr__(self):
        return '{} ({})'.format(self.host, 'master' if self.is_master else 'replica')

    def __str__(self):
        return self.host


class ClusterBase(object):
    """
    Represents a database cluster with abilities:
    - get instances of cluster
    - connect to specific instance
    - connect to "cluster" - i.e. automatically chosen instance (by sorting and filtering of instances)
    - distinction between master and replicas
    - keep track of aliveness of instances
    """

    def __init__(
        self,
        connection_getter=None,    # type: Callable[[DbInstance], Any]  # Should raise ConnectionFailure to mark instance as dead
        instance_filter=None,      # type: Callable[[DbInstance], bool]
        instance_sort_key=None,    # type: Callable[[DbInstance, str], tuple]
        instance_dead_ttl=30,     # type: int
        log=log,                   # type: logging.Logger
    ):
        self.log = log
        self.connection_getter = connection_getter
        self.instance_filter = instance_filter or self.default_instance_filter_func
        self.instance_sort_key = instance_sort_key or self.default_instance_sort_func

        self._instances = []  # type: List[DbInstance]

        self.instance_dead_ttl = instance_dead_ttl
        self._instances_dead = {}  # type: typing.Dict[DbInstance, TimeInterval]

    def get_actual_instances_list(self):
        # type: () -> typing.List[DbInstance]
        """
        Получение списка инстансов кластера
        """
        raise NotImplementedError

    def update_cluster_configuration(self):
        self._instances = self.get_actual_instances_list()

    @property
    def instances(self):
        # type: () -> List[DbInstance]
        if not self._instances:
            self.update_cluster_configuration()

        return self._instances

    def get_instances(self, current_dc=None):
        # type: (typing.Optional[str]) -> typing.List[DbInstance]
        """ Return filtered and sorted instances of a cluster """
        filtered = [inst for inst in self.instances if self.instance_filter(inst)]
        sort_key = partial(self.instance_sort_key, current_dc=current_dc)
        return sorted(filtered, key=sort_key)

    def get_connection_to_instance(self, inst):
        # type: (DbInstance) -> Any
        try:
            return self.connection_getter(inst)
        except ConnectionFailure as ex:
            self.log.warning("Couldn't to connect to %s: %s", inst, ex)
            self.set_instance_dead(inst, reason=repr(ex))
            raise

    def get_connection(self, current_dc=None):
        for inst in self.get_instances(current_dc):
            try:
                conn = self.get_connection_to_instance(inst)
            except ConnectionFailure:
                pass
            else:
                self.log.debug("Connected to instance: %s", conn)
                break
        else:
            msg = "Failed to connect to cluster: {}: {}".format(
                self,
                ', '.join(self.get_inst_alive_description(inst) for inst in self.get_instances())
            )
            self.log.error(msg)
            raise ClusterException(msg)

        return conn

    def set_instance_dead(self, inst,  reason=""):
        self._instances_dead[inst] = TimeInterval(self.instance_dead_ttl, reason=reason)

    def set_instance_alive(self, inst):
        self._instances_dead.pop(inst, None)

    def is_instance_alive(self, inst):
        if inst.health != ReplicaHealth.ALIVE:
            return False

        inst_dead = self._instances_dead.get(inst)
        if inst_dead and not inst_dead.is_time_passed:
            return False
        else:
            return True

    def get_inst_alive_description(self, inst):
        if self.is_instance_alive(inst):
            return '{} is alive'.format(inst)

        if inst.health != ReplicaHealth.ALIVE:
            return '{} is dead: health is {} but expected {}'.format(inst, inst.health, ReplicaHealth.ALIVE)

        inst_dead = self._instances_dead[inst]
        return '{} is dead: {}'.format(inst, inst_dead.get('reason', 'unknown reason'))

    def default_instance_sort_func(self, inst, current_dc=None):
        # type: (DbInstance, str) -> tuple
        return (
            0 if self.is_instance_alive(inst) else 1,
            0 if not current_dc else get_dc_priority(current_dc, inst.dc, default_value=LOW_INSTANCE_PRIORITY),
            inst.priority,
            0 if not inst.is_master else 1,
            random.random()
        )

    def default_instance_filter_func(self, inst):
        return True


class ClusterConst(ClusterBase):
    def __init__(self, instances, *args, **kwargs):
        # type: (typing.List[DbInstance], typing.List, typing.Dict) -> None
        super(ClusterConst, self).__init__(*args, **kwargs)
        self.__instances = instances

    def get_actual_instances_list(self):
        return self.__instances

    def __repr__(self):
        return 'ClusterConst: {}'.format(self.get_instances())


class ClusterPeriodicUpdateMixin(object):
    """
    Mixin for ClusterBase to update cluster information after some time.
    """
    def __init__(self, cluster_info_ttl=90, raise_on_update_fail=False, *args, **kwargs):
        # type: (int, bool, *Any, **typing.Any) -> None
        super(ClusterPeriodicUpdateMixin, self).__init__(*args, **kwargs)

        self.cluster_info_ttl = cluster_info_ttl
        self.raise_on_update_fail = raise_on_update_fail

        self._cluster_info_actual = TimeInterval(self.cluster_info_ttl)

    @property
    def instances(self):
        # type: (Union[ClusterBase, ClusterPeriodicUpdateMixin]) -> List[DbInstance]
        if self._cluster_info_actual.is_time_passed:
            try:
                self.update_cluster_configuration()
            except Exception:
                self.log.exception("Unable to update cluster configuration")

                # if not raising -> we keep old instances list
                if self.raise_on_update_fail:
                    raise

        return super(ClusterPeriodicUpdateMixin, self).instances

    def on_before_update_cluster_configuration(self):
        """ Hook to do some setup before actual configuration update, if it's required. """

    def update_cluster_configuration(self):
        # type: (Union[ClusterBase, ClusterPeriodicUpdateMixin]) -> None

        self.on_before_update_cluster_configuration()
        super(ClusterPeriodicUpdateMixin, self).update_cluster_configuration()
        self._cluster_info_actual.reset()
