# coding=utf-8
from __future__ import unicode_literals

import logging
from importlib import import_module

from sandbox import sdk2
from sandbox.projects.avia.import_marker import AviaImportMarker
from sandbox.projects.avia.lib.marker import BaseMarkerReader, MarkerTransfer, MarkerWriter
from sandbox.sandboxsdk.environments import PipEnvironment


class AeroflotMarkerReader(BaseMarkerReader):
    BOOKING_ORIGINAL_TIME_ZONE = 'US/Central'
    BOOKING_DESIRED_TIME_ZONE = 'Europe/Moscow'

    def __init__(self, logger, geo_point_cache, sqs_client, sqs_queue_url, partner, max_messages_per_launch):
        super(AeroflotMarkerReader, self).__init__(
            statuses_map=None,
            logger=logger,
            geo_point_cache=geo_point_cache,
        )
        self.date_groups = None
        self.sqs_client = sqs_client
        self.sqs_queue_url = sqs_queue_url
        self._partner = partner
        self.max_messages_per_launch = max_messages_per_launch
        self.messages_received = 0
        self.messages_delete_errors_count = 0
        self.messages_parse_errors_count = 0

    def import_data(self, date):
        if self.date_groups is None:
            self.receive_and_group()
        self._logger.info('Getting data for date %s', date)
        return self.date_groups.get(date) or []

    def receive_and_group(self):
        self._logger.info('Start: Receiving data and grouping by date')
        self.date_groups = {}
        for order in self._receive_orders():
            date = order[0].created_at.date()
            self.date_groups.setdefault(date, []).append(order)
        self._logger.info('Done: Receiving data and grouping by date')

    def _receive_orders(self):
        from botocore.exceptions import ClientError
        while self.max_messages_per_launch > self.messages_received:
            messages = self.sqs_client.receive_message(
                QueueUrl=self.sqs_queue_url,
                AttributeNames=['All'],
                MaxNumberOfMessages=10,
                VisibilityTimeout=600,
                WaitTimeSeconds=20,
            )

            if 'Messages' not in messages:
                break

            for message in messages['Messages']:
                receipt = message['ReceiptHandle']
                try:
                    bookings = list(self.parse_report(message))
                except Exception:
                    self._logger.exception(
                        'Cannot parse order from aeroflot: %r, deleting with receipt %s',
                        message,
                        receipt,
                    )
                    self.messages_parse_errors_count += 1
                    try:
                        self.sqs_client.delete_message(
                            QueueUrl=self.sqs_queue_url,
                            ReceiptHandle=receipt,
                        )
                    except ClientError:
                        self._logger.warning('Cannot delete order %s from sqs with receipt %s', message, receipt)
                        self.messages_delete_errors_count += 1
                    continue

                for order in bookings:
                    yield (order, message['ReceiptHandle'])
                self.messages_received += 1
        self._logger.info('Messages received: %d', self.messages_received)

    def parse_report(self, message):
        from sandbox.projects.avia.lib.safe_lxml import fromstring as safe_fromstring
        from six import text_type
        import json
        body = message['Body']

        order_details = json.loads(body)

        marker = order_details['marker']

        for booking in safe_fromstring(order_details['message']).xpath('//bookings/booking'):
            booking_dict = {text_type(e.tag): text_type(e.text) for e in booking}

            order_id = booking_dict['rloc']
            created_at = self.parse_datetime(booking_dict['bookingDate'])

            try:
                flights = [self.get_mapped_flight(segment) for segment in booking.xpath('//FlightSegment')]
            except TypeError:
                flights = []
            ticket_documents = [e for e in booking.xpath('//TicketDocuments')]

            (airport_from, airport_to), trip_type = self._get_airports({'flights': flights})

            # creating row here to be able to validate ASAP
            yield MarkerWriter.Row(
                partner=self._partner['code'],
                partner_id=self._partner['partner_id'],
                billing_order_id=self._partner['billing_order_id'],
                order_id=order_id,
                created_at=created_at,
                airport_from=airport_from,
                airport_to=airport_to,
                status=self._parse_status(ticket_documents),
                marker=marker,
                trip_type=trip_type,
            )

    def _parse_status(self, ticket_documents):
        if not ticket_documents or not len(ticket_documents[0]):
            return 'booking'
        else:
            return 'paid'

    def get_mapped_flight(self, segment):
        from six import text_type
        segment_dict = {text_type(e.tag): e for e in segment}
        try:
            return {
                'from': segment_dict['OriginLocation'].attrib['LocationCode'],
                'to': segment_dict['DestinationLocation'].attrib['LocationCode'],
                'departure_dt': self.parse_datetime(segment.attrib['DepartureDateTime']),
                'arrival_dt': self.parse_datetime(segment.attrib['ArrivalDateTime']),
            }
        except TypeError:
            logging.getLogger(__name__).exception(
                'Error mapping flights for segment %s',
                repr(segment_dict)
            )
            raise

    def parse_datetime(self, dt_text):
        from parse import parse
        from pytz import timezone
        try:
            original_datetime_naive = parse('{:ti}', dt_text)[0]
        except TypeError:
            logging.getLogger(__name__).exception(
                'Could not parse datetime %s',
                dt_text
            )
            raise
        original_datetime_aware = timezone(self.BOOKING_ORIGINAL_TIME_ZONE).localize(original_datetime_naive)
        desired_datetime_aware = original_datetime_aware.astimezone(timezone(self.BOOKING_DESIRED_TIME_ZONE))
        return desired_datetime_aware


