# -*- coding: utf-8 -*-
import json
import logging
from collections import defaultdict
from itertools import chain

from django.db.models import Q

from travel.avia.library.python.avia_data.models import AviaCompany, CompanyTariff
from travel.avia.library.python.common.importinfo.models.bus import BuscomuaStationCode
from travel.avia.library.python.common.models.geo import Station, StationCode, CodeSystem
from travel.avia.library.python.common.models.iatacorrection import IataCorrection
from travel.avia.library.python.common.models.schedule import Company
from travel.avia.library.python.common.utils import tracer

from travel.avia.avia_api.avia.lib.serialization import (
    JsonSerializable, ListConverter, ModelConverter, construct_cls_from_attrs,
)

log = logging.getLogger(__name__)


class CodeCompanyCache(object):
    cache = {}

    precached = False

    @classmethod
    def precache(cls):
        companies = Company.objects.all_cached()

        for company in companies:
            if company.iata:
                cls.cache[company.iata] = company
            if company.sirena_id:
                cls.cache[company.sirena_id] = company
            if company.icao:
                cls.cache[company.icao] = company

        cls.precached = True

    @classmethod
    def get_companies(cls, codes):
        if not cls.precached:
            return list(Company.objects.filter(
                Q(iata__in=codes) |
                Q(sirena_id__in=codes) |
                Q(icao__in=codes)
            ))

        result = []

        for code in codes:
            c = cls.cache.get(code)
            if c:
                result.append(c)

        return result


