# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
import unicodecsv
from datetime import datetime
from decimal import Decimal, ROUND_HALF_UP

from bson import ObjectId

from common.data_api.billing.trust_client import TrustClient
from travel.rasp.train_api.train_purchase.core.models import TrainOrder
from travel.rasp.train_api.train_purchase.utils.order import should_refund_yandex_fee

logger = logging.getLogger()
ADMIN_COMMENT_PREFIX='invalid_refund_yandex_fee_amount:'
CHUNK_SIZE_IN_ROWS = 5
START_AT = datetime(2019, 9, 1)
EXP = Decimal('.01')
PARTNER_FEE_DEFAULT = '31.7'


def _get_resize_amounts_by_order_id(purchase_token):
    trust_client = TrustClient()
    info = trust_client.get_raw_payment_info(purchase_token)
    return {o['order_id']: Decimal(o['orig_amount']) - Decimal(o['paid_amount']) for o in info['orders']}


def reset_refund_yandex_fee_amount():
    selected_len = CHUNK_SIZE_IN_ROWS
    total_processed = 0
    last_processed_object_id = ObjectId.from_datetime(START_AT)

    while selected_len == CHUNK_SIZE_IN_ROWS:
        plain_orders = list(TrainOrder.objects.aggregate(*[
            {'$match': {
                '_id': {'$gt': last_processed_object_id},
                'passengers.tickets.refund.refund_yandex_fee_amount': {'$gt': 0},
                'admin_comment': None,
            }},
            {'$sort': {'_id': 1}},
            {'$limit': CHUNK_SIZE_IN_ROWS},
        ]))

        selected_len = len(plain_orders)
        logger.info('Next batch: {} rows'.format(selected_len))

        for plain_order in plain_orders:
            order_uid = plain_order['uid']
            order = TrainOrder.objects.get(uid=order_uid)
            logger.info('Processing {}, passengers: {}'.format(order_uid, len(order.passengers)))

            resize_amounts_by_order_id = _get_resize_amounts_by_order_id(order.current_billing_payment.purchase_token)
            for passenger in order.passengers:
                ticket = passenger.tickets[0]
                blank_id = ticket.blank_id
                refund_yandex_fee_amount = getattr(ticket.refund, 'refund_yandex_fee_amount') if ticket.refund else None

                if not refund_yandex_fee_amount:
                    logger.info('Skipping blank {} due missing refund_yandex_fee_amount'.format(blank_id))
                    continue

                real_refund_yandex_fee_amount = resize_amounts_by_order_id[ticket.payment.fee_order_id]

                if refund_yandex_fee_amount == real_refund_yandex_fee_amount:
                    logger.info('{}, blank {}: refund_amount match {}'.format(
                        order_uid, blank_id, refund_yandex_fee_amount))
                else:
                    logger.info('{}, blank {}: refund_amount {}, real {}'.format(
                        order_uid, blank_id, refund_yandex_fee_amount, real_refund_yandex_fee_amount))
                    ticket.refund.refund_yandex_fee_amount = real_refund_yandex_fee_amount
                    # сохраняем здесь некорректные значения ресайза по бланкам для дальнейших подсчетов убытков
                    order.admin_comment = ((order.admin_comment or ADMIN_COMMENT_PREFIX)
                                           + '{}:{},'.format(blank_id, refund_yandex_fee_amount))

            order.save()
            last_processed_object_id = order.id
        total_processed += selected_len
    logger.info('Total processed: {}'.format(total_processed))


