package ru.yandex.yp.discovery.impl;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.ByteString;
import io.grpc.Metadata;
import io.grpc.StatusRuntimeException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.yp.discovery.model.YpError;
import ru.yandex.yp.discovery.model.YpException;
import ru.yandex.yt.TError;
import ru.yandex.yt.ytree.TAttribute;
import ru.yandex.yt.ytree.TAttributeDictionary;

/**
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
public class YpCompletableFuture<T> extends CompletableFuture<T> {

    private final static Logger logger = LoggerFactory.getLogger(YpCompletableFuture.class);

    private static final Metadata.Key<byte[]> YT_ERROR_BIN =
            Metadata.Key.of("yt-error-bin", Metadata.BINARY_BYTE_MARSHALLER);

    private final ListenableFuture<T> inner;

    public YpCompletableFuture(ListenableFuture<T> inner) {
        this.inner = inner;
        Futures.addCallback(inner, new FutureCallback<T>() {

            @Override
            public void onSuccess(T result) {
                YpCompletableFuture.this.complete(result);
            }

            @Override
            public void onFailure(Throwable t) {
                if (t instanceof StatusRuntimeException) {
                    StatusRuntimeException grpcException = (StatusRuntimeException) t;
                    Metadata trailers = grpcException.getTrailers();
                    if (trailers.containsKey(YT_ERROR_BIN)) {
                        try {
                            TError error = TError.parseFrom(trailers.get(YT_ERROR_BIN));
                            YpError ypError = mapError(error);
                            String message = ypError.getMessage().orElse(grpcException.getMessage());
                            YpException ypException = new YpException(message, grpcException.getStatus().getDescription(),
                                    ypError, grpcException);
                            YpCompletableFuture.this.completeExceptionally(ypException);
                        } catch (Exception e) {
                            logger.error("Failed to deserialize YP error from response", e);
                            YpCompletableFuture.this.completeExceptionally(t);
                        }
                    } else {
                        YpCompletableFuture.this.completeExceptionally(t);
                    }
                } else {
                    YpCompletableFuture.this.completeExceptionally(t);
                }
            }
        }, MoreExecutors.directExecutor());
    }

    @Override
    public boolean cancel(boolean mayInterruptIfRunning) {
        if (isDone()) {
            return false;
        }
        boolean result = inner.cancel(mayInterruptIfRunning);
        super.cancel(mayInterruptIfRunning);
        return result;
    }

    private YpError mapError(TError value) {
        return new YpError(value.getCode(),
                value.hasMessage() ? value.getMessage() : null,
                value.hasAttributes() ? mapAttributes(value.getAttributes()) : new HashMap<>(),
                value.getInnerErrorsList().stream().map(this::mapError).collect(Collectors.toList()));

    }

    private Map<String, ByteString> mapAttributes(TAttributeDictionary value) {
        return value.getAttributesList().stream().collect(Collectors.toMap(TAttribute::getKey, TAttribute::getValue));
    }

}
