import time
import typing
import logging
from itertools import chain
from functools import partial
from collections import OrderedDict, defaultdict, namedtuple

from gevent.pool import Pool as GeventPool

import yt.yson as yson
from yp.client import YpClient as BaseYpClient
from yp_proto.yp.client.api.proto import autogen_pb2, object_service_pb2

from infra.deploy_export_stats.src.libs.metrics import ROOT_REGISTRY
from infra.deploy_export_stats.src.reporters.base import BaseReporter

log = logging.getLogger('yp')


def yson_loads(data, default=None):
    res = yson.loads(data)
    if isinstance(res, yson.YsonEntity):
        return default
    return res


class YpObject(object):
    TYPE = None
    SELECTORS = None

    class QueryMaker(object):
        __slots__ = ('base_query', '_id')

        def __init__(self, base_query='true'):
            self.base_query = base_query
            self._id = None

        def update(self, obj):
            if not self._id or self._id < obj.obj_id:
                self._id = obj.obj_id

        def make_query(self):
            if not self._id:
                return self.base_query
            return '({}) AND [/meta/id] > "{}"'.format(self.base_query, self._id)


class Stage(YpObject):
    TYPE = autogen_pb2.OT_STAGE
    SELECTORS = OrderedDict([
        ('/meta/id', ('obj_id', yson.loads)),
        ('/labels/deploy_engine', ('deploy_engine', yson.loads)),
    ])

    __slots__ = ['obj_id', 'deploy_engine']

    def __init__(self, data):
        for i, (attr_name, parser) in enumerate(self.SELECTORS.values()):
            setattr(self, attr_name, parser(data[i]))


class ReplicaSet(YpObject):
    TYPE = autogen_pb2.OT_REPLICA_SET

    SELECTORS = OrderedDict([
        ('/meta/id', ('obj_id', yson.loads)),
        ('/labels/deploy_engine', ('deploy_engine', yson.loads)),
    ])

    __slots__ = ['obj_id', 'deploy_engine']

    def __init__(self, data):
        for i, (attr_name, parser) in enumerate(self.SELECTORS.values()):
            setattr(self, attr_name, parser(data[i]))


class MultiClusterReplicaSet(YpObject):
    TYPE = autogen_pb2.OT_MULTI_CLUSTER_REPLICA_SET

    SELECTORS = OrderedDict([
        ('/meta/id', ('obj_id', yson.loads)),
        ('/labels/deploy_engine', ('deploy_engine', yson.loads)),
    ])

    __slots__ = ['obj_id', 'deploy_engine']

    def __init__(self, data):
        for i, (attr_name, parser) in enumerate(self.SELECTORS.values()):
            setattr(self, attr_name, parser(data[i]))


class PodSet(YpObject):
    # by locations
    TYPE = autogen_pb2.OT_POD_SET

    SELECTORS = OrderedDict([
        ('/meta/id', ('obj_id', yson.loads)),
        ('/spec/node_segment_id', ('segment', yson.loads)),
        ('/labels/deploy_engine', ('deploy_engine', yson.loads)),
    ])

    __slots__ = ['obj_id', 'deploy_engine', 'segment']

    def __init__(self, data):
        for i, (attr_name, parser) in enumerate(self.SELECTORS.values()):
            setattr(self, attr_name, parser(data[i]))
        
        self.segment = self.prepare_segment(self.segment)

    def prepare_segment(self, segment):
        if segment.startswith("yt_"):
            return "yt_any"
        return segment

    def __repr__(self):
        return 'Pod(obj_id={}, {}, {})'.format(
            self.obj_id, self.deploy_engine, self.segment
        )


