# -*- coding: utf-8 -*-
from __future__ import unicode_literals

import logging
from collections import defaultdict
from contextlib import closing
from itertools import product
from typing import Any, Callable, Dict, Iterable, List

from django.template import loader
from lxml import etree

from travel.avia.library.python.sirena_client import SirenaClient
from travel.avia.ticket_daemon.ticket_daemon.api.flights import Variant, Flight
from travel.avia.ticket_daemon.ticket_daemon.api.query import Query
from travel.avia.ticket_daemon.ticket_daemon.daemon.utils import sleep_every
from travel.avia.ticket_daemon.ticket_daemon.lib.sirena.variant_models import (
    SirenaFlight, SirenaVariant, SirenaJoint, SirenaVariantBuilder,
)

logger = logging.getLogger(__name__)


class SirenaFetcher(object):
    def __init__(self, sirena_client_factory):
        # type: (Callable[[],SirenaClient])->None
        """
        :param sirena_client_factory: is a zero-args factory method to get SirenaClient instance.
         `get` function of such instance is used to get pricing_route and mono_brands info
        """
        self.sirena_client_factory = sirena_client_factory

    def get_variants(self, query, company_iata, deeplink, on_finish=None):
        # type: (Query, str, Callable, Callable)->Iterable[Iterable[Variant]]
        duplicate_tracker = DuplicateTracker()
        count = Counter('pricing route')
        count_branded = Counter('branded pricing route')
        count_deduplicated = Counter('deduplicated pricing route')
        with closing(count), closing(count_branded), closing(count_deduplicated):
            pricing_route_variants = self.sirena_price_route_variants(query, company_iata)
            counted = count(pricing_route_variants)
            branded = count_branded(filter_branded(counted))
            deduplicated = count_deduplicated(duplicate_tracker.sirena_deduplicate(branded))
            yield list(sirena_to_ticket_daemon_variants(
                q=query,
                variants=deduplicated,
                deeplink=deeplink,
                on_finish=on_finish,
            ))

    def sirena_price_route_variants(self, q, company_iata):
        price_route_fetcher = SirenaPricingRoute(self.sirena_client_factory())
        for variant in price_route_fetcher.pricing_route(q, company_iata):
            yield variant

    def get_company_routes(self, q, company_code):
        # type: (Query, str)->Dict[str, List[str]]
        fetcher = SirenaCompanyRoute(self.sirena_client_factory())
        return fetcher.get(q, company_code)


class SirenaCompanyRoute(object):
    def __init__(self, client):
        # type: (SirenaClient)->None
        self.client = client

    def get(self, q, company_code):
        return self.client.get_company_routes(company_code, q['lang'])


class SirenaPricingRoute(object):
    PRICING_ROUTE_TIMEOUT = 10
    KLASS_MAP = {'economy': u'Y', 'business': u'C', 'first': u'F'}

    def __init__(self, client):
        # type: (SirenaClient)->None
        self.client = client

    def pricing_route(self, q, company_code):
        params = self._build_aviasearch_params(q, company_code)
        xml = _build_xml('partners/sirena_pricing_route.xml', params)
        response = self.client.get(xml, 0, SirenaPricingRoute.PRICING_ROUTE_TIMEOUT)
        return self._parse_sirena_variants(response)

    @staticmethod
    def _build_aviasearch_params(q, company_iata):
        return {
            'forward_date': q.date_forward.strftime('%d-%m-%y'),
            'return_date': q.date_backward and q.date_backward.strftime('%d-%m-%y') or None,
            'iata_from': q.iata_from,
            'iata_to': q.iata_to,
            'passengersAdult': q.passengers.get('adults', 0),
            'passengersChild': q.passengers.get('children', 0),
            'passengersInfant': q.passengers.get('infants', 0),
            'class': SirenaPricingRoute.KLASS_MAP[q.klass],
            'company_iata': company_iata,
        }

    @staticmethod
    def _parse_sirena_variants(response):
        # type: (str)->Iterable[SirenaVariant]
        logger.debug(response.decode('utf-8'))
        tree = etree.fromstring(response)
        return _sirena_variants(tree.xpath('./answer/pricing_route_mono_brand_cartesian/variant'))


