import argparse
import datetime
import dpkt
import itertools
import json
import logging
import os
import re
import requests
import retry
import socket
import sys
import time

from multiprocessing.dummy import Pool as ThreadPool

import dns.exception
import dns.message
import dns.name
import dns.query
import dns.rdatatype

from logging.handlers import RotatingFileHandler

from color import colored
from tqdm import tqdm

from library.python.monlib.metric_registry import MetricRegistry
from library.python.monlib.encoder import dumps


LOGGING_FORMAT = '[%(asctime)s] [%(levelname)-5s] %(message)s'

MIN_SLEEP_TIME = 0.01
EPS = 0.01

ATTRIBUTES = 'ATTRIBUTES'
QUESTION = 'QUESTION'
ANSWER = 'ANSWER'
AUTHORITY = 'AUTHORITY'
ADDITIONAL = 'ADDITIONAL'

SECTIONS = [
    ATTRIBUTES,
    QUESTION,
    ANSWER,
    AUTHORITY,
    ADDITIONAL,
]

SOLOMON_URL = 'http://solomon.yandex.net/push/json'


logger = logging.getLogger('diff_ns_servers')


def parse_args(argv):
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--input', required=True,
                        help='Path to pcap file')
    parser.add_argument('--format', choices=['pcap', 'ips', 'environments', 'zonefile'], default='pcap',
                        help='Input file format')
    parser.add_argument('--ns-first', required=True,
                        help='First nameserver')
    parser.add_argument('--ns-first-backend',
                        help='Backend for first nameserver')
    parser.add_argument('--ns-second', required=True,
                        help='Second nameserver')
    parser.add_argument('--ns-second-backend',
                        help='Backend for second nameserver')
    parser.add_argument('--zone',
                        help='Filter queries by zone if specified')
    parser.add_argument('--protocol', choices=['udp', 'tcp'], default='udp',
                        help='Protocol to use')
    parser.add_argument('--reqs-num', type=int,
                        help='Number of requests to shoot')
    parser.add_argument('--multiplier', type=float,
                        help='Modify shooting speed to a given multiple')
    parser.add_argument('--duration', type=float,
                        help='Seconds to shoot')
    parser.add_argument('--uniform', action='store_true',
                        help='Send queries uniformly if specified. Otherwise, send packets at the speed at wich they were recorded')
    parser.add_argument('--threads', type=int, default=1,
                        help='Threads number for shooting')
    parser.add_argument('--log-file',
                        help='Path to store logs')
    parser.add_argument('--log-max-bytes', type=int, default=10 * 1024 * 1024,
                        help='Max log file size in bytes')
    parser.add_argument('--log-backup-count', type=int, default=1,
                        help='Max log backups to store')
    parser.add_argument('--solomon-token',
                        help='Solomon token for pushing statistics (do not push if omitted)')
    parser.add_argument('--verbose', action='store_true')
    args = parser.parse_args()

    assert not (args.multiplier is not None and args.duration is not None), 'Incompatible options'

    if not args.ns_first_backend:
        args.ns_first_backend = '{}:53'.format(args.ns_first)

    if not args.ns_second_backend:
        args.ns_second_backend = '{}:53'.format(args.ns_second)

    return args


def apply(func, data):
    return filter(None, map(func, data))


def get_address(backend):
    parts = backend.split(':')
    if len(parts) == 2:
        host, port = parts
        port = int(port)
    else:
        host, port = parts, 53
    return socket.getaddrinfo(host, port)[0][4][0], port


def get_pcap_packets(pcap_file, preload, sort=False):
    assert not(not preload and sort), 'Incompatible options'

    reader = dpkt.pcap.Reader(open(pcap_file, 'rb'))
    if preload:
        packets = reader.readpkts()
        if sort:
            packets.sort(key=lambda pkt: pkt[0])
        return packets

    return reader


def get_packets(input_file, sort=False):
    if os.stat(input_file).st_size == 0:
        return []

    return get_pcap_packets(input_file, preload=True, sort=sort)


def inet_to_str(inet):
    try:
        return socket.inet_ntop(socket.AF_INET, inet)
    except ValueError:
        return socket.inet_ntop(socket.AF_INET6, inet)


def get_hostname(addr):
    return socket.gethostbyaddr(addr)[0]


