MISSING = object()
SHARED_MOCK_DEFAULT = object()


class SharedMock:
    def __init__(self, **kwargs):
        self._call_args = []
        self._enter_called = False
        self._exit_call_args = None
        self.return_value = kwargs.get("return_value", MISSING)
        self.side_effect = kwargs.get("side_effect")
        self._cur_seq_pos = 0

    def __call__(self, *args, **kwargs):
        self._call_args.append((args, kwargs))

        if self.side_effect is not None:
            return self._exec_side_effect(self.side_effect, args, kwargs)
        elif self.return_value is not MISSING:
            return self.return_value
        else:
            return self

    def __enter__(self):
        self._enter_called = True
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._exit_call_args = (exc_type, exc_val, exc_tb)

    @property
    def call_args_list(self):
        return self._call_args.copy()

    @property
    def call_args(self):
        try:
            return self._call_args[-1]
        except IndexError:
            return None

    @property
    def call_count(self):
        return len(self._call_args)

    @property
    def called(self):
        return self.call_count > 0

    def assert_called(self):
        assert self.called, "Expected 'shared_mock' to have been called."

    def assert_not_called(self):
        assert (
            not self.called
        ), f"Expected 'shared_mock' to not have been called. Called {self.call_count} times."

    def assert_called_with(self, *args, **kwargs):
        if not self.called:
            raise AssertionError("Not called")

        if (args, kwargs) != self._call_args[-1]:
            raise AssertionError("Last called with other arguments")

    def assert_any_call(self, *args, **kwargs):
        if not self.called:
            raise AssertionError("Not called")

        if not (args, kwargs) in self._call_args:
            raise AssertionError("Last called with other arguments")

    def assert_enter_called(self):
        assert self._enter_called, "Not called"

    def assert_exit_called(self):
        assert self._exit_call_args is not None, "Not called"

    def assert_exit_called_without_exception(self):
        if self._exit_call_args is None:
            raise AssertionError("Not called")

        assert self._exit_call_args[0] is None, "Called with exception"

    def assert_exit_called_with_exception(self):
        if self._exit_call_args is None:
            raise AssertionError("Not called")

        assert self._exit_call_args[0] is not None, "Called without exception"

    def _exec_side_effect(self, side_effect, args=None, kwargs=None, from_seq=False):
        if isinstance(side_effect, (list, tuple)) and not from_seq:
            cur_pos = self._cur_seq_pos
            self._cur_seq_pos += 1
            try:
                side_effect = self.side_effect[cur_pos]
            except IndexError:
                raise StopIteration
            return self._exec_side_effect(side_effect, from_seq=True)

        elif (
            isinstance(side_effect, BaseException)
            or isinstance(side_effect, type)
            and issubclass(side_effect, BaseException)
        ):
            raise side_effect

        elif callable(side_effect) and not from_seq:
            args = args or tuple()
            kwargs = kwargs or {}
            result = side_effect(*args, **kwargs)

            if result is not SHARED_MOCK_DEFAULT:
                return result
            elif self.return_value is not MISSING:
                return self.return_value
            else:
                return self

        elif from_seq:
            return side_effect

        else:
            raise TypeError(f"Bad side effect of type {type(side_effect)}")

    def _set_return_value(self, value):
        self.return_value = value

    def _set_side_effect(self, value):
        self.side_effect = value
