import argparse
import logging
import os
import requests
import sys
import retry

from requests.packages.urllib3 import Retry

from infra.yp_service_discovery.api import api_pb2
from infra.yp_service_discovery.python.resolver.resolver import Resolver
from yp.client import YpClient, find_token
from yp.common import YtResponseError

from sandbox.projects.common.nanny.client import NannyClient


SD_CLIENT_NAME = 'yp.monitoring.chaos_service_discovery'
CHAOS_ESID = 'chaos-service-slave'

SD_NANNY_SERVICES = [
    'sas_yp_service_discovery',
    'man_yp_service_discovery',
    'vla_yp_service_discovery',
    'msk_yp_service_discovery',
]

SD_BALANCERS = [
    'sas.sd.yandex.net',
    'man.sd.yandex.net',
    'vla.sd.yandex.net',
    'msk.sd.yandex.net',
    'sd.yandex.net',
]

SD_HTTP_PORT = '8080'
SD_GRPC_PORT = '8081'

ADDR_TYPES = ['balancers', 'backends']


logger = logging.getLogger()
sh = logging.StreamHandler()
sh.setFormatter(logging.Formatter('[%(asctime)s] [%(levelname)s] %(message)s'))
logger.addHandler(sh)


def get_retry_session(retries=5, backoff_factor=2, status_forcelist=(500, 502, 503, 504)):
    session = requests.Session()
    adapter = requests.adapters.HTTPAdapter(max_retries=Retry(
        total=retries,
        read=retries,
        connect=retries,
        backoff_factor=backoff_factor,
        status_forcelist=status_forcelist,
    ))
    session.mount('http://', adapter)
    return session


@retry.retry(tries=5, delay=2, backoff=2)
def juggler_notify(cluster, status, description):
    result = get_retry_session().post(
        "http://juggler-push.search.yandex.net/events",
        json={
            "source": "chaos_service_discovery_monitoring",
            "events": [
                {
                    "description": description,
                    "host": "yp-{}.yandex.net".format(cluster),
                    "instance": "",
                    "service": "chaos_service_discovery_monitoring",
                    "status": status,
                }
            ]
        },
        timeout=10
    )
    result.raise_for_status()


def get_current_clusters():
    clusters = os.getenv('YP_MONITORING_CLUSTERS')
    if clusters is not None:
        return clusters.split(',')

    return [os.uname()[1].split('.')[1]]


def get_sd_addresses(addr_types=ADDR_TYPES):
    if isinstance(addr_types, str):
        addr_types = [addr_types]

    result = []
    if 'backends' in addr_types:
        nanny_client = NannyClient('http://nanny.yandex-team.ru', os.environ['OAUTH_NANNY'])

        for service in SD_NANNY_SERVICES:
            result += list(map(lambda instance: instance['container_hostname'],
                               nanny_client.get_service_current_instances(service)['result']))
    if 'balancers' in addr_types:
        result += SD_BALANCERS

    return list(map(lambda address: '{}:{}'.format(address, SD_GRPC_PORT), result))


def get_sd_resolver(address):
    if not hasattr(get_sd_resolver, 'resolvers'):
        get_sd_resolver.resolvers = {}
    if address not in get_sd_resolver.resolvers:
        get_sd_resolver.resolvers[address] = Resolver(client_name=SD_CLIENT_NAME, grpc_address=address, timeout=5)
    return get_sd_resolver.resolvers[address]


def get_yp_client(address):
    if not hasattr(get_yp_client, 'clients'):
        get_yp_client.clients = {}
    if address not in get_yp_client.clients:
        get_yp_client.clients[address] = YpClient(address=address, config={'token': find_token()})
    return get_yp_client.clients[address]


class ResolveChaosServiceEndpointsResult:
    def __init__(self, timestamp, endpoint_set, prev_timestamp):
        self._timestamp = timestamp
        self._endpoint_set = endpoint_set
        self._prev_timestamp = prev_timestamp

    @property
    def timestamp(self):
        return self._timestamp

    @property
    def endpoint_set(self):
        return self._endpoint_set

    @property
    def prev_timestamp(self):
        return self._prev_timestamp


def sort_endpoints(endpoints):
    endpoints.sort(key=lambda endpoint: endpoint.id)


@retry.retry(tries=5, delay=2, backoff=2)
def resolve_via_sd(sd_resolver, cluster, esid):
    request = api_pb2.TReqResolveEndpoints()
    request.cluster_name = cluster
    request.endpoint_set_id = esid

    sd_response = sd_resolver.resolve_endpoints(request)
    sort_endpoints(sd_response.endpoint_set.endpoints)
    logger.debug("SD answer to (cluster={cluster}, esid={esid})\n{response}".format(
        cluster=cluster, esid=esid, response=sd_response
    ))
    return sd_response


