# -*- coding: utf-8 -*-
import itertools
import logging

import ujson
from more_itertools import partition

from travel.avia.ticket_daemon.ticket_daemon.api.models_utils import get_companies_by_iata
from travel.avia.ticket_daemon.ticket_daemon.api.cache import shared_cache

MIN_PRICE_CACHE_TIME = 60 * 90
logger = logging.getLogger(__name__)


def cache_min_price_from_variants(query, variants):
    city_from = query.point_from.get_related_settlement()
    city_to = query.point_to.get_related_settlement()

    if not city_from or not city_to:
        logger.warning('No related settlements. Do not cache: %r %r. %s', city_from, city_to, query.id)
        return

    # Пропускаем бизнес класс
    variants = [v for v in variants if v.klass == 'economy']

    indirects, directs = partition(lambda variant: variant.is_direct, variants)
    # Нужно закэшировать 2 цены. С пересадкой и без.
    for variants in [directs, indirects]:
        variants = list(variants)
        if not variants:
            continue

        min_tariff = min(v.national_tariff for v in variants)
        min_price_variants = [
            v for v in variants
            if (v.national_tariff.currency == min_tariff.currency and
                v.national_tariff.value == min_tariff.value)
        ]

        _cache_variants_min_price(
            query, min_price_variants,
            city_from, city_to,
            tariff=min_tariff,
            are_direct=min_price_variants[0].is_direct
        )


def _serialize_routes(variants, only_forward):
    forward_routes = u';'.join(
        u','.join(
            s.number for s in v.forward.segments
        )
        for v in variants
    )

    if only_forward:
        return forward_routes

    backward_routes = u';'.join(
        u','.join(
            s.number for s in v.backward.segments
        )
        for v in variants
    )

    return u'{}/{}'.format(forward_routes, backward_routes)


def _get_iata_codes(variants, only_forward):
    iata_codes = set()

    for v in variants:
        segment_groups = [v.forward.segments]

        if not only_forward:
            segment_groups.append(v.backward.segments)

        for segments in segment_groups:
            for s in segments:
                iata = s.number.split()[0]
                iata_codes.add(iata)

    return iata_codes - {None}


def _cache_variants_min_price(query, variants, city_from, city_to, tariff, are_direct):
    are_one_way = (not query.date_backward)
    routes = _serialize_routes(variants, only_forward=are_one_way)
    iata_codes = _get_iata_codes(variants, only_forward=are_one_way)
    companies = list({
        company.id for company in itertools.chain.from_iterable(
            itertools.imap(get_companies_by_iata, iata_codes)
        )
    })
    data = {
        'price': tariff.value,
        'currency': tariff.currency,
        'routes': routes,
        'companies': companies,
    }

    key = key_by_params(
        city_from_key=city_from.point_key,
        city_to_key=city_to.point_key,
        national_version=query.national_version,
        passengers_key=passengers_key_from_dict(query.passengers),
        date_forward=query.date_forward,
        date_backward=query.date_backward,
        is_direct=are_direct
    )

    shared_cache.set(key, ujson.dumps(data), timeout=MIN_PRICE_CACHE_TIME)

    logger.info(u'Cache min price[%s]: %r', key, data)


def passengers_key_from_dict(passengers):
    return '%s_%s_%s' % (
        passengers.get('adults', 0),
        passengers.get('children', 0),
        passengers.get('infants', 0),
    )


def key_by_params(
    city_from_key,
    city_to_key,
    national_version,
    passengers_key,
    date_forward,
    date_backward,
    is_direct,
):
    return 'min_price_json_%s_%s_%s_%s_%s_%s_%s' % (
        city_from_key,
        city_to_key,
        national_version,
        passengers_key,
        date_forward.strftime('%Y-%m-%d'),
        date_backward.strftime('%Y-%m-%d') if date_backward else None,
        'direct' if is_direct else 'transfers',
    )
