import collections
import datetime
import logging
import random

import yp.client
import yp.common

from yt.yson.yson_types import YsonEntity

import entities


STORAGE_VOLUME_IDS = (
    'basesearch',
    'storage',
    'storage_volume',
    'shard_root',
)


PodFields = collections.namedtuple('PodFields', ['id', 'pod_index', 'fqdn', 'node'])
DeploySlot = collections.namedtuple('DeploySlot', ['cluster', 'deploy_unit', 'port'])
Node = collections.namedtuple('Node', ['id', 'rack', 'network_10G'])


class YpMasters(object):
    man_pre = 'man-pre.yp.yandex.net:8090'
    man = 'man.yp.yandex.net:8090'
    sas = 'sas.yp.yandex.net:8090'
    vla = 'vla.yp.yandex.net:8090'
    xdc = 'xdc.yp.yandex.net:8090'


# todo: replace publicly accessible client with utility functions
def client(master, token=None):
    token = token or yp.client.find_token()
    _log.info('Connect to [%s]', master)
    return yp.client.YpClient(master, config={'token': token})


def man():
    return client(YpMasters.man)


def man_pre():
    return client(YpMasters.man_pre)


def xdc():
    return client(YpMasters.xdc)


def read_pods(yp_cluster, pod_set_id):
    with client(yp_cluster) as yp_client:
        _log.info('get pods %s', pod_set_id)
        records = yp_client.select_objects('pod',
                                           filter='[/meta/pod_set_id]="{}"'.format(pod_set_id),
                                           selectors=['/meta/id', '/labels/pod_index',
                                                      '/status/dns/persistent_fqdn', '/spec/node_id'])
        return [PodFields(record[0], record[1], record[2], record[3]) for record in records]


def agent_shard_number_mapping(pods_list, port, tier):
    """
    :returns: Dict[Agent(fqdn, port), shard_number]; where shard_number = pod_index % shard_count
    """
    shard_numbers = tier.list_shard_numbers()
    return {
        entities.Agent(pod.fqdn, port):
            shard_numbers[pod.pod_index % tier.shards_count] if pod.pod_index
            else shard_numbers[0]
        for pod in pods_list
    }


class Instance(entities.AbstractInstance):
    @property
    def id(self):
        return self.raw_data['/meta/id']

    @property
    def hostname(self):
        return self.raw_data['/status/dns/persistent_fqdn']

    @property
    def node_name(self):
        return self.raw_data['/spec/node_id']

    @property
    def pod_index(self):
        return int(self.raw_data['/labels/pod_index'])

    @property
    def port(self):
        return self.raw_data['_port']

    @property
    def rack(self):
        return self.raw_data['_node'].rack

    @property
    def is_slow_network(self):
        return not self.raw_data['_node'].network_10G

    @property
    def shard_number(self):
        raise NotImplementedError()

    @property
    def tags(self):
        return set(self.raw_data['_tags'])

    @property
    def storage_size(self):
        disk_spec = self._get_storage_disk()

        if disk_spec:
            return int(disk_spec.get('quota_policy', {}).get('capacity', 0))
        return 0

    @property
    def is_on_ssd(self):
        disk_spec = self._get_storage_disk()

        if disk_spec:
            return disk_spec.get('storage_class') == 'ssd'
        return False

    @property
    def podset(self):
        return self.raw_data['/meta/pod_set_id']

    @property
    def is_alive(self):
        try:
            status_ready = self.raw_data['/status/agent/pod_agent_payload/status/ready/status']
            if not isinstance(status_ready, YsonEntity):
                # pod found in deploy
                return bool(status_ready)
            else:
                # pod found in nanny
                summaries = self.raw_data['/status/agent/iss_summary/state_summaries']
                if not isinstance(summaries, YsonEntity):
                    for state in dict(summaries).values():
                        if state['current_state'] == 'ACTIVE':
                            return True
                return False
        except Exception as e:
            _log.warning("could not verify activation for pod %s, becase of %s",
                         self.id,
                         str(e))
            return True

    def _get_storage_disk(self):
        volumes_spec = self.raw_data['/spec/pod_agent_payload/spec/volumes']
        disk_volume_requests = self.raw_data['/spec/disk_volume_requests']

        if isinstance(volumes_spec, YsonEntity):
            return _get_storage_disk_from_yp_lite(disk_volume_requests)
        else:
            return _get_storage_disk_from_ydeploy(volumes_spec, disk_volume_requests)


