import click
import prettytable
import yt.yson as yson
import yt_yson_bindings
import yp.data_model as data_model
from yp_proto.yp.client.api.proto import deploy_pb2
from infra.dctl.src import consts
from infra.dctl.src.lib import helpers
from infra.dctl.src.lib import yputil

TARGET_SIDECAR_REVISION_LABEL = 'du_sidecar_target_revision'
TARGET_RUNTIME_REVISION_LABEL = 'du_patchers_target_revision'


def pop_notifications_state(stage_dict):
    stage_dict['annotations'].pop('notifications_last_state', None)


def get(stage_id, client):
    """
    :type stage_id: str
    :type client: infra.dctl.src.lib.yp_client.YpClient
    :rtype: yp.data_model.TStage
    """
    stage = client.get(object_type=data_model.OT_STAGE,
                       object_id=stage_id)
    stage.spec.ClearField("account_id")
    return stage


def validate(stage):
    """
    :type stage: yp.data_model.TStage
    """
    if len(stage.meta.project_id) == 0:
        raise click.ClickException(
            'Stage must have /meta/project_id field, please fill it or get your existing stage from YP')
    if len(stage.meta.id) > consts.MAX_STAGE_ID_LENGTH:
        raise click.ClickException('Stage ID "{}" is too long: {} chars, max: {} '
                                   'chars'.format(stage.meta.id, len(stage.meta.id), consts.MAX_STAGE_ID_LENGTH))
    if stage.spec.account_id != stage.meta.account_id and stage.spec.account_id != "":
        raise click.ClickException(
            '/meta/account_id and /spec/account_id are not equal, please fill them with same value or use /meta/account_id instead of /spec/account_id')
    for du in stage.spec.deploy_units:
        if len(du) > consts.MAX_DEPLOY_UNIT_ID_LENGTH:
            raise click.ClickException('Deploy unit ID "{}" is too long: {} chars, max: {} '
                                       'chars'.format(du, len(du), consts.MAX_DEPLOY_UNIT_ID_LENGTH))
        if stage.spec.deploy_units[du].patchers_revision == 0:
            click.echo('Deploy unit "{}" must have patchers_revision field, please fill it'.format(du), err=True)


def put(stage, cluster, rewrite_delegation_tokens,
        vault_client, vault_client_rsa_fallback, client):
    """
    :type stage: yp.data_model.TStage
    :type rewrite_delegation_tokens: bool
    :type cluster: str
    :type vault_client: library.python.vault_client.VaultClient
    :type vault_client_rsa_fallback: library.python.vault_client.VaultClient
    :type client: infra.dctl.src.lib.yp_client.YpClient
    :rtype: str
    """
    deploy_engine = yputil.get_label(stage.labels, 'deploy_engine')
    if deploy_engine is None:
        yputil.set_label(stage.labels, 'deploy_engine', 'env_controller')

    stage.spec.revision = 1
    helpers.patch_stage_spec_before_put(
        vault_client=vault_client,
        vault_client_rsa_fallback=vault_client_rsa_fallback,
        stage_spec=stage.spec,
        stage_id=stage.meta.id,
        cluster=cluster,
        rewrite_delegation_tokens=rewrite_delegation_tokens,
    )
    obj_type = data_model.OT_STAGE
    obj = client.get(object_type=obj_type, object_id=stage.meta.id, ignore_nonexistent=True)
    if obj is None:
        client.create(object_type=obj_type, obj=stage, create_with_acl=False, add_default_user_acl=False)
    else:
        meta_fields = {}
        if obj.meta.account_id != stage.meta.account_id:
            meta_fields['/meta/account_id'] = stage.meta.account_id
        if obj.meta.project_id != stage.meta.project_id:
            meta_fields['/meta/project_id'] = stage.meta.project_id

        client.update_revision_increment(object_type=obj_type,
                                         object_id=stage.meta.id,
                                         obj=stage,
                                         path_to_value_specific_fields=meta_fields,
                                         copy_notifications_state=True)
    return stage.meta.id


