"""Simple distributed cron for periodic tasks which takes care that job won't be executed more often than its period.

The guarantees aren't strict at this time, but are strict enough for most applications.
"""

import collections
import datetime
import logging
import re
import time
import typing as tp
from collections import namedtuple
from contextlib import contextmanager

import gevent
from gevent.event import Event
from gevent.pool import Pool
from mongoengine import Document, StringField, DictField, DoesNotExist

from sepelib.core import config, constants
from sepelib.core.exceptions import Error
from sepelib.mongo.util import register_model
from walle.clients import juggler as juggler_client
from walle.errors import RecoverableError
from walle.locks import CronJobInterruptableLock, lost_mongo_lock_retry
from walle.statbox.loggers import cron_logger
from walle.stats import stats_manager as stats, DISTRIBUTED as DISTRIBUTED_CNTR, ABSOLUTE_MINIMUM
from walle.util import cloud_tools
from walle.util.mongo import MongoPartitionerService

log = logging.getLogger(__name__)

_LOCK_PATH = "/cron"
_INTERNAL_ERROR_RETRY_PERIOD = constants.MINUTE_SECONDS
_JOB_TIMEOUT_MULTIPLIER = 3

JobState = namedtuple("JobState", ["success_time", "failure_time", "instance"])
"""Represents current job state."""


class CronJobTryingToRunTooSoon(RecoverableError):
    def __init__(self, cron_id, job_id):
        super().__init__("Cron '{}': job: {} trying to run too soon", cron_id, job_id, cron_id=cron_id, job_id=job_id)


class CronJobKnownError(Exception):
    """Raises if cron job failed"""


LAST_VERBOSE_JOB_LOG = collections.defaultdict(time.time)


def _verbose_logging_ratelimit(job_id):
    now = time.time()
    if now - LAST_VERBOSE_JOB_LOG[job_id] > constants.MINUTE_SECONDS:
        LAST_VERBOSE_JOB_LOG[job_id] = now
        return True
    return False