def _get_storage_disk_from_yp_lite(disk_volume_requests):
    for volume_request in disk_volume_requests:
        if volume_request["labels"]["volume_type"] == "root_fs":
            return volume_request
    raise Exception("Unable to get disk storage in _get_storage_disk_from_yp_lite()")


def _get_storage_disk_from_ydeploy(volumes_spec, disk_volume_requests):
    disk_ref = _get_disk_ref(volumes_spec, STORAGE_VOLUME_IDS)
    return _get_disk(disk_volume_requests, disk_ref)


def _get_disk_ref(volumes_spec, volume_ids):
    for volume in volumes_spec:
        if volume['id'] in volume_ids:
            return volume['virtual_disk_id_ref']

    if len(volumes_spec) == 1:
        return volumes_spec[0]['virtual_disk_id_ref']


def _get_disk(disks_spec, disk_id):
    for disk in disks_spec:
        if disk['id'] == disk_id:
            return disk


class InstanceProvider(object):
    def __init__(self, pod_sets, report_tags, tier=None, blacklist_nodes=None, cache_for=300):
        self._pod_sets = pod_sets
        self._tags = report_tags
        self._blacklist_nodes = blacklist_nodes or set()

        self._cache_for = cache_for
        self._agents_instances = None
        self._last_update = None
        self._tier = tier
        self._pod_count = None

    def _cache_age(self):
        return (datetime.datetime.now() - self._last_update).seconds

    def group_keys(self):
        return [pod_set.deploy_unit for pod_set in self._pod_sets]

    def count_pods(self):
        if len(self._pod_sets) != 1:
            raise RuntimeError('Can\'t count pods in {} podsets, one required'.format(len(self._pod_sets)))

        pod_set = self._pod_sets[0]
        try:
            pod_count = len(self._read_pods(client(pod_set.cluster), pod_set))
            self._pod_count = pod_count
            return pod_count
        except yp.common.GrpcError:
            if self._pod_count is None:
                raise RuntimeError('Can\'t count pods in {}'.format(pod_set.deploy_unit))
            else:
                _log.warning('Can\'t count pods in %s, use cached value', pod_set.deploy_unit)
                return self._pod_count

    def _update(self):
        agents_instances = {}

        for pod_set in self._pod_sets:
            _log.info('get pods %s', pod_set.deploy_unit)
            try:
                with client(pod_set.cluster) as yp_client:
                    for instance in self._read_pods(yp_client, pod_set):
                        if instance.node_name in self._blacklist_nodes:
                            continue
                        instance.raw_data.update({
                            '_port': pod_set.port,
                            '_tags': self._tags
                        })
                        if instance.hostname and instance.node_name:
                            agents_instances[instance.get_agent()] = instance

                    nodes = self._read_nodes(
                        yp_client,
                        {
                            instance.node_name
                            for instance in agents_instances.itervalues()
                            if instance.hostname and instance.node_name
                               and instance.podset == pod_set.deploy_unit
                        }
                    )

                    for instance in agents_instances.itervalues():
                        instance.raw_data.update({
                            '_node': nodes.get(instance.node_name),
                        })
            except yp.common.GrpcError:
                _log.warning('Could not get podset %s, use cached pods',
                             pod_set.deploy_unit)
                for agent, instance in self._agents_instances.items():
                    if instance.podset == pod_set.deploy_unit:
                        agents_instances[agent] = instance

        assert len(agents_instances) > 0, 'Could not resolve podsets {}'.format(self._pod_sets)

        self._agents_instances = agents_instances

    def _read_pods(self, yp_client, podset):
        pod_selectors = (
            '/meta/id',
            '/meta/pod_set_id',
            '/labels/pod_index',
            '/status/dns/persistent_fqdn',
            '/spec/node_id',
            '/spec/disk_volume_requests',
            '/spec/pod_agent_payload/spec/volumes',
            '/status/agent/iss_summary/state_summaries',
            '/status/agent/pod_agent_payload/status/ready/status',
        )
        records = yp_client.select_objects(
            'pod',
            filter='[/meta/pod_set_id]="{}"'.format(podset.deploy_unit),
            selectors=pod_selectors
        )

        return [
            Instance(dict(zip(pod_selectors, pod_data)))
            for pod_data in records
        ]

    def _read_nodes(self, yp_client, node_ids):
        node_selectors = (
            '/meta/id',
            '/labels/location/rack',
            '/labels/extras/network/bandwidth_10G'
        )
        records = yp_client.get_objects(
            'node',
            node_ids,
            selectors=node_selectors,
        )

        return {
            node_data[0]: Node(node_data[0], node_data[1], bool(node_data[2]))
            for node_data in records
        }

    @property
    def agents_instances(self):
        if not self._last_update or self._cache_age() > self._cache_for:
            self._update()
            self._last_update = datetime.datetime.now()

        return self._agents_instances

    @property
    def ids(self):
        return map(lambda x: x.id, self.agents_instances.itervalues())

    @property
    def agents(self):
        return self.agents_instances.iterkeys()

    @property
    def strict_host_agent_mapping(self):
        mapping = {}
        for agent in self.agents_instances:
            if agent != mapping.get(agent.node_name, agent):
                raise entities.HostsIntersection(
                    '{} conflict with {}'.format(agent, mapping[agent.node_name])
                )
            mapping[agent.node_name] = agent

        return mapping

    @property
    def agent_shard_number_mapping(self):
        shard_numbers = self._tier.list_shard_numbers()
        return {
            agent: shard_numbers[instance.pod_index % self._tier.shards_count]
            for agent, instance in self.agents_instances.items()
        }

    @property
    def tags(self):
        return self._tags

    @property
    def tier(self):
        return self._tier


