import json
import logging
import grpc
import random

from tornado import gen
from tornado.httpclient import AsyncHTTPClient

from enum import Enum
from collections import defaultdict

log = logging.getLogger(__name__)

DEFAULT_GRPC_USER_AGENT = "yasmgateway"


class ClusterType(Enum):
    PRESTABLE = 'prestable'
    PRODUCTION = 'production'

    def __str__(self):
        return self.value


CLUSTER_TYPE_TO_DISCOVERY_URL = {
    ClusterType.PRESTABLE: "https://solomon.yandex.net/discovery/PRESTABLE/gateway.json",
    ClusterType.PRODUCTION: "https://solomon.yandex.net/discovery/PRODUCTION/gateway.json"
}


class GatewayHost(object):
    def __init__(self, fqdn, cluster, channel):
        self.fqdn = fqdn
        self.cluster = cluster

        self.channel = channel
        self._channel_state = None

        self._try_to_connect_and_subscribe(False)

    def check_connection(self):
        if self._channel_state == grpc.ChannelConnectivity.READY:
            return True
        self._try_to_connect_and_subscribe(True)
        return False

    def _try_to_connect_and_subscribe(self, unsubscribe):
        if unsubscribe:
            self.channel.unsubscribe(self._on_channel_state_changed)
        self.channel.subscribe(self._on_channel_state_changed, True)

    def _on_channel_state_changed(self, state):
        self._channel_state = state


class SolomonGatewayClusterProvider(object):
    DISCOVERY_API_FAILURE_METRIC_NAME = "cluster_provider.discovery_api_failures"
    DISCOVERY_API_SUCCESS_METRIC_NAME = "cluster_provider.discovery_api_success"

    def __init__(self,
                 cluster_type=ClusterType.PRESTABLE,
                 grpc_user_agent=DEFAULT_GRPC_USER_AGENT,
                 discovery_url=None,
                 unistat=None):
        self._discovery_url = CLUSTER_TYPE_TO_DISCOVERY_URL[cluster_type] if not discovery_url else discovery_url
        self._unistat = unistat
        self._prepare_unistat_metrics()
        self._endpoint_to_host = {}
        self._cluster_hosts = None  # calculated on demand
        self.grpc_user_agent = grpc_user_agent

        # fields needed for self scheduling
        self._schedule_io_loop = None
        self._schedule_interval = None
        self._schedule_jitter = None

    def get_cluster_hosts(self):
        if self._cluster_hosts is None:
            grouped_cluster_hosts = defaultdict(list)
            for host in self._endpoint_to_host.itervalues():
                grouped_cluster_hosts[host.cluster].append(host)
            self._cluster_hosts = [(cluster, hosts) for cluster, hosts in grouped_cluster_hosts.iteritems()]
        return self._cluster_hosts

    @gen.coroutine
    def reload(self):
        log.info("Loading solomon gateway instances")
        try:
            gateway_json_response = yield AsyncHTTPClient().fetch(self._discovery_url, request_timeout=30)
            self.reload_from_str(gateway_json_response.body)
        except Exception:
            log.exception("Failed to update solomon gateway instances")
            self._push_unistat_metric(self.DISCOVERY_API_FAILURE_METRIC_NAME, 1)
        else:
            self._push_unistat_metric(self.DISCOVERY_API_SUCCESS_METRIC_NAME, 1)
        log.info("Current number of solomon gateway instances: {}".format(len(self._endpoint_to_host)))
        self._schedule_next_reload()

    def reload_from_str(self, discovery_response):
        gateway_instances = json.loads(discovery_response)

        port = gateway_instances["ports"]["grpc"]

        old_endpoints = sorted(self._endpoint_to_host.iterkeys())
        new_endpoints = sorted("{}:{}".format(host["fqdn"], port) for host in gateway_instances["hosts"])

        if old_endpoints != new_endpoints:
            new_endpoint_to_host = {}
            log.info("Solomon gateway endpoint set changed")
            for host_info in gateway_instances["hosts"]:
                fqdn = host_info["fqdn"]
                endpoint = "{}:{}".format(fqdn, port)
                old_host = self._endpoint_to_host.get(endpoint)
                if old_host is not None:
                    new_endpoint_to_host[endpoint] = old_host
                else:
                    new_endpoint_to_host[endpoint] = GatewayHost(
                        fqdn,
                        host_info["cluster"],
                        self._create_new_channel(endpoint)
                    )
                    log.info("New solomon gateway endpoint: {}".format(endpoint))
            self._endpoint_to_host = new_endpoint_to_host
            self._cluster_hosts = None

    def schedule_reloads(self, ioloop, interval, jitter):
        self._schedule_io_loop = ioloop
        self._schedule_interval = interval
        self._schedule_jitter = jitter
        self._schedule_io_loop.add_callback(self._schedule_next_reload)

    def _schedule_next_reload(self):
        if self._schedule_io_loop is not None:
            new_delay = self._schedule_interval + random.randint(-self._schedule_jitter, self._schedule_jitter)
            self._schedule_io_loop.call_later(new_delay, self.reload)

    def _prepare_unistat_metrics(self):
        if self._unistat:
            self._unistat.create_float(self.DISCOVERY_API_FAILURE_METRIC_NAME)
            self._unistat.create_float(self.DISCOVERY_API_SUCCESS_METRIC_NAME)

    def _push_unistat_metric(self, metric, value):
        if self._unistat:
            self._unistat.push(metric, value)

    def _create_new_channel(self, endpoint):
        channel_options = [
            ('grpc.max_receive_message_length', 128 * 1024 * 1024),
            ('grpc.max_send_message_length', 128 * 1024 * 1024),
            ('grpc.primary_user_agent', self.grpc_user_agent),
            ('grpc.initial_reconnect_backoff_ms', 1000),
            ('grpc.max_reconnect_backoff_ms', 10000),
            ('grpc.keepalive_time_ms', 5000),
            ('grpc.keepalive_timeout_ms', 2000)
        ]
        return grpc.insecure_channel(endpoint, channel_options)
