# -*- coding: utf-8 -*-
import logging
from collections import namedtuple, defaultdict
from copy import deepcopy
from datetime import datetime
from typing import Generator

import requests
from django.template import loader
from lxml import etree

from travel.avia.ticket_daemon.ticket_daemon.api.flights import Variant, OperatingFlight
from travel.avia.ticket_daemon.ticket_daemon.api.query import Query
from travel.avia.ticket_daemon.ticket_daemon.daemon.utils import sleep_every, BadPartnerResponse
from travel.avia.ticket_daemon.ticket_daemon.settings.environment import env_variable_provider
from travel.avia.ticket_daemon.ticket_daemon.lib.currency import Price
from travel.avia.ticket_daemon.ticket_daemon.lib.partner_secret_storage import partner_secret_storage
from travel.avia.ticket_daemon.ticket_daemon.lib.tracker import QueryTracker
from travel.avia.ticket_daemon.ticket_daemon.settings import YANDEX_ENVIRONMENT_TYPE

logger = logging.getLogger(__name__)

if YANDEX_ENVIRONMENT_TYPE == 'production':
    API_URL = 'https://ndc-search.aeroflot.ru/api/pr/ndc/v3.0/asrq'
    AGENCY_ID = env_variable_provider.get('AEROFLOT_AGENCY_ID', required=False, default='Yandex18')
    CLIENT_ID = partner_secret_storage.get(importer_name='aeroflot2_yandex18', namespace='PASSWORD')
    USER_AGENT = env_variable_provider.get('AEROFLOT_USER_AGENT', required=False, default='Yandex')
else:
    API_URL = 'https://ndc-search.test.aeroflot.ru/api/sb/ndc/v3.0/asrq'
    AGENCY_ID = env_variable_provider.get('AEROFLOT_AGENCY_ID', required=False, default='Yandex')
    CLIENT_ID = partner_secret_storage.get(importer_name='aeroflot2_yandex', namespace='PASSWORD')
    USER_AGENT = env_variable_provider.get('AEROFLOT_USER_AGENT', required=False, default='YandexTest')

KLASS_MAP = {
    'economy': 3,
    'business': 2,
    'comfort': 4,
}
LANG_MAP = {'ru': 'ru', 'ua': 'ua', 'kz': 'kz', 'tr': 'tr', 'en': 'en'}
COUNTRY_MAP = {'ru': 'RU', 'ua': 'UA', 'kz': 'KZ', 'tr': 'TR', 'com': 'DE'}

HEADERS = {
    'Content-Type': 'application/x-iata.ndc.v1+xml',
    'X-IBM-Client-Id': CLIENT_ID,
    'User-Agent': USER_AGENT,
}

DATETIME_FORMAT = '%Y-%m-%dT%H:%M:%S'

MARKER_ALPHABET = '0123456789QWERTYUIOPASDFGHJKLZXCVBNM'
MARKER_FORMAT = 'YA{rnd}'
MARKER_RND_LENGTH = 8
MARKER_START_DATETIME = datetime(2020, 12, 11)

EMPTY_RESPONSE_WARNING_TYPE = '102'

NSMAP = {'ns': 'http://www.iata.org/IATA/2015/00/2018.2/IATA_AirShoppingRS'}

FlightRef = namedtuple('FlightRef', ('flight', 'is_backward_segment'))
Passenger = namedtuple('Passenger', ('id', 'type'))


def remove_xml_formatting(xml):
    return ''.join(filter(None, [
        line.strip() for line in xml.splitlines()
    ]))


def _to_string(elem):
    return etree.tostring(elem, encoding='utf8') if elem else ''


def _to_unicode(elem):
    return etree.tostring(elem, encoding='unicode') if elem else ''


