# coding: utf-8
import collections

import datetime
import inject
import six
from infra.yasm.yasmapi import GolovanRequest, Transport

from infra.awacs.proto import model_pb2
from awacs.lib.gutils import gevent_idle_iter
from awacs.model import cache
from awacs.wrappers.base import Holder
import awacs.wrappers.main as wrappers


DAY_IN_SECONDS = 24 * 60 * 60
TOTAL_REPORT_UUID = 'service_total'


def datetime_to_timestamp(d):
    return (d.toordinal() - datetime.date(1970, 1, 1).toordinal()) * DAY_IN_SECONDS


BalancerTagSet = collections.namedtuple('BalancerTagSet', ['itype', 'ctype', 'prj', 'geo'])
NamespaceTagSet = collections.namedtuple('NamespaceTagSet', ['itype', 'ctypes', 'prjs'])


class RpsInfo(object):
    __slots__ = ('max', 'avg', 'sum', 'len', 'by_hour')

    def __init__(self, max_=0.0, avg_=0.0, sum_=0.0, len_=0, by_hour=None):
        self.max = max_
        self.avg = avg_
        self.sum = sum_
        self.len = len_
        self.by_hour = by_hour

    def __str__(self):
        return 'RpsInfo(max={}, avg={}, sum={}, len={})'.format(self.max, self.avg, self.sum, self.len)


