import argparse
import ipaddress
import logging
import random
import requests
import socket
import sys
import time

import dns.message
import dns.query
import dns.rdataclass
import dns.rdatatype

from infra.yp_dns_api.bridge.api import api_pb2
from infra.yp_dns_api.client.client import YpDnsApiBridgeClient

from yp.client import YpClient, find_token


TEST_R_ZONE = 'test-zone-r.yandex.net'
TEST_R_ZONE_CLUSTERS = [
    'sas-test',
    'man-pre',
]

DEFAULT_MAX_RECORD_SETS = 15
DEFAULT_MAX_REPLICATION_TIME = 15.0


logger = logging.getLogger()
sh = logging.StreamHandler()
sh.setFormatter(logging.Formatter('[%(asctime)s] [%(levelname)s] %(message)s'))
logger.addHandler(sh)


class CheckRecordSetsException(Exception):
    pass


class CheckDNSResponseException(Exception):
    pass


def get_address(hostname, port):
    return socket.getaddrinfo(hostname, port)[0][4]


def get_address_and_port(address):
    hostname, port = address.split(':')
    return get_address(hostname, int(port))[:2]


def unique_records(records):
    if not records:
        return []

    result = {}
    for record in records:
        result[(record['type'], record['class'], record['data'])] = record
    return list(result.values())


def sort_records(records):
    return sorted(records, key=lambda record: (record.get('class', ''), record['type'], record['data']))


def records_equal(lhs, rhs):
    if len(lhs) != len(rhs):
        return False
    for lhs_record, rhs_record in zip(lhs, rhs):
        for key, value in lhs_record.items():
            if key in rhs_record:
                if value != rhs_record[key]:
                    return False
    return True


def list_record_sets(yp_client, zone):
    limit = 1000
    continuation_token = None
    timestamp = yp_client.generate_timestamp()

    result = []
    while True:
        chunk = yp_client.select_objects(
            'dns_record_set',
            selectors=[
                '/meta/id',
                '/spec/records',
            ],
            filter='[/labels/zone] = "{}"'.format(zone),
            limit=limit,
            timestamp=timestamp,
            options={'continuation_token': continuation_token},
            enable_structured_response=True,
        )
        continuation_token = chunk['continuation_token']

        result.extend(chunk['results'])

        if len(chunk['results']) < limit:
            break

    return result


def update_records(bridge_client, request):
    logger.info('YP DNS API request:\n{}'.format(request))
    response = bridge_client.update_records(request)
    logger.info('YP DNS API response:\n{}'.format(response))
    return response


def add_records(bridge_client, records_by_rs_id):
    request = api_pb2.TReqUpdateRecords()
    for record_set_id, records in records_by_rs_id.items():
        for record in records:
            record_request = request.requests.add()
            record_request.update.fqdn = record_set_id
            record_request.update.type = api_pb2.ERecordType.Value(record['type'])
            record_request.update.data = record['data']
            if 'ttl' in record:
                record_request.update.ttl = record['ttl']
            if 'class' in record:
                record_request.update.class_ = record['class']
    return update_records(bridge_client, request)


def remove_records(bridge_client, records_by_rs_id):
    request = api_pb2.TReqUpdateRecords()
    for record_set_id, records in records_by_rs_id.items():
        for record in records:
            record_request = request.requests.add()
            record_request.remove.fqdn = record_set_id
            record_request.remove.type = api_pb2.ERecordType.Value(record['type'])
            record_request.remove.data = record['data']
    return update_records(bridge_client, request)


def cleanup_zone(bridge_client, yp_clients, zone, clusters):
    record_sets_by_cluster = {}
    for cluster in clusters:
        record_sets_by_cluster[cluster] = list_record_sets(yp_clients[cluster], zone)

    merged_record_sets = {}
    for cluster, record_sets in record_sets_by_cluster.items():
        for id, records in record_sets:
            id, records = id['value'], records['value']
            merged_record_sets.setdefault(id, [])
            if records:
                merged_record_sets[id].extend(records)
    for id, records in merged_record_sets.items():
        merged_record_sets[id] = unique_records(records)

    response = remove_records(bridge_client, merged_record_sets)
    for record_response in response.responses:
        assert record_response.WhichOneof('response') == 'remove'
        assert record_response.remove.status == api_pb2.TRspRemoveRecord.ERemoveRecordStatus.OK
        assert not record_response.remove.error_message
        assert record_response.remove.cluster in clusters


def fill_zone(bridge_client, zone, clusters, max_record_sets):
    records_by_rs_id = {}
    for i in range(max_record_sets):
        fqdn = 'fqdn-{:0{}}.{}'.format(i, len(str(max_record_sets)), zone)
        address = str(ipaddress.IPv6Address(random.randint(0, 2**128 - 1)))
        record = {
            'type': 'AAAA',
            'data': address,
        }
        records_by_rs_id[fqdn] = [
            record,
        ]

    response = add_records(bridge_client, records_by_rs_id)
    for record_response in response.responses:
        assert record_response.WhichOneof('response') == 'update'
        assert record_response.update.status == api_pb2.TRspUpdateRecord.EUpdateRecordStatus.OK
        assert not record_response.update.error_message
        assert record_response.update.cluster in clusters

    return records_by_rs_id