class Reference(JsonSerializable):
    u"""Справочник компаний и станций, пока только по IATA кодам"""

    _log = logging.getLogger('avia.api.reference.unknown_codes')

    def __init__(self):
        self._code_companies = set()
        self._iata_stations = set()
        self._stations_in_code_system = {}
        self._stations_by_code_system = {}

        self._buscomua_stations = set()
        self._city_title_stations = set()

        self.cache = {}

    _json_attrs = {
        'stations': ListConverter(ModelConverter(Station)),
        'companies': ListConverter(ModelConverter(Company)),
    }

    @classmethod
    def merge(cls, references):
        def unique(items, key):
            return {key(v): v for v in items}.values()

        return construct_cls_from_attrs(cls, {
            attr: unique(
                chain.from_iterable([getattr(r, attr) for r in references]),
                key=lambda v: v.id
            )
            for attr in cls._json_attrs.keys()
        })

    @classmethod
    @tracer.wrap
    def complete(cls, flights, url=None, partner_code=None):
        reference = cls()

        for flight in flights:
            flight.add_to(reference)

        reference._fetch()

        reference._url = url
        reference._partner_code = partner_code

        for flight in flights:
            reference._flight = flight
            flight.complete_from(reference)

    @classmethod
    def complete_variants(cls, variants, *args, **kwargs):
        cls.complete([
            f
            for v in variants
            for f in v.all_segments
        ], *args, **kwargs)

    @classmethod
    def complete_partner_variants(cls, p_code, variants, search_url):
        for v in variants:
            v.partner_code = p_code

        cls.complete_variants(
            variants,
            partner_code=p_code,
            url=search_url
        )

    def add_company(self, code):
        self._code_companies.add(code)

    def add_station(self, iata=None, city_title=None, buscomua=None):
        if iata:
            self._iata_stations.add(iata)

        if city_title:
            self._city_title_stations.add(city_title)

        if buscomua:
            self._buscomua_stations.add(buscomua)

    def add_code_system_station(self, code, code_system):
        self._stations_in_code_system.setdefault(code_system, set()).add(code)

    def company(self, code, flight_number=None):
        company = self._get(self._companies, code, 'company code')

        # Правила коррекции привязаны к iata-коду.
        # Если нашли компанию по сирена или icao,
        # значит правила коррекции нужно искать по company.iata
        # RASPTICKETS-4807
        if company:
            code = company.iata

        return (
            self.correct_iata_company_by_flight_number(code, flight_number) or
            # Если не удалось скорректировать по правилам,
            # вернём любую (или ничего)
            company
        )

    def correct_iata_company_by_flight_number(self, code, flight_number):
        if not flight_number:
            return

        if code not in self._company_iata_corrections:
            return

        # Возьмём только номер без кода авиакомпании вначале
        number = flight_number.split()[-1]

        for correction in self._company_iata_corrections[code]:
            try:
                match = correction.match_number(number)
            except Exception as err:
                log.warning(
                    u'IataCorrection[%s].match_number() error. [%s] %r. %r',
                    correction.id, correction.code, correction.number, err
                )
            else:
                if match:
                    try:
                        return correction.company
                    except Exception as err:
                        log.warning(
                            u'IataCorrection[%s] company_id[%r]: %r',
                            correction.id, correction.company_id
                        )

    def station(self, iata=None, city_title=None, buscomua=None):
        if iata:
            return self._get(self._stations_by_iata, iata, 'station iata')

        if city_title:
            return self._get(self._stations_by_city_title, city_title, 'station city title')

        if buscomua:
            station = self._get(self._stations_by_buscomua_with_sale_server, buscomua, 'station buscomua')
            if not station:
                station = self._get(self._stations_by_buscomua_2code, buscomua[1:], 'station buscomua')

            return station

    def get_code_system_station(self, code, code_system):
        return self._get(self._stations_by_code_system.get(code_system, {}), code, '{} code'.format(code_system))

    def _get(self, d, key, desc):
        try:
            return d[key]

        except KeyError:
            if hasattr(self._flight, 'partner'):
                partner_code = self._flight.partner.code
            else:
                partner_code = self._partner_code

            message = [desc, key, self._flight.number, partner_code, self._url]

            self._log.warning(json.dumps(message))

            return None

    def _fetch(self):
        self._fetch_companies()
        self._fetch_stations()

    def _fetch_companies(self):
        codes = list(self._code_companies)

        companies = CodeCompanyCache.get_companies(codes)

        self._companies = {}

        for c in companies:
            if c.iata:
                self._companies[c.iata] = c

            if c.sirena_id:
                self._companies[c.sirena_id] = c

            if c.icao:
                self._companies[c.icao] = c

        self.companies = set(companies)

        self._fetch_company_iata_corrections()

        self._fetch_aviacompanies()

    def _fetch_company_iata_corrections(self):
        corrections = IataCorrection.objects. \
            filter(code__in=self._companies.keys())

        _corrections = defaultdict(set)

        for c in corrections:
            _corrections[c.code].add(c)
            self.companies.add(c.company)

        self._company_iata_corrections = dict(_corrections)

    def _fetch_aviacompanies(self):
        self.aviacompanies_by_rasp_company_id = {
            ac.rasp_company_id: ac
            for ac in AviaCompany.objects.filter(
                rasp_company__in=list(self.companies)
            )
        }

        # Зафетчим тарифы авиакомпаний
        companies_tariffs = CompanyTariff.objects.filter(
            published=True,
            avia_company__in=self.aviacompanies_by_rasp_company_id.values(),
        )
        tariffs_by_aviacompany_id = defaultdict(list)
        for tariff in companies_tariffs:
            tariffs_by_aviacompany_id[tariff.avia_company_id].append(tariff)
        self.tariffs_by_aviacompany_id = dict(tariffs_by_aviacompany_id)

    def _fetch_stations(self):
        self._stations_by_iata = self._get_station_by_code_system(self._iata_stations, 'iata')

        self._stations_by_iata.update(self._get_station_by_code_system(self._iata_stations, 'sirena'))

        self.stations = set(self._stations_by_iata.values())

        self._stations_by_city_title = self._get_stations_by_title()

        self.stations.update(self._stations_by_city_title.values())

        self._stations_by_buscomua_2code, self._stations_by_buscomua_with_sale_server = \
            self._get_buscomua_stations(self._buscomua_stations)

        self.stations.update(self._stations_by_buscomua_with_sale_server.values())

        for code_system in self._stations_in_code_system:
            self._stations_by_code_system[code_system] = \
                self._get_station_by_code_system(self._stations_in_code_system[code_system], code_system)

            self.stations.update(self._stations_by_code_system[code_system].values())

    def _get_station_by_code_system(self, codes, code_system):
        if not codes:
            return {}

        station_id_2_iata = StationCode.objects.filter(
            system=CodeSystem.objects.get(code=code_system),
            code__in=list(codes)
        ).values_list('station_id', 'code')

        station_by_id = Station.objects.in_bulk(s2i[0] for s2i in station_id_2_iata)

        return dict(
            (s2i[1], station_by_id[s2i[0]])
            for s2i in station_id_2_iata
        )

    def _get_stations_by_title(self):
        if not self._city_title_stations:
            return dict()

        q = Q()

        for city, title in self._city_title_stations:
            q |= Q(settlement__title=city, title=title)

        results = Station.objects.filter(q, t_type__id=2).values_list('id', 'title', 'settlement__title')

        stations_in_bulk = Station.objects.in_bulk(r[0] for r in results)

        return dict(
            # (settlement.title, station.title) -> station
            ((r[2], r[1]), stations_in_bulk[r[0]])
            for r in results
        )

    def _get_buscomua_stations(self, buscomua_stations):
        if not buscomua_stations:
            return dict(), dict()

        q = Q()
        for server_code, city_code, station_code in buscomua_stations:
            q |= Q(city_code=city_code, station_code=station_code)

        station_id_by_codes = {
            (server_code, city_code, station_code): station
            for server_code, city_code, station_code, station in
                BuscomuaStationCode.objects.filter(q).values_list('server_code', 'city_code', 'station_code', 'station')
        }

        stations_in_bulk = Station.objects.in_bulk(station_id_by_codes.values())

        stations_with_server_code = {
            (server_code, city_code, station_code): stations_in_bulk[station_id]
            for (server_code, city_code, station_code), station_id in station_id_by_codes.items()
        }

        buscomua_stations_by_codes = {
            (city_code, station_code): stations_in_bulk[station_id]
            for (server_code, city_code, station_code), station_id in station_id_by_codes.items()
        }

        return buscomua_stations_by_codes, stations_with_server_code
