import infra.callisto.controllers.utils.yp_utils as yp_utils

import argparse
import collections
import logging
import time

import intlookup_parser


Pod = collections.namedtuple('Pod', [
    'id', 'index', 'node_id', 'hint', 'eviction_state'
])


EndPointSet = collections.namedtuple('EndPointSet', [
    'id', 'acl', 'account_id', 'labels',
    'port', 'liveness_limit_ratio', 'protocol', 'pod_filter'
])


def main():
    args = parse_args()
    configure_logging()
    if args.command == 'endpoint_sets':
        endpoint_sets(args.endpointset_number, args.yp_cluster, args.pod_set, args.workload_port, args.workload_name, args.yplookup, args.apply)
    elif args.command == 'scheduling_hints':
        scheduling_hints(args.intlookup, args.yp_cluster, args.pod_set, args.itype, args.apply)
    elif args.command == 'evict_misplaced':
        evict_misplaced(args.yp_cluster, args.pod_set, args.force, args.apply)


def endpoint_sets(endpointset_number, yp_cluster, pod_set, port, workload, yplookup, apply_changes):
    target_endpoint_sets = [
        make_endpoint_set(pod_set, workload, port,
                          shard_id, endpointset_number, yplookup)
        for shard_id in range(endpointset_number)
    ]

    existent_eps = read_endpointsets(yp_cluster,
                                     [eps.id for eps in target_endpoint_sets])

    if apply_changes:
        applying_changes_alert()
    else:
        logging.info('Read-only mode')

    for endpoint_set in target_endpoint_sets:
        if endpoint_set.id in existent_eps:
            if endpoint_set == existent_eps[endpoint_set.id]:
                continue

            logging.info('%s -> %s', existent_eps[endpoint_set.id], endpoint_set)

            if apply_changes:
                update_endpoint_set(yp_cluster, endpoint_set)
        else:
            logging.info('None -> %s', endpoint_set)

            if apply_changes:
                create_endpoint_set(yp_cluster, endpoint_set)


def make_endpoint_set(pod_set, workload, port, shard_id, shard_count, yplookup):
    endpoint_set_id = '{}.{}'.format(pod_set, shard_id)
    if workload:
        endpoint_set_id += '.' + workload

    pod_filter = '[/meta/pod_set_id] = "{}" AND int64([/labels/pod_index]) % {} = {} AND NOT [/labels/gencfg/ban/set] = %true'.format(
        pod_set, shard_count, shard_id)

    stage, deploy_unit = pod_set.split('.')

    acl = [{
        'action': 'allow',
        'subjects': [
            'mcden',
            'okats',
            'robot-juba',
            'abc:service:3429',  # webbasesearchsshsudo
            # 'deploy:{}.OWNER'.format(stage),
            # 'deploy:{}.MAINTAINER'.format(stage),
        ],
        'permissions': ['read', 'write']
    }]

    return EndPointSet(
        id=endpoint_set_id,
        acl=acl,
        account_id='abc:service:3429',
        labels={'gencfg': {'yplookup': yplookup}},
        port=int(port),
        liveness_limit_ratio=1.0,
        protocol='TCP',
        pod_filter=pod_filter
    )


def read_endpointsets(cluster, ids):
    with yp_utils.client(cluster) as client:
        endpoint_sets = client.get_objects(
            'endpoint_set',
            ids,
            selectors=[
                '/meta/id',
                '/meta/acl',
                '/meta/account_id',
                '/labels',
                '/spec/port',
                '/spec/liveness_limit_ratio',
                '/spec/protocol',
                '/spec/pod_filter'
            ],
            options={'ignore_nonexistent':  True},
        )

    return {
        eps[0]: EndPointSet(
            id=eps[0],
            acl=eps[1],
            account_id=eps[2],
            labels=eps[3],
            port=eps[4],
            liveness_limit_ratio=eps[5] if eps[5] else None,
            protocol=eps[6] if eps[6] else None,
            pod_filter=eps[7]
        )
        for eps in endpoint_sets if eps
    }


def update_endpoint_set(cluster, endpoint_set):
    with yp_utils.client(cluster) as client:
        client.update_object('endpoint_set', endpoint_set.id, set_updates=[
            {'path': '/meta/acl', 'value': endpoint_set.acl},
            {'path': '/meta/account_id', 'value': endpoint_set.account_id},
            {'path': '/labels', 'value': endpoint_set.labels},
            {'path': '/spec/port', 'value': endpoint_set.port},
            {'path': '/spec/liveness_limit_ratio', 'value': endpoint_set.liveness_limit_ratio},
            {'path': '/spec/protocol', 'value': endpoint_set.protocol},
            {'path': '/spec/pod_filter', 'value': endpoint_set.pod_filter},
        ])


def create_endpoint_set(cluster, endpoint_set):
    with yp_utils.client(cluster) as client:
        client.create_object('endpoint_set', attributes={
            'meta': {
                'id': endpoint_set.id,
                'acl': endpoint_set.acl,
                'account_id': endpoint_set.account_id
            },
            'labels': endpoint_set.labels,
            'spec': {
                'pod_filter': endpoint_set.pod_filter,
                'protocol': endpoint_set.protocol,
                'liveness_limit_ratio': endpoint_set.liveness_limit_ratio,
                'port': endpoint_set.port
            }
        })


