package ru.yandex.chemodan.ydb.dao.twin;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import org.joda.time.Duration;
import org.joda.time.Instant;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.SetF;
import ru.yandex.bolts.function.Function0;
import ru.yandex.bolts.function.Function2B;
import ru.yandex.chemodan.http.YandexCloudRequestIdHolder;
import ru.yandex.misc.ExceptionUtils;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;


/**
 * @author yashunsky
 */

public class TwinDaoInvocationHandler<T> implements InvocationHandler {
    private static final Logger logger = LoggerFactory.getLogger(TwinDaoInvocationHandler.class);

    private static final SetF<String> DO_NOT_LOG_INVOCATION_METHODS = Cf.set("equals", "hashCode");

    public static final Function2B<String, Object> EMPTY_RESULT = (methodName, result) -> {
        if (result instanceof Collection) {
            return ((Collection) result).isEmpty();
        } else if (result instanceof Map) {
            return ((Map) result).isEmpty();
        }
        return false;
    };

    private final T primaryDao;
    private final T secondaryDao;

    private final ExecutorService primaryExecutor;
    private final ExecutorService secondaryExecutor;

    private final Function0<Duration> getFallbackTimeoutF;
    private final Function0<TwinDaoMode> getModeF;

    private final Function2B<String, Object> looksSuspiciousF;

    public TwinDaoInvocationHandler(
            T primaryDao, T secondaryDao,
            int threadsCount,
            int queueSize,
            Function0<Duration> getFallbackTimeoutF,
            Function0<TwinDaoMode> getModeF,
            Function2B<String, Object> looksSuspiciousF)
    {
        this.primaryDao = primaryDao;
        this.secondaryDao = secondaryDao;
        this.primaryExecutor = new ThreadPoolExecutor(threadsCount, threadsCount,
                0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(threadsCount + queueSize),
                new ThreadPoolExecutor.AbortPolicy());
        this.secondaryExecutor = new ThreadPoolExecutor(threadsCount, threadsCount,
                0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(threadsCount + queueSize),
                new ThreadPoolExecutor.AbortPolicy());
        this.getFallbackTimeoutF = getFallbackTimeoutF;
        this.getModeF = getModeF;
        this.looksSuspiciousF = looksSuspiciousF;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        TwinDaoMode mode = getModeF.apply();
        switch (mode) {
            case PRIMARY:
                return invokeInner(primaryDao, method, args);
            case SECONDARY:
                return invokeInner(secondaryDao, method, args);
            case PRIMARY_W_ASYNC:
                invokeAsync(secondaryExecutor, secondaryDao, method, args);
                return invokeInner(primaryDao, method, args);
            case SECONDARY_W_FALLBACK:
                return invokeWithFallback(secondaryExecutor, secondaryDao, primaryExecutor, primaryDao, method, args);
            case SECONDARY_W_ASYNC:
                invokeAsync(primaryExecutor, primaryDao, method, args);
                return invokeInner(secondaryDao, method, args);
            case BOTH_FROM_PRIMARY:
                return invokeBoth(primaryDao, secondaryDao, method, args);
            case BOTH_FROM_SECONDARY:
                return invokeBoth(secondaryDao, primaryDao, method, args);
            default:
                throw new IllegalStateException("Unsupported mode " + mode);

        }
    }

    private Object invokeBoth(T mainDao, T referenceDao, Method method, Object[] args) throws Throwable {
        boolean mainSuccessful = false;
        Object mainResult;
        Object referenceResult;
        try {
            mainResult = invokeInner(mainDao, method, args);
            mainSuccessful = true;
            referenceResult = invokeInner(referenceDao, method, args);
        } catch (Throwable t) {
            ExceptionUtils.throwIfUnrecoverable(t);
            String errorString = "failed to invoke "
                    + getInvokeString(mainSuccessful ? referenceDao : mainDao, method, args)
                    + " (" + (mainSuccessful ? "reference" : "main") + ")";
            logger.error(errorString, t);
            throw t;
        }
        if (!mainResult.equals(referenceResult)) {
            logger.warn("reference check failed for {}: main: {} reference: {}",
                    getInvokeString(mainDao, method, args),
                    mainResult, referenceResult);
        }
        return mainResult;
    }

