from __future__ import absolute_import

import re
from typing import List  # noqa
from operator import attrgetter
import logging

from travel.avia.library.python.common.models.iatacorrection import IataCorrection
from travel.avia.library.python.common.utils.iterrecipes import group_by

logger = logging.getLogger(__name__)


class IataCorrectionModel(object):
    __slots__ = (
        'id', 'code', 'number', 'company_id', '_number_re',
    )

    def __init__(self, pk, code, number, company_id):
        # type: (int, str, str, int) -> None
        self.id = pk
        self.code = code
        self.number = number
        self.company_id = company_id
        self._number_re = self._try_compile_number_re()

    def _try_compile_number_re(self):
        try:
            return re.compile(self.number, re.I + re.U)
        except Exception as err:
            logger.error(u'IataCorrection[%s][%s] compile number %r. %r',
                         self.id, self.code, self.number, err)
            return None

    def match_number(self, number):
        return self._number_re.match(number) if self._number_re else None

    def __repr__(self):
        return u'<IataCorrectionModel id={} code={} number={} company_id={}>'.format(
            unicode(self.id),
            self.code,
            self.number,
            self.company_id
        )


class IataCorrectionRepository(object):
    def __init__(self):
        self._index = {}
        self._corrections_by_iata = {}

    @staticmethod
    def _load_db_models():
        # type: () -> List[dict]
        fields = ['id', 'code', 'number', 'company_id']

        return list(IataCorrection.objects.values(*fields))

    def pre_cache(self):
        # type: () -> None
        _iata_corrections = self._load_db_models()

        index = {}

        for s in _iata_corrections:
            pk = s['id']
            m = IataCorrectionModel(
                pk=pk,
                code=s['code'],
                number=s['number'],
                company_id=s['company_id'],
            )

            index[pk] = m

        self._index = index
        self._corrections_by_iata = dict(group_by(self.get_all(), attrgetter('code')))

    def get_all(self):
        # type: () -> List[IataCorrectionModel]
        return self._index.values()

    def get_corrections_by_iata(self):
        # type () -> dict
        return self._corrections_by_iata


iata_correction_repository = IataCorrectionRepository()
