import collections
import logging
from datetime import datetime, timedelta
import pprint
import pymongo
import pytz
import time

from libraries.hosts_renaming import rename

from states import states
from utils import singleton
from mongo_params import HEARTBEAT_MONGODB

DEAD_GROUP_THRESHOLD = 0.05
ALIVE_GROUP_THRESHOLD = 0.87  # the case: 9 instances in 3 locations; one instance should be suspendable
QUARANTINE_THRESHOLD = timedelta(hours=6)

SuspensionResult = collections.namedtuple('SuspensionResult', ['possible', 'groups', 'message'])


class _SuspensionApi(object):
    def __init__(self):
        self.client = pymongo.MongoReplicaSetClient(
            HEARTBEAT_MONGODB.uri,
            replicaSet=HEARTBEAT_MONGODB.replicaset,
            w='majority',
            wtimeout=5000,
            read_preference=HEARTBEAT_MONGODB.read_preference,
        )
        self.coll = self.client['heartbeat']['suspension']

    def acquire(self, taskid, hosts, dry_run=False):
        assert len(hosts) == 1, 'suspension of multiple hosts is disabled by agreement'
        host = hosts[0]

        result = self.calc_suspension_possibility(host)

        if not dry_run:
            self.coll.update({'taskid': taskid},
                             {'$setOnInsert': {'taskid': taskid, 'hosts': hosts,
                                               'status': result.possible, 'message': result.message}}, upsert=True)

        return result.possible, result.message

    def get_status(self, taskid=None):
        result = {}

        for r in self.coll.find(
            {} if taskid is None else {'taskid': taskid},
            {'hosts': 1, 'taskid': 1, 'status': 1, 'message': 1}
        ):
            hosts = r['hosts']
            status = r['status']
            message = r.get('message', '')

            if hosts is None:
                raise KeyError('Task [{}] not found'.format(taskid))

            result[r['taskid']] = {'hosts': hosts, 'status': status, 'message': message}

        return result

    def update_status(self, readonly=False):
        successes = []
        rejects = {}
        to_suspend = set()

        for r in self.coll.find({}, {'_id': 1, 'hosts': 1, 'taskid': 1, 'status': 1}):
            host = r['hosts'][0]
            status = r['status']
            taskid = r['taskid']
            spent = datetime.now(pytz.utc) - r['_id'].generation_time

            if not status:
                result = self.calc_suspension_possibility(host, to_suspend, spent)
                if result.possible:
                    successes.append(taskid)
                    to_suspend.add(host)
                else:
                    rejects[taskid] = (host, result.groups, result.message)

        for taskid in successes:
            if readonly:
                _log.getChild('update').info('%s success', taskid)
            else:
                self.coll.update({'taskid': taskid}, {'$set': {'status': True}})
        for taskid in rejects:
            host, groups, message = rejects[taskid]
            if readonly:
                _log.getChild('update').info('reject %s %s', host, message)
            else:
                self.coll.update({'taskid': taskid}, {'$set': {'message': message}})

    def release(self, taskid):
        self.coll.remove({'taskid': taskid})

    def calc_suspension_possibility(self, host, suspended, quarantine_spent=timedelta(0)):
        suspended = self._get_suspended_hosts() | suspended

        last_seen = self._last_seen(host)
        if last_seen == 'never':
            if quarantine_spent > QUARANTINE_THRESHOLD:
                result = SuspensionResult(possible=True, groups=[], message='host never seen, quarantine passed')
            else:
                result = SuspensionResult(
                    possible=False, groups=[],
                    message='host never seen, put in quarantine for {}'.format(QUARANTINE_THRESHOLD))
        elif last_seen == 'long_ago':
            result = SuspensionResult(possible=True, groups=[], message='host seen long ago, suspension accepted')
        else:
            _log.info('[%s] seen recently, checking groups', host)
            endangered_groups = _get_endangered_groups(host, suspended)
            result = SuspensionResult(possible=not bool(endangered_groups), groups=endangered_groups,
                                      message=_get_message(host, endangered_groups))

        _log.info('%s', result)
        return result

    def _get_suspended_hosts(self):
        hosts = set()
        for r in self.coll.find({'status': True}, {'hosts': 1}):
            for h in r['hosts']:
                hosts.add(h)
        return hosts

    def _last_seen(self, host):
        seen = False

        for record in self.client['heartbeat']['instancestatev3'].find(
                {'$or': [{'host': host}, {'host': rename(host)}]},
                {'last_update': 1}
        ):
            seen = True

            _log.info('[%s] last seen [%s]', host, record['last_update'])
            if record['last_update'] > datetime.now() - timedelta(hours=1):
                return 'recent'

        if not seen:
            _log.info('[%s] never seen', host)

        return 'never' if not seen else 'long_ago'


