import time

from itertools import chain
from collections import namedtuple
from tornado import gen

from solomon.services.gateway.api import yasm_gateway_service_pb2 as gateway_service_proto
from solomon.services.gateway.api import yasm_gateway_service_pb2_grpc as gateway_service_grpc

from infra.yasm.gateway.lib.client.requester import MultiAttemptRpc, ReadyHostAttemptSequence
from infra.yasm.gateway.lib.client.requester import SOLOMON_GATEWAY_CLUSTER_ATTEMPTS
from infra.yasm.gateway.lib.client.requester import SOLOMON_GATEWAY_IN_CLUSTER_REQUEST_ATTEMPTS
from infra.yasm.gateway.lib.handlers.base import BaseGatewayHandler
from infra.yasm.gateway.lib.tags.request import RequestKey
from infra.yasm.gateway.lib.tags.errors import TagsError
from infra.yasm.gateway.lib.util.common import adjust_timestamp, parse_proto_timeseries_with_range, TimeSeries

RT_PERIOD = 5
RT_REQUEST_TIMEOUT = 5

RtOptions = namedtuple("RtOptions", [
    "signals_limit",
    "range_start_offset",
    "range_end_offset",
])
FullRequestKey = namedtuple("FullRequestKey", [
    "host",
    "request_key",
    "signal_expr",
    "raw_tags"
])


def fill_timeseries_result(timeseries, messages, without_messages):
    result = []
    for i, value in enumerate(timeseries.values):
        record = {
            "timestamp": timeseries.start_ts + i * timeseries.period,
            "value": value
        }
        if not without_messages:
            record["errors"] = sorted(messages) if messages else []
        result.append(record)
    return result


class SignalLimitError(Exception):
    """Raise when user request contains too many signals"""


class RtRequestDataProcessor(object):
    FILTERED_SIGNALS_COUNT_SIGNAL = "rt.filtered_signals"
    REQUEST_SIZE_SIGNAL = "rt.request_size"
    DELAY_SIGNAL = "rt.delay"

    def __init__(self, cluster_provider, yasmconf, options, unistat, log):
        self._cluster_provider = cluster_provider
        self._yasmconf = yasmconf
        self._options = options
        self._unistat = unistat
        self._log = log

    @classmethod
    def prepare_unistat_signals(cls, unistat):
        unistat.create_float(cls.FILTERED_SIGNALS_COUNT_SIGNAL)
        unistat.create_histogram(
            cls.REQUEST_SIZE_SIGNAL,
            intervals=list(chain(
                xrange(0, 100, 10),
                xrange(100, 500, 50),
                xrange(500, 1000, 100),
                xrange(1000, 11000, 1000)))
        )
        unistat.create_histogram(
            cls.DELAY_SIGNAL,
            intervals=list(chain(
                xrange(0, 60, 5),
                xrange(60, 120, 10),
                xrange(120, 300, 20),
                xrange(300, 3900, 200)))
        )

    def parse_requested_signals(self, signals):
        """
        Returns tuple of (parsed requests, invalid signals)
        :type signals (Iterable[str])
        :return Tuple[Set[FullRequestKey], Dict[str, str]]
        """
        if len(signals) > self._options.signals_limit:
            self._unistat.push(self.FILTERED_SIGNALS_COUNT_SIGNAL, len(signals))
            raise SignalLimitError()
        self._unistat.push(self.REQUEST_SIZE_SIGNAL, len(signals))

        full_keys = set()  # remove possible duplicates
        invalid_signals = {}
        for signal in signals:
            try:
                try:
                    host, tag, sig = signal.split(":")
                except ValueError as e:
                    raise ValueError("Wrong request {}, {}".format(signal, str(e)))
                if not host:
                    raise ValueError("Group name is missing in {}".format(signal))
                if not sig:
                    raise ValueError("Signal name missing in {}".format(signal))

                # ensure we have a valid tag
                request_key = RequestKey.from_string(tag)

                required_tags = self._yasmconf.required_tags(request_key.itype)
                if not all(request_key.has_tag_bound(tag) for tag in required_tags):
                    raise ValueError("Request should contain tags: {}".format(", ".join(required_tags)))

                full_keys.add(FullRequestKey(host=host, request_key=request_key, signal_expr=sig, raw_tags=tag))

            except (ValueError, TagsError) as e:
                invalid_signals[signal] = str(e)

        return list(full_keys), invalid_signals

    def pack_result(self, proto_response, requested_keys, invalid_signals, without_messages, now,
                    request_errors=None):
        result = {}
        request_errors = request_errors if request_errors is not None else []

        adjusted_now = adjust_timestamp(now, RT_PERIOD)
        start_ts = adjusted_now - self._options.range_start_offset
        end_ts = adjusted_now - self._options.range_end_offset

        def make_empty_ts():
            return TimeSeries([None] * ((end_ts - start_ts) / RT_PERIOD), start_ts, RT_PERIOD)

        if proto_response and len(proto_response.timeseries) != len(requested_keys):
            self._log.waring("Solomon gateway returned {} timeseries instead of {} requested".format(
                len(proto_response.timeseries),
                len(requested_keys)
            ))
            request_errors.append("Solomon gateway returned incorrect response")
            proto_response = None

        for i, key in enumerate(requested_keys):
            if proto_response is not None:
                proto_ts = proto_response.timeseries[i]
                if proto_ts.timestamps_millis:
                    last_ts_with_real_value = proto_ts.timestamps_millis[-1] / 1000
                    delay = adjusted_now - last_ts_with_real_value
                    self._unistat.push(self.DELAY_SIGNAL, delay)

                timeseries = parse_proto_timeseries_with_range(proto_ts, start_ts, end_ts, RT_PERIOD)
                errors = proto_ts.errors
            else:
                timeseries = make_empty_ts()
                errors = request_errors

            result_key = "{}:{}:{}".format(key.host, key.raw_tags, key.signal_expr)
            result[result_key] = fill_timeseries_result(timeseries, errors, without_messages)

        # treat invalid signals as signals with None values as described here
        # https://nda.ya.ru/3SzoEt
        for signal_name, msg in invalid_signals.iteritems():
            result[signal_name] = fill_timeseries_result(
                make_empty_ts(),
                [msg],
                False  # invalid signals were always returned with error
            )

        return result

    def prepare_call(self, keys_to_request, now):
        et = adjust_timestamp(now, RT_PERIOD)
        st = et - self._options.range_start_offset
        proto_request = gateway_service_proto.ReadDataRequest()
        for key in keys_to_request:
            proto_query = proto_request.queries.add()
            proto_query.hosts = key.host
            proto_query.tags = key.request_key.to_dynamic_string()
            proto_query.expression = key.signal_expr
            proto_query.from_millis = st * 1000
            proto_query.to_millis = et * 1000
            proto_query.grid_millis = RT_PERIOD * 1000

        call = MultiAttemptRpc(
            ReadyHostAttemptSequence.make_random(
                self._cluster_provider.get_cluster_hosts(),
                SOLOMON_GATEWAY_CLUSTER_ATTEMPTS,
                SOLOMON_GATEWAY_IN_CLUSTER_REQUEST_ATTEMPTS
            ),
            gateway_service_grpc.YasmGatewayServiceStub,
            rpc_name="ReadData",
            request=proto_request,
            timeout=RT_REQUEST_TIMEOUT,
            client_id=self._cluster_provider.grpc_user_agent,
            unistat=self._unistat,
            log=self._log,
            signals_prefix="rt"
        )
        return call