def check_record_sets(yp_clients, zone, clusters, records_by_rs_id):
    def log_status(cluster, status, log_level=logging.INFO):
        logger.log(log_level, 'Check record sets for {} zone in cluster {}: {}'.format(zone, cluster, status))

    for cluster in clusters:
        log_status(cluster, 'START')

        try:
            yp_client = yp_clients[cluster]
            actual_record_sets = list_record_sets(yp_client, zone)

            expected_records = records_by_rs_id.copy()
            for id, records in actual_record_sets:
                id, records = id['value'], records['value']

                sorted_expected_records = sort_records(expected_records[id])
                sorted_actual_records = sort_records(unique_records(records))
                if not records_equal(sorted_expected_records, sorted_actual_records):
                    logger.error('Record sets do not match:\n(actual):\t{}\n!=\n(expected):\t{}'.format(
                        sorted_actual_records,
                        sorted_expected_records
                    ))
                    raise CheckRecordSetsException('Record set {} in cluster {} differs from expected'.format(id, cluster))

                expected_records.pop(id)

            if len(expected_records) != 0:
                raise CheckRecordSetsException('Did not found some record sets in cluster {}'.format(cluster))
        except:
            log_status(cluster, 'ERROR', logging.ERROR)
            raise

        log_status(cluster, 'OK')


def check_dns_response(records_by_rs_id, address, can_be_empty=True):
    host, port = get_address_and_port(address)
    for id, records in records_by_rs_id.items():
        query = dns.message.make_query(id, 'AAAA')
        resp = dns.query.udp(query, where=host, port=port, timeout=5)

        if can_be_empty and len(resp.answer) == 0:
            continue

        if len(resp.answer) != len(records):
            raise CheckDNSResponseException('Expected {}, but DNS answer is {}'.format(records, list(map(str, resp.answer))))

        for answer in resp.answer:
            if answer.rdtype != dns.rdatatype.AAAA or answer.rdclass != dns.rdataclass.IN or list(map(str, answer.items)) != list(map(lambda record: record['data'], records)):
                raise CheckDNSResponseException('Expected {}, but DNS answer is {}'.format(records, list(map(str, resp.answer))))


def wait_for_replication(records_by_rs_id, max_replication_time, ns_address):
    logger.info('Wait for replication for {} seconds'.format(max_replication_time))

    start_time = time.time()
    time.sleep(max_replication_time / 2)
    while time.time() - start_time < max_replication_time:
        check_dns_response(records_by_rs_id, address=ns_address)
        time.sleep(max_replication_time / 100.0)
    check_dns_response(records_by_rs_id, address=ns_address, can_be_empty=False)


def juggler_notify(service, status, description):
    logging.info('Push juggler event.\nStatus: {}\nDescription:\n{}'.format(status, description))
    r = requests.post(
        "http://juggler-push.search.yandex.net/events",
        json={
            "source": "multicluster_updates_monitoring",
            "events": [
                {
                    "status": status,
                    "description": description,
                    "host": "multicluster_updates.yp_dns_api",
                    "service": service,
                    "instance": "",
                }
            ]
        },
        timeout=10
    )
    r.raise_for_status()


def run_checks(args):
    bridge_client = YpDnsApiBridgeClient(args.bridge_address)
    yp_clients = {
        cluster: YpClient('{}.yp.yandex.net:8090'.format(cluster), config={'token': find_token()})
        for cluster in args.clusters
    }

    cleanup_zone(bridge_client, yp_clients, args.zone, args.clusters)
    if args.only_cleanup:
        return

    records_added = fill_zone(bridge_client, args.zone, args.clusters, args.max_record_sets)
    try:
        wait_for_replication(records_added, args.max_replication_time, args.ns_address)
        check_record_sets(yp_clients, args.zone, args.clusters, records_added)
    except CheckRecordSetsException as e:
        logging.exception(e)
        juggler_notify(
            service='unexpected_state',
            status='CRIT',
            description='YP DNS API external monitoring:\n{}'.format(e),
        )
        raise e
    except CheckDNSResponseException as e:
        logging.exception(e)
        juggler_notify(
            service='incorrect_dns_answer',
            status='CRIT',
            description='YP DNS API external monitoring:\n{}'.format(e),
        )
        raise e

    juggler_notify(
        service='unexpected_state',
        status='OK',
        description='YP DNS API external monitoring: OK',
    )
    juggler_notify(
        service='incorrect_dns_answer',
        status='OK',
        description='YP DNS API external monitoring: OK',
    )


def parse_args(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('-z', '--zone', default=TEST_R_ZONE,
                        help='Zone name')
    parser.add_argument('-c', '--clusters', default=','.join(TEST_R_ZONE_CLUSTERS),
                        help='Zone clusters')
    parser.add_argument('-x', '--max-record-sets', type=int, default=DEFAULT_MAX_RECORD_SETS,
                        help='Max record sets to generate for zone')
    parser.add_argument('-t', '--max-replication-time', type=float, default=DEFAULT_MAX_REPLICATION_TIME,
                        help='Max acceptable time for replication')
    parser.add_argument('--only-cleanup', action='store_true',
                        help='Only remove all zone record sets and exit')
    parser.add_argument('-n', '--ns-address', default='ns1.yp-dns.yandex.net:53',
                        help='Nameserver address')
    parser.add_argument('-b', '--bridge-address', default='dns-api-bridge.yp.yandex.net:8081',
                        help='YP DNS API Bridge address')

    args = parser.parse_args(argv)
    args.clusters = args.clusters.split(',')
    return args


def main(argv):
    args = parse_args(argv)
    logger.setLevel(logging.DEBUG)
    run_checks(args)


if __name__ == '__main__':
    main(sys.argv[1:])
