import datetime
import errno
import json
import logging
import os
import random
import time
import urllib.parse as urlparse

from asn1crypto import crl
from asn1crypto import pem
from asn1crypto import keys
from asn1crypto import x509

import requests
from urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter

import pytz  # asn1crypto returns tz-aware datetime objects

from infra.rtc.certman import fileutil

log = logging.getLogger(__name__)

PROD_SERVER = 'crt-api.yandex-team.ru'
PROD_PORT = 444
RTC_ISSUER = {'common_name': 'YandexRCCA', 'domain_component': ['ru', 'yandex', 'ld']}

# CRL can be pretty big, current value for InternalCA is ~5.7Mb, though RCCA is small
CRL_MAX_SIZE_BYTES = 10 * 1024 * 1024
RCCA_CRL_URL = 'http://crls.yandex.ru/YandexRCCA/YandexRCCA.crl'


def crl_from_bytes(buf):
    try:
        return crl.CertificateList.load(buf), None
    except Exception as e:
        return None, str(e)


def mtime_from_modified(s):
    try:
        return time.mktime(datetime.datetime.strptime(s, '%a, %d %b %Y %H:%M:%S GMT').timetuple()), None
    except Exception as e:
        return None, str(e)


def fetch_crl(url, cache_path, max_size, stat_func=os.stat, http_get=None, timeout=8):
    try:
        mtime = stat_func(cache_path).st_mtime
    except EnvironmentError as e:
        if e.errno != errno.ENOENT:
            return None, str(e)
        if_modified = ''
    else:
        if_modified = datetime.datetime.fromtimestamp(mtime).strftime('%a, %d %b %Y %H:%M:%S GMT')
    headers = {'User-Agent': 'infra/rtc/certman'}
    if if_modified:
        headers['If-Modified-Since'] = if_modified

    if http_get is None:
        # fast fix for HOSTMAN-69, retry errors in 0s, 2s, 4s, 8s
        session = requests.Session()
        session.mount('', HTTPAdapter(max_retries=Retry(total=4, backoff_factor=1,
                                                         status_forcelist=(500, 502, 503, 504))))
        http_get = session.get

    try:
        resp = http_get(url, headers=headers, timeout=timeout)
    except Exception as e:
        return None, str(e)
    if resp.status_code == 200:
        buf = resp.content
        if len(buf) > max_size:
            return None, 'crl is too large, {} bytes'.format(len(buf))
        mtime, err = mtime_from_modified(resp.headers.get('last-modified', ''))
        if err is not None:
            return None, 'failed to parse last-modified from CRL response: {}'.format(err)
        err = fileutil.atomic_write(cache_path, buf, mode='wb',
                                    chmod=0o644,
                                    times=(mtime, mtime),  # We save last-modified in file directly
                                    )
        if err is not None:
            # Although we have CRL at hand, do not return it,
            # so that we can signal failure
            return None, err
        return crl_from_bytes(buf)
    elif resp.status_code == 304:
        buf, err = fileutil.read_file(cache_path, 'rb',
                                      max_size=max_size)
        if err is not None:
            return None, err
        return crl_from_bytes(buf)
    else:
        return None, 'bad http code {}: {}'.format(resp.status_code, resp.reason)


def is_revoked(certs, c):
    serial = c['tbs_certificate']['serial_number'].native
    for revoked in certs['tbs_cert_list']['revoked_certificates']:
        if serial == revoked['user_certificate'].native:
            return True
    return False


def cert_from_pem(pem_bytes):
    """
    Extracts **first** certificate from PEM
    """
    try:
        it = pem.unarmor(pem_bytes, multiple=True)
    except Exception as e:
        return None, 'failed to parse PEM bytes: {}'.format(e)
    # Iterate until first CERTIFICATE
    while 1:
        try:
            type_name, headers, der_bytes = next(it)
        except StopIteration:
            break
        except Exception as e:
            return None, 'failed to decode PEM: {}'.format(e)
        if type_name != 'CERTIFICATE':
            continue
        try:
            return x509.Certificate.load(der_bytes), None
        except Exception as e:
            return None, 'failed to load DER from {} PEM: {}'.format(type_name, e)
    return None, 'No CERTIFICATE found in pem'