class RpsStatisticsUpdater(object):
    _cache = inject.attr(cache.IAwacsCache)  # type: cache.AwacsCache

    @staticmethod
    def _get_signal_name_from_tagset_and_uuid(tagset, report_uuid):
        signal_name_template = 'balancer_report-report-{}-requests_summ'

        if isinstance(tagset, BalancerTagSet):
            return 'itype={};ctype={};prj={};geo={}:{}'.format(tagset.itype, tagset.ctype, tagset.prj,
                                                               tagset.geo, signal_name_template.format(report_uuid))
        elif isinstance(tagset, NamespaceTagSet):
            return 'itype={};ctype={};prj={}:{}'.format(tagset.itype, ','.join(tagset.ctypes), ','.join(tagset.prjs),
                                                        signal_name_template.format(report_uuid))
        else:
            raise AssertionError('Unknown tagset type')

    @staticmethod
    def _get_rps_by_tags(tagsets_and_uuids, start_time, end_time):
        """
        :type tagsets_and_uuids: list[(TagSet, six.text_type)]
        :type start_time: float
        :type end_time: float
        :return: dict[TagSet, RpsInfo]
        """

        tagsets_by_signal_name = {
            RpsStatisticsUpdater._get_signal_name_from_tagset_and_uuid(tagset, report_uuid): (tagset, report_uuid)
            for tagset, report_uuid in tagsets_and_uuids
        }

        period = 3600
        values_by_signal_name = collections.defaultdict(list)
        for timestamp, values in GolovanRequest(host='ASEARCH', period=period, st=int(start_time), et=int(end_time),
                                                fields=list(tagsets_by_signal_name), load_segments=200,
                                                transport=Transport(connect_timeout=100)):
            for signal, val in six.iteritems(values):
                values_by_signal_name[tagsets_by_signal_name[signal]].append(val if val else 0)
        rv = {}
        for tagset, report_uuid in tagsets_and_uuids:
            if not values_by_signal_name[tagset, report_uuid]:
                rv[(tagset, report_uuid)] = RpsInfo()
            else:
                rps = [float(v) / period for v in values_by_signal_name[tagset, report_uuid]]
                rv[(tagset, report_uuid)] = RpsInfo(
                    max_=float(max(rps)),
                    avg_=float(sum(rps)) / len(rps),
                    sum_=float(sum(rps)),
                    len_=len(rps),
                    by_hour=rps
                )
        return rv

    @staticmethod
    def get_rps_by_tags(tagsets_and_uuids, start_time, end_time):
        """
        :type tagsets_and_uuids: list[(TagSet, six.text_type)]
        :type start_time: float
        :type end_time: float
        :return: dict[TagSet, RpsInfo]
        """
        limit = 200
        skip = 0
        rv = {}
        while skip < len(tagsets_and_uuids):
            rv.update(RpsStatisticsUpdater._get_rps_by_tags(tagsets_and_uuids[skip:(skip + limit)], start_time, end_time))
            skip += limit
        return rv

    @staticmethod
    def extract_upstream_report_uuid(upstream_pb):
        if upstream_pb.spec.yandex_balancer.mode == upstream_pb.spec.yandex_balancer.EASY_MODE2:
            l7_upstream_macro_pb = upstream_pb.spec.yandex_balancer.config.l7_upstream_macro
            return l7_upstream_macro_pb.monitoring.uuid or upstream_pb.meta.id
        elif upstream_pb.spec.yandex_balancer.mode == upstream_pb.spec.yandex_balancer.FULL_MODE:
            config_pb = upstream_pb.spec.yandex_balancer.config
            h = Holder(config_pb)
            for module in h.walk_chain():
                if isinstance(module, wrappers.Report) and module.pb.uuid:
                    return module.pb.uuid
            return None
        return None

    def update_statistics(self, rps_by_balancer_pb, rps_by_namespace_pb, rps_by_upstream_pb):
        """
        :type rps_by_balancer_pb: MutableMapping[str, model_pb2.LoadStatisticsEntry.Content]
        :type rps_by_namespace_pb: MutableMapping[str, model_pb2.LoadStatisticsEntry.Content]
        :type rps_by_upstream_pb: MutableMapping[str, model_pb2.LoadStatisticsEntry.Content]
        """
        utc_midnight = datetime.datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
        yesterday = utc_midnight - datetime.timedelta(days=1)
        start_time = datetime_to_timestamp(yesterday)
        end_time = datetime_to_timestamp(utc_midnight)

        balancer_ids_by_tagset = collections.defaultdict(list)
        upstream_ids_by_tagset_and_report_uuid = collections.defaultdict(list)
        balancer_tagsets_by_namespace_id = collections.defaultdict(set)
        upstream_ids_by_report_uuid_by_namespace_id = collections.defaultdict(lambda: collections.defaultdict(list))
        tagsets_and_uuids = set()
        for namespace_pb in gevent_idle_iter(self._cache.list_all_namespaces()):
            for balancer_pb in self._cache.list_all_balancers(namespace_id=namespace_pb.meta.id):
                instance_tags = balancer_pb.spec.config_transport.nanny_static_file.instance_tags
                itype, ctype, prj = instance_tags.itype, instance_tags.ctype, instance_tags.prj
                geo = balancer_pb.meta.location.yp_cluster or balancer_pb.meta.location.gencfg_dc
                if not (itype and ctype and prj and geo):
                    continue
                geo = geo.lower()
                if geo in ('myt', 'iva'):
                    geo = 'msk'
                tagset = BalancerTagSet(itype, ctype, prj, geo)
                full_balancer_id = '{}/{}'.format(balancer_pb.meta.namespace_id, balancer_pb.meta.id)
                balancer_ids_by_tagset[tagset].append(full_balancer_id)
                balancer_tagsets_by_namespace_id[namespace_pb.meta.id].add(tagset)
                tagsets_and_uuids.add((tagset, TOTAL_REPORT_UUID))
            for upstream_pb in gevent_idle_iter(self._cache.list_all_upstreams(namespace_id=namespace_pb.meta.id)):
                uuid = self.extract_upstream_report_uuid(upstream_pb)
                if uuid is not None:
                    full_upstream_id = '{}/{}'.format(namespace_pb.meta.id, upstream_pb.meta.id)
                    upstream_ids_by_report_uuid_by_namespace_id[namespace_pb.meta.id][uuid].append(full_upstream_id)

        for namespace_id, upstream_ids_by_report_uuid in gevent_idle_iter(six.iteritems(upstream_ids_by_report_uuid_by_namespace_id)):
            if namespace_id not in balancer_tagsets_by_namespace_id:
                continue

            itypes, ctypes, prjs = set(), set(), set()
            for tagset in balancer_tagsets_by_namespace_id[namespace_id]:
                itypes.add(tagset.itype)
                ctypes.add(tagset.ctype)
                prjs.add(tagset.prj)

            if len(itypes) != 1:
                # How comes? Continue
                continue
            itype = itypes.pop()
            tagset = NamespaceTagSet(itype, tuple(ctypes), tuple(prjs))

            for report_uuid, upstream_ids in gevent_idle_iter(six.iteritems(upstream_ids_by_report_uuid)):
                for full_upstream_id in upstream_ids:
                    upstream_ids_by_tagset_and_report_uuid[(tagset, report_uuid)].append(full_upstream_id)
                    tagsets_and_uuids.add((tagset, report_uuid))

        tagsets_and_uuids = list(tagsets_and_uuids)
        rps = self.get_rps_by_tags(tagsets_and_uuids, start_time=start_time, end_time=end_time)
        for (tagset, report_uuid), balancer_rps in six.iteritems(rps):
            if not isinstance(tagset, BalancerTagSet):
                continue
            if len(balancer_ids_by_tagset[tagset]) > 1:
                accuracy = model_pb2.LoadStatisticsEntry.Content.Rps.UPPER_BOUND
            else:
                accuracy = model_pb2.LoadStatisticsEntry.Content.Rps.EXACT
            for full_balancer_id in balancer_ids_by_tagset[tagset]:
                rps_by_balancer_pb[full_balancer_id].max = balancer_rps.max
                rps_by_balancer_pb[full_balancer_id].average = balancer_rps.avg
                rps_by_balancer_pb[full_balancer_id].accuracy = accuracy

        for (tagset, report_uuid), upstream_rps in six.iteritems(rps):
            if not isinstance(tagset, NamespaceTagSet):
                continue
            if len(upstream_ids_by_tagset_and_report_uuid[(tagset, report_uuid)]) > 1:
                accuracy = model_pb2.LoadStatisticsEntry.Content.Rps.UPPER_BOUND
            else:
                accuracy = model_pb2.LoadStatisticsEntry.Content.Rps.EXACT
            for full_upstream_id in upstream_ids_by_tagset_and_report_uuid[(tagset, report_uuid)]:
                rps_by_upstream_pb[full_upstream_id].max = upstream_rps.max
                rps_by_upstream_pb[full_upstream_id].average = upstream_rps.avg
                rps_by_upstream_pb[full_upstream_id].accuracy = accuracy

        for namespace_id in balancer_tagsets_by_namespace_id:
            tagsets = balancer_tagsets_by_namespace_id[namespace_id]
            assert len(set(rps[(tagset, TOTAL_REPORT_UUID)].len for tagset in tagsets)) == 1
            max_namespace_rps = 0
            for i in range(rps[(list(tagsets)[0], TOTAL_REPORT_UUID)].len):
                total_moment_rps = sum(rps[(tagset, TOTAL_REPORT_UUID)].by_hour[i] for tagset in tagsets)
                max_namespace_rps = max(max_namespace_rps, total_moment_rps)
            rps_by_namespace_pb[namespace_id].max = max_namespace_rps
            rps_by_namespace_pb[namespace_id].average = sum(rps[(tagset, TOTAL_REPORT_UUID)].avg for tagset in tagsets)
