import argparse
import os
import time
import requests

import yp.client
import yp.data_model as data_model
import yt.yson as yson
import infra.dctl.src.consts as consts

from pprint import pprint

from yp_proto.yp.client.api.proto import object_service_pb2
from yp.common import YpNoSuchObjectError, YpAuthorizationError

YP_TOKEN_ENV = "YP_TOKEN"
YP_TOKEN_FILE = os.path.expanduser("~/.yp/token")
SELECT_LIMIT = 100000
DEFAULT_UPDATE_LIMIT = 20
IGNORED_SUBJECT_PREFIXES = ['robot-', 'zomb-']

TYPE_MAP = {
    'stage': data_model.OT_STAGE,
    'project': data_model.OT_PROJECT,
    'group': data_model.OT_GROUP
}

STAFF_API_URL_FORMAT = 'https://staff-api.yandex-team.ru/v3/groups?id={}'


def is_subject_ignored(subject):
    for prefix in IGNORED_SUBJECT_PREFIXES:
        if subject.startswith(prefix):
            return True

    return False


def convert_idm_to_staff(idm_group, token):
    if ':' not in idm_group:
        return idm_group

    group = idm_group.split(':')[1]
    rsp = requests.get(
        STAFF_API_URL_FORMAT.format(group),
        headers={
            'Authorization': 'OAuth ' + token,
        },
    )
    if rsp.status_code == 200:
        result = rsp.json()['result']
        if len(result) == 1:
            group_id = result[0]['service']['id']
            if group_id:
                return 'abc:service:{}'.format(group_id)

            group_id = result[0]['department']['id']
            if group_id:
                return 'staff:department:{}'.format(group_id)

    return idm_group


def get_cluster_list():
    ret = []
    for cluster in consts.CLUSTER_CONFIGS.keys():
        ret.append(cluster)

    return ret


def get_token(token_env, token_path):
    token = os.getenv(token_env)
    if token:
        print("Use yp token from env {}".format(token_env))
        return token

    if os.path.isfile(token_path):
        print("Use yp token from file {}".format(token_path))
        with open(token_path, 'r') as f:
            return f.read().strip()

    raise Exception("No yp token provided")


