package ru.yandex.intranet.d.datasource.model;

import java.util.function.BiFunction;
import java.util.function.Function;

import com.yandex.ydb.table.transaction.TransactionMode;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuples;

import ru.yandex.intranet.d.util.result.Result;

interface YdbQueryChainAppendable<I, O> extends YdbQueryChain<I, O> {
    @Override
    default <TN> YdbQueryChain<I, TN> append(YdbQueryChain<O, TN> chain) {
        if (chain instanceof YdbQueryChainImpl) {
            return append1((YdbQueryChainImpl<O, TN>) chain);
        } else if (chain instanceof YdbQueryChainImpl.YdbQueryChain2) {
            return append2((YdbQueryChainImpl.YdbQueryChain2<O, ?, TN>) chain);
        } else if (chain instanceof YdbQueryChainImpl.YdbQueryChainN) {
            return appendN((YdbQueryChainImpl.YdbQueryChainN<O, ?, ?, TN>) chain);
        } else {
            throw new IllegalArgumentException("Unsupported implementation: " + chain.getClass().getCanonicalName());
        }
    }

    @Override
    default <I1, O1, TN> YdbQueryChain<I, TN> combine(
            YdbQueryChain<I1, O1> chain,
            Function<O, I1> glue,
            BiFunction<O, O1, TN> combinator
    ) {
        if (chain instanceof YdbQueryChainImpl) {
            return combine1((YdbQueryChainImpl<I1, O1>) chain, glue, combinator);
        } else if (chain instanceof YdbQueryChainImpl.YdbQueryChain2) {
            return combine2((YdbQueryChainImpl.YdbQueryChain2<I1, ?, O1>) chain, glue, combinator);
        } else if (chain instanceof YdbQueryChainImpl.YdbQueryChainN) {
            return combineN((YdbQueryChainImpl.YdbQueryChainN<I1, ?, ?, O1>) chain, glue, combinator);
        } else {
            throw new IllegalArgumentException("Unsupported implementation: " + chain.getClass().getCanonicalName());
        }
    }

    <TN> YdbQueryChain<I, TN> append1(YdbQueryChainImpl<O, TN> chain);

    <T2, TN> YdbQueryChain<I, TN> append2(YdbQueryChainImpl.YdbQueryChain2<O, T2, TN> chain);

    <T2, T3, TN> YdbQueryChain<I, TN> appendN(YdbQueryChainImpl.YdbQueryChainN<O, T2, T3, TN> chain);

    <TN, I1, O1> YdbQueryChain<I, TN> combine1(
            YdbQueryChainImpl<I1, O1> chain, Function<O, I1> glue, BiFunction<O, O1, TN> combinator
    );

    <T2N, TN, I1, O1> YdbQueryChain<I, TN> combine2(
            YdbQueryChainImpl.YdbQueryChain2<I1, T2N, O1> chain, Function<O, I1> glue, BiFunction<O, O1, TN> combinator
    );

    <T2N, T3N, TN, I1, O1> YdbQueryChain<I, TN> combineN(
            YdbQueryChainImpl.YdbQueryChainN<I1, T2N, T3N, O1> chain,
            Function<O, I1> glue,
            BiFunction<O, O1, TN> combinator
    );
}

/**
 * Query chain implementation.
 *
 * @author Vladimir Zaytsev <vzay@yandex-team.ru>
 * @since 17.09.2020
 */
public class YdbQueryChainImpl<I, T1> implements YdbQueryChainAppendable<I, T1> {
    private final Stage<I, T1> preamble;

    public YdbQueryChainImpl(Stage<I, T1> preamble) {
        this.preamble = preamble;
    }

    @Override
    public <T2> YdbQueryChain<I, T2> append(Stage<T1, T2> stage) {
        return new YdbQueryChain2<>(preamble, stage);
    }

    @Override
    public <TN> YdbQueryChain<I, TN> append1(YdbQueryChainImpl<T1, TN> chain) {
        return append(chain.preamble);
    }

    @Override
    public <T2, O> YdbQueryChain<I, O> append2(YdbQueryChain2<T1, T2, O> chain) {
        return this
                .append(chain.preamble)
                .append(chain.trailer);
    }

    @Override
    public <T2, T3, O> YdbQueryChain<I, O> appendN(YdbQueryChainN<T1, T2, T3, O> chain) {
        return this
                .append(chain.preamble)
                .append(chain.body)
                .append(chain.trailer);
    }

    @Override
    public <TN, I1, O1> YdbQueryChain<I, TN> combine1(
            YdbQueryChainImpl<I1, O1> chain,
            Function<T1, I1> glue,
            BiFunction<T1, O1, TN> combinator
    ) {
        return append((ts, t1) -> chain.preamble.apply(ts, glue.apply(t1))
                .map(r -> new WithTxId<>(combinator.apply(t1, r.get()), r.getTransactionId()))
        );
    }

