import string

import click as c
import requests

from infra.awacs.tools.awacstoolslib.app import Selector, NamespaceSelector
from infra.awacs.tools.awacstoolslib.util import (cli, wait_for_confirmation, update_l7_macro_version,
                                                  update_instance_macro_version,
                                                  update_awacslet_version, update_instancectl_conf_version,
                                                  update_base_layer_version, is_l7_macro_version_gte,
                                                  update_l7_macro_version_preserving_comments,
                                                  is_instance_macro_version_gte, create_awacs_backends_list_href,
                                                  create_awacs_backend_href)


def get_backend_groups(a, namespace_id):
    """
    :type a: infra.awacs.tools.awacstoolslib.awacsclient.AwacsClient
    :type namespace_id: six.text_type
    :rtype: List[frozenset]
    """
    rv = []
    ok, graph = a.get_inclusion_graph(namespace_id)
    if not ok:
        raise ValueError('inclusion graph for {} is not OK'.format(namespace_id))
    for node in graph:
        if node['type'] == 'upstream':
            included_backend_ids = set()
            skip_upstream = False
            for flat_id in node['included_backend_ids']:
                backend_namespace_id, backend_id = flat_id.split('/')
                if backend_namespace_id != namespace_id:
                    skip_upstream = True
                included_backend_ids.add(backend_id)
            if not skip_upstream:
                rv.append(frozenset(included_backend_ids))
    return rv


def does_endpoint_set_exist(cluster, endpoint_set_id):
    """
    :type cluster: six.text_type
    :type endpoint_set_id: six.text_type
    :rtype: bool
    """
    req = {
        'endpoint_set_id': endpoint_set_id,
        'client_name': 'awacssdtool',
        'cluster_name': cluster.lower(),
    }
    resp = requests.post('http://sd.yandex.net:8080/resolve_endpoints/json', json=req)
    resp.raise_for_status()
    return resp.json()['resolve_status'] == 2


@cli.command('check_migrated_backends')
@c.pass_obj
def check_migrated_backends(app):
    """
    :type app: infra.awacs.tools.awacstoolslib.app.App
    """
    a = app.awacs_client
    if app.op is None:
        c.secho('Op must be configured (see help)', fg='red')
        return

    def get_migrated_backends(op):
        rv = set()
        for flat_balancer_id in op.data.get('migrated_flat_backend_ids', []):
            rv.add(tuple(flat_balancer_id.split('/', 1)))
        return rv

    full_migrated_backend_ids = get_migrated_backends(app.op)
    n = 1
    for namespace_id, backend_id in full_migrated_backend_ids:
        backend_pb = a.get_backend(namespace_id, backend_id)
        assert backend_pb.spec.selector.type == backend_pb.spec.selector.YP_ENDPOINT_SETS_SD
        ok = True
        for status_pb in backend_pb.statuses:
            for k, pb in status_pb.validated.items():
                n += 1
                if pb.status != 'True':
                    ok = False
        if not ok:
            c.secho('{}:{} is not OK'.format(namespace_id, backend_id), fg='red')
            c.echo(create_awacs_backend_href(namespace_id, backend_id))
    c.secho(u'Checked {} backends, {} n rev x balancer pairs'.format(len(full_migrated_backend_ids), n), fg='green')


