import time
import urllib2
import json
import re
import logging
import xmlrpclib
import datetime
import itertools
import operator

from collections import defaultdict
from collections import namedtuple
from multiprocessing.pool import ThreadPool

try:
    import kazoo
except ImportError:
    __import__('pkg_resources').require('zc_zookeeper_static', 'kazoo')
    import kazoo
import kazoo.client
import kazoo.exceptions


__all__ = ['survey']


def inconsistency(incs, type, msg,
        warn=logging.warning,
        **argv):
    ''' register non-lethal error ("inconsistency") '''
    warn(msg)
    inc = dict(
        type=type,
        message=msg)
    inc.update(**argv)
    incs.append(inc)


def survey(conf, sections=None):
    '''
        returns json dump of {cell_class_name : list for Storage.from_list}
    '''
    if not sections:
        sections = conf.sections.keys()

    zk = kazoo.client.KazooClient(conf.zk_hosts)
    for i in xrange(5):
        try:
            zk.start()
        except Exception as e:
            time.sleep(3)
            logging.exception(e)
        else:
            break
    else:
        raise RuntimeError('failed to establish zookeeper connection')

    section__keys = []
    trie_class__keys = []
    shard__keys = []
    sharded_class__keys = []
    instance__keys = []
    trie__keys = []

    section__trie_classes = defaultdict(list)
    section__shards = defaultdict(list)
    trie_class__section = {}
    trie_class__sharded_classes = defaultdict(list)
    trie_class__max_age = {}
    trie_class__period = {}
    shard__sharded_classes = defaultdict(list)
    shard__instances = defaultdict(list)
    sharded_class__tries = defaultdict(list)
    sharded_class__prod_trie = {}
    instance__shard = {}
    # instance__tries = {} # returned from function
    trie__sharded_class = {}
    trie__gen_time = {}
    trie__register_time = {}
    trie__accept_time = {}
    trie__size = {}

    _shard__root = {}
    _shortname__shard = {}
    _section__shards_amount = {}
    _config__sharded_classes = []

    incs = []  # list of non-lethal errors ("inconsistencies")

    # zookeeper config
    zk_conf_json = safe_zk_get(zk, conf.zk_conf_path, 3)
    zk_conf = json.loads(zk_conf_json)

    for section, sect_info in zk_conf['sections'].iteritems():
        if section in sections:
            section__keys.append(section)
            _section__shards_amount[section] = int(sect_info['shards_amount'])

            for n in xrange(int(sect_info['shards_amount'])):
                shard = (section, n)
                shard__keys.append(shard)
                _shard__root[shard] = sect_info['root'] + \
                    ('/shard-%s' % n if int(sect_info['shards_amount']) > 1 else '')
                _shortname = sect_info['root'].rsplit('/', 1)[-1].replace('-beta', '') + '-{0:0>3}'.format(n)
                _shortname__shard[_shortname] = shard

                section__shards[section].append(shard)

    for trie_class, tc_data in zk_conf['tries'].iteritems():
        for section in sections:
            if section in tc_data[conf.target]:
                trie_class__keys.append(trie_class)
                trie_class__section[trie_class] = section
                section__trie_classes[section].append(trie_class)
                trie_class__max_age[trie_class] = age_parser(tc_data['traits']['max_age'])
                trie_class__period[trie_class] = age_parser(tc_data['traits'].get('period', '0d'))

                for sharded_class in [((section, n), trie_class) for n in xrange(_section__shards_amount[section])]:
                    _config__sharded_classes.append(sharded_class)

                break

    # zookeeper
    for shard in shard__keys:
        # production_indexes
        classnames = safe_zk_get_children(zk, _shard__root[shard] + '/robot_state/production_indexes')
        for classname in classnames:
            sharded_class = (shard, classname)
            if sharded_class in _config__sharded_classes:
                sharded_class__keys.append(sharded_class)
                trie_class__sharded_classes[classname].append(sharded_class)
                shard__sharded_classes[shard].append(sharded_class)
                nodepath = '/'.join((_shard__root[shard], 'robot_state/production_indexes', classname))
                prod_trie = safe_zk_get(zk, nodepath, 3)
                sharded_class__prod_trie[sharded_class] = prod_trie
            else:
                inconsistency(incs,
                    type='trie class on shard controdicts config',
                    msg='sharded_class {0} found in production_indexes yet is inconsistent with zk config'.format(sharded_class),
                    sharded_class=sharded_class)

        # robot_indexes
        tries = safe_zk_get_children(zk, _shard__root[shard] + '/robot_state/robot_indexes')
        for trie in tries:
            classname = trie_classname(trie)
            if not any(sharded_class in shard__sharded_classes[shard] \
                    for sharded_class in trie_class__sharded_classes[classname]):
                inconsistency(incs,
                    type='trie off its shard',
                    msg='Trie {0} is off its shard and is ignored'.format(trie))
                continue
            sharded_class = (shard, classname)
            trie__sharded_class[trie] = sharded_class
            nodepath = _shard__root[shard] + '/robot_state/robot_indexes/' + trie
            trie_info_json = safe_zk_get(zk, nodepath, 3)
            try:
                trie_info = json.loads(trie_info_json)
            except (ValueError, TypeError) as e:
                inconsistency(incs,
                    type='malformed zookeeper trie info',
                    msg='Trie node {0} has info "{1}". It isn\'t even json'.format(trie, trie_info_json))
                continue
            trie__keys.append(trie)
            sharded_class__tries[sharded_class].append(trie)
            size = trie_info.get('trie_metadata', {}).get('trie_size', 0)
            gen_time = normalize_time(trie_info.get('trie_metadata', {}).get('Version', 0))
            register_time = trie_info.get('publisher_metadata', {}).get('register_time', 0)
            accept_time = trie_info.get('publisher_metadata', {}).get('accept_time', 0)
            trie__size[trie] = int(size)
            trie__gen_time[trie] = int(gen_time)
            trie__register_time[trie] = int(register_time)
            trie__accept_time[trie] = int(accept_time)

    for sharded_class in _config__sharded_classes:
        if sharded_class not in sharded_class__keys:
            inconsistency(incs,
                type='trie class absent from shard',
                msg='trie_class {0} not expanded on shard {1}'.format(sharded_class[1], sharded_class[0]),
                sharded_class=sharded_class)

    # cms
    for inst_dict in itertools.chain(*(get_instances(sectinfo['base_instances']) \
            for sectname, sectinfo in conf.sections.iteritems() if sectname in section__keys)):
        instance = (inst_dict['host'], inst_dict['port'])
        shard_shortname = inst_dict['shard'].rsplit('-', 1)[0]
        if shard_shortname not in _shortname__shard:
            inconsistency(incs,
                type='unknown shard of instance',
                msg='Instance {0} belongs to unknown shard {1} and is ignored'.format(instance, shard_shortname),
                instance=instance,
                shard=shard_shortname)
            continue
        instance__keys.append(instance)
        shard = _shortname__shard[shard_shortname]
        shard__instances[shard].append(instance)
        instance__shard[instance] = shard

    # instance protobufs / zookeeper cluster state
    if conf.direct_instance_inspection:
        # deprecated
        # actually, prohibited unless someone thoroughly debugs it
        # in which case it is strongly recommended to refactor and merge duplicate functionality
        #   in fetch_instance_tries and zk_get_instance_tries
        instance__tries = fetch_instance_tries(
            instance__keys,
            conf.protobuf_dl['threadcount'],
            conf.protobuf_dl['get_attempts'],
            incs)
    else:
        instance__tries = zk_get_instance_tries(zk, instance__keys, instance__shard, _shard__root, incs)

    zk.stop()

    # consistency enforcement
    # check all references to have actual keys
    # if referred key is not found, either reference or referring object is removed

    enforce_consistency = True

    if enforce_consistency:
        RefUp = namedtuple('RefUp', ['cont', 'keys'])

        RefDownT = namedtuple('RefDown', ['cont', 'keys', 'rm'])
        RefDown = lambda cont, keys, rm=None: RefDownT(cont, keys, rm)

        dead_ref_up = lambda cell_class, cell_key, ref_class, ref_key: \
            inconsistency(incs,
                type='relational consistency fault: up',
                msg='{0} {1} belongs to absent {2} {3} and is removed'.\
                    format(cell_class, cell_key, ref_class, ref_key),
                cell_class=cell_class,
                cell_key=cell_key,
                ref_class=ref_class,
                ref_key=ref_key)

        dead_ref_down = lambda cell_class, cell_key, ref_class, ref_key: \
            inconsistency(incs,
                type='relational consistency fault: down',
                msg='{0} {1} refers to absent {2} {3}'.\
                    format(cell_class, cell_key, ref_class, ref_key),
                cell_class=cell_class,
                cell_key=cell_key,
                ref_class=ref_class,
                ref_key=ref_key)

        sections = {}
        shards = {}
        trie_classes = {}
        sharded_classes = {}
        instances = {}
        tries = {}

        containers = sections, shards, trie_classes, sharded_classes, instances, tries

        sections.update(dict(
            cell_class='Section',
            keys=section__keys,
            refs_up=[],
            refs_down=[
                RefDown(shards, section__shards),
                RefDown(trie_classes, section__trie_classes),
            ]))
        shards.update(dict(
            cell_class='Shard',
            keys=shard__keys,
            refs_up=[
                RefUp(sections, lambda key: key[0]),
            ],
            refs_down=[
                RefDown(sharded_classes, shard__sharded_classes),
                RefDown(instances, shard__instances),
            ]))
        trie_classes.update(dict(
            cell_class='TrieClass',
            keys=trie_class__keys,
            refs_up=[
                RefUp(sections, trie_class__section),
            ],
            refs_down=[
                RefDown(sharded_classes, trie_class__sharded_classes),
            ]))
        sharded_classes.update(dict(
            cell_class='ShardedClass',
            keys=sharded_class__keys,
            refs_up=[
                RefUp(shards, lambda key: key[0]),
                RefUp(trie_classes, lambda key: key[1]),
                RefUp(tries, sharded_class__prod_trie),
            ],
            refs_down=[
                RefDown(tries, sharded_class__tries),
            ]))
        instances.update(dict(
            cell_class='Instance',
            keys=instance__keys,
            refs_up=[
                RefUp(shards, instance__shard),
            ],
            refs_down=[
                RefDown(trie_classes, lambda key: instance__tries[key].keys(),
                    lambda key, ref_key: operator.delitem(instance__tries[key], ref_key)),
                RefDown(tries, lambda key: instance__tries[key].values(),
                    lambda key, ref_key: del_by_value(instance__tries[key], ref_key)),
            ]))
        tries.update(dict(
            cell_class='Trie',
            keys=trie__keys,
            refs_up=[
                RefUp(sharded_classes, trie__sharded_class),
            ],
            refs_down=[]))

        while True:
            removed = False

            for cont in containers:
                to_remove = []
                for key in cont['keys']:
                    for ref in cont['refs_up']:
                        if isinstance(ref.keys, dict):
                            ref_key = ref.keys[key]
                        else:
                            ref_key = ref.keys(key)
                        if ref_key not in ref.cont['keys']:
                            if not to_remove or to_remove[-1] != key:
                                to_remove.append(key)
                            dead_ref_up(cont['cell_class'], key, ref.cont['cell_class'], ref_key)
                    if not to_remove or to_remove[-1] != key:
                        for ref in cont['refs_down']:
                            if isinstance(ref.keys, dict):
                                ref_keys = ref.keys[key]
                                if not ref_keys:
                                    to_remove.append(key)
                                    # TODO: write down inconsistency here. if anyone is going to use them, that is
                                    # raise StandardError('{0} not found in {1}...'.format(str(key), str(ref.keys)[:200]))
                            else:
                                ref_keys = ref.keys(key)
                            for ref_key in ref_keys:
                                if ref_key not in ref.cont['keys']:
                                    if ref.rm:
                                        ref.rm(key, ref_key)
                                    else:
                                        ref_keys.remove(ref_key)
                                    dead_ref_down(cont['cell_class'], key, ref.cont['cell_class'], ref_key)
                for key in to_remove:
                    cont['keys'].remove(key)
                removed |= bool(to_remove)

            if not removed:
                break

    shard_scalar_keys = dict([(vk, sk) for sk, vk in enumerate(shard__keys)])
    sharded_class_scalar_keys = dict([(vk, sk) for sk, vk in enumerate(sharded_class__keys)])
    instance_scalar_keys = dict([(vk, sk) for sk, vk in enumerate(instance__keys)])

    shard_scalar = lambda vk: shard_scalar_keys.get(vk)
    sharded_class_scalar = lambda vk: sharded_class_scalar_keys.get(vk)
    instance_scalar = lambda vk: instance_scalar_keys.get(vk)
    keepmap = lambda f, xs: filter(None, itertools.imap(f, xs))

    res = {}
    res['Section'] = [dict(
        section=section_key,
        trie_class_keys=section__trie_classes[section_key],
        shard_keys=keepmap(shard_scalar, section__shards[section_key])
        ) for section_key in section__keys]
    logging.info('Sections packed')

    res['TrieClass'] = [dict(
        trie_class=trie_class_key,
        section_key=trie_class__section[trie_class_key],
        sharded_class_keys=keepmap(sharded_class_scalar, trie_class__sharded_classes[trie_class_key]),
        max_age=trie_class__max_age[trie_class_key],
        period=trie_class__period[trie_class_key]
        ) for trie_class_key in trie_class__keys]
    logging.info('Trie Classes packed')

    res['Shard'] = [dict(
        key=shard_scalar_keys[shard_key],
        section_key=shard_key[0],
        number=shard_key[1],
        sharded_class_keys=keepmap(sharded_class_scalar, shard__sharded_classes[shard_key]),
        instance_keys=keepmap(instance_scalar, shard__instances[shard_key])
        ) for shard_key in shard__keys]
    logging.info('Shards packed')

    res['ShardedClass'] = [dict(
        key=sharded_class_scalar(sharded_class_key),
        shard_key=shard_scalar(sharded_class_key[0]),
        trie_class_key=sharded_class_key[1],
        trie_keys=sharded_class__tries[sharded_class_key],
        prod_trie_keys=sharded_class__prod_trie[sharded_class_key]
        ) for sharded_class_key in sharded_class__keys]
    logging.info('Sharded Classes packed')

    res['Instance'] = [dict(
        key=instance_scalar(instance_key),
        host=instance_key[0],
        port=instance_key[1],
        shard_key=shard_scalar(instance__shard[instance_key]),
        trie_keys=instance__tries[instance_key]
        ) for instance_key in instance__keys]
    logging.info('Instances packed')

    try:
        res['Trie'] = [dict(
            trie=trie_key,
            sharded_class_key=sharded_class_scalar(trie__sharded_class[trie_key]),
            gen_time=trie__gen_time[trie_key],
            register_time=trie__register_time[trie_key],
            accept_time=trie__accept_time[trie_key],
            size=trie__size[trie_key]
            ) for trie_key in trie__keys]
    except KeyError as e:
        trie = str(e)
        logging.error(('FAIL INFO\ntrie = {0}\ntrie__sharded_class = {1}\ntrie__gen_time = {2}' +\
                        '\ntrie__register_time = {3}\ntrie__accept_time = {4}\ntrie__size = {5}').\
                        format(trie, trie__sharded_class.get(trie, 'EMPTY'), trie__gen_time.get(trie, 'EMPTY'),
                        trie__register_time.get(trie, 'EMPTY'), trie__accept_time.get(trie, 'EMPTY'),
                        trie__size.get(trie, 'EMPTY')))
        raise

    logging.info('Tries packed')

    res['inconsistencies'] = incs

    res['timestamp'] = int(time.time())

    return json.dumps(res, indent=4)


