# -*- coding: utf-8 -*-
import io
import itertools
from datetime import datetime
from typing import Any, List

import requests
from django.template import loader
from lxml import etree, objectify

from travel.avia.ticket_daemon.ticket_daemon.daemon.utils import sleep_every
from travel.avia.ticket_daemon.ticket_daemon.lib.baggage import Baggage
from travel.avia.ticket_daemon.ticket_daemon.lib.currency import Price
from travel.avia.ticket_daemon.ticket_daemon.api.flights import Variant
from travel.avia.ticket_daemon.ticket_daemon.lib.decorators import pipe
from travel.avia.ticket_daemon.ticket_daemon.lib.http import url_complement_missing


ECONOMY_CLASS = 'economy'
BUSINESS_CLASS = 'business'


class SigImporter(object):
    KLASS_MAP = {
        ECONOMY_CLASS: u'Economy',
        BUSINESS_CLASS: u'Business',
    }

    def __init__(
        self,
        customer_id,
        search_url,
        api_password,
        utm,
        logger,
        tariff_codes_from_deeplink_extractor=None,
        tariff_codes_from_fare_info_extractor=None,
    ):
        self._customer_id = customer_id
        self._search_url = search_url
        self._api_password = api_password
        self._logger = logger
        self._utm = utm
        self._tariff_codes_from_deeplink_extractor = tariff_codes_from_deeplink_extractor
        self._tariff_codes_from_fare_info_extractor = tariff_codes_from_fare_info_extractor

    def query(self, tracker, q, variants_filter=None):
        # type: (Any, Any, Any)->List[Variant]
        xml = self._get_data(tracker, q)

        return list(self._parse(xml, q, flight_fabric=q.importer.flight_fabric, variants_filter=variants_filter))

    @staticmethod
    def validate_query(q):
        q.validate_klass(SigImporter.KLASS_MAP.keys())

    def _get_data(self, tracker, query):
        query_xml = loader.render_to_string(
            'partners/sig.xml', self._build_params(query)
        )

        r = tracker.wrap_request(
            requests.post,
            self._search_url,
            headers={'Content-Type': 'text/xml; charset=utf-8', },
            data=query_xml.encode('utf-8'),
            auth=requests.auth.HTTPBasicAuth(self._customer_id, self._api_password),
            verify=False,
        )

        return r.content

    def _build_params(self, query):
        params = {
            'CustomerID': self._customer_id,
            'from_iata': query.iata_from,
            'to_iata': query.iata_to,
            'forward_date': query.date_forward.strftime('%Y-%m-%d'),
            'adult_count': query.passengers.get('adults', 0),
            'child_count': query.passengers.get('children', 0),
            'infant_count': query.passengers.get('infants', 0),
            'klass': self.KLASS_MAP.get(query.klass)
        }

        if query.date_backward:
            params['return_date'] = query.date_backward.strftime('%Y-%m-%d')

        return params

    def _parse(self, xml, q, flight_fabric, variants_filter):
        tree = self.deannotated_tree_fromstring(xml)

        for shop_option in sleep_every(tree.xpath('//ShopOption')):
            v = Variant()

            deeplink = shop_option.xpath('DeepLinks/DeepLink')[0].text
            tariff_codes_forward = []
            tariff_codes_backward = []
            tariff_codes_dict = {}

            if self._tariff_codes_from_fare_info_extractor:
                tariff_codes_dict = self._tariff_codes_from_fare_info_extractor(shop_option, self._logger)
            elif self._tariff_codes_from_deeplink_extractor:
                tariff_codes_forward, tariff_codes_backward = self._tariff_codes_from_deeplink_extractor(
                    deeplink,
                    self._logger,
                )

            v.url = url_complement_missing(deeplink, self._utm)
            v.klass = q.klass
            v.order_data = {'url': v.url}

            for itinerary_option in shop_option.xpath(
                    'ItineraryOptions/ItineraryOption'):
                if itinerary_option.attrib['ODRef'] == 'forward':
                    v.forward.segments = self.parse_flights(
                        flight_fabric,
                        itinerary_option,
                        tariff_codes_forward,
                        tariff_codes_dict,
                    )
                elif itinerary_option.attrib['ODRef'] == 'backward':
                    v.backward.segments = self.parse_flights(
                        flight_fabric,
                        itinerary_option,
                        tariff_codes_backward,
                        tariff_codes_dict,
                    )

                is_business_cabin = False
                for reservation in itinerary_option.xpath('FlightSegment/ReservationDetails/Reservation'):
                    if reservation.attrib['Cabin'] == self.KLASS_MAP[BUSINESS_CLASS]:
                        is_business_cabin = True

            if v.klass == ECONOMY_CLASS and is_business_cabin:
                v.klass = BUSINESS_CLASS

            v.tariff = Price(
                float(shop_option.attrib['Total']),
                shop_option.attrib['Currency'].replace('RUB', 'RUR')
            )

            if variants_filter and not variants_filter(v):
                continue

            yield v

    @pipe(list)
    def parse_flights(self, flight_fabric, itinerary, tariff_codes, tariff_codes_dict):
        for flight_segment, fare_code in itertools.izip_longest(itinerary.xpath('FlightSegment'), tariff_codes):
            if not flight_segment:
                continue
            departure = flight_segment.xpath('Departure')[0]
            arrival = flight_segment.xpath('Arrival')[0]

            if fare_code is None:
                for reservation in flight_segment.xpath('ReservationDetails')[0].xpath('Reservation'):
                    fare_code = tariff_codes_dict.get(reservation.attrib['ReservationRef'], None)
                    if fare_code is not None:
                        break

            yield flight_fabric.create(
                station_from_iata=departure.attrib['Airport'],
                station_to_iata=arrival.attrib['Airport'],
                local_departure=datetime.strptime(
                    departure.attrib['Time'][:19],
                    '%Y-%m-%dT%H:%M:%S'
                ),
                local_arrival=datetime.strptime(
                    arrival.attrib['Time'][:19],
                    '%Y-%m-%dT%H:%M:%S'
                ),
                company_iata=flight_segment.attrib['Airline'],
                pure_number=flight_segment.attrib['Flight'],
                baggage=self._get_baggage(flight_segment),
                fare_code=fare_code,
            )

    def _get_baggage(self, segment):
        try:
            bag_detail = segment.find('BagDetails/BagDetail')
            if bag_detail is None:
                return None
            bag_type = bag_detail.get('BagType')
            bag_allowance = bag_detail.get('BagAllowance')
            if bag_type == 'PC':
                return Baggage.from_partner(pieces=int(bag_allowance))
            if bag_type == 'KG':
                return Baggage.from_partner(weight=int(bag_allowance))
            else:
                self._logger.warning('Unknown BagType: %s. BagAllowance: %s', bag_type, bag_allowance)
                return Baggage.from_partner()
        except Exception:
            self._logger.error('Baggage parsing exception')
            return Baggage.from_partner()

    @staticmethod
    def deannotated_tree_fromstring(xml):
        parser = etree.XMLParser(remove_blank_text=True)
        tree = etree.parse(io.BytesIO(xml), parser)
        root = tree.getroot()

        for elem in root.getiterator():
            if not hasattr(elem.tag, 'find'):
                continue

            i = elem.tag.find('}')

            if i >= 0:
                elem.tag = elem.tag[i + 1:]

        objectify.deannotate(root, cleanup_namespaces=True)

        return tree