def pk_from_pem(pem_bytes):
    """
    Extracts private key if it is **first** record in PEM,
    otherwise returns an error string.
    """
    try:
        it = pem.unarmor(pem_bytes, multiple=True)
    except Exception as e:
        return None, 'failed to parse PEM bytes: {}'.format(e)
    try:
        type_name, headers, der_bytes = next(it)
    except StopIteration:
        return None, 'PEM is empty'
    except Exception as e:
        return None, 'failed to decode PEM: {}'.format(e)
    if type_name != 'PRIVATE KEY':
        return None, "first PEM record is '{}' != 'PRIVATE KEY'".format(type_name)
    try:
        return keys.PrivateKeyInfo.load(der_bytes), None
    except Exception as e:
        return None, 'failed to load DER from {} PEM: {}'.format(type_name, e)


def get_current_utc_tz_aware_datetime():
    return datetime.datetime.now(tz=pytz.UTC)


def cert_needs_reissue(now, days_left, lo=10, hi=100, delta=365, h=None):
    """
    Decides if this host needs to ask for a new certificate given current days left.
    """
    # Always reissue if less than `lo` days left.
    if days_left < lo:
        return True
    # More than `hi` days is okay
    if days_left > hi:
        return False
    # We have some room, so we can choose "better" time to ask for certificate.
    if now.hour < 10 or now.hour > 19:
        # Too early or too late, let's not do it now.
        return False
    # What if today + one year (delta) is a weekend?
    t = now + datetime.timedelta(days=delta)
    if t.weekday() in [5, 6]:  # Oh, no, it's weekend
        return False
    # So, we have value [lo, hi], we want to have uniform distribution
    # so that on average we ask for new certificate
    # when we have (hi - lo)/2 days left.
    # Maybe we can use our hostname to get stable value on every run
    # on this particular server?
    if not h:
        import socket
        h = socket.gethostname()
    # Or we could use hash(hostname) % (hi-lo) + lo instead of PRNG
    turn = random.Random(h).randint(lo, hi)
    if days_left >= turn:
        # Too early for this server, our turn is after days_left.
        return False
    # Effectively we grab first weekday after, thus we can potentially
    # get peaks on mondays, but assume it is okay for now.
    # We better have peaks than too clever code. Prove me wrong!
    return True


def update_certificate(pem_path, host, port, ttl_days):
    """
    Issue new certificate using current one.

    There are curl commands logged to ease manual debug on hosts. This commands
    are equivalent to actions performed by python code. Please keep them
    updated when changing code.

    """
    reissue_url = 'https://{}:{}/api/reissue/'.format(host, port)
    data = json.dumps({'force': True, 'desired_ttl_days': str(ttl_days)})
    log.info(
        "curl "
        "-X POST "
        "--cert {} "
        "-H 'Content-Type: application/json'"
        " -d '{}' {}".format(pem_path,
                             data,
                             reissue_url))
    try:
        r = requests.post(reissue_url,
                          data=data,
                          headers={'Content-Type': 'application/json'},
                          cert=pem_path,
                          timeout=60)
        if r.status_code != 201:
            return None, "return code {}, expected 201, response: {}".format(r.status_code, r.text)
        res = r.json()
        if 'download2' not in res:
            return None, "no 'download2' key in response"
        log.info('Downloading new chain from {}'.format(res['download2']))
        download_url = 'https://{}:{}{}'.format(host, port, urlparse.urlparse(res['download2']).path)
        log.debug(f'curl --cert {pem_path} {download_url}')
        r = requests.get(download_url, cert=pem_path, timeout=60)
        if r.status_code != 200:
            return None, "return {}, expected 200, response: {}".format(r.status_code, r.text)
        return r.content, None
    except Exception as e:
        return None, str(e)


