import copy
import grpc
import logging
import time
import random

from tornado import gen
from tornado.ioloop import IOLoop

from infra.yasm.gateway.lib.util import stat

from cluster_provider import DEFAULT_GRPC_USER_AGENT

module_log = logging.getLogger(__name__)

SOLOMON_GATEWAY_CLUSTER_ATTEMPTS = 2
SOLOMON_GATEWAY_IN_CLUSTER_REQUEST_ATTEMPTS = 2


def _fwrap(f, gf):
    try:
        f.set_result(gf.result())
    except Exception as e:
        f.set_exception(e)


def fwrap(gf, ioloop=None):
    """
    Wraps a GRPC result in a future that can be yielded by tornado.
    This is the solution recommended by grpc: https://github.com/grpc/grpc/wiki/Integration-with-tornado-(python)
    """
    f = gen.Future()

    if ioloop is None:
        ioloop = IOLoop.current()

    gf.add_done_callback(lambda _: ioloop.add_callback(_fwrap, f, gf))
    return f


class NoHostsToRequest(Exception):
    """Thrown when there are no hosts to request"""


class GrpcCallsFailed(Exception):
    """Thrown when all grpc call attempts failed"""


class ReadyHostAttemptSequence(object):
    """
    Host sequence to make call with retries on. Returns hosts with ready channels first.
    """
    def __init__(self, sequence):
        self._sequence = sequence
        if not self._sequence:
            raise NoHostsToRequest("No solomon gateway hosts to request")

    @staticmethod
    def make_random(hosts_per_cluster, cluster_attempts, in_cluster_attempts):
        hosts_per_cluster = random.sample(hosts_per_cluster, min(cluster_attempts, len(hosts_per_cluster)))
        sequence = []
        for cluster, hosts in hosts_per_cluster:
            sequence.extend(random.sample(hosts, min(in_cluster_attempts, len(hosts))))
        random.shuffle(sequence)
        if not sequence:
            raise NoHostsToRequest("No solomon gateway hosts to request")
        return ReadyHostAttemptSequence(sequence)

    def pop(self, log=module_log):
        # find first ready channel and move it to the front of the sequence
        for candidate_idx in xrange(len(self._sequence) - 1, -1, -1):
            candidate = self._sequence[candidate_idx]
            if candidate.check_connection():
                self._sequence[candidate_idx], self._sequence[-1] = self._sequence[-1], candidate
                break
            else:
                log.warning("Channel to {} is not ready".format(candidate.fqdn))

        return self._sequence.pop()

    def clone(self):
        return ReadyHostAttemptSequence(copy.copy(self._sequence))

    def empty(self):
        return not self._sequence


class MultiAttemptRpc(object):
    TIME_METRIC_NAME_FMT = "gateway_call.{}{}_total_time"
    FAIL_METRIC_NAME_FMT = "gateway_call.{}{}_failed"
    CALL_TIME_METRIC_NAME_FMT = "gateway_call.{}{}_grpc_time"
    CALL_ERROR_METRIC_NAME_FMT = "gateway_call.{}{}_grpc_errors"

    RETRY_CODES = {
        grpc.StatusCode.UNAVAILABLE,
        grpc.StatusCode.DEADLINE_EXCEEDED,
        grpc.StatusCode.RESOURCE_EXHAUSTED,
        grpc.StatusCode.INTERNAL
    }

    def __init__(self, host_sequence, stub_cls, rpc_name, request, timeout,
                 client_id=DEFAULT_GRPC_USER_AGENT, unistat=None, time_func=time.time,
                 log=module_log, signals_prefix=None):
        self._host_sequence = host_sequence
        self._stub_cls = stub_cls
        self._rpc_name = rpc_name
        self._request = request
        self._client_id = client_id

        self._timeout = timeout
        self._deadline = None  # will be initialized on call to request
        self._time_func = time_func

        self._log = log

        # stat collection
        self._unistat = unistat
        self._signals_prefix = "" if signals_prefix is None else "{}.".format(signals_prefix)
        self._prepare_unistat_metrics()

    def _prepare_unistat_metrics(self):
        self._rpc_signal_name = self._rpc_name.lower()
        self._TIME_METRIC_NAME = self.TIME_METRIC_NAME_FMT.format(self._signals_prefix,
                                                                  self._rpc_signal_name)
        self._FAIL_METRIC_NAME = self.FAIL_METRIC_NAME_FMT.format(self._signals_prefix,
                                                                  self._rpc_signal_name)
        self._CALL_TIME_METRIC_NAME = self.CALL_TIME_METRIC_NAME_FMT.format(self._signals_prefix,
                                                                            self._rpc_signal_name)
        self._CALL_ERROR_METRIC_NAME = self.CALL_ERROR_METRIC_NAME_FMT.format(self._signals_prefix,
                                                                              self._rpc_signal_name)
        if self._unistat:
            self._unistat.create_histogram(self._TIME_METRIC_NAME, intervals=stat.TIMING_INTERVALS_LOG_1_3)
            self._unistat.create_float(self._FAIL_METRIC_NAME)
            self._unistat.create_histogram(self._CALL_TIME_METRIC_NAME, intervals=stat.TIMING_INTERVALS_LOG_1_3)
            self._unistat.create_float(self._CALL_ERROR_METRIC_NAME)

    def _push_metric_value(self, metric, value):
        if self._unistat:
            self._unistat.push(metric, value)

    @gen.coroutine
    def future(self):
        with stat.time_it(self._TIME_METRIC_NAME, self._unistat):
            self._deadline = self._time_func() + self._timeout
            host = self._host_sequence.pop(self._log)
            last_call_on_host = host
            last_error_code = None
            attempts = 0
            while host:
                timeout = self._deadline - self._time_func()
                if timeout <= 0:
                    break
                last_call_on_host = host
                attempts += 1
                stub = self._stub_cls(host.channel)
                rpc = getattr(stub, self._rpc_name)
                call_metadata = [
                    ("x-solomon-clientid", self._client_id),
                    ("x-solomon-created-at", str(int(self._time_func() * 1000)))
                ]
                try:
                    with stat.time_it(self._CALL_TIME_METRIC_NAME, self._unistat):
                        result = yield fwrap(rpc.future(self._request, timeout, metadata=call_metadata))
                    raise gen.Return(result)
                except grpc.RpcError as rpc_error:
                    host = None
                    last_error_code = rpc_error.code() if isinstance(rpc_error, grpc.Call) else None
                    self._log.exception("Call to {} failed (error_code: {}).".format(last_call_on_host.fqdn, last_error_code))

                    if last_error_code in self.RETRY_CODES and not self._host_sequence.empty():
                        host = self._host_sequence.pop(self._log)
                        self._log.info("Retrying on {}.".format(host.fqdn))

                    self._push_metric_value(self._CALL_ERROR_METRIC_NAME, 1)

            self._push_metric_value(self._FAIL_METRIC_NAME, 1)
            raise GrpcCallsFailed("All ({}) attempts to call {} failed. Last call to {} failed with code {}".format(
                attempts,
                self._rpc_name,
                last_call_on_host.fqdn,
                last_error_code
            ))
