from logging import Logger, getLogger

import requests
from typing import List, Optional

from travel.avia.price_index.lib.settings import Settings, settings


class NationalVersionModel(object):
    def __init__(self, pk, code):
        # type: (int, str) -> None
        self.pk = pk
        self.code = code

    def __repr__(self):
        return '<NationalVersionModel pk={} code={}>'.format(self.pk, self.code)


class NationalVersionProvider(object):
    def __init__(self, settings, logger):
        # type: (Settings, Logger) -> None
        self._backend_host = settings.BACKEND_HOST

        self._by_id = {}
        self._by_code = {}
        self._all = []

        self._logger = logger

    def fetch(self):
        self._logger.info('start: refetch national versions')

        result = requests.get('{}/rest/national_versions'.format(self._backend_host)).json()

        if result['status'] != 'ok':
            raise Exception('Error in backend: %r', result)

        all_models = []
        by_id = {}
        by_code = {}

        for raw_item in result['data']:
            m = NationalVersionModel(pk=raw_item['id'], code=raw_item['code'])

            all_models.append(m)
            by_id[m.pk] = m
            by_code[m.code] = m

        self._all = all_models
        self._by_code = by_code
        self._by_id = by_id

        self._logger.info('finish: refetch national versions')

    def get_by_id(self, pk):
        # type: (str) -> Optional[NationalVersionModel]
        if pk not in self._by_id:
            self._logger.warn('Can not find national version by id %r', pk)
            return None

        return self._by_id[pk]

    def get_by_code(self, code):
        # type: (str) -> Optional[NationalVersionModel]
        if code not in self._by_code:
            self._logger.warn('Can not find national version by code %r', code)
            return None

        return self._by_code[code]

    def get_all(self):
        # type: () -> List[NationalVersionModel]
        return self._all


national_version_provider = NationalVersionProvider(settings=settings, logger=getLogger(__name__))
