import logging
import random
import re
import time
import six
import copy
import requests
import threading
import itertools

from itertools import groupby
from six import iteritems, string_types
from typing import List, Dict  # noqa
from cached_property import cached_property_with_ttl

import saas.tools.devops.lib23.deploy_manager_api as dm_api
import saas.tools.devops.lib23.nanny_helpers as nanny
import saas.tools.devops.lib23.saas_slot as saas_slot  # noqa

from saas.tools.devops.lib23.endpointset import Endpointset, YpLitePod


SLOT_OPERATION_BATCH_SIZE = 20


class SaasService(object):
    LOGGER = logging.getLogger(__name__)
    _dm = dm_api.DeployManagerApiClient()

    def __init__(self, ctype, service):
        self._ctype = ctype
        self._service = service
        self._service_type = 'rtyserver'

    def __eq__(self, other):
        if isinstance(other, SaasService):
            return self._ctype == other._ctype and self.name == other.name and self._service_type == other._service_type
        else:
            return NotImplemented

    def __ne__(self, other):
        if isinstance(other, SaasService):
            return self._ctype != other._ctype or self.name != other.name or self._service_type != other._service_type
        elif isinstance(other, string_types):
            return self.name != other
        else:
            return NotImplemented

    def __hash__(self):
        return hash((self._ctype, self._service, self._service_type))

    def __repr__(self):
        return 'SaasService({}, {})'.format(self.ctype, self.name)

    def __str__(self):
        return 'SaasService({}/{})'.format(self.ctype, self.name)

    @property
    def free_slots_pool(self):
        return self._dm.get_free_slots(ctype=self._ctype, service=self._service)

    @property
    def ctype(self):
        return self._ctype

    @property
    def name(self):
        return self._service

    @property
    def prj_tag(self):
        if self.ctype.startswith('stable'):
            return '-'.join([self.ctype.replace('stable', 'saas'), self.name]).replace('_', '-')
        else:
            return 'saas-{}'.format(self.ctype)

    @cached_property_with_ttl(ttl=30)
    def endpointsets(self):
        """
        :rtype: Set[Endpointset]
        """
        return set(itertools.chain(self.endpointsets_by_shard.values()))

    @cached_property_with_ttl(ttl=30)
    def endpointsets_by_shard(self):
        """
        :rtype: Dict[Str, Set[Endpointset]]
        """
        endpointsets_by_shard = {}  # type: Dict[Str, Set[Endpointset]]

        slots_by_interval = self._dm.get_slots_by_interval(self.ctype, self.name)
        for interval in slots_by_interval:
            shard_id = interval['id']
            for s in interval['slots']:
                if s['is_sd']:
                    cluster, es_id, port = re.findall('(.*)@(.*):(.*)', s['id'])[0]
                    if shard_id not in endpointsets_by_shard.keys():
                        endpointsets_by_shard[shard_id] = set()
                    endpointsets_by_shard[shard_id].add(Endpointset(es_id, cluster.replace('-', '_')))
            self.LOGGER.debug('Shard %s endpointsets: %s', shard_id, endpointsets_by_shard.get(shard_id, None))
        return endpointsets_by_shard

    @cached_property_with_ttl(ttl=30)
    def shards(self):
        return self._dm.get_cluster_map(self._ctype, self._service).keys()

    @cached_property_with_ttl(ttl=30)
    def shards_with_clusters(self):
        cl_map = self._dm.get_cluster_map(self._ctype, self._service)
        result = {}
        for shard, dcs in cl_map.items():
            clusters = []
            for dc, slots in dcs.items():
                for slot in slots:
                    if 'yp-c' in slot.host:
                        clusters.append(slot.host.split('.')[1].upper().replace('-', '_'))
            result[shard] = list(set(clusters))
        return result

    @cached_property_with_ttl(ttl=30)
    def slots(self):
        # type: () -> List[saas_slot.Slot]
        all_slots = []
        for shard, replicas in iteritems(self._dm.get_cluster_map(self._ctype, self._service)):
            for location, location_slots in iteritems(replicas):
                all_slots.extend(location_slots)
        return all_slots

    def get_search_map(self):
        return self._dm.search_map(self.ctype, self._service)

    def get_slots_on_hosts(self, hostlist):
        result = []
        host_set = set(filter(None, hostlist))
        for shard, replicas in iteritems(self._dm.get_cluster_map(self._ctype, self._service)):
            for location, location_slots in iteritems(replicas):
                result.extend([s for s in location_slots if s.host in host_set or s.physical_host in host_set])
        return result

    @cached_property_with_ttl(ttl=30)
    def sla_info(self):
        return self._dm.get_sla(self._ctype, self._service)

    @cached_property_with_ttl(ttl=30)
    def tags_info(self):
        return self._dm.get_tags_info(self._ctype, self._service)

    @property
    def nanny_services_from_tags_info(self):
        return [x.split('@')[0] for x in self.tags_info.get('nanny_services', [])]

    @cached_property_with_ttl(ttl=60)
    def nanny_services(self):
        return set(self._dm.get_nanny_services(self.ctype, self.name))

    @cached_property_with_ttl(ttl=30)
    def gencfg_groups(self):
        # type: () -> Set[saas_slot.gencfg.GencfgGroup]
        return self.get_gencfg_groups(names_only=False)

    @cached_property_with_ttl(ttl=30)
    def gencfg_groups_names(self):
        # type: () -> Set[str]
        return self.get_gencfg_groups(names_only=True)

    def get_gencfg_groups(self, names_only=False):
        # type: (bool) -> (Set[str], Set[saas_slot.gencfg.GencfgGroup])
        all_groups = set()
        for nanny_service in self.nanny_services:
            try:
                if not nanny_service.is_gencfg_allocated():
                    continue
                all_groups |= nanny_service.get_gencfg_groups(no_raise=True, names_only=names_only)
            except nanny.NannyApiError:
                self.LOGGER.critical('Saas service %s%:%s has invalid nanny service %s', self.ctype, self.name, nanny_service.name)
        return all_groups

    def update_tags_info_nanny(self, nanny_services, resolve_containers=True):
        return self._dm.modify_tags_info(self._ctype, self._service, {'nanny_services': ','.join(nanny_services), 'use_container_names': resolve_containers})

    @cached_property_with_ttl(ttl=5)
    def per_dc_search(self):
        return self._dm.get_per_dc_search(self._ctype, self._service)

    def get_free_slots_in_geo(self, geo):
        return self._dm.get_free_slots_in_geo(ctype=self._ctype, service=self._service, geo=geo)

    def _get_service_replicas(self):

        replicas = {}
        per_dc_search = self.per_dc_search
        cluster_map = self._dm.get_cluster_map(self._ctype, self._service)

        for slot_id, replicas in six.iteritems(cluster_map):
            if per_dc_search:
                for geo, slots in six.iteritems(replicas):
                    replicas['{}@{}'.format(slot_id, geo)] = slots
            else:
                all_slots = []
                for geo, slots in six.iteritems(replicas):
                    all_slots.extend(slots)

                replicas['{}@ALL'.format(slot_id)] = all_slots

        return replicas

    def get_service_sharding(self):

        # FIXME: Detect replicas intersection
        replicas = self._get_service_replicas()
        shards = len(replicas)
        replicas = min([len(replica) for replica in replicas.values()])
        per_dc_search = self.per_dc_search

        return {
            'shards': shards,
            'replicas': replicas,
            'per_dc_search': per_dc_search
        }

    def get_rtyserver_diffconfig_raw(self):
        path = 'configs/{}/rtyserver.diff-{}'.format(self._service, self._service)
        return self._dm.get_storage_file(path=path)

    def sample_info_server(self):

        info_server = None
        for slot in self.slots:  # type: saas_slot.Slot
            try:
                info_server = slot.info_server
                break
            except requests.exceptions.BaseHTTPError:
                continue

        assert info_server is not None, "Could not retrieve info_server for any slot in service {}/{}".format(
            self._ctype,
            self._service
        )

        return info_server

    def has_backup(self):

        info_server = self.sample_info_server()

        # This logic originaly is ported from https://a.yandex-team.ru/arc/trunk/arcadia/junk/anikella/saas_dm/saas_deploy/saas_deploy/templates/service_cluster_table.html?rev=4502118#L282
        # Some improvements are on he way
        has_backup = False
        conf = info_server['config']['Server'][0]['ModulesConfig'][0]
        BACKUP_INDICATOR_FIELDS = ['SyncPath', 'BackupTable']

        if 'DOCFETCHER' not in conf.keys():
            self.LOGGER.debug("No DOCFETCHER module found in %s/%s", self._ctype, self._service)

        for docfetcher in conf['DOCFETCHER']:

            if docfetcher["Enabled"] != '1':
                self.LOGGER.debug("DOCFETCHER module is disabled in %s/%s", self._ctype, self._service)
            else:

                streams = [st for st in docfetcher['Stream'] if st.get('Enabled', '0') != '0']
                if len(streams) > 1:
                    self.LOGGER.debug("More than 1 stream found in %s/%s", self._ctype, self._service)
                    for stream in streams:
                        if stream.get('SyncPath', "") != "":
                            has_backup = True
                            break
                else:
                    for key in BACKUP_INDICATOR_FIELDS:
                        if streams[0].get(key, "") != "":
                            self.LOGGER.debug('Found %s, assuming service %s/%s has backup', key, self._ctype, self._service)
                            has_backup = True
                            break
                    if not has_backup and streams[0].get('SnapshotPath', '') != '':
                        has_backup = (streams[0].get('StreamType', '') == 'Snapshot'
                                      and streams[0].get('ConsumeMode', '') in ('replace', 'hard_replace'))
        return has_backup

    def restore_index_from_backup(self, num_docs_tolerance_factor=0.999, time_interval=120, slot_filter=".*", slot_id_list=None, force_backup=False, degrade_level=1, degrade_level_by_shard=1):
        def get_num_docs(slot, retries=1):
            for i in range(retries):
                try:
                    return slot.info_server['docs_in_final_indexes']
                except:
                    logging.exception("Can't obtain num_docs for slot %s", slot.id)
            return None

        def restore_slot(slot, numdoc_treshold, time_interval=120):
            if get_num_docs(slot, retries=3) > 0:
                logging.info("Clearing index at slot {}".format(slot.id))
                for i in range(100):
                    try:
                        slot.clear_index()
                        break
                    except:
                        pass
                try:
                    slot.shutdown()
                except:
                    slot.abort()
                time.sleep(30)  # Just to aviod 30s caching of get_info_server()
            num_docs = 0
            while num_docs < numdoc_treshold:
                num_docs = get_num_docs(slot) or 0
                logging.debug("%s docs found in slot %s. Waiting for at least %s", num_docs, slot.id, numdoc_treshold)
                time.sleep(time_interval)
            logging.info("Assuming slot %s is ready now", slot.id)
            slot.execute_command('enable_search')

        assert force_backup or self.has_backup(), "Can restore frome backup only services with backup in YT"

        num_docs_by_slot = {}
        for slot in self.slots:  # type: saas_slot.Slot
            try:
                num_docs_by_slot[slot.id] = slot.info_server['docs_in_final_indexes']
            except:
                self.LOGGER.exception("Can't get num docs in slot %s", slot.id)

        avg_num_docs = sum(num_docs_by_slot.values()) / len(num_docs_by_slot)
        self.LOGGER.info("Average documents per instance: %s", avg_num_docs)
        numdoc_treshold = avg_num_docs * num_docs_tolerance_factor

        shuffled_slots = [slot for slot in self.slots]
        random.shuffle(shuffled_slots)

        restoring_slots_by_shard = {}
        for slot in shuffled_slots:
            if slot.wait_up() and re.search(slot_filter, slot.id) is not None and (slot_id_list is None or slot.id in slot_id_list):

                if self.per_dc_search:
                    shard_id = '{}-{}@{}'.format(slot.shards_min, slot.shards_max, slot.geo)
                else:
                    shard_id = '{}-{}'.format(slot.shards_min, slot.shards_max)
                if shard_id not in restoring_slots_by_shard:
                    restoring_slots_by_shard[shard_id] = []

                self.LOGGER.debug("Current tasks to restore %s", restoring_slots_by_shard)
                if len(restoring_slots_by_shard[shard_id]) >= degrade_level_by_shard:
                    self.LOGGER.info("Current tasks to restore %s", restoring_slots_by_shard)
                    thread_to_wait = restoring_slots_by_shard[shard_id].pop()
                    self.LOGGER.info("Waiting for thread %s", thread_to_wait)
                    thread_to_wait.join()
                if sum([len(x) for x in restoring_slots_by_shard.values()]) >= degrade_level:
                    self.LOGGER.info("Current tasks to restore %s", restoring_slots_by_shard)
                    for threads in restoring_slots_by_shard.values():
                        if len(threads) > 0:
                            thread_to_wait = threads.pop()
                    self.LOGGER.info("Waiting for thread %s", thread_to_wait)
                    thread_to_wait.join()

                t = threading.Thread(
                    target=restore_slot,
                    args=(slot, numdoc_treshold, time_interval),
                    name='restore_{}'.format(slot.id)
                )
                t.start()
                restoring_slots_by_shard[shard_id].append(t)
                self.LOGGER.debug("Current tasks to restore %s", restoring_slots_by_shard)
                time.sleep(time_interval)

        for _ in restoring_slots_by_shard.values():
            for t in _:
                self.LOGGER.debug('Waiting for thread %s', t)
                t.join()

    def duplicate_slots(self, slots, new_slots_pool=None):
        # type: (List[saas_slot.Slot], List[saas_slot.Slot]) -> List[saas_slot.Slot]
        """
        :param slots: List of slots to duplicate
        :param new_slots_pool: List of free slots
        :type slots: List[Slot]
        :type new_slots_pool: List[Slot]
        """
        result = []
        for geo, slots in groupby(slots, lambda sl: sl.geo):
            intervals = [s.interval for s in slots]
            free_slots_loc = [fs for fs in new_slots_pool if fs.geo == geo] if new_slots_pool else None
            added_slots = self._dm.allocate_same_slots(ctype=self._ctype, service=self._service, geo=geo, intervals=intervals, new_slots_pool=free_slots_loc)
            for slot in added_slots:
                slot.restart()
                slot.disable_search()
            result.extend(added_slots)
        return result

    def get_release_branch(self):
        info_server = self.sample_info_server()
        svn_root = info_server['Svn_root']

        saas_branch_candidates = re.findall(r'.*branches/saas/([\d.]+)/.*', svn_root)
        if len(saas_branch_candidates) == 1:
            return saas_branch_candidates[0]

        refresh_branch_candidates = re.findall(r'.*/arc/tags/base/stable-([\d.]+)-([\d.]+)/.*', svn_root)
        if len(refresh_branch_candidates) == 1:
            return 'refresh'

        if svn_root == 'svn://arcadia.yandex.ru/arc/trunk/arcadia':
            return 'trunk'

        raise Exception("Can't guess branch from svn_root {}".format(svn_root))

    def clear_dm_cache(self):
        self._dm.clear_cache(self.ctype, self.name)

    def modify_searchmap(self, slots, action, no_batch=False, **kwargs):
        """
        Returns actual list of processed slots
        """
        batch = []
        result = []
        for sl in slots:
            batch.append(sl)
            if len(batch) >= SLOT_OPERATION_BATCH_SIZE or no_batch:
                if self._dm.modify_searchmap(ctype=self._ctype, service=self._service, slots=batch, action=action, **kwargs).ok:
                    result += batch
                batch = []

        if len(batch) > 0:
            if self._dm.modify_searchmap(ctype=self._ctype, service=self._service, slots=batch, action=action, **kwargs).ok:
                result += batch

        return result

    def set_searchmap(self, slots=None, search_enabled=None, indexing_enabled=None, **kwargs):
        actions = []
        if slots is None:
            slots = self.slots
        if search_enabled is not None:
            if search_enabled:
                actions.append('enable_search')
            else:
                actions.append('disable_search')
        if indexing_enabled is not None:
            if indexing_enabled:
                actions.append('enable_indexing')
            else:
                actions.append('disable_indexing')
        for action in actions:
            self.modify_searchmap(slots, action, **kwargs)

    def release_slots(self, slots, no_batch=False):
        """
        Returns actual list of released slots
        """
        batch = []
        result = []
        for sl in slots:
            batch.append(sl)
            if len(batch) >= SLOT_OPERATION_BATCH_SIZE or no_batch:
                if self._dm.release_slots(self._ctype, self._service, slots).ok:
                    result += batch
                batch = []

        if len(batch) > 0:
            if self._dm.release_slots(self._ctype, self._service, slots).ok:
                result += batch

        return result

    def release_gencfg_slots(self):
        gencfg_slots = [sl for sl in self.slots if '.yp-c.' not in sl.id]
        self.release_slots(gencfg_slots)

    def add_endpoint_sets(self, shards_endpoint_sets):
        return self._dm.add_endpoint_sets(self.ctype, self.name, shards_endpoint_sets)

    def update_pod_labels(self, update_disable_search=False):  # not meant to be here
        for shard_id, slots_info in iteritems(self.get_search_map()):
            yp_shard_label = "{}_{}".format(self.name, shard_id.replace('-', '_'))
            labels = {'shard': yp_shard_label, 'saas_service': self._service, 'saas_ctype': self._ctype}
            for slot, slot_info in iteritems(slots_info):
                self.LOGGER.debug('%s searchmap info: %s', slot, slot_info)
                if 'yp-c' in slot.id:  # FixMe
                    local_labels = copy.deepcopy(labels)
                    if update_disable_search:
                        local_labels['disable_search'] = str(slot_info['disable_search'])

                    pod = YpLitePod(fqdn=slot.id)
                    pod.update_pod_labels(**local_labels)

        free_slots_labels = {'shard': '', 'saas_service': self._service, 'saas_ctype': self._ctype}
        for slot in self.free_slots_pool:
            if 'yp-c' in slot.id and self._ctype not in ('prestable', 'testing') :  # FixMe
                self.LOGGER.debug('Updating free slot %s', slot.id)
                pod = YpLitePod(fqdn=slot.id)
                pod.update_pod_labels(**free_slots_labels)

    def create_endpointsets(self, do_update_dm=False):
        self.LOGGER.info('Service sharding: %s', self.shards)
        try:
            nanny_services = self.nanny_services
        except:
            from saas.tools.devops.lib23.nanny_helpers import NannyService
            if self.ctype == 'testing':
                nanny_services = [NannyService('saas_yp_cloud_base_testing')]
            elif self.ctype == 'prestable':
                nanny_services = [NannyService('saas_yp_cloud_prestable')]
            else:
                raise Exception("Can't guess nanny services for {}/{}".format(self.ctype, self.name))
        dm_shards = {}
        for ns in nanny_services:
            if ns.is_yp_lite():
                for yp_cluster in ns.yp_clusters:
                    all_instances_endpointset_name = ns.name
                    all_instances_endpointset = Endpointset.get_endpointset(all_instances_endpointset_name, yp_cluster)
                    if all_instances_endpointset is None:
                        ns.create_endpointset(
                            endpointset_name=all_instances_endpointset_name,
                            cluster=yp_cluster,
                            pod_filter="",
                            protocol='tcp',
                            port=80,
                            description='All instances in {}'.format(ns.name)
                        )

                    for shard_id, shard_clusters in self.shards_with_clusters.items():
                        if yp_cluster not in shard_clusters:
                            continue
                        if dm_shards.get(shard_id, None) is None:
                            dm_shards[shard_id] = set()
                        exist_in_dm = False
                        yp_shard_label = '{}_{}'.format(self.name, shard_id.replace('-', '_'))
                        pod_filter = '[/labels/shard] = "{}" and [/labels/disable_search] != "True"'.format(yp_shard_label)
                        description = 'Endpointset for SaaS service {}/{}'.format(self.ctype, self.name)
                        endpointset_name = "{}--{}--{}".format(ns.name, self.name, shard_id)

                        dm_endpoint_sets = self.endpointsets_by_shard.get(shard_id, set())
                        for es in dm_endpoint_sets:
                            self.LOGGER.info('Existing endpointset %s found', es)
                            if yp_cluster == es.cluster:
                                endpointset_name = es.name
                                exist_in_dm = True

                        endpointset = Endpointset.get_endpointset(endpointset_name, yp_cluster)

                        if endpointset is None:

                            self.LOGGER.info("Creating endpointset %s in cluster %s for nanny service %s", endpointset_name, yp_cluster, ns.name)
                            endpointset = ns.create_endpointset(
                                endpointset_name=endpointset_name,
                                cluster=yp_cluster,
                                pod_filter=pod_filter,
                                protocol='tcp',
                                port=80,
                                description=description
                            )
                            self.LOGGER.debug(
                                "ns.create_endpointset (endpointset_name=%s, cluster=%s, pod_filter=%s, protocol='tcp', port=80, description=description); Result=%s",
                                endpointset_name, yp_cluster, pod_filter, endpointset
                            )
                        else:
                            endpointset.update(
                                pod_filter=pod_filter,
                                description=description
                            )

                        if not exist_in_dm:
                            dm_shards[shard_id].add(Endpointset(endpointset_name, yp_cluster))
            else:
                self.LOGGER.info('%s is not YP_LITE', ns)

        if do_update_dm:
            time.sleep(90)  # give time for endpoinsets to emerge
            batch = []
            for shard in dm_shards:
                batch.append(shard)
                if len(batch) > 20:
                    res = self.add_endpoint_sets(dm_shards)
                    if res is not None:
                        self.LOGGER.info('Endpointsets registered in DM: %s', res)
                    else:
                        self.LOGGER.error('Adding endpointsets failed')
                    batch = []
            res = self.add_endpoint_sets(dm_shards)
            if res is not None:
                self.LOGGER.info('Endpointsets registered in DM: %s', res)
            else:
                self.LOGGER.error('Adding endpointsets failed')