def decode_udp(udp):
    timestamp, packet = udp

    eth = dpkt.sll.SLL(packet)
    ip = eth.data
    src = inet_to_str(ip.src)
    udp = ip.data
    return timestamp, udp.data, src


def create_message(data):
    timestamp, udp_data, src = data
    message = dns.message.from_wire(udp_data)
    return timestamp, message, src


def get_messages_from_ips(ips_file, max_messages=None, verbose=False):
    logger.info('Parsing json...')
    records = json.load(open(ips_file, 'r'))
    logger.info('Done')

    messages = []
    for idx, record in tqdm(enumerate(records), total=len(records), desc='Reading records json', ncols=100, disable=verbose):
        if 'host' not in record:
            continue
        host = record['host']
        if 'ip6' in record:
            messages.append((idx, dns.message.make_query(host, 'AAAA', 'IN'), 'localhost'))
        if 'ip4' in record:
            messages.append((idx, dns.message.make_query(host, 'A', 'IN'), 'localhost'))
        if max_messages is not None and len(messages) >= max_messages:
            break

    return messages


def get_messages_from_environments(environments_file, max_messages, verbose=False):
    logger.info('Parsing json...')
    records = json.load(open(environments_file, 'r'))
    logger.info('Done')

    messages = []
    ts = 0
    for idx, record in tqdm(enumerate(records), total=len(records), desc='Reading records json', ncols=100, disable=verbose):
        domains = []
        domains.append(record['name'])
        for component in record['components']:
            domains.append(component['name'])
            for instance in component['instances']:
                domains.append(instance['name'])
                domains.append(instance['fqdn'])

        for domain in domains:
            if domain:
                messages.append((ts, dns.message.make_query(domain, 'AAAA', 'IN'), 'localhost'))
                messages.append((ts + 1, dns.message.make_query(domain, 'SRV', 'IN'), 'localhost'))
                ts += 2

        if max_messages is not None and len(messages) >= max_messages:
            break

    return messages


def get_messages_from_zonefile(zonefile, max_messages, verbose=False):
    logger.info('Reading records from zonefile...')

    def is_record_type(token):
        try:
            dns.rdatatype.from_text(token)
            return True
        except dns.rdatatype.UnknownRdatatype:
            return False

    with open(zonefile, 'r') as f:
        messages = []
        ts = 0
        for line in f:
            if line.startswith('$'):
                continue

            parts = line.split()
            domain = parts[0]
            rtype = None
            for part in parts:
                if is_record_type(part):
                    rtype = part

            messages.append((ts, dns.message.make_query(domain, rtype, 'IN'), 'localhost'))
            ts += 1

            if max_messages is not None and len(messages) >= max_messages:
                break
    logger.info('Done')

    return messages


def create_messages(input_file, file_format, max_messages=None, verbose=False):
    if file_format == 'pcap':
        packets = get_packets(input_file, sort=True)
        udp_data = apply(decode_udp, packets)
        messages = apply(create_message, udp_data)
    elif file_format == 'ips':
        messages = get_messages_from_ips(input_file, max_messages, verbose)
    elif file_format == 'environments':
        messages = get_messages_from_environments(input_file, max_messages, verbose)
    elif file_format == 'zonefile':
        messages = get_messages_from_zonefile(input_file, max_messages, verbose)
    else:
        raise Exception('Unknown file format: {}'.format(file_format))

    return messages


def filter_messages(messages, zone, reqs_num):
    def match_zone(message, zone):
        return zone is None or message.question[0].name.is_subdomain(dns.name.from_text(zone))

    result = []
    for ts, message, src in messages:
        if not match_zone(message, zone):
            continue
        src_host = get_hostname(src)
        if 'vlan' in src_host and src_host.endswith('.yndx.net'):
            continue
        result.append([ts, message, src_host])
        if len(result) == reqs_num:
            break
    return result


def make_uniform(messages, duration):
    if duration is None:
        duration = messages[-1][0] - messages[0][0]
    num = len(messages)
    for idx in range(len(messages)):
        messages[idx][0] = messages[0][0] + duration / num * idx
    return messages


def calc_multiplier(messages, duration):
    if len(messages) < 2:
        return 1.0

    real_duration = messages[-1][0] - messages[0][0]
    return real_duration / duration


