import logging
import os
import sys
import subprocess
from threading import Timer
import requests
import collections
import json
import random
import datetime
from dateutil import parser
import pytz
from sandbox import sdk2
import sandbox.common.types.client as ctc


class InvalidSecretData(Exception):
    pass


class AwacsCheckCerts(sdk2.Task):
    """ Check certificates expiration in AWACS """

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 1024
        client_tags = ctc.Tag.LINUX_BIONIC

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Parameters):

        yav_awacs_token_secret = sdk2.parameters.YavSecret('YAV secret with AWACS token', required=True)
        abc_services = sdk2.parameters.List('ABC services', required=True)
        only_active = sdk2.parameters.Bool('Check only active certificates', default=True, required=True)
        check_ca = sdk2.parameters.List('Check certificates issued only by given CAs', required=False)
        warn_about_incomplete_spec = sdk2.parameters.Bool('Warn about incomplete specs', default=True, required=True)
        active_ssl_check = sdk2.parameters.Bool('Use active SSL check', default=True, required=True)
        warn_about_active_ssl_check_fail = sdk2.parameters.Bool('Warn about active SSL check fail', default=True,
                                                                required=True)
        warn_diff = sdk2.parameters.Integer('WARN days diff', default=30, required=True)
        crit_diff = sdk2.parameters.Integer('CRIT days diff', default=14, required=True)
        juggler_host = sdk2.parameters.String('Juggler host', default='juggler.sb.awacs_check_certs', required=True)
        juggler_service = sdk2.parameters.String('Juggler service', required=True)
        juggler_tags = sdk2.parameters.List('Juggler tags', required=False)
        print_sb_task_url = sdk2.parameters.Bool('Print SB task URL', required=True, default=False)

    def on_execute(self):
        def nested_dict():
            return collections.defaultdict(nested_dict)

        def http_response(response, decode_json=False):
            """ Get response content """
            if response.status_code != requests.codes.ok:
                sys.exit('Error {} [{}], {}: {}'.format(
                    response.status_code,
                    response.url,
                    response.reason,
                    response.text)
                )
            content = response.json() if decode_json else response.text
            return content

        def request_post(url, headers, post_content, decode_json=False):
            """ Make POST request """
            response = requests.post(url, headers=headers, json=post_content, timeout=10)
            content = http_response(response, decode_json)
            return content

        def request_get(url, headers, decode_json=False):
            """ Make GET request """
            response = requests.get(url, headers=headers, timeout=10)
            content = http_response(response, decode_json)
            return content

        def awacs_list_namespaces_ids(abc_service, headers):
            """ List AWACS namespaces IDs """
            url = 'https://awacs.yandex-team.ru/api/ListNamespaces/'
            post_content = {
                'skip': 0,
                'limit': 0,
                'query': {
                    'abcServiceIdIn': [abc_service]
                },
                'fieldMask': 'meta.id'
            }
            content = request_post(url, headers, post_content, True)
            namespace_ids = [namespace['meta']['id'] for namespace in content['namespaces']]
            return namespace_ids

        def awacs_list_certificates_ids(namespace_id, headers):
            """ List certificates IDs """
            url = 'https://awacs.yandex-team.ru/api/ListCertificates/'
            post_content = {
                'skip': 0,
                'limit': 0,
                'namespaceId': namespace_id,
            }
            content = request_post(url, headers, post_content, True)
            certs_ids = [certificate['meta']['id'] for certificate in content['certificates']]
            return certs_ids

        def awacs_get_certificate(namespace_id, id, headers):
            """ Get certificate info """
            url = 'https://awacs.yandex-team.ru/api/GetCertificate/'
            post_content = {
                'id': id,
                'namespaceId': namespace_id
            }
            cert_info = request_post(url, headers, post_content, True)
            return cert_info

        def awacs_list_balancers(namespace_id, headers):
            """ List AWACS balancers """
            url = 'https://awacs.yandex-team.ru/api/ListBalancers/'
            post_content = {
                'skip': 0,
                'limit': 0,
                'namespaceId': namespace_id,
                'fieldMask': 'meta.id,spec.configTransport.nannyStaticFile.serviceId'
            }
            content = request_post(url, headers, post_content, True)
            balancers = {balancer['meta']['id']: balancer['spec']['configTransport']['nannyStaticFile']['serviceId']
                         for balancer in content['balancers']}
            return balancers

        def awacs_list_domains(namespace_id, balancer_id, headers):
            """ List domains """
            url = 'https://awacs.yandex-team.ru/api/ListDomains/'
            post_content = {
                'skip': 0,
                'limit': 0,
                'namespaceId': namespace_id,
                'balancerId': balancer_id,
            }
            content = request_post(url, headers, post_content, True)
            domains_info = {}
            for domain in content['domains']:
                if domain['spec']['incomplete']:
                    continue
                cert_id = domain['spec']['yandexBalancer']['config']['cert']['id']
                fqdns = domain['spec']['yandexBalancer']['config']['fqdns']
                shadow_fqdns = domain['spec']['yandexBalancer']['config']['shadowFqdns']
                merged_fqdns = list(set(fqdns + shadow_fqdns))
                domains_info[cert_id] = merged_fqdns
            return domains_info

        def nanny_get_instances(nanny_service_id, headers):
            """ Get Nanny instances """
            url = 'http://nanny.yandex-team.ru/v2/services/{}/current_state/instances/partial/'.format(nanny_service_id)
            content = request_get(url, headers, True)
            instances = [balancer['container_hostname'] for balancer in content['instancesPart']]
            return instances

        def get_certificate_serial_number(hostname, sni, port=443):
            """ Connect and get certificate serial number """
            devnull = os.open(os.devnull, os.O_RDWR)
            process_1 = subprocess.Popen(['/usr/bin/openssl', 's_client', '-connect', '{}:{}'.format(hostname, port),
                                          '-servername', sni],
                                         stdin=devnull,
                                         stdout=subprocess.PIPE,
                                         stderr=subprocess.PIPE)
            process_2 = subprocess.Popen(['/usr/bin/openssl', 'x509', '-noout', '-serial'],
                                         stdin=process_1.stdout,
                                         stdout=subprocess.PIPE,
                                         stderr=subprocess.PIPE,
                                         universal_newlines=True)
            timer = Timer(1, process_2.kill)
            try:
                timer.start()
                output, _ = process_2.communicate()
            finally:
                timer.cancel()
            cert_serial_number = output.strip().partition('=')[2].lstrip('0') if process_2.returncode == 0 else None
            return cert_serial_number

        def send_event_to_juggler(juggler_host, juggler_service, juggler_events, juggler_tags, print_sb_task_url):
            """ Send event to Juggler """
            logging.info('Juggler events: {}'.format(json.dumps(juggler_events, indent=2)))

            # Juggler status
            juggler_status = 'OK'
            if 'CRIT' in juggler_events:
                juggler_status = 'CRIT'
            elif 'WARN' in juggler_events:
                juggler_status = 'WARN'
            logging.info('Juggler status: {}'.format(juggler_status))

            # Prepare Juggler description
            juggler_description = []
            if juggler_status == 'OK':
                juggler_description.append('OK')
            else:
                for status in ['CRIT', 'WARN']:
                    if status in juggler_events:
                        for namespace_id in juggler_events[status].keys():
                            for cert_id in juggler_events[status][namespace_id].keys():
                                for event in juggler_events[status][namespace_id][cert_id]:
                                    description = '{}: {}:{} certificate {} [{}]'.format(
                                        status,
                                        namespace_id,
                                        cert_id,
                                        event['desc'],
                                        event['sectask']
                                    )
                                    logging.info('{} {}'.format(description, event['link']))
                                    juggler_description.append(description)

            sb_task_url = 'https://sandbox.yandex-team.ru/task/{}/view'.format(sdk2.Task.current.id)
            if print_sb_task_url:
                juggler_description.append(sb_task_url)

            juggler_description_str = '\n'.join(juggler_description)
            if len(juggler_description_str) > 1024:
                juggler_description_str = 'Too many objects to show. See {}'.format(sb_task_url)
            logging.info('Juggler description: {}'.format(juggler_description_str))

            url = 'http://juggler-push.search.yandex.net/events'
            events = {
                'host': juggler_host,
                'service': juggler_service,
                'status': juggler_status,
                'description': juggler_description_str
            }
            if juggler_tags:
                events['tags'] = juggler_tags
            post_content = {
                'source': 'sandbox',
                'events': [
                    events
                ]
            }
            content = request_post(url, awacs_headers, post_content, True)
            logging.info('Successfully sent event to Juggler'
                         if content['events'][0]['code'] == 200 else 'Failed to send event to Juggler')
            return content

        # Fill variables from SB
        yav_awacs_token = self.Parameters.yav_awacs_token_secret
        abc_services = self.Parameters.abc_services
        only_active = self.Parameters.only_active
        check_ca = self.Parameters.check_ca
        warn_about_incomplete_spec = self.Parameters.warn_about_incomplete_spec
        active_ssl_check = self.Parameters.active_ssl_check
        warn_about_active_ssl_check_fail = self.Parameters.warn_about_active_ssl_check_fail
        warn_diff = self.Parameters.warn_diff
        crit_diff = self.Parameters.crit_diff
        juggler_host = self.Parameters.juggler_host
        juggler_service = self.Parameters.juggler_service
        juggler_tags = self.Parameters.juggler_tags
        print_sb_task_url = self.Parameters.print_sb_task_url

        logging.info('YAV secret with AWACS token: {}'.format(yav_awacs_token))
        logging.info('ABC services: {}'.format(abc_services))
        logging.info('Check only active certificates: {}'.format(only_active))
        logging.info('Check certificates issued only by given CAs: {}'.format(check_ca))
        logging.info('Warn about incomplete specs: {}'.format(warn_about_incomplete_spec))
        logging.info('Use active SSL check: {}'.format(active_ssl_check))
        logging.info('Warn about active SSL check fail: {}'.format(warn_about_active_ssl_check_fail))
        logging.info('WARN diff: {}'.format(warn_diff))
        logging.info('CRIT diff: {}'.format(crit_diff))
        logging.info('Juggler host: {}'.format(juggler_host))
        logging.info('Juggler service: {}'.format(juggler_service))
        logging.info('Juggler tags: {}'.format(juggler_tags))
        logging.info('Print SB task URL: {}'.format(print_sb_task_url))

        # Get AWACS token from YAV
        logging.info('Getting YAV secret')
        awacs_token = yav_awacs_token.value()

        # Prepare variables
        current_date = pytz.utc.localize(datetime.datetime.utcnow())
        awacs_headers = {
            'Authorization': 'OAuth {}'.format(awacs_token),
            'Content-Type': 'application/json'
        }
        nanny_headers = {
            'Content-Type': 'application/json'
        }
        juggler_events = nested_dict()
        balancers_cache = nested_dict()
        balancers_domains_cache = nested_dict()

        # Check given ABC services
        for abc_service in abc_services:
            # Get AWACS namespaces IDs
            namespace_ids = awacs_list_namespaces_ids(abc_service, awacs_headers)
            # Check AWACS namespaces
            for namespace_id in namespace_ids:
                # Get certificates IDs
                certs_ids = awacs_list_certificates_ids(namespace_id, awacs_headers)
                # Check certificates
                for cert_id in certs_ids:
                    cert_info = awacs_get_certificate(namespace_id, cert_id, awacs_headers)
                    active_balancers = [x.split(':')[1] for x in
                                        cert_info['certificate']['statuses'][0]['active'].keys()]
                    active_balancers_count = len(active_balancers)

                    # Check only active certificates
                    if only_active and (active_balancers_count == 0):
                        continue
                    # Check certificates issued only by given CAs
                    ca_name = cert_info['certificate']['spec'].get('certificator', {}).get('caName')
                    if check_ca and (ca_name is not None) and (not (ca_name in check_ca)):
                        continue

                    sectask_tiket = cert_info['certificate']['spec']['certificator']['approval']['startrek'][
                        'issueId'] if cert_info['certificate']['spec'].get('certificator', {}).get('approval', {}).get(
                        'startrek', {}).get('issueId') is not None else 'SECTASK-?'
                    spec_incomplete = cert_info['certificate']['spec']['incomplete']
                    awacs_cert_link = 'https://nanny.yandex-team.ru/ui/#/awacs/namespaces/list/' + \
                                      '{}/certs/list/{}/show/'.format(namespace_id, cert_id)

                    logging.info('Namespace {} certificate {}'.format(namespace_id, cert_id))
                    logging.info('\tAWACS certificate link: {}'.format(awacs_cert_link))
                    logging.info('\tActive on {} balancers: {}'.format(active_balancers_count,
                                                                       ', '.join(active_balancers)))
                    logging.info('\tCA name: {}'.format(ca_name))
                    logging.info('\tSECTASK ticket: {}'.format(sectask_tiket))
                    logging.info('\tSpec incomplete: {}'.format(spec_incomplete))

                    # Nothing to check if spec is incomplete
                    if spec_incomplete:
                        # Warn about incomplete spec
                        if warn_about_incomplete_spec:
                            if cert_id not in juggler_events['WARN'][namespace_id]:
                                juggler_events['WARN'][namespace_id][cert_id] = []
                            juggler_events['WARN'][namespace_id][cert_id].append({
                                'desc': 'spec is incomplete',
                                'sectask': sectask_tiket,
                                'link': awacs_cert_link
                            })
                        continue

                    cert_serial_number = cert_info['certificate']['spec']['fields']['serialNumber'].lstrip('0')
                    domains = cert_info['certificate']['spec']['fields']['subjectAlternativeNames']
                    valid_not_before = cert_info['certificate']['spec']['fields']['validity']['notBefore']
                    valid_not_after = cert_info['certificate']['spec']['fields']['validity']['notAfter']

                    logging.info('\tCertificate serial number: {}'.format(cert_serial_number))
                    logging.info('\tDomains: {}'.format(', '.join(domains)))
                    logging.info('\tValid not before: {}'.format(valid_not_before))
                    logging.info('\tValid not after: {}'.format(valid_not_after))

                    # Calculate certificate expiration
                    if valid_not_after is None:
                        if cert_id not in juggler_events['WARN'][namespace_id]:
                            juggler_events['WARN'][namespace_id][cert_id] = []
                        juggler_events['WARN'][namespace_id][cert_id].append({
                            'desc': 'expiration date is unknown',
                            'sectask': sectask_tiket,
                            'link': awacs_cert_link
                        })
                        continue

                    valid_till = parser.parse(valid_not_after)
                    expires_in = valid_till - current_date

                    logging.info('\tExpires in: {} days'.format(expires_in.days))

                    if active_ssl_check:
                        # Fill cache if needed
                        if namespace_id not in balancers_cache:
                            balancers_cache[namespace_id] = awacs_list_balancers(namespace_id, awacs_headers)
                            for balancer in balancers_cache[namespace_id]:
                                balancers_cache[namespace_id][balancer] = nanny_get_instances(
                                    balancers_cache[namespace_id][balancer], nanny_headers)
                                balancers_domains_cache[namespace_id][balancer] = awacs_list_domains(
                                    namespace_id, balancer, awacs_headers)

                        # Check certificate on balancer instances
                        for balancer in active_balancers:
                            balancer_domains = balancers_domains_cache[namespace_id][balancer].get(cert_id)
                            filtered_domains = balancer_domains if balancer_domains is not None else domains
                            logging.info('\tSNIs for active check: {}'.format(', '.join(filtered_domains)))
                            for instance in balancers_cache[namespace_id][balancer]:
                                random_domain = random.choice(filtered_domains)
                                instance_cert_serial_number = get_certificate_serial_number(instance, random_domain)
                                if instance_cert_serial_number is None:
                                    logging.info('\tFailed to connect to balancer {} instance {}'.format(
                                        balancer, instance))
                                    if warn_about_active_ssl_check_fail:
                                        if cert_id not in juggler_events['WARN'][namespace_id]:
                                            juggler_events['WARN'][namespace_id][cert_id] = []
                                        juggler_events['WARN'][namespace_id][cert_id].append({
                                            'desc': 'active SSL check failed on balancer {} instance {}'.format(
                                                balancer, instance),
                                            'sectask': sectask_tiket,
                                            'link': awacs_cert_link
                                        })
                                else:
                                    logging.info('\tCertificate serial number on balancer {} instance {}: {}'.format(
                                        balancer, instance, instance_cert_serial_number))
                                    if instance_cert_serial_number != cert_serial_number:
                                        if cert_id not in juggler_events['CRIT'][namespace_id]:
                                            juggler_events['CRIT'][namespace_id][cert_id] = []
                                        juggler_events['CRIT'][namespace_id][cert_id].append({
                                            'desc': '{} differs from spec one {} on balancer {} instance {}'.format(
                                                instance_cert_serial_number, cert_serial_number, balancer, instance),
                                            'sectask': sectask_tiket,
                                            'link': awacs_cert_link
                                        })

                    # Calculate Juggler event status
                    juggler_status = 'OK'
                    if expires_in.days <= crit_diff:
                        juggler_status = 'CRIT'
                    elif expires_in.days <= warn_diff:
                        juggler_status = 'WARN'
                    if juggler_status == 'OK':
                        continue

                    # Add event
                    if cert_id not in juggler_events[juggler_status][namespace_id]:
                        juggler_events[juggler_status][namespace_id][cert_id] = []
                    juggler_events[juggler_status][namespace_id][cert_id].append({
                        'desc': 'expires in {} days at {}'.format(
                            expires_in.days,
                            valid_not_after),
                        'sectask': sectask_tiket,
                        'link': awacs_cert_link
                    })

        # Send to Juggler
        send_event_to_juggler(juggler_host, juggler_service, juggler_events, juggler_tags, print_sb_task_url)
