import logging
import json
import os
import time
from datetime import datetime

import requests
from dateutil import parser

try:
    import OpenSSL
    from dateutil.tz import tzutc
    HAS_EXPIRATION_CHECK = True
except ImportError:
    HAS_EXPIRATION_CHECK = False

LOG = logging.getLogger(__name__)

COMPONENTS_ENABLED = set([
    'components.mongodb',
    'components.clickhouse',
])

CA_ONLY_COMPONENTS = set([
    'components.pg-barman',
])

CERTIFICATOR_BASE_URL = 'https://crt-api.yandex-team.ru/api/certificate/'


def _expired(crt):
    try:
        x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
                                               crt)
        delta = parser.parse(x509.get_notAfter()) - datetime.now(tz=tzutc())
        if delta.days < 30:
            return True
    except Exception as exc:
        LOG.error('Unable to check expiration: %s', repr(exc))
        return True

    return False


def _get_existing_cert(server_key_root, token, hostname):
    if not os.path.exists(server_key_root):
        os.mkdir(server_key_root)
        return

    ret = {}

    for suffix in ['crt', 'key']:
        path = os.path.join(
            server_key_root, 'server.{suffix}'.format(
                suffix=suffix))
        if not os.path.exists(path):
            return
        with open(path) as inp_file:
            res = inp_file.read()
            if not res:
                return
        ret['{prefix}.{suffix}'.format(prefix='cert', suffix=suffix)] = res

    if HAS_EXPIRATION_CHECK and _expired(ret['cert.crt']):
        return

    return ret


def _get_all_certs(headers, hostname):
    """
    Get all certs for hostname
    """
    res = requests.get(
        '{base}?host={name}'.format(base=CERTIFICATOR_BASE_URL, name=hostname),
        headers=headers).json()
    return res.get('results', [])


def _get_latest_cert(headers, hostname):
    """
    Find latest cert in certificator with not expired key
    """
    cert_info = {}
    max_issued = 0
    now = time.time()
    for result in _get_all_certs(headers, hostname):
        if result.get('status') == 'issued' and \
                result.get('revoked') is None:
            issue_ts = time.mktime(parser.parse(result['issued']).timetuple())
            end_ts = time.mktime(parser.parse(result['end_date']).timetuple())
            if end_ts - 30 * 24 * 3600 > now \
                    and issue_ts > max_issued and \
                    not result.get('priv_key_deleted_at'):
                cert_info = result
                max_issued = issue_ts

    return cert_info


def _issue_new(headers, hostname):
    """
    Issue new cert with specified CA
    """
    issue_data = {
        'type': 'host',
        'hosts': hostname,
        'ca_name': 'InternalCA',
    }

    res = requests.post(
        CERTIFICATOR_BASE_URL,
        data=json.dumps(issue_data),
        headers=headers)
    if res.status_code != 201:
        raise RuntimeError(
            'Unexpected certificator response: {status_code} {text}'.format(
                status_code=res.status_code, text=res.text))

    return res.json()


def _download_cert(headers, cert_info):
    """
    Get cert and key from certificator
    """
    if 'download2' not in cert_info:
        raise RuntimeError(
            'Download url not found in {cert}'.format(cert=cert_info))
    res = requests.get(
        '{url}.pem'.format(url=cert_info['download2']),
        headers={'Authorization': headers['Authorization']}).text
    key = ''
    key_lines = False
    cert = ''
    for line in res.split('\n'):
        if 'PRIVATE KEY--' in line:
            if '--BEGIN' in line:
                key_lines = True
                key += line + '\n'
            elif '--END' in line:
                key_lines = False
                key += line + '\n'
        elif key_lines:
            key += line + '\n'
        else:
            cert += line + '\n'

    if not key:
        raise RuntimeError('No private key in: {res}'.format(res=res))

    return key, cert


def _issue_cert(server_key_root, token, hostname):
    headers = {
        'Authorization': 'OAuth {token}'.format(token=token),
        'Accept': 'application/json',
        'Content-Type': 'application/json',
    }
    cert_info = _get_latest_cert(headers, hostname)
    if not cert_info:
        cert_info = _issue_new(headers, hostname)

    key, cert = _download_cert(headers, cert_info)

    with open(os.path.join(server_key_root, 'server.crt'), 'w') as cert_file:
        cert_file.write(cert)

    with open(os.path.join(server_key_root, 'server.key'), 'w') as key_file:
        key_file.write(key)

    return {'cert.key': key, 'cert.crt': cert}


def ext_pillar(minion_id, pillar, token, key_root=''):
    """
    Certificator ext_pillar
    """
    # Immediately return if certs are already in pillar
    if 'cert.key' in pillar and 'cert.crt' in pillar:
        return {}

    # return only cert.ca for pg-barman
    ca_path = os.path.join(key_root, 'prod/allCAs.pem')
    run_list = set(pillar.get('data', {}).get('runlist', []))
    if run_list.intersection(CA_ONLY_COMPONENTS):
        with open(ca_path, 'r') as inp_file:
            return {'cert.ca': inp_file.read()}

    if not (pillar.get('data', {}).get('pg_ssl', False) or
            run_list.intersection(COMPONENTS_ENABLED)):
        return {}

    hostname = pillar.get('data', {}).get('pg_ssl_balancer', minion_id)
    server_key_root = os.path.join(key_root, 'prod', hostname)
    ret = _get_existing_cert(server_key_root, token, hostname)
    if not ret:
        ret = _issue_cert(server_key_root, token, hostname)

    with open(ca_path, 'r') as inp_file:
        ret['cert.ca'] = inp_file.read()

    return ret