class Stats:
    def __init__(self):
        self.total_requests = 0
        self.total_diffs = 0
        self.section_diffs = {section: 0 for section in SECTIONS}
        self.timeouts = [0, 0]

    def add(self, other):
        self.total_requests += other.total_requests
        self.total_diffs += other.total_diffs
        assert sorted(self.section_diffs.keys()) == sorted(other.section_diffs.keys())
        for section, value in other.section_diffs.items():
            self.section_diffs[section] += value

    def diffs_repr(self, timeouts=False, color=False):
        repr = lambda v: '{} ({})'.format(v[1], v[0])
        colored_repr = lambda v: colored(repr(v), 'red' if v[1] else 'green')
        return ', '.join(map(
            colored_repr if color else repr,
            [('Diffs', self.total_diffs)] + [('Timed out first', self.timeouts[0]), ('Timed out second', self.timeouts[1])] * timeouts + self.section_diffs.items()
        ))


def send_udp(message, where, port=53):
    return dns.query.udp(message, where=where, port=port, timeout=1)


def send_tcp(message, where, port=53):
    return dns.query.tcp(message, where=where, port=port, timeout=1)


def send(message, where, port, protocol):
    if protocol == 'udp':
        return send_udp(message, where, port)
    elif protocol == 'tcp':
        return send_tcp(message, where, port)


def shoot(ns_first_address, ns_second_address, messages, protocol, time_delta, speed_multiplier, verbose=False):
    stats = {}
    responses = []
    for idx, data in tqdm(enumerate(messages), total=len(messages), desc='shooting', ncols=100, disable=verbose):
        ts, message, src_host = data

        rdtype = dns.rdatatype.to_text(message.question[0].rdtype)
        stats.setdefault(rdtype, Stats()).total_requests += 1

        deadline = ts + time_delta
        to_wait = (deadline - time.time()) / speed_multiplier
        if to_wait >= MIN_SLEEP_TIME:
            logger.debug('Sleep for {:.3f}s'.format(to_wait - EPS))
            time.sleep(to_wait - EPS)

        logger.debug('Send <{} {}>'.format(message.question[0].name, dns.rdatatype.to_text(message.question[0].rdtype)))

        current_responses = [None] * 2
        timed_out = False
        for idx, (ns_server, ns_port) in enumerate((ns_first_address, ns_second_address)):
            try:
                current_responses[idx] = send(message, ns_server, ns_port, protocol)
            except dns.exception.Timeout:
                stats.setdefault(rdtype, Stats()).timeouts[idx] += 1
                timed_out = True
            except dns.query.BadResponse as e:
                logger.warning('Bad response from {}: {}'.format('{}:{}'.format(ns_server, ns_port), e))
            except dns.query.UnexpectedSource as e:
                logger.warning('Invalid source from {}: {}'.format('{}:{}'.format(ns_server, ns_port), e))

        if not timed_out:
            responses.append([current_responses, src_host])

        time_delta += to_wait * (1 - speed_multiplier)

    return responses, stats


def process_message(packed_data):
    ns_first_address, ns_second_address, data, protocol = packed_data
    _, message, src_host = data

    rdtype = dns.rdatatype.to_text(message.question[0].rdtype)

    logger.debug('Send <{} {}>'.format(message.question[0].name, dns.rdatatype.to_text(message.question[0].rdtype)))

    responses = [None] * 2
    timeouts = [0, 0]
    timed_out = False
    for idx, (ns_server, ns_port) in enumerate((ns_first_address, ns_second_address)):
        try:
            responses[idx] = send(message, ns_server, ns_port, protocol)
        except dns.exception.Timeout:
            timeouts[idx] += 1
            timed_out = True
        except dns.query.BadResponse as e:
            logger.warning('Bad response from {}: {}'.format('{}:{}'.format(ns_server, ns_port), e))
        except dns.query.UnexpectedSource as e:
            logger.warning('Invalid source from {}: {}'.format('{}:{}'.format(ns_server, ns_port), e))

    time.sleep(0.01)

    return rdtype, timeouts, [responses, src_host] if not timed_out else None


