import time

from enum import Enum
from tornado import gen

from solomon.services.gateway.api import yasm_gateway_service_pb2 as gateway_service_proto
from solomon.services.gateway.api import yasm_gateway_service_pb2_grpc as gateway_service_grpc

from infra.yasm.gateway.lib.client.requester import MultiAttemptRpc, ReadyHostAttemptSequence, GrpcCallsFailed
from infra.yasm.gateway.lib.client.requester import SOLOMON_GATEWAY_CLUSTER_ATTEMPTS
from infra.yasm.gateway.lib.client.requester import SOLOMON_GATEWAY_IN_CLUSTER_REQUEST_ATTEMPTS
from infra.yasm.gateway.lib.handlers.hist_base import HistFrontHandler
from infra.yasm.gateway.lib.tags.errors import TagsError
from infra.yasm.gateway.lib.tags.request import RequestKey
from infra.yasm.gateway.lib.util.common import adjust_timestamp, parse_proto_timeseries_with_range
from infra.yasm.gateway.lib.util.verification import is_correct_tag_name

SIGNAL_ARG = "signal"  # required to build request key
GROUP_ARG = "group"  # required to build request key
COUNT_ARG = "count"  # number of hosts to return
METHOD_ARG = "method"  # top method
NO_TAGS_TOP_ARGS = {
    SIGNAL_ARG,
    GROUP_ARG,
    COUNT_ARG,
    METHOD_ARG,
    "start",  # allowed but ignored
    "end",  # allowed but ignored
    "period"  # allowed but ignored
}
TOP_PERIOD = 5
TOP_INTERVAL = 20 * 60  # 20 min
DEFAULT_TOP_HOST_COUNT = 5


class TopMethod(str, Enum):
    max_ = "max"
    min_ = "min"
    maxavg = "maxavg"
    minavg = "minavg"


DEFAULT_TOP_METHOD = TopMethod.max_

TOP_REQUEST_TIMEOUT = 40


class TopHandler(HistFrontHandler):
    HANDLE_STAT_NAME = "top"

    def initialize(self, cluster_provider, unistat, front_id):
        super(TopHandler, self).initialize(unistat, front_id)
        self._cluster_provider = cluster_provider

    @staticmethod
    def build_request(group, signal, request_key, count, method):
        now = time.time()
        request_proto = gateway_service_proto.QueryTopRequest()
        request_proto.query.hosts = group
        request_proto.query.tags = request_key.to_dynamic_string()
        request_proto.query.expression = signal
        request_proto.query.from_millis = adjust_timestamp(now - TOP_INTERVAL - TOP_PERIOD, TOP_PERIOD) * 1000
        request_proto.query.to_millis = adjust_timestamp(now, TOP_PERIOD) * 1000
        request_proto.query.grid_millis = TOP_PERIOD * 1000
        request_proto.limit = count
        request_proto.fill_timeseries = True
        if method == TopMethod.max_:
            request_proto.aggregation_function = gateway_service_proto.QueryTopRequest.MAX
            request_proto.sort_direction = gateway_service_proto.QueryTopRequest.DESCENDING
        elif method == TopMethod.min_:
            request_proto.aggregation_function = gateway_service_proto.QueryTopRequest.MIN
            request_proto.sort_direction = gateway_service_proto.QueryTopRequest.ASCENDING
        elif method == TopMethod.maxavg:
            request_proto.aggregation_function = gateway_service_proto.QueryTopRequest.AVG
            request_proto.sort_direction = gateway_service_proto.QueryTopRequest.DESCENDING
        elif method == TopMethod.minavg:
            request_proto.aggregation_function = gateway_service_proto.QueryTopRequest.AVG
            request_proto.sort_direction = gateway_service_proto.QueryTopRequest.ASCENDING
        return request_proto

    @staticmethod
    def pack_response(response_proto, request_proto):
        st = request_proto.query.from_millis / 1000
        et = request_proto.query.to_millis / 1000
        return {
            "top": [
                {
                    "host": host_data.host_name,
                    "scores": host_data.aggregated_value,
                    "timeline": list(xrange(st, et, TOP_PERIOD)),
                    "points": parse_proto_timeseries_with_range(host_data.timeseries, st, et, TOP_PERIOD).values
                }
                for host_data in response_proto.top_hosts
            ]
        }

    @gen.coroutine
    def handle_get(self):
        query_args = self.get_query_args()
        tags_dict = {}
        for arg, value in query_args.iteritems():
            if arg not in NO_TAGS_TOP_ARGS:
                if is_correct_tag_name(arg):
                    tags_dict[arg] = value
                else:
                    self.error("Invalid argument: {}".format(arg), status_code=400)
                    return
        try:
            request_key = RequestKey.from_dict(tags_dict)
        except TagsError as e:
            self.log.exception("Failed to build request key due to invalid tags")
            self.error("Invalid tags: {}".format(e), status_code=400)
            return

        group = self.get_argument(GROUP_ARG)
        signal = self.get_argument(SIGNAL_ARG)
        count = self.get_int_argument(COUNT_ARG, DEFAULT_TOP_HOST_COUNT)
        method = self.get_argument(METHOD_ARG, default=DEFAULT_TOP_METHOD)

        request_proto = self.build_request(group, signal, request_key, count, method)

        try:
            call = MultiAttemptRpc(
                ReadyHostAttemptSequence.make_random(
                    self._cluster_provider.get_cluster_hosts(),
                    SOLOMON_GATEWAY_CLUSTER_ATTEMPTS,
                    SOLOMON_GATEWAY_IN_CLUSTER_REQUEST_ATTEMPTS
                ),
                gateway_service_grpc.YasmGatewayServiceStub,
                rpc_name="QueryTop",
                request=request_proto,
                timeout=TOP_REQUEST_TIMEOUT,
                client_id=self._cluster_provider.grpc_user_agent,
                unistat=self.unistat,
                log=self.log
            )
            response_proto = yield call.future()
            response = self.pack_response(response_proto, request_proto)
            self.response(response)
        except GrpcCallsFailed as e:
            self.log.exception("Can't get info from solomon gateway")
            self.error(e.message)
