"""Memoize stage decorator, designed to be API-compatible with SandboxTask.memoize_stage."""

import logging
from contextlib import contextmanager

from sandbox.common import errors
from sandbox.sandboxsdk import channel


class MemoizeStage(object):
    def __init__(self, task, stage_name, max_runs=1, commit_on_entrance=True, commit_on_wait=True):
        self.task = task
        self.stage_name = stage_name
        self.max_runs = max_runs
        self.commit_on_entrance = commit_on_entrance
        self.commit_on_wait = commit_on_wait

        self._exec_key = '_a_stage_{}_exec_count__'.format(stage_name)
        self._skip_key = '_a_stage_{}_skip_count__'.format(stage_name)
        self._logger = logging.getLogger(task.__class__.__name__)

    @property
    def runs(self):
        return self.task.ctx.get(self._exec_key, 0)

    @property
    def passes(self):
        return self.task.ctx.get(self._skip_key, 0)

    @classmethod
    def commit_on_success(cls, task, stage_name, max_runs=1):
        # this option allows to restart stage on exception as well as resume after running subtasks.
        # subtask should be handled separately, though, see `wait` option.
        # Stage is memoized when function finished.
        return cls(task, stage_name, max_runs=max_runs, commit_on_entrance=False, commit_on_wait=False)

    @classmethod
    def wait(cls, task, stage_name, max_runs=1):
        # this option helps to run subtasks only once.
        # Stage is memoized when subtask started.
        # When control returns to current task, this stage is finished.
        return cls(task, stage_name, max_runs=max_runs, commit_on_entrance=False, commit_on_wait=True)

    def __call__(self, func):
        skip = self.runs >= self.max_runs
        self.executed = not skip

        if skip:
            self._skip()
        else:
            self._execute(func)

        return self

    def _skip(self):
        self._logger.info("Skipping stage '%s'", self.stage_name)
        self._inc_key(self._skip_key)

    def _execute(self, func):
        self._logger.info("Entering stage '%s'", self.stage_name)

        try:
            with self._count_executions():
                func()
        finally:
            self._logger.info("Exiting stage '%s'", self.stage_name)

    def _inc_key(self, key):
        value = self.task.ctx.get(key, 0) + 1
        self.task.ctx[key] = value
        channel.channel.sandbox.set_task_context_value(self.task.id, key, value)

    @contextmanager
    def _count_executions(self):
        with self._commit_on_entrance(), self._commit_on_exit(), self._commit_on_wait():
            yield

    @contextmanager
    def _commit_on_entrance(self):

        if self.commit_on_entrance:
            self._inc_key(self._exec_key)

        yield

    @contextmanager
    def _commit_on_exit(self):
        yield

        if not self.commit_on_entrance:
            self._inc_key(self._exec_key)

    @contextmanager
    def _commit_on_wait(self):
        try:
            yield
        except errors.Wait:
            if not self.commit_on_entrance and self.commit_on_wait:
                self._inc_key(self._exec_key)
            raise
