import io
import tarfile

import os
import re
import six
import time
from asn1crypto import pem
from asn1crypto.x509 import Certificate


normalized_cert_id_re = re.compile(r'[^\w\d]')


def get_end_entity_cert(cert_bytes):
    """
    Extract end entity cert from a public key, verifying the intermediate certs

    :type cert_bytes: bytes
    :rtype: Certificate
    """
    prev_cert = end_entity = None
    for type_name, headers, der_bytes in pem.unarmor(cert_bytes, multiple=True):
        current_cert = Certificate.load(der_bytes)
        if type_name != 'CERTIFICATE':
            raise ValueError('Expected only "CERTIFICATE" parts in public cert, instead got {}'.format(type_name))
        if prev_cert:
            if prev_cert.issuer.sha256 != current_cert.subject.sha256:
                raise ValueError('Check intermediate certs: issuer not equal subject')
        else:
            end_entity = current_cert
        prev_cert = current_cert
    return end_entity


def pack_certs(private_key, public_key, cert_id):
    """
    Packs private and public parts of the cert into a specific tgz bundle used by the balancer

    :type private_key: bytes
    :type public_key: bytes
    :type cert_id: six.text_type
    :rtype: bytes
    """
    tar_obj = io.BytesIO()
    current_time = time.time()
    with tarfile.open(fileobj=tar_obj, mode='w:gz') as t:
        for i in ('1st', '2nd', '3rd'):
            key = os.urandom(48)
            tarinfo = tarfile.TarInfo(name='./priv/{}.{}.key'.format(i, cert_id))
            tarinfo.size = len(key)
            tarinfo.mtime = current_time
            t.addfile(tarinfo=tarinfo, fileobj=io.BytesIO(key))

        tarinfo = tarfile.TarInfo(name='./priv/{}.pem'.format(cert_id))
        tarinfo.size = len(private_key)
        tarinfo.mtime = current_time
        t.addfile(tarinfo=tarinfo, fileobj=io.BytesIO(private_key))

        tarinfo = tarfile.TarInfo(name='./allCAs-{}.pem'.format(cert_id))
        tarinfo.size = len(public_key)
        tarinfo.mtime = current_time
        t.addfile(tarinfo=tarinfo, fileobj=io.BytesIO(public_key))
    return tar_obj.getvalue()


def tarball_needs_update(cert_tarball, private_key, public_key, cert_id):
    """
    :type cert_tarball: bytes | None
    :type private_key: six.text_type
    :type public_key: six.text_type
    :type cert_id: six.text_type
    :rtype: tuple[bool, six.text_type]
    """
    if not cert_tarball:
        return True, 'added secrets.tgz'
    expected_files = {
        './priv/{}.pem'.format(cert_id): private_key,
        './allCAs-{}.pem'.format(cert_id): public_key,
    }
    try:
        with tarfile.open(fileobj=io.BytesIO(cert_tarball), mode='r:gz') as t:
            for path, key in six.iteritems(expected_files):
                actual_key = t.extractfile(path).read()
                if actual_key != key:
                    return True, '"{}" has changed'.format(path)
    except tarfile.ReadError:  # archive is corrupted
        return True, 'failed to read old secrets.tgz'
    except KeyError:  # expected cert not found
        return True, 'file "{}" not found'.format(path)
    return False, None


def decimal_to_hexadecimal(value):
    """
    :type value: six.text_type
    :return: six.text_type

    Convert decimal value to hexadecimal with uppercase letters after

    >>> decimal_to_hexadecimal('432745928323474932744197')
    '5BA3342B00020009C005'

    >>> decimal_to_hexadecimal('2215605510564605947266274012786846555')
    '1AAB5C9194CB1C2AFE170FD441FEB5B'
    """
    return format(int(value), 'X')


def decimal_to_padded_hexadecimal(value):
    """
    :type value: six.text_type
    :return: six.text_type

    Convert decimal value to hexadecimal with uppercase letters after 9, then prepend with zero if length is uneven

    >>> decimal_to_padded_hexadecimal('432745928323474932744197')
    '5BA3342B00020009C005'

    >>> decimal_to_padded_hexadecimal('2215605510564605947266274012786846555')
    '01AAB5C9194CB1C2AFE170FD441FEB5B'
    """
    val = decimal_to_hexadecimal(value)
    if len(val) % 2 == 1:
        val = '0{}'.format(val)
    return val


