import time

from sqlalchemy.exc import (
    InternalError,
    StatementError,
    OperationalError,
    InvalidRequestError,
    DBAPIError,
)


class ConnectionErrorFor503(RuntimeError):
    pass


# @contextmanager
# def blabla(session, attempts=10):
#     while attempts > 0:
#         attempts -= 1
#         try:
#             yield
#         except SA.exc.DBAPIError as exc:
#             if attempts > 0 and exc.connection_invalidated:
#                 session.rollback()
#             else:
#                 raise


class RetryData(object):
    def __init__(self):
        pass

    def get_next_sleep(self):
        pass


class ExecuteRequest(object):
    """
    TODO: pass custom object with exponential backoff sleep generator
    """

    def __init__(self, session, action: callable, retries: int = 0):
        if retries < 0:
            raise ValueError("wrong retry number: %s" % (retries,))
        self.session = session
        self.action = action
        self.retries = retries

    def __exit__(self, exc_type, exc_val, traceback):
        pass

    def __enter__(self):
        attempts_left = 1 + self.retries
        last_exc = None
        while attempts_left > 0:
            attempts_left -= 1
            try:
                return self.do_action()
            except (OperationalError, InvalidRequestError, InternalError) as exc:
                try:
                    self.session.rollback()
                except DBAPIError:
                    # connection may be already broken at this point
                    pass
                last_exc = exc
            except DBAPIError as exc:
                if exc.connection_invalidated:
                    try:
                        self.session.rollback()
                    except DBAPIError:
                        # connection may be already broken at this point
                        pass
                    last_exc = exc
            except StatementError as exc:
                last_exc = exc

            time.sleep(0.1)  # TODO: time.sleep(sleep_generator.get_sleep_delay())

        raise ConnectionErrorFor503("unable to satisfy DB request in %d retries" % (self.retries,)) from last_exc

    def do_action(self):
        return self.action()


if __name__ == "__main__":
    query = None

    with ExecuteRequest(None, query) as result:
        print(type(result))
