# coding=utf-8
from __future__ import unicode_literals

import logging

from sandbox import sdk2
from datetime import datetime, timedelta, date
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.projects.avia.lib.yt_helpers import YtClientFactory, tables_for_daterange
from sandbox.projects.avia.lib.marker import MarkerWriter
from sandbox.projects.avia.lib.logs import configure_logging, get_sentry_dsn
from itertools import groupby


# One does not simply define reducer because it needs to import yt which is not loaded by sandbox until top-level
# classes are parsed.
def get_reducer(partner_codes):
    import yt.wrapper as yt

    def reduce(key, rows, context):
        redirects = {}
        for row in rows:
            if context.table_index == 0 and row['partnerCode'] in partner_codes:
                redirects[key] = row
            elif key in redirects:
                redirect = redirects[key]
                yield {
                    'unixtime': row['unixtime'],
                    'source': redirect['partnerCode'],
                    'partner': redirect['partnerCode'],
                    'partner_id': redirect['partnerId'],
                    'billing_order_id': redirect['billing_order_id'],
                    'marker': row['marker'],
                    'status': row['status'],
                    # not actually a creation date but we decided it will suffice,
                    # see https://st.yandex-team.ru/RASPTICKETS-17309 for details
                    'created_at': row['iso_eventtime'],
                    'order_id': row['order_id'],
                    'currency': row['currency'],
                    'order_price': row['order_price']
                }

    return yt.with_context(reduce)


def get_redirect_log_mapper(partner_codes):
    def map(row):
        if row['partnerCode'] in partner_codes:
            partial_row = {
                'partnerCode': row['partnerCode'],
                'partnerId': row['partnerId'],
                'billing_order_id': row['billing_order_id'],
                'marker': row['marker']
            }
            yield partial_row

    return map


def get_success_log_mapper():
    """
    Used for selection of needed columns, because this way we can avoid troubles with different schemas
    on different days
    """

    def map(row):
        partial_row = {
            'unixtime': row.get('unixtime'),
            'marker': row.get('marker'),
            'status': row.get('status'),
            'iso_eventtime': row.get('iso_eventtime'),
            'order_id': row.get('order_id'),
            'currency': row.get('currency'),
        }
        if row.get('order_price') is not None:
            partial_row['order_price'] = str(row.get('order_price'))
        yield partial_row

    return map


