import json
import time

from itertools import izip
from collections import defaultdict, namedtuple
from tornado import gen

from solomon.services.gateway.api import yasm_gateway_service_pb2 as gateway_service_proto
from solomon.services.gateway.api import yasm_gateway_service_pb2_grpc as gateway_service_grpc

from infra.yasm.gateway.lib.client.requester import MultiAttemptRpc, ReadyHostAttemptSequence
from infra.yasm.gateway.lib.client.requester import SOLOMON_GATEWAY_CLUSTER_ATTEMPTS
from infra.yasm.gateway.lib.client.requester import SOLOMON_GATEWAY_IN_CLUSTER_REQUEST_ATTEMPTS
from infra.yasm.gateway.lib.handlers.hist_base import HistFrontHandler
from infra.yasm.gateway.lib.tags.request import RequestKey
from infra.yasm.gateway.lib.tags.errors import TagsError
from infra.yasm.gateway.lib.util.common import MIN_PERIOD, adjust_timestamp, parse_proto_timeseries_with_range
from infra.yasm.gateway.lib.util import stat

get_id = '{}:{}_{}_{}'.format


def get_context_id(ctx):
    ctx_default_id = get_id(ctx['host'], ctx['st'], ctx['et'], ctx['period'])
    return ctx.get('id', ctx_default_id)


HistSeriesLimits = namedtuple("HistSeriesLimits", [
    "max_request_points_total",
    "max_request_points_signal",
    "max_signals_per_request"
])
TaggedSignalExpr = namedtuple("TaggedSignalExpr", [
    "raw_tagged_signal_expr",
    "request_key",
    "signal_expr"
])
HIST_SERIES_REQUEST_TIMEOUT = 30


class ReqCtxWrapper(object):

    """
    Convenience wrapper which describes user request
    """

    def __init__(self, ctx, host_name=None):
        self.host = host_name or ctx['host']
        self.period = int(ctx['period'])
        self.st = adjust_timestamp(int(ctx['st']), self.period)
        self.et = adjust_timestamp(int(ctx['et']), self.period)
        self.signals = ctx['signals']
        self.id = get_context_id(ctx)

    def __hash__(self):
        return hash(self.id)

    def __eq__(self, other):
        return self.id == other.id