def del_by_value(dct, val):
    for key, value in dct.items():
        if value == val:
            del dct[key]


def get_instances(section_cfg, server=xmlrpclib.ServerProxy('http://cmsearch.yandex.ru/xmlrpc/bs')):
    return server.listSearchInstancesIntersectTags(*section_cfg)


def age_parser(age, time_units={
        's': 1,
        'm': 60,
        'h': 3600,
        'd': 86400,
    }):
    return int(age[:-1]) * time_units[age[-1]]


def trie_classname(trie):
    return re.search(r'.*?\..*?(?=\-)', trie).group()


def normalize_time(version):
    version = str(version)
    if len(version) > 10:
        return int(version[:10])
    elif len(version) == 10:
        return int(version)
    elif len(version) == 8:
        return datetime.datetime(int(version[:4]), int(version[4:6]), int(version[6:8]), 4, 0).strftime('%s')
    elif version == '0':
        return 0
    else:
        raise ValueError(version)


def zk_get_instance_tries(zk, instance__keys, instance__shard, _shard__root, incs):
    instance__tries = {}
    remove_keys = []
    for instance in instance__keys:
        inst_node = '/'.join((_shard__root[instance__shard[instance]], 'cluster_state', instance[0] + '.yandex.ru'))
        if 'alive' not in safe_zk_get_children(zk, inst_node):
            inconsistency(incs,
                type='dead instance',
                msg='Instance {0} of shard {1} lacks \'alive\' tag in ZooKeeper'.\
                    format(instance, instance__shard[instance]),
                instance=instance,
                shard=instance__shard[instance])
            remove_keys.append(instance)
        else:
            try:
                for i in range(3):
                    inst_node_data = safe_zk_get(zk, inst_node)
                    if inst_node_data:
                        break
                if not inst_node_data:
                    inconsistency(incs,
                        type='node inaccessible',
                        msg="Can't access node {0} in zookeeper".format(inst_node))
                    continue
                instance_stats = json.loads(inst_node_data)['basesearch_state']
                assert instance_stats
                assert 'error' not in instance_stats
                stats = {}
                for trie_class, tries in instance_stats.iteritems():
                    if len(tries) == 1:
                        stats[trie_class] = tries[0]
                    else:
                        inconsistency(incs,
                            type='multiple tries on instance',
                            msg='Multiple tries of class {0} on instance {1} of shard {2}'.\
                                format(trie_class, instance, instance__shard[instance]),
                            instance=instance,
                            shard=instance__shard[instance])
                        stats[trie_class] = tries[-1]
                instance__tries[instance] = stats
            except:
                inconsistency(incs,
                    type='empty or erraneous instance stats',
                    msg='Malformed zookeeper stats for instance {0} on shard {1}'.\
                        format(instance, instance__shard[instance]),
                    instance=instance,
                    shard=instance__shard[instance])
                remove_keys.append(instance)

    for inst in remove_keys:
        instance__keys.remove(inst)
    remove_keys = []
    for inst in instance__keys:
        if inst not in instance__tries:
            logging.warning(
                'for not entirely comprehensible reason instance {0} is not associated with any tries info'.\
                format(inst))
            remove_keys.append(inst)
    logging.info('instance__keys {0}'.format(instance__keys))
    logging.info('remove_keys {0}'.format(remove_keys))
    for inst in remove_keys:
        instance__keys.remove(inst)
    return instance__tries