EndpointSet = collections.namedtuple('EndpointSet', ['location', 'id'])


class EndpointSetResolver(object):
    def __init__(self, endpoint_sets, cache_for=120):
        self._endpoint_sets = endpoint_sets
        self._cache_for = cache_for

        self._endpoints_map = {}
        self._last_update = datetime.datetime.min

    def _cache_age(self):
        return (datetime.datetime.now() - self._last_update).seconds

    def _update(self):
        for endpoint_set in self._endpoint_sets:
            endpoints_filter = '[/meta/endpoint_set_id] = "{}"'.format(
                endpoint_set.id
            )
            with client(endpoint_set.location) as yp_client:
                try:
                    endpoints = yp_client.select_objects(
                        'endpoint',
                        filter=endpoints_filter,
                        selectors=['/spec/fqdn', '/spec/port']
                    )
                    if endpoints:
                        self._endpoints_map[endpoint_set] = endpoints
                    else:
                        _log.error('Empty endpoint set %s@%s',
                                     endpoint_set.location, endpoint_set.id)
                except yp.common.GrpcError as e:
                    _log.error(
                        'Exception while resolve endpoint set %s@%s:\n%s',
                        endpoint_set.location, endpoint_set.id,
                        e
                    )

    def get_server(self):
        if self._cache_age() > self._cache_for:
            self._update()
            self._last_update = datetime.datetime.now()

        choosen_endpoint = random.choice([
            endpoint
            for endpoints in self._endpoints_map.values()
            for endpoint in endpoints
        ])

        return '{}:{}'.format(
            choosen_endpoint[0],
            int(choosen_endpoint[1])
        )


_log = logging.getLogger(__name__)
