import argparse
import ipaddress
import logging
import random
import requests
import sys
import time

from enum import Enum

from infra.yp_dns_api.bridge.api import api_pb2
from infra.yp_dns_api.client.client import YpDnsApiBridgeClient
from infra.yp_dns_api.monitoring.libs.solomon import SolomonClient

from infra.libs.yp_dns.dynamic_zones.zones_manager_service.api import api_pb2 as zm_api_pb2
from infra.libs.yp_dns.zone.protos import config_pb2 as zone_pb2

from yp.client import YpClient, find_token


DEFAULT_ZONE_APEX = "zm-monitoring.test-ypdnsapi.yandex.net"
DEFAULT_CLUSTERS = [
    "sas-test",
    "man-pre",
]

DEFAULT_NS_ADDRESS = "ns1.yp-dns.yandex.net:53"
DEFAULT_BRIDGE_GRPC_ADDRESS = "dns-api-bridge.yp.yandex.net:8081"

DEFAULT_ZONES_NUMBER = 5
DEFAULT_MAX_ZONE_CREATION_TIME = 30.0
DEFAULT_RECORD_SETS_NUMBER = 15
DEFAULT_MAX_REPLICATION_TIME = 15.0


SERVICE_TYPES = [
    [
        "BRIDGE",
        "REPLICATOR",
    ],
    [
        "DNS-API",
    ],
]


class InvalidZonesListing(Exception):
    JUGGLER_SERVICE = "zones_listing"


class TooLongStatusChangeForZone(Exception):
    JUGGLER_SERVICE = "zone_change_status_time"


class UnclassifiedAssertionError(Exception):
    JUGGLER_SERVICE = "assert_error"


class UpdateRecordsError(Exception):
    JUGGLER_SERVICE = "zone_update_records"


ERROR_EXCEPTIONS = (
    InvalidZonesListing,
    TooLongStatusChangeForZone,
    UnclassifiedAssertionError,
    UpdateRecordsError,
)


logger = logging.getLogger()
sh = logging.StreamHandler()
sh.setFormatter(logging.Formatter('[%(asctime)s] [%(levelname)s] %(message)s'))
logger.addHandler(sh)


solomon = SolomonClient(
    project="yp_dns_api",
    cluster="external_monitoring",
    service="e2e",
)


def unique_records(records):
    if not records:
        return []

    result = {}
    for record in records:
        result[(record['type'], record['class'], record['data'])] = record
    return list(result.values())


def list_record_sets(yp_client, zone):
    limit = 1000
    continuation_token = None
    timestamp = yp_client.generate_timestamp()

    result = []
    while True:
        chunk = yp_client.select_objects(
            'dns_record_set',
            selectors=[
                '/meta/id',
                '/spec/records',
            ],
            filter='[/labels/zone] = "{}"'.format(zone),
            limit=limit,
            timestamp=timestamp,
            options={'continuation_token': continuation_token},
            enable_structured_response=True,
        )
        continuation_token = chunk['continuation_token']

        result.extend(chunk['results'])

        if len(chunk['results']) < limit:
            break

    return result


def update_records(bridge_client, request):
    logger.info("Updating records")
    logger.info(f"Request: {request}")
    response = bridge_client.update_records(request)
    logger.info(f"Response: {response}")
    return response


def add_records(bridge_client, records_by_rs_id):
    request = api_pb2.TReqUpdateRecords()
    for record_set_id, records in records_by_rs_id.items():
        for record in records:
            record_request = request.requests.add()
            record_request.update.fqdn = record_set_id
            record_request.update.type = api_pb2.ERecordType.Value(record['type'])
            record_request.update.data = record['data']
            if 'ttl' in record:
                record_request.update.ttl = record['ttl']
            if 'class' in record:
                record_request.update.class_ = record['class']
    return update_records(bridge_client, request)