def validate_certificate(c, cname, at, issuer):
    """
    Validates that certificate :param c: is valid, i.e:
      * common name is the one provided
      * certificate is valid :param at: time (datetime in utc)
      * certificate is issuer by :param issuer:
    """
    if c.subject.native != {'common_name': cname}:
        return 'unexpected subject in certificate: {}'.format(dict(c.subject.native))
    if c.issuer.native != issuer:
        return 'unexpected issuer in certificate: {}'.format(dict(c.issuer.native))
    tbs = c['tbs_certificate']
    v = tbs['validity']
    if at > v['not_after'].native:
        return 'certificate is expired, not_after={}'.format(v['not_after'].native)
    if at < v['not_before'].native:
        return 'certificate is not valid yet, not_before={}'.format(v['not_before'].native)
    return None


def days_left_for(c, now):
    tbs = c['tbs_certificate']
    v = tbs['validity']
    return (v['not_after'].native - now).days


def save_pem(pem_bytes, pem_path, chmod=0o400, chown=1049):
    """
    Write and rotate PEM file.

    CUR -> OLD
    NEW -> CUR

    To keep whole thing atomic, first operation is made using os.link (i.e.
    hardlink); os.move in this place will lead to PEM file absence until second
    operation is complete.
    See https://a.yandex-team.ru/review/1800452/details for details.
    """
    try:
        try:
            os.unlink(f'{pem_path}.old')
        except FileNotFoundError:
            pass

        os.link(pem_path, f'{pem_path}.old')
    except Exception as e:
        log.error(f'Failed to save old PEM file: {e}')
        # Nothing more: old cert is for debug purposes only

    log.info(f'Saving PEM to {pem_path}')
    err = fileutil.atomic_write(pem_path, pem_bytes, mode='wb', chmod=chmod)
    if err is not None:
        raise Exception(err)

    log.info(f'Set owner {chown}:{chown} for {pem_path}')
    os.chown(pem_path, chown, chown)


class Certificate(object):
    def __init__(self, hostname):
        self.hostname = hostname
        self.status = Status()

    def manage(self, pem_path, min_days_left, ttl_days, crl_path, dry_run):
        # This is a very (TOO) long function with fair amount of duplicated logic.
        # But at the time of writing I am out of time thinking about  how to refactor it.
        # So please be patient and don't judge too much.
        pem_bytes, err = fileutil.read_file(pem_path, 'rb', max_size=1 * 1024 * 1024)
        if err is not None:
            self.status.error_msg = f'Failed to read PEM file: {err}'
            return self.status

        key, err = pk_from_pem(pem_bytes)
        if err is not None:
            self.status.error_msg = f'Failed to extract key from {pem_path} ({err})'
            # We cannot update certificate without PK
            return self.status

        c, err = cert_from_pem(pem_bytes)
        if err is not None:
            self.status.error_msg = f'Failed to extract cert from {pem_path} ({err})'
            # We cannot update certificate without extracting current one
            return self.status

        now = get_current_utc_tz_aware_datetime()
        err = validate_certificate(c, self.hostname, now, RTC_ISSUER)
        if err is not None:
            self.status.error_msg = err
            # We cannot update invalid certificate automatically, give up.
            return self.status

        days_left = days_left_for(c, now)
        self.status.days_left = days_left
        if days_left < 0:
            self.status.error_msg = 'Certificate already expired'
            # We cannot update expired certificate automatically, give up.
            return self.status

        certs, err = fetch_crl(RCCA_CRL_URL, crl_path, CRL_MAX_SIZE_BYTES)
        if err is None:
            self.crl_check = time.time()

            if is_revoked(certs, c):
                self.status.error_msg = 'Certificate is revoked'
                # We cannot update revoked certificate automatically, give up.
                return self.status
        else:
            log.error(f'Unable to get CRL, revoke check skipped ({err})')

        self.status.error_msg = ''  # current certificate is usable

        if self.status.is_update_required(now, min_days_left, ttl_days, self.hostname):
            log.warning(f'Need to update certificate: days_left={days_left}, '
                        f'min_days={min_days_left} max_days={ttl_days}')

            if dry_run:
                return self.status

            pem_bytes, err = update_certificate(pem_path,
                                                PROD_SERVER,
                                                PROD_PORT,
                                                ttl_days=ttl_days or 365)
            if err is not None:
                log.error(f'Failed to update certificate: {err}')
                return self.status

            # TODO(mixas): verify once again and remove post-update checks:
            # Is seems post-update checks are useless: new pem may be simply
            # put into it's place (there is no second chance to issue new cert
            # using old one: when new cert is issued old one is automatically
            # revoked. Along with old cert moved to .PREV this would give all
            # info required for debug. As for cert status for monitorings, it
            # will be updated on new check cycle (in other hand checks may be
            # factored out to separated method).
            key, err = pk_from_pem(pem_bytes)
            if err is not None:
                self.status.error_msg = 'Failed to extract key from new PEM'
                return self.status

            c, err = cert_from_pem(pem_bytes)
            if err is not None:
                self.status.error_msg = 'Failed to extract cert from new PEM'
                return self.status

            now = get_current_utc_tz_aware_datetime()
            err = validate_certificate(c, self.hostname, now, RTC_ISSUER)
            if err is not None:
                self.status.error_msg = err
                return self.status

            try:
                save_pem(pem_bytes, pem_path)
            except Exception as e:
                self.status.error_msg = f'Failed to save new PEM: {e}'
                return self.status

            days_left = days_left_for(c, now)
            self.status.days_left = days_left
            if days_left < 0:
                self.status.error_msg = 'New certificate is expired'
                return self.status

            self.status.error_msg = ''
            log.info('Certificate updated successfully')
        else:
            self.status.error_msg = ''
            log.debug(f'Certificate {pem_path} is okay: days_left={days_left}')

        return self.status