class YpClientApi(object):
    def __init__(self, cluster_config, yp_token, dry_run):
        yp_client = yp.client.YpClient(
            address=cluster_config.address,
            config={
                'token': yp_token,
            }
        )
        self.yp_client_stub = yp_client.create_grpc_object_stub()
        self.dry_run = dry_run
        self.name = cluster_config.address.split('.')[0].upper()
        self.token = yp_token

    def list_objects(self, type, filter, limit):
        req = object_service_pb2.TReqSelectObjects()
        req.object_type = type
        req.limit.value = limit

        req.selector.paths.append("/meta/id")

        if filter is not None:
            req.filter.query = filter

        resp = self.yp_client_stub.SelectObjects(req)

        objects = []
        for r in resp.results:
            try:
                object_id = yson.loads(r.values[0])
            except Exception:
                print("{}: Error while parsing object id in select '{}'".format(self.name, str(r)))
                continue

            objects.append(object_id)

        if len(resp.results) == limit:
            print("{}: WARNING You have selected as many objects as specified in the query limit".format(self.name))

        return objects

    def start_transaction(self):
        req = object_service_pb2.TReqStartTransaction()
        resp = self.yp_client_stub.StartTransaction(req)

        return resp.transaction_id, resp.start_timestamp

    def commit_transaction(self, transaction_id):
        req = object_service_pb2.TReqCommitTransaction()
        req.transaction_id = transaction_id
        self.yp_client_stub.CommitTransaction(req)

    @staticmethod
    def _check_prefix(subject, prefixes):
        for prefix in prefixes:
            if subject.startswith(prefix):
                return True
        return False

    def update_acl(self, transaction_timestamp, transaction_id, type, object_id, prefixes, remove):
        acls = self._get_acl(transaction_timestamp, type, object_id)

        new_acls = []
        need_update = False

        for acl in acls:
            subjects = acl.get('subjects', [])
            if not len(subjects):
                print('[{}] {}: EMPTY SUBJECTS'.format(self.name, object_id))

            new_subjects = []
            for subject in subjects:
                prefix_found = self._check_prefix(subject, prefixes)
                if remove:
                    if not prefix_found:
                        new_subjects.append(subject)
                    else:
                        need_update = True
                else:
                    if prefix_found:
                        new_subjects.append(subject)
                    else:
                        need_update = True

            if len(new_subjects) > 0 :
                new_acl = dict(acl)
                new_acl["subjects"] = new_subjects
                new_acls.append(new_acl)

        if not need_update:
            print('[{}] {}: no need to update'.format(self.name, object_id))
            return

        print('[{}] {}: old ACLs:'.format(self.name, object_id))
        pprint(acls)
        print('[{}] {}: new ACLs:'.format(self.name, object_id))
        pprint(new_acls)

        if not self.dry_run:
            self._set_acl(transaction_id, type, object_id, new_acls)

    def count_acl(self, transaction_timestamp, type, object_id):
        acls = self._get_acl(transaction_timestamp, type, object_id)

        all_new_subjects = []
        personal_new_subjects = []
        for acl in acls:
            subjects = acl["subjects"]
            permissions = acl["permissions"]
            if 'create' in permissions:
                for subject in subjects:
                    if is_subject_ignored(subject):
                        continue

                    if ':' not in subject:
                        personal_new_subjects.append(subject)
                        all_new_subjects.append(subject)
                    else:
                        if subject.startswith("abc:"):
                            all_new_subjects.extend(subject)
                        elif subject.startswith("deploy:"):
                            all_new_subjects.extend(
                                ['deploy -> ' + convert_idm_to_staff(m, self.token)
                                 for m in self.get_group(subject)['spec'].get('members', [])]
                            )

        pprint(personal_new_subjects)
        return personal_new_subjects, all_new_subjects

    def get_group(self, group_id):
        print('{}: get group: {}'.format(self.name, group_id))
        req = object_service_pb2.TReqGetObject()
        req.object_type = data_model.OT_GROUP
        req.object_id = group_id
        req.selector.paths.append("")
        resp = self.yp_client_stub.GetObject(req)

        return yson.loads(resp.result.values[0])

    def update_group(self, group_info):
        print('{}: update group'.format(self.name))
        pprint(group_info)

        if self.dry_run:
            return

        # update acls if required
        req = object_service_pb2.TReqUpdateObject()
        req.object_type = data_model.OT_GROUP
        req.object_id = group_info['meta']['id']

        upd = req.set_updates.add()
        upd.path = "/spec"
        upd.value = yson.dumps(group_info['spec'])

        upd = req.set_updates.add()
        upd.path = "/meta/acl"
        upd.value = yson.dumps(group_info['meta']['acl'])

        upd = req.set_updates.add()
        upd.path = "/labels"
        upd.value = yson.dumps(group_info['labels'])

        try:
            self.yp_client_stub.UpdateObject(req)
            print('{}: updated {}'.format(self.name, group_info['meta']['id']))
        except YpAuthorizationError as e:
            print('{}: YpAuthorizationError {}'.format(self.name, e))

    def create_group(self, group_info):
        print('{}: create group:'.format(self.name))

        new_info = {
            'meta': {
                'id': group_info['meta']['id'],
                'acl': group_info['meta']['acl']
            },
            'spec': group_info['spec'],
            'labels': group_info['labels']
        }

        pprint(new_info)

        if self.dry_run:
            return

        req = object_service_pb2.TReqCreateObject(
            object_type=data_model.OT_GROUP,
            attributes=yson.dumps(new_info)
        )
        try:
            rsp = self.yp_client_stub.CreateObject(req)
            print('{}: created {}'.format(self.name, rsp.object_id))
        except YpAuthorizationError as e:
            print('{}: YpAuthorizationError {}'.format(self.name, e))

    def _get_acl(self, transaction_timestamp, type, object_id):
        req = object_service_pb2.TReqGetObject()
        req.object_type = type
        req.object_id = object_id
        req.selector.paths.append("/meta/acl")
        req.timestamp = transaction_timestamp
        resp = self.yp_client_stub.GetObject(req)

        return yson.loads(resp.result.values[0])

    def _set_acl(self, transaction_id, type, object_id, new_acls):
        # update acls if required
        print('{}: write new ACL'.format(self.name))
        req = object_service_pb2.TReqUpdateObject()
        req.object_type = type
        req.object_id = object_id
        req.transaction_id = transaction_id

        upd = req.set_updates.add()
        upd.path = "/meta/acl"
        upd.value = yson.dumps(new_acls)

        self.yp_client_stub.UpdateObject(req)


