import abc

import pem
from OpenSSL import crypto
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from django.utils import timezone
from django.utils.encoding import force_text, force_bytes

from intranet.crt.constants import CERT_EXTENSION
from intranet.crt.exceptions import (
    CrtError,
    CsrMultipleOid,
)
from intranet.crt.utils.punicode import decode_punicode


def serial_number_converter(serial_number_int):
    serial_number = '{:X}'.format(serial_number_int)
    if len(serial_number) % 2:
        serial_number = '0' + serial_number
    return serial_number


def get_x509_custom_extensions(x509_obj):
    x509_custom_extensions = {}
    for field in CERT_EXTENSION.ALL:
        try:
            extension = x509_obj.extensions.get_extension_for_oid(CERT_EXTENSION.OID[field])
            x509_custom_extensions[field] = extension.value
        except x509.ExtensionNotFound:
            pass
    return x509_custom_extensions


class BasePemCertificate(object, metaclass=abc.ABCMeta):
    @abc.abstractproperty
    def loader(self):
        pass

    def __init__(self, pem_data):
        self.x509_object = self.loader(force_bytes(pem_data))

    def get_extension(self, extension_type):
        for extension in self.x509_object.extensions:
            if isinstance(extension.value, extension_type):
                return extension

        return None

    def get_subject_attribute(self, oid):
        attribute = self.x509_object.subject.get_attributes_for_oid(oid)
        if not len(attribute):
            return None
        elif len(attribute) > 1:
            raise CsrMultipleOid()
        else:
            return attribute[0].value

    @property
    def is_ecc(self):
        return self.x509_object.signature_algorithm_oid._name == 'ecdsa-with-SHA256'

    @property
    def sans(self):
        sans_extension = self.get_extension(x509.SubjectAlternativeName)
        if sans_extension is None:
            return []

        sans = sans_extension.value.get_values_for_type(x509.DNSName)
        return [decode_punicode(san) for san in sans]

    @property
    def common_name(self):
        cns = self.x509_object.subject.get_attributes_for_oid(x509.OID_COMMON_NAME)
        cns = [cn for cn in cns if cn.value != 'Users']
        if len(cns) > 1:
            raise CsrMultipleOid()
        return cns[0].value if cns else None

    @property
    def organizational_unit(self):
        return self.get_subject_attribute(x509.OID_ORGANIZATIONAL_UNIT_NAME)

    @property
    def email_address(self):
        return self.get_subject_attribute(x509.OID_EMAIL_ADDRESS)


class PemCertificateRequest(BasePemCertificate):
    loader = default_backend().load_pem_x509_csr


class PemCertificate(BasePemCertificate):
    loader = default_backend().load_pem_x509_certificate

    @property
    def not_after(self):
        return timezone.utc.localize(self.x509_object.not_valid_after)

    @property
    def not_before(self):
        return timezone.utc.localize(self.x509_object.not_valid_before)

    @property
    def serial_number(self):
        return serial_number_converter(self.x509_object.serial_number)


def create_pfx(db_certificate, password, include_cacerts, cacerts_filename):
    private_key = crypto.load_privatekey(crypto.FILETYPE_PEM, db_certificate.priv_key)
    certificate = crypto.load_certificate(crypto.FILETYPE_PEM, db_certificate.certificate)

    pfx = crypto.PKCS12()
    pfx.set_privatekey(private_key)
    pfx.set_certificate(certificate)

    if include_cacerts:
        pem_cacerts = pem.parse_file(cacerts_filename)
        cacerts = [crypto.load_certificate(crypto.FILETYPE_PEM, cert.as_bytes())
                   for cert in pem_cacerts]

        pfx.set_ca_certificates(cacerts)

    return pfx.export(password)


def parse_pfx(pfx, passphrase):
    p12 = crypto.load_pkcs12(pfx, passphrase)

    certificate = p12.get_certificate()
    private_key = p12.get_privatekey()

    pem_certificate = crypto.dump_certificate(crypto.FILETYPE_PEM, certificate)
    pem_private_key = crypto.dump_privatekey(crypto.FILETYPE_PEM, private_key)

    pem_certificate = force_text(pem_certificate)
    pem_private_key = force_text(pem_private_key)

    return pem_certificate, pem_private_key


class PrivateKeyCryptographer(object):
    def __init__(self, passwords):
        self.passwords = passwords

    def encrypt(self, pem_data):
        private_key = default_backend().load_pem_private_key(force_bytes(pem_data), None)

        return private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.BestAvailableEncryption(force_bytes(self.passwords[0])),
        )

    def decrypt(self, pem_data):
        private_key = None
        for password in self.passwords:
            try:
                private_key = default_backend().load_pem_private_key(force_bytes(pem_data), force_bytes(password))
            except ValueError:
                pass

        if private_key is None:
            raise CrtError('Can not decrypt private key (invalid passwords)')

        return private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption(),
        )


def request_contains_custom_crt_extensions(request_bytes):
    x509_object = PemCertificateRequest(request_bytes).x509_object
    return bool(
        set(CERT_EXTENSION.OID.values()) & {ext.oid for ext in x509_object.extensions._extensions}
    )