@cli.command('migrate_backends')
@c.option('--selector', type=NamespaceSelector.parse, required=True)
@c.pass_obj
def migrate_backends(app, selector):
    """
    :type app: infra.awacs.tools.awacstoolslib.app.App
    :type selector: NamespaceSelector
    """
    if app.op is None:
        c.secho('Op must be configured (see help)', fg='red')
        return

    def get_migrated_backends(op):
        rv = set()
        for flat_balancer_id in op.data.get('migrated_flat_backend_ids', []):
            rv.add(tuple(flat_balancer_id.split('/', 1)))
        return rv

    def save_migrated_backends(op, full_migrated_backend_ids):
        op.data['migrated_flat_backend_ids'] = sorted(['/'.join(id_)
                                                       for id_ in full_migrated_backend_ids])
        op.save()

    a = app.awacs_client

    full_migrated_backend_ids = get_migrated_backends(app.op)

    def process_namespace(namespace_id):
        for balancer_pb in a.iter_all_balancers(namespace_id_in=(namespace_id,)):
            balancer_id = balancer_pb.meta.id
            cluster = balancer_pb.meta.location.yp_cluster.upper()
            if not cluster:
                raise ValueError('{}:{} is not YP.lite-powered'.format(namespace_id, balancer_id))
            service_id = balancer_pb.spec.config_transport.nanny_static_file.service_id
            c.echo('Looking at {} pods in {}'.format(service_id, cluster))
            for pod_pb in app.nanny_client.iter_pods(cluster, service_id):
                if not is_pod_processed(pod_pb):
                    raise ValueError('Pod {} in {}:{} is not processed'.format(
                        pod_pb.meta.id, namespace_id, balancer_id))

            spec_pb = balancer_pb.spec
            if spec_pb.yandex_balancer.config.HasField('l7_macro'):
                if not is_l7_macro_version_gte(spec_pb, '0.2.3'):
                    raise ValueError('{}:{} l7_macro is < 0.2.3'.format(namespace_id, balancer_id))
            elif spec_pb.yandex_balancer.config.HasField('instance_macro'):
                if not is_instance_macro_version_gte(spec_pb, '0.0.2'):
                    raise ValueError('{}:{} instance_macro is < 0.0.2'.format(namespace_id, balancer_id))

        rps_stats = a.get_yesterday_max_rps_stats_by_namespace()
        namespace_rps = rps_stats.get(namespace_id, 0)

        backend_groups = get_backend_groups(a, namespace_id)

        l7_backend_ids = set()
        l7_backend_id_types = {}
        l7_backend_ids_to_be_migrated = set()

        unused_backend_ids = set()

        backend_pbs = a.list_backends(namespace_id)
        for backend_pb in backend_pbs:
            backend_id = backend_pb.meta.id
            if backend_pb.spec.deleted:
                continue
            if backend_pb.resolver_status.last_attempt.succeeded.status != 'True':
                continue
            used_in_dns = False
            for rev_pb in backend_pb.dns_record_statuses:
                for _, cond_pb in rev_pb.validated.items():
                    if cond_pb.status == 'True':
                        used_in_dns = True
            used_in_l3 = False
            for rev_pb in backend_pb.l3_statuses:
                for _, cond_pb in rev_pb.validated.items():
                    if cond_pb.status == 'True':
                        used_in_l3 = True
            used_in_l7 = False
            for rev_pb in backend_pb.statuses:
                for _, cond_pb in rev_pb.validated.items():
                    if cond_pb.status == 'True':
                        used_in_l7 = True
                        l7_backend_id_types[backend_id] = backend_pb.spec.selector.Type.Name(
                            backend_pb.spec.selector.type)
            # print('backend_id', backend_id, used_in_l7, used_in_l3, used_in_dns)
            if used_in_l7:
                l7_backend_ids.add(backend_id)

            if not used_in_dns and not used_in_l3 and not used_in_l7:
                unused_backend_ids.add(backend_id)

            if (not used_in_dns and
                not used_in_l3 and
                backend_pb.spec.selector.type == backend_pb.spec.selector.YP_ENDPOINT_SETS and
                backend_pb.spec.selector.port.policy == backend_pb.spec.selector.port.KEEP):
                is_backend_to_be_migrated = True
                for es_pb in backend_pb.spec.selector.yp_endpoint_sets:
                    if (es_pb.port.policy != es_pb.port.KEEP or
                        es_pb.weight.policy != es_pb.weight.KEEP or
                        not does_endpoint_set_exist(es_pb.cluster, es_pb.endpoint_set_id)):
                        is_backend_to_be_migrated = False
            else:
                is_backend_to_be_migrated = False
            if is_backend_to_be_migrated:
                l7_backend_ids_to_be_migrated.add(backend_id)

        for group_backend_ids in backend_groups:
            if (l7_backend_ids_to_be_migrated & group_backend_ids and
                not group_backend_ids.issubset(l7_backend_ids_to_be_migrated)):
                c.secho('group {} contains both fixable and unfixable backends'.format(
                    ', '.join(sorted(group_backend_ids))), fg='yellow')
                l7_backend_ids_to_be_migrated -= group_backend_ids

        l7_backend_ids_to_be_migrated -= unused_backend_ids

        if l7_backend_ids_to_be_migrated:
            c.secho('Backends to be migrated:')

            total_non_sd = 0
            for k, v in l7_backend_id_types.items():
                if v == 'YP_ENDPOINT_SETS':
                    fg = 'red'
                    total_non_sd += 1
                elif v == 'YP_ENDPOINT_SETS_SD':
                    fg = 'green'
                else:
                    fg = 'blue'
                    total_non_sd += 1
                c.echo(string.ljust(k, 60) + ' ' + c.style(v, fg=fg))
            c.secho('\nTotal non-SD backends: {}, to be migrated: {}'.format(total_non_sd,
                                                                             len(l7_backend_ids_to_be_migrated)),
                    fg='blue')
            for backend_id in sorted(l7_backend_ids_to_be_migrated):
                c.secho(' * {}'.format(backend_id))
            c.secho('Namespace {} serves {} RPS at peak times'.format(namespace_id, namespace_rps))
        else:
            c.secho('No backends to be migrated', fg='green')

        if l7_backend_ids_to_be_migrated:
            msg = 'The tool is about to migrate {} affected backends in {}'.format(
                len(l7_backend_ids_to_be_migrated), namespace_id)
            if wait_for_confirmation(msg, confirm_automatically_after=60):
                for backend_pb in backend_pbs:
                    if backend_pb.meta.id not in l7_backend_ids_to_be_migrated:
                        continue
                    backend_pb.spec.selector.type = backend_pb.spec.selector.YP_ENDPOINT_SETS_SD
                    a.update_backend(namespace_id=backend_pb.meta.namespace_id,
                                     backend_id=backend_pb.meta.id,
                                     version=backend_pb.meta.version,
                                     spec_pb=backend_pb.spec,
                                     comment='SWATOPS-270: use SD')
                    c.secho('Updated backend "{}"'.format(backend_pb.meta.id), fg='green')
                    full_migrated_backend_ids.add((backend_pb.meta.namespace_id, backend_pb.meta.id))
                    save_migrated_backends(app.op, full_migrated_backend_ids)
                c.secho('Migrated {}'.format(create_awacs_backends_list_href(namespace_id)), fg='green')

    for namespace_id in sorted(selector.resolver(app)):
        try:
            process_namespace(namespace_id)
        except Exception as e:
            c.secho('Failed to migrate {}: {}'.format(namespace_id, e), fg='red')


