from travel.proto.dicts.avia.company_pb2 import TCompany

from travel.library.python.dicts.base_repository import BaseRepository


class CompanyRepository(BaseRepository):
    _PB = TCompany

    def __init__(self):
        super(CompanyRepository, self).__init__()
        self._company_by_iata = {}
        self._company_by_sirena = {}

    def _fill_company_by_code(self):
        for c in self.itervalues():
            if c.Iata:
                self._company_by_iata[c.Iata] = c
            if c.SirenaID:
                self._company_by_sirena[c.SirenaID] = c

    def load_from_file(self, path):
        super(CompanyRepository, self).load_from_file(path)
        self._fill_company_by_code()

    def load_from_string(self, content):
        super(CompanyRepository, self).load_from_string(content)
        self._fill_company_by_code()

    def get_company_by_code(self, code):
        company = self._company_by_iata.get(code)
        if company:
            return company
        return self._company_by_sirena.get(code)