def remove_records(bridge_client, records_by_rs_id):
    request = api_pb2.TReqUpdateRecords()
    for record_set_id, records in records_by_rs_id.items():
        for record in records:
            record_request = request.requests.add()
            record_request.remove.fqdn = record_set_id
            record_request.remove.type = api_pb2.ERecordType.Value(record['type'])
            record_request.remove.data = record['data']
    return update_records(bridge_client, request)


def cleanup_zone(bridge_client, yp_clients, zone_name, clusters):
    record_sets_by_cluster = {}
    for cluster in clusters:
        record_sets_by_cluster[cluster] = list_record_sets(yp_clients[cluster], zone_name)

    merged_record_sets = {}
    for cluster, record_sets in record_sets_by_cluster.items():
        for id, records in record_sets:
            id, records = id['value'], records['value']
            merged_record_sets.setdefault(id, [])
            if records:
                merged_record_sets[id].extend(records)
    for id, records in merged_record_sets.items():
        merged_record_sets[id] = unique_records(records)

    if merged_record_sets:
        response = remove_records(bridge_client, merged_record_sets)
        for record_response in response.responses:
            try:
                assert record_response.WhichOneof('response') == 'remove'
                assert record_response.remove.status == api_pb2.TRspRemoveRecord.ERemoveRecordStatus.OK
                assert not record_response.remove.error_message
                assert record_response.remove.cluster in clusters
            except AssertionError as e:
                logger.exception(f"Failed to remove records in zone {zone_name}")
                raise UpdateRecordsError(f"Failed to remove records in {zone_name}. Error: {str(e)}")


def wait_for_zone_empty(yp_clients, zone_name, clusters):
    while True:
        empty = True
        for cluster in clusters:
            if len(list_record_sets(yp_clients[cluster], zone_name)) > 0:
                empty = False
                break
        if empty:
            break
        time.sleep(1)


def check_fill_zone(bridge_client, zone_name, record_sets_num, clusters):
    logger.info(f"Filling {zone_name} with {record_sets_num} record sets")
    records_by_rs_id = {}
    for i in range(record_sets_num):
        fqdn = 'fqdn-{:0{}}.{}'.format(i, len(str(record_sets_num)), zone_name)
        address = str(ipaddress.IPv6Address(random.randint(0, 2**128 - 1)))
        record = {
            'type': 'AAAA',
            'data': address,
        }
        records_by_rs_id[fqdn] = [
            record,
        ]

    if records_by_rs_id:
        response = add_records(bridge_client, records_by_rs_id)
        for record_response in response.responses:
            try:
                assert record_response.WhichOneof('response') == 'update'
                assert record_response.update.status == api_pb2.TRspUpdateRecord.EUpdateRecordStatus.OK
                assert not record_response.update.error_message
                assert record_response.update.cluster in clusters
            except AssertionError as e:
                logger.exception(f"Failed to add records in zone {zone_name}")
                raise UpdateRecordsError(f"Failed to update records in {zone_name}. Error: {e}")

    logger.info(f"Done filling {zone_name}. Added {record_sets_num} record sets.")

    return records_by_rs_id


def list_zones(bridge_client, service_type, log_level=logging.INFO):
    logger.log(log_level, "Listing zones")
    logger.log(log_level, f"Request: (service_type: {service_type})")
    response = bridge_client.list_zones(service_type)
    logger.log(log_level, f"Response: (status: {response.status}, zones_num: {len(response.zones)})")
    return response


def create_zone(bridge_client, zone_id, clusters):
    zone_config = zone_pb2.TZoneConfig()
    zone_config.Name = zone_id
    zone_config.PrimaryNameserver = "ns1.testns.yandex.net"
    zone_config.Nameservers.extend(["ns1.testns.yandex.net"])
    zone_config.YPClusters.extend(clusters)
    logger.info(f"Creating zone {zone_id}")
    logger.info(f"Request: (zone_config:\n{zone_config})")
    response = bridge_client.create_zone(zone_config)
    logger.info(f"Response: {response}")
    return response