    @Override
    public <T2, TN, I1, O1> YdbQueryChain<I, TN> combine2(
            YdbQueryChain2<I1, T2, O1> chain,
            Function<T1, I1> glue,
            BiFunction<T1, O1, TN> combinator
    ) {
        return this
                .append((ts, t1) -> chain.preamble.apply(ts, glue.apply(t1))
                        .map(r -> r.map(t2 -> Tuples.of(t1, t2))))
                .append((ts, t) -> chain.trailer.apply(ts, t.getT2())
                        .map(r -> r.map(o1 -> combinator.apply(t.getT1(), o1))));
    }

    @Override
    public <T2, T3, TN, I1, O1> YdbQueryChain<I, TN> combineN(
            YdbQueryChainN<I1, T2, T3, O1> chain,
            Function<T1, I1> glue,
            BiFunction<T1, O1, TN> combinator
    ) {
        return this
                .append((ts, t1) -> chain.preamble.apply(ts, glue.apply(t1))
                        .map(r -> r.map(t2 -> Tuples.of(t1, t2))))
                .append((ts, t) -> chain.body.apply(ts, t.getT2())
                        .map(r -> r.map(t3 -> Tuples.of(t.getT1(), t3))))
                .append((ts, t) -> chain.trailer.apply(ts, t.getT2())
                        .map(r -> r.map(o1 -> combinator.apply(t.getT1(), o1))));
    }

    @Override
    public Mono<T1> execute(YdbSession session, TransactionMode txMode, I input) {
        return session.usingTxMono(txMode, ts -> preamble.apply(ts, input)
                .map(WithTxId::get));
    }

    @Override
    public Mono<Result<T1>> validateAndExecute(
            YdbSession session, TransactionMode txMode, I input, Validator<I> validator
    ) {
        return session.usingCompResultTxRetryable(txMode,
                ts -> validator.apply(ts, input),
                (ts, t1) -> Mono.just(new WithTxId<>(t1, ts.getId())),
                preamble
        );
    }

    static class YdbQueryChain2<I, T1, T2> implements YdbQueryChainAppendable<I, T2> {
        private final Stage<I, T1> preamble;
        private final Stage<T1, T2> trailer;

        YdbQueryChain2(Stage<I, T1> preamble, Stage<T1, T2> trailer) {
            this.preamble = preamble;
            this.trailer = trailer;
        }

        @Override
        public <T3> YdbQueryChain<I, T3> append(Stage<T2, T3> stage) {
            return new YdbQueryChainN<>(preamble, trailer, stage);
        }

        @Override
        public Mono<T2> execute(YdbSession session, TransactionMode txMode, I input) {
            return session.usingCompTxMonoRetryable(txMode,
                    ts -> preamble.apply(ts, input).map(WithTxId::asTuple),
                    (ts, t1) -> Mono.just(t1),
                    (ts, t1) -> trailer.apply(ts, t1).map(WithTxId::get)
            );
        }

        @Override
        public Mono<Result<T2>> validateAndExecute(
                YdbSession session, TransactionMode txMode, I input, Validator<I> validator
        ) {
            return session.usingCompResultTxRetryable(txMode,
                    ts -> validator.apply(ts, input),
                    (ts, t1) -> preamble.apply(ts, input),
                    trailer
            );
        }

        @Override
        public <TN> YdbQueryChain<I, TN> append1(YdbQueryChainImpl<T2, TN> chain) {
            return append(chain.preamble);
        }

        @Override
        public <T21, TN> YdbQueryChain<I, TN> append2(YdbQueryChain2<T2, T21, TN> chain) {
            return this
                    .append(chain.preamble)
                    .append(chain.trailer);
        }

        @Override
        public <T21, T3, TN> YdbQueryChain<I, TN> appendN(YdbQueryChainN<T2, T21, T3, TN> chain) {
            return null;
        }

        @Override
        public <TN, I1, O1> YdbQueryChain<I, TN> combine1(
                YdbQueryChainImpl<I1, O1> chain,
                Function<T2, I1> glue,
                BiFunction<T2, O1, TN> combinator
        ) {
            return append((ts, t1) -> chain.preamble.apply(ts, glue.apply(t1))
                    .map(r -> new WithTxId<>(combinator.apply(t1, r.get()), r.getTransactionId()))
            );
        }

        @Override
        public <T2N, TN, I1, O1> YdbQueryChain<I, TN> combine2(
                YdbQueryChain2<I1, T2N, O1> chain,
                Function<T2, I1> glue,
                BiFunction<T2, O1, TN> combinator
        ) {
            return this
                    .append((ts, t1) -> chain.preamble.apply(ts, glue.apply(t1))
                            .map(r -> r.map(t2 -> Tuples.of(t1, t2))))
                    .append((ts, t) -> chain.trailer.apply(ts, t.getT2())
                            .map(r -> r.map(o1 -> combinator.apply(t.getT1(), o1))));
        }

