# -*- coding: utf-8 -*-
import logging
from itertools import chain

from marshmallow import Schema, fields, post_load

from travel.avia.ticket_daemon.ticket_daemon.api.min_price_cacher import cache_min_price_from_variants
from travel.avia.ticket_daemon.ticket_daemon.lib.currency import Price
from travel.avia.ticket_daemon.ticket_daemon.lib.yt_loggers.min_price_logger import min_price_logger
from travel.avia.ticket_daemon.ticket_daemon.lib.utils import wrap

log = logging.getLogger(__name__)


class PriceSchema(Schema):
    value = fields.Float()
    currency = fields.String()

    @post_load
    def make_object(self, data):
        return Price(data['value'], data['currency'])


class PartnerSchema(Schema):
    code = fields.String()


class FlightSchema(Schema):
    number = fields.String()
    departure = fields.DateTime(format='%Y-%m-%d')
    station_from_iata = fields.String()
    station_to_iata = fields.String()


class FlightsSchema(Schema):
    segments = fields.List(fields.Nested(FlightSchema))


class MinPriceVariantSchema(Schema):
    partner = fields.Nested(PartnerSchema)
    forward = fields.Nested(FlightsSchema)
    backward = fields.Nested(FlightsSchema)
    national_tariff = fields.Nested(PriceSchema)
    tariff = fields.Nested(PriceSchema)
    klass = fields.String(allow_none=True)
    raw_tariffs = fields.Raw()
    created_dt = fields.DateTime()
    expire_dt = fields.DateTime()


class MinPriceVariantsSameStopsSchema(Schema):
    variants = fields.Nested(MinPriceVariantSchema, many=True)
    stops = fields.String()
    national_tariff = fields.Nested(PriceSchema)
    tariff = fields.Nested(PriceSchema)

    @post_load
    def wrap_it(self, item):
        return wrap(item)


class MinPriceVariantsSameStops(object):
    schema = MinPriceVariantsSameStopsSchema

    def __init__(self, variants, stops, national_tariff, tariff):
        self.variants = variants
        self.stops = stops
        self.national_tariff = national_tariff
        self.tariff = tariff

    def __repr__(self):
        return '<{}: [{}] {} variants: {}>'.format(
            self.__class__.__name__,
            self.stops, self.national_tariff, len(self.variants)
        )

    @classmethod
    def variant_stops_key(cls, v):
        return '%d_%d' % (
            min(1, len(v.forward.segments) - 1),
            min(1, len(v.backward.segments) - 1),
        )

    @classmethod
    def divide_variants(cls, variants, national_version, rates):
        log.info('divide_variants')

        variants_orig_len = len(variants)
        # Skip variants without national_tariff
        variants = [v for v in variants if v.national_tariff]
        variants_national_len = len(variants)
        variants_dropped_len = variants_orig_len - variants_national_len

        log.info(
            'divide_variants[%s] used:%d, dropped:%d, rates:%r',
            national_version, variants_national_len, variants_dropped_len, rates
        )

        # Выкидывает список инстансов MinPriceVariantsSameStops
        return list(cls._gen_from_variants(variants))

    @classmethod
    def _gen_from_variants(cls, variants):
        variants_by_stops = {}

        for variant in variants:
            variants_by_stops.setdefault(
                cls.variant_stops_key(variant), []
            ).append(variant)

        for stops, same_stops_price_info in variants_by_stops.items():
            min_price_variant = min(same_stops_price_info, key=lambda v: v.national_tariff)

            # Collect variants with min price
            same_stops_min_price_variants = [
                v for v in same_stops_price_info
                if v.national_tariff == min_price_variant.national_tariff
            ]

            log.info(
                'gen_from_variants stops:%s min_price:%s variants:%d:%s',
                stops,
                min_price_variant.national_tariff,
                len(same_stops_min_price_variants),
                ','.join(sorted(set([
                    v.partner.code for v in same_stops_min_price_variants
                ])))
            )

            yield cls(
                same_stops_min_price_variants,
                stops,
                min_price_variant.national_tariff,
                min_price_variant.tariff
            )

    @classmethod
    def aggregate(cls, result_min_prices_groups, query):
        """
            1) Сохранить мин. цену в мемкэш для avia-api
            2) Сохранить в файл для поставки в YT rasp-min-price-log
        """
        all_variants = list(chain.from_iterable([
            g.variants for g in result_min_prices_groups
        ]))

        cache_min_price_from_variants(query, all_variants)

        infos = list(cls._gen_from_variants(all_variants))
        min_price_logger.log(query, infos)


class PartnerMinPrices(object):
    def __init__(self, material):
        self.material = material

    def __nonzero__(self):
        return bool(self.material)

    @classmethod
    def aggregate(cls, min_prices_of_partners, query):
        MinPriceVariantsSameStops.aggregate(
            list(chain.from_iterable([
                partner_min_prices.material
                for partner_min_prices in min_prices_of_partners
            ])),
            query
        )

    def to_dict(self):
        log.info('PartnerMinPrices pack material %r', self.material)
        result = {
            '__class_name__': u'PartnerMinPrices',
            # Todo: __class_name__ нужен для обратной совместимости, можно выпиливать после выкатки
            'material': MinPriceVariantsSameStopsSchema().dump(self.material, many=True).data
        }
        log.info('PartnerMinPrices pack material result %r', result)
        return result

    @classmethod
    def create_from_variants(cls, variants, national_version, rates):
        material = MinPriceVariantsSameStops.divide_variants(
            variants, national_version, rates
        )
        return cls(material)

    @classmethod
    def create_from_serialized(cls, partner_min_prices):
        material = MinPriceVariantsSameStopsSchema().load(partner_min_prices['material'], many=True).data
        return cls(material)