class RtHandler(BaseGatewayHandler):
    """
    Request looks like:

    '{
        "signals": [
            "ASEARCH:base_self:loadlog-success",
            "ASEARCH:base_self:passagelog-success"
        ]
    }'

    Response looks like:

    '{
        "ASEARCH:base_self:loadlog-success": [
            {
                "timestamp": 1360410495,
                "value": 26526382,
                "errors": []
            }, {
                "timestamp": 1360410500,
                "value": 26344397,
                "errors": [
                    "missing values for MAN1.0"
                ]
            },
            ...
        ],
        "ASEARCH:base_self:passagelog-success": [
            {
                "timestamp": 1360410495,
                "value": 1223378,
                "errors": [
                    "missing values for SAS1.1"
                ]
            }, {
                "timestamp": 1360410500,
                "value": 1224852,
                "errros": []
            },
            ...
        ]
    }'
    """

    FQDN_HTTP_HEADER = "X-Golovan-Front-Host"
    TIMESTAMP_HTTP_HEADER = "X-Golovan-Front-Host-Timestamp"
    HANDLE_STAT_NAME = "rt"
    ALLOWED_OPTIONS_HEADERS = ['X-Golovan-Rt-Request']

    def initialize(self, cluster_provider, unistat, front_id, options, yasmconf):
        super(RtHandler, self).initialize(unistat, front_id)
        self._data_processor = RtRequestDataProcessor(cluster_provider, yasmconf, options, self.unistat, self.log)

    @classmethod
    def prepare_unistat_signals(cls, unistat):
        super(RtHandler, cls).prepare_unistat_signals(unistat)
        RtRequestDataProcessor.prepare_unistat_signals(unistat)

    @gen.coroutine
    def handle(self, signals):
        if not signals:
            self.error("signals query parameter expected", status_code=400)
            return

        without_messages = self.get_query_argument("without_messages", "").lower() == "true"
        try:
            result = yield self.rt(signals, without_messages)
            self.response(result)
        except SignalLimitError:
            self.error("Too many signals", status_code=400)

    @gen.coroutine
    def handle_post(self, decoded_body):
        yield self.handle(decoded_body.get("signals"))

    @gen.coroutine
    def handle_get(self):
        yield self.handle(self.get_query_arguments("signals"))

    @gen.coroutine
    def rt(self, signals, without_messages=False):
        keys_to_request, invalid_signals = self._data_processor.parse_requested_signals(signals)
        now = time.time()
        if keys_to_request:
            call = self._data_processor.prepare_call(keys_to_request, now)
            try:
                proto_response = yield call.future()
                result = self._data_processor.pack_result(proto_response, keys_to_request, invalid_signals,
                                                          without_messages, now)
            except Exception:
                self.log.exception("Call to solomon gateway failed")
                # TODO fill all requested signals with nulls and errors
                result = self._data_processor.pack_result(None, keys_to_request, invalid_signals, without_messages, now,
                                                          request_errors=["No response from solomon gateway"])
        else:
            result = self._data_processor.pack_result(None, keys_to_request, invalid_signals, without_messages, now)

        raise gen.Return(result)
