from travel.avia.library.python.shared_dicts.rasp import iter_protobuf_data, ResourceType


class CompanyCache:
    SANDBOX_RESOURCE = ResourceType.TRAVEL_DICT_RASP_CARRIER_PROD

    def __init__(self, logger, oauth=None):
        self.company_by_id = {}
        self.company_by_code = {}
        self.logger = logger
        self.oauth = oauth

    def get_company_by_id(self, company_id):
        if not company_id:
            return None
        if company_id in self.company_by_id:
            return self.company_by_id[company_id]
        self.logger.warning('No company for company_id %d in the cache', company_id)

    def get_company_by_code(self, company_code):
        if not company_code:
            return None
        company_code = company_code.upper()
        if company_code in self.company_by_code:
            return self.company_by_code[company_code]
        self.logger.warning('No company for company_code %d in the cache', company_code)

    def get_company_code_by_id(self, company_id):
        company = self.get_company_by_id(company_id)
        return company.Iata or company.SirenaId or company.Icao if company else None

    def clear_cache(self):
        self.company_by_id = {}

    def populate(self):
        self.company_by_id, self.company_by_code = self.build_cache(self.oauth)
        self.logger.info('populated cache with %d companies', len(self.company_by_id))

    def update_cache(self):
        try:
            self.company_by_id, self.company_by_code = self.build_cache(self.oauth)
        except Exception as e:
            self.logger.error('Failed to update company cache. Will continue using old cache. Reason: %s', e)

        self.logger.info('repopulated cache with %d companies', len(self.company_by_id))

    @staticmethod
    def build_cache(oauth=None):
        company_by_id = {}
        company_by_code = {}

        for company in iter_protobuf_data(CompanyCache.SANDBOX_RESOURCE, oauth):
            company_by_id[company.Id] = company
            if company.Iata:
                company_by_code[company.Iata.upper()] = company
            if company.SirenaId:
                sirena = company.SirenaId
                try:
                    sirena = sirena.decode('utf8')
                except AttributeError:
                    pass
                company_by_code[sirena.upper()] = company

        return company_by_id, company_by_code
