import argparse
import os

import json
import yaml

import yp.client
import yp.common
import yp.data_model as data_model
import yt.yson as yson
import yt_yson_bindings
from yp_proto.yp.client.api.proto import object_service_pb2

import infra.dctl.src.consts as consts

YP_TOKEN_ENV = "YP_TOKEN"
YP_TOKEN_FILE = os.path.expanduser("~/.yp/token")

SELECT_LIMIT = 10000


def get_cluster_list():
    ret = []
    for cluster in consts.CLUSTER_CONFIGS.keys():
        ret.append(cluster)

    return ret


def get_token(token_env, token_path):
    token = os.getenv(token_env)
    if token:
        print("Use yp token from env {}".format(token_env))
        return token

    if os.path.isfile(token_path):
        print("Use yp token from file {}".format(token_path))
        with open(token_path, 'r') as f:
            return f.read().strip()

    raise Exception("No yp token provided")


def dict_to_protobuf(object_type, object_dict):
    return yp.common.dict_to_protobuf(object_dict, data_model.OT_STAGE)


def create_yp_client(cluster):
    cluster_config = consts.CLUSTER_CONFIGS[cluster]

    yp_token = get_token(YP_TOKEN_ENV, YP_TOKEN_FILE)

    yp_client = yp.client.YpClient(
        address=cluster_config.address,
        config={
            'token': yp_token,
        }
    )
    yp_client_stub = yp_client.create_grpc_object_stub()

    return yp_client, yp_client_stub


def get_stages(yp_client_stub, stage_filter, select_limit):
    req = object_service_pb2.TReqSelectObjects()
    req.object_type = data_model.OT_STAGE
    req.limit.value = select_limit

    req.selector.paths.append("")
    req.selector.paths.append("/meta/id")

    if stage_filter is not None:
        req.filter.query = stage_filter

    resp = yp_client_stub.SelectObjects(req)

    objects = []
    for r in resp.results:
        try:
            cur_object = yt_yson_bindings.loads_proto(r.values[0], data_model.TStage)
        except Exception as e:
            print("Error while parsing object '{}' in select: '{}'".format(str(r.values[1]), str(e)))
            continue

        objects.append(cur_object)

    if len(resp.results) == select_limit:
        print("WARNING You have selected as many objects as specified in the query limit")

    return objects


def remove_stage(yp_client_stub, stage_id):
    req = object_service_pb2.TReqRemoveObject(object_type=data_model.OT_STAGE, object_id=stage_id)
    return yp_client_stub.RemoveObject(req)


def create_stage(yp_client_stub, stage_id, stage_data):
    req = object_service_pb2.TReqCreateObject(object_type=data_model.OT_STAGE, attributes=yt_yson_bindings.dumps_proto(stage_data))
    return yp_client_stub.CreateObject(req)


def get_batch_size(index):
    if index <= 0:
        return 1
    elif index <= 1:
        return 5
    else:
        return 25


def get_full_dump_dir(cluster):
    return cluster + "/full_dump"


def get_dump_dir(cluster):
    return cluster + "/dump"


def get_update_dir(cluster):
    return cluster + "/update"


def get_batches_dir(cluster):
    return cluster + "/batches"


def get_batch_file(batches_dir, index):
    return "{}/{}".format(batches_dir, "{0:03}".format(index))


def update_stage(stage):
    local_stage = stage.copy()

    local_stage["spec"]["revision"] += 1

    if local_stage["spec"]["account_id"] == "tmp":
        acls = local_stage["meta"]["acl"]
        new_acls = []

        for acl in acls:
            subjects = acl["subjects"]
            new_subjects = []
            for subject in subjects:
                if subject.find("abc:") == -1:
                    new_subjects.append(subject)
            acl["subjects"] = new_subjects

            if len(new_subjects) > 0:
                new_acls.append(acl)

        local_stage["meta"]["acl"] = new_acls

    return local_stage


def safe_stage_to_file(object_dict, file_path):
    d = json.loads(json.dumps(object_dict))  # magic from dctl
    with open(file_path, 'w') as f:
        f.write(yaml.safe_dump(d))


def load_batch(cluster, index):
    with open(get_batch_file(get_batches_dir(cluster), index), 'r') as f:
        stage_ids = f.readlines()
        for i in range(len(stage_ids)):
            stage_ids[i] = stage_ids[i].rstrip()

    stages_data = []

    for stage_id in stage_ids:
        stage_data_file = "{}/{}".format(get_update_dir(cluster), stage_id)

        with open(stage_data_file, 'r') as f:
            # dctl copypaste
            d = yaml.load(f.read(), Loader=yaml.SafeLoader)
            stage_data = yt_yson_bindings.loads_proto(yson.dumps(d), proto_class=data_model.TStage, skip_unknown_fields=False)

        stages_data.append((stage_id, stage_data))

    return stages_data