class BookingInfoMeta(object):
    def _truncate_tree(self, tree):
        for offer in tree.xpath(".//ns:Offer", namespaces=NSMAP):
            self.offers[self.get_unique_segment_refs(offer)].append(offer)
            offer.getparent().remove(offer)

        for pax_segment in tree.xpath(".//ns:PaxSegment", namespaces=NSMAP):
            self.pax_segments[pax_segment.findtext('ns:PaxSegmentID', namespaces=NSMAP)] = pax_segment
            pax_segment.getparent().remove(pax_segment)

        for pax_journey in tree.xpath(".//ns:PaxJourney", namespaces=NSMAP):
            self.pax_journeys[pax_journey.findtext('ns:PaxJourneyID', namespaces=NSMAP)] = pax_journey
            pax_journey.getparent().remove(pax_journey)

        for price_class in tree.xpath(".//ns:PriceClass", namespaces=NSMAP):
            self.price_classes[price_class.findtext('ns:PriceClassID', namespaces=NSMAP)] = price_class
            price_class.getparent().remove(price_class)

        for baggage_allowance in tree.xpath('.//ns:BaggageAllowance', namespaces=NSMAP):
            self.baggage_allowances[
                baggage_allowance.findtext('ns:BaggageAllowanceID', namespaces=NSMAP)
            ] = baggage_allowance
            baggage_allowance.getparent().remove(baggage_allowance)

        return tree

    def __init__(self, elements_tree):
        self.offers = defaultdict(list)
        self.pax_segments = {}
        self.pax_journeys = {}
        self.price_classes = {}
        self.baggage_allowances = {}
        self.truncated_tree = self._truncate_tree(deepcopy(elements_tree))

    def build_tree(self, segment_refs):
        tree = deepcopy(self.truncated_tree)

        pax_segment_ids = []
        pax_journey_ids = []
        price_class_ids = []
        baggage_allowance_ids = []

        carrier_offers = tree.find('.//ns:CarrierOffers', namespaces=NSMAP)
        for offer in self.offers[frozenset(segment_refs)]:
            carrier_offers.append(offer)
            pax_segment_ids.extend(
                x.text for x in offer.xpath('.//ns:PaxSegmentRefID', namespaces=NSMAP) if x.text not in pax_segment_ids
            )
            pax_journey_ids.extend(
                x.text for x in offer.xpath('.//ns:PaxJourneyRefID', namespaces=NSMAP) if x.text not in pax_journey_ids
            )
            price_class_ids.extend(
                x.text for x in offer.xpath('.//ns:PriceClassRefID', namespaces=NSMAP) if x.text not in price_class_ids
            )
            baggage_allowance_ids.extend(
                x.text for x in offer.xpath('.//ns:BaggageAllowanceRefID', namespaces=NSMAP)
                if x.text not in baggage_allowance_ids
            )

        pax_segment_list = tree.find('.//ns:PaxSegmentList', namespaces=NSMAP)
        for pax_segment_id in pax_segment_ids:
            pax_segment_list.append(self.pax_segments[pax_segment_id])

        pax_journey_list = tree.find('.//ns:PaxJourneyList', namespaces=NSMAP)
        for pax_journey_id in pax_journey_ids:
            pax_journey_list.append(self.pax_journeys[pax_journey_id])

        price_class_list = tree.find('.//ns:PriceClassList', namespaces=NSMAP)
        for price_class_id in price_class_ids:
            price_class_list.append(self.price_classes[price_class_id])

        baggage_allowance_list = tree.find('.//ns:BaggageAllowanceList', namespaces=NSMAP)
        for baggage_allowance_id in baggage_allowance_ids:
            baggage_allowance_list.append(self.baggage_allowances[baggage_allowance_id])

        return tree

    def get_unique_segment_refs(self, offer):
        unique_segment_refs = set()
        for segment_tag_ref in offer.xpath(
            'ns:OfferItem/ns:FareDetail/ns:FareComponent/ns:PaxSegmentRefID', namespaces=NSMAP
        ):
            unique_segment_refs.add(segment_tag_ref.text)
        return frozenset(unique_segment_refs)


@QueryTracker.init_query
def query(tracker, q):
    r = _get_response(tracker, q, q.iata_from, q.iata_to)
    variants = list(_parse_response(r, q))
    return variants


def _parse_dt(datetime_tag):
    return datetime.strptime(
        datetime_tag.findtext('ns:AircraftScheduledDateTime', namespaces=NSMAP), DATETIME_FORMAT
    )


