package ru.yandex.travel.grpc.interceptors;

import java.util.Set;

import com.google.common.base.Preconditions;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;

/**
 * FOR INTEGRATION TESTS ONLY // shameless copy
 */
public class RetryingClientInterceptor implements ClientInterceptor {

    public static final Set<Status.Code> RETRYABLE_CODES = Set.of(
            Status.Code.UNAVAILABLE, Status.Code.ABORTED
    );
    private final int maxAttempts;

    private RetryingClientInterceptor(int maxAttempts) {
        this.maxAttempts = maxAttempts;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
            final MethodDescriptor<ReqT, RespT> method,
            final CallOptions callOptions,
            final Channel next) {

        class RetryingUnaryRequestClientCall<RQ, RS> extends ClientCall<RQ, RS> {

            Listener listener;
            Metadata metadata;
            RQ msg;
            int req;
            ClientCall call;

            @Override
            public void start(Listener listener, Metadata metadata) {
                this.listener = listener;
                this.metadata = metadata;
            }

            @Override
            public void sendMessage(RQ msg) {
                assert this.msg == null;
                this.msg = msg;
            }

            @Override
            public void request(int num) {
                req += num;
                assert this.msg == null;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void halfClose() {
                startCall(new CheckingListener(1));
            }

            private void startCall(Listener listener) {
                call = next.newCall(method, callOptions);
                Metadata headers = new Metadata();
                headers.merge(metadata);
                call.start(listener, headers);
                assert this.msg != null;
                call.request(req);
                call.sendMessage(msg);
                call.halfClose();
            }

            @Override
            public void cancel(String s, Throwable t) {
                if (call != null) { // need synchronization
                    call.cancel(s, t);
                }
                // technically should use CallOptions.getExecutor() if set
                listener.onClose(Status.CANCELLED.withDescription(s).withCause(t), new Metadata());
            }

            class CheckingListener extends ForwardingClientCallListener {

                private final int attempt;
                Listener<RespT> delegate;

                public CheckingListener(int attempt) {
                    this.attempt = attempt;
                }

                @Override
                protected Listener delegate() {
                    if (delegate == null) {
                        throw new IllegalStateException();
                    }
                    return delegate;
                }

                @Override
                public void onHeaders(Metadata headers) {
                    delegate = listener;
                    super.onHeaders(headers);
                }

                @Override
                public void onClose(Status status, Metadata trailers) {
                    if (delegate != null) {
                        super.onClose(status, trailers);
                        return;
                    }
                    if (!needRetry(status)) { // YOUR CODE HERE
                        delegate = listener;
                        super.onClose(status, trailers);
                        return;
                    }
                    startCall(new CheckingListener(attempt + 1)); // to allow multiple retries
                }

                private boolean needRetry(Status status) {
                    return attempt < maxAttempts && RETRYABLE_CODES.contains(status.getCode());
                }
            }


        }
        return new RetryingUnaryRequestClientCall<>();
    }

    public static class Builder {
        private int maxAttempts = 3;

        public Builder maxAttempts(int maxAttempts) {
            Preconditions.checkArgument(maxAttempts > 0, "Max attempts must be positive");
            this.maxAttempts = maxAttempts;
            return this;
        }

        public RetryingClientInterceptor build() {
            return new RetryingClientInterceptor(maxAttempts);
        }
    }

}