def _get_message(host, groups):
    return '{} postponed because of {}'.format(host, list(groups))


def _get_endangered_groups(host, suspended):
    endangered_groups = []
    for group_id, group_state in _get_online_groups_on_host(host):
        if not _group_survives_suspension(group_id, group_state, host, suspended):
            endangered_groups.append((group_id, group_state))
    return endangered_groups


def _get_online_groups_on_host(host):
    groups_to_preserve = {}

    groups_on_host = states().instance_state.groups_state.groups_on_host(host)
    for group_id, group_state in groups_on_host.iteritems():
        group_instances = list(group_state.iter_all_instances())
        alive_instances = list(group_state.iter_alive_instances())
        if len(group_instances) < 100:
            _log.warn('%s %s %s', group_id, group_instances, alive_instances)
        else:
            _log.warn('%s', group_id)
        if (
            _is_alive(len(group_instances), len(alive_instances)) or
            _is_bad_trunk(group_id, group_instances)
        ) and not _ignored(group_id[0]):
            groups_to_preserve[group_id] = group_state

    _log.info('[%s] on host %s', host, groups_on_host)
    _log.info('[%s] to preserve %s', host, groups_to_preserve.keys())

    return groups_to_preserve.items()


def _is_alive(total_count, alive_count):
    return alive_count > DEAD_GROUP_THRESHOLD * total_count


def _ignored(group):
    if group in {
        'MAN_KIWICALC_PROD_VM',
        'ALL_INFORMANT_PRODUCTION',
        'ALL_SEARCH',
    }:
        return True

    return any(prefix in group for prefix in {
        'YASMAGENT',
        'JUGGLER_CLIENT',
        'RTC_SLA_TENTACLES_PROD',
        'MAIL_LUCENE',
        'PSI',
        'YT_PROD',
        '_WEB_DEPLOY',
        'MAN_YT_TESTING1_PORTOVM',
    })


def _group_survives_suspension(group_id, group_state, host, suspended_hosts):
    if _is_bad_trunk(group_id, group_state):
        return False

    survived_instances = set(group_state.iter_alive_instances())
    for instance in group_state.iter_all_instances():
        if instance[0] == host or instance[0] == rename(host) or instance[0] in suspended_hosts:
            survived_instances -= {instance}

    group_size = len(group_state)
    survived_size = len(survived_instances)

    # special case: one instance can die in small groups
    if group_size <= 9:
        if group_size - survived_size <= 1:
            return True

    # common case
    print group_size, survived_size
    return group_size * (1 - ALIVE_GROUP_THRESHOLD) >= (group_size - survived_size)


def _is_bad_trunk(group_id, group_instances):
    return group_id[1] > 1000000000 and len(group_instances) == 0


def acquire(taskid, hosts, dry_run):
    return _suspension_api().acquire(taskid, hosts, dry_run)


def get_status(taskid=None):
    return _suspension_api().get_status(taskid)


def release(taskid):
    return _suspension_api().release(taskid)


@singleton
def _suspension_api():
    return _SuspensionApi()


def update_status_loop():
    while True:
        states().update_alive()
        states().update_istates()
        _suspension_api().update_status()

        time.sleep(60)


def check(hosts):
    states().update_alive()
    states().update_istates()

    suspended = set()
    for host in hosts:
        result = _suspension_api().calc_suspension_possibility(
            host=host, suspended=suspended,
            quarantine_spent=QUARANTINE_THRESHOLD + timedelta(seconds=1))
        if result.possible:
            suspended.add(host)

        pprint.pprint(suspended)
        pprint.pprint(host)
        pprint.pprint({'result': result.possible, 'message': result.message})


_log = logging.getLogger('suspension')
