import argparse
import os

import yp.client
import yp.data_model as data_model
import yt.yson as yson
from yp_proto.yp.client.api.proto import object_service_pb2

import infra.dctl.src.consts as consts

YP_TOKEN_ENV = "YP_TOKEN"
YP_TOKEN_FILE = os.path.expanduser("~/.yp/token")

SELECT_LIMIT = 10000
WARNING_UPDATE_LIMIT = 15


def get_cluster_list():
    ret = []
    for cluster in consts.CLUSTER_CONFIGS.keys():
        ret.append(cluster)

    return ret


def get_token(token_env, token_path):
    token = os.getenv(token_env)
    if token:
        print("Use yp token from env {}".format(token_env))
        return token

    if os.path.isfile(token_path):
        print("Use yp token from file {}".format(token_path))
        with open(token_path, 'r') as f:
            return f.read().strip()

    raise Exception("No yp token provided")


def list_pods(yp_client_stub, pod_filter, select_limit):
    req = object_service_pb2.TReqSelectObjects()
    req.object_type = data_model.OT_POD
    req.limit.value = select_limit

    req.selector.paths.append("/meta/id")

    if pod_filter is not None:
        print("Current pod filter '{}'".format(pod_filter))
        req.filter.query = pod_filter

    resp = yp_client_stub.SelectObjects(req)

    objects = []
    for r in resp.results:
        try:
            object_id = yson.loads(r.values[0])
        except Exception:
            print("Error while parsing object id in select '{}'".format(str(r)))
            continue

        objects.append(object_id)

    if len(resp.results) == select_limit:
        print("WARNING You have selected as many objects as specified in the query limit")

    return objects


def start_transaction(yp_client_stub):
    req = object_service_pb2.TReqStartTransaction()
    resp = yp_client_stub.StartTransaction(req)

    return resp.transaction_id, resp.start_timestamp


def commit_transaction(yp_client_stub, transaction_id):
    req = object_service_pb2.TReqCommitTransaction()
    req.transaction_id = transaction_id
    yp_client_stub.CommitTransaction(req)


def update_pod_custom(yp_client_stub, transaction_id, pod):
    req = object_service_pb2.TReqUpdateObject()
    req.object_type = data_model.OT_POD
    req.object_id = pod

    req.transaction_id = transaction_id

    # TODO custom update
    # upd = req.set_updates.add()
    # upd.path = "/spec/revision"
    # upd.value = yson.dumps(revision)

    yp_client_stub.UpdateObject(req)


def update_pod(yp_client_stub, pod):
    print("Start transaction")
    transaction_id, transaction_timestamp = start_transaction(yp_client_stub)
    print("Started")

    print("Run custom update")
    update_pod_custom(yp_client_stub, transaction_id, pod)
    print("Updated")

    print("Commit")
    commit_transaction(yp_client_stub, transaction_id)
    print("Success commit")


def update_cluster(cluster, pod_filter, select_limit):
    cluster_config = consts.CLUSTER_CONFIGS[cluster]

    print("-" * 50)
    print("Update cluster {}".format(cluster))

    yp_token = get_token(YP_TOKEN_ENV, YP_TOKEN_FILE)

    yp_client = yp.client.YpClient(
        address=cluster_config.address,
        config={
            'token': yp_token,
        }
    )
    yp_client_stub = yp_client.create_grpc_object_stub()

    pods = list_pods(yp_client_stub, pod_filter, select_limit)

    if len(pods) == 0:
        print("No pods to update")
        print("Cluster {} skiped".format(cluster))
        return
    elif len(pods) <= WARNING_UPDATE_LIMIT:
        print("Selected {} pods:".format(len(pods)))
        for pod in pods:
            print(pod)
        print("")

        expected_ret = "y"
        ret = raw_input("Continue? (y/n): ").rstrip()
        while ret != "y" and ret != "n":
            ret = raw_input("Use 'y' or 'n': ").rstrip()
    else:
        print("You select {} pods".format(len(pods)))

        expected_ret = "Yes, do as I say!"
        ret = raw_input("This is a lot of objects, you really want to continue? If yes type '{}': ".format(expected_ret)).rstrip()

    if ret != expected_ret:
        print("Cluster {} skiped".format(cluster))
        return

    print("Current pods: ")
    print(", ".join(pods))

    success = 0
    errors = []
    for pod in pods:
        print("Update pod {}".format(pod))
        try:
            update_pod(yp_client_stub, pod)
            print("Success update {}\n".format(pod))
            success += 1
        except Exception as e:
            print("Error in pod {}: '{}'".format(pod, str(e)))
            errors.append((pod, str(e)))

    print("Updated {}/{}".format(success, len(pods)))

    if len(errors) != 0:
        print("Errors:")
        for pod, error in errors:
            print("Pod {}, error: '{}'".format(pod, error))

        print("")
    else:
        print("No errors")

    raw_input("Press enter to continue")


def main(arguments):
    for cluster in arguments.clusters:
        try:
            update_cluster(cluster, arguments.filter, arguments.select_limit)
        except Exception as e:
            print("Error in cluster {}: '{}'".format(cluster, str(e)))


def parse_arguments():
    parser = argparse.ArgumentParser(description="Patch spec of every selected pod.")
    parser.add_argument(
        "clusters",
        metavar="cluster",
        type=str,
        nargs="+",
        choices=get_cluster_list(),
        help="clusters for update."
    )
    parser.add_argument(
        "-f", "--filter",
        dest="filter",
        default="[/labels/deploy_engine] = \"RSC\" or [/labels/deploy_engine] = \"MCRSC\"",
        help="filter for pods."
    )
    parser.add_argument(
        "--select-limit",
        dest="select_limit",
        type=int,
        default=SELECT_LIMIT,
        help="maximum number of objects to update."
    )

    return parser.parse_args()


if __name__ == "__main__":
    main(parse_arguments())
