# -*- coding: utf-8 -*-
from __future__ import unicode_literals

from typing import Iterable

import logging
import sys
from collections import OrderedDict
from copy import copy, deepcopy
from datetime import date

from dateutil.relativedelta import relativedelta
from lxml import etree
from six import iteritems, text_type
from yt.wrapper import YtClient

from travel.avia.library.python.marker_helpers import AviaFlightRouteHelper
from travel.avia.library.python.references.partner import PartnerCache
from travel.avia.library.python.references.station import create_station_cache
from travel.cpa.collectors.lib.http_collector import HttpCollector
from travel.cpa.lib.common import with_retries
from travel.cpa.lib.errors import ErrorType, ProcessError
from travel.cpa.lib.lib_datetime import iter_day, parse_datetime_iso, timestamp
from travel.cpa.lib.lib_logging import get_logger
from travel.cpa.lib.order_snapshot import OrderCurrencyCode, OrderStatus

LOG = get_logger(__name__)
LOG.setLevel(logging.DEBUG)
LOG.addHandler(logging.StreamHandler(sys.stdout))


class _SnapshotContainer:
    def __init__(self, snapshot, raw_snapshot):
        self.snapshot = snapshot
        self.raw_snapshot = raw_snapshot


class YandexAviaCollector(HttpCollector):
    PARTNER_NAME = None
    PARTNER_CODE = None
    REQUEST_TIMEOUT = 60

    SNAPSHOT_CLS = None
    BOOKINGS_PATH = None
    FILL_AIRPORTS = True
    FLIGHT_NODE_XPATH = 'segment/flight'

    DROP_STATUS = None

    STATUS_MAPPING = {
        'PROCESSING': OrderStatus.PENDING,
        'PAID': OrderStatus.CONFIRMED,
        'CANCELLED': OrderStatus.CANCELLED,
        'PART_REFUND': OrderStatus.CONFIRMED,
    }

    def __init__(self, options):
        super(YandexAviaCollector, self).__init__(options)

        self.auth = None
        self.params = None

        yt_client = YtClient(options.yt_proxy, options.yt_token)

        self.partner = PartnerCache(yt_client)
        self.partner_id, self.billing_order_id = self.partner.partner_id_bundle(self.PARTNER_CODE)

        self.station = None
        self.route_helper = None
        if self.FILL_AIRPORTS:
            self.station = create_station_cache(yt_client)
            self.route_helper = AviaFlightRouteHelper(self.station)

        self.get_day_report = with_retries(
            self.get_day_report_once,
            counter=self.metrics,
            key='collector.events.invalid_response'
        )

        self.date_from = parse_datetime_iso(options.date_from).date()
        self.date_to = parse_datetime_iso(options.date_to).date()

    @classmethod
    def configure(cls, parser):
        super(YandexAviaCollector, cls).configure(parser)

        parser.add_argument('--yt-proxy', default='hahn')
        parser.add_argument('--yt-token', default=None)

        parser.add_argument('--date-from', default=(date.today() + relativedelta(months=-1)).isoformat())
        parser.add_argument('--date-to', default=date.today().isoformat())

    def _get_snapshots(self):
        """
        Does all the high level job of fetching data, grouping and squashing.
        Yields deduplicated snapshots.
        Guarantees that there is only one snapshot from partner with the same order id.
        """

        global_group = OrderedDict()

        for day_date in iter_day(self.date_from, self.date_to):
            grouped_snapshots = OrderedDict()
            LOG.info('Getting snapshots for %r', day_date)
            for snapshot_container in self.get_day_snapshots(day_date):
                if snapshot_container is None:
                    continue
                key = self._deduplication_key(snapshot_container.snapshot)
                grouped_snapshots.setdefault(key, [])
                grouped_snapshots[key].append(snapshot_container),

            for key, snapshot_containers in iteritems(grouped_snapshots):
                snapshot = self._squash_snapshot(snapshot_containers)
                if snapshot is not None:
                    global_group[key] = snapshot

        for key, snapshot in iteritems(global_group):
            yield snapshot

    def get_day_snapshots(self, day_date):
        report = self.get_day_report(day_date)
        try:
            tree = etree.fromstring(report)
        except etree.XMLSyntaxError:
            raise ProcessError(ErrorType.ET_PARTNER_DATA)

        for raw_snapshot in tree.xpath(self.BOOKINGS_PATH):
            if raw_snapshot is None:
                continue
            snapshot = self.get_order_snapshot(raw_snapshot)
            if snapshot is None:
                continue
            yield _SnapshotContainer(
                snapshot=snapshot,
                raw_snapshot=raw_snapshot,
            )

    def get_day_report_once(self, day_date):
        params = copy(self.params)
        params.update({
            'date1': day_date.strftime('%Y-%m-%d'),
            'date2': day_date.strftime('%Y-%m-%d'),
        })
        rsp = self.request_get(self.base_url, params=params, auth=self.auth, timeout=self.REQUEST_TIMEOUT)
        return rsp.content

    def get_raw_flights_node(self, raw_snapshot):
        return raw_snapshot.xpath(self.FLIGHT_NODE_XPATH)

    @staticmethod
    def get_node_dict(node):
        """Converts raw flight node returned from `get_raw_flight_nodes` to dict"""
        return {text_type(row.tag): text_type(row.text) for row in node}

    @staticmethod
    def get_mapped_flight(flight):
        """Maps flight from `get_node_dict` to route helper structure"""
        return {
            'from': flight['departure'],
            'to': flight['arrival'],
            'departure_dt': parse_datetime_iso(flight['departureDate'] + ' ' + flight['departureTime']),
            'arrival_dt': parse_datetime_iso(flight['arrivalDate'] + ' ' + flight['arrivalTime']),
        }

    def get_flights_from_raw_snapshot(self, raw_snapshot):
        flights_node = self.get_raw_flights_node(raw_snapshot)

        return self.remove_duplicate_flight_segments([
            self.get_mapped_flight(self.get_node_dict(flight_node))
            for flight_node in flights_node
        ])

    @staticmethod
    def _deduplication_key(snapshot):
        return snapshot.travel_order_id, snapshot.status

    def fill_airports(self, raw_snapshots):
        """
        Returns a dict with data about departure and arrival airports and trip type if needed
        :param raw_snapshot:
        :return:
        """
        if not isinstance(raw_snapshots, list):
            raw_snapshots = [raw_snapshots]
        try:
            if self.FILL_AIRPORTS:
                flights = self.remove_duplicate_flight_segments([
                    flight
                    for raw_snapshot in raw_snapshots
                    for flight in self.get_flights_from_raw_snapshot(raw_snapshot)
                ])
                self.route_helper.localize_datetime(flights)
                return self.route_helper.fillin_trip_info({'flights': flights})
        except Exception:
            LOG.exception('Unable to get airport info')  # But, that's ok. We should not loose data because of that

        return dict()

    def _squash_snapshot(self, snapshot_containers):
        '''
        Takes several snapshot, grouped by some key using `_deduplication_key()` and merges them together
        using some sort of policy for different fields

        :param List[_SnapshotContainer] snapshot_containers:
        '''
        if not snapshot_containers:
            return None
        if len(snapshot_containers) == 1:
            return snapshot_containers[0].snapshot
        base_snapshot = deepcopy(snapshot_containers[0].snapshot)

        ordersum = sum(c.snapshot.order_amount for c in snapshot_containers)
        base_snapshot.order_amount = ordersum

        airports = self.fill_airports([c.raw_snapshot for c in snapshot_containers])
        if airports and airports.get('origin') and airports.get('destination') and airports.get('trip_type'):
            base_snapshot.origin = airports['origin']
            base_snapshot.destination = airports['destination']
            base_snapshot.trip_type = airports['trip_type']
            base_snapshot.date_forward = airports['date_forward']
            base_snapshot.date_backward = airports['date_backward']
        return base_snapshot

    def get_order_snapshot(self, raw_snapshot):
        snapshot_dict = self.get_node_dict(raw_snapshot)
        if not snapshot_dict:
            return None

        if self.DROP_STATUS is not None and snapshot_dict['state'] == self.DROP_STATUS:
            return None

        airports = self.fill_airports(raw_snapshot)

        snapshot = self.SNAPSHOT_CLS.from_dict(airports, ignore_unknown=True)

        snapshot.update_partner_order_id(snapshot_dict['id'])
        snapshot.status = self.STATUS_MAPPING[snapshot_dict['state']]

        snapshot.label = snapshot_dict.get('marker')  # take first non-empty when squashing

        snapshot.created_at = timestamp(parse_datetime_iso(snapshot_dict['created_at']))
        snapshot.updated_at = timestamp(self.now)

        try:
            # raise error if differs when squashing, else take first
            snapshot.currency_code = OrderCurrencyCode(snapshot_dict['currency'])
        except ValueError:
            raise ProcessError(ErrorType.ET_PARTNER_DATA)

        snapshot.order_amount = float(snapshot_dict['price'])  # take sum when squashing

        snapshot.partner_id = self.partner_id
        snapshot.billing_order_id = self.billing_order_id

        return snapshot

    @staticmethod
    def remove_duplicate_flight_segments(segments):
        # type: (Iterable[dict]) -> Iterable[dict]
        output = []
        tracker = set()
        for segment in segments:
            frozen = frozenset(segment.items())
            if frozen in tracker:
                continue
            output.append(segment)
            tracker.add(frozen)
        return output
