# coding=utf-8
import logging
from datetime import datetime, timedelta

import pymongo
from states import states
from libraries.hosts_renaming import rename

from utils import singleton
from mongo_params import HEARTBEAT_MONGODB

DEAD_GROUP_THRESHOLD = 0.05
ALIVE_GROUP_THRESHOLD = 0.88  # the case: 9 instances in 3 locations; one instance should be suspendable


class _SuspensionApi(object):
    def __init__(self):
        self.client = pymongo.MongoReplicaSetClient(
            HEARTBEAT_MONGODB.uri,
            replicaSet=HEARTBEAT_MONGODB.replicaset,
            w='majority',
            wtimeout=5000,
            read_preference=HEARTBEAT_MONGODB.read_preference,
        )
        self.coll = self.client['heartbeat']['suspension']
        self.log = logging.getLogger('suspension')

    def acquire(self, taskid, hosts, dry_run=False):
        assert len(hosts) == 1, 'suspension of multiple hosts is disabled by agreement'
        host = hosts[0]

        status, groups = self.calc_suspension_possibility(host)
        message = _get_message(host, groups)

        # What about duplicate taskid?
        if not dry_run and status is not None:
            self.coll.insert({'taskid': taskid, 'hosts': hosts, 'status': status, 'message': message})

        return status, message

    def get_status(self, taskid=None):
        result = {}

        for r in self.coll.find(
            {} if taskid is None else {'taskid': taskid},
            {'hosts': 1, 'taskid': 1, 'status': 1, 'message': 1}
        ):
            hosts = r['hosts']
            status = r['status']
            message = r.get('message', '')

            if hosts is None:
                raise KeyError('Task [{}] not found'.format(taskid))

            result[r['taskid']] = {'hosts': hosts, 'status': status, 'message': message}

        return result

    def update_status(self):
        successes = []
        rejects = {}

        for r in self.coll.find(
            {},
            {'hosts': 1, 'taskid': 1, 'status': 1}
        ):
            host = r['hosts'][0]
            status = r['status']

            if not status:
                status, groups = self.calc_suspension_possibility(host)
                if status:
                    successes.append(r['taskid'])
                else:
                    rejects[r['taskid']] = (host, groups)

        for taskid in successes:
            self.coll.update({'taskid': taskid}, {'$set': {'status': True}})
        for taskid in rejects:
            host, groups = rejects[taskid]
            self.coll.update({'taskid': taskid}, {'$set': {'message': _get_message(host, groups)}})

    def release(self, taskid):
        self.coll.remove({'taskid': taskid})

    def calc_suspension_possibility(self, host):
        suspended = self._get_suspended_hosts()

        last_seen = self._last_seen(host)
        if last_seen == 'never':
            return None, []
        elif last_seen == 'long_ago':
            return True, []
        else:
            endangered_groups = _get_endangered_groups(host, suspended)
            return (not bool(endangered_groups)), endangered_groups

    def _get_suspended_hosts(self):
        hosts = set()
        for r in self.coll.find({'status': True}, {'hosts': 1}):
            for h in r['hosts']:
                hosts.add(h)
        return hosts

    def _last_seen(self, host):
        seen = False

        for record in self.client['heartbeat']['instancestatev3'].find(
                {'$or': [{'host': host}, {'host': rename(host)}]},
                {'last_update': 1}
        ):
            seen = True

            self.log.info('host [%s] last seen [%s]', host, record['last_update'])
            if record['last_update'] > datetime.now() - timedelta(hours=1):
                return 'recent'

        if not seen:
            self.log.info('host [%s] never seen', host)

        return 'never' if not seen else 'long_ago'


def _get_message(host, groups):
    return '{} postponed because of {}'.format(host, list(groups))


def _get_endangered_groups(host, suspended):
    endangered_groups = []
    for group_id, group_state in _get_online_groups_on_host(host):
        if not _group_survives_suspension(group_id, group_state, host, suspended):
            endangered_groups.append((group_id, group_state))
    return endangered_groups


def _get_online_groups_on_host(host):
    groups_to_preserve = {}

    groups_on_host = states().instance_state.groups_state.groups_on_host(host)
    for group_id, group_state in groups_on_host.iteritems():
        group_instances = list(group_state.iter_all_instances())
        alive_instances = list(group_state.iter_alive_instances())
        if (
            _is_alive(len(group_instances), len(alive_instances)) or
            _is_bad_trunk(group_id, group_instances)
        ) and not _ignored(group_id[0]):
            groups_to_preserve[group_id] = group_state

    logging.getLogger('suspension').info('[%s] on host %s', host, groups_on_host)
    logging.getLogger('suspension').info('[%s] to preserve %s', host, groups_to_preserve.keys())

    return groups_to_preserve.items()


def _is_alive(total_count, alive_count):
    return alive_count > DEAD_GROUP_THRESHOLD * total_count


def _ignored(group):
    if group in {
        'MAN_KIWICALC_PROD_VM',
        'ALL_INFORMANT_PRODUCTION',
        'ALL_SEARCH',
    }:
        return True

    if 'PSI' in group and 'AGENTS' in group:
        return True

    return any(prefix in group for prefix in [
        'YASMAGENT',
        'JUGGLER_CLIENT',
        'RTC_SLA_TENTACLES_PROD',
        'MAIL_LUCENE',
    ])


def _group_survives_suspension(group_id, group_state, host, suspended_hosts):
    if _is_bad_trunk(group_id, group_state):
        return False

    survived_instances = set(group_state.iter_alive_instances())
    for instance in group_state.iter_all_instances():
        if instance[0] == host or instance[0] == rename(host) or instance[0] in suspended_hosts:
            survived_instances -= {instance}

    group_size = len(group_state)
    survived_size = len(survived_instances)

    # special case: one instance can die in small groups
    if group_size <= 9:
        if group_size - survived_size <= 1:
            return True

    # common case
    return group_size * (1 - ALIVE_GROUP_THRESHOLD) >= (group_size - survived_size)


def _is_bad_trunk(group_id, group_instances):
    return group_id[1] > 1000000000 and len(group_instances) == 0


def acquire(taskid, hosts, dry_run):
    return _suspension_api().acquire(taskid, hosts, dry_run)


def get_status(taskid=None):
    return _suspension_api().get_status(taskid)


def release(taskid):
    return _suspension_api().release(taskid)


@singleton
def _suspension_api():
    return _SuspensionApi()


def _update_status_loop():
    import time
    while True:
        states().update_alive()
        states().update_istates()
        _suspension_api().update_status()

        time.sleep(60)


def check(host):
    states().update_alive()
    states().update_istates()

    result, groups = _suspension_api().calc_suspension_possibility(host)

    import pprint
    pprint.pprint(host)
    pprint.pprint(_get_message(host, groups))


def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description='Check and update suspension status')

    parser.add_argument('--update', dest='update', action='store_true', help='run infinite update loop')
    parser.add_argument('--check', dest='check', type=str, default=None, help='check host')

    args = parser.parse_args()
    if not args.update and not args.check:
        parser.print_help()

    return args


def main():
    from utils import configure_log
    configure_log(debug=True)

    args = parse_args()
    if args.check:
        check(args.check)
    elif args.update:
        _update_status_loop()


if __name__ == '__main__':
    main()
