import abc
import os.path
import functools
import logging

from django.conf import settings
from django.db import OperationalError, transaction
from django.utils import timezone
from ylog.context import log_context

from intranet.crt.constants import CERT_STATUS, CERT_TYPE
from intranet.crt.core.ca.exceptions import RetryCaException, CaError, ValidationCaError
from intranet.crt.utils.domain import get_domain_levels

log = logging.getLogger(__name__)


REVOKE_ACCEPTED_STATUSES = [
    CERT_STATUS.REQUESTED,
    CERT_STATUS.VALIDATION,
    CERT_STATUS.ISSUED,
    CERT_STATUS.HOLD,
]


def log_and_status(accepted_statuses):
    def decorator(func):
        @functools.wraps(func)
        def method(self, cert):
            if self.supported_types and cert.type.name not in self.supported_types:
                raise ValidationCaError(
                    'Invalid {} cert for {}'.format(cert.type.name, cert.ca_name)
                )

            with log_context(certificate_id=cert.id):
                log.info('Trying to {} certificate'.format(func.__name__))

                if cert.status not in accepted_statuses:
                    raise ValidationCaError(
                        'Can not {} certificate with status {}'.format(func.__name__, cert.status)
                    )

                return func(self, cert)
        return method
    return decorator


class BaseCA(object, metaclass=abc.ABCMeta):
    IS_EXTERNAL = True
    IS_ASYNC = True
    IS_SUPPORTING_MULTIPLE_WILDCARDS = True

    PERMISSION_REQUIRED = None

    ISSUE_TIMEOUT_IN_DAYS = 2

    def __init__(self, supported_types=None):
        self.supported_types = supported_types

    @abc.abstractproperty
    def chain_filename(self):
        pass

    @classmethod
    def get_chain_path(cls, is_ecc=False):
        chain_filename = cls.chain_filename
        if is_ecc:
            chain_filename = getattr(cls, 'ecc_chain_filename', chain_filename)
        return os.path.join(settings.CRT_CA_CHAINS_PATH, chain_filename)

    @staticmethod
    def get_autovalidate_domain_names():
        from intranet.crt.core.models import HostToApprove
        query = HostToApprove.objects.all()
        return set(query.values_list('host', flat=True))

    @classmethod
    def find_non_auto_hosts(cls, fqdns):
        auto_hosts = cls.get_autovalidate_domain_names()
        non_auto_hosts = set()
        for fqdn in fqdns:
            fqdn_levels = set(get_domain_levels(fqdn))
            if not bool(fqdn_levels & auto_hosts):
                non_auto_hosts.add(fqdn)
        return non_auto_hosts

    @abc.abstractmethod
    def _issue(self, cert):
        pass

    @log_and_status([CERT_STATUS.REQUESTED, CERT_STATUS.VALIDATION])
    @transaction.atomic
    def issue(self, cert):
        try:
            certificate = self._issue(cert)
        except OperationalError as error:
            log.info('OperationalError, retry. Args: {}'.format(error.args), exc_info=True)
            return
        except RetryCaException as error:
            log.info('Retry ca exception: {}'.format(error.args), exc_info=True)
            return
        except CaError as error:
            log.exception('Critical ca exception')

            cert.status = CERT_STATUS.ERROR
            cert.error_message = str(error)
            cert.save()

            return

        cert.set_certificate(certificate)

        cert.status = CERT_STATUS.ISSUED
        cert.issued = timezone.now()

        cert.save()

        self.hold_old_certificates(cert)

    @log_and_status(REVOKE_ACCEPTED_STATUSES)
    def revoke(self, cert):
        raise NotImplementedError()

    @log_and_status([CERT_STATUS.ISSUED])
    def hold(self, cert):
        raise NotImplementedError()

    @log_and_status([CERT_STATUS.HOLD])
    def unhold(self, cert):
        raise NotImplementedError()

    @staticmethod
    def get_yandex_domain_names():
        from intranet.crt.core.models import HostToApprove
        return set(HostToApprove.objects.values_list('host', flat=True))

    @classmethod
    def find_non_whitelisted_hosts(cls, fqdns):
        return set()

    def hold_old_certificates(self, cert):
        if cert.type.name in {CERT_TYPE.VPN_TOKEN} | CERT_TYPE.USER_HARDWARE_TYPES:
            for cert in cert.user.certificates.issued().filter(
                type__name=CERT_TYPE.VPN_TOKEN,
                revoke_at=None,
            ).exclude(pk=cert.pk):
                cert.controller.add_to_hold_queue(
                    description=f'by api request {cert.type.name} certificate',
                    revoke_at=timezone.now() + timezone.timedelta(
                        days=settings.CRT_VPN_TOKEN_HOLD_ON_REISSUE_AFTER_DAYS,
                    ),
                )
