import argparse
import json
import logging

from yp.client import YpClient, find_token


def parse_args():
    parser = argparse.ArgumentParser(description='Sets write permission for every record set'
                                                 'by filter for specified users')

    parser.add_argument('-c', '--cluster', required=True,
                        choices=['sas-test', 'sas', 'man-pre', 'man', 'vla', 'myt', 'iva', 'xdc'],
                        help='YP cluster name')
    parser.add_argument('-z', '--zone')
    parser.add_argument('-f', '--filter',
                        help='Additional filter')
    parser.add_argument('-u', '--users', nargs='+',
                        help='user names')
    parser.add_argument('-x', '--exclude', nargs='+',
                        help='user names to delete from acl')
    parser.add_argument('-d', '--dry-run', action='store_true',
                        help='Do no apply updates')
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='Verbose output')
    args = parser.parse_args()

    args.users = set(args.users or [])
    args.exclude = set(args.exclude or [])

    return args


def main():
    args = parse_args()
    logging.basicConfig(level=(logging.DEBUG if args.verbose else logging.INFO))

    yp_client = YpClient(address='{}.yp.yandex.net:8090'.format(args.cluster), config={'token': find_token()})

    filter = ''
    if args.zone:
        filter = '[/labels/zone] = "{}"'.format(args.zone)
    if args.filter:
        filter = '({}) and ({})'.format(filter or "%true", args.filter)

    assert filter, 'Filter must be set'

    continuation_token = None
    update_requests = []
    while True:
        timestamp = yp_client.generate_timestamp()

        limit = 10000
        result = yp_client.select_objects('dns_record_set', filter=filter,
                                          selectors=['/meta/id', '/meta/acl'], timestamp=timestamp, limit=limit,
                                          options={'continuation_token': continuation_token}, enable_structured_response=True)
        continuation_token = result['continuation_token']

        updates = []
        for id, acl in result['results']:
            have_update = False

            current_users = set()
            for entry in acl['value']:
                entry.setdefault('subjects', [])
                if entry['action'] != 'allow':
                    continue
                if 'write' not in entry['permissions']:
                    continue
                to_del = args.exclude & set(entry['subjects'])
                if to_del:
                    entry['subjects'] = list(set(entry['subjects']) - args.exclude)
                    have_update = True
                current_users |= set(entry['subjects'])

            users_to_add = args.users - current_users
            if users_to_add:
                acl['value'].append({
                    'action': 'allow',
                    'permissions': [
                        'read',
                        'write',
                    ],
                    'subjects': list(users_to_add),
                })
                have_update = True

            size_before = len(acl['value'])
            acl['value'] = [entry for entry in acl['value'] if len(entry['subjects']) > 0]
            if len(acl['value']) != size_before:
                have_update = True

            if have_update:
                updates.append((id['value'], acl['value']))

        for id, acl in updates:
            update_requests.append({
                'object_type': 'dns_record_set',
                'object_id': id,
                'set_updates': [
                    {
                        'path': '/meta/acl',
                        'value': acl,
                    },
                ],
            })
            logging.debug('add request: {}'.format(json.dumps(update_requests[-1], indent=2)))

        if len(result['results']) < limit:
            break

    logging.info('total updates: {}'.format(len(update_requests)))

    batch_size = 10000
    for i in range(0, len(update_requests), batch_size):
        reqs = update_requests[i:min(len(update_requests), i + batch_size)]
        logging.info('record sets to update: {}'.format(len(reqs)))
        if not args.dry_run:
            yp_client.update_objects(reqs)
        else:
            logging.info('skip (dry run)')


if __name__ == '__main__':
    main()