def scheduling_hints(intlookup_file, yp_cluster, pod_set, itype, apply_changes):
    intlookup = intlookup_parser.parse(intlookup_file)
    logging.info('Tier: %s, itype: %s', intlookup.tier, itype)

    if apply_changes:
        applying_changes_alert()
    else:
        logging.info('Read-only mode')

    # Pods are sorted by pod index.
    # Hosts are in intlookup-defined order.
    if itype == 'base':
        shard_count = intlookup.tier.shard_count if intlookup.tier else intlookup.instance_count
        instances = intlookup.shard_hosts
    elif itype == 'int':
        shard_count = len(intlookup.ints)
        instances = intlookup.ints
    elif itype == 'intl2':
        shard_count = len(intlookup.ints_l2)
        instances = intlookup.ints_l2
    else:
        raise RuntimeError('itype {} is unknown'.format(itype))

    assert shard_count > 0, 'itype {} has no instances at intlookup {}'.format(itype, intlookup_file)

    updates = []
    for pod in sorted(read_pods(yp_cluster, pod_set), key=lambda p: p.index):
        node = instances[pod.index % shard_count].pop(0)
        if pod.hint != pod.node_id and pod.node_id:
            logging.error('Outlaw pod: %s %s should be %s', pod, pod.node_id, pod.hint)
        if not pod.node_id:
            logging.warning('Unassigned pod: %s should be %s', pod, pod.hint)
        if pod.hint != node:
            logging.info('%s -> %s', pod, node)
            updates.append({
                'object_id': pod.id,
                'object_type': 'pod',
                'set_updates': [{
                    'path': '/spec/scheduling/hints',
                    'value': [{'node_id': node, 'strong': True}]
                }]
            })

    if updates and apply_changes:
        with yp_utils.client(yp_cluster) as client:
            client.update_objects(updates)

    logging.info('Total %s misplaced pods', len(updates))


def read_pods(cluster, pod_set):
    with yp_utils.client(cluster) as client:
        pods = client.select_objects('pod', filter='[/meta/pod_set_id]="{}"'.format(pod_set),
                                     selectors=['/meta/id', '/labels/pod_index',
                                                '/spec/node_id', '/spec/scheduling/hints/0/node_id',
                                                '/status/eviction/state'])
    logging.info('%s pods', len(pods))
    return [
        Pod(id=pod[0], index=int(pod[1]), node_id=pod[2],
            hint=pod[3], eviction_state=pod[4])
        for pod in pods
    ]


def evict_misplaced(yp_cluster, pod_set, force, apply_changes):
    if apply_changes:
        applying_changes_alert()
    else:
        logging.info('Read-only mode')

    updates = []
    for pod in read_pods(yp_cluster, pod_set):
        if not pod.node_id:
            continue
        if not pod.hint:
            continue
        if pod.node_id == pod.hint:
            continue
        if pod.eviction_state != 'none':
            continue
        logging.info('%s -> %s', pod.node_id, pod.hint)
        updates.append({
            'object_type': 'pod',
            'object_id': pod.id,
            'set_updates': [{
                'path': '/control/evict',
                'value': {'message': 'evict missplaced pod',
                          'validate_disruption_budget': not force}
            }]
        })

    if apply_changes:
        with yp_utils.client(yp_cluster) as client:
            client.update_objects(updates)
            logging.info('Total %s pods have been evicted', len(updates))
    else:
        logging.info('Total %s pods have to be evicted', len(updates))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--apply', required=False, action='store_true',
                        default=False)
    subparsers = parser.add_subparsers(dest='command')

    endpoint_set_parser = subparsers.add_parser('endpoint_sets')
    endpoint_set_parser.add_argument('--endpointset_number', type=int,
                                     required=True)
    endpoint_set_parser.add_argument('--yp_cluster', required=True)
    endpoint_set_parser.add_argument('--pod_set', required=True)
    endpoint_set_parser.add_argument('--workload_port', type=int, default=80)
    endpoint_set_parser.add_argument('--workload_name', required=False,
                                     help='Empty by default')
    endpoint_set_parser.add_argument('--yplookup', required=True, help='Used in labels only')

    scheduling_parser = subparsers.add_parser('scheduling_hints')
    scheduling_parser.add_argument('--intlookup', required=True)
    scheduling_parser.add_argument('--yp_cluster', required=True)
    scheduling_parser.add_argument('--pod_set', required=True)
    scheduling_parser.add_argument('--itype', required=True,
                                   choices=['base', 'int', 'intl2'],
                                   help='Instance level/type in intlookup')

    evict_misplaced = subparsers.add_parser('evict_misplaced')
    evict_misplaced.add_argument('--yp_cluster', required=True)
    evict_misplaced.add_argument('--pod_set', required=True)
    evict_misplaced.add_argument('--force', required=False,
                                 action='store_true', default=False)

    return parser.parse_args()


def configure_logging():
    logging.getLogger().setLevel(logging.INFO)


def applying_changes_alert():
    logging.warn('Applying changes')
    time.sleep(1)
    logging.warn('...3')
    time.sleep(1)
    logging.warn('...2')
    time.sleep(1)
    logging.warn('...1')
    time.sleep(1)


if __name__ == '__main__':
    main()