class Status(object):
    """Certificate status."""

    def __init__(self, crl_check=0, days_left=0, error_msg='Unknown error'):
        """
        Construct certificate status.

        Defaults deliberately set to worse possible combination: it's better to
        once again re-issue certificate than break automatic update entirely.

        State deliberately do not hold things like `update_required` flag and
        similar, which depends on runtime options: same cert may be ok in some
        environment and require update in another, environment may change (host
        moved from prod to testing for example), so it's better to calculate
        such things on the fly.

        """
        self.crl_check = crl_check
        self.days_left = days_left
        self.error_msg = error_msg

    @property
    def crl_check(self):
        return self.__crl_check

    @crl_check.setter
    def crl_check(self, tstamp):
        if tstamp.__class__ is int and tstamp >= 0:
            self.__crl_check = tstamp
        else:
            raise TypeError(tstamp)

    @property
    def days_left(self):
        return self.__days_left

    @days_left.setter
    def days_left(self, days):
        if days.__class__ is int and days >= 0:
            self.__days_left = days
        else:
            raise TypeError(days)

    @property
    def error_msg(self):
        return self.__error_msg

    @error_msg.setter
    def error_msg(self, msg):
        if msg.__class__ is str:
            self.__error_msg = msg
            self.__mtime = time.time()
        else:
            raise TypeError(msg)

    def dump(self):
        return {
            'crl_check': self.__crl_check,
            'days_left': self.__days_left,
            'error_msg': self.__error_msg,
        }

    def dump_to_file(self, path):
        err = fileutil.atomic_write(path, json.dumps(self.dump()),
                                    times=(self.__mtime, self.__mtime))
        if err is not None:
            raise Exception(err)

    def is_update_required(self, now, cert_min_days, cert_max_days, hostname):
        if self.days_left > cert_max_days:
            # We must have been moved from production to prestable
            return True

        if cert_min_days > 0:  # strict threshold, obey.
            if self.days_left < cert_min_days:
                return True
        else:
            # Try to randomize updates from host to host and make cert update
            # process smooth across the cluster (even if hosts was installed at
            # the same time) to avoid situations when certs expiring in big
            # numbers whithin small time intervals. This trick allow to loose
            # hosts slowly even if certs infrastructure disaster lasts weeks.
            return cert_needs_reissue(now, self.days_left, h=hostname)

        return False

    def is_valid(self):
        return not self.__error_msg

    @classmethod
    def load_from_file(cls, filename):
        with open(filename) as f:
            status = cls(**json.load(f))
            status.__mtime = os.fstat(f.fileno()).st_mtime

        return status

    @property
    def mtime(self):
        return self.__mtime
