import os
import math
import logging
import logging.handlers
import argparse

import retry
import kernel.util.logging as skynet_logging
import yt

from libraries.topology.utils import is_trunk, version_to_tag, tag_to_version
from libraries.topology.groups import get_commit_of_version
from libraries.topology.searcher_lookup import load_instances

import libraries.yp_sd.lock as lock
import libraries.yp_sd.yp_utils as yp_utils
import libraries.yp_sd.endpoint as endpoint
import libraries.yp_sd.endpoint_set as endpoint_set
import groups_state


def is_probably_alive(state):
    alive_cnt = state.alive_instances_cnt()
    all_cnt = state.all_instances_cnt()

    if all_cnt == 0:
        return False
    if float(alive_cnt) / all_cnt > 0.05:
        return True
    if alive_cnt > 0 and all_cnt < 10:
        return True
    return False


def is_alive(version):
    total, alive = version['total_instances'], version['alive_instances']
    if total > 10:
        return alive >= int(0.88 * total)
    if total > 4:
        return alive >= math.ceil(0.6 * total)
    if total == 4:
        return alive >= 2
    if total == 3:
        return alive >= 2
    return alive >= 1


def alive_version(versions):
    versions = sorted(versions, key=lambda x: tag_to_version(x['topology']), reverse=True)
    for version in versions:
        if is_alive(version):
            return version
    return None


def probably_alive_versions(versions):
    result = []
    for version, state in versions.items():
        if is_probably_alive(state) and not is_trunk(version):
            result.append({
                'topology': version_to_tag(version),
                'alive_instances': state.alive_instances_cnt(),
                'total_instances': state.all_instances_cnt(),
            })
    result.sort(key=lambda x: x['alive_instances'], reverse=True)
    return result


def _has_mtn_section(instance):
    return (
        'hbf' in instance
        and 'interfaces' in instance['hbf']
        and 'backbone' in instance['hbf']['interfaces']
    )


def group_endpoints(group, topology, mtn):
    result = []
    commit = get_commit_of_version(tag_to_version(topology))
    instances = load_instances(commit, group)
    if instances is None:
        return None
    for instance in instances:
        if not mtn:
            result.append(endpoint.Record(
                fqdn=instance['hostname'],
                ipv6=instance['ipv6addr'],
                port=instance['port'],
            ))
        elif _has_mtn_section(instance):
            hbf = instance['hbf']['interfaces']
            result.append(endpoint.Record(
                fqdn=hbf['backbone']['hostname'],
                ipv6=hbf['backbone']['ipv6addr'],
                port=instance['port'],
            ))

    return result


@retry.retry(tries=3, delay=5)
def update_group(yp_client, group_name, alive_versions, mtn):
    endpoint_set_name = endpoint_set.group_to_endpoint_set(group_name, 'mtn' if mtn else None)
    online = alive_version(alive_versions) or {}
    with yp_utils.transaction(yp_client) as transaction_id:
        endpoint_set.update_endpoint_set(
            yp_client=yp_client,
            transaction_id=transaction_id,
            endpoint_set=endpoint_set_name,
            topology=online.get('topology', 'unknown'),
            annotations={
                'all': alive_versions,
                'online': online,
            }
        )
        if online:
            records = set(group_endpoints(group_name, online['topology'], mtn))
            endpoint.update_endpoints(
                yp_client=yp_client,
                transaction_id=transaction_id,
                endpoint_set=endpoint_set_name,
                records=records,
                set_ready=True,
            )


def groups_to_update(yp_client, groups_state_, all_groups):
    to_update = []
    group_to_topology = endpoint_set.list_endpoint_sets_topologies(yp_client)
    for group in all_groups:
        versions = probably_alive_versions(groups_state_.group_versions(group))
        topology = (alive_version(versions) or {}).get('topology', 'unknown')
        if topology != group_to_topology[group] or topology != group_to_topology[group + '.mtn']:
            to_update.append(group)
    return to_update


def update_all(yp_cluster, readonly):
    yp_client = yp_utils.make_client(yp_cluster)
    _log.info('start')
    _log.info('load instances state')
    groups_state_ = groups_state.get_groups_state()
    _log.info('done instances state')

    groups = sorted(groups_state_.groups())
    _log.info('total %s groups', len(groups))

    if not readonly:
        endpoint_set.create_new_endpoint_sets(
            yp_client=yp_client,
            names=groups,
        )
        endpoint_set.create_new_endpoint_sets(
            yp_client=yp_client,
            names=groups,
            name_suffix='mtn',
        )
    else:
        _log.info('readonly, do not create new endpoint_sets')

    _log.info('starting update')
    to_update = sorted(groups_to_update(yp_client, groups_state_, groups))
    _log.info('found %s groups to update', len(to_update))
    for group in to_update:
        if 'REMOTE_STORAGE_BASE' in group or 'INVERTED_INDEX' in group:
            _log.info('Skip %s', group)
            continue

        all_versions = groups_state_.group_versions(group)
        alive_versions = probably_alive_versions(all_versions)
        if not readonly:
            try:
                update_group(yp_client, group, alive_versions, mtn=False)
                update_group(yp_client, group, alive_versions, mtn=True)
            except Exception:
                _log.exception('Could not update endpoints for group %s', group)
        else:
            _log.info('readonly, do change endpoints for group %s', group)


def ensure_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)


def configure_logging():
    ensure_dir('./logs')

    debug_handler = logging.handlers.RotatingFileHandler('./logs/debug.log', maxBytes=1024 ** 3, backupCount=10)
    debug_handler.setLevel(logging.DEBUG)
    skynet_logging.initialize(handler=debug_handler)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.DEBUG)
    skynet_logging.initialize(handler=console_handler)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--readonly', dest='readonly', action='store_true', default=True)
    parser.add_argument('--not-readonly', dest='readonly', action='store_false')
    parser.add_argument('--yp-cluster', default='sas-test', choices=yp_utils.CLUSTERS)
    return parser.parse_args()


def main():
    configure_logging()
    args = parse_args()
    try:
        with lock.lock_path(args.yp_cluster, args.readonly):
            update_all(
                yp_cluster=args.yp_cluster,
                readonly=args.readonly,
            )
    except yt.wrapper.errors.YtCypressTransactionLockConflict:
        _log.warning('could not lock, exit 0')
        exit(0)


_log = logging.getLogger(__name__)


if __name__ == '__main__':
    main()
