# coding: utf-8
import inject
import monotonic
import six
import time
from sepelib.core import config

from awacs.lib import ctlmanager, context
from awacs.lib.certificator import ICertificatorClient, CertificatorClient
from awacs.lib.order_processor.model import has_actionable_spec, needs_removal
from awacs.lib.strutils import to_full_id
from awacs.lib.ya_vault import IYaVaultClient, YaVaultClient
from awacs.model import events, cache, util
from awacs.model.certs import cert
from awacs.model.dao import IDao, Dao
from awacs.model.errors import ConflictError
from awacs.model.zk import IZkStorage, ZkStorage
from infra.awacs.proto import model_pb2


# This is hacky, but there's no clear way to do it properly: https://st.yandex-team.ru/TOOLSUP-58560
CERTIFICATOR_COMMON_NAMES = {
    'YandexInternalCA': 'InternalCA',
    'Yandex CA': 'CertumProductionCA',
    'Test-YandexInternal-Ca': 'InternalTestCA',
}


class CertCtl(ctlmanager.ContextedCtl):
    _cache = inject.attr(cache.IAwacsCache)  # type: cache.AwacsCache
    _zk = inject.attr(IZkStorage)  # type: ZkStorage
    _dao = inject.attr(IDao)  # type: Dao
    _certificator_client = inject.attr(ICertificatorClient)  # type: CertificatorClient
    _yav_client = inject.attr(IYaVaultClient)  # type: YaVaultClient

    EVENTS_QUEUE_GET_TIMEOUT = 10
    SELF_ACTIONS_DELAY_INTERVAL = 30
    SELF_DELETION_COOLDOWN_PERIOD = 30

    # sleep more after unexpected exceptions such as 429 TOO MANY REQUESTS (see AWACS-652 for details)
    SLEEP_AFTER_EXCEPTION_TIMEOUT = 60
    SLEEP_AFTER_EXCEPTION_TIMEOUT_JITTER = 30

    def __init__(self, namespace_id, cert_id):
        name = 'cert-ctl("{}:{}")'.format(namespace_id, cert_id)
        super(CertCtl, self).__init__(name)
        self._cert_id = cert_id
        self._namespace_id = namespace_id
        self._self_deletion_check_deadline = monotonic.monotonic()
        self._renewal_check_deadline = monotonic.monotonic()
        self._full_cert_id = (self._namespace_id, self._cert_id)
        self._allow_automation = bool(config.get_value('run.allow_automatic_cert_management'))
        self._renewal_deadline = int(config.get_value('run.days_until_cert_expiration_to_renew')) * util.SECONDS_IN_DAY
        self._pb = None  # type: model_pb2.Certificate or None

    def _accept_event(self, event):
        return (isinstance(event, events.CertUpdate) and
                event.pb.meta.namespace_id == self._namespace_id and
                event.pb.meta.id == self._cert_id and
                has_actionable_spec(event.pb))

    def _start(self, ctx):
        try:
            self._process(ctx)
        except ctlmanager.UNEXPECTED_EXCEPTIONS as e:
            ctx.log.exception('failed to process cert on start: %s', e)
        self._cache.bind(self._callback)

    def _stop(self):
        self._cache.unbind(self._callback)

    def _crossed_renewal_deadline(self):
        if not self._pb.spec.fields.validity.HasField('not_after'):
            raise RuntimeError('Cert has no field "validity.not_after"')
        not_after = self._pb.spec.fields.validity.not_after.ToSeconds()
        if not_after == 0:
            raise RuntimeError('Cert has "validity.not_after" equal to 0')
        return (not_after - self._renewal_deadline) <= time.time()

    def _needs_renewal(self, ctx):
        current_time = monotonic.monotonic()
        if current_time < self._renewal_check_deadline:
            return False
        self._renewal_check_deadline = current_time + self.SELF_ACTIONS_DELAY_INTERVAL

        forced = self._pb.meta.force_renewal
        if forced.value:
            ctx.log.info('renewal is forced by %s with comment "%s"', forced.author, forced.comment)
            return True

        if not self._crossed_renewal_deadline():
            return False

        # avoid race with renewed cert spec
        self._pb = self._zk.must_get_cert(namespace_id=self._namespace_id, cert_id=self._cert_id, sync=True)
        if not self._crossed_renewal_deadline():
            return False

        return True

    def _fill_cert_renewal_order(self, ctx, cert_renewal_pb):
        """
        :type cert_renewal_pb: model_pb2.CertificateRenewal
        """
        content = cert_renewal_pb.order.content
        if self._pb.spec.source == model_pb2.CertificateSpec.IMPORTED:
            abc_service_id = self._pb.spec.imported.abc_service_id
        elif self._pb.spec.source in (model_pb2.CertificateSpec.CERTIFICATOR,
                                      model_pb2.CertificateSpec.CERTIFICATOR_TESTING):
            abc_service_id = self._pb.spec.certificator.abc_service_id
        else:
            raise RuntimeError('Unsupported certificate source: {}'.format(
                ctx.id(), model_pb2.CertificateSpec.Source.Name(self._pb.spec.source)))
        if self._pb.spec.certificator.ca_name:
            ca_name = self._pb.spec.certificator.ca_name
        else:
            ca_name = CERTIFICATOR_COMMON_NAMES.get(self._pb.spec.fields.issuer_common_name)
        if not ca_name:
            raise RuntimeError('Unknown CA common name "{}". Supported names: {}'.format(
                self._pb.spec.fields.issuer_common_name, ', '.join(sorted(CERTIFICATOR_COMMON_NAMES))))
        content.ca_name = ca_name
        content.common_name = self._pb.spec.fields.subject_common_name
        san = [s for s in self._pb.spec.fields.subject_alternative_names if s != content.common_name]
        content.subject_alternative_names.extend(san)
        content.abc_service_id = abc_service_id
        content.public_key_algorithm_id = self._pb.spec.fields.public_key_info.algorithm_id
        content.ttl = self._pb.meta.force_renewal.cert_ttl

    def _make_cert_renewal_pb(self, ctx):
        """
        :param ctx: Context.OpCtx
        :rtype: model_pb2.CertificateRenewal
        """
        cert_renewal_pb = model_pb2.CertificateRenewal()
        cert_renewal_pb.meta.id = self._cert_id
        cert_renewal_pb.meta.namespace_id = self._namespace_id
        cert_renewal_pb.meta.target_rev = self._pb.meta.version
        if self._allow_automation:
            cert_renewal_pb.meta.paused.value = False
        else:
            cert_renewal_pb.meta.paused.value = True
            cert_renewal_pb.meta.paused.comment = 'Automatic renewal paused due to awacs configuration'
        cert_renewal_pb.meta.paused.mtime.GetCurrentTime()
        cert_renewal_pb.meta.paused.author = util.NANNY_ROBOT_LOGIN
        cert_renewal_pb.spec.incomplete = True
        self._fill_cert_renewal_order(ctx, cert_renewal_pb)
        return cert_renewal_pb

    def _start_renewal(self, ctx):
        cert_renewal_pb = self._zk.get_cert_renewal(namespace_id=self._namespace_id,
                                                    cert_renewal_id=self._cert_id,
                                                    sync=True)
        if cert_renewal_pb:
            if cert_renewal_pb.meta.target_rev != self._pb.meta.version:
                raise RuntimeError('Renewal already exists, but has invalid target_rev')
            return

        cert_renewal_pb = self._make_cert_renewal_pb(ctx)
        try:
            self._zk.create_cert_renewal(namespace_id=self._namespace_id,
                                         cert_renewal_id=self._cert_id,
                                         cert_renewal_pb=cert_renewal_pb)
            ctx.log.info('cert renewal successfully started')
        except ConflictError:
            ctx.log.info('cert renewal is already in progress')

    def _process(self, ctx):
        """
        :type ctx: context.OpCtx
        """
        self._pb = self._cache.must_get_cert(self._namespace_id, self._cert_id)

        if needs_removal(self._pb):
            ctx.log.debug('Maybe deleting cert marked for removal')
            if self._ready_to_delete(ctx):
                self._self_delete(ctx)
            return

        assert has_actionable_spec(self._pb), 'Guaranteed by CertCtlManager'

        if self._needs_renewal(ctx):
            return self._start_renewal(ctx)

    def _is_used(self):
        """
        A bullet-proof and hugely inefficient method to know if cert is referenced by any balancer state --
        avoid caches, query Zookeeper with sync=True.
        Let's call it at least for the time being, until we make sure there are no bugs in our caching code.

        :rtype: bool
        """

        def get_keys_w_statuses(m):
            rv = set()
            for k, v_pb in six.iteritems(m):
                if v_pb.statuses:
                    rv.add(k)
            return rv

        for balancer_pb in self._cache.list_all_balancers(namespace_id=self._namespace_id):
            balancer_state_pb = self._zk.must_get_balancer_state(
                namespace_id=balancer_pb.meta.namespace_id,
                balancer_id=balancer_pb.meta.id,
                sync=True)
            included_full_cert_ids = {
                to_full_id(balancer_state_pb.namespace_id, cert_id)
                for cert_id in get_keys_w_statuses(balancer_state_pb.certificates)}
            if self._full_cert_id in included_full_cert_ids:
                return True

        return False

    def _self_delete(self, ctx):
        """
        :type ctx: context.OpCtx
        """
        ctx.log.info('started self deletion')
        try:
            ctx.log.info('starting _is_used()')
            with ctx.with_forced_timeout(60 * 10):
                is_cert_used = self._is_used()
        except context.CtxTimeoutExceeded:
            ctx.log.warn('_is_used() timed out, returning...')
            return
        except context.CtxTimeoutCancelled:
            ctx.log.debug('ctx is cancelled: %s, returning...', ctx.error())
            return
        ctx.log.info('finished _is_used()')
        if is_cert_used:
            raise RuntimeError("Critical error: would delete a referenced cert if it wasn't for this raise")

        if cert.needs_revocation(self._pb) and self._pb.spec.certificator.order_id:
            if self._pb.meta.unrevokable.value:
                ctx.log.info('not revoking cert, it\'s marked as unrevokable by %s with comment "%s"',
                             self._pb.meta.unrevokable.author, self._pb.meta.unrevokable.comment)
            else:
                ctx.log.info('revoking cert')
                self._certificator_client.revoke_certificate(self._pb.spec.certificator.order_id)
                ctx.log.info('finished revoking cert')
        if cert.needs_storage_removal(self._pb) and self._pb.spec.storage.ya_vault_secret.secret_id:
            ctx.log.info('removing cert from storage')
            self._yav_client.remove_secret(self._pb.spec.storage.ya_vault_secret.secret_id)
            ctx.log.info('finished removing cert from storage')
        if self._cache.get_cert_renewal(*self._full_cert_id):
            ctx.log.info('removing cert renewal from zk')
            self._zk.remove_cert_renewal(*self._full_cert_id)
        ctx.log.info('removing cert from db')
        self._dao.delete_cert(*self._full_cert_id)

    def _ready_to_delete(self, ctx):
        """
        :type ctx: context.OpCtx
        """
        current_time = monotonic.monotonic()

        # don't check too often
        if current_time < self._self_deletion_check_deadline:
            return False

        # let balancers update their state before proceeding
        self_deletion_elapsed_time = time.time() - self._pb.meta.mtime.ToSeconds()
        if self_deletion_elapsed_time < self.SELF_DELETION_COOLDOWN_PERIOD:
            ctx.log.info('too little time passed since cert modification: %s out of minimum %s',
                         self_deletion_elapsed_time, self.SELF_DELETION_COOLDOWN_PERIOD)
            return False
        self._self_deletion_check_deadline = current_time + self.SELF_ACTIONS_DELAY_INTERVAL

        full_balancer_ids = self._cache.list_full_balancer_ids_for_cert(*self._full_cert_id)
        if full_balancer_ids:
            ctx.log.info("cannot delete cert since it's used in balancers: %s",
                         ', '.join(i for _, i in full_balancer_ids))
            return False
        return True

    def _process_event(self, ctx, event):
        assert isinstance(event, events.CertUpdate)
        self._process(ctx)

    def _process_empty_queue(self, ctx):
        self._process(ctx)
