package ru.yandex.travel.orders.entities;

import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Stream;

import javax.persistence.Entity;
import javax.persistence.GeneratedValue;
import javax.persistence.GenerationType;
import javax.persistence.Id;
import javax.persistence.ManyToOne;
import javax.persistence.SequenceGenerator;
import javax.persistence.Table;

import com.google.common.base.Preconditions;
import com.google.common.collect.Sets;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import org.hibernate.annotations.BatchSize;
import org.hibernate.annotations.Type;
import org.javamoney.moneta.Money;

import ru.yandex.travel.commons.lang.ComparatorUtils;
import ru.yandex.travel.commons.proto.TPrice;
import ru.yandex.travel.commons.streams.CustomCollectors;
import ru.yandex.travel.orders.commons.proto.EOrderType;
import ru.yandex.travel.orders.workflow.invoice.proto.TMoneyMarkup;
import ru.yandex.travel.orders.workflows.orderitem.RefundingUtils;

import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toSet;

@Entity
@Table(name = "money_refunds")
@Data
@EqualsAndHashCode(exclude = {"order"})
@ToString(exclude = "order")
@BatchSize(size = 100)
public class MoneyRefund {
    @Id
    @GeneratedValue(
            strategy = GenerationType.SEQUENCE,
            generator = "money_refunds_id_seq"
    )
    @SequenceGenerator(
            name = "money_refunds_id_seq",
            sequenceName = "money_refunds_id_seq",
            allocationSize = 1
    )
    private Long id;

    @ManyToOne
    private Order order;

    @Type(type = "jsonb-object")
    private Map<Long, Money> targetFiscalItems;

    @Type(type = "jsonb-object")
    private Map<Long, MoneyMarkup> targetFiscalItemsMarkup;

    @Type(type = "custom-enum")
    private MoneyRefundState state;

    private UUID orderRefundId;

    private String reason;

    @Type(type = "jsonb-object")
    private MoneyRefundContext context;

    public static MoneyRefund createPendingRefundFromProto(Order order,
                                                           Map<Long, TPrice> targetFiscalItems,
                                                           Map<Long, TMoneyMarkup> targetFiscalItemsMarkup,
                                                           UUID orderRefundId, String reason) {
        return createPendingRefund(order,
                RefundingUtils.convertTargetFiscalItemsFromProto(targetFiscalItems),
                RefundingUtils.convertTargetFiscalItemsMarkupFromProto(targetFiscalItemsMarkup),
                orderRefundId, reason);
    }

