# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

from collections import defaultdict

from marshmallow import Schema, fields, post_load

from common.data_api.ticket_daemon.serialization.avia_partner import AviaPartner
from common.data_api.ticket_daemon.serialization.itinerary import Itinerary, ItinerarySchema
from common.models.geo import Station
from common.serialization.common_schemas import PriceSchema
from common.serialization.fields import DictField, DictNestedField
from common.serialization.schema import get_defaults_from_schema
from common.utils.namedtuple import namedtuple_with_defaults


class VariantSchema(Schema):
    deep_link = fields.String()  # ссылка на переход к партнеру через авиабилеты
    order_link = fields.String()  # ссылка на страницу покупки с другими ценами и информацией о рейсе
    forward = fields.Nested(ItinerarySchema, allow_none=True)
    backward = fields.Nested(ItinerarySchema, allow_none=True)
    order_data = fields.Raw(allow_none=True)
    tariff = fields.Nested(PriceSchema)
    partner = fields.Method(deserialize='partner_by_code')
    raw_seats = DictField(fields.Integer(), allow_none=True)
    raw_tariffs = DictNestedField(PriceSchema, allow_none=True)
    raw_is_several_prices = DictField(fields.Boolean(), allow_none=True)
    query_time = fields.Int()
    from_company = fields.Boolean()

    @post_load
    def post_load(self, data):
        for itinerary_field in ('forward', 'backward'):
            if not data.get(itinerary_field):
                data[itinerary_field] = Itinerary(segments=[])

        return Variant(**data)

    def partner_by_code(self, code):
        if self.context is None:
            return None
        return self.context['partners'][code]

    class Meta:
        strict = True


Variant = namedtuple_with_defaults(
    'Variant', VariantSchema().fields.keys(), get_defaults_from_schema(VariantSchema)
)


def parse_variants(variants_data, reference):
    variant_schema = VariantSchema(context={
        'reference': reference,
        'stations': collect_stations(reference.get('flights', {})),
        'flights_by_key': _build_flights_by_key(reference.get('flights', {})),
        'itineraries': reference.get('itineraries', {}),
        'partners': {
            p['code']: AviaPartner(p['code'], p['title'], p['logoSvg'])
            for p in reference.get('partners', [])
        }
    })

    result = defaultdict(list)
    for partner_variant in variants_data:
        partner_code = partner_variant['partner']
        result[partner_code].append(variant_schema.load(partner_variant).data)
    return result


def collect_stations(flights):
    station_ids = set()
    for flight in flights:
        for fname in ('station_from', 'station_to', 'first_station', 'last_station'):
            station_ids.add(flight.get(fname))
    return Station.objects.in_bulk(list(station_ids))


def _build_flights_by_key(flights):
    return {flight['key']: flight for flight in flights}
