package ru.yandex.crypta.graph2.dao.yt.ops;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.function.Function0;
import ru.yandex.bolts.function.Function0V;
import ru.yandex.crypta.graph2.utils.EnvUtils;
import ru.yandex.inside.yt.kosher.operations.Operation;
import ru.yandex.misc.concurrent.ExecutorUtils;
import ru.yandex.misc.lang.ObjectUtils;

public class Await {
    private static final Logger LOG = LoggerFactory.getLogger(Await.class);

    private static final int YT_POOL_MAX_OPS = EnvUtils.getIntEnvOrDefault("YT_POOL_MAX_OPS", 50);


    public static void all(ListF<Function0<Operation>> asyncOps) {

        ListF<Function0V> logAndCheckOps = asyncOps.zipWithIndex().map(opWithIndex -> {
            int opIndex = opWithIndex._2;
            Function0<Operation> opCallable = opWithIndex._1;
            return opCallable
                    .andThen(op -> {
                        LOG.info("Op {}: {}", opIndex, op.getId());
                        return op;
                    })
                    .andThen(Await::op);
        });

        Integer numberOfParallelOps = ObjectUtils.min(YT_POOL_MAX_OPS, asyncOps.size());

        LOG.info("Scheduled parallel execution of {} yt operations in {} threads...",
                asyncOps.size(), numberOfParallelOps);

        ExecutorUtils.executeAll(numberOfParallelOps, logAndCheckOps, Function0V::apply);

        LOG.info("{} operations are done.", asyncOps.size());
    }

    public static void executeAll(Runnable... tasks) {
        executeAll(Cf.list(tasks));
    }

    public static void executeAll(ListF<Runnable> tasks) {
        ExecutorUtils.executeAll(YT_POOL_MAX_OPS, tasks, Runnable::run);
    }

    @SafeVarargs
    public static void all(Function0<Operation>... asyncOps) {
        all(Cf.wrap(asyncOps));
    }

    public static void op(Operation op) {
        op.awaitAndThrowIfNotSuccess();
    }
}
