import argparse
import os
import time

import yp.client
import yp.data_model as data_model
import yt.yson as yson
from yp_proto.yp.client.api.proto import object_service_pb2

import infra.dctl.src.consts as consts

YP_TOKEN_ENV = "YP_TOKEN"
YP_TOKEN_FILE = os.path.expanduser("~/.yp/token")

SELECT_LIMIT = 10000
WARNING_UPDATE_LIMIT = 15


def get_cluster_list():
    ret = []
    for cluster in consts.CLUSTER_CONFIGS.keys():
        ret.append(cluster)

    return ret


def get_token(token_env, token_path):
    token = os.getenv(token_env)
    if token:
        print("Use yp token from env {}".format(token_env))
        return token

    if os.path.isfile(token_path):
        print("Use yp token from file {}".format(token_path))
        with open(token_path, 'r') as f:
            return f.read().strip()

    raise Exception("No yp token provided")


def list_stages(yp_client_stub, stage_filter, select_limit):
    req = object_service_pb2.TReqSelectObjects()
    req.object_type = data_model.OT_STAGE
    req.limit.value = select_limit

    req.selector.paths.append("/meta/id")

    if stage_filter is not None:
        req.filter.query = stage_filter

    resp = yp_client_stub.SelectObjects(req)

    objects = []
    for r in resp.results:
        try:
            object_id = yson.loads(r.values[0])
        except Exception:
            print("Error while parsing object id in select '{}'".format(str(r)))
            continue

        objects.append(object_id)

    if len(resp.results) == select_limit:
        print("WARNING You have selected as many objects as specified in the query limit")

    return objects


def start_transaction(yp_client_stub):
    req = object_service_pb2.TReqStartTransaction()
    resp = yp_client_stub.StartTransaction(req)

    return resp.transaction_id, resp.start_timestamp


def commit_transaction(yp_client_stub, transaction_id):
    req = object_service_pb2.TReqCommitTransaction()
    req.transaction_id = transaction_id
    yp_client_stub.CommitTransaction(req)


def get_stage_revision(yp_client_stub, transaction_timestamp, stage):
    req = object_service_pb2.TReqGetObject()
    req.object_type = data_model.OT_STAGE
    req.object_id = stage

    req.selector.paths.append("/spec/revision")

    req.timestamp = transaction_timestamp

    resp = yp_client_stub.GetObject(req)
    return yson.loads(resp.result.values[0])


def fix_pod_template_spec_mutable_workloads(yp_client_stub, transaction_timestamp, spec, path_prefix, req):
    spec = spec.get("spec", dict()).get("pod_agent_payload", dict()).get("spec", dict())

    workloads = spec.get("workloads", [])
    mutable_workloads = spec.get("mutable_workloads", [])

    workloads_set = set()
    mutable_workloads_set = set()

    for workload in workloads:
        workloads_set.add(workload["id"])

    for mutable_workload in mutable_workloads:
        mutable_workloads_set.add(mutable_workload["workload_ref"])

    need_add = workloads_set - mutable_workloads_set

    if need_add:
        for workload_to_add in need_add:
            mutable_workloads.append({"workload_ref": workload_to_add})

        upd = req.set_updates.add()
        upd.path = "{}/{}".format(path_prefix, "spec/pod_agent_payload/spec/mutable_workloads")
        upd.value = yson.dumps(mutable_workloads)

        return True
    else:
        return False


def get_stage_deploy_units(yp_client_stub, transaction_timestamp, stage):
    req = object_service_pb2.TReqGetObject()
    req.object_type = data_model.OT_STAGE
    req.object_id = stage

    req.selector.paths.append("/spec/deploy_units")

    if transaction_timestamp:
        req.timestamp = transaction_timestamp

    resp = yp_client_stub.GetObject(req)

    return yson.loads(resp.result.values[0])


def stage_has_logs(yp_client_stub, stage):
    deploy_units = get_stage_deploy_units(yp_client_stub, None, stage)

    if not isinstance(deploy_units, dict):
        print("Warning! stage '{}' does not have deploy units".format(stage))
        return False

    for deploy_unit_id, deploy_unit_spec in deploy_units.items():
        pod_spec = {}
        if "replica_set" in deploy_unit_spec:
            pod_spec = deploy_unit_spec["replica_set"].get("replica_set_template").get("pod_template_spec", dict())
        elif "multi_cluster_replica_set" in deploy_unit_spec:
            pod_spec = deploy_unit_spec["multi_cluster_replica_set"].get("replica_set", dict()).get("pod_template_spec", dict())
        else:
            print("Warning! No pod_deploy_primitive in deploy unit '{}' spec for stage '{}'".format(deploy_unit_id, stage))

        workloads = pod_spec.get("spec", dict()).get("pod_agent_payload", dict()).get("spec", dict()).get("workloads", [])
        for workload in workloads:
            if workload.get("transmit_logs", False):
                return True

    return False


