package ru.yandex.travel.orders.workflows.orderitem;

import java.math.RoundingMode;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import javax.money.CurrencyUnit;

import com.google.common.base.Preconditions;
import com.google.common.io.BaseEncoding;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.ProtocolMessageEnum;
import lombok.extern.slf4j.Slf4j;
import org.javamoney.moneta.Money;

import ru.yandex.travel.commons.lang.ComparatorUtils;
import ru.yandex.travel.commons.lang.MoneyUtils;
import ru.yandex.travel.commons.proto.ProtoUtils;
import ru.yandex.travel.commons.proto.TPrice;
import ru.yandex.travel.hotels.common.orders.BaseRate;
import ru.yandex.travel.hotels.common.orders.RefundInfo;
import ru.yandex.travel.hotels.common.orders.RefundReason;
import ru.yandex.travel.orders.entities.FiscalItem;
import ru.yandex.travel.orders.entities.FiscalItemType;
import ru.yandex.travel.orders.entities.HotelOrderItem;
import ru.yandex.travel.orders.entities.MoneyMarkup;
import ru.yandex.travel.orders.entities.MoneyRefund;
import ru.yandex.travel.orders.proto.THotelRefundToken;
import ru.yandex.travel.orders.services.finances.proto.EMoneyRefundMode;
import ru.yandex.travel.orders.workflow.hotels.proto.ERefundReason;
import ru.yandex.travel.orders.workflow.invoice.proto.TMoneyMarkup;
import ru.yandex.travel.orders.workflow.invoice.proto.TPaymentRefund;
import ru.yandex.travel.orders.workflow.order.proto.TServiceRefunded;
import ru.yandex.travel.orders.workflows.order.hotel.MoneyRefundUtils;
import ru.yandex.travel.workflow.StateContext;
import ru.yandex.travel.workflow.entities.WorkflowEntity;

import static java.util.stream.Collectors.toMap;
import static ru.yandex.travel.orders.workflows.invoice.trust.InvoiceUtils.toProtoPaymentMarkup;

@Slf4j
public class RefundingUtils {

    public static THotelRefundToken getRefundToken(String token) {
        try {
            return THotelRefundToken.parseFrom(BaseEncoding.base64Url().decode(token));
        } catch (InvalidProtocolBufferException e) {
            throw new RuntimeException("Unable to deserialize refund token", e);
        }
    }

    public static RefundInfo createRefundInfo(THotelRefundToken refundToken, ERefundReason reason) {
        RefundInfo refundInfo = new RefundInfo();

        refundInfo.setRefund(BaseRate.fromTPrice(refundToken.getRefundAmount()));
        refundInfo.setPenalty(BaseRate.fromTPrice(refundToken.getPenaltyAmount()));
        refundInfo.setPenaltyIntervalIndex(refundToken.getPenaltyIndex());
        refundInfo.setRefundDateTime(LocalDateTime.ofInstant(ProtoUtils.toInstant(refundToken.getRequestedAt()),
                ZoneOffset.UTC));
        switch (reason) {
            case RR_USER:
                refundInfo.setReason(RefundReason.USER);
                break;
            case RR_OPERATOR:
                refundInfo.setReason(RefundReason.OPERATOR);
                break;
            case RR_SCHEDULE:
                refundInfo.setReason(RefundReason.SCHEDULE);
                break;
        }
        return refundInfo;
    }

