import requests
from datetime import timedelta

from logging import getLogger, Logger

from typing import Dict

from travel.avia.price_index.lib.environment import environment, Environment
from travel.avia.price_index.lib.national_version_provider import (
    NationalVersionProvider,
    NationalVersionModel,
    national_version_provider,
)
from travel.avia.price_index.lib.settings import settings, Settings


class RatesProvider(object):
    def __init__(self, settings, national_version_provider, environment, logger):
        # type: (Settings, NationalVersionProvider, Environment, Logger) -> None
        self._backend_host = settings.BACKEND_HOST
        self._environment = environment
        self._national_version_provider = national_version_provider
        self._logger = logger

        self._rates_by_nv_id = None
        self._base_currency_id_by_nv_id = None
        self._last_update = None

    def fetch(self):
        # type: () -> None
        self._logger.info('start: refetch national versions')
        index = {}
        base_currency_id_to_nv = {}
        for nv in self._national_version_provider.get_all():
            for i in range(3):
                try:
                    data = self._fetch_rates_for(nv)

                    index[nv.pk] = {item['currency_id']: item['rate'] for item in data['rates']}
                    base_currency_id_to_nv[nv.pk] = data['base_currency_id']

                    break
                except Exception:
                    raise
                    pass
            else:
                if not self._rates_by_nv_id:
                    raise Exception('Can not fetch rates in 3 retries')

        self._rates_by_nv_id = index
        self._base_currency_id_by_nv_id = base_currency_id_to_nv
        self._last_update = self._environment.now()
        self._logger.info('end: refetch national versions')

    def _fetch_rates_for(self, national_version):
        # type: (NationalVersionModel) -> dict
        result = requests.get('{}/rest/currencies/rates/{}/ru'.format(self._backend_host, national_version.code)).json()

        if result['status'] != 'ok':
            raise Exception('Error in backend: %r', result)

        return result['data']

    def _check_fetch(self):
        if self._environment.now() - self._last_update < timedelta(hours=1):
            return

        self.fetch()

    def get_rates_for(self, nv_id):
        # type: (int) -> Dict[int, float]
        self._check_fetch()

        if nv_id not in self._rates_by_nv_id:
            self._logger.warn('Unknown national version: %s', nv_id)
            return self._rates_by_nv_id[self._national_version_provider.get_by_code('ru').pk]

        return self._rates_by_nv_id[nv_id]

    def get_base_currency_id(self, nv_id):
        # type: (int) -> int
        if nv_id not in self._base_currency_id_by_nv_id:
            self._logger.warn('Unknown national version id: %s', nv_id)
            return self._base_currency_id_by_nv_id[self._national_version_provider.get_by_code('ru').pk]

        return self._base_currency_id_by_nv_id[nv_id]


rates_provider = RatesProvider(
    settings=settings,
    national_version_provider=national_version_provider,
    environment=environment,
    logger=getLogger(__name__),
)