@retry.retry(tries=5, delay=2, backoff=2)
def resolve_via_yp_client(cluster, esid, timestamp=None):
    yp_client = get_yp_client(cluster)

    if timestamp is None:
        timestamp = yp_client.generate_timestamp()

    results = yp_client.select_objects(
        "endpoint",
        selectors=["/meta/id", "/spec", "/status"],
        filter="[/meta/endpoint_set_id] = '{}'".format(esid),
        timestamp=timestamp
    )

    endpoint_set = api_pb2.TEndpointSet()

    creation_timestamps = []
    for meta_id, spec, status in results:
        endpoint = endpoint_set.endpoints.add()
        endpoint.id = meta_id
        if "protocol" in spec:
            endpoint.protocol = spec["protocol"]
        if "fqdn" in spec:
            endpoint.fqdn = spec["fqdn"]
            creation_timestamps.append(int(spec["fqdn"].split(".")[0].split("-")[1]))
        if "ip4_address" in spec:
            endpoint.ip4_address = spec["ip4_address"]
        if "ip6_address" in spec:
            endpoint.ip6_address = spec["ip6_address"]
        if "port" in spec:
            endpoint.port = spec["port"]
        if "ready" in status:
            endpoint.ready = status["ready"]

    if endpoint_set.endpoints:
        endpoint_set.endpoint_set_id = esid

    prev_timestamp = sorted(creation_timestamps)[-2] if len(creation_timestamps) >= 2 else 0

    sort_endpoints(endpoint_set.endpoints)

    logger.debug("YP master answer to (cluster={cluster}, esid={esid}, ts={timestamp}):\n{endpoint_set}".format(
        cluster=cluster, esid=esid, timestamp=timestamp, endpoint_set=endpoint_set
    ))

    return ResolveChaosServiceEndpointsResult(timestamp, endpoint_set, prev_timestamp)


def check_consistency(cluster, esid, sd_resolver, addr_type):
    error = None
    try:
        sd_response = resolve_via_sd(sd_resolver, cluster, esid)
        yp_response = resolve_via_yp_client(cluster, esid, sd_response.timestamp)

        status = 'OK' if sd_response.endpoint_set == yp_response.endpoint_set else 'CRIT'
    except Exception as e:
        error = str(e)
        logger.exception(e)
        if addr_type == 'balancers':
            status = 'CRIT'
        else:
            status = 'NO_DATA'

    return status, error


def ping(address, port):
    url = 'http://{}:{}/ping'.format(address, port)
    r = get_retry_session().get(url, timeout=10)
    r.raise_for_status()
    return r


def check_number_of_endpoints(cluster, esid, threshold):
    def is_active(endpoint):
        try:
            return ping('[{}]'.format(endpoint.ip6_address), endpoint.port).status_code == 200
        except Exception as e:
            logger.debug(e)
            return False

    all_endpoints = []
    active_endpoints = []
    error = None
    try:
        all_endpoints = resolve_via_yp_client(cluster, esid).endpoint_set.endpoints
        active_endpoints = list(filter(lambda endpoint: is_active(endpoint), all_endpoints))

        status = 'OK' if len(active_endpoints) >= threshold else 'CRIT'
    except Exception as e:
        error = str(e)
        logger.exception(e)
        status = 'NO_DATA'

    return status, len(active_endpoints), len(all_endpoints), error


def check_endpoints_moving(cluster, esid):
    error = None
    try:
        yp_response = resolve_via_yp_client(cluster, esid)
        current_endpoints = yp_response.endpoint_set

        too_old_timestamp = False
        try:
            yp_response_old = resolve_via_yp_client(cluster, esid, yp_response.prev_timestamp)
            old_endpoints = yp_response_old.endpoint_set
        except YtResponseError:
            too_old_timestamp = True

        status = 'OK' if not too_old_timestamp and current_endpoints.endpoints != old_endpoints.endpoints else 'CRIT'
    except Exception as e:
        error = str(e)
        logger.exception(e)
        status = 'NO_DATA'

    return status, error


def check_number_of_instances(service):
    active, total, error = None, None, None
    try:
        nanny_client = NannyClient('http://nanny.yandex-team.ru', os.environ['OAUTH_NANNY'])

        hostnames = list(map(lambda instance: instance['container_hostname'],
                             nanny_client.get_service_current_instances(service)['result']))
        total = len(hostnames)
        active = len(list(filter(lambda hostname: ping(hostname, SD_HTTP_PORT).status_code == 200, hostnames)))

        assert active <= total

        if active == total:
            status = 'OK'
        elif active + 1 == total:
            status = 'WARN'
        else:
            status = 'CRIT'
    except Exception as e:
        error = str(e)
        logger.exception(e)
        status = 'NO_DATA'

    return status, active, total, error