    public static <S extends ProtocolMessageEnum,
            E extends HotelOrderItem & WorkflowEntity<? extends ProtocolMessageEnum>>
    void scheduleServiceRefundedEvent(StateContext<S, ? extends E> context, THotelRefundToken refundToken,
                                      boolean skipFinEvents, EMoneyRefundMode moneyRefundMode,
                                      String refundDescription) {
        HotelOrderItem order = context.getWorkflowEntity();
        order.setRefundedAt(Instant.now());
        Money newFiscalAmount = ProtoUtils.fromTPrice(refundToken.getNewInvoiceAmount());
        MoneyMarkup newFiscalAmountMarkup = MoneyRefundUtils.fromTMoneyMarkup(refundToken.getNewInvoiceAmountMarkup());
        TargetFiscalItems targetFiscalItems = calculateTargetFiscalItemsGroupingByType(
                order.getFiscalItems(), newFiscalAmount, newFiscalAmountMarkup);
        context.scheduleExternalEvent(order.getOrderWorkflowId(),
                TServiceRefunded.newBuilder()
                        .setServiceId(order.getId().toString())
                        .setSkipFinEvents(skipFinEvents)
                        .setMoneyRefundMode(moneyRefundMode)
                        .setRefundDescription(refundDescription)
                        .putAllTargetFiscalItems(convertTargetFiscalItemsToProto(targetFiscalItems.getPrices()))
                        .putAllTargetFiscalItemsMarkup(convertTargetFiscalItemsMarkupToProto(targetFiscalItems.getPricesMarkup()))
                        .build()
        );
    }

    /**
     * The default behaviour is to return all yandex plus points first.
     * So we keep as much card money as possible.
     * <b>WARN</b>: not the best way to calculate actual money distribution in case of repeated refunds.
     * todo(tlg-13): we need to avoid this method for new refunds
     */
    public static MoneyMarkup calculateDefaultTargetMoneyMarkup(List<FiscalItem> fiscalItems, Money newAmount) {
        CurrencyUnit currency = newAmount.getCurrency();
        MoneyMarkup currentAmount = fiscalItems.stream()
                .map(FiscalItem::getMoneyAmountMarkup)
                .reduce(MoneyMarkup.zero(currency), MoneyMarkup::add);
        return calculateDefaultTargetMoneyMarkup(currentAmount, newAmount);
    }

    /**
     * The default behaviour is to return all yandex plus points first.
     * So we keep as much card money as possible.
     * <b>WARN</b>: not the best way to calculate actual money distribution in case of repeated refunds.
     * todo(tlg-13): we need to avoid this method for new refunds
     */
    public static MoneyMarkup calculateDefaultTargetMoneyMarkup(MoneyMarkup currentAmount, Money newAmount) {
        Preconditions.checkArgument(newAmount.isPositiveOrZero(), "Negative new amount: %s", newAmount);
        MoneyUtils.ensureDecimalScale(newAmount);

        Money newCard = ComparatorUtils.min(newAmount, currentAmount.getCard());
        newAmount = newAmount.subtract(newCard);

        Money newYandexAccount = ComparatorUtils.min(newAmount, currentAmount.getYandexAccount());
        newAmount = newAmount.subtract(newYandexAccount);

        // it's important to keep Plus points as integer money values,
        // our payment gateways doesn't fail on decimals but just throws the decimal parts away
        Money newYandexAccountDecimalExtra = MoneyUtils.roundUpToInteger(newYandexAccount).subtract(newYandexAccount);
        if (!newYandexAccountDecimalExtra.isZero()) {
            newYandexAccount = newYandexAccount.add(newYandexAccountDecimalExtra);
            newCard = newCard.subtract(newYandexAccountDecimalExtra);
            Preconditions.checkArgument(newYandexAccount.isLessThanOrEqualTo(currentAmount.getYandexAccount()),
                    "No enough Yandex Account money: needed %s, current %s",
                    newYandexAccount, currentAmount.getYandexAccount());
            Preconditions.checkArgument(newCard.isPositiveOrZero(),
                    "No enough card money: needed %s, current %s",
                    newCard, currentAmount.getCard());
        }

        Preconditions.checkState(newAmount.isZero(),
                "Not enough fiscal item money to provide target sum; left %s, new target values [%s, %s]",
                newAmount, newCard, newYandexAccount);
        if (!newYandexAccount.isZero()) {
            log.warn("Implicitly calculating refund money markup for an order with Yandex Plus points: {}",
                    currentAmount);
        }
        return MoneyMarkup.builder()
                .card(newCard)
                .yandexAccount(newYandexAccount)
                .build();
    }

