from logging import Logger, getLogger

import requests
from typing import List, Optional

from travel.avia.price_index.lib.national_version_provider import NationalVersionProvider, national_version_provider
from travel.avia.price_index.lib.settings import Settings, settings


class CurrencyModel(object):
    def __init__(self, pk, code):
        # type: (int, str) -> None
        self.pk = pk
        self.code = code

    def __repr__(self):
        return "<CurrencyModel pk={} code={}>".format(self.pk, self.code)


class CurrencyProvider(object):
    def __init__(self, settings, national_version_provider, logger):
        # type: (Settings, NationalVersionProvider, Logger) -> None
        self._backend_host = settings.BACKEND_HOST

        self._by_id = {}
        self._by_code = {}
        self._all_by_nv = []

        self._national_version_provider = national_version_provider
        self._logger = logger

    def fetch(self):
        self._logger.info('start: fetch currency versions')

        all_models = {}
        by_id = {}
        by_code = {}

        for nv in self._national_version_provider.get_all():
            result = requests.get('{}/rest/currencies/{}/{}'.format(self._backend_host, nv.code, 'ru')).json()

            if result['status'] != 'ok':
                raise Exception('Error in backend: %r', result)

            all_models_for_nv = []
            by_id_for_nv = {}
            by_code_for_nv = {}

            for raw_item in result['data']:
                m = CurrencyModel(pk=raw_item['id'], code=raw_item['code'])

                all_models_for_nv.append(m)
                by_id_for_nv[m.pk] = m
                by_code_for_nv[m.code] = m

            all_models[nv.pk] = all_models_for_nv
            by_id[nv.pk] = by_id_for_nv
            by_code[nv.pk] = by_code_for_nv

        self._all_by_nv = all_models
        self._by_code = by_code
        self._by_id = by_id

        self._logger.info('finish: fetch currency versions')

    def get_by_id(self, pk, nv_id):
        # type: (int, int) -> Optional[CurrencyModel]
        if nv_id not in self._by_id:
            self._logger.warn('Unknown national version id %r', nv_id)
            return None

        by_id = self._by_id[nv_id]
        if pk not in by_id:
            self._logger.warn('Unknown currency id %r', pk)
            return None

        return by_id[pk]

    def get_by_code(self, code, nv_id):
        # type: (str, int) -> Optional[CurrencyModel]
        if nv_id not in self._by_code:
            self._logger.warn('Unknown national version id %r', nv_id)
            return None

        by_code = self._by_code[nv_id]
        if code not in by_code:
            self._logger.warn('Unknown currency code %r', code)
            return None

        return by_code[code]

    def get_all(self, nv_id):
        # type: () -> List[CurrencyModel]
        if nv_id not in self._all_by_nv:
            self._logger.warn('Unknown national version %r', nv_id)
            return []

        return self._all_by_nv[nv_id]


currency_provider = CurrencyProvider(
    settings=settings, national_version_provider=national_version_provider, logger=getLogger(__name__)
)