class SeriesRequestDataProcessor(object):
    """
    Class for most of the request pre- and post-processing logic.
    """
    PARSE_REQUEST_ERRORS_SIGNAL = "hist.series.parse_request_errors"
    VALIDATE_REQUEST_ERRORS_SIGNAL = "hist.series.validate_request_errors"

    def __init__(self, request_id, cluster_provider, unistat, limits, yasmconf, log):
        self._request_id = request_id
        self._cluster_provider = cluster_provider
        self._unistat = unistat
        self._yasmconf = yasmconf
        self._limits = limits
        self._log = log

    @classmethod
    def prepare_unistat_signals(cls, unistat):
        unistat.create_float(cls.PARSE_REQUEST_ERRORS_SIGNAL)
        unistat.create_float(cls.VALIDATE_REQUEST_ERRORS_SIGNAL)

    def _parse_and_filter_valid_signals(self, signals):

        ctx_errors = []
        valid_signals = set()

        for signal in signals:

            try:
                tag, signal_expr = signal.split(":")

            except ValueError:
                ctx_errors.append(
                    "signal '{}' has to contain exactly one ':' symbol "
                    "(separator between tags and signal name)".format(signal))
                continue

            try:
                request_key = RequestKey.from_string(tag)

            except TagsError:
                ctx_errors.append("wrong tags format '{}'".format(tag))
                continue

            required_tags = self._yasmconf.required_tags(request_key.itype)
            if not all(request_key.has_tag_bound(tag) for tag in required_tags):
                ctx_errors.append("Request should contain tags: {}".format(", ".join(required_tags)))
                continue

            valid_signals.add(TaggedSignalExpr(signal, request_key, signal_expr))

        return valid_signals, ctx_errors

    def _ctx_contains_required_fields(self, ctx):
        result = {}
        error = None

        try:
            for field in SeriesHandler.REQUIRED_REQUEST_FIELDS:
                if field not in ctx:
                    error = "context JSON object have to contain '{}' field".format(field)
                    break
                else:
                    result[field] = ctx[field]

            for field in SeriesHandler.OPTIONAL_REQUEST_FIELDS:
                if field in ctx:
                    result[field] = ctx[field]
        except TypeError:
            self._log.exception('Incorrect ctx structure')
            error = "Incorrect ctx structure"

        return (result, None) if not error else (None, error)

    def check_srv_request(self, request_data):
        # request format: {"ctxList":[{"name":"get","id":"SEARCH38:1360221625_1360221660","host":"SEARCH38",
        #                              "st":1360221625,"et":1360221660,"signals":["common_self:ipmi-overheat", ...],
        #                              "period":5, "single_replica":0, "allow_legacy_types": true}]}

        # name field is not required
        # id field is not required
        # single_replica field is not required
        # allow_legacy_types field is not required

        # errors: {"context_id": [list of errors for context],
        #          "common": [list of common, not context specific errors]}

        filtered_request = []
        errors = defaultdict(list)

        total_pts_requested = 0
        ctx_list = []
        try:
            ctx_list = list(request_data['ctxList'])
        except Exception:
            self._log.exception("No proper ctx list in request")
            errors['common'].append("No proper ctx list in request")

        all_ctx_ids = set()
        for ctx in ctx_list:
            try:
                ctx, error = self._ctx_contains_required_fields(ctx)

                if ctx is not None:
                    # check tags and signals
                    ctx_id = get_context_id(ctx)

                    if ctx_id in all_ctx_ids:
                        msg = "Duplicated context id {}.".format(ctx_id)
                        self._log.warning(msg)
                        errors[ctx_id].append(msg)
                        continue
                    else:
                        all_ctx_ids.add(ctx_id)

                    sig_num, max_signals = len(ctx['signals']), self._limits.max_signals_per_request
                    if sig_num > max_signals:
                        msg = 'Request contains %s > %s signals, skipped'
                        errors[ctx_id].append(msg % (sig_num, max_signals))
                        self._log.warning(msg, sig_num, max_signals)
                        continue

                    valid_signals, ctx_errors = self._parse_and_filter_valid_signals(ctx['signals'])

                    period_len = int(ctx['et']) - int(ctx['st'])

                    # points to request for each signal
                    pts_to_request = period_len / int(ctx['period']) + 1

                    host = ctx['host']
                    if type(host) not in (str, unicode):
                        msg = "Incorrect host type [{}]. ".format(host)
                        self._log.warning(msg)
                        errors[ctx_id].append(msg)
                        continue

                    if not host:
                        msg = "Host can't be empty"
                        self._log.warning(msg)
                        errors[ctx_id].append(msg)
                        continue

                    if int(ctx['period']) < MIN_PERIOD or int(ctx['period'] % MIN_PERIOD) != 0:
                        msg = "Unsupported period {}".format(ctx['period'])
                        self._log.warning(msg)
                        errors[ctx_id].append(msg)
                        continue

                    if pts_to_request <= 0:
                        self._log.warning("Can't request negative or zero points")
                        errors[ctx_id].append("Can't request negative or zero points")
                        continue

                    # restrict requested points by options.max_request_points_signal
                    if pts_to_request > self._limits.max_request_points_signal:
                        self._log.warning('Too many points per signal requested: %s requested > %s maximum per signal',
                                          pts_to_request, self._limits.max_request_points_signal)
                        errors[ctx_id].append('Pts per signal limit exceeded {} > {}'.format(pts_to_request,
                                                                                             self._limits.max_request_points_signal))
                        continue

                    total_pts_requested += len(valid_signals) * pts_to_request

                    ctx['signals'] = valid_signals

                    filtered_request.append(ctx)

                else:
                    ctx_id = 'common'
                    ctx_errors = [error]

                if ctx_errors:
                    errors[ctx_id] = ctx_errors
            except Exception:
                self._log.exception('Error parsing request')
                self._unistat.push(self.PARSE_REQUEST_ERRORS_SIGNAL, 1)

        if total_pts_requested > self._limits.max_request_points_total:
            self._log.warning('Too many total points requested: %s > %s', total_pts_requested, self._limits.max_request_points_total)
            errors['common'].append('Total pts limit exceeded: {} > {}'.format(total_pts_requested,
                                                                               self._limits.max_request_points_total))
            filtered_request = []

        for ctx_errors in errors.itervalues():
            self._unistat.push(self.VALIDATE_REQUEST_ERRORS_SIGNAL, len(ctx_errors))

        return filtered_request, errors

    def prepare_call(self, raw_contexts):
        requested_context_signals = []
        proto_request = gateway_service_proto.ReadDataRequest()

        for raw_ctx in raw_contexts:
            ctx = ReqCtxWrapper(raw_ctx)
            for tagged_signal_expr in ctx.signals:
                proto_query = proto_request.queries.add()
                proto_query.hosts = ctx.host
                proto_query.tags = tagged_signal_expr.request_key.to_dynamic_string()
                proto_query.expression = tagged_signal_expr.signal_expr
                proto_query.from_millis = ctx.st * 1000
                proto_query.to_millis = (ctx.et + ctx.period) * 1000
                proto_query.grid_millis = ctx.period * 1000
                requested_context_signals.append((ctx, tagged_signal_expr))

        call = MultiAttemptRpc(
            ReadyHostAttemptSequence.make_random(
                self._cluster_provider.get_cluster_hosts(),
                SOLOMON_GATEWAY_CLUSTER_ATTEMPTS,
                SOLOMON_GATEWAY_IN_CLUSTER_REQUEST_ATTEMPTS
            ),
            gateway_service_grpc.YasmGatewayServiceStub,
            rpc_name="ReadData",
            request=proto_request,
            timeout=HIST_SERIES_REQUEST_TIMEOUT,
            client_id=self._cluster_provider.grpc_user_agent,
            unistat=self._unistat,
            log=self._log,
            signals_prefix="hist"
        )

        return call, requested_context_signals

    @staticmethod
    def _add_response_signal_values(ctx, raw_tagged_signal, signal_values, result):
        if ctx.id not in result:
            # init timeline and sub-dict once
            result[ctx.id] = {
                "content": {
                    "timeline": list(xrange(ctx.st, ctx.et + ctx.period, ctx.period)),
                    "values": {
                        raw_tagged_signal: signal_values
                    }
                }
            }
        else:
            result[ctx.id]["content"]["values"][raw_tagged_signal] = signal_values

    def pack_result(self, proto_response, requested_context_signals, request_errors):
        result = {}
        for (ctx, signal), proto_ts in izip(requested_context_signals, proto_response.timeseries):
            timeseries = parse_proto_timeseries_with_range(proto_ts, ctx.st, ctx.et + ctx.period, ctx.period)
            self._add_response_signal_values(ctx, signal.raw_tagged_signal_expr, timeseries.values, result)
            request_errors[ctx.id].extend(proto_ts.errors)
        self._add_errors_to_result(request_errors, result)
        result['request_id'] = self._request_id
        return result

    def pack_empty_result(self, requested_context_signals, request_errors):
        result = {}
        for ctx, signal in requested_context_signals:
            self._add_response_signal_values(
                ctx,
                signal.raw_tagged_signal_expr,
                [None] * ((ctx.et + ctx.period - ctx.st) / ctx.period),
                result)
        self._add_errors_to_result(request_errors, result)
        result['request_id'] = self._request_id
        return result

    def _add_errors_to_result(self, request_errors, result):
        requests_errors_set = set()
        for request_id, errors in request_errors.iteritems():
            requests_errors_set = requests_errors_set.union(errors)
        result['errors'] = sorted(requests_errors_set)