def filter_stages_with_logs(yp_client_stub, stages):
    stages_with_logs = []
    for stage in stages:
        if stage_has_logs(yp_client_stub, stage):
            stages_with_logs.append(stage)

    return stages_with_logs


def fix_stage_mutable_workloads(yp_client_stub, transaction_id, transaction_timestamp, stage):
    deploy_units = get_stage_deploy_units(yp_client_stub, transaction_timestamp, stage)

    req = object_service_pb2.TReqUpdateObject()
    req.object_type = data_model.OT_STAGE
    req.object_id = stage

    req.transaction_id = transaction_id

    need_fix = False
    for deploy_unit_id, deploy_unit_spec in deploy_units.items():
        if "replica_set" in deploy_unit_spec:
            need_fix |= fix_pod_template_spec_mutable_workloads(
                yp_client_stub,
                transaction_timestamp,
                deploy_unit_spec["replica_set"].get("replica_set_template").get("pod_template_spec", dict()),
                "/spec/deploy_units/{}/replica_set/replica_set_template/pod_template_spec".format(deploy_unit_id),
                req
            )
        elif "multi_cluster_replica_set" in deploy_unit_spec:
            need_fix |= fix_pod_template_spec_mutable_workloads(
                yp_client_stub,
                transaction_timestamp,
                deploy_unit_spec["multi_cluster_replica_set"].get("replica_set", dict()).get("pod_template_spec", dict()),
                "/spec/deploy_units/{}/multi_cluster_replica_set/replica_set/pod_template_spec".format(deploy_unit_id),
                req
            )
        else:
            print("Warning! No pod_deploy_primitive in deploy unit '{}' spec".format(deploy_unit_id))

    if need_fix:
        print("Stage need fix")
        yp_client_stub.UpdateObject(req)
        print("Stage fixed")
    else:
        print("Stage correct")


def update_stage_revision(yp_client_stub, transaction_id, stage, revision, revision_description):
    req = object_service_pb2.TReqUpdateObject()
    req.object_type = data_model.OT_STAGE
    req.object_id = stage

    req.transaction_id = transaction_id

    upd_revision = req.set_updates.add()
    upd_revision.path = "/spec/revision"
    upd_revision.value = yson.dumps(revision)

    upd_revision_description = req.set_updates.add()
    upd_revision_description.path = "/spec/revision_info"
    upd_revision_description.value = yson.dumps({"description": revision_description})

    yp_client_stub.UpdateObject(req)


def reset_number_of_pods(yp_client_stub, transaction_id, transaction_timestamp, stage):
    deploy_units = get_stage_deploy_units(yp_client_stub, transaction_timestamp, stage)

    req = object_service_pb2.TReqUpdateObject()
    req.object_type = data_model.OT_STAGE
    req.object_id = stage

    req.transaction_id = transaction_id

    for deploy_unit_id, deploy_unit_spec in deploy_units.items():
        if "replica_set" in deploy_unit_spec:
            fix_rsc_pod_count(
                deploy_unit_spec["replica_set"].get("per_cluster_settings", dict()),
                "/spec/deploy_units/{}/replica_set/per_cluster_settings".format(deploy_unit_id),
                req
            )
        elif "multi_cluster_replica_set" in deploy_unit_spec:
            fix_mcrsc_pod_count(
                deploy_unit_spec["multi_cluster_replica_set"].get("replica_set", dict()).get("clusters", []),
                "/spec/deploy_units/{}/multi_cluster_replica_set/replica_set/clusters".format(deploy_unit_id),
                req
            )
        else:
            print("Warning! No pod_deploy_primitive in deploy unit '{}' spec".format(deploy_unit_id))

    yp_client_stub.UpdateObject(req)


def fix_rsc_pod_count(spec, path_prefix, req):
    for dc, value in spec.iteritems():
        print("DC: {}, pod_count: {}".format(dc, value.get("pod_count", 0)))
        upd = req.set_updates.add()
        upd.path = "{}/{}/pod_count".format(path_prefix, dc)
        upd.value = yson.dumps(0)


