# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

import re
import logging

from common.models.schedule import Company
from common.utils.caching import cache_method_result
from travel.rasp.admin.importinfo.models import IataCorrection


log = logging.getLogger(__name__)


class AviaCompanyFinder(object):
    def __init__(self):
        self._cached_iata = {}
        self.errors = set()

    def get_company(self, code, flight_number=None, stations=None):
        company = self.get_by_iata(code, flight_number, stations)
        if company:
            return company

        company = self.get_by_sirena(code)
        if company:
            return company

        company = self.get_by_icao(code)
        if company:
            return company

        company = self.get_by_icao_ru(code)
        if company:
            return company

        self.error('Не нашли компанию с кодом {} номер рейса {}'.format(code, flight_number))
        return

    def get_by_iata(self, code, flight_number, stations):
        station_ids = [s.id for s in stations] if stations else []
        station_ids.sort()

        key = '{}_{}_{}'.format(code, flight_number, '&'.join(map(unicode, station_ids)))
        if key in self._cached_iata:
            return self._cached_iata[key]

        company = self.correct_iata_company_by_number(code, flight_number)
        if company:
            self._cached_iata[key] = company
            return company

        iata_companies = self.get_companies_by_iata_code(code)
        if len(iata_companies) == 1:
            self._cached_iata[key] = iata_companies[0]
            return iata_companies[0]

        if not iata_companies:
            return

        if not stations:
            self.error('Не приложены станции не можем откорректировать компанию {} {}'.format(code, flight_number))
            self._cached_iata[key] = iata_companies[0]
            return iata_companies[0]

        company = self.correct_iata_company_by_station_country(iata_companies, stations)
        if company:
            self._cached_iata[key] = company
            return company
        else:
            self.error('Не удалось скорректировать по стране iata компанию {} {}'.format(code, flight_number))
            self._cached_iata[key] = iata_companies[0]
            return iata_companies[0]

    @cache_method_result
    def get_iata_corrections(self, code):
        corrections = list(IataCorrection.objects.filter(code=code))

        return corrections

    def find_by_corrections(self, corrections, flight_number):
        for correction in corrections:
            number_re = re.compile(correction.number, re.I + re.U)
            if number_re.match(flight_number):
                return correction.company

    def error(self, message):
        self.errors.add(message)
        log.error(message)

    def correct_iata_company_by_number(self, code, flight_number):
        corrections = self.get_iata_corrections(code)

        if corrections:
            if not flight_number:
                self.error('Не указан номер, не возможно точно'
                           ' определить компанию по iata {}'.format(code))
                return

            company = self.find_by_corrections(corrections, flight_number)
            if company:
                return company
            else:
                self.error('Коррекции для iata {} не покрывают случай с номером {}'.format(
                    code, flight_number
                ))

    def correct_iata_company_by_station_country(self, iata_companies, stations):
        country_set = set([s.country_id for s in stations])

        for company in iata_companies:
            if company.country_id in country_set:
                return company

    @cache_method_result
    def get_companies_by_iata_code(self, code):
        return list(Company.objects.filter(iata=code).order_by('id'))

    @cache_method_result
    def get_by_sirena(self, code):
        try:
            return Company.objects.get(sirena_id=code)
        except Company.DoesNotExist:
            pass

    @cache_method_result
    def get_by_icao(self, code):
        try:
            return Company.objects.get(icao=code)
        except Company.DoesNotExist:
            pass

    @cache_method_result
    def get_by_icao_ru(self, code):
        try:
            return Company.objects.get(icao_ru=code)
        except Company.DoesNotExist:
            pass
