package ru.yandex.travel.grpc.interceptors;

import java.util.Optional;

import com.google.common.base.Preconditions;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ClientInterceptors;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.stub.AbstractStub;
import lombok.extern.slf4j.Slf4j;

import ru.yandex.travel.tvm.TvmWrapper;

@Slf4j
public class TvmHeaderClientHelper {
    private TvmWrapper tvm2;

    public TvmHeaderClientHelper(TvmWrapper tvm2) {
        this.tvm2 = tvm2;
    }

    public <T extends AbstractStub<T>> T withTvmInterceptor(T stub, String dstAlias) {
        return stub.withInterceptors(new TvmSettingHeaderInterceptor(dstAlias));
    }

    public ClientInterceptor getInterceptor(String dstAlias) {
        return new TvmSettingHeaderInterceptor(dstAlias);
    }

    private final class TvmSettingHeaderInterceptor implements ClientInterceptor {
        private String dstAlias;

        public TvmSettingHeaderInterceptor(String dstAlias) {
            Preconditions.checkArgument(dstAlias != null, "You must provide a non-null dstAlias for tvm client interceptor");
            this.dstAlias = dstAlias;
            tvm2.validateAlias(dstAlias);
        }

        @Override
        public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {

            return new ClientInterceptors.CheckedForwardingClientCall<>(next.newCall(method, callOptions)) {
                @Override
                protected void checkedStart(Listener<RespT> responseListener, Metadata headers) {
                    Optional<String> maybeServiceTicket = tvm2.getServiceTicketOptional(dstAlias);
                    String serviceTicket = maybeServiceTicket.orElseThrow(
                            () -> Status.FAILED_PRECONDITION
                                    .withDescription(
                                            String.format("Couldn't get service ticket for dstAlias: %s", dstAlias)
                                    ).asRuntimeException()
                    );
                    headers.put(TvmHeaders.METADATA_SERVICE_TICKET_HEADER, serviceTicket);
                    delegate().start(responseListener, headers);
                }
            };
        }
    }
}