    public static MoneyRefund createPendingRefund(Order order,
                                                  Map<Long, Money> targetFiscalItems,
                                                  Map<Long, MoneyMarkup> targetFiscalItemsMarkup,
                                                  UUID orderRefundId, String reason) {
        validateTargetValues(order, targetFiscalItems, targetFiscalItemsMarkup);

        MoneyRefund refund = new MoneyRefund();
        refund.setContext(new MoneyRefundContext());
        refund.getContext().setInvoicesToRefund(new HashMap<>());
        refund.setState(MoneyRefundState.PENDING);
        refund.setTargetFiscalItems(targetFiscalItems);
        refund.setTargetFiscalItemsMarkup(targetFiscalItemsMarkup);
        refund.setOrderRefundId(orderRefundId);
        refund.setReason(reason);
        order.addMoneyRefund(refund);

        @SuppressWarnings("SuspiciousMethodCalls")
        List<Invoice> paidInvoices = order.getInvoices().stream()
                .filter(i -> !TrustInvoice.UNPAID_INVOICE_STATES.contains(i.getInvoiceState()))
                .collect(toList());
        Map<Long, Money> takenMoneyByFiscalItemIdMap = paidInvoices.stream()
                .flatMap(i -> i.getInvoiceItems().stream())
                .collect(groupingBy(InvoiceItem::getFiscalItemId,
                        CustomCollectors.summingMoney(InvoiceItem::getPriceMoney,
                                order.getCurrency())));

        Map<Long, Money> fiscalItemRefundMap = takenMoneyByFiscalItemIdMap.entrySet().stream()
                .filter(e -> targetFiscalItems.containsKey(e.getKey()))
                .collect(toMap(Map.Entry::getKey,
                        e -> {
                            Money val = e.getValue().subtract(targetFiscalItems.get(e.getKey()));
                            Preconditions.checkState(val.isPositiveOrZero(),
                                    String.format("Unexpected negative refund amount for fiscal item %s", e.getKey()));
                            return val;
                        }));

        Map<Long, MoneyMarkup> fiscalItemRefundMarkupMap = buildRefundMarkupMap(targetFiscalItemsMarkup, paidInvoices);

        paidInvoices.stream()
                .sorted(Comparator.comparing(Invoice::getCreatedAt).reversed())
                // reversed sort - newest first
                .forEach(i -> {
                    Map<Long, Money> newPricesForInvoice = new HashMap<>();
                    Map<Long, MoneyMarkup> newPricesForInvoiceMarkup = new HashMap<>();
                    for (var invoiceItem : i.getInvoiceItems()) {
                        Long fiscalItemId = invoiceItem.getFiscalItemId();
                        Money itemPrice = invoiceItem.getPriceMoney();
                        Money toRefund = fiscalItemRefundMap.get(fiscalItemId);
                        MoneyMarkup itemPriceMarkup = invoiceItem.getPriceMarkup();
                        MoneyMarkup toRefundMarkup = fiscalItemRefundMarkupMap.get(fiscalItemId);
                        InvoiceItemRefund itemRefund = tryItemRefund(toRefund, toRefundMarkup,
                                itemPrice, itemPriceMarkup);
                        if (itemRefund != null) {
                            fiscalItemRefundMap.put(fiscalItemId, itemRefund.refundRemainder);
                            newPricesForInvoice.put(fiscalItemId, itemRefund.newItemPrice);
                            if (toRefundMarkup != null) {
                                fiscalItemRefundMarkupMap.put(fiscalItemId, itemRefund.refundRemainderMarkup);
                                newPricesForInvoiceMarkup.put(fiscalItemId, itemRefund.newItemPriceMarkup);
                            }
                        }
                    }
                    if (!newPricesForInvoice.isEmpty()) {
                        validateTargetValues(order, newPricesForInvoice, newPricesForInvoiceMarkup);
                        refund.getContext().getInvoicesToRefund().put(i.getId(),
                                MoneyRefundContext.InvoiceRefund.builder()
                                        .invoiceId(i.getId())
                                        .state(MoneyRefundContext.InvoiceRefundState.PENDING)
                                        .targetPricesByFiscalItem(newPricesForInvoice)
                                        .targetPricesByFiscalItemMarkup(newPricesForInvoiceMarkup)
                                        .build());
                    }
                });
        fiscalItemRefundMap.forEach((key, value) ->
                Preconditions.checkState(value.isZero(),
                        String.format("Unable to meet target refund amount for item %s: left %s", key, value)));
        fiscalItemRefundMarkupMap.forEach((key, value) ->
                Preconditions.checkState(value.getTotal().isZero(),
                        String.format("Unable to meet target refund amount markup for item %s: left %s", key, value)));
        return refund;
    }

    public void addTargetFiscalItems(Map<Long, TPrice> targetFiscalItemsProto,
                                     Map<Long, TMoneyMarkup> targetFiscalItemsMarkupProto) {
        // be aware refund.getContext().getInvoicesToRefund() was only supported for hotel orders,
        // while this is generic-order-only functionality
        Preconditions.checkState(order.getPublicType() == EOrderType.OT_GENERIC,
                "Only generic orders (without deferred payments) support items addition",
                order.getPublicType());

        Map<Long, Money> targetFiscalItemsMap =
                RefundingUtils.convertTargetFiscalItemsFromProto(targetFiscalItemsProto);
        Map<Long, MoneyMarkup> targetFiscalItemsMarkup =
                RefundingUtils.convertTargetFiscalItemsMarkupFromProto(targetFiscalItemsMarkupProto);
        validateTargetValues(order, targetFiscalItemsMap, targetFiscalItemsMarkup);

        checkNoFiscalItemsIntersection(this.targetFiscalItems, targetFiscalItemsMap);
        this.targetFiscalItems.putAll(targetFiscalItemsMap);

        checkNoFiscalItemsIntersection(this.targetFiscalItemsMarkup, targetFiscalItemsMarkup);
        this.targetFiscalItemsMarkup.putAll(targetFiscalItemsMarkup);
    }

    public static void validateTargetValues(Order order,
                                            Map<Long, Money> targetFiscalItems,
                                            Map<Long, MoneyMarkup> targetFiscalItemsMarkup) {
        Set<Long> allFiscalItemIdsViaOrderItems = order.getOrderItems().stream()
                .flatMap(x -> x.getFiscalItems().stream())
                .map(FiscalItem::getId)
                .collect(toSet());
        Set<Long> allFiscalItemIdsViaInvoices = order.getInvoices().stream()
                .flatMap(i -> i.getInvoiceItems().stream())
                .map(InvoiceItem::getFiscalItemId)
                .collect(toSet());
        Set<Long> allFiscalItemIds = Stream.concat(allFiscalItemIdsViaOrderItems.stream(),
                allFiscalItemIdsViaInvoices.stream()).collect(toSet());

        Preconditions.checkArgument(allFiscalItemIds.containsAll(targetFiscalItems.keySet()),
                "Unknown fiscal item ids received: current %s, target %s",
                allFiscalItemIds, targetFiscalItems.keySet());

        var nonZeroTargetFiscalItemIds = targetFiscalItems.entrySet().stream()
                .filter(x -> !x.getValue().isZero())
                .map(Map.Entry::getKey)
                .collect(toSet());
        Preconditions.checkArgument(allFiscalItemIdsViaInvoices.containsAll(nonZeroTargetFiscalItemIds),
                "Unknown non-zero fiscal item ids received: current %s, target %s",
                allFiscalItemIdsViaInvoices, nonZeroTargetFiscalItemIds);

        if (targetFiscalItemsMarkup != null) {
            Preconditions.checkArgument(targetFiscalItemsMarkup.entrySet().stream().allMatch(e ->
                            Objects.equals(targetFiscalItems.get(e.getKey()), e.getValue().getTotal())),
                    "Broken markup map; markup %s, target items %s", targetFiscalItems, targetFiscalItemsMarkup);
        }
    }