def check_ac():
    ret = raw_input("Continue? (y/n): ").rstrip()
    while ret != "y" and ret != "n":
        ret = raw_input("Type 'y' or 'n': ").rstrip()

    return ret == "y"


def dump_db(arguments):
    if os.path.exists(arguments.cluster):
        raise Exception("Db already dumped. To dump again remove dir '{}'".format(arguments.cluster))

    full_dump_dir = get_full_dump_dir(arguments.cluster)
    dump_dir = get_dump_dir(arguments.cluster)
    update_dir = get_update_dir(arguments.cluster)
    batches_dir = get_batches_dir(arguments.cluster)

    os.mkdir(arguments.cluster)
    for path in [full_dump_dir, dump_dir, update_dir, batches_dir]:
        os.mkdir(path)

    yp_client, yp_client_stub = create_yp_client(arguments.cluster)

    stages = get_stages(yp_client_stub, arguments.filter, arguments.select_limit)

    stage_ids = []
    for stage in stages:
        stage = yson.loads(yt_yson_bindings.dumps_proto(stage))

        stage_id = stage["meta"]["id"]
        stage_ids.append(stage_id)

        safe_stage_to_file(stage, "{}/{}".format(full_dump_dir, stage_id))

        stage.pop("status")
        stage["meta"].pop("creation_time")
        stage["meta"].pop("uuid")
        stage["meta"].pop("type")

        safe_stage_to_file(stage, "{}/{}".format(dump_dir, stage_id))
        safe_stage_to_file(update_stage(stage), "{}/{}".format(update_dir, stage_id))

    batches = []
    while len(stage_ids) > 0:
        current_batch_size = get_batch_size(len(batches))

        if len(stage_ids) <= current_batch_size:
            batches.append(stage_ids[:])
            stage_ids = []
        else:
            batches.append(stage_ids[:current_batch_size])
            stage_ids = stage_ids[current_batch_size:]

    for i in range(len(batches)):
        with open(get_batch_file(batches_dir, i), 'w') as f:
            for stage_id in batches[i]:
                f.write(stage_id + "\n")


def remove_batch(arguments):
    if not os.path.exists(arguments.cluster):
        raise Exception("Db not dumped for '{}'".format(arguments.cluster))

    stages_data = load_batch(arguments.cluster, arguments.index)

    print("Stages in batch:")
    for stage_data in stages_data:
        print(stage_data[0])
    print("")

    if not check_ac():
        print("Interrupt")
        return

    yp_client, yp_client_stub = create_yp_client(arguments.cluster)

    for stage_data in stages_data:
        try:
            print("Remove '{}'".format(stage_data[0]))
            remove_stage(yp_client_stub, stage_data[0])
            print("Success")
        except Exception as e:
            print("Error while removing '{}': '{}'".format(stage_data[0], str(e)))


def upload_batch(arguments):
    if not os.path.exists(arguments.cluster):
        raise Exception("Db not dumped for '{}'".format(arguments.cluster))

    stages_data = load_batch(arguments.cluster, arguments.index)

    print("Stages in batch:")
    for stage_data in stages_data:
        print(stage_data[0])
    print("")

    if not check_ac():
        print("Interrupt")
        return

    yp_client, yp_client_stub = create_yp_client(arguments.cluster)

    for stage_data in stages_data:
        try:
            print("Create '{}'".format(stage_data[0]))
            create_stage(yp_client_stub, stage_data[0], stage_data[1])
            print("Success")
        except Exception as e:
            print("Error while creating '{}': '{}'".format(stage_data[0], str(e)))


def main(arguments):
    arguments.func(arguments)


def parse_arguments():
    parser = argparse.ArgumentParser(description="Dump and recreate stages.")

    parser.add_argument(
        "cluster",
        metavar="cluster",
        type=str,
        choices=get_cluster_list(),
        help="cluster for update."
    )

    subparsers = parser.add_subparsers()

    parse_dump = subparsers.add_parser("dump", help="Dump db.")
    parse_dump.add_argument(
        "-f", "--filter",
        dest="filter",
        default=None,
        help="filter for stages."
    )
    parse_dump.add_argument(
        "--select-limit",
        dest="select_limit",
        type=int,
        default=SELECT_LIMIT,
        help="maximum number of objects to select."
    )
    parse_dump.set_defaults(func=dump_db)

    parse_remove_batch = subparsers.add_parser("remove-batch", help="Remove stages from batch")
    parse_remove_batch.add_argument(
        "index",
        metavar="index",
        type=int,
        help="batch index."
    )
    parse_remove_batch.set_defaults(func=remove_batch)

    parse_upload_batch = subparsers.add_parser("upload-batch", help="Upload stages from batch")
    parse_upload_batch.add_argument(
        "index",
        metavar="index",
        type=int,
        help="batch index."
    )
    parse_upload_batch.set_defaults(func=upload_batch)

    return parser.parse_args()


if __name__ == "__main__":
    main(parse_arguments())
