package ru.yandex.travel.orders.grpc.helpers;

import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;

import com.google.common.base.Strings;
import com.google.protobuf.Message;
import io.grpc.stub.StreamObserver;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.springframework.stereotype.Service;
import org.springframework.transaction.support.TransactionTemplate;

import ru.yandex.travel.commons.grpc.ServerUtils;
import ru.yandex.travel.commons.logging.NestedMdc;
import ru.yandex.travel.commons.proto.Error;
import ru.yandex.travel.orders.OrdersApplicationProperties;
import ru.yandex.travel.orders.configurations.jdbc.TxScope;
import ru.yandex.travel.orders.configurations.jdbc.TxScopeType;
import ru.yandex.travel.orders.grpc.CallType;
import ru.yandex.travel.orders.grpc.GrpcExceptionHelper;
import ru.yandex.travel.orders.infrastructure.CallDescriptor;
import ru.yandex.travel.orders.services.orders.OrderAggregateStateRefresher;
import ru.yandex.travel.workflow.ClientCallService;

@Service
@Slf4j
@RequiredArgsConstructor
public class TxCallWrapper {

    private final OrdersApplicationProperties ordersApplicationProperties;

    private final TransactionTemplate transactionTemplate;

    private final ClientCallService clientCallService;

    private final OrderAggregateStateRefresher orderAggregateStateRefresher;

    public <ReqT extends Message, RspT extends Message> void synchronouslyWithTx(
            CallDescriptor<ReqT> callDescriptor,
            StreamObserver<RspT> observer,
            Function<ReqT, RspT> handler) {
        synchronouslyWithTx(callDescriptor, observer, log, handler);
    }

    public <ReqT extends Message, RspT extends Message> void synchronouslyWithTx(
            CallDescriptor<ReqT> callDescriptor,
            StreamObserver<RspT> observer,
            Logger log,
            Function<ReqT, RspT> handler
    ) {
        synchronouslyWithTx(callDescriptor, observer, log, handler, TxScopeType.GRPC);
    }

    public <ReqT extends Message, RspT extends Message> void synchronouslyWithTx(
            CallDescriptor<ReqT> callDescriptor,
            StreamObserver<RspT> observer,
            Logger log,
            Function<ReqT, RspT> handler,
            TxScopeType txScopeType
    ) {
        ServerUtils.synchronously(log, callDescriptor.getRequest(), observer,
                rq -> {
                    String appCallId = callDescriptor.getCallId();
                    ensureOperationPreconditions(callDescriptor);
                    return TxScope.supplyWithTxScope(txScopeType, () -> transactionTemplate.execute(ignored -> {
                        if (!Strings.isNullOrEmpty(appCallId) &&
                                callDescriptor.getCallType() == CallType.READ_WRITE) {
                            return clientCallService.computeIfAbsent(appCallId, () -> handler.apply(rq));
                        } else {
                            return handler.apply(rq);
                        }
                    }));
                },
                ex -> GrpcExceptionHelper.mapStatusException(log, callDescriptor.getRequest(), ex)
        );
    }

    /**
     * After synchronous execution of db action, async action gets executed without transaction.
     * <p>
     * The method is to be used in case we need to free a thread when we execute a db action and then make a request
     * to another remote service.
     *
     * @implNote caching of responses isn't possible as in
     * {@link #synchronouslyWithTx(CallDescriptor, StreamObserver, Logger, Function)}
     */
    public <ReqT extends Message, Inter, RspT extends Message> void synchronouslyWithTxAndThenAsync(
            ReqT request,
            StreamObserver<RspT> observer,
            Logger log,
            Function<ReqT, Inter> dbAction,
            Function<Inter, CompletableFuture<RspT>> netCall
    ) {
        ServerUtils.asynchronously(log, request, observer,
                rq -> {
                    Inter operation1Result = TxScope.supplyWithTxScope(TxScopeType.GRPC,
                            () -> transactionTemplate.execute(
                                    ignored -> dbAction.apply(rq)
                            ));
                    return netCall.apply(operation1Result);
                },
                ex -> GrpcExceptionHelper.mapStatusException(log, request, ex)
        );
    }

    public <ReqT extends Message, RspT extends Message> void synchronouslyWithTxForOrder(
            String orderIdStr,
            CallDescriptor<ReqT> callDescriptor,
            StreamObserver<RspT> observer,
            Logger log,
            Function<ReqT, RspT> handler
    ) {
        ServerUtils.synchronously(log, callDescriptor.getRequest(), observer, rq -> {
                    UUID orderId = ProtoChecks.checkStringIsUuid("Order id", orderIdStr);
                    String appCallId = callDescriptor.getCallId();
                    CallType callType = callDescriptor.getCallType();
                    ensureOperationPreconditions(callDescriptor);
                    try (var ignored2 = NestedMdc.forEntity(orderId, null)) {
                        try {
                            return TxScope.supplyWithTxScope(
                                    callType == CallType.READ_ONLY ? TxScopeType.READ_ONLY :
                                            TxScopeType.GRPC,
                                    () -> transactionTemplate.execute(ignored -> {
                                        if (!Strings.isNullOrEmpty(appCallId) && callType == CallType.READ_WRITE) {
                                            return clientCallService.computeIfAbsent(appCallId, () -> {
                                                RspT result = handler.apply(rq);
                                                if (ordersApplicationProperties.isRefreshOrderAggregateStateOnWrite()) {
                                                    orderAggregateStateRefresher.refreshOrderAggregateState(orderId);
                                                }
                                                return result;
                                            });
                                        } else {
                                            RspT result = handler.apply(rq);
                                            if (ordersApplicationProperties.isRefreshOrderAggregateStateOnWrite()
                                                    && callType == CallType.READ_WRITE) {
                                                orderAggregateStateRefresher.refreshOrderAggregateState(orderId);
                                            }
                                            return result;
                                        }
                                    }));
                        } catch (Exception e) {
                            log.error("An exception has happened during the call", e);
                            throw e;
                        }
                    }
                },
                ex -> GrpcExceptionHelper.mapStatusException(log, callDescriptor.getRequest(), ex)
        );
    }

    public void ensureOperationPreconditions(CallDescriptor<?> callDescriptor) {
        if (callDescriptor.getCallType() == CallType.READ_WRITE) {
            boolean emptyCallId = Strings.isNullOrEmpty(callDescriptor.getCallId());
            if (ordersApplicationProperties.isRequireAppCallIdForWrite()) {
                Error.checkState(!emptyCallId, "App call id required but not present");
            } else if (emptyCallId) {
                log.warn("Calling READ_WRITE operation {} without call id. Will fail if call id is required",
                        callDescriptor.getName());
            }
        }
    }

}