def get_label(labels, key, default=None):
    for a in labels.attributes:
        if a.key == key:
            return a.value
    return default


def get_disk_request_by_mount_path(disk_volume_requests, mount_path):
    for r in disk_volume_requests:
        if get_label(r.labels, key='mount_path') == mount_path:
            return r


def is_pod_processed(pod_pb):
    pod_spec_pb = pod_pb.spec
    awacs_volume = get_disk_request_by_mount_path(pod_spec_pb.disk_volume_requests, '/awacs')
    return awacs_volume is not None


def process_balancer(app, balancer_pb, continue_with_incomplete_pods=False):
    namespace_id = balancer_pb.meta.namespace_id
    balancer_id = balancer_pb.meta.id
    cluster = balancer_pb.meta.location.yp_cluster.upper()
    service_id = balancer_pb.spec.config_transport.nanny_static_file.service_id

    pod_ids_to_be_processed = {}
    total_pods_count = 0
    already_processed_pods_count = 0
    c.echo('Looking at {} pods in {}'.format(service_id, cluster))
    for pod_pb in app.nanny_client.iter_pods(cluster, service_id):
        pod_id = pod_pb.meta.id
        pod_version = get_label(pod_pb.labels, 'nanny_version')
        pod_spec_pb = pod_pb.spec
        awacs_volume = get_disk_request_by_mount_path(pod_spec_pb.disk_volume_requests, '/awacs')
        root_volume = get_disk_request_by_mount_path(pod_spec_pb.disk_volume_requests, '/')
        # root_fs_snapshot_quota = int(get_label(root_volume.labels, key='root_fs_snapshot_quota'))
        # work_dir_snapshot_quota = int(get_label(root_volume.labels, key='work_dir_snapshot_quota'))
        # snapshots_count = root_volume.quota_policy.capacity / (root_fs_snapshot_quota + work_dir_snapshot_quota)
        # print('pod_id', pod_id,
        #      'snapshots_count', snapshots_count,
        #      'root_fs_snapshot_quota', root_fs_snapshot_quota / (1024 * 1024),
        #      'work_dir_snapshot_quota', work_dir_snapshot_quota / (1024 * 1024))
        if awacs_volume is None and len(pod_spec_pb.disk_volume_requests) == 2:
            pod_ids_to_be_processed[pod_id] = pod_version
        else:
            already_processed_pods_count += 1
        total_pods_count += 1

    c.echo('{} of {} pods are already processed'.format(already_processed_pods_count, total_pods_count))
    updated_pods_count = 0
    if pod_ids_to_be_processed:
        c.secho('Processing {} affected pods in {}:{}'.format(
            len(pod_ids_to_be_processed), namespace_id, balancer_id))
        for pod_id, pod_version in sorted(pod_ids_to_be_processed.items()):
            try:
                app.nanny_client.update_pod_by_script(
                    cluster=cluster,
                    service_id=service_id,
                    pod_id=pod_id,
                    pod_version=pod_version
                )
            except Exception as e:
                c.secho(' * failed to update pod "{}": {}'.format(pod_id, e), fg='yellow')
            else:
                c.secho(' * updated pod "{}"'.format(pod_id), fg='green')
                updated_pods_count += 1
        c.secho('Updated {} pods: https://nanny.yandex-team.ru/ui/#/services/catalog/{}/yp_pods/'.format(
            updated_pods_count, service_id), fg='green')

    total_processed_count = updated_pods_count + already_processed_pods_count
    if total_processed_count != total_pods_count and not continue_with_incomplete_pods:
        c.secho('Not proceeding, processed just {} out of {} pods'.format(
            total_processed_count, total_pods_count), fg='yellow')
        return bool(updated_pods_count), False

    update_msg_parts = []
    non_update_msg_parts = []
    tickets = set()

    updated = False
    spec_pb = balancer_pb.spec

    if spec_pb.yandex_balancer.config.HasField('l7_macro'):
        if '### balancer_deployer_sign' in spec_pb.yandex_balancer.yaml:
            c.secho('Not editing YAML, balancers seems to be sedem-managed', fg='yellow')
        else:
            if '#' in spec_pb.yandex_balancer.yaml:
                c.secho('l7_macro YAML has comments, using update_l7_macro_version_preserving_comments', fg='yellow')
                upd, msg = update_l7_macro_version_preserving_comments(spec_pb, '0.2.3')
            else:
                upd, msg = update_l7_macro_version(spec_pb, '0.2.3')

            if upd:
                updated = True
                update_msg_parts.append(msg)
                tickets.add('SWATOPS-281')
            else:
                non_update_msg_parts.append(msg)
    elif spec_pb.yandex_balancer.config.HasField('instance_macro'):
        upd, msg = update_instance_macro_version(spec_pb, '0.0.2')
        if upd:
            updated = True
            update_msg_parts.append(msg)
            tickets.add('SWATOPS-267')
        else:
            non_update_msg_parts.append(msg)
    elif spec_pb.yandex_balancer.config.HasField('quick_start_balancer_macro'):
        c.secho('Not proceeding, balancer is in quick start mode', fg='red')
        return False, False

    instancectl_conf_pb = spec_pb.components.instancectl_conf
    awacslet_pb = spec_pb.components.awacslet
    if instancectl_conf_pb.state == instancectl_conf_pb.SET:
        upd, msg = update_instancectl_conf_version(spec_pb,
                                                   to_version='0.1.6',
                                                   to_pushclient_version='0.1.6-pushclient')
        if upd:
            updated = True
            tickets.add('SWATOPS-267')
            update_msg_parts.append(msg)
        else:
            non_update_msg_parts.append(msg)
    elif awacslet_pb.state == awacslet_pb.SET:
        upd, msg = update_awacslet_version(spec_pb,
                                           to_version='0.0.4',
                                           to_pushclient_version='0.0.4-pushclient')
        if upd:
            updated = True
            tickets.add('SWATOPS-267')
            update_msg_parts.append(msg)
        else:
            non_update_msg_parts.append(msg)
    else:
        c.secho('Not proceeding, balancer does not have neither instancectl.conf nor awacslet', fg='red')
        return False, False

    base_layer_pb = spec_pb.components.base_layer
    assert base_layer_pb.state == base_layer_pb.SET
    upd, msg = update_base_layer_version(spec_pb, to_version='xenial-1')
    if upd:
        updated = True
        tickets.add('SWATOPS-103')
        update_msg_parts.append(msg)
    else:
        non_update_msg_parts.append(msg)

    if updated:
        comment = '{}: {}'.format(', '.join(sorted(tickets)), '; '.join(update_msg_parts))
        try:
            app.awacs_client.update_balancer(namespace_id=namespace_id,
                                             balancer_id=balancer_id,
                                             version=balancer_pb.meta.version,
                                             spec_pb=balancer_pb.spec,
                                             comment=comment)
        except Exception as e:
            c.secho('failed to update {}:{}: {}'.format(namespace_id, balancer_id, e), fg='red')
            return False, False
        else:
            c.secho(
                'Updated balancer spec, see https://nanny.yandex-team.ru/ui/#/awacs/'
                'namespaces/list/{}/monitoring/common/'.format(namespace_id), fg='green')
            for msg in update_msg_parts:
                c.secho(' * ' + msg, fg='green')
            return True, True
    else:
        c.echo('Did not update {}:{}'.format(namespace_id, balancer_id))
        assert not update_msg_parts
        for msg in non_update_msg_parts:
            c.echo(' * ' + msg)
        return False, True