class YpTransaction(object):
    def __init__(self, yp_client):
        self.yp_client = yp_client

    def __enter__(self):
        self.transaction_id, self.start_timestamp = self.yp_client.start_transaction()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.yp_client.commit_transaction(self.transaction_id)
        return self

    def get_id(self):
        return self.transaction_id

    def get_timestamp(self):
        return self.start_timestamp


def continue_prompt():
    ret = raw_input("Continue? [(y)es | (n)o | (s)kip] ")
    if ret == 's':
        skip = True
    elif ret != 'y':
        skip = None
    else:
        skip = False
    return skip


def update_cluster(yp_token, cluster, object_type, opt, remove):
    yp_client = YpClientApi(consts.CLUSTER_CONFIGS[cluster], yp_token, opt.dry_run)

    objects = yp_client.list_objects(object_type, opt.filter, opt.select_limit)
    do(objects, [yp_client], opt, do_remove_acls if remove else do_cleanup_acls)


def list_cluster(yp_token, cluster, object_type, opt):
    yp_client = YpClientApi(consts.CLUSTER_CONFIGS[cluster], yp_token, opt.dry_run)

    objects = yp_client.list_objects(object_type, opt.filter, opt.select_limit)
    do(objects, [yp_client], opt, do_list_acls)


def sync_clusters(yp_token, master, slaves, opt):
    clients = [YpClientApi(consts.CLUSTER_CONFIGS[master], yp_token, opt.dry_run)]
    clients.extend([YpClientApi(consts.CLUSTER_CONFIGS[cluster], yp_token, opt.dry_run) for cluster in slaves])

    objects = [obj for obj in clients[0].list_objects(data_model.OT_GROUP, opt.filter, opt.select_limit) if
               obj.startswith(opt.type + ':')]
    do(objects, clients, opt, do_sync_groups)


def are_groups_equal(master, slave):
    master_members = master['spec'].get('members', [])
    slave_members = slave['spec'].get('members', [])

    if len(master_members) != len(slave_members):
        print('Members mismatched by size')
        return False

    if len(set(master_members) & set(slave_members)) != len(master_members):
        print('Members mismatched by value')
        return False

    if master['labels']['system'] != slave['labels']['system']:
        print('Labels mismatched')
        return False

    if len(master['meta']['acl']) != len(slave['meta']['acl']):
        print('ACLs mismatched by size')
        return False

    for master_acl, slave_acl in zip(master['meta']['acl'], slave['meta']['acl']):
        master_subjects = master_acl.get('subjects', [])
        slave_subjects = slave_acl.get('subjects', [])

        if len(master_subjects) != len(slave_subjects):
            print('ACL subjects mismatched by size')
            return False

        if len(set(master_subjects) & set(slave_subjects)) != len(master_subjects):
            print('ACL subjects mismatched by value')
            return False

    return True


