# coding: utf-8
import inject
import monotonic
import time
from sepelib.core import config

from infra.awacs.proto import model_pb2
from awacs.lib import ctlmanager, context
from awacs.lib.certificator import ICertificatorClient, CertificatorClient
from awacs.lib.order_processor.model import is_spec_complete, needs_removal
from awacs.lib.ya_vault import IYaVaultClient, YaVaultClient
from awacs.model import events, db, cache, util
from awacs.model.certs import cert
from awacs.model.dao import IDao, Dao
from awacs.model.zk import IZkStorage, ZkStorage


class CertRenewalCtl(ctlmanager.ContextedCtl):
    _cache = inject.attr(cache.IAwacsCache)  # type: cache.AwacsCache
    _db = inject.attr(db.IMongoStorage)  # type: db.MongoStorage
    _zk = inject.attr(IZkStorage)  # type: ZkStorage
    _dao = inject.attr(IDao)  # type: Dao
    _certificator_client = inject.attr(ICertificatorClient)  # type: CertificatorClient
    _yav_client = inject.attr(IYaVaultClient)  # type: YaVaultClient

    EVENTS_QUEUE_GET_TIMEOUT = 20
    SELF_ACTIONS_DELAY_INTERVAL = 30

    # sleep more after unexpected exceptions such as 429 TOO MANY REQUESTS (see AWACS-652 for details)
    SLEEP_AFTER_EXCEPTION_TIMEOUT = 60
    SLEEP_AFTER_EXCEPTION_TIMEOUT_JITTER = 30

    def __init__(self, namespace_id, cert_renewal_id):
        name = 'cert-renewal-ctl("{}:{}")'.format(namespace_id, cert_renewal_id)
        super(CertRenewalCtl, self).__init__(name)
        self._cert_renewal_id = cert_renewal_id
        self._namespace_id = namespace_id
        self._renewal_check_deadline = monotonic.monotonic()
        self._allow_time_pressured_renewal = bool(
            config.get_value('run.ignore_cert_renewal_pause_on_expiration_deadline'))
        self._full_cert_renewal_id = (self._namespace_id, self._cert_renewal_id)
        self._pb = self._cache.must_get_cert_renewal(namespace_id=namespace_id,
                                                     cert_renewal_id=cert_renewal_id)
        self._backup_pb = self._db.must_get_cert_rev(self._pb.meta.target_rev)
        self._exact_expiration_time = self._backup_pb.spec.fields.validity.not_after.ToSeconds()
        self._not_after = self._exact_expiration_time - util.SECONDS_IN_DAY

    @property
    def _cert_renewal_delay_after_issuing(self):
        return max(config.get_value('run.cert_renewal_delay_after_issuing', default=util.SECONDS_IN_DAY), 0)

    @property
    def _not_before(self):
        return (self._pb.spec.fields.validity.not_before.ToSeconds() +
                self._cert_renewal_delay_after_issuing)

    def _accept_event(self, event):
        return (isinstance(event, events.CertRenewalUpdate) and
                event.pb.meta.namespace_id == self._namespace_id and
                event.pb.meta.id == self._cert_renewal_id)

    def _start(self, ctx):
        try:
            self._process(ctx)
        except ctlmanager.UNEXPECTED_EXCEPTIONS as e:
            ctx.log.exception('failed to process cert renewal on start: %s', e)
        self._cache.bind(self._callback)

    def _stop(self):
        self._cache.unbind(self._callback)

    @staticmethod
    def _check_algorithms(pb, renewal_pb, field_name, upgraded_algorithms=None):
        original_field_pb = getattr(pb.spec.fields, field_name)
        renewal_field_pb = getattr(renewal_pb.spec.fields, field_name)

        original_algo = original_field_pb.algorithm_id
        new_algo = renewal_field_pb.algorithm_id
        algos_match = original_algo == new_algo

        if not algos_match and upgraded_algorithms:
            algos_match = original_algo in upgraded_algorithms and upgraded_algorithms[original_algo] == new_algo
        if not algos_match or original_field_pb.parameters != renewal_field_pb.parameters:
            raise RuntimeError('{} algorithm mismatch. Original cert has "{}{}", renewal has "{}{}"'.format(
                field_name,
                original_field_pb.algorithm_id,
                ' ({})'.format(original_field_pb.parameters) if original_field_pb.parameters else '',
                renewal_field_pb.algorithm_id,
                ' ({})'.format(renewal_field_pb.parameters) if renewal_field_pb.parameters else '',
            ))

    def _should_renew(self, ctx):
        current_time = monotonic.monotonic()

        # Don't check too often
        if current_time < self._renewal_check_deadline:
            return None
        self._renewal_check_deadline = current_time + self.SELF_ACTIONS_DELAY_INTERVAL

        cert_pb = self._cache.get_cert(namespace_id=self._namespace_id,
                                       cert_id=self._pb.meta.id)
        if not cert_pb:
            # renewal object itself will be removed shortly by manager
            ctx.log.info('target certificate does not exist')
            return None

        if needs_removal(cert_pb):
            ctx.log.info('target certificate is marked for deletion')
            return None

        if cert_pb.meta.version != self._pb.meta.target_rev:
            if cert_pb.spec.fields.validity.not_after == self._pb.spec.fields.validity.not_after:
                ctx.log.info('target certificate is already renewed')
                return None
            else:
                raise RuntimeError('Certificate version "{}" doesn\'t match expected version "{}"'.format(
                    cert_pb.meta.version, self._pb.meta.target_rev
                ))

        if cert_pb.spec != self._backup_pb.spec:
            raise RuntimeError("Current certificate spec doesn't match spec from backup revision".format(
                cert_pb.meta.version, self._pb.meta.target_rev
            ))

        self._check_algorithms(cert_pb, self._pb, 'public_key_info', upgraded_algorithms=None)
        if cert_pb.spec.fields.public_key_info.algorithm_id == u'ec':
            upgraded_signature_algos = {u'sha256_rsa': u'sha384_ecdsa'}
        else:
            upgraded_signature_algos = None
        self._check_algorithms(cert_pb, self._pb, 'signature', upgraded_algorithms=upgraded_signature_algos)

        paused = self._pb.meta.paused

        # if less than 24 is left before cert expiration, we really need to renew it
        real_time = time.time()
        if real_time >= self._not_after:
            seconds_until_expiration = self._exact_expiration_time - real_time
            if not paused.value or self._allow_time_pressured_renewal:
                ctx.log.info('target cert expires in less than 24h (%s seconds from now), forcing spec copy',
                             seconds_until_expiration)
                return cert_pb
            else:
                ctx.log.info('target cert expires in less than 24h (%s seconds from now), but renewal is paused',
                             seconds_until_expiration)
                return None

        # otherwise respect the paused status
        if paused.value:
            ctx.log.info('renewal is paused by %s with comment "%s"', paused.author, paused.comment)
            return None

        forced = cert_pb.meta.force_renewal
        if forced.value:
            ctx.log.info('renewal is forced by %s with comment "%s"', forced.author, forced.comment)
            return cert_pb

        # avoid problems for users with invalid clock settings
        if real_time < self._not_before:
            seconds_until_valid = self._not_before - real_time

            ctx.log.info('renewed cert has been valid for less than '
                         '"run.cert_renewal_delay_after_issuing" = %s seconds '
                         '(will be ready in %s seconds), waiting',
                         self._cert_renewal_delay_after_issuing, seconds_until_valid)
            return None

        ctx.log.info('ordinary renewal')
        return cert_pb

    def _copy_spec_to_target_cert(self, ctx):
        ctx.log.info('updating target cert spec')
        if self._pb.meta.HasField('target_discoverability'):
            meta_pb = model_pb2.CertificateMeta(discoverability=self._pb.meta.target_discoverability)
        else:
            meta_pb = None
        self._dao.update_cert(namespace_id=self._namespace_id,
                              cert_id=self._pb.meta.id,
                              version=self._pb.meta.target_rev,
                              comment='Renewed certificate',
                              login=util.NANNY_ROBOT_LOGIN,
                              updated_spec_pb=self._pb.spec,
                              updated_meta_pb=meta_pb,
                              disable_force_renewal=True)
        ctx.log.info('finished updating target cert spec')

    def _process(self, ctx):
        """
        :type ctx: context.OpCtx
        """
        self._pb = self._cache.must_get_cert_renewal(self._namespace_id, self._cert_renewal_id)

        assert is_spec_complete(self._pb), 'Guaranteed by CertRenewalCtlManager'

        ctx.log.debug('Checking if should renew certificate')
        if self._should_renew(ctx):
            self._copy_spec_to_target_cert(ctx)
        else:
            ctx.log.debug('Renewal is not needed')

        ctx.log.debug('Checking if should delete backup')
        if self._should_delete_backup():
            self._delete_backup(ctx)
        else:
            ctx.log.debug('Not deleting backup')

        ctx.log.debug('Checking if should delete self')
        if self._should_self_delete():
            self._delete_self(ctx)
        else:
            ctx.log.debug('Not removing self')

    def _should_delete_backup(self):
        return cert.is_expired(self._backup_pb)

    def _should_self_delete(self):
        """
        1. Backup cert is expired
        2. Current cert is not expired (or deleted)
        """
        if cert.is_expired(self._backup_pb):
            cert_pb = self._zk.get_cert(namespace_id=self._namespace_id, cert_id=self._pb.meta.id)
            return not cert_pb or not cert.is_expired(cert_pb)
        return False

    def _delete_backup(self, ctx):
        """
        :type ctx: context.OpCtx
        """
        ctx.log.info('started cert backup deletion')
        ctx.log.info('cleaning up old cert')
        if self._backup_pb.spec.storage.ya_vault_secret.secret_id:
            ctx.log.info('removing old cert from storage')
            self._yav_client.remove_secret(self._backup_pb.spec.storage.ya_vault_secret.secret_id)
            ctx.log.info('finished removing old cert from storage')
        else:
            ctx.log.info('not removing old cert from storage: no storage.ya_vault_secret.secret_id')

    def _delete_self(self, ctx):
        ctx.log.info('removing cert renewal from zk')
        self._zk.remove_cert_renewal(namespace_id=self._namespace_id, cert_renewal_id=self._cert_renewal_id)

    def _process_event(self, ctx, event):
        assert isinstance(event, events.CertRenewalUpdate)
        self._process(ctx)

    def _process_empty_queue(self, ctx):
        self._process(ctx)
