import json
import argparse
import logging
import sys

from yp.client import YpClient, find_token


logger = logging.getLogger()
sh = logging.StreamHandler()
sh.setFormatter(logging.Formatter('[%(asctime)s] [%(levelname)s] %(message)s'))
logger.addHandler(sh)


def count_number_of_zone_record_sets_in_cluster(zone, cluster, group_by_exprs, additional_filter):
    yp_client = YpClient(cluster, config={'token': find_token()})

    match_zone_filter = f'regex_partial_match("(^|\\.){zone}\\.?$", [/meta/id]) = %true'
    filter_expr = match_zone_filter
    if additional_filter:
        filter_expr = f'({filter_expr}) and ({additional_filter})'
    aggregate_result = yp_client.aggregate_objects(
        'dns_record_set',
        group_by=[
            match_zone_filter,
        ] + group_by_exprs,
        aggregators=[
            'sum(1u)'
        ],
        filter=filter_expr
    )
    logger.debug('Aggregate objects result:\n{}'.format(aggregate_result))

    count_by_group = {}
    for group in aggregate_result:
        if group[0]:
            count_by_group.setdefault('total', 0)
            count_by_group['total'] += group[-1]
        for i, expr in enumerate(group_by_exprs):
            if group[1 + i]:
                count_by_group.setdefault(expr, {}).setdefault(group[1 + i], 0)
                count_by_group[expr][group[1 + i]] += group[-1]
    return count_by_group


def count(args):
    for cluster in args.clusters:
        result = count_number_of_zone_record_sets_in_cluster(args.zone, cluster, args.group_by, args.filter)
        logger.info(f'Record sets of {args.zone} in YP-{cluster.upper()}:\n{json.dumps(result, indent=2)}')


def parse_args(argv):
    parser = argparse.ArgumentParser(description='Counts number of record sets by zone in YP')

    parser.add_argument('-c', '--clusters',
                        default=','.join(['sas-test', 'sas', 'man-pre', 'man', 'vla', 'myt', 'iva', 'xdc']),
                        help='Comma-separated YP cluster names')
    parser.add_argument('-z', '--zone', required=True,
                        help='Zone name')
    parser.add_argument('-g', '--group-by', action='append',
                        help='Additional group by expression')
    parser.add_argument('-f', '--filter',
                        help='Additional filter')
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='Verbose output')
    args = parser.parse_args()
    args.zone = args.zone.removesuffix('.')
    args.clusters = args.clusters.split(',')
    args.group_by = args.group_by or []
    return args


def main(argv):
    args = parse_args(argv)
    if args.verbose:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    count(args)


if __name__ == '__main__':
    main(sys.argv[1:])