def do(objects, clients, opt, func):
    if len(objects) == 0:
        print("No objects to update")
        return

    print("Selected objects [{}]: {} ".format(len(objects), objects))

    step = opt.update_window if opt.update_window else len(objects) + 1 if opt.dry_run else DEFAULT_UPDATE_LIMIT
    print('Update window: {}\n'.format(step))
    skip = continue_prompt()
    success = 0
    skipped = 0
    errors = []

    if skip is None:
        return

    for obj in objects:
        if skip:
            print("Skipped {}".format(obj))
            skipped += 1
        else:
            err = func(obj, clients, opt)
            if not err:
                success += 1
            else:
                errors.append((obj, err))
            print("")
            time.sleep(opt.time_interval)

        step -= 1
        if not step:
            print("Processed {}/{}".format(skipped + success + len(errors), len(objects)))
            skip = continue_prompt()
            if skip is None:
                break
            step = opt.update_window

    print('-' * 50)
    print("Processed/Skipped/Total: {}/{}/{}".format(success, skipped, len(objects)))

    if len(errors) != 0:
        print("Errors:")
        for obj, error in errors:
            print("Object {}, error: '{}'".format(obj, error))

        print("")
    else:
        print("No errors")


def do_remove_acls(obj, clients, opt):
    with YpTransaction(clients[0]) as t:
        try:
            object_type = TYPE_MAP[opt.type]
            clients[0].update_acl(t.get_timestamp(), t.get_id(), object_type, obj, opt.prefixes, True)
        except Exception as e:
            print("Error in object {}: '{}'".format(obj, str(e)))
            return str(e)

    return ""


def do_cleanup_acls(obj, clients, opt):
    with YpTransaction(clients[0]) as t:
        try:
            object_type = TYPE_MAP[opt.type]
            clients[0].update_acl(t.get_timestamp(), t.get_id(), object_type, obj, opt.prefixes, False)
        except Exception as e:
            print("Error in object {}: '{}'".format(obj, str(e)))
            return str(e)

    return ""


def do_list_acls(obj, clients, opt):
    with YpTransaction(clients[0]) as t:
        try:
            object_type = TYPE_MAP[opt.type]
            print('{}:'.format(obj))
            person, total = clients[0].count_acl(t.get_timestamp(), object_type, obj)
        except Exception as e:
            print("Error in object {}: '{}'".format(obj, str(e)))
            return str(e)

    return '{}'.format('' if len(person) == 1 else person)


def do_sync_groups(obj, clients, opt):
    group_info = clients[0].get_group(obj)

    if 'system' not in group_info['labels'] or group_info['labels']['system'] != opt.type:
        # invalid label, we have to update it in master
        print('Fix invalid labels:')
        pprint(group_info['labels'])
        group_info['labels'] = {'system': opt.type}
        if opt.type == 'deploy':
            group_info['labels']['project'] = obj.split(':')[1].split('.')[0]
        print('Correct labels:')
        pprint(group_info['labels'])
        skip = continue_prompt()
        if skip is not None and skip is False:
            clients[0].update_group(group_info)

    updated = []
    for slave in clients[1:]:
        try:
            slave_info = slave.get_group(obj)
            if are_groups_equal(group_info, slave_info):
                print('Nothing to sync')
            else:
                if opt.dry_run:
                    # in dry-run mode return list of stages to update as error
                    updated.append(slave.name)
                slave_info['spec'] = group_info['spec']
                slave_info['labels'] = group_info['labels']
                slave_info['meta']['acl'] = group_info['meta']['acl']
                slave.update_group(slave_info)
        except YpNoSuchObjectError:
            slave.create_group(group_info)

    return 'updated in {}'.format(updated) if updated else ''