        @Override
        public <T2N, T3N, TN, I1, O1> YdbQueryChain<I, TN> combineN(
                YdbQueryChainN<I1, T2N, T3N, O1> chain,
                Function<T2, I1> glue,
                BiFunction<T2, O1, TN> combinator
        ) {
            return this
                    .append((ts, t1) -> chain.preamble.apply(ts, glue.apply(t1))
                            .map(r -> r.map(t2 -> Tuples.of(t1, t2))))
                    .append((ts, t) -> chain.body.apply(ts, t.getT2())
                            .map(r -> r.map(t3 -> Tuples.of(t.getT1(), t3))))
                    .append((ts, t) -> chain.trailer.apply(ts, t.getT2())
                            .map(r -> r.map(o1 -> combinator.apply(t.getT1(), o1))));
        }
    }

    static class YdbQueryChainN<I, T1, T2, T3> implements YdbQueryChainAppendable<I, T3> {
        private final Stage<I, T1> preamble;
        private final Stage<T1, T2> body;
        private final Stage<T2, T3> trailer;

        YdbQueryChainN(
                Stage<I, T1> preamble,
                Stage<T1, T2> body,
                Stage<T2, T3> trailer
        ) {
            this.preamble = preamble;
            this.body = body;
            this.trailer = trailer;
        }

        @Override
        public <TN> YdbQueryChain<I, TN> append(Stage<T3, TN> stage) {
            return new YdbQueryChainN<>(
                    preamble,
                    (ts, t1) -> body.apply(ts, t1)
                            .flatMap(t2 -> trailer.apply(ts, t2.get())),
                    stage
            );
        }

        @Override
        public Mono<T3> execute(YdbSession session, TransactionMode txMode, I input) {
            return session.usingCompTxMonoRetryable(txMode,
                    ts -> preamble.apply(ts, input).map(WithTxId::asTuple),
                    (ts, t1) -> body.apply(ts, t1).map(WithTxId::get),
                    (ts, t2) -> trailer.apply(ts, t2).map(WithTxId::get)
            );
        }

        @Override
        public Mono<Result<T3>> validateAndExecute(
                YdbSession session, TransactionMode txMode, I input, Validator<I> validator
        ) {
            return session.usingCompResultTxRetryable(txMode,
                    ts -> validator.apply(ts, input),
                    (ts, t1) -> preamble.apply(ts, input)
                            .flatMap(t -> body.apply(ts, t.get())),
                    trailer
            );
        }

        @Override
        public <TN> YdbQueryChain<I, TN> append1(YdbQueryChainImpl<T3, TN> chain) {
            return append(chain.preamble);
        }

        @Override
        public <NT2, TN> YdbQueryChain<I, TN> append2(YdbQueryChain2<T3, NT2, TN> chain) {
            return this
                    .append(chain.preamble)
                    .append(chain.trailer);
        }

        @Override
        public <NT2, NT3, TN> YdbQueryChain<I, TN> appendN(YdbQueryChainN<T3, NT2, NT3, TN> chain) {
            return this
                    .append(chain.preamble)
                    .append(chain.body)
                    .append(chain.trailer);
        }

        @Override
        public <TN, I1, O1> YdbQueryChain<I, TN> combine1(
                YdbQueryChainImpl<I1, O1> chain,
                Function<T3, I1> glue,
                BiFunction<T3, O1, TN> combinator
        ) {
            return append((ts, t1) -> chain.preamble.apply(ts, glue.apply(t1))
                    .map(r -> new WithTxId<>(combinator.apply(t1, r.get()), r.getTransactionId()))
            );
        }

        @Override
        public <T2N, TN, I1, O1> YdbQueryChain<I, TN> combine2(
                YdbQueryChain2<I1, T2N, O1> chain,
                Function<T3, I1> glue,
                BiFunction<T3, O1, TN> combinator
        ) {
            return this
                    .append((ts, t1) -> chain.preamble.apply(ts, glue.apply(t1))
                            .map(r -> r.map(t2 -> Tuples.of(t1, t2))))
                    .append((ts, t) -> chain.trailer.apply(ts, t.getT2())
                            .map(r -> r.map(o1 -> combinator.apply(t.getT1(), o1))));
        }

        @Override
        public <T2N, T3N, TN, I1, O1> YdbQueryChain<I, TN> combineN(
                YdbQueryChainN<I1, T2N, T3N, O1> chain,
                Function<T3, I1> glue,
                BiFunction<T3, O1, TN> combinator
        ) {
            return this
                    .append((ts, t1) -> chain.preamble.apply(ts, glue.apply(t1))
                            .map(r -> r.map(t2 -> Tuples.of(t1, t2))))
                    .append((ts, t) -> chain.body.apply(ts, t.getT2())
                            .map(r -> r.map(t3 -> Tuples.of(t.getT1(), t3))))
                    .append((ts, t) -> chain.trailer.apply(ts, t.getT2())
                            .map(r -> r.map(o1 -> combinator.apply(t.getT1(), o1))));
        }
    }
}