def fetch_instance_tries(instance__keys, threadcount, attempts, incs):
    '''
        normally shouldn't be used; is activated by direct_instance_inspection configuration option
        requires debugging to be used safely
    '''
    raise NotImplementedError()
    remove_keys = []
    logging.info('{0} instances reported by cms'.format(len(instance__keys)))
    instance__tries = {}
    tp = ThreadPool(threadcount)
    total = 0
    _t = tp.map(lambda inst: (inst, instance_tries(*inst, attempts=attempts)), instance__keys)
    logging.info('{0} results at imap_unordered'.format(len(_t)))
    for instance, instance_stats in _t:
        total += 1
        if instance_stats:
            stats = {}
            for trie_class, tries in instance_stats.iteritems():
                if len(tries) == 1:
                    stats[trie_class] = tries[0]
                else:
                    msg = 'Multiple tries of class {0} on instance {1}'.format(trie_class, instance)
                    logging.warning(msg)
                    incs.append(dict(
                        type='multiple tries',
                        message=msg))
                    stats[trie_class] = tries[-1]
            instance__tries[instance] = stats
        else:
            msg = 'Failed to get stats from instance {0}'.format(instance)
            logging.warning(msg)
            incs.append(dict(
                type='dead instance',
                message=msg))
            remove_keys.append(instance)
    for inst in remove_keys:
        instance__keys.remove(instance)
    logging.info('{0} instances confirmed'.format(len(instance__keys)))
    logging.info('{0} records in instance__tries'.format(len(instance__tries)))
    logging.info('DARK STATS:\n' + '\n'.join(('{0} {1}'.format(inst, stats) for inst, stats in _t if inst not in instance__tries)))
    return instance__tries