def shoot_async(ns_first_address, ns_second_address, messages, protocol, threads, verbose=False):
    pool = ThreadPool(threads)

    stats = {}
    responses = []
    args = itertools.product([ns_first_address], [ns_second_address], messages, [protocol])
    for idx, (rdtype, timeouts, response) in tqdm(enumerate(pool.imap(process_message, args)), total=len(messages), desc='Shooting', ncols=100, disable=verbose):
        stats.setdefault(rdtype, Stats()).total_requests += 1
        for idx, timed_out in enumerate(timeouts):
            stats.setdefault(rdtype, Stats()).timeouts[idx] += timed_out
        if response is not None:
            responses.append(response)

    pool.close()
    pool.join()

    return responses, stats


def response_by_sections(resp, skip=[], replaces=[], replaces_by_section={}):
    skip = tuple(skip)
    result = {section_name: set() for section_name in SECTIONS}

    def replace(line, section_name):
        for pattern, repl in replaces + replaces_by_section.get(section_name, []):
            line = pattern.sub(repl, line, 1)
        return line

    def process_authority(line):
        # skip numeric values and email value of SOA record
        return ' '.join(list(filter(lambda token: not token.isnumeric(), line.split()))[:-1])

    section_name = ATTRIBUTES
    for line in resp.to_text().split('\n'):
        if line.startswith(skip):
            continue

        line = replace(line, section_name)

        if line[1:] in SECTIONS:
            section_name = line[1:]
            continue

        if section_name == AUTHORITY:
            line = process_authority(line)
        result[section_name].add(line)

    return result


def response_dict_to_text(response_dict):
    def make_serializable(response_dict):
        for section, values in response_dict.items():
            response_dict[section] = list(values)
        return response_dict
    return json.dumps(make_serializable(response_dict), indent=2)


def compare_response(ns_first, ns_second, first_response, second_response, replaces, replaces_by_section, src_host):
    # skip payload size
    # replace ns[1..N].blah.yandex.net with ns[1..N].<NS-zone>
    first_resp = response_by_sections(first_response, skip=['payload'], replaces=replaces, replaces_by_section=replaces_by_section)
    second_resp = response_by_sections(second_response, skip=['payload'], replaces=replaces, replaces_by_section=replaces_by_section)

    stats = Stats()

    for section_name in SECTIONS:
        if section_name in [ANSWER, AUTHORITY, ADDITIONAL]:
            if not first_resp[section_name].issubset(second_resp[section_name]):
                stats.section_diffs[section_name] = 1
        else:
            if first_resp[section_name] != second_resp[section_name]:
                stats.section_diffs[section_name] = 1

    stats.total_diffs = int(any(stats.section_diffs.values()))

    if stats.total_diffs:
        logger.warning(colored('Found difference in answers:', 'red'))
        logger.warning('Response from {}:\n{}'.format(ns_first, response_dict_to_text(first_resp)))
        logger.warning('Response from {}:\n{}'.format(ns_second, response_dict_to_text(second_resp)))
        logger.warning('Diffs: {}'.format(stats.diffs_repr(color=False)))
        logger.warning('Original query sent from {}'.format(src_host))
    return stats


def compare_responses(ns_first, ns_second, responses):
    replaces = [
        (re.compile(ns_first.split('.', 1)[1], re.IGNORECASE), '<NS-zone>'),
        (re.compile(ns_second.split('.', 1)[1], re.IGNORECASE), '<NS-zone>'),
    ]
    replaces_by_section = {
        ANSWER: [
            (re.compile(r'([^\s]+\s)(?:[\d]+)'), r'\1<TTL>'),
        ],
        ADDITIONAL: [
            (re.compile(r'([^\s]+\s)(?:[\d]+)'), r'\1<TTL>'),
        ],
    }
    stats = {}
    for resps, src_host in responses:
        response_first, response_second = resps
        if response_first is not None and response_second is not None:
            rdtype = dns.rdatatype.to_text(response_first.question[0].rdtype)
            stats.setdefault(rdtype, Stats()).add(compare_response(ns_first, ns_second, response_first, response_second, replaces, replaces_by_section, src_host))
    return stats


