import argparse
import collections
import infra.dctl.src.consts as consts
import os
import requests
import sys
import time
import yp.client
import yp.data_model as data_model
import yt.yson as yson
from pprint import pprint
from yp.common import YpAuthorizationError
from yp_proto.yp.client.api.proto import object_service_pb2

reload(sys)
sys.setdefaultencoding("utf8")

YP_TOKEN_ENV = "YP_TOKEN"
YP_TOKEN_FILE = os.path.expanduser("~/.yp/token")
IDM_TOKEN_ENV = "IDM_TOKEN"
IDM_TOKEN_FILE = os.path.expanduser("~/.idm/token")
IDM_API_ROLEREQUESTS_PATH = "https://idm-api.yandex-team.ru/api/v1/rolerequests/"
SELECT_LIMIT = 100000
DEFAULT_UPDATE_LIMIT = 20
IGNORED_SUBJECT_PREFIXES = ['robot-', 'zomb-']

TYPE_MAP = {
    'stage': data_model.OT_STAGE,
    'project': data_model.OT_PROJECT,
    'group': data_model.OT_GROUP
}

CLUSTER_TO_IDM_SYSTEM = {
    "xdc": "deploy-prod",
    "sas-test": "deploy-test",
    "man-pre": "deploy-pre"
}

STAFF_API_URL_FORMAT_ID_REQUEST = 'https://staff-api.yandex-team.ru/v3/groups?id={}'
STAFF_API_URL_FORMAT_SERVICE_ID_REQUEST = 'https://staff-api.yandex-team.ru/v3/groups?service.id={}'


def is_subject_ignored(subject):
    for prefix in IGNORED_SUBJECT_PREFIXES:
        if subject.startswith(prefix):
            return True

    return False


def is_person_dismissed(login, token):
    if login == "root":
        return True
    headers = {"Authorization": "OAuth {}".format(token)}

    r = requests.get(
        'https://staff-api.yandex-team.ru/v3/persons?login={}&_fields=official.is_dismissed'.format(login),
        headers=headers)
    result = r.json().get('result', [])
    if result:
        return result[0].get('official', {}).get('is_dismissed', False)
    return False


def convert_idm_to_staff(idm_group, token):
    if ':' not in idm_group:
        return idm_group

    group = idm_group.split(':')[1]
    rsp = requests.get(
        STAFF_API_URL_FORMAT_ID_REQUEST.format(group),
        headers={
            'Authorization': 'OAuth ' + token,
        },
    )
    if rsp.status_code == 200:
        result = rsp.json()['result']
        if len(result) == 1:
            group_id = result[0]['service']['id']
            if group_id:
                return 'abc:service:{}'.format(group_id)

            group_id = result[0]['department']['id']
            if group_id:
                return 'staff:department:{}'.format(group_id)

    return idm_group


def convert_abc_service_to_idm_id(abc_service, token):
    if not abc_service.startswith('abc:service:'):
        return abc_service
    id = abc_service.split(':')[2]
    rsp = requests.get(
        STAFF_API_URL_FORMAT_SERVICE_ID_REQUEST.format(id),
        headers={
            'Authorization': 'OAuth ' + token,
        },
    )
    if rsp.status_code == 200:
        result = rsp.json()['result']
        if len(result) == 1:
            print('Successful convert {} to idm:{}'.format(abc_service, result[0]['id']))
            return 'idm:{}'.format(result[0]['id'])
    print('Invalid response from staff for {}, code {}, message: {}'.format(abc_service, rsp.status_code, str(rsp.json())))
    return abc_service


def get_cluster_list():
    ret = []
    for cluster in consts.CLUSTER_CONFIGS.keys():
        ret.append(cluster)

    return ret


def get_token(token_env, token_path):
    token = os.getenv(token_env)
    if token:
        print("Use yp token from env {}".format(token_env))
        return token

    if os.path.isfile(token_path):
        print("Use yp token from file {}".format(token_path))
        with open(token_path, 'r') as f:
            return f.read().strip()

    raise Exception("No yp token provided")