def instance_tries(host, port, attempts, timeout=5, raw=False):
    '''
        currently unused: data is already in zookeeper
        keeping here as a backup of a kind and for debugging purposes
    '''

    def get_instance_protobuf(host, port):
        '''
            str(instance) --> querysearch stats protobuf
            return None on any problems
        '''
        url = 'http://{0}.yandex.ru:{1}/yandsearch?text=1&hr=da&ms=proto&rearr=get_querydata_stats'.\
            format(host, port)
        try:
            return urllib2.urlopen(url, timeout=timeout).read()
        except Exception:  # sorry...
            return None

    def parse_instance_protobuf(protobuf,
            inst_name='unspecified',
            log_malformed_protobufs=True,
            alert=False,
            raw=False):
        '''
            returns {trie_class: [trie(s) from newest to oldest]}
        '''
        if not protobuf:
            if alert:
                logging.warning('Instance {0} seems unwilling to show its protobuf'.format(inst_name))
            return {}

        try:
            json_querydata_stats = re.search(
                re.compile(r'(?<=Key: \"QueryData\.stats\.nodump\"\n  Value: \").*?[^\\](?=\")',
                    re.S + re.M),
                protobuf).group().replace('\\', '')
            stats = json.loads(json_querydata_stats)
        except Exception as e:
            if alert:
                logging.error('Malformed protobuf at instance {0}'.format(inst_name))
                if log_malformed_protobufs:
                    with open('malformed_protobufs', 'a') as f:
                        f.write('\n'.join(('=' * 80, inst_name, str(datetime.datetime.now()),
                            protobuf, '\n\n\n')))
            return {}
        else:
            logging.info('Protobuf @ {0} ok'.format(inst_name))
        if raw:
            return stats
        res = {}
        try:
            if stats['sources']:
                # tries listen in config
                for trie_class, trie_list in stats['sources'].iteritems():
                    for trie_info in trie_list:
                        if 'errors' in trie_info and trie_info['errors']:
                            continue
                        trie = trie_info['file']['real'].split('/')[-2]
                        if trie > res.get(trie_class):
                            res.setdefault(trie_class, []).append(trie)
            else:
                # tries listed in lists
                for conf_list in stats['lists'].itervalues():
                    for sublist in conf_list:
                        for trie_class, trie_list in sublist['entries'].iteritems():
                            for trie_info in trie_list:
                                if 'errors' in trie_info and trie_info['errors']:
                                    continue
                                trie = trie_info['file']['real'].split('/')[-2]
                                if trie > res.get(trie_class):
                                    res.setdefault(trie_class, []).append(trie)
            for trie_class, tries in res.iteritems():
                res[trie_class] = sorted(tries, reverse=True)
        except Exception as e:
            res = {}
            res['stats'] = json_querydata_stats
            res['error'] = str(e)
        return res

    for n in xrange(attempts):
        res = parse_instance_protobuf(
            get_instance_protobuf(host, port), '%s:%s' % (host, port), alert=(n == attempts-1),
            raw=raw)
        if res:
            return res
    return {}


def safe_zk_get(zk, path, attempts=3):
    '''
        get node info with automatic reask
        wrapper for kazoo handle
    '''
    res = None
    for _ in xrange(attempts):
        try:
            res = zk.get(path)
        except kazoo.exceptions.NoNodeError:
            time.sleep(1)
        else:
            break
    if res:
        return res[0]
    else:
        return None


def safe_zk_get_children(zk, path, attempts=3):
    '''
        get node info with automatic reask
        wrapper for kazoo handle
    '''
    res = {}
    for _ in xrange(attempts):
        try:
            res = zk.get_children(path)
        except kazoo.exceptions.NoNodeError:
            time.sleep(1)
        else:
            break
    if not res:
        logging.warning('Node {0} is inaccessible'.format(path))
    return res


def restore(foo):
    '''
        lists to tuples
        currently unused; keeping it for sentimental reasons
    '''
    def list_to_tuple(x):
        if isinstance(x, list) or isinstance(x, tuple):
            return tuple(map(list_to_tuple, x))
        else:
            return x

    def resfoo(*args, **kwargs):
        res = foo(*args, **kwargs)
        return list_to_tuple(res)

    return resfoo
