# -*- coding: utf-8 -*-
from logging import getLogger

from requests import ConnectionError
from retrying import retry

from travel.avia.library.python.common.saas.json_saas import JsonSaas, saas_index, saas_search


logger = getLogger(__name__)


def retry_on_connection_exc(exc):
    logger.warning('Caught %r on requesting SAAS', exc)
    return isinstance(exc, ConnectionError)


class RetryingJsonSaas(JsonSaas):
    @retry(
        retry_on_exception=retry_on_connection_exc,
        stop_max_attempt_number=3,
        wait_exponential_multiplier=100, wait_exponential_max=500
    )
    def search(self, keys, columns=None, label=None, low_priority=False, timeout=None):
        return super(RetryingJsonSaas, self).search(keys, columns, label, low_priority, timeout)

    @retry(
        retry_on_exception=retry_on_connection_exc,
        stop_max_attempt_number=3,
        wait_exponential_multiplier=100, wait_exponential_max=500
    )
    def index(self, key, doc, labels=None, expires_at=None, realtime=True, timeout=None):
        return super(RetryingJsonSaas, self).index(
            key, doc, labels, expires_at, realtime, timeout
        )


retrying_json_saas = RetryingJsonSaas(
    index=saas_index,
    search=saas_search,
)