def get_unprocessed_full_balancer_ids(app, selector):
    op = app.op
    processed_full_balancer_ids = set()
    for flat_balancer_id in op.data.get('processed_flat_balancer_ids', []):
        processed_full_balancer_ids.add(tuple(flat_balancer_id.split('/', 1)))
    full_balancer_ids = selector.resolver(app)
    return full_balancer_ids, processed_full_balancer_ids, full_balancer_ids - processed_full_balancer_ids


@cli.command()
@c.option('--selector', type=Selector.parse, required=True)
@c.pass_obj
def ls(app, selector):
    """
    :type app: infra.awacs.tools.awacstoolslib.app.App
    :type selector: Selector
    """
    if app.op is None:
        c.secho('Op must be configured (see help)', fg='red')
        return

    full_balancer_ids, processed_full_balancer_ids, unprocessed_full_balancer_ids = \
        get_unprocessed_full_balancer_ids(app, selector)
    for ns_id, b_id in sorted(unprocessed_full_balancer_ids):
        c.echo(ns_id + '/' + b_id)


@cli.command('split_off_awacs_volume')
@c.option('--selector', type=Selector.parse, required=True)
@c.option('--confirm-every', type=int, default=1)
@c.option('--confirm-automatically-after', type=int, default=-1)
@c.option('--continue-with-incomplete-pods', type=bool, default=False)
@c.pass_obj
def split_off_awacs_volume(app, selector, confirm_every,
                           continue_with_incomplete_pods=False,
                           confirm_automatically_after=-1):
    """
    :type app: infra.awacs.tools.awacstoolslib.app.App
    :type selector: Selector
    :type confirm_every: int
    :type confirm_automatically_after: int
    """
    if app.op is None:
        c.secho('Op must be configured (see help)', fg='red')
        return

    def save_processed_balancers(op, full_balancer_ids):
        op.data['processed_flat_balancer_ids'] = sorted(['/'.join(full_balancer_id)
                                                         for full_balancer_id in full_balancer_ids])
        op.save()

    full_balancer_ids, processed_full_balancer_ids, unprocessed_full_balancer_ids = \
        get_unprocessed_full_balancer_ids(app, selector)

    c.echo('found {} unprocessed balancers in {}'.format(
        len(unprocessed_full_balancer_ids),
        selector.expr))

    if not unprocessed_full_balancer_ids:
        c.secho('No unprocessed balancers found', fg='red')
        return

    i = 1
    if not wait_for_confirmation('Let\'s start?'):
        return

    rps_data = app.awacs_client.get_yesterday_max_rps_stats_by_balancer()
    just_confirmed = True
    for balancer_pb in app.awacs_client.iter_all_balancers(skip_incomplete=True,
                                                           yp_lite_only=True,
                                                           full_balancer_id_in=unprocessed_full_balancer_ids):
        namespace_id = balancer_pb.meta.namespace_id
        balancer_id = balancer_pb.meta.id

        b_rps = rps_data.get((namespace_id, balancer_id), -1)
        rps = str(int(b_rps)) if b_rps != -1 else 'UNKNOWN'
        c.secho('Looking at {}:{}, {} RPS...'.format(namespace_id, balancer_id, rps), fg='blue')
        if not just_confirmed and i % confirm_every == 0:
            if not wait_for_confirmation(
                'Going to update {}:{}'.format(namespace_id, balancer_id),
                confirm_automatically_after=confirm_automatically_after):
                c.echo('Skipped {}:{}...'.format(namespace_id, balancer_id))
                continue
            else:
                just_confirmed = True

        try:
            updated, processed = process_balancer(app, balancer_pb,
                                                  continue_with_incomplete_pods=continue_with_incomplete_pods)
        except Exception:
            c.secho('Failed to process {}:{}'.format(namespace_id, balancer_id), fg='red')
            raise

        if updated:
            just_confirmed = False
            i += 1
        if processed:
            processed_full_balancer_ids.add((namespace_id, balancer_id))
            save_processed_balancers(app.op, processed_full_balancer_ids)


if __name__ == '__main__':
    cli()
