import logging
import time

import weakref
import json

from . import panopticon_util as _util


class BaseCell(object):
    class CellMeta(type):
        def __new__(cls, name, bases, dct):
            dct['identity'] = property(dct['identity'])
            return type.__new__(cls, name, bases, dct)

    def identity(self):
        raise NotImplementedError()

    __metaclass__ = CellMeta

    class Ref(object):
        def __init__(self, cell_key, cell_class, snapshot):
            self.cell_key = cell_key
            self.cell_class = cell_class
            self.snapshot = snapshot

        def actualize(self):
            try:
                return weakref.proxy(self.snapshot.storages[self.cell_class.__name__].get(self.cell_key))
            except:
                logging.error('Dead reference to cell of class {0} with key {1}'.\
                    format(self.cell_class.__name__, self.cell_key))
                raise

        @property
        def key(self):
            return self.cell_key

    ref = lambda self, name, cell_class: BaseCell.Ref(name, cell_class, self.snapshot)
    ref_list = lambda self, names, cell_class: [BaseCell.Ref(name, cell_class, self.snapshot) for name in names]
    ref_dict = lambda self, _dict, cell_class: \
                        dict((k, BaseCell.Ref(v, cell_class, self.snapshot)) for k, v in _dict.iteritems())

    def __init__(self, snapshot, **kwargs):
        self.snapshot = snapshot

        self.__connected = False

    def connect(self):
        assert not self.__connected
        for key, value in self.__dict__.iteritems():
            if isinstance(value, BaseCell.Ref):
                try:
                    self.__dict__[key] = value.actualize()
                except:
                    logging.error("could not actualize key '%s'", key)
            elif isinstance(value, list):
                for i, x in enumerate(value):
                    if isinstance(x, BaseCell.Ref):
                        try:
                            value[i] = x.actualize()
                        except:
                            logging.error("could not actualize key '%s', index '%d'", key, i)
            elif isinstance(value, dict):
                for k, v in value.iteritems():
                    if isinstance(v, BaseCell.Ref):
                        try:
                            value[k] = v.actualize()
                        except:
                            logging.error("could not actualize key '%s', subkey '%s'", key, k)
        self.__connected = True


class CellStorage(object):
    def __init__(self, cell_class, snapshot):
        self.cell_class = cell_class
        self.snapshot = snapshot
        self.__dict = {}
        self.__connected = False

    def add(self, *args, **kwargs):
        try:
            cell = self.cell_class(*args, snapshot=self.snapshot, **kwargs)
        except:
            logging.error('Couldn\'t create cell of class {0} from args:\n{1} {2}'.format(self.cell_class.__name__, str(args), str(kwargs)))
            raise
        assert cell.key not in self.__dict
        self.__dict[cell.key] = cell

    def get(self, cell_key):
        return self.__dict.get(cell_key)

    def __iter__(self):
        return self.__dict.itervalues()

    def connect(self):
        assert not self.__connected
        for cell in self.__dict.itervalues():
            if cell:
                cell.connect()
        self.__connected = True

    @staticmethod
    def from_list(_list, cell_class, snapshot):
        assert isinstance(_list, list)
        res = CellStorage(cell_class, snapshot)
        for x in _list:
            res.add(**x)
        return res


class Section(BaseCell):
    def __init__(self, section, trie_class_keys, shard_keys, snapshot):
        super(Section, self).__init__(snapshot)

        self.section = section

        self.trie_classes = self.ref_list(trie_class_keys, TrieClass)
        self.shards = self.ref_list(shard_keys, Shard)

        self.key = section

    def identity(self):
        return self.section


class TrieClass(BaseCell):
    def __init__(self, trie_class, section_key, sharded_class_keys, snapshot, max_age, period):
        super(TrieClass, self).__init__(snapshot)

        self.trie_class = trie_class

        self.section = self.ref(section_key, Section)
        self.sharded_classes = self.ref_list(sharded_class_keys, ShardedClass)

        self.max_age = max_age
        self.period = period

        self.key = trie_class

    def identity(self):
        return self.trie_class


class Shard(BaseCell):
    def __init__(self, key, section_key, number, sharded_class_keys, instance_keys, snapshot):
        super(Shard, self).__init__(snapshot)

        self.number = number

        self.section = self.ref(section_key, Section)
        self.sharded_classes = self.ref_list(sharded_class_keys, ShardedClass)
        self.instances = self.ref_list(instance_keys, Instance)

        self.key = key

    def identity(self):
        return '{0}-{1}'.format(self.section.identity, self.number)


class ShardedClass(BaseCell):
    def __init__(self, key, shard_key, trie_class_key, trie_keys, prod_trie_keys, snapshot):
        super(ShardedClass, self).__init__(snapshot)

        self.shard = self.ref(shard_key, Shard)
        self.trie_class = self.ref(trie_class_key, TrieClass)

        self.tries = self.ref_list(trie_keys, Trie)
        self.prod_trie = self.ref(prod_trie_keys, Trie)

        self.key = key

    def identity(self):
        return '{0}:{1}'.format(self.shard.identity, self.trie_class.identity)


class Instance(BaseCell):
    def __init__(self, key, host, port, shard_key, trie_keys, snapshot):
        super(Instance, self).__init__(snapshot)

        self.host = host
        self.port = port

        self.shard = self.ref(shard_key, Shard)
        self.tries = self.ref_dict(trie_keys, Trie)

        self.key = key

    def identity(self):
        return '{0}:{1}'.format(self.host, self.port)