def main(arguments):
    if arguments.dry_run:
        print('DRY RUN MODE IS ON !!!')

    cmd = arguments.command
    yp_token = get_token(YP_TOKEN_ENV, YP_TOKEN_FILE) if not arguments.token else arguments.token
    if cmd == 'list':
        object_type = TYPE_MAP[arguments.type]
        for cluster in arguments.clusters:
            try:
                print("-" * 50)
                print("List ACLs in cluster {}".format(cluster))
                list_cluster(yp_token, cluster, object_type, arguments)
            except Exception as e:
                print("Error in cluster {}: '{}'".format(cluster, str(e)))
    elif cmd == 'cleanup' or cmd == 'remove':
        object_type = TYPE_MAP[arguments.type]
        try:
            if cmd == 'cleanup':
                print("Keep only '{}' ACLs in cluster {}".format(arguments.prefixes, arguments.cluster))
                update_cluster(yp_token, arguments.cluster, object_type, arguments, False)
            else:
                print("Remove '{}' ACLs in cluster {}".format(arguments.prefixes, arguments.cluster))
                update_cluster(yp_token, arguments.cluster, object_type, arguments, True)
        except Exception as e:
            print("Error in cluster {}: '{}'".format(arguments.cluster, str(e)))
    elif cmd == 'sync':
        sync_clusters(yp_token, arguments.master, arguments.clusters, arguments)
    else:
        print('Unknown command!')


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-f", "--filter",
        dest="filter",
        default=None,
        help="filter for stages."
    )
    parser.add_argument(
        "--select-limit",
        dest="select_limit",
        type=int,
        default=SELECT_LIMIT,
        help="maximum number of objects to update."
    )
    parser.add_argument(
        "--time-interval",
        dest="time_interval",
        type=int,
        default=0,
        help="Interval in seconds between updates."
    )
    parser.add_argument(
        "--dry-run",
        dest="dry_run",
        action="store_true",
        default=False,
        help="print selected stages, do not update anything."
    )
    parser.add_argument(
        "--token",
        dest="token",
        default=None,
        help="use specified YP token."
    )
    parser.add_argument(
        "--update-window",
        dest="update_window",
        type=int,
        default=0,
        help="Update window sise. Next step has to be approved by user."
    )

    subparsers = parser.add_subparsers(help='List of commands', dest='command')
    cleanup_parser = subparsers.add_parser('cleanup', help='Keep ACLs with specified prefixes')
    cleanup_parser.add_argument(
        "type",
        metavar='type',
        choices=TYPE_MAP.keys(),
        help="type of objects for update."
    )
    cleanup_parser.add_argument(
        "cluster",
        metavar='cluster',
        choices=get_cluster_list(),
        help="cluster for update."
    )
    cleanup_parser.add_argument(
        "prefixes",
        metavar='prefix',
        nargs="+",
        help="prefixes to keep."
    )

    remove_parser = subparsers.add_parser('remove', help='Remove ACLs with specified prefixes')
    remove_parser.add_argument(
        "type",
        metavar='type',
        choices=TYPE_MAP.keys(),
        help="type of objects for update."
    )
    remove_parser.add_argument(
        "cluster",
        metavar='cluster',
        choices=get_cluster_list(),
        help="cluster for update."
    )
    remove_parser.add_argument(
        "prefixes",
        metavar='prefix',
        nargs="+",
        help="prefixes to remove."
    )

    list_parser = subparsers.add_parser('list', help='List ACLs')
    list_parser.add_argument(
        "type",
        metavar='type',
        choices=TYPE_MAP.keys(),
        help="type of objects for update."
    )
    list_parser.add_argument(
        "clusters",
        metavar='cluster',
        nargs="+",
        choices=get_cluster_list(),
        help="clusters for update."
    )

    sync_parser = subparsers.add_parser('sync', help='Synchronize groups')
    sync_parser.add_argument(
        "type",
        metavar='type',
        choices=['deploy', 'idm'],
        help="Group type."
    )
    sync_parser.add_argument(
        "master",
        metavar='master',
        choices=get_cluster_list(),
        help="YP master as a source."
    )
    sync_parser.add_argument(
        "clusters",
        metavar='cluster',
        nargs="+",
        choices=get_cluster_list(),
        help="YP clusters to synchronize."
    )
    return parser.parse_args()


if __name__ == "__main__":
    main(parse_arguments())
