import collections
import decimal
import itertools
import logging

from django.db import transaction
from django.db.models import Q
from django.utils import timezone

import cars.settings
from cars.billing.core.payment_processor import PaymentProcessor
from cars.billing.iface.payment import IPayment
from cars.billing.models.bonus_payment import BonusPayment
from cars.billing.models.card_payment import CardPayment
from ..core.push_client import PUSH_CLIENT
from ..models.order import Order
from ..models.order_item_payment import OrderItemPayment
from ..models.order_payment_method import OrderPaymentMethod
from ..models.order_item_tariff import OrderItemTariff
from .order_updater import OrderUpdater


LOGGER = logging.getLogger(__name__)


class OrderPaymentProcessor:

    class Error(Exception):
        pass

    class PaymentMethodMissingError(Exception):
        pass

    def __init__(self, payment_processor, chunked_payment_amount, push_client):
        self._payment_processor = payment_processor
        self._chunked_payment_amount = chunked_payment_amount
        self._push_client = push_client

    @classmethod
    def from_settings(cls):
        return cls(
            payment_processor=PaymentProcessor.from_settings(push_client=PUSH_CLIENT),
            chunked_payment_amount=cars.settings.ORDERS['chunked_payment_amount'],
            push_client=PUSH_CLIENT,
        )

    def get_order_payment_summary(self, order):
        cost = bonus_amount = card_amount = decimal.Decimal(0)
        for item in order.items.all():
            item_payment_summary = self.get_order_item_payment_summary(item)
            cost += item_payment_summary.cost
            bonus_amount += item_payment_summary.bonus_amount
            card_amount += item_payment_summary.card_amount

        return PaymentSummary(
            cost=cost,
            bonus_amount=bonus_amount,
            card_amount=card_amount,
        )

    def get_order_cost(self, order):
        return sum([self.get_order_item_cost(item) for item in order.items.all()])

    def get_order_payment_progress(self, order):
        ok_statuses = {
            IPayment.GenericStatus.REFUNDED,
            IPayment.GenericStatus.SUCCESS,
        }

        order_cost = self.get_order_cost(order)

        has_errors = False
        paid_amount = decimal.Decimal(0)
        in_progress_amount = decimal.Decimal(0)

        for item in order.get_sorted_items():
            for payment in item.payments.all():
                payment_impl = payment.get_impl()
                payment_status = payment_impl.get_generic_status()
                if payment_status in ok_statuses:
                    paid_amount += payment.amount
                    continue
                elif payment_status is IPayment.GenericStatus.IN_PROGRESS:
                    in_progress_amount += payment.amount
                    continue
                elif payment_status is IPayment.GenericStatus.ERROR:
                    has_errors = True
                    continue
                else:
                    raise RuntimeError('unreachable: {}'.format(payment_status))

        errors_amount = order_cost - paid_amount - in_progress_amount

        # Zero out errors amount if there were no errors.
        # It happens because of intermediate payments.
        if not has_errors:
            errors_amount = decimal.Decimal(0)

        # Turn off has_errors flags if all errors have been compensated.
        if errors_amount <= 0:
            has_errors = False

        return OrderPaymentProgress(
            cost=order_cost,
            has_errors=has_errors,
            in_progress_amount=in_progress_amount,
            paid_amount=paid_amount,
            errors_amount=errors_amount,
        )

    def get_order_item_payment_summary(self, order_item):
        cost = self.get_order_item_cost(order_item)

        bonus_amount = decimal.Decimal(0)
        for payment in order_item.payments.all():
            payment_method = payment.get_payment_method()
            if payment_method is OrderItemPayment.PaymentMethod.BONUS:
                bonus_amount += payment.amount
            elif payment_method is OrderItemPayment.PaymentMethod.CARD:
                pass
            else:
                LOGGER.error('unknown payment method: %s', payment_method)
                continue

        # There may be multiple not_authorized card payments which screws the summary.
        # With only two payment methods card amount is the cost except for bonus payments.
        card_amount = cost - bonus_amount

        return PaymentSummary(
            cost=cost,
            bonus_amount=bonus_amount,
            card_amount=card_amount,
        )

    def get_order_item_cost(self, item):
        if item.tariff is None:
            return decimal.Decimal(0)

        tariff_type = item.tariff.get_type()

        if tariff_type is OrderItemTariff.Type.FIX:
            # Assume FIX tariff activates only after the related item is finished.
            if item.finished_at is None:
                cost = decimal.Decimal(0)
            else:
                cost = self._get_fix_cost(item)
        elif tariff_type is OrderItemTariff.Type.PER_MINUTE:
            cost = self._get_per_minute_cost(item)
        else:
            raise RuntimeError('unreachable: {}'.format(tariff_type))

        # Cost should be rounded to two digits to match DB field precision.
        rounded_cost = round(cost, 2)

        return rounded_cost

    def _get_fix_cost(self, item):
        return item.tariff.fix_params.cost

    def _get_per_minute_cost(self, item):
        finished_at = item.finished_at or timezone.now()
        duration_seconds = int((finished_at - item.started_at).total_seconds())
        cost = item.tariff.per_minute_params.cost_per_minute * duration_seconds / 60
        return cost

    def make_payments_for_unpaid_orders(self):
        orders = (
            Order.objects
            .with_related()
            .select_related(
                'user__bonus_account',
            )
            .filter(
                payment_method__isnull=False,
                payment_status=Order.PaymentStatus.NEW.value,
            )
        )

        for order in orders:
            try:
                self.make_payments(order)
            except Exception:
                LOGGER.exception('failed to make payments for order %s', order.id)
                continue

    def make_payments(self, order, payment_method=None):
        if payment_method is None:
            payment_method = order.payment_method
        if payment_method is None:
            raise self.PaymentMethodMissingError

        items = order.get_sorted_items()
        item_payments = list(
            OrderItemPayment.objects
            .select_related(
                'bonus_payment',
                'card_payment',
            )
            .filter(order_item__in=items)
        )

        # Group payments related to each item.
        payments_per_item = collections.defaultdict(list)
        for item_payment in item_payments:
            payments_per_item[item_payment.order_item_id].append(item_payment)

        item_payment_specs = []
        for item in items:
            current_payments = payments_per_item[item.id]
            current_payments_amount = decimal.Decimal(0)
            for payment in current_payments:
                payment_impl = payment.get_impl()
                if payment_impl.is_error() and order.completed_at is not None:
                    # Count error payments only for ongoing orders to correctly track chunks.
                    continue
                current_payments_amount += payment.amount

            item_cost = self.get_order_item_cost(item)

            unprocessed_amount = item_cost - current_payments_amount
            if unprocessed_amount < 0:
                LOGGER.error(
                    'order item %s billed over its cost for %s',
                    item.id,
                    abs(unprocessed_amount),
                )
                continue

            if unprocessed_amount == 0:
                continue

            item_payment_specs.append(
                OrderItemPaymentSpec(order_item=item, amount=unprocessed_amount),
            )

        total_unprocessed_amount = sum(x.amount for x in item_payment_specs)
        if total_unprocessed_amount == 0:
            return []

        chunked_payment_amount = self._get_chunked_payment_amount(order)

        if order.completed_at is None and total_unprocessed_amount < chunked_payment_amount:
            # Wait for order items to accumulate enough cost.
            return []

        # Round payment amount down to the nearest chunk
        # if it's not the last payment and the amount is close to chunk boundary.
        if order.completed_at is None:
            unprocessed_over_chunk_amount = total_unprocessed_amount - chunked_payment_amount
            if unprocessed_over_chunk_amount < chunked_payment_amount / 10:
                LOGGER.info(
                    'rounding payment for order %s from %s to %s',
                    order.id,
                    total_unprocessed_amount,
                    total_unprocessed_amount - unprocessed_over_chunk_amount,
                )
                total_unprocessed_amount -= unprocessed_over_chunk_amount
                ok = False
                for payment_spec in reversed(item_payment_specs):
                    if payment_spec.amount >= unprocessed_over_chunk_amount:
                        payment_spec.amount -= unprocessed_over_chunk_amount
                        ok = True
                        break
                assert ok

        order_payment_spec = OrderPaymentSpec(
            order=order,
            payment_method=payment_method,
            item_specs=item_payment_specs,
        )

        return self.make_payments_by_spec(order_payment_spec)

    def make_payments_by_spec(self, order_payment_spec):
        payments = []
        item_specs_to_satisfy = order_payment_spec.item_specs.copy()

        with transaction.atomic():
            item_payments = []
            payment_usages = self._get_or_create_payments_for_spec(order_payment_spec)

            for payment_usage in payment_usages:
                payment_amount_available = payment_usage.amount_available
                payment = payment_usage.payment
                payments.append(payment)

                if isinstance(payment, BonusPayment):
                    params = {
                        'payment_method': OrderItemPayment.PaymentMethod.BONUS.value,
                        'bonus_payment': payment,
                    }
                elif isinstance(payment, CardPayment):
                    params = {
                        'payment_method': OrderItemPayment.PaymentMethod.CARD.value,
                        'card_payment': payment,
                    }
                else:
                    raise RuntimeError('unreachable: {}'.format(type(payment)))

                while payment_amount_available > 0 and item_specs_to_satisfy:
                    item_spec = item_specs_to_satisfy.pop()
                    if item_spec.amount > payment_amount_available:
                        amount1 = payment_amount_available
                        amount2 = item_spec.amount - payment_amount_available

                        item_spec1 = OrderItemPaymentSpec(
                            order_item=item_spec.order_item,
                            amount=amount1,
                        )
                        item_spec2 = OrderItemPaymentSpec(
                            order_item=item_spec.order_item,
                            amount=amount2,
                        )

                        item_spec = item_spec1
                        item_specs_to_satisfy.append(item_spec2)

                    item_payment = OrderItemPayment.objects.create(
                        order_item=item_spec.order_item,
                        amount=item_spec.amount,
                        created_at=timezone.now(),
                        **params
                    )
                    item_payments.append(item_payment)

                    payment_amount_available -= item_payment.amount
                    assert payment_amount_available >= 0

            payments_amount = sum(p.get_amount() for p in payments)
            item_payments_amount = sum(p.amount for p in item_payments)

            assert payments_amount >= item_payments_amount
            if payments_amount > item_payments_amount:
                LOGGER.info(
                    'order %s has underused payments for %s',
                    order_payment_spec.order.id,
                    payments_amount - item_payments_amount,
                )

        return item_payments

    def _get_or_create_payments_for_spec(self, order_payment_spec):
        assert order_payment_spec.payment_method.get_type() is OrderPaymentMethod.Type.CARD

        existing_underused_payments = [
            payment_usage
            for payment_usage in self._get_payment_usages(order=order_payment_spec.order)
            if payment_usage.amount_available > 0
        ]
        for payment_usage in existing_underused_payments:
            # Only card payments can be underused in current implementation.
            assert isinstance(payment_usage.payment, CardPayment)

        result = []
        amount = order_payment_spec.get_payment_amount()

        if amount > 0 and order_payment_spec.is_bonus_allowed():
            bonus_payment = self._make_bonus_payment_for_amount(
                order=order_payment_spec.order,
                amount=amount,
            )
            if bonus_payment is not None:
                result.append(
                    PaymentUsage(
                        payment=bonus_payment,
                        amount_available=bonus_payment.get_amount(),
                    )
                )
                amount -= bonus_payment.get_amount()

        while amount > 0 and existing_underused_payments and order_payment_spec.is_card_allowed():
            payment_usage = existing_underused_payments.pop()
            amount_available = min(amount, payment_usage.amount_available)
            result.append(
                PaymentUsage(
                    payment=payment_usage.payment,
                    amount_available=amount_available,
                )
            )
            amount -= amount_available

        if amount > 0 and order_payment_spec.is_card_allowed():
            card_payment = self._payment_processor.make_card_payment(
                user=order_payment_spec.order.user,
                amount=amount,
                paymethod_id=order_payment_spec.payment_method.card_paymethod_id,
            )
            result.append(
                PaymentUsage(
                    payment=card_payment,
                    amount_available=card_payment.get_amount(),
                )
            )
            amount -= card_payment.get_amount()

        assert amount == 0

        return result

    def _get_payment_usages(self, order):
        amount_used_per_payment = collections.defaultdict(lambda: 0)
        for order_item in order.get_sorted_items():
            for order_item_payment in order_item.payments.all():
                payment = order_item_payment.get_impl()
                amount_used = 0 if payment.is_error() else order_item_payment.amount
                amount_used_per_payment[payment] += amount_used

        payment_usages = []
        for payment, amount_used in amount_used_per_payment.items():
            amount_available = 0 if payment.is_error() else payment.amount - amount_used
            payment_usage = PaymentUsage(
                payment=payment,
                amount_available=amount_available,
            )
            payment_usages.append(payment_usage)

        return payment_usages

    def _make_bonus_payment_for_amount(self, order, amount):
        payment = self._payment_processor.try_make_bonus_payment(
            user=order.user,
            max_amount=amount,
        )
        return payment

    def finalize_all_orders_payment_statuses(self):
        new_orders = (
            Order.objects
            .with_payments()
            .filter(
                completed_at__isnull=False,
                payment_status=Order.PaymentStatus.NEW.value,
            )
        )

        maybe_paid_error_orders_ids = (
            OrderItemPayment.objects
            .filter(
                Q(
                    payment_method=OrderItemPayment.PaymentMethod.CARD.value,
                    card_payment__status__in=[
                        CardPayment.Status.AUTHORIZED.value,
                        CardPayment.Status.CLEARED.value,
                    ],
                )
                |
                Q(
                    payment_method=OrderItemPayment.PaymentMethod.BONUS.value,
                ),
                order_item__order__payment_status=Order.PaymentStatus.ERROR.value,
            )
            .order_by('-order_item__order__created_at')
            .values('order_item__order_id')
            [:16384]  # Sanity limit.
        )
        maybe_paid_error_orders = (
            Order.objects
            .with_payments()
            .filter(id__in=maybe_paid_error_orders_ids)
        )

        for order in itertools.chain(new_orders, maybe_paid_error_orders):
            try:
                self.finalize_order_payment_status(order)
            except Exception:
                LOGGER.exception('failed to finalize payment status for order %s', order.id)
                continue

    def finalize_order_payment_status(self, order):
        assert order.completed_at is not None

        for item in order.get_sorted_items():
            for payment in item.payments.all():
                if payment.get_impl().get_generic_status() is IPayment.GenericStatus.IN_PROGRESS:
                    LOGGER.info(
                        'cannot finalize order payments while some payments are in progress: %s',
                        order.id,
                    )
                    return

        self._resize_underused_payments(order)
        self._update_order_payment_status(order)

    def _get_chunked_payment_amount(self, order):
        car_manufacturer = order.get_sorted_items()[0].get_impl().car.model.manufacturer
        if car_manufacturer == 'Porsche':
            return decimal.Decimal('911.00')
        else:
            return self._chunked_payment_amount

    def _resize_underused_payments(self, order):
        assert order.completed_at is not None

        payment_usages = self._get_payment_usages(order)

        for payment_usage in payment_usages:
            payment = payment_usage.payment
            if payment.get_generic_status() is not IPayment.GenericStatus.SUCCESS:
                continue

            if payment_usage.amount_available > 0:
                new_amount = payment.get_amount() - payment_usage.amount_available
                self._payment_processor.resize(
                    payment=payment,
                    amount=new_amount,
                )
                self._payment_processor.process(
                    payment=payment,
                    card_processor_kwargs={'force_clear': True},
                )
                payment.refresh_from_db()

    def _update_order_payment_status(self, order):
        if order.completed_at is None:
            assert order.get_payment_status() is Order.PaymentStatus.NEW
            return

        progress = self.get_order_payment_progress(order)
        if progress.in_progress_amount > 0:
            assert order.get_payment_status() is Order.PaymentStatus.NEW
            return

        if progress.has_errors and progress.cost > progress.paid_amount:
            new_status = Order.PaymentStatus.ERROR
        else:
            new_status = Order.PaymentStatus.SUCCESS

        if order.get_payment_status() is new_status:
            return

        updater = OrderUpdater(order=order, push_client=self._push_client)
        updater.update_payment_status(new_status)

    def refund_order(self, order):
        """
        This method refunds only card payments.

        TODO: bonus payments.
        """

        expected_statuses = {
            Order.PaymentStatus.ERROR,
            Order.PaymentStatus.REFUNDED,
            Order.PaymentStatus.SUCCESS,
        }
        if order.get_payment_status() not in expected_statuses:
            raise self.Error('payment_status.invalid')

        card_payment_ids = (
            OrderItemPayment.objects
            .filter(
                order_item__order=order,
                payment_method=OrderItemPayment.PaymentMethod.CARD.value,
            )
            .values_list('card_payment_id', flat=True)
        )
        card_payments = CardPayment.objects.filter(id__in=card_payment_ids)
        for card_payment in card_payments:
            if card_payment.get_generic_status() is not IPayment.GenericStatus.SUCCESS:
                LOGGER.info(
                    'skip refunding card payment with status different from success: %s %s',
                    card_payment.id,
                    card_payment.purchase_token,
                )
                continue
            self._payment_processor.refund(card_payment)

        updater = OrderUpdater(order, push_client=self._push_client)
        updater.update_payment_status(payment_status=Order.PaymentStatus.REFUNDED)

    def refund_order_item(self, order_item):
        card_payment_ids = (
            OrderItemPayment.objects
            .filter(
                order_item=order_item,
                payment_method=OrderItemPayment.PaymentMethod.CARD.value,
            )
            .values_list('card_payment_id', flat=True)
        )
        card_payments = CardPayment.objects.filter(id__in=card_payment_ids)
        for card_payment in card_payments:
            if card_payment.get_generic_status() is not IPayment.GenericStatus.SUCCESS:
                LOGGER.info(
                    'skip refunding card payment with status different from success: %s %s',
                    card_payment.id,
                    card_payment.purchase_token,
                )
                continue
            self._payment_processor.refund(card_payment)

    def refund(self, order_item_payment):
        if order_item_payment.get_payment_method() is not OrderItemPayment.PaymentMethod.CARD:
            raise self.Error('payment_method.unsuported')
        self._payment_processor.refund(order_item_payment.card_payment)

    def wait_for_payments(self, payments, timeout):
        payment_impls = [p.get_impl() for p in payments]
        self._payment_processor.wait_for_completion(payment_impls, timeout=timeout)

        for payment in payments:
            payment.get_impl().refresh_from_db()

        return payments