class Pod(YpObject):
    # by locations

    TYPE = autogen_pb2.OT_POD

    SELECTORS = OrderedDict([
        ('/meta/pod_set_id', ('pod_set_id', yson.loads)),
        ('/meta/id', ('obj_id', yson.loads)),
        ('/status/scheduling/state', ('scheduling_state', yson.loads)),
        ('/labels/deploy_engine', ('deploy_engine', yson.loads)),
        ('/labels/vmagent_version', ('vmagent_version', yson.loads)),
        ('/spec/resource_requests/vcpu_guarantee', ('vcpu_guarantee', yson.loads)),
        ('/spec/resource_requests/memory_limit', ('memory_limit', partial(yson_loads, default=0))),
        ('/spec/disk_volume_requests', ('disk_volume_requests', partial(yson_loads, default=list()))),
    ])

    __slots__ = ['pod_set_id',
                 'obj_id',
                 'deploy_engine',
                 'scheduling_state',
                 'vcpu_guarantee',
                 'memory_limit',
                 'disk_volume_requests',
                 'segment',
                 'vmagent_version',
                 'hdd_bytes',
                 'ssd_bytes',
                 ]

    def __init__(self, data):
        for i, (attr_name, parser) in enumerate(self.SELECTORS.values()):
            setattr(self, attr_name, parser(data[i]))

        self.ssd_bytes = 0
        self.hdd_bytes = 0

        if isinstance(self.vmagent_version, yson.YsonEntity):
            self.vmagent_version = 'n-a' if self.deploy_engine == "QYP" else None

        for volume in self.disk_volume_requests:
            storage_class = volume.get('storage_class')
            if storage_class == 'ssd':
                self.ssd_bytes += volume['quota_policy']['capacity']
            elif storage_class == 'hdd':
                self.hdd_bytes += volume['quota_policy']['capacity']

    def __repr__(self):
        return 'Pod(pod_id={}, {}, {})'.format(
            self.obj_id, self.deploy_engine, self.segment
        )

    class QueryMaker(YpObject.QueryMaker):
        __slots__ = ('base_query', '_id_tuple')

        ID_TEMPLATE = '({}) AND ([/meta/pod_set_id], [/meta/id]) > ("{}", "{}")'

        def __init__(self, base_query='true'):
            self.base_query = base_query
            self._id_tuple = None

        def update(self, obj):  # type: (Pod) -> None
            id_tuple = obj.pod_set_id, obj.obj_id
            if not self._id_tuple or self._id_tuple < id_tuple:
                self._id_tuple = id_tuple

        def make_query(self):
            if not self._id_tuple:
                return self.base_query
            return self.ID_TEMPLATE.format(self.base_query, *self._id_tuple)


class Node(YpObject):
    TYPE = autogen_pb2.OT_NODE
    SELECTORS = OrderedDict([
        ('/meta/id', ('obj_id', yson.loads)),
        ('/labels/segment', ('segment', yson.loads)),
        ('/labels/extras/migration/source', ('migration_source', yson.loads)),
    ])

    __slots__ = ['obj_id', 'segment', 'migration_source']

    def __init__(self, data):
        for i, (attr_name, parser) in enumerate(self.SELECTORS.values()):
            setattr(self, attr_name, parser(data[i]))


class CPUResource(object):
    def __init__(self, data):
        _parsed = yson.loads(data).get('cpu', {})
        self.cpu_to_vcpu_factor = _parsed.get('cpu_to_vcpu_factor')
        self.total_capacity = _parsed.get('total_capacity')
        if not self.cpu_to_vcpu_factor:
            self.cpu_to_vcpu_factor = 1.0
        # Host may be split into parts like 18.8 cores for YP and 13.2 cores for gencfg
        self.real_vcpu = (self.total_capacity / self.cpu_to_vcpu_factor) / 1000.


class Resource(YpObject):
    TYPE = autogen_pb2.OT_RESOURCE
    SELECTORS = OrderedDict([
        ('/meta/node_id', ('node_id', yson.loads)),
        ('/meta/id', ('obj_id', yson.loads)),
        ('/meta/kind', ('kind', yson.loads)),
        ('/spec', ('spec', CPUResource)),
        ('/status/used/cpu/capacity', ('used_vcpu', yson.loads)),
    ])

    __slots__ = ['obj_id',
                 'node_id',
                 'kind',
                 'spec',
                 'used_vcpu',
                 ]

    def __init__(self, data):
        for i, (attr_name, parser) in enumerate(self.SELECTORS.values()):
            setattr(self, attr_name, parser(data[i]))

    class QueryMaker(YpObject.QueryMaker):
        __slots__ = ('base_query', '_id_tuple')

        ID_TEMPLATE = '({}) AND ([/meta/node_id], [/meta/id]) > ("{}", "{}")'

        def __init__(self, base_query='true'):
            self.base_query = base_query
            self._id_tuple = None

        def update(self, obj):  # type: (Resource) -> None
            id_tuple = obj.node_id, obj.obj_id
            if not self._id_tuple or self._id_tuple < id_tuple:
                self._id_tuple = id_tuple

        def make_query(self):
            if not self._id_tuple:
                return self.base_query
            return self.ID_TEMPLATE.format(self.base_query, *self._id_tuple)