def fix_mcrsc_pod_count(spec, path_prefix, req):
    for value in spec:
        print("DC: {}, pod_count: {}".format(value['cluster'], value['spec'].get('replica_count', 0)))
        value['spec']['replica_count'] = 0
    upd = req.set_updates.add()
    upd.path = path_prefix
    upd.value = yson.dumps(spec)


def remove_deploy_acl(yp_client_stub, transaction_timestamp, transaction_id, stage):
    # get current acls
    req = object_service_pb2.TReqGetObject()
    req.object_type = data_model.OT_STAGE
    req.object_id = stage
    req.selector.paths.append("/meta/acl")
    req.timestamp = transaction_timestamp

    resp = yp_client_stub.GetObject(req)

    acls = yson.loads(resp.result.values[0])
    print("acls:", acls)

    new_acls = []
    need_update = False

    for acl in acls:
        subjects = acl["subjects"]
        new_subjects = []
        for subject in subjects:
            if subject.startswith("deploy:"):
                need_update = True
            else:
                new_subjects.append(subject)
        acl["subjects"] = new_subjects

        if len(new_subjects) > 0:
            new_acls.append(acl)

    if not need_update:
        print("There are no deploy ACLs")
        return

    print("new_acls:", new_acls)

    # update acls if required
    req = object_service_pb2.TReqUpdateObject()
    req.object_type = data_model.OT_STAGE
    req.object_id = stage
    req.transaction_id = transaction_id

    upd = req.set_updates.add()
    upd.path = "/meta/acl"
    upd.value = yson.dumps(new_acls)

    yp_client_stub.UpdateObject(req)


def drop_stage_acl(yp_client_stub, stage):
    print("Start transaction")
    transaction_id, transaction_timestamp = start_transaction(yp_client_stub)
    print("Started")

    remove_deploy_acl(yp_client_stub, transaction_timestamp, transaction_id, stage)

    print("Commit")
    commit_transaction(yp_client_stub, transaction_id)
    print("Success commit")


def update_stage(yp_client_stub, stage, drop_pods, revision_description):
    # Just update revision now
    # You can do whatever you want here
    # But you MUST bump revision in your update to trigger stage_ctl update

    print("Start transaction")
    transaction_id, transaction_timestamp = start_transaction(yp_client_stub)
    print("Started")

    print("Fix stage mutable workloads")
    fix_stage_mutable_workloads(yp_client_stub, transaction_id, transaction_timestamp, stage)
    print("Mutable workloads fixed")

    if drop_pods:
        print("Drop pods")
        reset_number_of_pods(yp_client_stub, transaction_id, transaction_timestamp, stage)
        print("Pods dropped")

    print("Get revision")
    revision = get_stage_revision(yp_client_stub, transaction_timestamp, stage)
    print("Current revision {}".format(revision))

    print("Update to revision {}".format(revision + 1))
    update_stage_revision(yp_client_stub, transaction_id, stage, revision + 1, revision_description)
    print("Revision updated")

    print("Commit")
    commit_transaction(yp_client_stub, transaction_id)
    print("Success commit")