class YpClientApi(object):
    def __init__(self, cluster_config, yp_token, dry_run):
        yp_client = yp.client.YpClient(
            address=cluster_config.address,
            config={
                'token': yp_token,
            }
        )
        self.yp_client_stub = yp_client.create_grpc_object_stub()
        self.dry_run = dry_run
        self.name = cluster_config.address.split('.')[0].upper()
        self.token = yp_token

    def generate_timestamp(self):
        req = object_service_pb2.TReqGenerateTimestamp()
        return self.yp_client_stub.GenerateTimestamp(req).timestamp

    def list_objects(self, type, filter, limit, extra_selectors=[], timestamp=None):
        req = object_service_pb2.TReqSelectObjects()
        req.object_type = type
        req.limit.value = limit
        if timestamp:
            req.timestamp = timestamp

        req.selector.paths.append("/meta/id")
        req.selector.paths.extend(extra_selectors)

        if filter is not None:
            req.filter.query = filter

        resp = self.yp_client_stub.SelectObjects(req)

        objects = []
        for r in resp.results:
            try:
                object_fields = [yson.loads(value) for value in r.values]
            except Exception:
                print("{}: Error while parsing object id in select '{}'".format(self.name, str(r)))
                continue

            objects.append(object_fields)

        if len(resp.results) == limit:
            print("{}: WARNING You have selected as many objects as specified in the query limit".format(self.name))

        return objects

    def start_transaction(self):
        req = object_service_pb2.TReqStartTransaction()
        resp = self.yp_client_stub.StartTransaction(req)

        return resp.transaction_id, resp.start_timestamp

    def commit_transaction(self, transaction_id):
        req = object_service_pb2.TReqCommitTransaction()
        req.transaction_id = transaction_id
        self.yp_client_stub.CommitTransaction(req)

    @staticmethod
    def _check_prefix(subject, prefixes):
        for prefix in prefixes:
            if subject.startswith(prefix):
                return True
        return False

    def update_acl(self, transaction_timestamp, transaction_id, type, object_id, prefixes, remove):
        acls = self._get_acl(type, object_id, transaction_timestamp)

        new_acls = []
        need_update = False

        for acl in acls:
            subjects = acl.get('subjects', [])
            if not len(subjects):
                print('[{}] {}: EMPTY SUBJECTS'.format(self.name, object_id))

            new_subjects = []
            for subject in subjects:
                prefix_found = self._check_prefix(subject, prefixes)
                if remove:
                    if not prefix_found:
                        new_subjects.append(subject)
                    else:
                        need_update = True
                else:
                    if prefix_found:
                        new_subjects.append(subject)
                    else:
                        need_update = True

            if len(new_subjects) > 0:
                new_acl = dict(acl)
                new_acl["subjects"] = new_subjects
                new_acls.append(new_acl)

        if not need_update:
            print('[{}] {}: no need to update'.format(self.name, object_id))
            return

        print('[{}] {}: old ACLs:'.format(self.name, object_id))
        pprint(acls)
        print('[{}] {}: new ACLs:'.format(self.name, object_id))
        pprint(new_acls)

        if not self.dry_run:
            self._set_acl(transaction_id, type, object_id, new_acls)

    def check_permission(self, type, object_id, permission='write', transaction_timestamp=None):
        acls = self._get_acl(type, object_id, transaction_timestamp)

        if not acls:
            print('Empty ACLs')
            return set()

        new_subjects = set()
        new_subject_compact = set()
        old_subjects = set()
        old_subject_compact = set()
        for acl in acls:
            subjects = acl.get("subjects", [])
            permissions = acl.get("permissions", [])
            if permission in permissions:
                for subject in subjects:
                    if is_subject_ignored(subject):
                        continue

                    if ':' not in subject:
                        old_subjects.add(subject)
                        old_subject_compact.add(subject)
                    else:
                        if subject.startswith("abc:"):
                            old_subjects.update(
                                self.get_group(subject, transaction_timestamp)['spec'].get('members', [])
                            )
                            old_subject_compact.add(subject)
                        elif subject.startswith("deploy"):
                            try:
                                for m in self.get_group(subject, transaction_timestamp)['spec'].get('members', []):
                                    new_subject_compact.add(m)
                                    if ':' in m:
                                        new_subjects.update(
                                            self.get_group(m, transaction_timestamp)['spec'].get('members', [])
                                        )
                                    else:
                                        new_subjects.add(m)
                            except Exception as e:
                                print("Error while resolving group {}: {}".format(subject, str(e)))
                        else:
                            print('ERROR: Unexpected subject: {}'.format(subject))

        diff = old_subjects - new_subjects
        if diff:
            pprint('OLD: {}'.format(sorted(old_subjects)))
            pprint('NEW: {}'.format(sorted(new_subjects)))
        return diff, old_subject_compact, new_subject_compact

    def get_group(self, group_id, timestamp=None):
        print('{}: get group: {}'.format(self.name, group_id))
        req = object_service_pb2.TReqGetObject()
        if timestamp:
            req.timestamp = timestamp
        req.object_type = data_model.OT_GROUP
        req.object_id = group_id
        req.selector.paths.append("")
        resp = self.yp_client_stub.GetObject(req)

        return yson.loads(resp.result.values[0])

    def update_group(self, group_info):
        print('{}: update group'.format(self.name))
        pprint(group_info)

        if self.dry_run:
            return

        # update acls if required
        req = object_service_pb2.TReqUpdateObject()
        req.object_type = data_model.OT_GROUP
        req.object_id = group_info['meta']['id']

        upd = req.set_updates.add()
        upd.path = "/spec"
        upd.value = yson.dumps(group_info['spec'])

        upd = req.set_updates.add()
        upd.path = "/meta/acl"
        upd.value = yson.dumps(group_info['meta']['acl'])

        upd = req.set_updates.add()
        upd.path = "/labels"
        upd.value = yson.dumps(group_info['labels'])

        try:
            self.yp_client_stub.UpdateObject(req)
            print('{}: updated {}'.format(self.name, group_info['meta']['id']))
        except YpAuthorizationError as e:
            print('{}: YpAuthorizationError {}'.format(self.name, e))

    def create_group(self, group_info):
        print('{}: create group:'.format(self.name))

        new_info = {
            'meta': {
                'id': group_info['meta']['id'],
                'acl': group_info['meta']['acl']
            },
            'spec': group_info['spec'],
            'labels': group_info['labels']
        }

        pprint(new_info)

        if self.dry_run:
            return

        req = object_service_pb2.TReqCreateObject(
            object_type=data_model.OT_GROUP,
            attributes=yson.dumps(new_info)
        )
        try:
            rsp = self.yp_client_stub.CreateObject(req)
            print('{}: created {}'.format(self.name, rsp.object_id))
        except YpAuthorizationError as e:
            print('{}: YpAuthorizationError {}'.format(self.name, e))

    def _get_acl(self, type, object_id, transaction_timestamp=None):
        req = object_service_pb2.TReqGetObject()
        req.object_type = type
        req.object_id = object_id
        req.selector.paths.append("/meta/acl")
        if transaction_timestamp:
            req.timestamp = transaction_timestamp
        resp = self.yp_client_stub.GetObject(req)

        return yson.loads(resp.result.values[0])

    def _set_acl(self, transaction_id, type, object_id, new_acls):
        # update acls if required
        print('{}: write new ACL'.format(self.name))
        req = object_service_pb2.TReqUpdateObject()
        req.object_type = type
        req.object_id = object_id
        req.transaction_id = transaction_id

        upd = req.set_updates.add()
        upd.path = "/meta/acl"
        upd.value = yson.dumps(new_acls)

        self.yp_client_stub.UpdateObject(req)


