import argparse
import copy
import datetime
import ipaddress
import json
import logging
import os
import random
import requests
import retry
import sys

import dns.exception
import dns.name
import dns.namedict
import dns.rdataclass
import dns.rdatatype

from filelock import FileLock

from multiprocessing.dummy import Pool

import infra.yp_dns.monitoring.tools.yp_dns_monitoring.lib.util as util

from infra.yp_dns.monitoring.tools.yp_dns_monitoring.lib.quantity import Quantity

from infra.yp_dns.daemon import DnsClient
from infra.yp_dns.tools.handle_duplicate_records.lib.handle_duplicate_records import handle_duplicate_records

from infra.yp_service_discovery.api import api_pb2
from infra.yp_service_discovery.python.resolver.resolver import Resolver

from yt.yson import YsonEntity

logger = logging.getLogger()
sh = logging.StreamHandler()
sh.setFormatter(logging.Formatter('[%(asctime)s] [%(levelname)s] %(message)s'))
logger.addHandler(sh)


NS_SERVERS = [
    'ns1.yp-dns.yandex.net',
    'ns2.yp-dns.yandex.net',
]

YP_DNS_DATACENTERS = [
    'SAS',
    'MAN',
    'VLA',
    'MSK',
]

RECORD_TYPES = [
    'AAAA',
    'PTR',
]

MAX_SELECTION_AGE = {
    'MAN-PRE': 15,
    'SAS-TEST': 15,
    'SAS': 40,
    'MAN': 40,
    'VLA': 40,
    'IVA': 15,
    'MYT': 15,
    'XDC': 210
}

QUANTITY_JSON_PATH = "quantity.json"
SAVED_RECORD_SETS_JSON_PATH = "saved_record_sets.json"

IN_YANDEX_TEAM_RU = dns.name.from_text('in.yandex-team.ru')
IN_YANDEX_NET = dns.name.from_text('in.yandex.net')


def check_duplicate_records(cluster, limit):
    yp_client = util.get_yp_client(cluster)

    return handle_duplicate_records(yp_client, remove=False, remove_limit=limit, logger=logger)


def full_scan_cluster(cluster, dns_zones, quantity):
    def check(limit, continuation_token, yp_client, new_quantity):
        result = util.select_objects(yp_client,
                                     selectors=['/meta/id', '/meta/creation_time', '/spec/records', '/labels/zone', '/labels/gencfg'],
                                     limit=limit,
                                     continuation_token=continuation_token)

        timestamp = result['timestamp']
        continuation_token = result['continuation_token']
        should_continue = len(result['results']) == limit

        record_sets_batch = []

        for object in result['results']:
            qname, creation_time, record_set, label_zone, label_gencfg = map(lambda value: value['value'], object)
            creation_date = datetime.datetime.fromtimestamp(creation_time / 1000 / 1000)
            if datetime.datetime.utcnow() - creation_date <= datetime.timedelta(minutes=10):
                continue

            if not record_set:
                continue

            if label_zone != YsonEntity():
                superdomain, found = dns.name.from_text(label_zone), dns.name.from_text(label_zone) in dns_zones
            elif label_gencfg != YsonEntity():
                superdomain, found = dns.name.from_text("gencfg-c.yandex.net"), True
            else:
                superdomain, found = dns_zones.get_deepest_match(dns.name.from_text(qname))

            if not found or superdomain not in new_quantity:
                continue

            new_quantity[superdomain] += 1
            if superdomain in quantity and random.random() > 1000 / max(1000, quantity[superdomain]):
                continue

            if superdomain not in quantity and new_quantity[superdomain] > 1000:
                continue

            record_sets_batch.append((str(superdomain), qname, record_set, timestamp))

        return continuation_token, should_continue, record_sets_batch

    try:
        creation_time = datetime.datetime.utcnow().timestamp()
        yp_client = util.get_yp_client(cluster)
        limit = 5000
        continuation_token = None
        new_quantity = dict().fromkeys(copy.deepcopy(dns_zones), 0)
        saved_record_sets = []

        while True:
            continuation_token, should_continue, record_sets_batch = check(limit, continuation_token, yp_client, new_quantity)
            saved_record_sets += record_sets_batch

            if not should_continue:
                break

    except Exception as e:
        error = str(e)
        logger.exception(error)

    return creation_time, new_quantity, saved_record_sets