def _check_errors_and_warnings(response, elements_tree):
    errors = [_to_string(error) for error in elements_tree.findall('ns:Error', namespaces=NSMAP)]
    warnings = [_to_string(warning) for warning in elements_tree.findall('ns:Response/ns:Warning', namespaces=NSMAP)]

    if errors or warnings:
        response.reason = '\n'.join([
            reason_part for reason_part in [
                'Errors:\n\t{}'.format('\n\t'.join(errors)) if errors else '',
                'Unknown warnings:\n\t{}'.format('\n\t'.join(warnings)) if warnings else ''
            ] if reason_part
        ])
        raise BadPartnerResponse('aeroflot', response)


def _parse_response(r, q):
    # type: (requests.Response, Query) -> Generator[Variant, None, None]
    xml = r.content
    parser = etree.XMLParser(remove_blank_text=True)
    tree = etree.XML(xml, parser=parser)

    _check_errors_and_warnings(r, tree)

    forward_journeys = set()
    backward_journeys = set()
    for origin_dest in sleep_every(tree.xpath('//ns:DataLists/ns:OriginDestList/ns:OriginDest', namespaces=NSMAP)):
        if origin_dest.findtext('ns:OriginDestID', namespaces=NSMAP).startswith('OD1'):
            for ref in origin_dest.findall('ns:PaxJourneyRefID', namespaces=NSMAP):
                forward_journeys.add(ref.text)
        elif origin_dest.findtext('ns:OriginDestID', namespaces=NSMAP).startswith('OD2'):
            for ref in origin_dest.findall('ns:PaxJourneyRefID', namespaces=NSMAP):
                backward_journeys.add(ref.text)

    forward_segments = set()
    backward_segments = set()
    for pax_journey in sleep_every(tree.xpath('//ns:DataLists/ns:PaxJourneyList/ns:PaxJourney', namespaces=NSMAP)):
        pax_journey_id = pax_journey.findtext('ns:PaxJourneyID', namespaces=NSMAP)
        if pax_journey_id in forward_journeys:
            for ref in pax_journey.findall('ns:PaxSegmentRefID', namespaces=NSMAP):
                forward_segments.add(ref.text)
        elif pax_journey_id in backward_journeys:
            for ref in pax_journey.findall('ns:PaxSegmentRefID', namespaces=NSMAP):
                backward_segments.add(ref.text)
        else:
            logger.error('Unknown PaxJourneyID "%s"', pax_journey_id)

    flights = {}

    booking_info_meta = BookingInfoMeta(tree)

    for segment_tag in sleep_every(tree.xpath('//ns:DataLists/ns:PaxSegmentList/ns:PaxSegment', namespaces=NSMAP)):
        segment_key = segment_tag.findtext('ns:PaxSegmentID', namespaces=NSMAP)

        departure_tag = segment_tag.find('ns:Dep', namespaces=NSMAP)
        arrival_tag = segment_tag.find('ns:Arrival', namespaces=NSMAP)
        carrier_tag = segment_tag.find('ns:MarketingCarrierInfo', namespaces=NSMAP)
        operating_carrier_tag = segment_tag.find('ns:OperatingCarrierInfo', namespaces=NSMAP)

        flights[segment_key] = FlightRef(
            flight=dict(
                station_from_iata=departure_tag.findtext('ns:IATALocationCode', namespaces=NSMAP),
                station_to_iata=arrival_tag.findtext('ns:IATALocationCode', namespaces=NSMAP),
                local_departure=_parse_dt(departure_tag),
                local_arrival=_parse_dt(arrival_tag),
                company_iata=carrier_tag.findtext('ns:CarrierDesigCode', namespaces=NSMAP),
                pure_number=carrier_tag.findtext('ns:MarketingCarrierFlightNumberText', namespaces=NSMAP),
                operating=OperatingFlight(
                    company_iata=operating_carrier_tag.findtext('ns:CarrierDesigCode', namespaces=NSMAP),
                    pure_number=operating_carrier_tag.findtext('ns:OperatingCarrierFlightNumberText', namespaces=NSMAP),
                ),
                baggage=None,
            ),
            is_backward_segment=segment_key in backward_segments,
        )

    for offer in sleep_every(tree.xpath('//ns:OffersGroup/ns:CarrierOffers/ns:Offer', namespaces=NSMAP)):
        v = Variant()
        v.klass = q.klass

        price_tag = offer.find('ns:TotalPrice/ns:TotalAmount', namespaces=NSMAP)

        v.tariff = Price(
            float(price_tag.text.replace(',', '.')),
            currency=price_tag.get('CurCode'),
        )

        segments = [v.forward.segments, v.backward.segments]

        fare_codes = {}
        for fare_component in offer.xpath('ns:OfferItem/ns:FareDetail/ns:FareComponent', namespaces=NSMAP):
            segment_tag_ref = fare_component.findtext('ns:PaxSegmentRefID', namespaces=NSMAP)
            fare_codes.setdefault(segment_tag_ref, []).append(
                fare_component.findtext('ns:FareBasisCode', namespaces=NSMAP).split('/')[0]
            )

        unique_segment_refs = set()
        for segment_tag_ref in offer.xpath(
            'ns:OfferItem/ns:FareDetail/ns:FareComponent/ns:PaxSegmentRefID', namespaces=NSMAP
        ):
            if segment_tag_ref.text in unique_segment_refs:
                continue
            unique_segment_refs.add(segment_tag_ref.text)
            flight, is_backward_segment = flights[segment_tag_ref.text]

            segment_fare_codes = fare_codes[segment_tag_ref.text]
            segments[is_backward_segment].append(
                q.importer.flight_fabric.create(
                    fare_code=max(set(segment_fare_codes), key=segment_fare_codes.count),
                    **flight
                )
            )

        booking_info = {
            'OfferId': offer.findtext('ns:OfferID', namespaces=NSMAP),
            'CountryCode': COUNTRY_MAP.get(q.national_version, 'RU'),
            'LanguageCode': LANG_MAP.get(q.lang, 'ru'),
            'CabinType': KLASS_MAP[q.klass],
        }
        v.order_data = {
            'url': offer.findtext('ns:WebAddressURL', namespaces=NSMAP),
            'booking_info': booking_info,
        }

        v.order_data['booking_info']['AirShoppingRS'] = _to_unicode(
            booking_info_meta.build_tree(unique_segment_refs)
        )
        yield v