class YpTransaction(object):
    def __init__(self, yp_client):
        self.yp_client = yp_client

    def __enter__(self):
        self.transaction_id, self.start_timestamp = self.yp_client.start_transaction()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.yp_client.commit_transaction(self.transaction_id)
        return self

    def get_id(self):
        return self.transaction_id

    def get_timestamp(self):
        return self.start_timestamp


class IdmClientApi(object):
    def __init__(self, cluster, idm_token, dry_run):
        self.name = cluster.upper()
        self.system = CLUSTER_TO_IDM_SYSTEM[cluster]
        self.token = idm_token
        self.dry_run = dry_run

    def add_role_for_group(self, role_path, group, object_id):
        return self._add_role_subject(role_path, {"group": group}, group, object_id)

    def add_role_for_user(self, role_path, user, object_id):
        return self._add_role_subject(role_path, {"user": user}, user, object_id)

    def _add_role_subject(self, role_path, json_req, subject_id, object_id):
        print("{}: object {} add {} for '{}'".format(self.name, object_id, role_path, subject_id))
        if self.dry_run:
            return

        json_req["system"] = self.system
        json_req["path"] = role_path
        rsp = requests.post(
            IDM_API_ROLEREQUESTS_PATH,
            json=json_req,
            headers={
                "Authorization": "OAuth {}".format(self.token)
            })

        if rsp.status_code == 201:
            print("{}: object {} added {} for '{}'".format(self.name, object_id, role_path, subject_id))
            return rsp.json()["id"]
        else:
            raise Exception("{}: Bad request for role {} adding, code {}, msg: {}".format(
                self.name, role_path, rsp.status_code, rsp.json()["message"]))


