import os
from kernel.util.functional import memoized, threadsafe


@threadsafe
@memoized
def heartbeat_api(namespace=None):
    if namespace is None:
        namespace = os.getenv('SKYNET_HEARTBEAT_NAMESPACE', 'skynet')

    from api.skycore import ServiceManager

    return ServiceManager().get_service_python_api(namespace, 'heartbeat-client')


def listServers(*types):
    if not types:
        return []
    return heartbeat_api()(master=True).call('listServers', *types)


def shardState(*args, **kwargs):
    api = heartbeat_api()(master=True)

    import msgpack
    import zlib

    compact = kwargs.pop('compact', False)

    class _Cb(object):
        result = {}
        decompressor = zlib.decompressobj()
        unpacker = msgpack.Unpacker()

        stateFields = None
        maxTimestamp = 0

        def __call__(self, d):
            chunk = self.decompressor.decompress(d)
            self.unpacker.feed(chunk)

            for data in self.unpacker:
                if self.stateFields is None:
                    self.stateFields = data
                    continue
                hostname, timestamp, statePacked = data
                self.maxTimestamp = max(timestamp, self.maxTimestamp)
                state = msgpack.loads(statePacked)
                if not compact:
                    for name, hostState in list(state.items()):
                        for key, value in list(hostState.items()):
                            if key in self.stateFields:
                                hostState.pop(key)
                                key = self.stateFields[key]
                                hostState[key] = value

                            if key == 'generated_chksum':
                                value = value.encode('hex')
                                hostState[key] = value
                            elif key == 'chksum':
                                value = 'MD5:%s' % (value.encode('hex'), )
                                hostState[key] = value

                self.result[hostname] = {'state': state, 'modified': timestamp}

    cb = _Cb()
    kwargs['stateCallback'] = cb
    returncode = api.call('shardState', *args, **kwargs)

    assert returncode

    if compact:
        return cb.result, cb.stateFields, cb.maxTimestamp
    return cb.result, cb.maxTimestamp


def shardStateV3(*args, **kwargs):
    api = heartbeat_api()(master=True)
    import msgpack
    import zlib
    import six

    kwargs.setdefault('compress', 5)

    stateKeys = kwargs.pop('stateKeys', False)
    longNameKeyMap = {
        'acquired': 'a',
        'aspam_mtime': 'sv',
        'downloaded': 'ds',
        'indexarc_size': 'as',
        'indexinv_size': 'is',
        'installed': 'i',
        'mtime': 'ct',          # %mtime (shard.conf)
        'search_zone': 'z',     # SearchZone (stamp.tag)
        'shard_int_md5': 'cs',  # int hash generated checksum
        'total': 'ts',          # sum of downloading files sizes
    }

    if stateKeys:
        class objV3(object):

            __slots__ = stateKeys

            def __init__(self, stateDict):
                for key in stateKeys:
                    realKey = longNameKeyMap.get(key, key)
                    setattr(self, key, stateDict.get(realKey, None))

            def __getitem__(self, key):
                try:
                    return getattr(self, key)
                except AttributeError:
                    raise KeyError(key)

            def __contains__(self, key):
                try:
                    self[key]
                    return True
                except KeyError:
                    return False

            def __repr__(self):
                res = []
                for key in sorted(self.__slots__):
                    value = getattr(self, key, None)
                    if value is not None:
                        res.append(key + ': ' + str(value))
                return '<%s object at 0x%x {%s}>' % (
                    self.__class__.__name__,
                    id(self),
                    ', '.join(res)
                )

            def get(self, key, default):
                try:
                    return getattr(self, key)
                except AttributeError:
                    if default:
                        return default
                    else:
                        raise KeyError(key)

    class _Cb(object):
        result = {}
        decompressor = zlib.decompressobj()
        unpacker = msgpack.Unpacker()

        stateFields = None
        maxTimestamp = 0

        def __call__(self, d):
            chunk = self.decompressor.decompress(d)
            self.unpacker.feed(chunk)

            for data in self.unpacker:
                if self.stateFields is None:
                    self.stateFields = data
                    continue
                hostname, timestamp, statePacked = data
                self.maxTimestamp = max(timestamp, self.maxTimestamp)
                state = msgpack.loads(statePacked)
                if stateKeys:
                    for k, v in six.iteritems(state):
                        state[k] = objV3(v)
                self.result[hostname] = state

    cb = _Cb()
    kwargs['stateCallback'] = cb
    returncode = api.call('shardStateV3', *args, **kwargs)

    assert returncode

    return cb.result, cb.maxTimestamp


