from __future__ import unicode_literals
import logging
import collections

import gevent
import yp.data_model as data_model
from sepelib.gevent import greenthread

from infra.swatlib.gevent import geventutil as gutil
from infra.rsc.src.lib import podutil
from infra.controllers_monitoring.src import consts
from infra.controllers_monitoring.src.collectors import mcrs as mcrs_collector
from infra.controllers_monitoring.src.collectors import rs as rs_collector


class ClusterPollResult(object):

    def __init__(self):
        self.unready_pods_over_budget_rs_dict = None
        self.rs_dict = collections.defaultdict(rs_collector.Rs)
        self.mcrs_pod_dict = collections.defaultdict(list)


class YpPoller(greenthread.GreenThread):
    name = 'poller'

    ITERATION_DELAY = 20

    def __init__(self, metrics_registry, client_map, xdc_cluster,
                 batch_size, pod_match_dict, rs_match_dict, yasm_export_mcrs_ids=None):

        super(YpPoller, self).__init__()
        self.log = logging.getLogger(self.name)
        self.metrics_registry = metrics_registry
        self.client_map = client_map
        self.xdc_cluster = xdc_cluster
        self.batch_size = batch_size
        self.rs_filter = podutil.make_filter_from_dict(rs_match_dict)
        self.yasm_export_mcrs_ids = set(yasm_export_mcrs_ids) if yasm_export_mcrs_ids else set()
        self.mcrs_filter = self.rs_filter
        self.pod_filter = (
            '{} AND '
            '([/labels/deploy_engine] = "{}" OR '
            '[/labels/deploy_engine] = "{}")'.format(
                podutil.make_filter_from_dict(pod_match_dict),
                consts.RSC_DEPLOY_ENGINE,
                consts.MCRSC_DEPLOY_ENGINE
            )
        )
        self.pod_selectors = [
            "/meta/id",
            "/meta/pod_set_id",
            "/spec/pod_agent_payload/spec/revision",
            "/labels/deploy_engine",
        ]
        self.rs_selectors = [
            "/meta/id",
            "/spec/replica_count",
            "/spec/revision_id",
            "/spec/deployment_strategy/max_unavailable",
            "/status"
        ]
        self.mcrs_selectors = [
            "/meta/id",
            "/spec/clusters",
            "/spec/revision",
            "/spec/deployment_strategy/max_unavailable",
            "/status"
        ]

    def poll_cluster(self, cluster, cluster_results, yasm_export_mcrs_ids):
        client = self.client_map[cluster]
        ts = client.generate_timestamp()
        pod_g = client.select_all_objects(obj_type=data_model.OT_POD,
                                          obj_class=data_model.TPod,
                                          query=self.pod_filter,
                                          batch_size=self.batch_size,
                                          timestamp=ts,
                                          selectors=self.pod_selectors)
        rs_g = client.select_all_objects(obj_type=data_model.OT_REPLICA_SET,
                                         obj_class=data_model.TReplicaSet,
                                         query=self.rs_filter,
                                         batch_size=self.batch_size,
                                         timestamp=ts,
                                         selectors=self.rs_selectors)
        rv = ClusterPollResult()

        for rs in gutil.gevent_idle_iter(rs_g):
            rv.rs_dict[rs.meta.id].rs = rs

        for p in gutil.gevent_idle_iter(pod_g):
            de = podutil.get_label(p.labels, 'deploy_engine')
            ps_id = p.meta.pod_set_id
            if de == consts.RSC_DEPLOY_ENGINE:
                rv.rs_dict[ps_id].pods.append(p)
            elif de == consts.MCRSC_DEPLOY_ENGINE:
                rv.mcrs_pod_dict[ps_id].append(p)

        rv.unready_pods_over_budget_rs_dict = rs_collector.collect_metrics(
            cluster=cluster,
            registry=self.metrics_registry,
            rs_list=rv.rs_dict.itervalues(),
            yasm_export_mcrs_ids=yasm_export_mcrs_ids
        )
        cluster_results[cluster] = rv

    def poll(self):
        cluster_results = {}
        greenlets = {}
        for c in self.client_map:
            greenlets[c] = gevent.Greenlet.spawn(self.poll_cluster, c, cluster_results, self.yasm_export_mcrs_ids)

        unavailable_clusters = set()
        gevent.joinall(greenlets.values(), raise_error=False)
        for cluster, greenlet in greenlets.iteritems():
            if greenlet.exception:
                unavailable_clusters.add(cluster)
                self.log.error('Poll cluster "%s failed with exception: \n"%s', cluster, greenlet.exception)

        rs_collector.set_deploy_units_budget_metrics(cluster_results=cluster_results, registry=self.metrics_registry)
        client = self.client_map[self.xdc_cluster]
        ts = client.generate_timestamp()
        mcrs_g = client.select_all_objects(
            obj_type=data_model.OT_MULTI_CLUSTER_REPLICA_SET,
            obj_class=data_model.TMultiClusterReplicaSet,
            query=self.mcrs_filter,
            batch_size=self.batch_size,
            timestamp=ts,
            selectors=self.mcrs_selectors
        )
        mcrs_list = []
        for yp_mcrs in gutil.gevent_idle_iter(mcrs_g):
            mcrs = mcrs_collector.Mcrs()
            mcrs.mcrs = yp_mcrs
            for c, r in cluster_results.iteritems():
                mcrs.pods[c] = r.mcrs_pod_dict[yp_mcrs.meta.id]
            mcrs_list.append(mcrs)
        mcrs_collector.collect_metrics(xdc_cluster=self.xdc_cluster,
                                       registry=self.metrics_registry,
                                       mcrs_list=mcrs_list,
                                       yasm_export_mcrs_ids=self.yasm_export_mcrs_ids,
                                       unavailable_clusters=unavailable_clusters,
                                       clusters_list=list(self.client_map.keys()))

    def run(self):
        while True:
            self.log.info('Starting polling all')
            try:
                self.poll()
            except Exception as e:
                self.log.error('Iteration failed: %s', e, exc_info=True)
            gevent.sleep(self.ITERATION_DELAY)
