import contextlib
import logging
import os
import shutil
import sys
import tempfile

import six


# TODO (BYIN-14594): it's unused when Python 2 is gone. Use then contextlib.ExitStack
class _ExitStackBackport(object):
    """
    Backport of python3 contextlib.ExitStack.
    """

    def __init__(self):
        super(_ExitStackBackport, self).__init__()
        self._exit_stack = []

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, traceback):
        self._close(exc_type, exc_val, traceback)

    def push(self, exit_):
        self._exit_stack.append(exit_)

    def enter_context(self, cm):
        value = cm.__enter__()
        self.push(cm.__exit__)
        return value

    def callback(self, callback, *args, **kwargs):
        self.push(
            lambda exc_type, exc_val, traceback: callback(*args, **kwargs))
        return callback

    def pop_all(self):
        new_stack = self.__class__()
        new_stack._exit_stack = self._exit_stack
        self._exit_stack = []
        return new_stack

    def close(self):
        self._close(None, None, None)

    def _close(self, exc_type, exc_val, traceback):
        pending_raise = False
        for exit_function in reversed(self._exit_stack):
            try:
                exit_function(exc_type, exc_val, traceback)
            except:
                exc_type, exc_val, traceback = sys.exc_info()
                pending_raise = True
        self._exit_stack = []
        if pending_raise:
            six.reraise(exc_type, exc_val, traceback)


class TempEnvironment(object):
    logger = logging.getLogger('TempEnvironment')

    def __init__(self):
        self._env = {}

    def set_var(self, key, value):
        if key not in self._env:
            self.logger.debug('Remember initial value of %s: %r', key, os.environ.get(key))
            self._env[key] = os.environ.get(key)
        self.logger.debug('Set %s to %r', key, value)
        if value is None:
            os.environ.pop(key, None)
        else:
            os.environ[key] = value

    def del_var(self, key):
        self.set_var(key, None)

    def restore(self):
        for key, value in self._env.items():
            self.logger.debug('Restore initial value of %s: %r (was %r)', key, value, os.environ.get(key))
            if value is None:
                os.environ.pop(key, None)
            else:
                os.environ[key] = value
        self._env.clear()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.restore()


class TmpdirContext(object):
    """
    Inside this context manager all TMP-related envvars will point to some subdirectory of /tmp
    """
    ENVVARS = ('TMP', 'TMPDIR', 'TEMP', 'TEMPDIR')
    logger = logging.getLogger('TmpdirContext')

    def __init__(self, base_tmpdir):
        self.base_tmpdir = base_tmpdir

        self._temp_env = TempEnvironment()
        self._old_python_tempdir = None

        self._tmpdir = None

    def _create_tmpdir(self):
        if not os.path.exists(self.base_tmpdir):
            os.makedirs(self.base_tmpdir)
        return tempfile.mkdtemp(dir=self.base_tmpdir)

    def __enter__(self):
        self._old_python_tempdir = tempfile.tempdir
        self.logger.debug('tempfile.tempdir was %s', self._old_python_tempdir)

        self._tmpdir = self._create_tmpdir()

        for key in self.ENVVARS:
            if key in os.environ:
                self._temp_env.set_var(key, self._tmpdir)

        tempfile.tempdir = None  # Unset cached tempdir if was.
        self.logger.debug('tempfile.gettempdir() is %s', tempfile.gettempdir())
        assert tempfile.gettempdir() == self._tmpdir

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._temp_env.restore()
        tempfile.tempdir = self._old_python_tempdir

        shutil.rmtree(self._tmpdir, ignore_errors=True)


if six.PY3:
    ExitStack = contextlib.ExitStack
else:
    ExitStack = _ExitStackBackport


@contextlib.contextmanager
def TempfileContext(**tempfile_arguments):
    """
    Provides named temporary file that would be removed on context exit.
    """
    named_file = tempfile.NamedTemporaryFile(delete=False, **tempfile_arguments)
    try:
        named_file.close()
        yield named_file.name
    finally:
        os.remove(named_file.name)


@contextlib.contextmanager
def ChdirContext(new_cwd):
    """
    Changes current directory to specified one and restores initial current
    directory on context exit.
    """
    current_directory = os.getcwd()
    try:
        os.chdir(new_cwd)
        yield current_directory
    finally:
        os.chdir(current_directory)