def shardStateYield(*args, **kwargs):
    api = heartbeat_api()(master=True)
    from greenlet import greenlet
    import msgpack
    import zlib
    import six

    stateKeys = kwargs.pop('stateKeys', False)

    if stateKeys:
        state = {'fields': None}
        stateKeysSlots = {}
        stateKeysSlotsBack = {}
        for key in stateKeys:
            if '-' in key:
                nkey = key.replace('-', '_')
                stateKeysSlots[nkey] = key
                stateKeysSlotsBack[key] = nkey
            else:
                stateKeysSlots[key] = key

        class ShardState(object):
            __slots__ = list(stateKeysSlots.keys())

            def __init__(self, stateDict, cfields):
                for key, rkey in six.iteritems(stateKeysSlots):
                    setattr(
                        self, key,
                        stateDict.get(
                            rkey, None
                        )
                    )

            def __getitem__(self, key):
                try:
                    return getattr(self, stateKeysSlotsBack.get(key, key))
                except AttributeError:
                    raise KeyError(key)

            def __contains__(self, key):
                try:
                    self[key]
                    return True
                except KeyError:
                    return False

            def __repr__(self):
                res = []
                for key in sorted(self.__slots__):
                    value = getattr(self, key, None)
                    if value is not None:
                        res.append(key + ': ' + str(value))
                return '<%s object at 0x%x {%s}>' % (
                    self.__class__.__name__,
                    id(self),
                    ', '.join(res)
                )

    decompressor = zlib.decompressobj()
    unpacker = msgpack.Unpacker()
    grn_curr = greenlet.getcurrent()

    def cb(rawdata):
        chunk = decompressor.decompress(rawdata)
        unpacker.feed(chunk)

        for data in unpacker:
            if not stateKeys:
                grn_curr.switch(data)
            else:
                if not state['fields']:
                    for mangleKey, realKey in six.iteritems(data):
                        realRealKey = realKey.replace('-', '_')
                        if realRealKey in stateKeysSlots:
                            stateKeysSlots[realRealKey] = mangleKey
                    state['fields'] = True
                    continue

                host, timestamp, statePacked = data
                shardsObjDict = {}

                for shardName, shardStateDict in six.iteritems(msgpack.loads(statePacked)):
                    stateObj = ShardState(shardStateDict, state['fields'])
                    shardsObjDict[shardName] = stateObj

                grn_curr.switch((host, timestamp, shardsObjDict))

    kwargs['stateCallback'] = cb

    def call():
        returncode = api.call('shardState', *args, **kwargs)
        assert returncode

    grn_call = greenlet(call)

    while True:
        result = grn_call.switch()
        if not grn_call.dead:
            yield result
        else:
            break


def instanceStateV3(*args, **kwargs):
    api = heartbeat_api()(master=True)
    import msgpack
    import zlib
    import six

    addRankingModels = kwargs.pop('addRankingModels', False)
    kwargs.setdefault('compress', 4)

    stateKeys = kwargs.pop('stateKeys', False)
    longNameKeyMap = {
        'alive': 'a',
        'binary_md5': 'md5',
        'binary_revision': 'r',
        'models_md5': 'mv',
        'prepared': 'p',
        'shard': 's',
        'supermind': 'mu',
        'svn_url': 'u',
    }

    if stateKeys:
        class objV3(object):

            __slots__ = stateKeys

            def __init__(self, stateDict):
                for key in stateKeys:
                    realKey = longNameKeyMap.get(key, key)
                    setattr(self, key, stateDict.get(realKey, None))

            def __getitem__(self, key):
                try:
                    return getattr(self, key)
                except AttributeError:
                    raise KeyError(key)

            def __contains__(self, key):
                try:
                    self[key]
                    return True
                except KeyError:
                    return False

            def __repr__(self):
                res = []
                for key in sorted(self.__slots__):
                    value = getattr(self, key, None)
                    if value is not None:
                        res.append(key + ': ' + str(value))
                return '<%s object at 0x%x {%s}>' % (
                    self.__class__.__name__,
                    id(self),
                    ', '.join(res)
                )

            def get(self, key, default):
                try:
                    return getattr(self, key)
                except AttributeError:
                    if default:
                        return default
                    else:
                        raise KeyError(key)

    class _Cb(object):
        instanceNames = {}
        result = {}
        models = {}
        decompressor = zlib.decompressobj()
        unpacker = msgpack.Unpacker()

        maxTimestamp = 0
        endmodels = False

        def __call__(self, d):
            chunk = self.decompressor.decompress(d)
            self.unpacker.feed(chunk)

            for data in self.unpacker:
                if data == 'ENDMODELS':
                    self.endmodels = True
                    continue

                if not self.endmodels:
                    mid, model = data[0], msgpack.loads(data[1])
                    self.models[mid] = model
                    continue

                hostname, timestamp, statePacked = data
                self.maxTimestamp = max(timestamp, self.maxTimestamp)
                state = msgpack.loads(statePacked)
                for instance, instanceinfo in list(state.get('i', {}).items()):
                    ranking_info = instanceinfo.get('m', None)
                    if ranking_info is not None:
                        if addRankingModels:
                            if ranking_info in self.models:
                                state['i'][instance]['m'] = self.models.get(ranking_info, ranking_info)
                        else:
                            state['i'][instance].pop('m', None)
                if stateKeys:
                    for k, v in six.iteritems(state['i']):
                        state['i'][k] = objV3(v)
                self.result[hostname] = state

    cb = _Cb()
    kwargs['stateCallback'] = cb
    returncode = api.call(
        'instanceStateV3',
        *args, **kwargs
    )

    assert returncode

    return cb.result, cb.maxTimestamp


def listHosts(*args, **kwargs):
    assert not args, 'Args are not supported here, use keyword arguments only'
    api = heartbeat_api()(master=True)
    return api.call('listHosts', *args, **kwargs)


def hostsInfo(*args, **kwargs):
    assert not args, 'Plain arguments list is not supported here, use keyword arguments only.'
    api = heartbeat_api()(master=True)

    kwargs.setdefault('compress', 4)

    import zlib
    import msgpack

    class _Cb(object):
        decompressor = zlib.decompressobj()
        unpacker = msgpack.Unpacker()

        endos = False
        result = {}
        oses = {}

        def __call__(self, d):
            chunk = self.decompressor.decompress(d)
            self.unpacker.feed(chunk)

            self.oses[None] = (None, ) * 3
            for data in self.unpacker:
                if data == 'ENDOS':
                    self.endos = True
                    continue

                if not self.endos:
                    self.oses[data[0]] = data[1:]
                    continue

                hostname, oid = data
                self.result[hostname] = self.oses[oid]

    cb = _Cb()
    kwargs['stateCallback'] = cb
    if 'user_hosts' in kwargs:
        assert api.call('hostInfoEx', **kwargs)
    else:
        assert api.call('hostInfo', **kwargs)

    return cb.result
