import argparse
import json
import logging

from collections import defaultdict

from yp.client import YpClient, find_token

from library.python import resource

import dns.name


def in_zone(fqdn, zone):
    return fqdn.is_subdomain(zone)


def batched_select_objects(yp_client, batch_size, object_type, selectors, filter=None):
    continuation_token = None
    objects = []
    timestamp = yp_client.generate_timestamp()
    while True:
        result = yp_client.select_objects(object_type, selectors=selectors, timestamp=timestamp,
                                          filter=filter, limit=batch_size,
                                          options={'continuation_token': continuation_token},
                                          enable_structured_response=True)
        continuation_token = result['continuation_token']
        batch_objects = list(map(
            lambda obj: tuple(map(lambda selector: selector['value'], obj)),
            result['results']
        ))
        objects.extend(batch_objects)

        if len(result['results']) < batch_size:
            break
    return objects


def batched_update_objects(args, batch_size, update_requests):
    yp_client = YpClient(args.cluster, config={'token': find_token()})
    for i in range(0, len(update_requests), batch_size):
        reqs = update_requests[i:i + batch_size]
        logging.info('Update batch of size {}'.format(len(reqs)))
        if not args.dry_run:
            yp_client.update_objects(reqs)
        else:
            logging.info("Updates disabled due to dry run")


def get_dynamic_zones():
    yp_client = YpClient('xdc', config={'token': find_token()})
    dns_zones = batched_select_objects(yp_client, batch_size=1000, object_type='dns_zone',
                                       selectors=['/meta/id'])
    return set(map(lambda dns_zone: dns_zone[0], dns_zones))


def get_static_zones():
    BRIDGE_CONFIG_RESOURCE_NAME = '/proto_config/bridge_config.json'
    bridge_config = json.loads(resource.find(BRIDGE_CONFIG_RESOURCE_NAME))
    return set(map(lambda zone_config: zone_config["Name"], bridge_config["ZoneConfigs"]))


def get_all_zones():
    return set(map(lambda zone: dns.name.from_text(zone),
                   get_static_zones() | get_dynamic_zones()))


def get_record_sets(args):
    yp_client = YpClient(args.cluster, config={'token': find_token()})
    record_sets = batched_select_objects(yp_client, batch_size=1000,
                                         object_type='dns_record_set',
                                         selectors=['/meta/id', '/labels/zone'],
                                         filter='[/labels/zone] != #')
    return list(map(
        lambda record_set: (record_set[0], dns.name.from_text(record_set[0]), dns.name.from_text(record_set[1])),
        record_sets))


def find_matching_zone(fqdn, zones):
    result = None
    for zone in zones:
        if in_zone(fqdn, zone):
            if result is not None:
                raise Exception(f"{fqdn}: matched several zones: {result}, {zone}")
            result = zone
    return result


def parse_args():
    parser = argparse.ArgumentParser(description='Changes all record sets\' /labels/zone to existing zones')

    parser.add_argument('-z', '--zone', required=False)
    parser.add_argument('-c', '--cluster', required=True,
                        choices=['sas-test', 'sas', 'man-pre', 'man', 'vla', 'myt', 'iva', 'xdc'],
                        help='YP cluster name')
    parser.add_argument('-d', '--dry-run', action='store_true')
    return parser.parse_args()


def main():
    args = parse_args()

    logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(levelname)-5s] %(message)s')

    zones = get_all_zones()
    logging.info(f"Zones found: {len(zones)}")

    record_sets = get_record_sets(args)
    logging.info(f"Record sets with zone label found: {len(record_sets)}")

    updates_by_zone = {}
    subzones_by_zone = {}
    for id, fqdn, zone in record_sets:
        if zone in zones:
            continue

        actual_zone = find_matching_zone(fqdn, zones)
        if actual_zone is None:
            # logging.info(f"{fqdn} {zone} no zone found")
            continue

        logging.info(f"{fqdn}: change {zone} -> {actual_zone}")
        subzones_by_zone.setdefault(actual_zone, defaultdict(int))[zone] += 1
        updates_by_zone.setdefault(actual_zone, []).append({
            'object_type': 'dns_record_set',
            'object_id': id,
            'set_updates': [
                {
                    'path': '/labels/zone',
                    'value': actual_zone.to_text(omit_final_dot=True),
                },
            ],
        })
        logging.debug(json.dumps(updates_by_zone[actual_zone][-1], indent=2))

    for zone, subzones in subzones_by_zone.items():
        logging.info(f"Aggregate zones below to {zone}")
        for subzone, record_sets_num in subzones.items():
            logging.info(f"  -> {subzone}: {record_sets_num}")

    for zone, updates in updates_by_zone.items():
        logging.info(f"Updates with zone {zone}: {len(updates)}")

    for zone, updates in updates_by_zone.items():
        batched_update_objects(args, batch_size=1000, update_requests=updates)

if __name__ == '__main__':
    main()
