import copy
import itertools

from django.conf import settings

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import rsa, ec

from intranet.crt.csr.fields import CsrField, SubjectField, ExtensionField, EXTENSION_FIELD

CRYPTOGRAPHY_PUBLIC_EXPONENT = 65537
CRYPTOGRAPHY_KEY_SIZE = 2048


class CsrConfigMeta(type):
    def __new__(mcs, class_name, bases, attrs):
        config_class = super(CsrConfigMeta, mcs).__new__(mcs, class_name, bases, attrs)

        subject_fields = {}
        extension_fields = {}
        for base in reversed(config_class.__mro__[1:]):
            if hasattr(base, 'subject_fields'):
                subject_fields.update(base.subject_fields)
            if hasattr(base, 'extension_fields'):
                extension_fields.update(base.extension_fields)

        for key, value in attrs.items():
            if isinstance(value, CsrField):
                if isinstance(value, SubjectField):
                    subject_fields[key] = value
                if isinstance(value, ExtensionField):
                    extension_fields[key] = value

        required_fields = set()
        for name, field in itertools.chain(subject_fields.items(), extension_fields.items()):
            if field.required:
                required_fields.add(name)

        config_class.subject_fields = subject_fields
        config_class.extension_fields = extension_fields
        config_class.required_fields = required_fields

        ordered_fields = set(config_class.SUBJECT_ORDER + config_class.EXTENSION_ORDER)
        declared_fields = subject_fields.keys() | extension_fields.keys()

        if ordered_fields > declared_fields:
            missed_fields = ', '.join(sorted(ordered_fields - declared_fields))
            raise RuntimeError(
                'Missed fields detected in {class_name}: ({fields})'
                .format(class_name=config_class.__name__, fields=missed_fields)
            )

        return config_class


class BaseCsrConfig(object, metaclass=CsrConfigMeta):
    """ Работает наподобии джанговых форм. При обьявлении наследника нужно указать поля, которые
    должны попасть в csr. Если поле обьявлено с флагом required - значит это поле нужно передать
    аргументом в __init__ конфига. Остальные поля, переданные в __init__ будут проигнорированны.
    В наследниках необходимо обязательно указать поля SUBJECT_ORDER и EXTENSION_ORDER.
    В этих настройках указывается, какие поля и в каком порядке необходимо поместить в сертификат.
    SUBJECT_ORDER - поля которые попадут в DN. EXTENSION_ORDER - расширения.
    """

    SUBJECT_ORDER = []
    EXTENSION_ORDER = []

    def __new__(cls, *args, **kwargs):
        new_config = super(BaseCsrConfig, cls).__new__(cls)

        new_config.subject_fields = {}
        for name, field in cls.subject_fields.items():
            new_field = copy.deepcopy(field)
            setattr(new_config, name, new_field)
            new_config.subject_fields[name] = new_field

        new_config.extension_fields = {}
        for name, field in cls.extension_fields.items():
            new_field = copy.deepcopy(field)
            setattr(new_config, name, new_field)
            new_config.extension_fields[name] = new_field

        return new_config

    def __init__(self, **kwargs):
        self._csr = None
        self._private_key = None

        if self.required_fields > set(kwargs.keys()):
            missed_fields = ', '.join(sorted(self.required_fields - kwargs.keys()))
            raise RuntimeError(
                'Missed initial arguments for {class_name}(): ({fields})'
                .format(fields=missed_fields, class_name=self.__class__.__name__)
            )

        for name in self.required_fields:
            getattr(self, name).set_value(kwargs[name])

    def update_extensions_with_context(self, context):
        self.EXTENSION_ORDER = list(set(self.EXTENSION_ORDER) | set(context.keys()))
        extensions = {
            field_name: EXTENSION_FIELD[field_name](','.join(field_value))
            for (field_name, field_value) in context.items()
        }
        self.extension_fields.update(extensions)

    def get_subject_name(self):
        subject_attributes = []
        for name in self.SUBJECT_ORDER:
            attribute = self.subject_fields[name].get_attribute()
            subject_attributes.append(attribute)
        return x509.Name(subject_attributes)

    def get_extensions(self, **kwargs):
        extensions = []
        for name in self.EXTENSION_ORDER:
            extension = self.extension_fields[name].get_extension()
            extensions.append(extension)
        return extensions

    def _create_csr(self, is_ecc=False):
        subject_name = self.get_subject_name()
        extensions = self.get_extensions()
        private_key = self.generate_private_key(is_ecc=is_ecc)

        builder = x509.CertificateSigningRequestBuilder()
        builder = builder.subject_name(subject_name)

        for extension in extensions:
            builder = builder.add_extension(extension, critical=False)

        self._csr = builder.sign(private_key, hashes.SHA256(), default_backend())
        self._private_key = private_key

    def generate_private_key(self, is_ecc=False):
        if is_ecc:
            return ec.generate_private_key(
                curve=settings.ELLIPTIC_CURVE,
                backend=default_backend()
            )
        else:
            return rsa.generate_private_key(
                public_exponent=CRYPTOGRAPHY_PUBLIC_EXPONENT,
                key_size=CRYPTOGRAPHY_KEY_SIZE,
                backend=default_backend(),
            )

    def get_csr(self, pem=True, is_ecc=False):
        if self._csr is None:
            self._create_csr(is_ecc=is_ecc)

        return self.get_pem_csr() if pem else self._csr

    def get_private_key(self, pem=True):
        if self._private_key is None:
            self._create_csr()

        return self.get_pem_private_key() if pem else self._private_key

    def get_pem_csr(self):
        return self._csr.public_bytes(serialization.Encoding.PEM)

    def get_pem_private_key(self):
        return self._private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption(),
        )