def check_records(dns_address, cluster, generation_time, record_sets, recheck_records=False):
    def check_equal(zone, yp_dns_addresses, yp_master_addresses):
        try:
            formatted_dns_addresses = set(map(ipaddress.ip_address, yp_dns_addresses))
            formatted_master_addresses = set(map(ipaddress.ip_address, yp_master_addresses))
        except:
            formatted_dns_addresses = set(map(dns.name.from_text, yp_dns_addresses))
            formatted_master_addresses = set(map(dns.name.from_text, yp_master_addresses))

        if dns.name.from_text(zone) == IN_YANDEX_NET or \
           dns.name.from_text(zone) == IN_YANDEX_TEAM_RU:
            return formatted_dns_addresses == formatted_master_addresses or \
                formatted_dns_addresses.issubset(formatted_master_addresses) and \
                len(formatted_dns_addresses) == 3

        return formatted_dns_addresses == formatted_master_addresses

    def get_addresses(record_set):
        if not record_set:
            return dict()

        result = dict()
        for type in RECORD_TYPES:
            records = list(filter(lambda record: record['type'] == type, record_set))
            addresses = set(map(lambda record: record['data'], records))
            if addresses:
                result[type] = addresses

        return result

    def check(limit, offset, yp_client):
        result = record_sets[offset:offset + limit]
        offset += limit
        should_continue = len(result) == limit

        yp_dns = DnsClient(util.get_address(dns_address))

        diff_data = dict()
        first_diff_record_timestamp = float('inf')

        diffs, timeouts, skips, total = 0, 0, 0, 0
        for object in result:
            zone, qname, record_set, timestamp = object

            type_2_addresses = get_addresses(record_set)
            total += 1

            yp_master_addresses = set()
            yp_dns_addresses = set()

            try:
                udp_timeout = 5
                for type, addresses in type_2_addresses.items():
                    yp_master_addresses = yp_master_addresses.union(addresses)
                    resp = yp_dns.udp(qname, type, timeout=udp_timeout)
                    yp_dns_addresses = yp_dns_addresses.union(set(str(address) for address in yp_dns.get_answer(resp, qname, type) or []))
            except dns.exception.Timeout:
                timeouts += 1
                logger.warning(f"DNS Client on address {dns_address} cluster {cluster} UDP query response with timeout {udp_timeout} on qname {qname}.")
                continue

            if not check_equal(zone, yp_dns_addresses, yp_master_addresses):
                if recheck_records:
                    yp_master_addresses = set()
                    data = util.get_object(yp_client, id=qname, selectors=['/spec/records'])
                    timestamp = data['timestamp']
                    type_2_addresses = get_addresses(data['result'][0]['value']) if data['result'] else dict()
                    for _, addresses in type_2_addresses.items():
                        yp_master_addresses = yp_master_addresses.union(addresses)
                    if check_equal(zone, yp_dns_addresses, yp_master_addresses):
                        continue

                diff_data[qname] = int(timestamp), dns_address, yp_master_addresses, yp_dns_addresses
                first_diff_record_timestamp = min(first_diff_record_timestamp, int(timestamp))

        if diff_data:
            if not recheck_records:
                watch_result = util.watch_objects(yp_client, first_diff_record_timestamp)
                for object in watch_result:
                    if object['object_id'] in diff_data.keys() and object['timestamp'] > diff_data[object['object_id']][0]:
                        skips += 1
                        logger.warning(f"Answers for {object['object_id']} differ on {diff_data[object['object_id']][1]} ns server: " +
                                       f"{diff_data[object['object_id']][3]} (yp dns) " +
                                       f"but dns_record_set has been modified since it has been added to record sets selection file with YP request timestamp {object['timestamp']}.")
                        del diff_data[object['object_id']]

            for qname, data in diff_data.items():
                diffs += 1
                logger.warning(f"Answers for {qname} differ on {data[1]} ns server with YP request timestamp {data[0]}: {data[2]} (yp master) and {data[3]} (yp dns)")
                if recheck_records:
                    logger.warning(f"Answers for {qname} on {data[1]} ns server still differ AFTER recheck on yp master.")

        return diffs, timeouts, skips, total, offset, should_continue

    status = 'OK'
    diffs, timeouts, skips, total = 0, 0, 0, 0
    error = None

    if datetime.datetime.utcnow() - datetime.datetime.fromtimestamp(generation_time) >= datetime.timedelta(minutes=MAX_SELECTION_AGE[cluster.upper()]):
        status = 'NO_DATA'
        error = f'Old data in dns record sets selection. Generated on {datetime.datetime.fromtimestamp(generation_time)}, but maximum age is {MAX_SELECTION_AGE[cluster.upper()]} minutes'
        return status, diffs, timeouts, skips, total, error

    try:
        offset = 0
        limit = 5000
        yp_client = util.get_yp_client(cluster)

        while True:
            current_diffs, current_timeouts, current_skips, current_total, offset, should_continue = check(limit, offset, yp_client)
            diffs += current_diffs
            timeouts += current_timeouts
            skips += current_skips
            total += current_total

            if not should_continue:
                break

        if skips > total * 0.1:
            status = 'WARN'
        if diffs > max(1, total * 0.05) or timeouts > total * 0.1 or skips > total * 0.2:
            status = 'CRIT'

    except Exception as e:
        error = str(e)
        logger.exception(e)
        status = 'NO_DATA'

    return status, diffs, timeouts, skips, total, error


