package ru.yandex.travel.orders.services.train;

import java.math.BigDecimal;
import java.util.List;

import javax.money.Monetary;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.javamoney.moneta.Money;
import org.springframework.stereotype.Service;

import ru.yandex.travel.commons.proto.ProtoCurrencyUnit;
import ru.yandex.travel.commons.proto.ProtoUtils;
import ru.yandex.travel.orders.entities.TrainOrderItem;
import ru.yandex.travel.orders.services.train.bandit.TrainBanditClient;
import ru.yandex.travel.orders.services.train.bandit.TrainBanditProperties;
import ru.yandex.travel.train.model.CarType;
import ru.yandex.travel.train.model.FeeSource;
import ru.yandex.travel.train.model.TrainPassenger;
import ru.yandex.travel.train.model.TrainReservation;
import ru.yandex.travel.train.model.TrainTicket;
import ru.yandex.travel.train_bandit_api.proto.TChargeStringCtx;
import ru.yandex.travel.train_bandit_api.proto.TGetChargeByTokenRequest;
import ru.yandex.travel.train_bandit_api.proto.TGetChargeByTokenResponse;
import ru.yandex.travel.train_bandit_api.proto.TGetChargeStringCtxRequest;
import ru.yandex.travel.train_bandit_api.proto.TGetChargeStringCtxResponse;
import ru.yandex.travel.train_bandit_api.proto.TTicketFee;
import ru.yandex.travel.train_bandit_api.proto.TTicketPrice;
import ru.yandex.travel.train_bandit_api.proto.TTicketPricesStringCtx;

@Service
@Slf4j
@RequiredArgsConstructor
public class FeeCalculationService {
    private final TrainBanditProperties properties;
    private final TrainBanditClient banditClient;
    private final TrainMeters meters;

    public void calculateAndAssignFees(TrainOrderItem orderItem, int tryNumber) {
        if (properties.isEnabled() && tryNumber > properties.getMaxTriesToFallback()) {
            meters.getTrainOrdersBanditFallbacks().increment();
            calculateFeesLocal(orderItem);
        } else if (properties.isEnabled() && !Strings.isNullOrEmpty(orderItem.getPayload().getBanditToken())) {
            calculateFeesBanditByToken(orderItem);
        } else if (properties.isEnabled() && !Strings.isNullOrEmpty(orderItem.getPayload().getBanditContext())) {
            calculateFeesBandit(orderItem);
        } else {
            calculateFeesLocal(orderItem);
        }
    }

    private void calculateFeesLocal(TrainOrderItem orderItem) {
        TrainReservation payload = orderItem.getPayload();
        payload.setBanditType(null);
        payload.setBanditVersion(0L);
        payload.setPermille(properties.getLocalPermille());
        payload.setFeeSource(FeeSource.LOCAL);
        var feeFactor = BigDecimal.valueOf(properties.getLocalPermille(), 3);
        var partnerFee = Money.of(properties.getLocalPartnerFee(), ProtoCurrencyUnit.RUB);
        var partnerRefundFee = Money.of(properties.getLocalPartnerRefundFee(), ProtoCurrencyUnit.RUB);
        List<TrainPassenger> passengers = payload.getPassengers();
        for (TrainPassenger passenger : passengers) {
            TrainTicket ticket = passenger.getTicket();
            ticket.setPartnerFee(partnerFee);
            ticket.setPartnerRefundFee(partnerRefundFee);
            ticket.setBanditFeeApplied(false);

            Money serviceAmount = ticket.getServiceAmount();
            Money amount = ticket.getTariffAmount().add(ticket.getServiceAmount());
            if (amount.isLessThanOrEqualTo(serviceAmount) || payload.getCarType() != CarType.RESERVED_SEAT) {
                serviceAmount = Money.zero(ProtoCurrencyUnit.RUB);
            }
            Money tariffAmount = amount.subtract(serviceAmount);
            Money serviceFee = serviceAmount.multiply(feeFactor);
            Money fee = tariffAmount.multiply(feeFactor);
            Money minimalFee = tariffAmount.multiply(properties.getLocalMinFeeFactor()).add(partnerFee);

            Money fullFee = serviceFee;
            if (minimalFee.isGreaterThan(fee)) {
                fullFee = fullFee.add(minimalFee);
            } else {
                fullFee = fullFee.add(fee);
            }
            fullFee = fullFee.with(Monetary.getDefaultRounding());
            ticket.setFeeAmount(fullFee);
        }
    }