def ensure_single_file_match(value_name, file_name, old_value_is_set, new_value):
    if old_value_is_set:
        raise RuntimeError('Encountered a second file that could match "{}": "{}"'.format(value_name, file_name))
    if isinstance(new_value, six.text_type):
        return new_value.encode('ascii')
    else:
        return new_value


def extract_certs_from_yav_secret(log, flat_cert_id, serial_number, cert_secret):
    """
    :param log
    :type flat_cert_id: six.text_type
    :type serial_number: six.text_type
    :type cert_secret: dict
    :rtype bytes, bytes, Optional[bytes]
    :raises ValueError

    Only works with secrets created by Certificator during certificate order.
    Look for files named "<cert_serial_number>_certificate" and "<cert_serial_number>_private_key".
    They also may be prefixed with zero to make number of hex-digits even, see TOOLSUP-60682.
    Otherwise, look for any files ending with "_certificate" and "_private_key" as a final fallback.
    """
    public_key = private_key = cert_tarball = None
    has_inexact_files = False

    public_key_name = '{}_certificate'.format(serial_number)
    private_key_name = '{}_private_key'.format(serial_number)
    for cert_key, cert_value in six.iteritems(cert_secret):

        if cert_key == 'secrets.tgz':
            cert_tarball = cert_value

        # usual case
        elif cert_key == public_key_name:
            public_key = ensure_single_file_match('public key', cert_key, bool(public_key), cert_value)
        elif cert_key == private_key_name:
            private_key = ensure_single_file_match('private key', cert_key, bool(private_key), cert_value)

        # just-in-case fallback case
        elif cert_key.endswith('_certificate'):
            has_inexact_files = True
            public_key = ensure_single_file_match('public key', cert_key, bool(public_key), cert_value)
        elif cert_key.endswith('_private_key'):
            has_inexact_files = True
            private_key = ensure_single_file_match('private key', cert_key, bool(private_key), cert_value)

    if public_key and private_key:
        if has_inexact_files:
            log.debug('cert {}: secret contains files with inexact name matches'.format(flat_cert_id))
        return public_key, private_key, cert_tarball
    else:
        raise ValueError("cert {}: secret doesn't contain both public and private cert parts".format(flat_cert_id))


def fill_cert_fields(fields_pb, allcas):
    """
    Used by arcadia/infra/awacs/tools/awacscertsctl

    :type fields_pb: awacs.proto.modules_pb2.CertificateSpec.Fields
    :type allcas: asn1crypto.x509 import Certificate
    """
    tbs_certificate = allcas.native['tbs_certificate']
    signature = tbs_certificate['signature']
    validity = tbs_certificate['validity']
    public_key_info = tbs_certificate['subject_public_key_info']
    del fields_pb.subject_alternative_names[:]  # protect against repeated run on the same spec
    fields_pb.subject_alternative_names.extend(allcas.valid_domains)
    fields_pb.subject = allcas.subject.human_friendly
    fields_pb.subject_common_name = allcas.subject.native['common_name']
    fields_pb.issuer = allcas.issuer.human_friendly
    fields_pb.issuer_common_name = allcas.issuer.native['common_name']
    fields_pb.serial_number = str(allcas.serial_number)
    fields_pb.version = tbs_certificate['version']
    fields_pb.signature.algorithm_id = signature['algorithm']
    if signature['parameters']:
        fields_pb.signature.parameters = signature['parameters']
    fields_pb.public_key_info.algorithm_id = public_key_info['algorithm']['algorithm']
    if public_key_info['algorithm']['parameters']:
        fields_pb.public_key_info.parameters = public_key_info['algorithm']['parameters']
    fields_pb.validity.not_before.FromDatetime(validity['not_before'].replace(tzinfo=None))
    fields_pb.validity.not_after.FromDatetime(validity['not_after'].replace(tzinfo=None))


def normalize_cert_id(cert_id):
    return normalized_cert_id_re.sub('_', cert_id)