def check_cluster_count_duplicates(cluster, limit):
    def make_description(number_of_duplicates, number_of_record_sets_with_duplicates):
        description = ''
        if number_of_duplicates > 0:
            description += "CRIT: Found {} duplicate records in {} dns_record_sets in cluster {} \n\n".format(number_of_duplicates, number_of_record_sets_with_duplicates, cluster)
        else:
            description += "OK: There are NO duplicate records in cluster {} \n".format(cluster)

        return description

    aggr_status = 'OK'

    number_of_duplicates, number_of_record_sets_with_duplicates = check_duplicate_records(cluster, limit)
    if number_of_duplicates > 0:
        aggr_status = 'CRIT'

    description = make_description(number_of_duplicates, number_of_record_sets_with_duplicates)

    return aggr_status, description


def check_ns_server(args):
    ns_server, cluster, generation_time, record_sets, recheck = args
    logger.info(f"Check {ns_server} ns server")

    status, diffs, timeouts, skips, total, error = check_records(ns_server, cluster, generation_time, record_sets, recheck)

    return (ns_server, status, diffs, timeouts, skips, total, error)


def check_cluster_diff_with_yp(cluster, generation_time, record_sets, recheck=False):
    def make_description(oks, warns, no_datas, crits):
        description = ''
        if crits:
            description += "CRIT: Answers from master and dns differ for {}\n".format(
                ', '.join(map(
                    lambda crit: "{server} {diffs}/{timeouts}/{skips}/{total} (diffs/timeouts/skips/total)".format(server=crit[0], diffs=crit[1], timeouts=crit[2], skips=crit[3], total=crit[4]),
                    crits
                ))
            )

        if warns:
            description += "WARN: Answers from master and dns differ for {}\n".format(
                ', '.join(map(
                    lambda warn: "{server} {diffs}/{timeouts}/{skips}/{total} (diffs/timeouts/skips/total)".format(server=warn[0], diffs=warn[1], timeouts=warn[2], skips=warn[3], total=warn[4]),
                    warns
                ))
            )

        if no_datas:
            description += "NO_DATA: An error occured while checking consistency:\n{}\n".format(
                '\n'.join(map(
                    lambda no_data: "For {server} exception raised:\n{error}".format(server=no_data[0], error=no_data[1]),
                    no_datas
                ))
            )

        if oks:
            description += "OK: Servers {} are consistent with master\n".format(', '.join(oks))

        return description

    oks, no_datas, warns, crits = [], [], [], []
    aggr_status = 'OK'

    pool = Pool()
    result = pool.map_async(
        check_ns_server,
        [(ns_server, cluster, generation_time, record_sets, recheck) for ns_server in NS_SERVERS]
    )

    pool.close()
    results = result.get()

    for single_result in results:
        ns_server, status, diffs, timeouts, skips, total, error = single_result

        if status == 'CRIT':
            aggr_status = 'CRIT'
            crits.append((ns_server, diffs, timeouts, skips, total))
        elif status == 'WARN':
            if status != 'CRIT':
                aggr_status = 'WARN'
            warns.append((ns_server, diffs, timeouts, skips, total))
        elif status == 'NO_DATA':
            if aggr_status not in ['CRIT', 'WARN']:
                aggr_status = 'NO_DATA'
            no_datas.append((ns_server, error))
        elif status == 'OK':
            oks.append(ns_server)

    result.successful()

    return aggr_status, make_description(oks, warns, no_datas, crits)