class YpClient(object):

    def __init__(self, yp_cluster_name, stub):
        self.stub = stub
        self.yp_cluster_name = yp_cluster_name

    def _select_objects_values(self, object_type, limit,
                               query=None, selectors=None, timestamp=None):
        req = object_service_pb2.TReqSelectObjects()
        req.object_type = object_type
        req.limit.value = limit
        if timestamp is not None:
            req.timestamp = timestamp
        selectors = selectors or ['']
        req.selector.paths.extend(selectors)
        if query:
            req.filter.query = query

        resp = self.stub.SelectObjects(req)
        for r in resp.results:
            yield r.values

    def select_all_objects(self, obj_cls, query='true', batch_size=200, timestamp=None):
        # type: (typing.Type[YpObject], str, int, int) -> typing.Generator[YpObject]

        query_maker = obj_cls.QueryMaker(query)

        while True:
            q = query_maker.make_query()
            objs_values = self._select_objects_values(object_type=obj_cls.TYPE,
                                                      limit=batch_size,
                                                      selectors=obj_cls.SELECTORS.keys(),
                                                      query=q,
                                                      timestamp=timestamp)
            count = 0
            for obj_data in objs_values:
                obj = obj_cls(obj_data)
                yield obj
                count += 1
                query_maker.update(obj)
            if count < batch_size:
                return


RegistryPath = namedtuple('RegistryPath', ['location', 'deploy_engine', 'segment'])


class PodsStats(object):
    __slots__ = [
        'count',
        'cpu',
        'memory_bytes',
        'hdd_bytes',
        'ssd_bytes',
    ]

    def __init__(self):
        for attr in self.__slots__:
            setattr(self, attr, 0)

    @property
    def cpu_cores(self):
        return self.cpu / 1000

    @property
    def memory_gb(self):
        return self.memory_bytes / 1024 ** 3

    @property
    def hdd_tb(self):
        return self.hdd_bytes / 1024 ** 4

    @property
    def ssd_tb(self):
        return self.ssd_bytes / 1024 ** 4

    def __str__(self):
        return "count:{self.count};" \
               " cpu:{self.cpu_cores}c;" \
               " memory:{self.memory_gb}Gb;" \
               " hdd:{self.hdd_tb}Tb;" \
               " ssd:{self.ssd_tb}Tb;".format(self=self)

    def join(self, other):
        self.ssd_bytes += other.ssd_bytes
        self.hdd_bytes += other.hdd_bytes
        self.cpu += other.cpu
        self.count += other.count
        self.memory_bytes += other.memory_bytes

    def export(self, registry):
        registry.get_gauge('cpu-cores').set(self.cpu_cores)
        registry.get_gauge('memory-gb').set(self.memory_gb)
        registry.get_gauge('ssd-tb').set(self.ssd_tb)
        registry.get_gauge('hdd-tb').set(self.hdd_tb)
        registry.get_gauge('pods-count').set(self.count)


class StagesStats(object):
    __slots__ = ['count']

    def __init__(self):
        for attr in self.__slots__:
            setattr(self, attr, 0)

    def export(self, registry):
        registry.get_gauge('stages-count').set(self.count)


class ReplicaSetStats(object):
    __slots__ = ['count']

    def __init__(self):
        for attr in self.__slots__:
            setattr(self, attr, 0)

    def export(self, registry):
        registry.get_gauge('replica-set-count').set(self.count)

    def join(self, other):
        self.count += other.count


class MultiClusterReplicaSetStats(object):
    __slots__ = ['count']

    def __init__(self):
        for attr in self.__slots__:
            setattr(self, attr, 0)

    def export(self, registry):
        registry.get_gauge('mc-replica-set-count').set(self.count)