def remove(stage_id, client):
    """
    :type client: infra.dctl.src.lib.yp_client.YpClient
    :type stage_id: str
    """
    return client.remove(object_type=data_model.OT_STAGE,
                         object_id=stage_id)


def get_deploy_units_total(stage):
    rs = 0
    mcrs = 0
    for s in stage.spec.deploy_units.values():
        t = s.WhichOneof('pod_deploy_primitive')
        if t == 'replica_set':
            rs += 1
        elif t == 'multi_cluster_replica_set':
            mcrs += 1
    return rs, mcrs


def aggregate_progress(stage):
    rv = deploy_pb2.TDeployProgress()
    for s in stage.status.deploy_units.values():
        rv.pods_ready += s.progress.pods_ready
        rv.pods_in_progress += s.progress.pods_in_progress
        rv.pods_total += s.progress.pods_total
    return rv


def stringify_progress(p):
    return '{}/{}/{}'.format(click.style(str(p.pods_ready), fg='green'),
                             click.style(str(p.pods_in_progress), fg='blue'),
                             p.pods_total)


def list_objects(client, cluster, user, limit, project):
    """
    :type client: infra.dctl.src.lib.yp_client.YpClient
    :type cluster: str
    :type user: str
    :type limit: int
    :type project: str | None
    :rtype: prettytable.PrettyTable
    """
    rv = prettytable.PrettyTable(['Cluster', 'ID', 'SpecRev', 'RS', 'MCRS', 'Pods'])
    query = '[/meta/project_id]="{}"'.format(project) if project else None
    batch = client.list(object_type=data_model.OT_STAGE, user=user, limit=limit, query=query)
    for stage in batch:
        r = [cluster, stage.meta.id, stage.spec.revision]
        r.extend(get_deploy_units_total(stage))
        r.append(stringify_progress(aggregate_progress(stage)))
        rv.add_row(r)
    return rv


def get_cluster_statuses(s):
    """
    :type s: yp_proto.yp.client.api.proto.stage_pb2.TDeployUnitStatus
    :rtype: (set[str], str, str, str)
    """
    clusters = None
    rs_id = None
    mcrs_id = None
    es_id = None

    if s.WhichOneof('details') == 'multi_cluster_replica_set':
        mcrs_id = s.multi_cluster_replica_set.replica_set_id
        cluster_statuses = s.multi_cluster_replica_set.cluster_statuses
    else:
        cluster_statuses = s.replica_set.cluster_statuses

    if not cluster_statuses:
        return clusters, rs_id, mcrs_id, es_id

    clusters = ','.join(sorted(cluster_statuses.keys()))
    # For now we retrieve only one rs_id/es_id because they are equal for
    # all clusters.
    cluster_status = next(iter(cluster_statuses.values()))
    es_id = cluster_status.endpoint_set_id
    if s.WhichOneof('details') == 'replica_set':
        rs_id = cluster_status.replica_set_id
    return clusters, rs_id, mcrs_id, es_id


def get_status(stage_id, client):
    """
    :type stage_id: str
    :type client: infra.dctl.src.lib.yp_client.YpClient
    :rtype: prettytable.PrettyTable
    """
    stage = get(stage_id, client)
    rv = prettytable.PrettyTable(['StageID', 'SpecRev', 'DeployUnitID',
                                  'DeployStatus', 'Pods', 'Clusters',
                                  'RS', 'MCRS', 'ES', 'StatusRev'])
    rev = stage.spec.revision
    if not stage.status.deploy_units:
        row = [stage.meta.id, rev] + ['-'] * 7 + [stage.status.revision]
        rv.add_row(row)
        return rv

    for unit_id, s in sorted(stage.status.deploy_units.items()):
        r = [stage.meta.id, rev, unit_id]
        if rev != stage.status.revision:
            r.append('Outdated')
        elif s.in_progress.status == data_model.CS_TRUE:
            r.append('InProgress')
        elif s.ready:
            r.append('Ready')
        else:
            r.append('-')
        r.append(stringify_progress(s.progress))
        r.extend(get_cluster_statuses(s))
        r.append(stage.status.revision)
        rv.add_row(r)
    return rv