@retry.retry(tries=5, delay=2, backoff=2)
def juggler_notify(service, cluster, status, description):
    result = util.get_retry_session().post(
        "http://juggler-push.search.yandex.net/events",
        json={
            "source": "yp_dns_monitoring",
            "events": [
                {
                    "description": description,
                    "host": f"yp-{cluster}.yandex.net",
                    "instance": "",
                    "service": service,
                    "status": status,
                }
            ]
        },
        timeout=10
    )
    result.raise_for_status()


@retry.retry(tries=5, delay=1, backoff=2)
def get_config():
    resolver = Resolver(client_name='yp_dns_monitoring', timeout=5)
    port = "9091"

    for datacenter in YP_DNS_DATACENTERS:
        request = api_pb2.TReqResolveEndpoints()
        request.cluster_name = 'xdc'
        request.endpoint_set_id = '{}_YP_DNS.mtn'.format(datacenter)

        result = None

        try:
            result = resolver.resolve_endpoints(request)
            logger.info(f"resolve_endpoints {request.endpoint_set_id} performed successfully")
        except Exception as e:
            logger.exception(e)
            continue

        if not len(result.endpoint_set.endpoints):
            continue

        for endpoint in result.endpoint_set.endpoints:
            url = "http://{fqdn}:{port}/config".format(fqdn=endpoint.fqdn, port=port)

            try:
                config = requests.get(url).json()
                logger.info(f"requests.get {url} performed successfully")
                return config
            except requests.exceptions.RequestException as e:
                logger.exception(e)

    raise Exception("Failed to get config")


def get_zones(config, cluster):
    dns_zones = dns.namedict.NameDict()
    dns_zones[dns.name.empty] = False
    for zone in config['Zones']:
        if cluster in zone['YPClusters']:
            name = dns.name.from_text(zone['Name'])
            dns_zones[name] = True
    return dns_zones


def read_record_sets(path):
    with FileLock(path + ".lock"):
        try:
            return json.load(open(path, 'r'))
        except Exception as e:
            logging.warning("Failed to read json from file {}: {}".format(path, e))
            return {}


def write_record_sets(path, result):
    with FileLock(path + ".lock"):
        try:
            with open(path, 'w') as write_file:
                json.dump(result, write_file, indent=4)
        except Exception as e:
            logging.warning("Failed to write json to file {}: {}".format(path, e))