def check_chaos_service_consistency(cluster):
    result = {addr_type: {'OK': [], 'CRIT': [], 'NO_DATA': []} for addr_type in ADDR_TYPES}
    last_error = None

    sum_oks, sum_crits, sum_no_datas = 0, 0, 0
    for addr_type in ADDR_TYPES:
        for sd_address in get_sd_addresses(addr_type):
            logging.debug('Check {}'.format(sd_address))
            status, error = check_consistency(cluster, CHAOS_ESID, get_sd_resolver(sd_address), addr_type)
            if error is not None:
                last_error = error
            result[addr_type][status].append(sd_address)
        sum_oks += len(result[addr_type]['OK'])
        sum_crits += len(result[addr_type]['CRIT'])
        sum_no_datas += len(result[addr_type]['NO_DATA'])

    if sum_crits + sum_no_datas == 0:
        status = 'OK'
        description = "Consistent with YP master"
    else:
        logger.warning('Consistency check result:\n{}'.format(result))
        status = 'CRIT'

        if len(result['backends']['CRIT']) + len(result['backends']['NO_DATA']) == 0:
            description = "{} do not respond correctly".format(
                ', '.join(result['balancers']['CRIT'] + result['balancers']['NO_DATA'])
            )
        else:
            description = "{crits} SD instances DO NOT maintain consistency with master; {oks} are OK; NO DATA from {no_datas}".format(
                crits=len(result['backends']['CRIT']), oks=len(result['backends']['OK']), no_datas=len(result['backends']['NO_DATA'])
            )

        if sum_no_datas:
            description = "{description}, last NO_DATA error:\n{error}".format(
                description=description, error=last_error
            )

    return sum_oks, sum_crits, sum_no_datas, "{status}: {description};".format(status=status, description=description)


def check_chaos_service_number_of_endpoints(cluster):
    threshold = 1
    status, active, total, error = check_number_of_endpoints(cluster, CHAOS_ESID, threshold)

    if status in ['OK', 'CRIT']:
        description = "{active}/{total} instances of chaos-service-slave are ACTIVE (threshold = {threshold})".format(active=active, total=total, threshold=threshold)
    else:
        description = "An error occured while checking:\n{error}".format(error=error)

    return status == 'OK', status == 'CRIT', status == 'NO_DATA', "{status}: {description};".format(status=status, description=description)


def check_chaos_service_endpoints_moving(cluster):
    status, error = check_endpoints_moving(cluster, CHAOS_ESID)

    if status == 'OK':
        description = "Endpoints are moving"
    elif status == 'CRIT':
        description = "Endpoints do not move for too long"
    elif status == 'NO_DATA':
        description = "An error occured while checking:\n{error}".format(error=error)

    return status == 'OK', status == 'CRIT', status == 'NO_DATA', "{status}: {description};".format(status=status, description=description)


def check_service_discovery_number_of_instances(cluster):
    in_ok, in_warn, in_crit = [], [], []
    last_error = None
    for sd_service in SD_NANNY_SERVICES:
        status, active, total, error = check_number_of_instances(sd_service)

        descr = "{service} ({active}/{total})".format(service=sd_service, active=active, total=total)
        if status == 'NO_DATA':
            last_error = error
        elif status == 'WARN':
            in_warn.append(descr)
        elif status == 'CRIT':
            in_crit.append(descr)
        else:
            in_ok.append(descr)

    status = 'OK'
    description = 'Check SD instances activeness: '

    if in_ok:
        status = 'OK'
        description += 'in OK: {}; '.format(', '.join(in_ok))

    if in_warn:
        status = 'WARN'
        description += 'in WARN: {}; '.format(', '.join(in_warn))

    if in_crit:
        status = 'CRIT'
        description += 'in CRIT: {}; '.format(', '.join(in_crit))

    if last_error is not None:
        if status == 'OK':
            status = 'NO_DATA'
        description += "An error occured while checking:\n{error}".format(error=last_error)

    return status == 'OK', status == 'CRIT', status == 'WARN' or status == 'NO_DATA', "{status}: {description}".format(status=status, description=description)


CHECK_FUNCS = [
    check_chaos_service_consistency,
    check_chaos_service_number_of_endpoints,
    check_chaos_service_endpoints_moving,
    check_service_discovery_number_of_instances,
]


def parse_args(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('--clusters')
    parser.add_argument('--yp-token')
    parser.add_argument('--nanny-token')
    parser.add_argument('--verbose', action='store_true')
    args = parser.parse_args(argv)

    if args.clusters is not None:
        args.clusters = args.clusters.split(',')

    if args.yp_token is not None:
        os.environ['YP_TOKEN'] = args.yp_token

    if args.nanny_token is not None:
        os.environ['OAUTH_NANNY'] = args.nanny_token

    return args


def main(argv):
    args = parse_args(argv)

    if args.verbose:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    clusters = args.clusters if args.clusters is not None else get_current_clusters()
    for cluster in clusters:
        logger.info("Checking Chaos Service in {cluster}".format(cluster=cluster))

        aggr_status = 'OK'
        aggr_description = ''

        for check in CHECK_FUNCS:
            oks, crits, no_datas, description = check(cluster)

            if crits:
                aggr_status = 'CRIT'
            elif no_datas:
                if aggr_status != 'CRIT':
                    aggr_status = 'WARN'

            aggr_description += '{}\n'.format(description)
        logger.info("Result status: {aggr_status}".format(aggr_status=aggr_status))
        logger.info("Additional info:\n{aggr_description}".format(aggr_description=aggr_description))
        juggler_notify(cluster, aggr_status, aggr_description)


if __name__ == '__main__':
    main(sys.argv[1:])