def copy_stage(stage, new_stage_id,
               vault_client, vault_client_rsa_fallback,
               cluster, client):
    """
    :type stage: yp.data_model.TStage
    :type new_stage_id: str
    :type vault_client: library.python.vault_client.VaultClient
    :type vault_client_rsa_fallback: library.python.vault_client.VaultClient
    :type cluster: str
    :type client: infra.dctl.src.lib.yp_client.YpClient
    """
    stage.meta.id = new_stage_id
    put(stage=stage,
        cluster=cluster,
        rewrite_delegation_tokens=True,
        vault_client=vault_client,
        vault_client_rsa_fallback=vault_client_rsa_fallback,
        client=client)


def cast_yaml_dict_to_yp_object(d):
    """
    :type d: dict
    :rtype: yp.data_model.TStage
    """
    return yt_yson_bindings.loads_proto(yson.dumps(d),
                                        proto_class=data_model.TStage,
                                        skip_unknown_fields=False)


def override_deployment_strategy(client, stage_id, revision, du_id, new_max_unavailable_value,
                                 clusters_to_override):
    options = data_model.TStageControl.TOverrideDeploymentStrategy()
    options.max_unavailable.deploy_unit_id = du_id
    options.max_unavailable.revision = revision
    options.max_unavailable.value = new_max_unavailable_value
    options.max_unavailable.clusters.extend(clusters_to_override)
    client.control(data_model.OT_STAGE, stage_id, "override_deployment_strategy", options)


def approve_location(client, stage_id, du_revision, du_id, clusters):
    approve_disapprove_location(client, stage_id, du_revision, du_id, clusters, "approve")


def disapprove_location(client, stage_id, du_revision, du_id, clusters):
    approve_disapprove_location(client, stage_id, du_revision, du_id, clusters, "disapprove")


def approve_disapprove_location(client, stage_id, du_revision, du_id, clusters, control):
    for cluster in clusters:
        options = data_model.TStageControl.TApproveAction()
        options.options.cluster = cluster
        options.options.revision = du_revision
        options.options.deploy_unit = du_id
        client.control(data_model.OT_STAGE, stage_id, control, options)


def update_sidecars(client, stage_id):
    stage = client.get(object_type=data_model.OT_STAGE,
                       object_id=stage_id)
    stage.spec.ClearField("account_id")
    sidecars_rev = yputil.get_label(stage.labels, TARGET_SIDECAR_REVISION_LABEL, '')
    runtime_rev = yputil.get_label(stage.labels, TARGET_RUNTIME_REVISION_LABEL, '')
    if sidecars_rev == '' and runtime_rev == '':
        return

    yputil.set_label(stage.labels, TARGET_SIDECAR_REVISION_LABEL, '')
    yputil.set_label(stage.labels, TARGET_RUNTIME_REVISION_LABEL, '')
    for du_id, du in stage.spec.deploy_units.items():
        if runtime_rev != '' and runtime_rev[du_id] is not None:
            rev = runtime_rev[du_id]
            click.echo('Deploy unit "{}", runtime revision will be updated to {}'.format(du_id, rev))
            du.patchers_revision = rev

        if sidecars_rev == '' or sidecars_rev[du_id] is None:
            continue
        req_labels = sidecars_rev[du_id]
        click.echo('Deploy unit "{}", labels to update: {}'.format(du_id, req_labels))

        for sidecar, rev in req_labels.items():
            if sidecar == 'druLayer':
                du.dynamic_resource_updater_sandbox_info.revision = rev
            elif sidecar == 'tvm':
                du.tvm_sandbox_info.revision = rev
            elif sidecar == 'podBin':
                du.pod_agent_sandbox_info.revision = rev
            elif sidecar == 'logbrokerToolsLayer':
                du.logbroker_tools_sandbox_info.revision = rev

    client.update_revision_increment(object_type=data_model.OT_STAGE,
                                     object_id=stage.meta.id,
                                     obj=stage,
                                     path_to_value_specific_fields={},
                                     assert_same_revision=True)

    click.echo("Infra components successfully updated")