    private static Map<Long, MoneyMarkup> buildRefundMarkupMap(Map<Long, MoneyMarkup> targetMarkup,
                                                               List<Invoice> paidInvoices) {
        if (targetMarkup == null) {
            return Map.of();
        }

        Map<Long, MoneyMarkup> currentMarkup = new HashMap<>();
        for (var entry : targetMarkup.entrySet()) {
            currentMarkup.put(entry.getKey(), MoneyMarkup.zero(entry.getValue().getCard().getCurrency()));
        }
        for (Invoice invoice : paidInvoices) {
            for (InvoiceItem invoiceItem : invoice.getInvoiceItems()) {
                MoneyMarkup markupSum = currentMarkup.get(invoiceItem.getFiscalItemId());
                currentMarkup.put(invoiceItem.getFiscalItemId(), invoiceItem.getPriceMarkup().addSafe(markupSum));
            }
        }

        Map<Long, MoneyMarkup> refundsMarkup = new HashMap<>();
        for (Long fiscalId : targetMarkup.keySet()) {
            refundsMarkup.put(fiscalId, currentMarkup.get(fiscalId).subtract(targetMarkup.get(fiscalId)));
        }

        for (Map.Entry<Long, MoneyMarkup> e : refundsMarkup.entrySet()) {
            Preconditions.checkState(!e.getValue().hasNegativeValues(),
                    "Negative refund markup for %s: %s", e.getKey(), e.getValue());
        }

        return refundsMarkup;
    }

    public Map<Long, Money> getTargetFiscalItems() {
        return targetFiscalItems != null ? Collections.unmodifiableMap(targetFiscalItems) : Collections.emptyMap();
    }

    private void checkNoFiscalItemsIntersection(Map<Long, ?> oldValues, Map<Long, ?> newValues) {
        Set<Long> keysConflict = Sets.intersection(oldValues.keySet(), newValues.keySet());
        Preconditions.checkArgument(keysConflict.isEmpty(),
                "Old fiscal items conflict with the new ones: old - %s, new - %s",
                oldValues, newValues);
    }

    private static InvoiceItemRefund tryItemRefund(Money toRefund, MoneyMarkup toRefundMarkup,
                                                   Money itemPrice, MoneyMarkup itemPriceMarkup) {
        if (toRefund == null || toRefund.isZero() || itemPrice.isZero()) {
            return null;
        }
        if (toRefundMarkup == null) {
            Money actualItemRefund = ComparatorUtils.min(toRefund, itemPrice);
            Money refundRemainder = toRefund.subtract(actualItemRefund);
            Money remainingPrice = itemPrice.subtract(actualItemRefund);
            return InvoiceItemRefund.builder()
                    .newItemPrice(remainingPrice)
                    .refundRemainder(refundRemainder)
                    .build();
        } else {
            MoneyMarkup actualItemRefundMarkup = toRefundMarkup.merge(itemPriceMarkup, ComparatorUtils::min);
            MoneyMarkup refundRemainderMarkup = toRefundMarkup.subtract(actualItemRefundMarkup);
            MoneyMarkup remainingPriceMarkup = itemPriceMarkup.subtract(actualItemRefundMarkup);
            return InvoiceItemRefund.builder()
                    .newItemPrice(remainingPriceMarkup.getTotal())
                    .refundRemainder(refundRemainderMarkup.getTotal())
                    .newItemPriceMarkup(remainingPriceMarkup)
                    .refundRemainderMarkup(refundRemainderMarkup)
                    .build();
        }
    }

    @Builder
    private static class InvoiceItemRefund {
        private final Money refundRemainder;
        private final Money newItemPrice;
        // the fields below are only set for refunds with markups
        private final MoneyMarkup refundRemainderMarkup;
        private final MoneyMarkup newItemPriceMarkup;
    }
}