def update_cluster(
    cluster,
    stage_filter,
    update_only_stages_with_logs,
    select_limit,
    update_interval_seconds,
    dry_run,
    drop_pods,
    token,
    update_window,
    revision_description,
    drop_acls
):
    cluster_config = consts.CLUSTER_CONFIGS[cluster]

    print("-" * 50)
    print("Update cluster {}".format(cluster))

    yp_token = get_token(YP_TOKEN_ENV, YP_TOKEN_FILE) if not token else token

    yp_client = yp.client.YpClient(
        address=cluster_config.address,
        config={
            'token': yp_token,
        }
    )
    yp_client_stub = yp_client.create_grpc_object_stub()

    stages = list_stages(yp_client_stub, stage_filter, select_limit)

    # Separate from the main select request
    # because sometimes the yp response does not fit within the grpc limit
    if update_only_stages_with_logs:
        print("Selected {} stages".format(len(stages)))
        print("Filtering stages with logs")
        stages = filter_stages_with_logs(yp_client_stub, stages)
        print("Filtered {} stages".format(len(stages)))

    if len(stages) == 0:
        print("No stages to update")
        print("Cluster {} skiped".format(cluster))
        return
    elif len(stages) <= WARNING_UPDATE_LIMIT:
        print("Selected {} stages:".format(len(stages)))
        for stage in stages:
            print(stage)
        print("")

        expected_ret = "y"
        ret = raw_input("Continue? (y/n): ").rstrip()
        while ret != "y" and ret != "n":
            ret = raw_input("Use 'y' or 'n': ").rstrip()
    else:
        print("You select {} stages".format(len(stages)))

        expected_ret = "Yes, do as I say!"
        ret = raw_input(
            "This is a lot of objects, you really want to continue? If yes type '{}': ".format(expected_ret)).rstrip()

    if ret != expected_ret:
        print("Cluster {} skipped".format(cluster))
        return

    print("Current stages [{}]: ".format(len(stages)))
    print(", ".join(stages))

    success = 0
    errors = []
    step = update_window if update_window else len(stages) + 1
    skip = False
    for stage in stages:
        print("Update stage {}".format(stage))
        if dry_run or skip:
            print("[DRY_RUN] Updated {}\n".format(stage))
        else:
            try:
                if drop_acls:
                    drop_stage_acl(yp_client_stub, stage)
                else:
                    update_stage(yp_client_stub, stage, drop_pods, revision_description)
                print("Successfully updated {}\n".format(stage))
                success += 1
            except Exception as e:
                print("Error in stage {}: '{}'".format(stage, str(e)))
                errors.append((stage, str(e)))
        time.sleep(update_interval_seconds)
        step -= 1
        if not step:
            ret = raw_input("Continue? [y|n|s]")
            if ret == 's':
                print('Skip next bunch of stages')
                skip = True
            elif ret != 'y':
                break
            else:
                skip = False
            step = update_window

    print("Updated {}/{}".format(success, len(stages)))

    if len(errors) != 0:
        print("Errors:")
        for stage, error in errors:
            print("Stage {}, error: '{}'".format(stage, error))

        print("")
    else:
        print("No errors")

    raw_input("Press enter to continue")


def main(arguments):
    for cluster in arguments.clusters:
        try:
            update_cluster(
                cluster,
                arguments.filter,
                arguments.update_only_stages_with_logs,
                arguments.select_limit,
                arguments.time_interval,
                arguments.dry_run,
                arguments.drop_pods,
                arguments.token,
                arguments.update_window,
                arguments.revision_description,
                arguments.drop_acls
            )
        except Exception as e:
            print("Error in cluster {}: '{}'".format(cluster, str(e)))


def parse_arguments():
    parser = argparse.ArgumentParser(description="Up revision in every selected stage.")
    parser.add_argument(
        "clusters",
        metavar="cluster",
        type=str,
        nargs="+",
        choices=get_cluster_list(),
        help="clusters for update."
    )
    parser.add_argument(
        "-f", "--filter",
        dest="filter",
        default=None,
        help="filter for stages."
    )
    parser.add_argument(
        "--update-only-stages-with-logs",
        dest="update_only_stages_with_logs",
        action="store_true",
        default=False,
        help="Additionally to filter, leaves only those stages that have logs transmitter."
    )
    parser.add_argument(
        "--select-limit",
        dest="select_limit",
        type=int,
        default=SELECT_LIMIT,
        help="maximum number of objects to update."
    )
    parser.add_argument(
        "--time-interval",
        dest="time_interval",
        type=int,
        default=0,
        help="Interval in seconds between updates."
    )
    parser.add_argument(
        "--dry-run",
        dest="dry_run",
        action="store_true",
        default=False,
        help="print selected stages, do not update anything."
    )
    parser.add_argument(
        "--drop-pods",
        dest="drop_pods",
        action="store_true",
        default=False,
        help="set pod count to zero for all selected deploy units."
    )
    parser.add_argument(
        "--drop-acls",
        dest="drop_acls",
        action="store_true",
        default=False,
        help="Delete deploy specific ACLs in all stages. Revision won't be bumped."
    )
    parser.add_argument(
        "--token",
        dest="token",
        default=None,
        help="use specified YP token."
    )
    parser.add_argument(
        "--update-window",
        dest="update_window",
        type=int,
        default=0,
        help="Update window sise. Next step has to be approved by user."
    )
    parser.add_argument(
        "--revision-description",
        dest="revision_description",
        type=str,
        required=True,
        help="Description for new stage revision. Please use the format '<TICKET_ID> <Message>' for large updates."
    )

    return parser.parse_args()


if __name__ == "__main__":
    main(parse_arguments())