class Trie(BaseCell):
    def __init__(self, trie, sharded_class_key, gen_time, register_time, accept_time, size, snapshot):
        super(Trie, self).__init__(snapshot)

        self.trie = trie

        self.sharded_class = self.ref(sharded_class_key, ShardedClass)

        self.gen_time = gen_time  # = Version ~ stamp_in_name
        self.register_time = register_time
        self.accept_time = accept_time
        self.size = size

        self.key = trie

    def identity(self):
        return self.trie


class Snapshot(object):

    cell_classes = Section, TrieClass, Shard, ShardedClass, Instance, Trie

    def __init__(self, survey=None, conf=None, sections=None, consistency_check=None):
        '''
            conf if imported python module with configuration
                (see conf/panopticon_conf or something)
            sections is explicit list of sections of interest
                (leave None to rely on conf)
            survey is json survey result
                (in case you want to make survey and construct snapshot object separately)
            consistency_check is callable for, well, consistency check on survey
                (recommended to derive it from ConsistencyCheck)
        '''
        if not survey:
            assert conf
            survey = Snapshot.survey(conf, sections)
        data = json.loads(survey)
        logging.info('Survey has timestamp {0}'.format(data['timestamp']))

        if consistency_check:
            check_result = consistency_check(data['inconsistencies'])
            if not check_result.passed:
                raise ConsistencyFault(check_result.report)

        self.storages = dict((cl.__name__, CellStorage.from_list(data[cl.__name__], cl, self)) \
            for cl in Snapshot.cell_classes)
        for storage in self.storages.itervalues():
            storage.connect()

    @staticmethod
    def survey(conf, sections=None):
        return _util.survey(conf, sections)

    def test_check(self):
        '''  '''
        pass

    def deploy_check(self):
        ''' tries in production: according zk vs on tries '''
        pass

    def freshness_check(self):
        '''
            gen_time vs now - max_age
                for trie in prod (by zk)
        '''
        pass

    def cover_ass_check(self, obsolete_only=True, detailed=False):
        '''
            gen_time vs now - max_age
                for trie in prod (on insts)
            return {trie_class: {
                max_age_on_insts: int,
                avg_age_on_insts: int,
                obsolete_insts: {inst: (shard, age, overage)}
            }}
        '''
        obsolete_inst = lambda shard, age, overage: dict(shard=shard, age=age, overgae=overage)

        res = {}
        now = int(time.time())
        age = lambda trie: now - trie.gen_time
        for trie_class in self.storages['TrieClass']:
            ages = {}
            for sharded_class in trie_class.sharded_classes:
                shard = sharded_class.shard
                ages[shard.identity] = {}
                for inst in shard.instances:
                    try:
                        ages[shard.identity][inst.identity] = age(inst.tries[trie_class.key])
                    except KeyError as e:
                        logging.warning('instance {0} shard {1} absent trie {2}'.format(inst.identity, shard.identity, trie_class.identity))
                        logging.exception(e)

            try:
                max_age_on_insts = max(max(inst.itervalues()) for inst in ages.itervalues())
                avg_age_on_insts = sum(sum(inst.itervalues()) for inst in ages.itervalues()) / \
                                    sum(len(inst) for inst in ages.itervalues())
            except ValueError as e:
                logging.warning('{0} seems to be utterly absent'.format(trie_class.identity))
                logging.exception(e)
                continue

            obsolete_insts = {}
            total_insts_count = 0
            for shard, insts in ages.iteritems():
                for inst, _age in insts.iteritems():
                    total_insts_count += 1
                    if _age > trie_class.max_age:
                        obsolete_insts[inst] = obsolete_inst(shard, _age, _age - trie_class.max_age)
            if not obsolete_only or obsolete_insts:
                _el = res.setdefault(trie_class.section.identity, {})
                _el[trie_class.identity] = dict(
                    age_threshold=trie_class.max_age,
                    max_age_on_insts=max_age_on_insts,
                    avg_age_on_insts=avg_age_on_insts,
                    relative_max_age=float(max_age_on_insts) / trie_class.max_age,
                    relative_avg_age=float(avg_age_on_insts) / trie_class.max_age)
                if detailed:
                    _el[trie_class.identity].update(dict(
                        obsolete_insts=obsolete_insts,
                        obsolete=(len(obsolete_insts), total_insts_count),
                        threshold=trie_class.max_age,
                        period=trie_class.period))
        return res


class ConsistencyFault(StandardError):
    pass


class CheckConsistency(object):
    tolerable_incs = [
        'trie off its shard',
        'unknown shard of instance',
        'multiple tries',
        'dead instance',
        'relational consistency fault: up',
        'relational consistency fault: down',
    ]

    def __init__(self, inconsistencies):
        warns = []
        faults = []
        for inc in inconsistencies:
            if inc['class'] not in self.tolerable_incs:
                faults.append(inc)
            else:
                warns.append(inc)
        if faults:
            self.passed = False
            self.report = 'Faults: {0}\n'.format('\n'.join(
                ['\t{0}: {1}'.format(inc['type'], inc['message']) for inc in faults]))
        else:
            self.passed = True
            self.report = ''
        self.report += 'Warns: {0}\n'.format('\n'.join(
            [
                '\t{0}: {1}'.format(inc['type'], inc['message'])
                for inc in warns
            ]
        ))