class AviaSuccessBookLogMarker(sdk2.Task):
    STATUS_CHOICES = {'booking', 'paid', 'cancel'}
    # The script was designed to work only with these partners.
    # If you want to add new ones, you should test and make all the necessary adjustments before.
    PARTNER_CHOICES = {'supersaver', 'gogate', 'trip_ru'}

    PROCESSED_REDIRECT_LOG_SCHEMA = [
        dict(name='marker', type='string'),
        dict(name='partnerCode', type='string'),
        dict(name='billing_order_id', type='int64'),
        dict(name='partnerId', type='int64'),
    ]

    # sort cannot infer schema because it accepts range of tables and they can have different schemas
    PROCESSED_SUCCESS_LOG_SCHEMA = [
        dict(name='marker', type='string'),
        dict(name='unixtime', type='int64'),
        dict(name='status', type='string'),
        dict(name='iso_eventtime', type='string'),
        dict(name='order_id', type='string'),
        dict(name='currency', type='string'),
        dict(name='order_price', type='string'),
    ]

    class Requirements(sdk2.Requirements):
        # configure this for your task, the more accurate - the better
        cores = 1  # exactly 1 core
        disk_space = 128  # 128 Megs or less
        ram = 128  # 128 Megs or less

        environments = (
            PipEnvironment('raven'),
            PipEnvironment('yandex-yt', version='0.10.8'),
            PipEnvironment('yandex-yt-yson-bindings-skynet', version='0.3.32-0'),
        )

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

    class Parameters(sdk2.Parameters):
        with sdk2.parameters.Group('Task parameters') as task_parameters:
            days = sdk2.parameters.Integer(
                'Days',
                default_value=1,
                required=True,
            )
            partner_codes = sdk2.parameters.List('Partner\'s codes', required=True)
            max_lag = sdk2.parameters.Integer(
                'Max lag',
                default_value=5,
                required=True,
                description='the maximum expected number of days between redirecting '
                            'client to partner and receiving booking confirmation of the '
                            'client',
            )

        with sdk2.parameters.Group('YT Settings') as yt_settings:
            vaults_owner = sdk2.parameters.String('Token vault owner', required=True)
            yt_vault_name = sdk2.parameters.String('YT Token vault name', required=True, default='YT_TOKEN')
            yt_proxy = sdk2.parameters.String('YT cluster', required=True, default='hahn')

        with sdk2.parameters.Group('Log directories') as paths:
            success_dir = sdk2.parameters.String(
                'Success dir',
                required=True,
                default='//home/logfeller/logs/avia-success-book-log/30min',
            )
            booking_dir = sdk2.parameters.String(
                'Booking dir',
                required=True,
                default='//home/avia/logs/avia-partner-booking-log',
            )
            redirect_dir = sdk2.parameters.String(
                'Redirect dir',
                required=True,
                default='//home/avia/logs/avia-json-redir-log',
            )

    def update_partner_booking_log(self):
        date_to = date.today()
        date_from = date.today() - timedelta(days=self.Parameters.days)

        if not self.yt_client.exists(self.Parameters.success_dir):
            raise Exception('Directory {} does not exist'.format(self.Parameters.success_dir))
        if not self.yt_client.exists(self.Parameters.booking_dir):
            raise Exception('Directory {} does not exist'.format(self.Parameters.booking_dir))

        success_logs = tables_for_daterange(
            self.yt_client,
            self.Parameters.success_dir,
            date_from,
            date_to,
        )

        redirect_logs = tables_for_daterange(
            self.yt_client,
            self.Parameters.redirect_dir,
            date_from - timedelta(days=self.Parameters.max_lag),
            date_to,
        )

        with self.yt_client.Transaction():
            with self.yt_client.TempTable(attributes=dict(schema=self.PROCESSED_REDIRECT_LOG_SCHEMA)) as sorted_redirect_log,\
                    self.yt_client.TempTable(attributes=dict(schema=self.PROCESSED_SUCCESS_LOG_SCHEMA)) as sorted_success_log,\
                    self.yt_client.TempTable() as temp_booking_log:

                foreign_redirect_log = '<foreign=%true>{}'.format(sorted_redirect_log)
                self.yt_client.run_map(get_redirect_log_mapper(self.Parameters.partner_codes), redirect_logs, sorted_redirect_log)
                self.yt_client.run_sort(sorted_redirect_log, foreign_redirect_log, sort_by=['marker'])

                self.yt_client.run_map(get_success_log_mapper(), success_logs, sorted_success_log)
                self.yt_client.run_sort(sorted_success_log, sort_by=['marker'])

                self.yt_client.run_join_reduce(
                    get_reducer(self.Parameters.partner_codes),
                    [foreign_redirect_log, sorted_success_log], temp_booking_log,
                    join_by=['marker']
                )

                self.move_data_from_tmp_to_booking_log(temp_booking_log)

    def move_data_from_tmp_to_booking_log(self, tmp):
        from yt.wrapper import format
        rows = self.yt_client.read_table(tmp, format=format.JsonFormat())
        parsed_rows = []
        for row in rows:
            parsed_row = MarkerWriter.Row(
                partner=row['partner'],
                partner_id=row['partner_id'],
                billing_order_id=row['billing_order_id'],
                order_id=row.get('order_id'),
                created_at=datetime.strptime(row['created_at'], "%Y-%m-%d %H:%M:%S"),
                price=row.get('order_price'),
                currency=row.get('currency'),
                status=AviaSuccessBookLogMarker._validate_status(row['status']),
                marker=row['marker'],
                trip_type=row.get('trip_type'),
                ticket_number=row.get('ticket_number'),
            )
            parsed_rows.append(parsed_row)
        self.write_data_to_booking_log(parsed_rows)

    def write_data_to_booking_log(self, rows):
        """
        Group orders by date and write to corresponding table
        """

        try:
            orders = sorted(rows, key=lambda x: x.created_at)
            for order_date, date_orders in groupby(orders, lambda x:  x.created_at):
                writer = MarkerWriter('avia-success-book-log', self._logger, self.Parameters.booking_dir,
                                      self.yt_client)
                writer.add_rows(date_orders)
                writer.write_to_yt(order_date)
        except Exception:
            self._logger.exception('Error writing data in  {}'.format(self.Parameters.booking_dir))
            raise
        else:
            self._logger.info('Data transfer complete')

    @staticmethod
    def _validate_status(status):
        if status not in AviaSuccessBookLogMarker.STATUS_CHOICES:
            raise ValueError('Bad status choice {}'.format(status))
        # They send us only paid orders, but call the status 'booking'
        # https://st.yandex-team.ru/RASPTICKETS-17047#5e27f3ca1a94137cc8bcbcbb
        return 'paid'

    def on_prepare(self):
        unknown_partners = set(self.Parameters.partner_codes) - self.PARTNER_CHOICES
        if unknown_partners:
            raise ValueError('The script was not designed to work with following partners: {}'.format(unknown_partners))
        configure_logging(
            sentry_dsn=get_sentry_dsn(self)
        )
        self._logger = logging.getLogger(__name__)
        self.yt_client = YtClientFactory().create(
            proxy=self.Parameters.yt_proxy,
            token=sdk2.Vault.data(self.Parameters.vaults_owner, self.Parameters.yt_vault_name),
        )

    def on_execute(self):
        self.update_partner_booking_log()