class QypVmagentVersionsStats(object):

    def __init__(self):
        self.vmagent_version_count = defaultdict(lambda: 0)

    def increment(self, vmagent_version):
        self.vmagent_version_count[vmagent_version] += 1

    def _prepare_version(self, vmagent_version):
        if vmagent_version == 'N/A':
            return 'n-a'
        return vmagent_version.replace('.', '-')

    def join(self, other):  # type: (QypVmagentVersionsStats) -> None
        for vmagent_version, count in other.vmagent_version_count.items():
            self.vmagent_version_count[vmagent_version] += count

    def export(self, registry):
        for vmagent_version, count in self.vmagent_version_count.items():
            registry.get_gauge('{}-count'.format(vmagent_version)).set(count)

    def __str__(self):
        return ";".join(["{}:{}".format(k, v) for k, v in self.vmagent_version_count.items()])


class QypVolumesCountStats(object):
    def __init__(self):
        self.volumes_count_to_vm = defaultdict(lambda: 0)

    def increment(self, volumes_count):
        self.volumes_count_to_vm[volumes_count] += 1

    def join(self, other):  # type: (QypVolumesCountStats) -> None
        for volumes_count, count in other.volumes_count_to_vm.items():
            self.volumes_count_to_vm[volumes_count] += count

    def export(self, registry):
        for volumes_count, count in self.volumes_count_to_vm.items():
            registry.get_gauge('{}-count'.format(volumes_count)).set(count)

    def __str__(self):
        return ";".join(["{}:{}".format(k, v) for k, v in self.volumes_count_to_vm.items()])