def _write_invalid_refund_yandex_fee(writer):
    total_loss = Decimal(0)
    total_processed = 0
    selected_len = CHUNK_SIZE_IN_ROWS
    last_processed_object_id = ObjectId.from_datetime(START_AT)

    while selected_len == CHUNK_SIZE_IN_ROWS:
        plain_orders = list(TrainOrder.objects.aggregate(*[
            {'$match': {
                '_id': {'$gt': last_processed_object_id},
                'passengers.tickets.refund.refund_yandex_fee': {'$ne': None},
            }},
            {'$sort': {'_id': 1}},
            {'$limit': CHUNK_SIZE_IN_ROWS},
        ]))

        selected_len = len(plain_orders)
        logger.info('Next batch: {} rows'.format(selected_len))

        for plain_order in plain_orders:
            order_uid = plain_order['uid']
            plain_passengers = plain_order['passengers']
            logger.info('Processing {}, passengers: {}'.format(order_uid, len(plain_passengers)))

            order = TrainOrder.objects.get(uid=order_uid)
            refund_by_blank_id = {blank_id: refund for refund in order.iter_refunds() for blank_id in refund.blank_ids}

            for plain_passenger in plain_passengers:
                plain_ticket = plain_passenger['tickets'][0]
                blank_id = plain_ticket['blank_id']
                plain_refund = plain_ticket.get('refund')
                plain_refund_yandex_fee = plain_refund.get('refund_yandex_fee')

                if not plain_refund_yandex_fee:
                    logger.info('Skipping blank {} due missing refund_yandex_fee'.format(blank_id))
                    continue

                refund = refund_by_blank_id.get(blank_id)
                if not refund:
                    logger.info('Skipping blank {} due missing refund'.format(blank_id))
                    continue
                refund_amount = Decimal(plain_refund_yandex_fee['amount']).quantize(EXP, ROUND_HALF_UP)
                if not should_refund_yandex_fee(order, refund.created_at):
                    loss = refund_amount
                    logger.info('Blank {}: invalid full refund {}'.format(blank_id, loss))
                    writer.writerow([order_uid, blank_id, loss, 'full'])
                else:
                    partner_fee = (Decimal(plain_ticket['payment'].get(
                        'partner_fee', PARTNER_FEE_DEFAULT)).quantize(EXP, ROUND_HALF_UP))
                    partner_refund_fee = (Decimal(plain_ticket['payment'].get(
                        'partner_refund_fee', PARTNER_FEE_DEFAULT)).quantize(EXP, ROUND_HALF_UP))
                    loss = refund_amount.min(partner_fee + partner_refund_fee)
                    logger.info('Blank {}: invalid partial refund {}'.format(blank_id, loss))
                    writer.writerow([order_uid, blank_id, loss, 'partial'])

                total_loss += loss
            last_processed_object_id = plain_order['_id']
        total_processed += selected_len
    logger.info('Total processed: {}, loss: {}'.format(total_processed, total_loss))
    return total_loss


def _write_invalid_resize_yandex_fee(writer):
    total_loss = Decimal(0)
    total_processed = 0
    selected_len = CHUNK_SIZE_IN_ROWS
    last_processed_object_id = ObjectId.from_datetime(START_AT)

    while selected_len == CHUNK_SIZE_IN_ROWS:
        orders = TrainOrder.objects.filter(
            id__gt=last_processed_object_id,
            admin_comment__startswith=ADMIN_COMMENT_PREFIX,
            passengers__tickets__refund__refund_yandex_fee_amount__gt=Decimal(0),
        )[:CHUNK_SIZE_IN_ROWS]

        selected_len = len(orders)
        logger.info('Next batch: {} rows'.format(selected_len))

        for order in orders:
            order_uid = order.uid
            logger.info('Processing {}, passengers: {}'.format(order_uid, len(order.passengers)))
            invalid_resizes = [r.split(':') for r in order.admin_comment[len(ADMIN_COMMENT_PREFIX):].split(',') if r]
            invalid_resizes_by_blank_id = {r[0]: Decimal(r[1]) for r in invalid_resizes}
            logger.info('invalid_refunds: {}'.format(invalid_resizes_by_blank_id))

            for passenger in order.passengers:
                ticket = passenger.tickets[0]
                blank_id = ticket.blank_id

                refund_yandex_fee_amount = getattr(ticket.refund, 'refund_yandex_fee_amount') if ticket.refund else None

                if not refund_yandex_fee_amount:
                    logger.info('Skipping blank {} due missing refund_yandex_fee_amount'.format(blank_id))
                    continue

                invalid_resize_yandex_fee_amount = invalid_resizes_by_blank_id.get(blank_id)
                if not invalid_resize_yandex_fee_amount:
                    logger.info('Skipping blank {} due missing invalid_resize_yandex_fee_amount'.format(blank_id))
                    continue

                loss = refund_yandex_fee_amount - invalid_resize_yandex_fee_amount
                if not loss:
                    logger.info('Skipping blank {}: it is correct'.format(blank_id))
                else:
                    logger.info('{}: blank {}, loss {}'.format(order_uid, blank_id, loss))
                    writer.writerow([order_uid, blank_id, loss, 'resize'])
                    total_loss += loss

            last_processed_object_id = order.id
        total_processed += selected_len
    logger.info('Total processed: {}, loss: {}'.format(total_processed, total_loss))
    return total_loss


def main():
    output = file('/tmp/losses.csv', 'w')
    writer = unicodecsv.writer(output, encoding='utf-8', dialect='excel', lineterminator='\n')
    writer.writerow(['order_uid', 'blank_id', 'loss', 'comment'])
    total_loss = Decimal(0)

    total_loss += _write_invalid_refund_yandex_fee(writer)
    total_loss += _write_invalid_resize_yandex_fee(writer)

    logger.info('Total loss {}'.format(total_loss))
    output.close()