class AeroflotMarkerTransfer(MarkerTransfer):
    def __init__(self, partner, marker_writer, marker_reader, logger, sqs_client, sqs_queue_url):
        """
        :param dict partner:
        :param sandbox.projects.avia.lib.marker.MarkerWriter marker_writer:
        :param sandbox.projects.avia.lib.marker.MarkerReader marker_reader:
        :param logging.Logger logger:
        :param sqs_client: boto3 sqs client
        :param sqs_queue_url: sqs queue url
        """
        super(AeroflotMarkerTransfer, self).__init__(
            partner=partner,
            marker_writer=marker_writer,
            marker_reader=marker_reader,
            logger=logger,
        )
        self.sqs_client = sqs_client
        self.sqs_queue_url = sqs_queue_url
        self.messages_delete_errors_count = 0

    def transfer(self, date):
        """
        Transfers data row by row and deletes each successfull write from sqs
        :param datetime.datetime date:
        :return:
        """
        from botocore.exceptions import ClientError
        self._logger.info('Start: Transfer data for date %s', date)
        orders = self._marker_reader.import_data(date)
        self._logger.info('Got %d orders for date %s', len(orders), date)
        try:
            self.export_data_rows([order for order, _ in orders], date)
        except Exception:
            self._logger.exception(
                'Error transfering data for partner %s, date %s',
                self._partner,
                date,
            )
            self._marker_writer._rows = []
        else:
            for order, receipt in orders:
                self._logger.info('Sent order %s, deleting from sqs with receipt %s', order, receipt)
                try:
                    self.sqs_client.delete_message(
                        QueueUrl=self.sqs_queue_url,
                        ReceiptHandle=receipt,
                    )
                except ClientError:
                    self._logger.warning('Cannot delete order %s from sqs with receipt %s', order, receipt)
                    self.messages_delete_errors_count += 1

        self._logger.info('Done: Transfer data for date %s', date)

    def export_data_rows(self, order_rows, date):
        """
        Writes a single row of data to yt
        :param list[MarkerWriter.Row] order_rows:
        :param datetime.datetime date:
        :return:
        """
        self._marker_writer.add_rows(order_rows)
        self._marker_writer.write_to_yt(date)