    public static TargetFiscalItems calculateTargetFiscalItems(List<FiscalItem> fiscalItems,
                                                               Money newAmount,
                                                               MoneyMarkup newAmountMarkup) {
        Preconditions.checkArgument(newAmount.equals(newAmountMarkup.getTotal()),
                "Target amount doesn't match its markup; target %s, markup %s", newAmount, newAmountMarkup);
        Map<Long, MoneyMarkup> targetItemsMarkup = calculateTargetFiscalItemsMarkup(fiscalItems, newAmountMarkup);
        return TargetFiscalItems.builder()
                .prices(targetItemsMarkup.entrySet().stream().collect(
                        toMap(Map.Entry::getKey, e -> e.getValue().getTotal())))
                .pricesMarkup(targetItemsMarkup)
                .build();
    }

    public static TargetFiscalItems calculateTargetFiscalItemsGroupingByType(List<FiscalItem> fiscalItems,
                                                                             Money newFiscalAmount,
                                                                             MoneyMarkup newFiscalAmountMarkup) {
        if (newFiscalAmountMarkup != null) {
            return calculateTargetFiscalItems(fiscalItems, newFiscalAmount, newFiscalAmountMarkup);
        } else {
            CurrencyUnit currency = newFiscalAmount.getCurrency();
            Money totalCurrentAmount = fiscalItems.stream()
                    .map(FiscalItem::getMoneyAmountMarkup)
                    .reduce(MoneyMarkup.zero(currency), MoneyMarkup::add).getTotal();

            Map<Long, MoneyMarkup> result = new HashMap<>();
            Map<FiscalItemType, List<FiscalItem>> itemsByType =
                    fiscalItems.stream().collect(Collectors.groupingBy(FiscalItem::getType, Collectors.toList()));
            for (var typeItems : itemsByType.values()) {
                Money typeCurrentAmount = typeItems.stream()
                        .map(FiscalItem::getMoneyAmountMarkup)
                        .reduce(MoneyMarkup.zero(currency), MoneyMarkup::add).getTotal();
                Money typeShare = typeCurrentAmount.divide(totalCurrentAmount.getNumber());
                Money typeNewAmount = MoneyUtils.roundToDecimal(newFiscalAmount.multiply(typeShare.getNumber()), RoundingMode.HALF_UP);
                MoneyMarkup typeNewAmountMarkup = calculateDefaultTargetMoneyMarkup(typeItems, typeNewAmount);
                result.putAll(calculateTargetFiscalItemsMarkup(typeItems, typeNewAmountMarkup));
            }
            var newTotal = result.values().stream().reduce(MoneyMarkup.zero(currency), MoneyMarkup::add).getTotal();
            Preconditions.checkState(newTotal.isEqualTo(newFiscalAmount), "Unexpected total breakdown");
            return TargetFiscalItems.builder()
                    .prices(result.entrySet().stream().collect(
                            toMap(Map.Entry::getKey, e -> e.getValue().getTotal())))
                    .pricesMarkup(result)
                    .build();
        }
    }

    private static Map<Long, MoneyMarkup> calculateTargetFiscalItemsMarkup(List<FiscalItem> fiscalItems,
                                                                           MoneyMarkup newAmount) {
        CurrencyUnit currency = newAmount.getTotal().getCurrency();
        MoneyMarkup currentAmount = fiscalItems.stream()
                .map(FiscalItem::getMoneyAmountMarkup)
                .reduce(MoneyMarkup.zero(currency), MoneyMarkup::add);
        MoneyMarkup amountToSubtract = currentAmount.subtract(newAmount);
        Preconditions.checkState(!amountToSubtract.hasNegativeValues(),
                "Unexpected negative amountToSubtract; current %s, new %s", currentAmount, newAmount);

        Map<Long, MoneyMarkup> newPrices = new HashMap<>();
        for (int i = fiscalItems.size() - 1; i >= 0; i--) {  // reversed direction: latest items
            // are refunded first
            FiscalItem fi = fiscalItems.get(i);
            MoneyMarkup currentItemAmount = fi.getMoneyAmountMarkup();
            MoneyMarkup actualItemRefund = currentItemAmount.merge(amountToSubtract, ComparatorUtils::min);
            newPrices.put(fi.getFiscalItemId(), currentItemAmount.subtract(actualItemRefund));
            amountToSubtract = amountToSubtract.subtract(actualItemRefund);
        }
        log.info("New target prices markup: {}", newPrices);
        Preconditions.checkState(amountToSubtract.getTotal().isZero(),
                "Failed to reduce fiscal items for the the full refund amount, left %s", amountToSubtract);
        return newPrices;
    }