def parse_variants_brand(segments):
    for segment in segments:
        brand_text = segment.find('price').get('brand')
        if not brand_text:
            continue
        return brand_text
    return None


def filter_branded(pricing_route_variants):
    # type: (Iterable[SirenaVariant])->Iterable[SirenaVariant]
    for variant in pricing_route_variants:  # type: SirenaVariant
        if not variant.brand:
            logger.info('Skipping variant without brand: %s', variant)
            continue
        yield variant


def _build_xml(path, params):
    query_xml = loader.render_to_string(path, params)
    return query_xml


def _sirena_variants(variants_tree):
    # type: (etree, SirenaVariant)->Iterable[SirenaVariant]
    sirena_variants = defaultdict(lambda: defaultdict(list))
    '''
    Build structure like this:
    {
        'ECONOM': {
            'RTW-SVX': [variant1, variant2],
            'SVX-RTW': [variant3, variant4],
        }
    }
    '''
    for variant in _sort_variants_by_departure(variants_tree):
        brand = parse_variants_brand(variant.xpath('flight')) or variant.get('requested_brands')
        joint_id = _get_joint_id(variant)
        sirena_variants[brand][joint_id].append(variant)
    '''
    In theory this way of calculating whether query is one-way or round-trip is capable of producing some wrong results.
    For instance, if every segment is assigned a different brand name. Yet it's simple, it does not require passing an extra
    parameter to this method, and an error here most likely implies havoc in partner's response anyway.
    '''
    expected_number_of_segments = max([len(v) for v in sirena_variants.values()]) if sirena_variants else 0
    '''
    Now take the cartesian product of [variant1, variant2]*[variant3, variant4]
    and yield the results
    '''
    for brand, brand_variants in sirena_variants.items():
        for variants_tuple in product(*brand_variants.values()):
            variants = list(variants_tuple)
            if len(variants) != expected_number_of_segments:
                logger.warning(
                    'Unexpected number of segments ({} instead of {}) for the brand {}'.format(
                        len(variants),
                        expected_number_of_segments,
                        brand,
                    ),
                )
                continue
            '''
            Verify that variant[i] lands before variant[i+1] takes off,
            otherwise we may do the wrong thing for the same-day round-trip searches
            '''
            if not _valid_flights_order(variants):
                continue
            yield _build_variant(variants, brand)


def _valid_flights_order(variants):
    for index in range(len(variants)-1):
        if not _flight_precedes(variants[index], variants[index+1]):
            return False
    return True


def _flight_precedes(variant1, variant2):
    arrival1 = _get_arrival(variant1)
    departure2 = _get_departure(variant2)
    return arrival1 and departure2 and departure2 > arrival1


def _get_departure(variant):
    for flight in variant.xpath('flight'):
        return '{} {}'.format(_reverse_date(flight.findtext('deptdate')), flight.findtext('depttime'))
    return None


def _get_arrival(variant):
    for flight in variant.xpath('flight[last()]'):
        return '{} {}'.format(_reverse_date(flight.findtext('arrvdate')), flight.findtext('arrvtime'))
    return None


def _get_joint_id(variant):
    for flight in variant.xpath('flight'):
        return flight.get('iSegmentNum')
    return None


def _reverse_date(str_date):
    if not str_date:
        return str_date
    return '.'.join(str_date.split('.')[::-1])


def _calculate_price(variants):
    price = 0.
    currency = None
    for variant in variants:
        variant_price = variant.find('variant_total')
        variant_currency = variant_price.get('currency')
        if not variant_currency:
            raise ValueError('Unknown currency in the response from Sirena')
        if not currency:
            currency = variant_currency
        elif currency != variant_currency:
            raise ValueError('Multilple currencies in the response from Sirena: {}, {}'.format(currency, variant_currency))
        price += float(variant_price.text)
    return price, currency