OrderPaymentProgress = collections.namedtuple(
    'OrderPaymentProgress',
    [
        'cost',
        'has_errors',
        'in_progress_amount',
        'paid_amount',
        'errors_amount',
    ],
)


PaymentSummary = collections.namedtuple(
    'PaymentSummary',
    [
        'cost',
        'bonus_amount',
        'card_amount',
    ],
)


PaymentUsage = collections.namedtuple(
    'PaymentUsage',
    [
        'payment',
        'amount_available',
    ],
)


class OrderItemPaymentSpec:

    def __init__(self, order_item, amount):
        self.order_item = order_item
        self.amount = amount


class OrderPaymentSpec:

    def __init__(self, order, payment_method, item_specs,
                 allowed_payment_methods=None, payment_amount=None):
        for item_spec in item_specs:
            assert item_spec.order_item.order_id == order.id

        self.order = order
        self.payment_method = payment_method
        self.item_specs = item_specs
        self._allowed_payment_methods = allowed_payment_methods
        self._payment_amount = payment_amount

        if self._payment_amount is not None:
            assert self._payment_amount >= self._get_item_specs_total_amount()

    def is_bonus_allowed(self):
        return (
            self._allowed_payment_methods is None
            or OrderItemPayment.PaymentMethod.BONUS in self._allowed_payment_methods
        )

    def is_card_allowed(self):
        return (
            self._allowed_payment_methods is None
            or OrderItemPayment.PaymentMethod.CARD in self._allowed_payment_methods
        )

    def get_payment_amount(self):
        if self._payment_amount is not None:
            return self._payment_amount
        return self._get_item_specs_total_amount()

    def _get_item_specs_total_amount(self):
        return sum(item_spec.amount for item_spec in self.item_specs)