    private void calculateFeesBandit(TrainOrderItem orderItem) {
        TrainReservation payload = orderItem.getPayload();
        var chargeRequest = TGetChargeStringCtxRequest.newBuilder();
        if (!Strings.isNullOrEmpty(payload.getReservationRequestData().getBanditType())) {
            chargeRequest.setBanditType(payload.getReservationRequestData().getBanditType());
        }
        var requestPrices = TTicketPricesStringCtx.newBuilder();
        requestPrices.setInternalId(0);
        requestPrices.setContext(payload.getBanditContext());
        List<TrainPassenger> passengers = payload.getPassengers();
        for (int i = 0; i < passengers.size(); i++) {
            TrainPassenger passenger = passengers.get(i);
            TrainTicket ticket = passenger.getTicket();
            requestPrices.putTicketPrices(i, TTicketPrice.newBuilder()
                    .setAmount(ProtoUtils.toTPrice(ticket.getTariffAmount().add(ticket.getServiceAmount())))
                    .setServiceAmount(ProtoUtils.toTPrice(ticket.getServiceAmount()))
                    .build());
        }
        chargeRequest.addContextsWithPrices(requestPrices.build());

        TGetChargeStringCtxResponse chargeResponse = banditClient.getCharge(chargeRequest.build());
        Preconditions.checkState(chargeResponse.getChargesByContextsList().size() == 1);
        TChargeStringCtx charges = chargeResponse.getChargesByContextsList().get(0);

        payload.setBanditType(charges.getBanditType());
        payload.setBanditVersion(charges.getBanditVersion());
        payload.setPermille(charges.getPermille());
        payload.setFeeSource(FeeSource.TRAIN_BANDIT_API);
        for (int i = 0; i < passengers.size(); i++) {
            var passenger = passengers.get(i);
            TTicketFee responseFees = charges.getTicketFeesOrThrow(i);
            passenger.getTicket().setPartnerFee(ProtoUtils.fromTPrice(chargeResponse.getPartnerFee()));
            passenger.getTicket().setPartnerRefundFee(ProtoUtils.fromTPrice(chargeResponse.getPartnerRefundFee()));
            Money fullFee = ProtoUtils.fromTPrice(responseFees.getFee())
                    .add(ProtoUtils.fromTPrice(responseFees.getServiceFee()));
            passenger.getTicket().setFeeAmount(fullFee);
            passenger.getTicket().setBanditFeeApplied(responseFees.getIsBanditFeeApplied());
        }
    }

    private void calculateFeesBanditByToken(TrainOrderItem orderItem) {
        TrainReservation payload = orderItem.getPayload();
        var chargeRequest = TGetChargeByTokenRequest.newBuilder();
        chargeRequest.setFeeCalculationToken(payload.getBanditToken());
        List<TrainPassenger> passengers = payload.getPassengers();
        for (int i = 0; i < passengers.size(); i++) {
            TrainPassenger passenger = passengers.get(i);
            TrainTicket ticket = passenger.getTicket();
            chargeRequest.putTicketPrices(i, TTicketPrice.newBuilder()
                    .setAmount(ProtoUtils.toTPrice(ticket.getTariffAmount().add(ticket.getServiceAmount())))
                    .setServiceAmount(ProtoUtils.toTPrice(ticket.getServiceAmount()))
                    .build());
        }
        TGetChargeByTokenResponse chargeResponse = banditClient.getChargeByToken(chargeRequest.build());
        TChargeStringCtx charges = chargeResponse.getChargesByContext();

        payload.setBanditType(charges.getBanditType());
        payload.setBanditVersion(charges.getBanditVersion());
        payload.setPermille(charges.getPermille());
        payload.setFeeSource(FeeSource.TRAIN_BANDIT_API);
        for (int i = 0; i < passengers.size(); i++) {
            var passenger = passengers.get(i);
            TTicketFee responseFees = charges.getTicketFeesOrThrow(i);
            passenger.getTicket().setPartnerFee(ProtoUtils.fromTPrice(chargeResponse.getPartnerFee()));
            passenger.getTicket().setPartnerRefundFee(ProtoUtils.fromTPrice(chargeResponse.getPartnerRefundFee()));
            Money fullFee = ProtoUtils.fromTPrice(responseFees.getFee())
                    .add(ProtoUtils.fromTPrice(responseFees.getServiceFee()));
            passenger.getTicket().setFeeAmount(fullFee);
            passenger.getTicket().setBanditFeeApplied(responseFees.getIsBanditFeeApplied());
        }
    }
}
