package ru.yandex.intranet.d.services.integration.providers.grpc;

import java.util.function.Supplier;

import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall;
import io.grpc.ForwardingClientCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;

import ru.yandex.intranet.d.services.integration.providers.ProvidersIntegrationService;

/**
 * GRPC client interceptor to handle request id.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
public class RequestIdInterceptor implements ClientInterceptor {

    public static final CallOptions.Key<Supplier<String>> REQUEST_ID_KEY = CallOptions.Key
            .create("X-Request-ID");
    public static final CallOptions.Key<RequestIdHolder> REQUEST_ID_HOLDER_KEY = CallOptions.Key
            .create("X-Request-ID-Holder");

    @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
                                                               CallOptions callOptions, Channel next) {
        Supplier<String> requestIdSupplier = callOptions.getOption(REQUEST_ID_KEY);
        RequestIdHolder requestIdHolder = callOptions.getOption(REQUEST_ID_HOLDER_KEY);
        requestIdHolder.setRequestId(null);
        return new ForwardingClientCall.SimpleForwardingClientCall<>(next.newCall(method, callOptions)) {
            @Override
            public void start(Listener<RespT> responseListener, Metadata headers) {
                if (requestIdSupplier != null) {
                    String requestId = requestIdSupplier.get();
                    headers.put(ProvidersIntegrationService.REQUEST_ID_KEY, requestId);
                }
                super.start(new ForwardingClientCallListener.SimpleForwardingClientCallListener<>(responseListener) {
                    @Override
                    public void onHeaders(Metadata headers) {
                        if (headers.containsKey(ProvidersIntegrationService.REQUEST_ID_KEY)) {
                            requestIdHolder.setRequestId(headers.get(ProvidersIntegrationService.REQUEST_ID_KEY));
                        }
                        super.onHeaders(headers);
                    }
                }, headers);
            }
        };
    }

}