def remove_zone(bridge_client, zone_id):
    logger.info(f"Removing zone {zone_id}")
    logger.info(f"Request: (zone_id: {zone_id})")
    response = bridge_client.remove_zone(zone_id)
    logger.info(f"Response: {response}")
    return response


class ZoneStatusInListing(Enum):
    IN = 1
    OUT = 2


def check_wait_for_zone_in_listing(bridge_client, zone_name,
                                   target_state: ZoneStatusInListing, max_time):
    logger.info(f"Waiting for state {target_state.name} in zones listing for {zone_name}")
    start_time = time.time()
    time_by_service_type = {}
    while True:
        status_by_stage = {}
        status_by_service_type = {}
        for stage, service_types in enumerate(SERVICE_TYPES):
            for service_type in service_types:
                list_zones_resp = list_zones(bridge_client, service_type, log_level=logging.NOTSET)
                assert list_zones_resp.status == zm_api_pb2.TRspListZones.EStatus.OK

                zone_in_resp = any(map(lambda zone: zone.config.Name == zone_name, list_zones_resp.zones))
                if target_state == ZoneStatusInListing.OUT:
                    zone_in_resp = not zone_in_resp

                status_by_stage[stage] = zone_in_resp
                status_by_service_type[service_type] = zone_in_resp

                if zone_in_resp and service_type not in time_by_service_type:
                    time_by_service_type[service_type] = time.time()

                if target_state == ZoneStatusInListing.IN and zone_in_resp:
                    if stage != 0 and not status_by_stage[stage - 1]:
                        ok_service_types = filter(lambda st: status_by_service_type[st], status_by_service_type.keys())
                        not_ok_service_types = filter(lambda st: not status_by_service_type[st], status_by_service_type.keys())
                        raise InvalidZonesListing(
                            f"Zone {zone_name} is in status {target_state.name} in listing for service types {', '.join(ok_service_types)}, "
                            f"but not for {', '.join(not_ok_service_types)}"
                        )
        if all(status_by_stage.values()):
            break

    exceeded_time_service_types = [
        service_type for service_type, end_time in time_by_service_type.items()
        if end_time - start_time > max_time
    ]
    logger.info(f"Done waiting for {zone_name} to be in status {target_state.name}")
    logger.info(f"Time spent before {zone_name} appeared in status {target_state.name} in listing response by service type:\n" + '\n'.join(map(
        lambda kv: f"{kv[0]}: {kv[1] - start_time:.1f}s",
        time_by_service_type.items()
    )))

    for service_type, end_time in time_by_service_type.items():
        total_time = end_time - start_time
        solomon.set_value("zone_in_listing.time.seconds", total_time, {"target_state": target_state.name, "service_type": service_type})

    if exceeded_time_service_types:
        raise TooLongStatusChangeForZone(
            f"Too long status changing to {target_state.name} for zone {zone_name}. "
            "Time spent by service type: {}".format('\n'.join(map(
                lambda st: f"{st}: {time_by_service_type[st] - start_time:.1f}s",
                exceeded_time_service_types
            )))
        )


def cleanup(bridge_client, yp_clients, zones, args):
    list_zones_resp = list_zones(bridge_client, "BRIDGE")
    assert list_zones_resp.status == zm_api_pb2.TRspListZones.EStatus.OK
    actual_zones = set(zone.config.Name for zone in list_zones_resp.zones)

    logger.info("Cleaning up all testing zones")
    for zone_name in zones:
        logger.info(f"Cleaning up zone {zone_name}")

        cleanup_zone(bridge_client, yp_clients, zone_name, args.clusters)

        if zone_name in actual_zones:
            wait_for_zone_empty(yp_clients, zone_name, args.clusters)
            resp = remove_zone(bridge_client, zone_name)
            assert resp.status == zm_api_pb2.TRspRemoveZone.EStatus.OK
            check_wait_for_zone_in_listing(bridge_client, zone_name, ZoneStatusInListing.OUT, 10)
        logger.info(f"Cleanup has finished for zone {zone_name}")

    logger.info("Cleanup has finished for all testing zones")


