package ru.yandex.travel.yt.daos;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Consumer;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.travel.yt.Factory;
import ru.yandex.travel.yt.YtDao;
import ru.yandex.travel.yt.queries.QueryPart;
import ru.yandex.travel.yt.switching.ErrorFlag;


public class MirroredYtDao<T> implements YtDao<T> {
    private YtDao<T> primary;
    private YtDao<T> secondary;
    private Logger logger;

    private ErrorFlag readError;
    private ErrorFlag writeError;


    public MirroredYtDao(Factory factory, String primaryCluster, String secondaryCluster, String basePath, Class<T> mappedClass) {
        this(factory, primaryCluster, secondaryCluster, basePath, mappedClass, null, null);
    }

    public MirroredYtDao(Factory factory, String primaryCluster, String secondaryCluster, String basePath, Class<T> mappedClass, String tableName, Long ttl) {
        primary = new SingleYtDao<T>(factory, primaryCluster, basePath, mappedClass, tableName, ttl);
        secondary = new SingleYtDao<T>(factory, secondaryCluster, basePath, mappedClass, tableName, ttl);
        this.logger = LoggerFactory.getLogger(this.getClass().getName() + "[" + mappedClass.getSimpleName() + "]");
        this.readError = new ErrorFlag(getTableName() + ":read", logger);
        this.writeError = new ErrorFlag(getTableName() + ":write", logger);
    }

    public MirroredYtDao(YtDao<T> primaryDao, YtDao<T> secondaryDao) {
        assert primaryDao.getTableName().equals(secondaryDao.getTableName());
        this.primary = primaryDao;
        this.secondary = secondaryDao;
        this.logger = LoggerFactory.getLogger(this.getClass().getName() + "[" + primaryDao.getMappedClass().getSimpleName() + "]");
        this.readError = new ErrorFlag(getTableName() + ":read", logger);
        this.writeError = new ErrorFlag(getTableName() + ":write", logger);
    }

    @Override
    public CompletableFuture<List<T>> select(List<QueryPart> queryParts) {
        return mirrorRead(YtDao<T>::select, queryParts, "select");
    }

    @Override
    public CompletableFuture<List<T>> select(QueryPart... queryParts) {
        return mirrorRead(YtDao<T>::select, queryParts, "select");
    }

    @Override
    public CompletableFuture<Optional<T>> get(List<Object> keyValues) {
        return mirrorRead(YtDao<T>::get, keyValues, "get");
    }

    @Override
    public CompletableFuture<Optional<T>> get(Object... keyValues) {
        return mirrorRead(YtDao<T>::get, keyValues, "get");
    }

    @Override
    public CompletableFuture<Void> put(T object) {
        return mirrorWrite(YtDao<T>::put, object, "put");
    }

    @Override
    public CompletableFuture<Void> put(List<T> objects) {
        return mirrorWrite(YtDao<T>::put, objects, "putMany");
    }

    @Override
    public CompletableFuture<Void> delete(List<Object> keyValues) {
        return mirrorWrite(YtDao<T>::delete, keyValues, "delete");
    }

    @Override
    public CompletableFuture<Void> delete(Object... keyValues) {
        return mirrorWrite(YtDao<T>::delete, keyValues, "delete");
    }

    protected <A, R> CompletableFuture<R> mirrorWrite(BiFunction<YtDao, A, CompletableFuture<R>> func, A args, String name) {
        CompletableFuture<R> result = new CompletableFuture<>();

        CompletableFuture<R> sec = func.apply(secondary, args).handle((r, t) -> {
            if (t != null) {
                logger.error("Secondary YT cluster error on {}: {}", name, t.getMessage());
                throw new RuntimeException("Secondary YT cluster error on " + name, t);
            } else {
                return r;
            }
        });

        func.apply(primary, args).whenComplete((r, t) -> {
            if (t != null) {
                logger.error("Primary YT cluster error on {}: {}", name, t.getMessage());
                writeError.error();
                sec.whenComplete((rr, tt) -> {
                    if (tt == null) {
                        result.complete(rr);
                    }
                    else {
                        result.completeExceptionally(new RuntimeException("Both clusters failed on " + name, t));
                    }
                });
            } else {
                writeError.ok();
                result.complete(r);
            }
        });

        if (shouldUseSecondary()) {
            return sec;
        }
        else {
            return result;
        }
    }


    protected <A, R> CompletableFuture<R> mirrorRead(BiFunction<YtDao, A, CompletableFuture<R>> func, A args, String name) {
        CompletableFuture<R> result = new CompletableFuture<>();

        CompletableFuture<R> alreadyRunning = shouldUseSecondary() ? getSecondaryReader(func, args, name) : null;
        func.apply(primary, args).whenComplete((r, t) -> {
            if (t != null) {
                logger.error("Primary YT cluster error on {}: {}", name, t.getMessage());
                readError.error();
                CompletableFuture<R> sec;
                if (alreadyRunning == null) {
                    sec = getSecondaryReader(func, args, name);
                } else {
                    sec = alreadyRunning;
                }
                sec.whenComplete((rr, tt) -> {
                    if (tt == null) {
                        result.complete(rr);
                    } else {
                        result.completeExceptionally(new RuntimeException("Both clusters failed on " + name, t));
                    }
                });
            } else {
                readError.ok();
                result.complete(r);
            }
        });
        if (alreadyRunning != null) {
            // If we are in secondary mode, then no need to wait for primary future to fail: return the first to be ok
            return anySuccessfull(result, alreadyRunning).exceptionally(t -> {
                logger.error("Both clusters failed", name);
                throw new RuntimeException("Both clusters failed on " + name, t);
            });
        } else {
            return result;
        }
    }

    private <A, R> CompletableFuture<R> getSecondaryReader(BiFunction<YtDao, A, CompletableFuture<R>> func, A args, String name) {
        logger.warn("Will read from secondary");
        return func.apply(secondary, args).handle((r, t) -> {
            if (t != null) {
                logger.error("Secondary YT cluster error on {}: {}", name, t.getMessage());
                throw new RuntimeException("Secondary YT cluster error on " + name, t);
            } else {
                return r;
            }
        });
    }

    protected <R> CompletableFuture<R> anySuccessfull(CompletableFuture<R>... futures) {
        CompletableFuture<R> future = new CompletableFuture<>();
        Consumer<R> complete = future::complete;
        CompletableFuture.allOf(Arrays.stream(futures).map(f -> f.thenAccept(complete)).toArray(CompletableFuture<?>[]::new))
                .exceptionally(ex -> {
                    future.completeExceptionally(ex);
                    return null;
                });
        return future;
    }

    protected boolean shouldUseSecondary() {
       return readError.check() || writeError.check();
    }


    @Override
    public String getTableName() {
        return primary.getTableName();
    }

    @Override
    public Class<?> getMappedClass() {
        return primary.getMappedClass();
    }

    @Override
    public String getClusterName() {
        return shouldUseSecondary() ? "secondary" : "primary";
    }
}