def _get_response(tracker, q, iata_from, iata_to):
    query_xml = _build_xml(
        'partners/aeroflot2.xml',
        _build_avia_search_params(q, iata_from, iata_to),
    )
    return tracker.wrap_request(
        requests.post,
        API_URL,
        headers=HEADERS,
        data=query_xml.encode('utf-8'),
        verify=False,
    )


def book(order_data):
    return {'url': order_data['url']}


def generate_marker():
    # type: () -> str
    alphabet_size = len(MARKER_ALPHABET)
    delta = (datetime.now() - MARKER_START_DATETIME)
    base = int(delta.total_seconds() * 1000)
    rnd = ''
    while base >= alphabet_size:
        rnd += MARKER_ALPHABET[base % alphabet_size]
        base = base / alphabet_size
    rnd += MARKER_ALPHABET[base]
    return MARKER_FORMAT.format(rnd=rnd.zfill(MARKER_RND_LENGTH))


def add_marker_to_url(marker_field_name, marker_value, url):
    return url.replace('%s=%s' % (marker_field_name, AGENCY_ID), '%s=%s' % (marker_field_name, marker_value))


def _build_xml(xml_template_file, params):
    query_xml = loader.render_to_string(xml_template_file, params)

    return remove_xml_formatting(query_xml)


def _build_avia_search_params(q, iata_from, iata_to):
    # type: (Query, basestring, basestring) -> dict
    return {
        'forward_date': q.date_forward.strftime('%Y-%m-%d'),
        'backward_date': q.date_backward and q.date_backward.strftime('%Y-%m-%d') or None,
        'agency_id': AGENCY_ID,
        'iata_from': iata_from,
        'iata_to': iata_to,
        'passenger_list': _format_passengers(q.passengers),
        'klass': KLASS_MAP[q.klass],
        'country_code': COUNTRY_MAP.get(q.national_version, 'RU'),
        'lang': LANG_MAP.get(q.lang, 'ru'),
    }


def _format_passengers(passengers):
    unfolded = ['ADT'] * passengers.get('adults', 0) + \
               ['CHD'] * passengers.get('children', 0) + \
               ['INF'] * passengers.get('infants', 0)

    return list(Passenger(*x) for x in enumerate(unfolded, 1))