    public static Map<Long, TPrice> convertTargetFiscalItemsToProto(Map<Long, Money> targetFiscalItems) {
        return targetFiscalItems.entrySet().stream()
                .collect(Collectors.toMap(Map.Entry::getKey, entry -> ProtoUtils.toTPrice(entry.getValue())));
    }

    public static Map<Long, Money> convertTargetFiscalItemsFromProto(Map<Long, TPrice> targetFiscalItems) {
        return targetFiscalItems.entrySet().stream()
                .collect(Collectors.toMap(Map.Entry::getKey, entry -> ProtoUtils.fromTPrice(entry.getValue())));
    }

    public static Map<Long, TMoneyMarkup> convertTargetFiscalItemsMarkupToProto(Map<Long, MoneyMarkup> targetFiscalItemsMarkup) {
        return targetFiscalItemsMarkup == null ? Map.of() : targetFiscalItemsMarkup.entrySet().stream()
                .peek(e -> e.getValue().ensureValuesScale())
                .collect(Collectors.toMap(Map.Entry::getKey, entry -> toProtoPaymentMarkup(entry.getValue())));
    }

    public static Map<Long, MoneyMarkup> convertTargetFiscalItemsMarkupFromProto(Map<Long, TMoneyMarkup> targetFiscalItemsMarkup) {
        Map<Long, MoneyMarkup> result = targetFiscalItemsMarkup == null ? Map.of() :
                targetFiscalItemsMarkup.entrySet().stream()
                        .collect(Collectors.toMap(Map.Entry::getKey, entry -> MoneyMarkup.builder()
                                .card(ProtoUtils.fromTPrice(entry.getValue().getCard()))
                                .yandexAccount(ProtoUtils.fromTPrice(entry.getValue().getYandexAccount()))
                                .build()));
        for (MoneyMarkup markup : result.values()) {
            markup.ensureValuesScale();
        }
        return result;
    }

    // some orders don't support the multiple money sources at the moment
    public static Map<Long, MoneyMarkup> convertTargetFiscalItemsToCardOnlyMarkup(Map<Long, Money> targetFiscalItems) {
        return targetFiscalItems.entrySet().stream()
                .collect(Collectors.toMap(Map.Entry::getKey, entry -> MoneyMarkup.cardOnly(entry.getValue())));
    }

    public static TPaymentRefund createRefundEvent(Map<Long, Money> targetPricesByFiscalItem,
                                                   Map<Long, MoneyMarkup> targetPricesByFiscalItemMarkup,
                                                   String moneyRefundReason) {
        return TPaymentRefund.newBuilder()
                .putAllTargetFiscalItems(convertTargetFiscalItemsToProto(targetPricesByFiscalItem))
                .putAllTargetFiscalItemsMarkup(convertTargetFiscalItemsMarkupToProto(targetPricesByFiscalItemMarkup))
                .setReason(moneyRefundReason)
                .build();
    }

    public static TPaymentRefund createRefundEvent(MoneyRefund refund) {
        return TPaymentRefund.newBuilder()
                .putAllTargetFiscalItems(convertTargetFiscalItemsToProto(refund.getTargetFiscalItems()))
                .putAllTargetFiscalItemsMarkup(convertTargetFiscalItemsMarkupToProto(refund.getTargetFiscalItemsMarkup()))
                .setReason(refund.getReason() != null ? refund.getReason() : "Refund")
                .setOrderRefundId(ProtoUtils.toStringOrEmpty(refund.getOrderRefundId()))
                .build();
    }
}