def parse_args(argv):
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest='action', help='Choose monitoring')

    parser.add_argument('--clusters')
    parser.add_argument('--yp-token')
    parser.add_argument('--verbose', action='store_true')

    parser_diff_with_yp = subparsers.add_parser('diff_with_yp', help='diff_with_yp monitoring')
    parser_diff_with_yp.add_argument('--mode', choices=['slow_preload_batch', 'fast_batch_check'], help='Run on preloaded batch or preload batch for further runs')
    parser_diff_with_yp.add_argument('--quantity-file', default=QUANTITY_JSON_PATH, help='File for slow_preload_bath mode to calculate probabilities.')
    parser_diff_with_yp.add_argument('--records-file', default=SAVED_RECORD_SETS_JSON_PATH, help='File to save batch.')
    parser_diff_with_yp.add_argument('--recheck', action='store_true', help='Should select records from MASTER again if found diff. Runs WatchObjects and skips diffs if not set.')

    parser_count_duplicates = subparsers.add_parser('count_duplicates', help='count_duplicates monitoring')
    parser_count_duplicates.add_argument('--limit', type=int, help='Max record sets with duplicates to search (all by default)')

    args = parser.parse_args(argv)

    if args.clusters is not None:
        args.clusters = args.clusters.split(',')

    if args.yp_token is not None:
        os.environ['YP_TOKEN'] = args.yp_token

    return args


def main(argv):
    args = parse_args(argv)

    if args.verbose:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    config = get_config()

    clusters = args.clusters if args.clusters is not None else util.get_current_clusters()

    total_quantity = Quantity(args.quantity_file) if args.action == 'diff_with_yp' and args.mode == 'slow_preaload_batch' else Quantity()
    total_saved_record_sets = read_record_sets(args.records_file) if args.action == 'diff_with_yp' and args.mode == 'fast_batch_check' else {}

    for cluster in clusters:
        logger.info(f"Checking YP DNS records for zone {cluster}.yp-c.yandex.net")

        if args.action == 'diff_with_yp' and args.mode == 'slow_preload_batch':
            if cluster not in total_quantity.data:
                total_quantity.data[cluster] = {}

            dns_zones = get_zones(config, cluster)
            creation_time, new_quantity, new_saved_record_sets = full_scan_cluster(cluster, dns_zones, quantity=total_quantity.data[cluster])
            total_quantity.data[cluster] = new_quantity
            total_saved_record_sets[cluster] = {'creation_time': creation_time, 'saved_records': new_saved_record_sets}

            logger.info(f"Selected record sets from {cluster} cluster")
        elif args.action == 'diff_with_yp':
            if cluster not in total_saved_record_sets:
                logger.warning(f"No saved record sets for {cluster}.yp-c.yandex.net yet")
                return 0

            status, description = check_cluster_diff_with_yp(cluster, total_saved_record_sets[cluster]['creation_time'], total_saved_record_sets[cluster]['saved_records'], args.recheck)

            logger.info(f"Result status: {status}")
            logger.info(f"Additional info:\n{description}")
        elif args.action == 'count_duplicates':
            status, description = check_cluster_count_duplicates(cluster, args.limit)

            logger.info(f"Result status: {status}")
            logger.info(f"Additional info:\n{description}")

        if args.action == 'diff_with_yp' and args.mode == 'fast_batch_check':
            juggler_notify('yp_dns_master_sync_lag_monitoring', cluster, status, description)
        elif args.action == 'count_duplicates':
            juggler_notify('yp_dns_record_set_duplicates_monitoring', cluster, status, description)

    if args.action == 'diff_with_yp' and args.mode == 'slow_preload_batch':
        total_quantity.write(args.quantity_file)
        write_record_sets(args.records_file, total_saved_record_sets)


if __name__ == '__main__':
    main(sys.argv[1:])
