import argparse
import json
import logging
import retry
import socket
import sys
import time

import dns.exception
import dns.message
import dns.name
import dns.query
import dns.rdatatype

from color import colored
from tqdm import tqdm


LOGGING_FORMAT = '[%(asctime)s] [%(levelname)-5s] %(message)s'

EPS = 0.01

logger = logging.getLogger('check_zone')


def get_messages_from_awacs_nameserver_config(input_file):
    logger.info('Parsing json...')
    with open(input_file, 'r') as f:
        config = json.load(f)
    logger.info('Done')

    zone = config['zone']
    records = config['records']
    messages = []
    for idx, record in tqdm(enumerate(records), total=len(records), desc='Reading records json', ncols=100):
        if record['type'] != 'ADDRESS':
            continue

        fqdn = '.'.join([record['address']['zone'], zone])
        addresses = set(map(lambda addr: addr['value'], record['address']['ipv6Addrs']))
        messages.append((idx, dns.message.make_query(fqdn, 'AAAA', 'IN'), addresses))

    return messages


def create_messages(input_file, file_format):
    if file_format == 'awacs':
        messages = get_messages_from_awacs_nameserver_config(input_file)
    return messages


def filter_messages(messages, zone):
    def match_zone(message, zone):
        return zone is None or message.question[0].name.is_subdomain(dns.name.from_text(zone))

    result = []
    for ts, message, resp in messages:
        if not match_zone(message, zone):
            continue
        result.append([ts, message, resp])
    return result


def set_duration(messages, duration):
    num = len(messages)
    for idx in range(len(messages)):
        messages[idx][0] = messages[0][0] + duration / float(num) * idx
    return messages


def create_validation_config(args):
    return {
        'max_records': args.max_records_expected,
        'compare_mode': args.compare_mode,
    }


def get_address(backend):
    parts = backend.split(':')
    if len(parts) == 2:
        host, port = parts
        port = int(port)
    else:
        host, port = parts, 53
    return socket.getaddrinfo(host, port)[0][4][0], port


def send_udp(message, where, port=53):
    return dns.query.udp(message, where=where, port=port, timeout=1)


def send_tcp(message, where, port=53):
    return dns.query.tcp(message, where=where, port=port, timeout=1)


@retry.retry(tries=3, delay=0.5, backoff=2)
def send(message, ip, port, protocol):
    if protocol == 'udp':
        return send_udp(message, ip, port)
    elif protocol == 'tcp':
        return send_tcp(message, ip, port)


def shoot(backend_addr, messages, protocol):
    responses = []
    start_time = time.time()
    for idx, data in tqdm(enumerate(messages), total=len(messages), desc='shooting', ncols=100):
        ts, message, _ = data

        deadline = start_time + ts
        current_time = time.time()
        if current_time + EPS < deadline:
            time.sleep(deadline - current_time - EPS)

        response = None
        try:
            ip, port = backend_addr
            response = send(message, ip, port, protocol)
        except dns.exception.Timeout:
            logger.warning('Query timed out. Query:\n{}'.format(message.to_text()))
        except dns.query.BadResponse as e:
            logger.warning('Bad response {}. Query:\n{}'.format(e, message.to_text()))
        except dns.query.UnexpectedSource as e:
            logger.warning('Invalid source {}. Query:\n{}'.format(e, message.to_text()))

        responses.append(response)

    return responses


def add_indent(lines, indent=0):
    return map(lambda line: '{}{}'.format(' ' * indent, line), lines)


def print_message(message, indent=0):
    lines = message.to_text().split('\n')
    return '\n'.join(add_indent(lines, indent=indent))


def validate(messages, responses, config):
    assert len(messages) == len(responses)
    for message, response in tqdm(zip(messages, responses), total=len(messages), desc='Validating responses', ncols=100):
        _, query, expected = message
        domain = query.question[0].name
        rdtype = query.question[0].rdtype
        rrs = response.get_rrset(dns.message.ANSWER, domain, query.question[0].rdclass, rdtype)
        response_addresses = set(map(str, rrs))

        expected_records_number = len(expected)
        if config['max_records']:
            expected_records_number = min(expected_records_number, config['max_records'])

        errors = []

        if len(response_addresses) != expected_records_number:
            errors.append('Invalid number of records in ANSWER: found {}, expected {}'.format(len(response_addresses), expected_records_number))

        if config['compare_mode'] == 'equal':
            if response_addresses != expected:
                errors.append('Invalid set of records in ANSWER')
        elif config['compare_mode'] == 'subset':
            if not response_addresses.issubset(expected):
                errors.append('Invalid set of records in ANSWER')

        if errors:
            error_message = colored('Validation failed on query ({}, {}).'.format(domain, rdtype), 'red')
            error_message += ' Errors:\n{}'.format(colored('\n'.join(add_indent(errors, indent=2)), 'yellow'))
            error_message += '\n\n  Query:\n{}'.format(print_message(query, indent=4))
            error_message += '\n\n  Response:\n{}'.format(print_message(response, indent=4))
            error_message += '\n\n  Expected records:\n{}'.format('\n'.join(add_indent(expected, indent=4)))
            logging.critical(error_message)


def check(ns, backend, messages, protocol, validation_config):
    backend_addr = get_address(backend)
    responses = shoot(backend_addr, messages, protocol)
    validate(messages, responses, validation_config)


def init_logger(args):
    log_level = logging.DEBUG if args.verbose else logging.INFO
    logging.basicConfig(level=log_level, format=LOGGING_FORMAT)


def parse_args(argv):
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--input', required=True,
                        help='Path to pcap file')
    parser.add_argument('--format', choices=['awacs'], default='awacs',
                        help='Input file format')
    parser.add_argument('--ns', required=True,
                        help='Nameserver')
    parser.add_argument('--ns-backend',
                        help='Backend for nameserver')
    parser.add_argument('--zone',
                        help='Filter queries by zone if specified')
    parser.add_argument('--protocol', choices=['udp', 'tcp'], default='udp',
                        help='Protocol to use')
    parser.add_argument('--duration', type=float,
                        help='Seconds to shoot')
    parser.add_argument('--compare-mode', choices=['equal', 'subset'], default='equal',
                        help='Mode for checking records in answers with actual')
    parser.add_argument('--max-records-expected', type=int, default=0,
                        help='Max number of records expected in answer (0 = default)')
    parser.add_argument('--verbose', action='store_true')
    args = parser.parse_args()

    if not args.ns_backend:
        args.ns_backend = '{}:53'.format(args.ns)

    return args


def main(argv):
    args = parse_args(argv)

    init_logger(args)

    messages = create_messages(args.input, args.format)
    messages = filter_messages(messages, args.zone)

    if len(messages) == 0:
        logger.info('No queries left after filter')
        return

    if args.duration:
        messages = set_duration(messages, args.duration)

    validation_config = create_validation_config(args)

    check(args.ns, args.ns_backend, messages, args.protocol, validation_config)


if __name__ == '__main__':
    main(sys.argv[1:])