class SeriesHandler(HistFrontHandler):
    REQUIRED_REQUEST_FIELDS = ["host", "signals", "period", "st", "et"]
    OPTIONAL_REQUEST_FIELDS = ["id"]

    HIST_TOTAL_TIME_SIGNAL = "hist.series.hist_total_time"

    HANDLE_STAT_NAME = "series"

    def initialize(self, cluster_provider, unistat, front_id, limits, yasmconf):
        super(SeriesHandler, self).initialize(unistat, front_id)
        self._cluster_provider = cluster_provider
        self._yasmconf = yasmconf
        self._limits = limits
        self._data_processor = SeriesRequestDataProcessor(self.request_id, self._cluster_provider, self.unistat,
                                                          self._limits, self._yasmconf, self.log)

    @classmethod
    def prepare_unistat_signals(cls, unistat):
        super(SeriesHandler, cls).prepare_unistat_signals(unistat)
        unistat.create_histogram(cls.HIST_TOTAL_TIME_SIGNAL, intervals=stat.TIMING_INTERVALS_LOG_1_3)
        SeriesRequestDataProcessor.prepare_unistat_signals(unistat)

    @gen.coroutine
    def _process_request(self, request_data):
        started = time.time()

        self.log.info('start processing hist request')
        self._log_request_info(request_data)

        with stat.time_it(self.HIST_TOTAL_TIME_SIGNAL, self.unistat):
            verified_request, request_errors = self._data_processor.check_srv_request(request_data)
            if verified_request:
                call, requested_context_signals = self._data_processor.prepare_call(verified_request)
                try:
                    proto_response = yield call.future()
                    result = self._data_processor.pack_result(proto_response, requested_context_signals, request_errors)
                except Exception:
                    self.log.exception("Call to solomon gateway failed")
                    request_errors['common'].append("call to solomon gateway failed")
                    result = self._data_processor.pack_empty_result(requested_context_signals, request_errors)
            else:
                result = self._data_processor.pack_empty_result(requested_context_signals, request_errors)

        took = time.time() - started
        self.log.info('end processing hist request, took %g seconds', took)

        raise gen.Return(result)

    def _log_request_info(self, request_data):
        extra_info = {
            'referrer': self.request.headers.get(self.REFERRER_HEADER),
            'golovan_hostname': self.request.headers.get(self.YASMAPI_HOSTNAME_HEADER),
            'golovan_username': self.request.headers.get(self.YASMAPI_USERNAME_HEADER),
            'yandex_login': self.get_cookie(self.UI_USERNAME_COOKIE)
        }

        request_info = {
            'data': request_data,
            'info': {
                k: v
                for k, v in extra_info.iteritems()
                if v is not None
            }
        }

        self.log.info('request: %s', json.dumps(request_info))

    @gen.coroutine
    def handle_get(self):
        # get requests may be useful for debug purposes

        query_args = self.get_query_args(SeriesHandler.REQUIRED_REQUEST_FIELDS)

        # split signals by commas
        signals = [signal.strip() for signal in query_args.get("signals", '')
                                                .strip()
                                                .split(',')]

        query_args['signals'] = signals
        request_body = {'ctxList': [query_args]}

        result = yield self._process_request(request_body)
        self.response(result)

    @gen.coroutine
    def handle_post(self, request):
        result = yield self._process_request(request)
        self.response(result)