def continue_prompt():
    ret = raw_input("Continue? [(y)es | (n)o | (s)kip] ")
    if ret == 's':
        skip = True
    elif ret != 'y':
        skip = None
    else:
        skip = False
    return skip


def check_cluster(yp_token, cluster, object_type, opt):
    yp_client = YpClientApi(consts.CLUSTER_CONFIGS[cluster], yp_token, opt.dry_run)

    objects = yp_client.list_objects(object_type, opt.filter, opt.select_limit)
    objects = [object_fields[0] for object_fields in objects]  # extract object id
    do(objects, [yp_client], opt, do_check_permission)


def migrate_cluster(yp_token, cluster, opt):
    yp_client = YpClientApi(consts.CLUSTER_CONFIGS[cluster], yp_token, opt.dry_run)
    idm_token = get_token(IDM_TOKEN_ENV, IDM_TOKEN_FILE) if not opt.idm_token else opt.idm_token
    staff_token = idm_token
    idm_client = IdmClientApi(cluster, idm_token, opt.dry_run)

    object_type = TYPE_MAP[opt.type]

    extra_selectors = ["/meta/project_id"] if opt.type == "stage" else []

    timestamp = yp_client.generate_timestamp()
    objects = yp_client.list_objects(object_type, opt.filter, opt.select_limit, extra_selectors, timestamp)
    project_id_to_subjects = collections.defaultdict(set)
    project_id_to_object_id = {}
    for object in objects:
        print("Object '{}': ".format(object[0]))
        lost_subjects, old, new = yp_client.check_permission(object_type, object[0], transaction_timestamp=timestamp)
        old = [convert_abc_service_to_idm_id(subject, staff_token) for subject in old if not is_person_dismissed(subject, staff_token)]
        subjects_to_migrate = set(old) - new
        if not subjects_to_migrate:
            continue

        project_id = object[0] if opt.type == "project" else object[1]
        project_id_to_object_id[project_id] = project_id if opt.type == "project" else object[0]
        project_id_to_subjects[project_id] = project_id_to_subjects[project_id].union(subjects_to_migrate)

    objects_to_migrate = []
    for project_id, subjects in project_id_to_subjects.items():
        objects_to_migrate.append({
            "object_id": project_id_to_object_id[project_id],
            "project_id": project_id,
            "subjects_to_migrate": subjects
        })

    do(objects_to_migrate, [idm_client], opt, do_add_role)


def do(objects, clients, opt, func):
    if len(objects) == 0:
        print("No objects to update")
        return

    print("Selected objects [{}]: {} ".format(len(objects), objects))

    step = opt.update_window if opt.update_window else len(objects) + 1 if opt.dry_run else DEFAULT_UPDATE_LIMIT
    print('Update window: {}\n'.format(step))
    skip = continue_prompt()
    success = 0
    skipped = 0
    errors = []

    if skip is None:
        return

    for obj in objects:
        if skip:
            print("Skipped {}".format(obj))
            skipped += 1
        else:
            err = func(obj, clients, opt)
            if not err:
                success += 1
            else:
                errors.append((obj, err))
            print("")
            time.sleep(opt.time_interval)

        step -= 1
        if not step:
            print("Processed {}/{}".format(skipped + success + len(errors), len(objects)))
            skip = continue_prompt()
            if skip is None:
                break
            step = opt.update_window

    print('-' * 50)
    print("Processed/Skipped/Total: {}/{}/{}".format(success, skipped, len(objects)))

    if len(errors) != 0:
        print("Errors:")
        for obj, error in errors:
            print("Object {}, error: '{}'".format(obj, error))

        print("")
    else:
        print("No errors")