def diff(ns_first, ns_first_backend, ns_second, ns_second_backend, messages, protocol, speed_multiplier, threads, verbose):
    ns_first_address, ns_second_address = get_address(ns_first_backend), get_address(ns_second_backend)

    time_delta = time.time() - messages[0][0]
    duration = messages[-1][0] - messages[0][0]

    logger.info('Send {} queries to {} and {} for {:.2f} seconds with {} thread(s)'.format(len(messages), ns_first, ns_second, duration / speed_multiplier, threads))
    if threads <= 1:
        responses, stats = shoot(ns_first_address, ns_second_address, messages, protocol, time_delta, speed_multiplier, verbose)
    else:
        responses, stats = shoot_async(ns_first_address, ns_second_address, messages, protocol, threads, verbose)

    logger.info('Compare responses')
    compare_stats = compare_responses(ns_first, ns_second, responses)
    for rdtype, stat in compare_stats.items():
        stats.setdefault(rdtype, Stats()).add(stat)
    return stats, len(responses)


@retry.retry(tries=3, delay=1, backoff=2)
def push_stats_to_solomon(solomon_url, data, headers):
    response = requests.post(solomon_url, data=data, headers=headers)
    response.raise_for_status()
    logger.info('Push status: {}'.format(response.status_code))


def push_stats(ns_first, ns_second, zone, stats, protocol, solomon_token):
    logger.info('Push statistics to solomon')

    ts = datetime.datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')

    default_common_labels = {
        'project': 'yp_dns',
        'cluster': 'diff',
        'service': 'diff_nameservers',
        'baseline': ns_first,
        'experiment': ns_second,
        'zone': zone,
    }
    common_labels_with_protocol = default_common_labels.copy()
    common_labels_with_protocol['protocol'] = protocol

    common_labels_list = [common_labels_with_protocol]
    if protocol == 'udp':
        common_labels_list.append(default_common_labels)

    for common_labels in common_labels_list:
        registry = MetricRegistry(common_labels)
        for rdtype, stat in stats.items():
            for section, diffs in stat.section_diffs.items():
                registry.int_gauge({'rdtype': rdtype, 'section': section, 'sensor': 'diffs'}).set(diffs)
            registry.int_gauge({'rdtype': rdtype, 'sensor': 'requests'}).set(stat.total_requests)
            registry.int_gauge({'rdtype': rdtype, 'sensor': 'timeouts.baseline'}).set(stat.timeouts[0])
            registry.int_gauge({'rdtype': rdtype, 'sensor': 'timeouts.experiment'}).set(stat.timeouts[1])

        headers = {
            'Content-Type':  'application/json',
            'Authorization': 'OAuth {}'.format(solomon_token),
        }
        data = json.loads(dumps(registry, format='json'))
        for sensor in data['sensors']:
            sensor['ts'] = ts
        push_stats_to_solomon(SOLOMON_URL, json.dumps(data), headers)


def init_loggers(args):
    log_level = logging.DEBUG if args.verbose else logging.INFO
    logging.basicConfig(level=log_level, format=LOGGING_FORMAT)

    if args.log_file:
        rotating_file_handler = RotatingFileHandler(args.log_file, maxBytes=args.log_max_bytes, backupCount=args.log_backup_count)
        rotating_file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
        logger.addHandler(rotating_file_handler)


def main(argv):
    args = parse_args(argv)

    init_loggers(args)

    messages = create_messages(args.input, args.format, args.reqs_num, args.verbose)
    messages = filter_messages(messages, args.zone, args.reqs_num)

    if len(messages) == 0:
        logger.info('No queries left after filter')
        return

    if args.uniform:
        messages = make_uniform(messages, args.duration)

    if args.duration is not None:
        args.multiplier = calc_multiplier(messages, args.duration)

    if args.multiplier is None:
        args.multiplier = 1.0

    stats, total = diff(args.ns_first, args.ns_first_backend, args.ns_second, args.ns_second_backend, messages, args.protocol, args.multiplier, args.threads, args.verbose)

    for rdtype, stat in stats.items():
        logger.info('Statistics for {:4}: {} (Requests), {}'.format(rdtype, stat.total_requests, stat.diffs_repr(timeouts=True, color=True)))

    if args.solomon_token is not None:
        push_stats(args.ns_first, args.ns_second, args.zone, stats, args.protocol, args.solomon_token)

if __name__ == '__main__':
    main(sys.argv[1:])