def _iter_pip_environments(names):
    """Если мы пытаемся добавить PipEnvironment для модуля, который уже есть в окружении, то что-то ломается.
    Если установка одного из PipEnvironment сломалась, то могут сломаться другие.
    Поэтому, мы проверяем есть ли в текущем окружении нужный модуль и возвращаем PipEnvironment только если такого модуля ещё нет.
    """
    for name in names:
        try:
            import_module(name)
        except ImportError:
            yield PipEnvironment(name)


class AviaImportAeroflotMarker(AviaImportMarker):
    """ Import marker from Aeroflot """

    class Requirements(AviaImportMarker.Requirements):
        environments = AviaImportMarker.Requirements.environments.default + tuple(
            pip_environment
            for pip_environment in _iter_pip_environments((
                'boto3', 'lxml', 'parse', 'pytz', 'six',
            ))
        )

    class Parameters(AviaImportMarker.Parameters):
        with sdk2.parameters.Group('Import parameters') as import_block:
            days = sdk2.parameters.Integer('Days', default_value=5)
            partner_code = sdk2.parameters.String('Partner\'s code', required=True, default='aeroflot')
            source = sdk2.parameters.String('Source', required=True, default='aeroflot')
            max_messages_per_launch = sdk2.parameters.Integer(
                'Max messages per launch',
                required=True,
                default=1000,
                description='Limits maximum number of messages that would be read from sqs queue per single launch',
            )
        with sdk2.parameters.Output:
            messages_received = sdk2.parameters.Integer(
                'Messages received',
                description='Total number of messages that were received from sqs queue',
            )
            messages_delete_errors_count = sdk2.parameters.Integer(
                'Messages failed to delete',
                description='Total number of messages that failed to be deleted from queue',
            )
            messages_parse_errors_count = sdk2.parameters.Integer(
                'Messages failed to parse',
                description='Total number of messages that failed to be parsed',
            )

        with sdk2.parameters.Group('SQS parameters') as sqs_block:
            sqs_endpoint = sdk2.parameters.String(
                'SQS endpoint url',
                required=True,
                default='http://sqs.yandex.net:8771',
            )
            sqs_access_key = sdk2.parameters.String(
                'SQS access key',
                required=True,
                default='avia',
            )
            sqs_queue = sdk2.parameters.String(
                'Aeroflot queue',
                required=True,
                default='aeroflot_order.fifo',
            )

    def on_execute(self):
        import boto3
        boto3.set_stream_logger('boto3')
        sqs = boto3.client(
            'sqs',
            region_name='yandex',
            endpoint_url=self.Parameters.sqs_endpoint,
            aws_access_key_id=self.Parameters.sqs_access_key,
            aws_secret_access_key='',
        )
        queue_url = sqs.list_queues(QueueNamePrefix=self.Parameters.sqs_queue)['QueueUrls'][0]

        marker_reader = AeroflotMarkerReader(
            logger=self._logger,
            geo_point_cache=self.geo_point_cache,
            sqs_client=sqs,
            sqs_queue_url=queue_url,
            partner=self._partner,
            max_messages_per_launch=self.Parameters.max_messages_per_launch,
        )
        marker_transfer = AeroflotMarkerTransfer(
            partner=self._partner,
            marker_writer=MarkerWriter(
                self.Parameters.source,
                self._logger,
                self.Parameters.yt_partner_booking_root,
                self._yt,
            ),
            marker_reader=marker_reader,
            logger=self._logger,
            sqs_client=sqs,
            sqs_queue_url=queue_url,
        )

        self._logger.info('Start: Transferring data in date range')

        marker_reader.receive_and_group()

        for report_date in sorted(marker_reader.date_groups.keys()):
            marker_transfer.transfer(report_date)

        self._logger.info('Stop: Transferring data in date range')

        self.Parameters.messages_delete_errors_count = (
            marker_transfer.messages_delete_errors_count + marker_reader.messages_delete_errors_count
        )
        self.Parameters.messages_parse_errors_count = marker_reader.messages_parse_errors_count
        self.Parameters.messages_received = marker_reader.messages_received