class Cron:
    """Represents a cron daemon."""

    JOB_MONITORING_INTERVAL = constants.MINUTE_SECONDS

    def __init__(self, cron_id, retry_delay=0):
        """
        :param retry_delay: delay for retrying of failed task on host where it has failed to give a chance to other
        workers to successfully complete it.
        """

        # Checked in tests
        self._loops = 0

        self.__id = _validate_id("cron", cron_id)
        self.__instance = cloud_tools.get_process_identifier()
        self.__retry_delay = retry_delay

        self.__jobs = {}
        self.__stop = False
        self.__notify = Event()
        self.__daemon_greenlet = None
        self.__start_time = None
        self.__monitor_greenlet = None
        self.__pool = Pool()
        self.__stop_handlers: tp.Dict[str, tp.Callable[[], None]] = dict()
        self.__statbox_logger = cron_logger(id=cron_id)

    def add_job(
        self,
        job_type,
        func,
        period,
        retry_period=None,
        expected_work_time=None,
        stop_handler: tp.Optional[tp.Callable[[], None]] = None,
    ):
        """Adds a job that will be executed with the specified period.

        Job is considered failed if it raised an exception or returned False.

        :param retry_period: specifies retry period for failed jobs.
        :param expected_work_time: expected time of job execution.
        :param stop_handler: optional, handler what should be called on cron shutdown
        """
        job_id = job_type.value
        _validate_id("job", job_id)

        if job_id in self.__jobs:
            raise Error("Job '{}' already exists.", job_id)

        if retry_period is None:
            retry_period = period
        locking_time = expected_work_time if expected_work_time else min(60 * 60, retry_period)
        lock = CronJobInterruptableLock(job_id, locking_time)
        self.__jobs[job_id] = _Job(job_id, func, lock, period, retry_period, expected_work_time)
        if stop_handler:
            self.__stop_handlers[job_id] = stop_handler
        log.debug("Cron '%s': '%s' job added.", self.__id, job_id)

        self.__notify.set()

    def get_job_state(self, job_id):
        """Returns current state of the specified job."""

        return self.__get_job_state(_validate_id("job", job_id))

    def get_jobs(self):
        return list(self.__jobs.values())

    def start(self):
        """Starts the daemon."""

        if self.__daemon_greenlet is not None:
            raise Error("The daemon is already started.")

        log.info("Starting '%s' cron daemon...", self.__id)

        self.__stop = False
        self.__notify.clear()
        self.__daemon_greenlet = gevent.spawn(self.__daemon)
        self.__monitor_greenlet = gevent.spawn(self.__monitor)
        self.__start_time = time.time()

    def stop(self):
        """Stops the daemon."""

        if self.__daemon_greenlet is None:
            return

        running_jobs = ", ".join(job.id for job in self.__jobs.values() if job.running)
        log.info(
            "Stopping '%s' cron daemon%s...",
            self.__id,
            " (running jobs: {})".format(running_jobs) if running_jobs else "",
        )

        self.__stop = True
        self.__notify.set()
        self.__daemon_greenlet.kill()
        self.__monitor_greenlet.kill()

        for job_id, handler in self.__stop_handlers.items():
            log.info("Executing stop handler for job %s...", job_id)
            try:
                handler()
            except Exception as e:
                log.exception("Failed to execute stop handler for job %s: %s", job_id, e)
            else:
                log.info("Executing stop handler for job %s is finished", job_id)

        self.__pool.kill()

        self.__daemon_greenlet = None
        self.__monitor_greenlet = None
        log.info("Cron is stopped")

    def __monitor(self):
        while not self.__stop:
            for job in self.__jobs.values():
                state = self.__get_job_state(job.id)
                current_time = time.time()
                job_timeout = job.expected_work_time if job.expected_work_time else job.period
                if (state.success_time or self.__start_time) + _JOB_TIMEOUT_MULTIPLIER * job_timeout <= current_time:
                    last_success_run_time = _format_time(state.success_time) if state.success_time else "Does not exist"
                    last_failure_run_time = _format_time(state.failure_time) if state.failure_time else "Does not exist"
                    message = (
                        "Cron job {} last success run timeout\n"
                        "Last success run: {}\n"
                        "Last failure run: {}\n"
                        "Job interval: {} seconds\n"
                        "Instance: {}"
                    ).format(job.id, last_success_run_time, last_failure_run_time, job.period, state.instance)
                    juggler_client.send_event(
                        juggler_service_name="wall-e.cron.{}.last_success_run".format(job.id),
                        status=juggler_client.JugglerCheckStatus.CRIT,
                        message=message,
                        host_name="wall-e.srv.{}".format(config.get_value("environment.name")),
                        tags=["wall-e.cron", state.instance],
                    )
                elif state.success_time is not None and (
                    state.failure_time is None or state.success_time > state.failure_time
                ):
                    juggler_client.send_event(
                        juggler_service_name="wall-e.cron.{}.last_success_run".format(job.id),
                        status=juggler_client.JugglerCheckStatus.OK,
                        message="Cron job {} runs successfully".format(job.id),
                        host_name="wall-e.srv.{}".format(config.get_value("environment.name")),
                        tags=["wall-e.cron", state.instance],
                    )
            gevent.sleep(self.JOB_MONITORING_INTERVAL)

    def __daemon(self):
        with self.__partitioner() as partitioner:
            while not self.__stop:
                self._loops += 1
                try:
                    awake_time = self.__handle_jobs(partitioner)
                except Exception as e:
                    log.error("Cron '%s' got an error during processing its jobs: %s", self.__id, e)
                    timeout = _INTERNAL_ERROR_RETRY_PERIOD
                else:
                    if awake_time is None:
                        timeout = _INTERNAL_ERROR_RETRY_PERIOD
                    else:
                        timeout = awake_time - time.time()
                if self.__notify.wait(timeout):
                    self.__notify.clear()

    def __handle_jobs(self, partitioner):
        awake_time = None

        # NOTE(rocco66): cron has his own job.lock for acquiring
        current_instance_jobs = [
            job
            for job in self.__jobs.values()
            if partitioner.get_shard(job.id, log_state=_verbose_logging_ratelimit(job.id))
        ]
        if _verbose_logging_ratelimit("cron-jobs"):
            log.debug(f"Current instance cron jobs: {[(j.id, j.running) for j in current_instance_jobs]}")
        for job in current_instance_jobs:
            if job.running:
                continue

            if job.run_time is None:
                job.calculate_run_time(self.__get_job_state(job.id))
                log.debug(
                    "Cron '%s': next run for '%s' job is scheduled at %s.",
                    self.__id,
                    job.id,
                    _format_time(job.run_time),
                )

            if job.run_time <= time.time():
                # NOTE(rocco66): cron jobs use their own locks, not shards locks
                job.running = True
                self.__pool.spawn(self.__run_job, job)
            else:
                awake_time = _min(awake_time, job.run_time)

        return awake_time

    def __run_job(self, job):
        result = None

        try:
            with job.lock:
                job.calculate_run_time(self.__get_job_state(job.id))
                time_since_run_time = time.time() - job.run_time

                if time_since_run_time < 0:
                    raise CronJobTryingToRunTooSoon(self.__id, job.id)

                delay = round(time_since_run_time)
                if delay > 1:
                    (log.error if delay > 5 else log.warning)(
                        "Cron '%s': '%s' job has been delayed for %.1f seconds.", self.__id, job.id, time_since_run_time
                    )

                log.debug("Cron '%s': run '%s' job...", self.__id, job.id)
                self.__statbox_logger.log(job_id=job.id, event="started")
                start_time = time.time()
                log.info(f"Cron-job {job.id} ({self.__id}) started.")

                # Use logging very carefully after job execution:
                # it may trigger greenlet switching and cause LockIsLostError

                failed_message = ""
                try:
                    with collect_stats(job):
                        try:
                            result = job.func() is not False
                        except CronJobKnownError as exc:
                            result = False
                            failed_message = str(exc)
                except BaseException as e:
                    result = False
                    result_time = time.time()
                    execution_time = result_time - start_time
                    job.run_time = result_time + job.retry_period + self.__retry_delay

                    log.exception("Cron '%s': job '%s' has crashed:", self.__id, job.id)
                    log.info(f"Cron-job {job.id} ({self.__id}) crashed. Execution time: {execution_time} sec")
                    self.__statbox_logger.log(
                        job_id=job.id, event="crashed", error=str(e), execution_time=execution_time
                    )
                else:
                    result_time = time.time()
                    execution_time = result_time - start_time

                    if result:
                        job.run_time = result_time + job.period
                        log.info(f"Cron-job {job.id} ({self.__id}) finished. Execution time: {execution_time} sec")
                        self.__statbox_logger.log(job_id=job.id, event="finished", execution_time=execution_time)
                    else:
                        job.run_time = result_time + job.retry_period + self.__retry_delay
                        log.info(
                            f"Cron-job {job.id} ({self.__id}) failed. "
                            f"Execution time: {execution_time} sec. "
                            f"{failed_message}"
                        )
                        self.__statbox_logger.log(job_id=job.id, event="failed", execution_time=execution_time)
                finally:
                    lost_mongo_lock_retry(self.__set_job_state)(job, result, result_time)

        except CronJobTryingToRunTooSoon as e:
            log.info("%s", e)
        except Exception as e:
            if result is None:
                log.error("Cron '%s': failed to run '%s' job: %s", self.__id, job.id, e)
                job.run_time = time.time() + _INTERNAL_ERROR_RETRY_PERIOD
            else:
                log.error("Cron '%s': error during handling '%s' job: %s", self.__id, job.id, e)
        finally:
            job.running = False
            self.__notify.set()

            log.debug(
                "Cron '%s': next run for '%s' job is scheduled at %s.", self.__id, job.id, _format_time(job.run_time)
            )

    def __get_job_state(self, job_id) -> JobState:
        state = {}

        try:
            state = _Cron.objects(id=self.__id).only("jobs." + job_id).get().jobs.get(job_id, state)
        except DoesNotExist:
            pass
        except Exception as e:
            raise Error("Failed to get current job state: {}", e)

        return JobState(state.get("success_time"), state.get("failure_time"), state.get("instance"))

    def __set_job_state(self, job, success, state_time):
        try:
            field = "success_time" if success else "failure_time"
            update_kwargs = {
                "set__jobs__{job_id}__{field}".format(job_id=job.id, field=field): state_time,
                "set__jobs__{job_id}__instance".format(job_id=job.id): self.__instance,
            }

            _Cron.objects(id=self.__id).update(upsert=True, multi=False, **update_kwargs)
        except Exception as e:
            log.error("Cron %s: failed to save '%s' job state: %s", self.__id, job.id, e)

    @contextmanager
    def __partitioner(self):
        partitioner = MongoPartitionerService("cron")
        try:
            partitioner.start()
            yield partitioner
        finally:
            partitioner.stop()


