# -*- coding: utf-8 -*-
from __future__ import unicode_literals

from lxml import etree
from six import ensure_text

from travel.cpa.collectors.lib.yandex_avia_collector import YandexAviaCollector
from travel.cpa.lib.common import with_retries
from travel.cpa.lib.errors import ProcessError, ErrorType
from travel.cpa.lib.lib_datetime import timestamp, parse_datetime_iso
from travel.cpa.lib.lib_logging import get_logger
from travel.cpa.lib.order_snapshot import NabortuAviaOrderSnapshot, OrderStatus, OrderCurrencyCode

LOG = get_logger(__name__)


class NabortuCollector(YandexAviaCollector):
    PARTNER_NAME = 'nabortu'
    PARTNER_CODE = 'nabortu'

    BASE_URL = 'https://media.zelenski.ru/partners/yandex/yandex_report.xml'
    REQUEST_TIMEOUT = 60

    BOOKINGS_PATH = '//orders/order'

    STATUS_MAPPING = {
        'booking': OrderStatus.PENDING,
        'paid': OrderStatus.CONFIRMED,
        'cancel': OrderStatus.CANCELLED,
    }

    FILL_AIRPORTS = False

    def __init__(self, options):
        super(NabortuCollector, self).__init__(options)

        self.auth = (options.login, options.password)

        self.get_all_reports = with_retries(
            self.get_all_reports_once,
            counter=self.metrics,
            key='collector.events.invalid_response'
        )

    @classmethod
    def configure(cls, parser):
        super(NabortuCollector, cls).configure(parser)

        parser.add_argument('--login')
        parser.add_argument('--password')

        parser.add_argument('--base-url', default=cls.BASE_URL)

    def _get_snapshots(self):
        LOG.info('Getting snapshots')
        for snapshot in self.break_on_duplicate_order_id(
            self.filter_none(
                self.get_all_snapshots()
            )
        ):
            yield snapshot

    @staticmethod
    def break_on_duplicate_order_id(snapshots):
        keys = set()
        for snapshot in snapshots:
            if snapshot.travel_order_id in keys:
                raise ProcessError(ErrorType.ET_PARTNER_DATA)
            keys.add(snapshot.travel_order_id)
            yield snapshot

    @staticmethod
    def filter_none(values):
        for value in values:
            if value is None:
                continue
            yield value

    def get_all_snapshots(self):
        LOG.info('Capturing snapshots from Nabortu')
        report = self.get_all_reports()
        try:
            tree = etree.fromstring(report)
        except etree.XMLSyntaxError:
            raise ProcessError(ErrorType.ET_PARTNER_DATA)

        for raw_snapshot in tree.xpath(self.BOOKINGS_PATH):
            if raw_snapshot is None:
                continue
            snapshot = self.parse_snapshot(raw_snapshot)
            if snapshot is None:
                continue
            yield snapshot

    def get_all_reports_once(self):
        LOG.info('Requesting nabortu reports')
        rsp = self.request_get(self.base_url, auth=self.auth, timeout=self.REQUEST_TIMEOUT)
        return rsp.content

    @staticmethod
    def get_node_dict(node):
        """Converts raw flight node returned from `get_raw_flight_nodes` to dict"""
        return {ensure_text(row.tag): ensure_text(row.text) for row in node}

    def parse_snapshot(self, raw_snapshot):
        snapshot_dict = self.get_node_dict(raw_snapshot)

        if not snapshot_dict:
            return None

        mandatory_fields = {'marker', 'status', 'orderid', 'price', 'currency', 'created_at', 'changed_at'}

        if set(snapshot_dict.keys()) & mandatory_fields != mandatory_fields:
            raise ProcessError(ErrorType.ET_PARTNER_DATA)

        snapshot = NabortuAviaOrderSnapshot()
        snapshot.label = snapshot_dict['marker']

        if snapshot_dict['status'] not in self.STATUS_MAPPING:
            raise ProcessError(ErrorType.ET_PARTNER_DATA)

        snapshot.status = self.STATUS_MAPPING[snapshot_dict['status']]
        snapshot.update_partner_order_id(snapshot_dict['orderid'])
        snapshot.order_amount = float(snapshot_dict['price'])
        try:
            snapshot.currency_code = OrderCurrencyCode(snapshot_dict['currency'])
        except ValueError:
            raise ProcessError(ErrorType.ET_PARTNER_DATA)

        snapshot.created_at = timestamp(parse_datetime_iso(snapshot_dict['created_at']))
        snapshot.source_updated_at = timestamp(parse_datetime_iso(snapshot_dict['changed_at']))

        snapshot.partner_id = self.partner_id
        snapshot.billing_order_id = self.billing_order_id

        return snapshot
