# coding=utf-8
from __future__ import unicode_literals

import datetime as dt
import json
from argparse import ArgumentParser, Namespace  # noqa
from typing import Optional  # noqa

import boto3
import lxml.etree  # noqa
from pytz import timezone
from six import ensure_str, text_type
from yt.wrapper import YtClient

from travel.avia.library.python.marker_helpers import AviaFlightRouteHelper
from travel.avia.library.python.safe_lxml import fromstring as safe_fromstring
from travel.avia.library.python.references.partner import PartnerCache
from travel.avia.library.python.references.station import create_station_cache
from travel.cpa.collectors.lib.collector import Collector
from travel.cpa.lib.errors import ErrorType, ProcessError
from travel.cpa.lib.lib_datetime import timestamp
from travel.cpa.lib.lib_logging import get_logger
from travel.cpa.lib.order_snapshot import AeroflotClickoutAviaOrderSnapshot, OrderCurrencyCode, OrderStatus

LOG = get_logger(__name__)

PROFIT_AMOUNT = 215.


class AeroflotCollector(Collector):
    PARTNER_NAME = 'aeroflot'
    PARTNER_CODE = 'aeroflot'

    BOOKING_ORIGINAL_TIME_ZONE = 'US/Central'
    MESSAGE_SENT_TIME_ZONE = 'Europe/Moscow'

    @classmethod
    def configure(cls, parser):
        # type: (ArgumentParser) -> None
        parser.add_argument('--sqs-endpoint', default='http://sqs.yandex.net:8771')
        parser.add_argument('--sqs-access-key', default='avia')
        parser.add_argument('--sqs-secret-key', default='not used yet')
        parser.add_argument('--sqs-queue-name', default='testing_aeroflot_order_cpa')
        parser.add_argument('--sqs-queue-owner', default=None)
        parser.add_argument('--sqs-visibility-timeout', default=100)

        parser.add_argument('--yt-proxy', default='hahn')
        parser.add_argument('--yt-token', default=None)

    def __init__(self, options):
        # type: (Namespace) -> None
        super(AeroflotCollector, self).__init__()

        yt_client = YtClient(options.yt_proxy, options.yt_token)

        self.partner = PartnerCache(yt_client)
        self.partner_id, self.billing_order_id = self.partner.partner_id_bundle(self.PARTNER_CODE)

        self.route_helper = AviaFlightRouteHelper(create_station_cache(yt_client))

        boto3.set_stream_logger()
        self._sqs_client = boto3.client(
            'sqs',
            region_name='yandex',
            endpoint_url=options.sqs_endpoint,
            aws_access_key_id=options.sqs_access_key,
            aws_secret_access_key=options.sqs_secret_key,
        )
        self._sqs_queue_name = options.sqs_queue_name
        self._sqs_queue_owner = options.sqs_queue_owner

        if options.sqs_visibility_timeout is not None:
            self._sqs_visibility_timeout = int(options.sqs_visibility_timeout)
        else:
            self._sqs_visibility_timeout = None

    def _get_snapshots(self):
        snapshots_by_order_id = {}
        for message in self.get_raw_snapshots():
            for snapshot in self._parse_sqs_message(message):
                self._put_replace_with_most_recent(snapshots_by_order_id, snapshot)

        for snapshot in snapshots_by_order_id.values():
            yield snapshot

    def _put_replace_with_most_recent(self, snapshots, snapshot):
        if snapshot.travel_order_id not in snapshots:
            snapshots[snapshot.travel_order_id] = snapshot
            return

        current_snapshot = snapshots[snapshot.travel_order_id]
        if current_snapshot.updated_at < snapshot.updated_at:
            snapshots[snapshot.travel_order_id] = snapshot
            return

    def _parse_sqs_message(self, message):
        body = json.loads(message['Body'])

        message = body['message'].lstrip()
        marker = body['marker']
        date_sent = body['date_sent']

        message = ensure_str(message)

        for snapshot in self.parse_afl_message(message, marker, date_sent):
            yield snapshot

    def parse_afl_message(self, message, marker, date_sent):
        for booking in safe_fromstring(message).xpath('//bookings/booking'):
            booking_dict = {text_type(e.tag): text_type(e.text) for e in booking}

            ticket_documents = [e for e in booking.xpath('//TicketDocuments')]

            try:
                segments = [
                    {
                        'from': segment.xpath('./OriginLocation')[0].attrib['LocationCode'],
                        'to': segment.xpath('./DestinationLocation')[0].attrib['LocationCode'],
                        'departure_dt': dt.datetime.strptime(
                            segment.attrib['DepartureDateTime'],
                            '%Y-%m-%d %H:%M:%S',
                        ),
                        'arrival_dt': dt.datetime.strptime(
                            segment.attrib['ArrivalDateTime'],
                            '%Y-%m-%d %H:%M:%S',
                        ),
                    } for segment in booking.xpath('//FlightSegment')
                ]
                LOG.debug('Flights: %s', segments)
                self.route_helper.localize_datetime(segments)
                LOG.debug('Flights localized: %s', segments)
            except ValueError:
                segments = []

            flight_info = {'flights': segments}
            if segments:
                try:
                    flight_info = self.route_helper.fillin_trip_info(flight_info)
                except Exception:
                    LOG.exception('Cannot fill airports for order. %s', segments)
                    raise ProcessError(ErrorType.ET_PARTNER_DATA)

            amount, currency = self.parse_price(booking)

            snapshot = AeroflotClickoutAviaOrderSnapshot.from_dict(
                dict(
                    partner_id=self.partner_id,
                    billing_order_id=self.billing_order_id,
                    order_amount=amount,
                    currency_code=currency,
                    profit_amount=PROFIT_AMOUNT,
                    origin=flight_info.get('origin'),
                    destination=flight_info.get('destination'),
                    trip_type=flight_info.get('trip_type'),
                    date_forward=flight_info.get('date_forward'),
                    date_backward=flight_info.get('date_backward'),
                    status=self._parse_status(ticket_documents),
                    label=marker,
                    created_at=self._parse_datetime(
                        booking_dict['bookingDate'],
                        self.BOOKING_ORIGINAL_TIME_ZONE,
                    ),
                    updated_at=self._parse_datetime(
                        date_sent,
                        self.MESSAGE_SENT_TIME_ZONE,
                    ),
                ),
                convert_type=True,
            )

            order_id = '{}_{}'.format(marker, booking_dict['rloc'])
            snapshot.update_partner_order_id(order_id)

            yield snapshot

    def get_raw_snapshots(self):
        sqs_request_kwargs = dict(QueueName=self._sqs_queue_name)
        if self._sqs_queue_owner is not None:
            sqs_request_kwargs['QueueOwnerAWSAccountId'] = self._sqs_queue_owner
        sqs_queue_response = self._sqs_client.get_queue_url(**sqs_request_kwargs)
        sqs_queue_url = sqs_queue_response['QueueUrl']
        while True:
            receive_message_kwargs = dict(
                QueueUrl=sqs_queue_url,
                MaxNumberOfMessages=10,
                WaitTimeSeconds=10,  # long polling
                # https://wiki.yandex-team.ru/users/radix/how-we-created-ymq/#osnovnyemetodyapiymq
            )
            if self._sqs_visibility_timeout is not None:
                receive_message_kwargs['VisibilityTimeout'] = self._sqs_visibility_timeout
            sqs_messages_response = self._sqs_client.receive_message(**receive_message_kwargs)
            messages = sqs_messages_response.get('Messages', [])
            if not messages:
                break
            for message in messages:
                yield message

    @staticmethod
    def _parse_status(ticket_documents):
        if not ticket_documents or not len(ticket_documents[0]):
            return OrderStatus.PENDING
        else:
            return OrderStatus.CONFIRMED

    @staticmethod
    def _parse_datetime(dt_text, tz):
        original_datetime_naive = dt.datetime.strptime(dt_text, '%Y-%m-%d %H:%M:%S')
        original_datetime_aware = timezone(tz).localize(original_datetime_naive)
        return timestamp(original_datetime_aware)

    @staticmethod
    def parse_price(booking):
        # type: (lxml.etree.ElementBase)-> (Optional[float], Optional[OrderCurrencyCode])
        try:
            amount = float(booking.xpath('./PGWOrders/PGWOrder/Amount/text()')[0])
            currency = OrderCurrencyCode(
                booking.xpath('./PGWOrders/PGWOrder/Currency/text()')[0])
            return amount, currency
        except IndexError:
            return float(0), OrderCurrencyCode.RUB