@register_model
class _Cron(Document):
    """Stores job state."""

    id = StringField(primary_key=True, required=True, help_text="Cron ID")
    jobs = DictField(help_text="Current job state")

    meta = {"collection": "cron"}


class _Job:
    def __init__(self, job_id, func, lock, period, retry_period, expected_work_time):
        self.id = job_id
        self.func = func
        self.lock = lock
        self.running = False
        self.run_time = None
        self.period = period
        self.retry_period = retry_period
        self.expected_work_time = expected_work_time

    def calculate_run_time(self, state):
        if state.success_time is None:
            if state.failure_time is None:
                run_time = time.time()
            else:
                run_time = state.failure_time + self.retry_period
        else:
            if state.failure_time is not None and state.failure_time >= state.success_time:
                run_time = state.failure_time + self.retry_period
            else:
                run_time = state.success_time + self.period

        self.run_time = run_time


@contextmanager
def collect_stats(job):
    start_time = time.time()
    duration = 0
    stats.increment_counter(("cron", "jobs_running"), aggregation=DISTRIBUTED_CNTR)

    try:
        yield
    except BaseException:
        duration = time.time() - start_time

        stats.add_sample(("cron", "jobs_failure_time"), duration)
        stats.add_sample(("cron", "jobs_failure_time", job.id), duration)
        stats.increment_counter(("cron", "jobs_failures", job.id), aggregation=DISTRIBUTED_CNTR)
        raise

    else:
        result_time = time.time()
        duration = result_time - start_time
        stats.set_age_timestamp(('cron', 'jobs_last_success_run'), result_time, ABSOLUTE_MINIMUM)
        stats.set_age_timestamp(('cron', 'jobs_last_success_run', job.id), result_time, ABSOLUTE_MINIMUM)
        stats.add_sample(("cron", "jobs_success_time"), duration)
        stats.add_sample(("cron", "jobs_success_time", job.id), duration)
        stats.set_counter_value(("cron", "jobs_failures", job.id), 0, aggregation=DISTRIBUTED_CNTR)

    finally:
        stats.increment_counter(("cron", "jobs_running"), -1, aggregation=DISTRIBUTED_CNTR)
        stats.add_sample(("cron", "jobs_execution_time"), duration)


def _validate_id(name, id):
    if not re.search(r"^[0-9a-z-]+$", id, re.IGNORECASE):
        raise Error("Invalid {} ID: {}.", name, id)

    return id


def _min(a, b):
    if a is None:
        return b
    elif b is None:
        return a
    else:
        return min(a, b)


def _format_time(timestamp):
    return datetime.datetime.fromtimestamp(timestamp).strftime("%Y.%m.%d %H:%M:%S")