def check_create_zone(bridge_client, zone_name, clusters):
    create_zone_resp = create_zone(bridge_client, zone_name, clusters)
    assert create_zone_resp.status == zm_api_pb2.TRspCreateZone.EStatus.OK


def juggler_notify(service, status, description):
    logger.info('Push juggler event.\nStatus: {}\nDescription:\n{}'.format(status, description))
    r = requests.post(
        "http://juggler-push.search.yandex.net/events",
        json={
            "source": "e2e_monitoring",
            "events": [
                {
                    "status": status,
                    "description": description,
                    "host": "e2e.yp_dns_api",
                    "service": service,
                    "instance": "",
                }
            ]
        },
        timeout=10
    )
    r.raise_for_status()


def run_checks(args):
    bridge_client = YpDnsApiBridgeClient(args.bridge_address)
    yp_clients = {
        cluster: YpClient(cluster, config={'token': find_token()})
        for cluster in args.clusters
    }

    zones = [f"zone-{zone_idx}.{args.zone_apex}" for zone_idx in range(args.zones_num)]

    cleanup(bridge_client, yp_clients, zones, args)
    if args.only_cleanup:
        return

    try:
        try:
            for zone_name in zones:
                check_create_zone(bridge_client, zone_name, args.clusters)
                check_wait_for_zone_in_listing(bridge_client, zone_name, ZoneStatusInListing.IN, args.max_zone_creation_time)
                check_fill_zone(bridge_client, zone_name, args.record_sets_num, args.clusters)
                # check DNS and track time for replication of records
        except AssertionError as e:
            raise UnclassifiedAssertionError(f"Assert statement failed: {e}")
    except ERROR_EXCEPTIONS as e:
        logger.exception("An error occured during checks")
        juggler_notify(
            service=e.JUGGLER_SERVICE,
            status='CRIT',
            description=f'YP DNS API E2E monitoring:\n{e}',
        )
        raise e

    for exception_type in ERROR_EXCEPTIONS:
        juggler_notify(
            service=exception_type.JUGGLER_SERVICE,
            status='OK',
            description='YP DNS API E2E monitoring: OK',
        )


def parse_args(argv):
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--zone-apex', default=DEFAULT_ZONE_APEX,
                        help='Zone apex')
    parser.add_argument('--zones-num', type=int, default=DEFAULT_ZONES_NUMBER,
                        help='Number of zones to create')
    parser.add_argument('--clusters', default=','.join(DEFAULT_CLUSTERS),
                        help='YP clusters for created zones')
    parser.add_argument('--max-zone-creation-time', type=float, default=DEFAULT_MAX_ZONE_CREATION_TIME,
                        help='Max acceptable time for zone creation')
    parser.add_argument('--record-sets-num', type=int, default=DEFAULT_RECORD_SETS_NUMBER,
                        help='Number of record sets to generate in created zones')
    parser.add_argument('--max-replication-time', type=float, default=DEFAULT_MAX_REPLICATION_TIME,
                        help='Max acceptable time for records replication')
    parser.add_argument('--only-cleanup', action='store_true',
                        help='Only remove all test zones and record sets and exit')
    parser.add_argument('--ns-address', default=DEFAULT_NS_ADDRESS,
                        help='Nameserver address')
    parser.add_argument('--bridge-address', default=DEFAULT_BRIDGE_GRPC_ADDRESS,
                        help='YP DNS API Bridge gRPC address')

    args = parser.parse_args(argv)
    args.clusters = args.clusters.split(',')
    return args


def main(argv):
    args = parse_args(argv)
    logger.setLevel(logging.DEBUG)
    run_checks(args)


if __name__ == '__main__':
    main(sys.argv[1:])