class YpAllocatedReporter(BaseReporter):

    def __init__(self, token, config, registry=None):
        self.yp_discovery = config.get('enable_master_discovery', True)
        self.yp_cfg = {'token': token, 'enable_master_discovery': self.yp_discovery}
        self.config = config
        self._deploy_engines = self.config.get('deploy_engines')
        self._pool = GeventPool(size=10)
        self._registry = registry or ROOT_REGISTRY
        self._yp_registry = self._registry.path('allocated', 'yp')
        self._qyp_vmagent_ver_registry = self._registry.path('vmagent-ver', 'qyp')
        self._qyp_vols_counts_registry = self._registry.path('vols-count', 'qyp')
        self._reporter_registry = self._registry.path('reporters', self.__class__.__name__)
        self._idle_seconds = config.get('idle_seconds', 60 * 60)

    @property
    def idle_seconds(self):
        return self._idle_seconds

    def build_yp_client(self, cluster_config):
        address = cluster_config.get('address')
        yp_cluster_name = cluster_config.get('name')
        client_base = BaseYpClient(address=address, config=self.yp_cfg)
        stub = client_base.create_grpc_object_stub()
        return YpClient(yp_cluster_name, stub)

    def get_pod_sets(self, cluster_config):
        yp_client = self.build_yp_client(cluster_config)
        return yp_client.yp_cluster_name, yp_client.select_all_objects(PodSet)

    def get_stages_stats(self, cluster_config):  # type: (dict) -> typing.Tuple[str, StagesStats]

        yp_client = self.build_yp_client(cluster_config)
        stages_stats = StagesStats()

        for _ in yp_client.select_all_objects(Stage):  # type: Stage
            stages_stats.count += 1

        return yp_client.yp_cluster_name, stages_stats

    def get_replica_sets_stats(self, cluster_config):  # type: (dict) -> typing.Tuple[str, ReplicaSetStats]
        yp_client = self.build_yp_client(cluster_config)
        replica_sets_stats = ReplicaSetStats()

        for _ in yp_client.select_all_objects(ReplicaSet):  # type: ReplicaSet
            replica_sets_stats.count += 1

        return yp_client.yp_cluster_name, replica_sets_stats

    def get_multi_cluster_replica_sets_stats(self, cluster_config):
        # type: (dict) -> typing.Tuple[str, MultiClusterReplicaSetStats]
        yp_client = self.build_yp_client(cluster_config)
        stats = MultiClusterReplicaSetStats()

        for _ in yp_client.select_all_objects(MultiClusterReplicaSet):  # type: MultiClusterReplicaSet
            stats.count += 1

        return yp_client.yp_cluster_name, stats

    def get_nodes_vcpu(self, cluster_config):
        yp_client = self.build_yp_client(cluster_config)
        for resource in yp_client.select_all_objects(Resource, query='[/meta/kind] = "cpu"'):  # type: Resource
            yield resource.node_id, resource

    def get_all_nodes_vcpu(self):  # type: () -> typing.Generator[typing.Tuple[str, Resource]]
        for cluster_config in self.config.get('clusters'):
            for node_id, resource in self.get_nodes_vcpu(cluster_config):
                yield node_id, resource

    def get_nodes(self, cluster_config, query=None):  # type: (dict, str) -> typing.Generator[typing.Tuple[str, Node]]
        yp_client = self.build_yp_client(cluster_config)
        for node in yp_client.select_all_objects(Node, query=query or 'true'):  # type: Node
            yield node.obj_id, node

    def get_managed_nodes(self):  # type: () -> typing.Generator[typing.Tuple[str, str, Node]]
        query = '[/labels/segment] != "not_managed"'
        for cluster_config in self.config.get('clusters'):
            for node_id, node in self.get_nodes(cluster_config, query=query):
                yield cluster_config.get('name'), node_id, node

    def get_pods_stats(self,
                       cluster_config,  # type: dict
                       pod_sets  # type: typing.Generator[PodSet]
                       ):  # type: (...) -> typing.Iterable[typing.Tuple[RegistryPath, PodsStats]]

        yp_client = self.build_yp_client(cluster_config)

        pod_sets_dict = {pod_set.obj_id: pod_set for pod_set in pod_sets}

        stats = defaultdict(PodsStats)

        for pod in yp_client.select_all_objects(Pod, query='[/status/scheduling/state] = "assigned"'):  # type: Pod
            pod_set = pod_sets_dict.get(pod.pod_set_id)
            if not pod_set:
                continue
            pod.segment = pod_set.segment

            if pod.deploy_engine not in self._deploy_engines:
                continue

            stat_key = RegistryPath(yp_client.yp_cluster_name, pod.deploy_engine.lower(), pod.segment)

            target = stats[stat_key]

            target.cpu += pod.vcpu_guarantee
            target.memory_bytes += pod.memory_limit
            target.hdd_bytes += pod.hdd_bytes
            target.ssd_bytes += pod.ssd_bytes
            target.count += 1

        return stats.items()

    def get_qyp_stats(self, cluster_config):
        yp_client = self.build_yp_client(cluster_config)

        query = '[/status/scheduling/state] = "assigned"' \
                ' AND [/labels/deploy_engine] = "QYP"'
        vmagent_versions_stats = QypVmagentVersionsStats()
        volumes_count_stats = QypVolumesCountStats()
        for pod in yp_client.select_all_objects(Pod, query=query):
            vmagent_versions_stats.increment(pod.vmagent_version)
            volumes_count_stats.increment(len(pod.disk_volume_requests))

        return yp_client.yp_cluster_name, vmagent_versions_stats, volumes_count_stats

    def run(self, start_at, initial=False):

        self._reporter_registry.get_gauge('last-run-start').set(start_at)

        timer_buckets = [5.0, 7.0, 10.0, 15.0, 20.0, 30.0, 50.0, 100.0, 1000.0]

        with self._reporter_registry.get_histogram('fetch-pods-stats', timer_buckets).timer():
            pod_sets_by_cluster = dict(self.get_pod_sets(cluster_config)
                                       for cluster_config in self.config.get('clusters'))

            pods_stats_by_location = self._pool.map(lambda args: self.get_pods_stats(*args),
                                                    [(c, pod_sets_by_cluster[c['name']]) for c in
                                                     self.config.get('clusters')])

        global_pod_stats = defaultdict(PodsStats)
        segments_pod_stats = defaultdict(PodsStats)

        log.info('Fetch yp allocated stats by {}s'.format(int(time.time() - start_at)))

        # Pods stats for all deploy engines
        for (stat_key, pods_stats) in chain(*pods_stats_by_location):  # type: RegistryPath, PodsStats
            registry = self._yp_registry.path(*stat_key)
            pods_stats.export(registry)

            global_pod_stats[stat_key.deploy_engine, stat_key.segment].join(pods_stats)
            global_pod_stats[stat_key.deploy_engine, 'all'].join(pods_stats)

            segments_pod_stats[stat_key.location, stat_key.deploy_engine].join(pods_stats)
            segments_pod_stats['global', stat_key.deploy_engine].join(pods_stats)

            log.info("Allocated {s.location} {s.deploy_engine} {s.segment}  \t {}".format(pods_stats, s=stat_key))

        for ((deploy_engine, segment), pods_stats) in global_pod_stats.items():
            registry = self._yp_registry.path('global', deploy_engine, segment)
            pods_stats.export(registry)
            log.info("Allocated {s[0]} {s[1]} {s[2]} \t {}".format(
                pods_stats, s=('global', deploy_engine, segment))
            )

        for ((location, deploy_engine), pods_stats) in segments_pod_stats.items():
            registry = self._yp_registry.path(location, deploy_engine, 'all')
            pods_stats.export(registry)
            log.info("Allocated {s[0]} {s[1]} {s[2]} \t {}".format(
                pods_stats, s=(location, deploy_engine, 'all'))
            )

        # PodSets stats for YP_LITE (Unique Count)
        pod_sets_by_cluster = dict(self.get_pod_sets(cluster_config) for cluster_config in self.config.get('clusters'))

        yp_lite_services_by_segment = defaultdict(lambda: set())
        for pod_set in chain(*pod_sets_by_cluster.values()):
            if pod_set.deploy_engine != 'YP_LITE':
                continue
            yp_lite_services_by_segment[pod_set.segment].add(pod_set.obj_id)

        for segment, pod_set_ids in yp_lite_services_by_segment.items():
            stat_key = 'global', 'yp-lite', segment
            registry = self._yp_registry.path(*stat_key)
            registry.get_gauge('pod-sets-count').set(len(pod_set_ids))
            log.info("Yp-Lite pod sets {} count: {}".format(segment, len(pod_set_ids)))

        self._reporter_registry.get_gauge('last-run-done').set(int(time.time()))

        # QYP Vmagent Versions Stats
        qyp_stats_by_location = self._pool.map(lambda conf: self.get_qyp_stats(conf),
                                                   [c for c in self.config.get('clusters')])

        global_vmagent_versions_stats = QypVmagentVersionsStats()
        global_volumes_count_stats = QypVolumesCountStats()
        for cluster, vmagent_version_stats, volumes_count_stats in qyp_stats_by_location:
            global_vmagent_versions_stats.join(vmagent_version_stats)
            global_volumes_count_stats.join(volumes_count_stats)

        global_vmagent_versions_stats.export(self._qyp_vmagent_ver_registry.path('all'))
        global_volumes_count_stats.export(self._qyp_vols_counts_registry.path('all'))
        log.info('QYP Vmagent Versions: {}'.format(global_vmagent_versions_stats))
        log.info('QYP Vols Count: {}'.format(global_volumes_count_stats))

        # Stages Stats for YD
        location, stages_stats = self.get_stages_stats(self.config.get('xdc_cluster'))
        registry = self._yp_registry.path(location, 'yd', 'default')
        stages_stats.export(registry)

        # Replica Set Stats for RSC
        replica_sets_by_cluster = dict(self.get_replica_sets_stats(cluster_config)
                                       for cluster_config in self.config.get('clusters'))

        global_replica_set_stats = ReplicaSetStats()
        for cluster_name, replica_set_stats in replica_sets_by_cluster.items():
            registry = self._yp_registry.path(cluster_name, 'rsc', 'default')
            replica_set_stats.export(registry)
            global_replica_set_stats.join(replica_set_stats)

        registry = self._yp_registry.path('global', 'rsc', 'default')
        global_replica_set_stats.export(registry)

        # Multi Cluster Replica Set Stats for MCRSC
        location, mcrsc_stats = self.get_multi_cluster_replica_sets_stats(self.config.get('xdc_cluster'))
        registry = self._yp_registry.path(location, 'mcrsc', 'default')
        mcrsc_stats.export(registry)