    private Object invokeWithFallback(
            ExecutorService mainExecutor, T mainDao,
            ExecutorService fallbackExecutor, T fallbackDao, Method method, Object[] args) throws Throwable
    {
        long fallbackTimeout = getFallbackTimeoutF.apply().getMillis();
        CompletableFuture<Object> mainTask = invokeAsync(mainExecutor, mainDao, method, args);
        CompletableFuture<Object> fallbackTask = invokeAsync(fallbackExecutor, fallbackDao, method, args);

        if (mainTask.isCompletedExceptionally() && fallbackTask.isCompletedExceptionally()) {
            //both executors are full
            logger.warn("Both executors full, returning fallback result");
            return invokeInner(fallbackDao, method, args);
        }

        long deadline = Instant.now().getMillis() + fallbackTimeout;

        // main task is done, or we ran out of time. Return main result if it doesn't look empty
        boolean mainTaskDoneInTime = await(mainTask, fallbackTimeout);
        if (mainTaskDoneInTime) {
            Object mainResult = mainTask.get();
            if (looksNonSuspicious(method.getName(), mainResult)) {
                logger.info("Returning fast main result");
                return mainResult;
            }
        }

        // wait till original deadline for fallback task.
        // It will be a real await only if main task returned an empty result
        if (await(fallbackTask, deadline - Instant.now().getMillis())) {
            Object fallbackResult = fallbackTask.get();
            if (!mainTaskDoneInTime) {
                logger.info("Only fallback dao was in time: " + getInvokeString(fallbackDao, method, args));
            } else if (looksNonSuspicious(method.getName(), fallbackResult)) {
                logger.info("Result only found in fallback dao: " + getInvokeString(fallbackDao, method, args));
            }
            logger.info("Returning fast fallback result");
            return fallbackResult;
        }

        // non of tasks returned result in time. Waiting for any
        CompletableFuture<Object> anyTask = CompletableFuture.anyOf(mainTask, fallbackTask);
        if (await(anyTask)) {
            //one of daos returned a suitable result
            return anyTask.get();
        } else {
            //one of daos failed, let's wait for the other
            if (mainTask.isCompletedExceptionally()) {
                logger.info("Returning slow fallback result");
                return fallbackTask.get();
            } else {
                logger.info("Returning slow main result");
                return mainTask.get();
            }
        }
    }

    private boolean looksNonSuspicious(String methodName, Object result) {
        return !looksSuspiciousF.apply(methodName, result);
    }

    private CompletableFuture<Object> invokeAsync(
            ExecutorService executor, T dao, Method method, Object[] args)
    {
        try {
            return CompletableFuture.supplyAsync(YandexCloudRequestIdHolder.supplyWithYcrid(
                    () -> invokeWithLogging(dao, method, args)), executor);
        } catch (RejectedExecutionException e) {
            logger.warn("failed to invoke {}. queue is full", getInvokeString(dao, method, args));
            CompletableFuture<Object> failedResult = new CompletableFuture<>();
            failedResult.completeExceptionally(e);
            return failedResult;
        }
    }

    private Object invokeWithLogging(T dao, Method method, Object[] args) {
        boolean doLog = !DO_NOT_LOG_INVOCATION_METHODS.containsTs(method.getName());
        String invokeString = getInvokeString(dao, method, args);
        Instant start = Instant.now();
        try {
            Object result = invokeInner(dao, method, args);
            if (doLog) {
                logger.info("{} took {}", invokeString, new Duration(start, Instant.now()));
            }
            return result;
        } catch (Throwable e) {
            ExceptionUtils.throwIfUnrecoverable(e);
            if (doLog) {
                logger.error("{} failed. took {}: {}", invokeString, new Duration(start, Instant.now()), e);
            }
            throw new RuntimeException(invokeString + " failed" , e);
        }
    }

    private String getInvokeString(T dao, Method method, Object[] args) {
        return dao.getClass().getSimpleName() + " " +
                method.getName() + ((args == null) ? "()" : "(" + Cf.x(Arrays.asList(args)) + ")");
    }

    private Object invokeInner(T dao, Method method, Object[] args) throws Throwable {
        try {
            return method.invoke(dao, args);
        } catch (IllegalAccessException e) {
            throw ExceptionUtils.translate(e);
        } catch (InvocationTargetException e) {
            throw e.getTargetException();
        }
    }

    private boolean await(CompletableFuture future, long timeoutMillis) {
        return await(future, Option.of(timeoutMillis));
    }

    private boolean await(CompletableFuture future) {
        return await(future, Option.empty());
    }

    private boolean await(CompletableFuture future, Option<Long> timeoutMillis) {
        try {
            if (timeoutMillis.isPresent()) {
                future.get(timeoutMillis.get() > 0 ? timeoutMillis.get() : 0, TimeUnit.MILLISECONDS);
            } else {
                future.get();
            }
            return true;
        } catch (InterruptedException|ExecutionException|TimeoutException e) {
            return false;
        }
    }
}