def _build_variant(variants, brand):
    price, currency = _calculate_price(variants)
    v = SirenaVariantBuilder(price, currency, brand)
    for variant in variants:
        for segment in variant.xpath('flight'):
            joint_id = segment.get('iSegmentNum')
            if not joint_id:
                raise ValueError('Branded segment does not have iSegmentNum')
            v.add_flight(
                joint_id,
                SirenaFlight(
                    company=segment.findtext('company'),
                    num=segment.findtext('num'),
                    origin=segment.findtext('origin'),
                    destination=segment.findtext('destination'),
                    departure_date=segment.findtext('deptdate'),
                    departure_time=segment.findtext('depttime'),
                    arrival_date=segment.findtext('arrvdate'),
                    arrival_time=segment.findtext('arrvtime'),
                    class_=segment.find('subclass').get('baseclass'),
                    subclass=segment.findtext('subclass'),
                    baggage=segment.find('price').get('baggage'),
                    fare_code=segment.find('price').find('fare').get('base_code'),
                ),
            )
    return v.build()


def _sort_variants_by_departure(variants_tree):
    # type: (etree)->list[SirenaVariant]
    result = [variant for variant in variants_tree]
    result.sort(key=_get_departure)
    return result


def sirena_joint_to_segments(flight_fabric, joint):
    # type: (Any, SirenaJoint)->Iterable[Flight]
    for flight in joint.flights:
        yield flight_fabric.create(
            company_iata=flight.company,
            pure_number=flight.num,
            station_from_iata=flight.origin,
            station_to_iata=flight.destination,
            local_departure=flight.departure,
            local_arrival=flight.arrival,
            fare_code=flight.fare_code,
            baggage=flight.baggage,
        )


def sirena_to_ticket_daemon_variants(q, variants, deeplink, on_finish=None):
    # type: (Any, Iterable[SirenaVariant], Callable, Callable)->Iterable[Variant]
    for variant in sleep_every(variants):
        v = Variant()
        v.klass = q.klass
        v.tariff = variant.price

        v.forward.segments = list(sirena_joint_to_segments(q.importer.flight_fabric, variant.joints[0]))

        if q.date_backward:
            if len(variant.joints) <= 1:
                logger.warning(
                    'Expected at least 2 joints for RT, but got %d: %s',
                    len(variant.joints),
                    q.qkey,
                )
                continue
            v.backward.segments = list(sirena_joint_to_segments(q.importer.flight_fabric, variant.joints[1]))

        v.order_data = deeplink(q, variant)

        yield v
    if on_finish:
        on_finish()


class DuplicateTracker(object):
    """
    Был написан для дедупликации вариантов в результатах запроса pricing_mono_brand.
    https://wiki.sirena-travel.ru/xmlgate:14fare_brands_main:03fb_pricing_mono_brand
    В некоторых случаях часть возвращаемых вариантов оценки может практически полностью совпадать,
    различаться только значениями атрибута requested_brands. (Речь про ответ от pricing_mono_brand)
    Такие варианты перевозки целесообразно не дублировать при отображении пассажиру,
    а отфильтровывать при обработке ответа XML-шлюза.

    На всякий случай, сейчас через этот фильтр прогоняются и результаты запроса pricing_route_mono_brand_cartesian.
    TODO(u-jeen): Возможно, это стоит выпилить.
    """

    def __init__(self):
        self.cache = set()

    def sirena_deduplicate(self, sirena_variants):
        # type: (Iterable[SirenaVariant])->Iterable[SirenaVariant]

        for variant in sirena_variants:
            if variant not in self.cache:
                self.cache.add(variant)
                yield variant


class Counter(object):
    def __init__(self, name):
        self.name = name
        self.i = 0

    def __call__(self, iterable):
        # type:(Iterable)->Iterable
        for item in iterable:
            self.i += 1
            yield item

    def close(self):
        logger.info('Number of %s variants: %d', self.name, self.i)
