# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
from time import sleep

from pymongo.collection import Collection

log = logging.getLogger(__name__)

COLLECTION_FUNCTIONS_TO_PATCH = {'find', 'aggregate'}


def add_read_retry_to_pymongo(max_attempt=5, sleep_duration=1):
    Collection._rasp_try_wrapper = _rasp_try_wrapper

    def _wrap(name):
        wrapped_name = '_rasp_wrapped_{}'.format(name)

        if not hasattr(Collection, wrapped_name):
            setattr(Collection, wrapped_name, getattr(Collection, name))

        setattr(Collection,
                name,
                lambda self, *args, **kwargs: self._rasp_try_wrapper(
                    rasp_func_name=wrapped_name,
                    rasp_max_attempt=max_attempt,
                    rasp_sleep_duration=sleep_duration,
                    *args, **kwargs
                ))

    for name in COLLECTION_FUNCTIONS_TO_PATCH:
        _wrap(name)


def _rasp_try_wrapper(self, *args, **kwargs):
    name = kwargs.pop('rasp_func_name')
    max_attempt = kwargs.pop('rasp_max_attempt')
    sleep_duration = kwargs.pop('rasp_sleep_duration')
    attempt = 0

    while True:
        try:
            attempt += 1
            log.log(logging.DEBUG if attempt == 1 else logging.WARNING,
                    '_rasp_try_wrapper(%s, %s, %s): attempt %d', name, max_attempt, sleep_duration, attempt)

            return getattr(self, name)(*args, **kwargs)
        except Exception as e:
            if attempt >= max_attempt:
                raise e
            else:
                log.exception(e)
                if sleep_duration:
                    sleep(sleep_duration)