def do_check_permission(obj, clients, opt):
    with YpTransaction(clients[0]) as t:
        try:
            object_type = TYPE_MAP[opt.type]
            print('{}:'.format(obj))
            lost_subjects, old, new = clients[0].check_permission(object_type, obj, opt.permission, t.get_timestamp())
        except Exception as e:
            print("Error in object {}: '{}'".format(obj, str(e)))
            return str(e)

    return '{}'.format('' if not lost_subjects else '\nold: {}\nnew:{}\ndiff:{}\n'.format(
        sorted(old), sorted(new), sorted(lost_subjects)))


def do_add_role(obj, idm_clients, opt):
    role_path = "/{}/MAINTAINER/".format(obj['project_id'])
    for subject in obj['subjects_to_migrate']:
        try:
            if subject.startswith('idm:'):
                group = int(subject.split(':')[1])
                idm_clients[0].add_role_for_group(role_path, group, obj['object_id'])
            else:
                idm_clients[0].add_role_for_user(role_path, subject, obj['object_id'])
        except Exception as e:
            print("Error in object {}: {}".format(obj['object_id'], str(e)))


def main(arguments):
    if arguments.dry_run:
        print('DRY RUN MODE IS ON !!!')

    cmd = arguments.command
    yp_token = get_token(YP_TOKEN_ENV, YP_TOKEN_FILE) if not arguments.yp_token else arguments.yp_token
    if cmd == 'check':
        object_type = TYPE_MAP[arguments.type]
        for cluster in arguments.clusters:
            try:
                print("-" * 50)
                print("Check ACLs in cluster {}".format(cluster))
                check_cluster(yp_token, cluster, object_type, arguments)
            except Exception as e:
                print("Error in cluster {}: '{}'".format(cluster, str(e)))
    elif cmd == 'migrate':
        for cluster in arguments.clusters:
            try:
                print("-" * 50)
                print("Migrate ACLs in cluster {}".format(cluster))
                migrate_cluster(yp_token, cluster, arguments)
            except Exception as e:
                print("Error in cluster {}: '{}'".format(cluster, str(e)))
    else:
        print('Unknown command!')


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-f", "--filter",
        dest="filter",
        default=None,
        help="filter for stages."
    )
    parser.add_argument(
        "--select-limit",
        dest="select_limit",
        type=int,
        default=SELECT_LIMIT,
        help="maximum number of objects to update."
    )
    parser.add_argument(
        "--time-interval",
        dest="time_interval",
        type=int,
        default=0,
        help="Interval in seconds between updates."
    )
    parser.add_argument(
        "--dry-run",
        dest="dry_run",
        action="store_true",
        default=False,
        help="print selected stages, do not update anything."
    )
    parser.add_argument(
        "--yp-token",
        dest="yp_token",
        default=None,
        help="use specified YP token."
    )
    parser.add_argument(
        "--idm-token",
        dest="idm_token",
        default=None,
        help="use specified IDM token."
    )
    parser.add_argument(
        "--update-window",
        dest="update_window",
        type=int,
        default=0,
        help="Update window size. Next step has to be approved by user."
    )

    subparsers = parser.add_subparsers(help='List of commands', dest='command')

    check_parser = subparsers.add_parser('check', help='List ACLs')
    check_parser.add_argument(
        "--permission",
        metavar='permission',
        choices='write, read, create, ssh_access, root_ssh_access, read_secrets, use',
        default='write',
        help="permission to check."
    )
    check_parser.add_argument(
        "type",
        metavar='type',
        choices=TYPE_MAP.keys(),
        help="type of objects for checking."
    )
    check_parser.add_argument(
        "clusters",
        metavar='cluster',
        nargs="+",
        choices=get_cluster_list(),
        help="clusters for checking."
    )

    migrate_parser = subparsers.add_parser('migrate', help='Migrate ACLs')
    migrate_parser.add_argument(
        "type",
        metavar='type',
        choices=TYPE_MAP.keys(),
        help="type of objects for parsing."
    )
    migrate_parser.add_argument(
        "clusters",
        metavar='cluster',
        nargs="+",
        choices=get_cluster_list(),
        help="clusters for parsing."
    )

    return parser.parse_args()


if __name__ == "__main__":
    main(parse_arguments())
