package ru.yandex.crypta.graph2.dao.yt.local;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.crypta.graph2.dao.yt.bendable.YsonMultiEntityReducerWithKey;
import ru.yandex.crypta.graph2.dao.yt.local.fastyt.testdata.LoggingIterator;
import ru.yandex.inside.yt.kosher.Yt;
import ru.yandex.inside.yt.kosher.cypress.YPath;
import ru.yandex.inside.yt.kosher.impl.operations.utils.ReducerWithKey;
import ru.yandex.inside.yt.kosher.impl.transactions.utils.YtTransactionsUtils;
import ru.yandex.inside.yt.kosher.impl.ytree.builder.YTree;
import ru.yandex.inside.yt.kosher.operations.OperationContext;
import ru.yandex.inside.yt.kosher.tables.YTableEntryType;
import ru.yandex.inside.yt.kosher.tables.types.NativeProtobufEntryType;
import ru.yandex.inside.yt.kosher.ytree.YTreeMapNode;

import static ru.yandex.crypta.graph2.dao.yt.bendable.YsonMultiEntitySupport.TABLE_INDEX_COLUMN;


public class LocalModeYt {

    public static final Duration PING_INTERVAL = Duration.ofSeconds(10);
    private Yt yt;
    private int debugRecFreq = 0;

    public LocalModeYt(Yt yt) {
        this.yt = yt;
    }

    public LocalModeYt(Yt yt, int debugRecFreq) {
        this.yt = yt;
        this.debugRecFreq = debugRecFreq;
    }

    public <TInput, TOutput, TKey> LocalYield<TOutput> reduce(
            ReducerWithKey<TInput, TOutput, TKey> reducer, YPath table) {

        LocalYield<TOutput> localYield = new LocalYield<>();
        // pass read iterator directly to reducer
        yt.tables().<TInput, Void>read(
                table,
                reducer.inputType(),
                iteratorF -> {
                    reducer.reduce(
                            new LoggingIterator<>(iteratorF, debugRecFreq),
                            localYield,
                            new StatisticsSlf4jLoggingImpl(),
                            new OperationContext()
                    );
                    return null;
                });

        // and collect it output as list
        return localYield;

    }

    public <TInput extends Message, TOutput, TKey> LocalYield<TOutput> reduce(
            ReducerWithKey<TInput, TOutput, TKey> reducer, ListF<YPath> tables, Message.Builder builder) {

        Descriptors.Descriptor descriptor = builder.getDescriptorForType();
        List<Descriptors.OneofDescriptor> oneofs = descriptor.getOneofs();

        Descriptors.OneofDescriptor oneofDescriptor = oneofs.get(0);

        // pre-compute builders for performance
        List<Message.Builder> fieldsBuilders = oneofDescriptor.getFields().stream()
                .map(builder::getFieldBuilder)
                .collect(Collectors.toList());

        List<List<TInput>> severalTableIteratorsWithControlRecs = new ArrayList<>();
        LocalYield<TOutput> localYield = new LocalYield<>();


        return YtTransactionsUtils.withTransaction(yt, PING_INTERVAL, Optional.of(PING_INTERVAL), tr -> {
            // concat several table iterators and put table separators between them
            for (int tableIndex = 0; tableIndex < tables.size(); tableIndex++) {
                YPath table = tables.get(tableIndex);

                if (!yt.cypress().exists(table)) {
                    throw new IllegalStateException(String.format("Table %s doesn't exist", table));
                }
                Message.Builder resBuilder = fieldsBuilders.get(tableIndex);
                resBuilder.clear();

                YTableEntryType<Message> inputType = new NativeProtobufEntryType<>(
                        resBuilder
                );

                ListF<Message> it = yt.tables().read(
                        Optional.of(tr.getId()),
                        false,
                        table,
                        inputType,
                        iter -> Cf.x(iter).map(r -> {
                            return r;
                        }).toList()  // this interface assumes that iterator is consumed inside a transaction
                );
                Descriptors.FieldDescriptor subMessageDescriptor = oneofDescriptor.getField(tableIndex);


                severalTableIteratorsWithControlRecs.add(it.map(
                        r -> {
                            builder.clear();
                            builder.setField(subMessageDescriptor, r);
                            return (TInput) builder.build();
                        }
                ));
            }

            List<TInput> crossTableIterator = severalTableIteratorsWithControlRecs
                    .stream()
                    .flatMap(Collection::stream)
                    .collect(Collectors.toList());

            reducer.reduce(crossTableIterator.iterator(), localYield, new StatisticsSlf4jLoggingImpl(),
                    new OperationContext());

            // and collect it output as list
            return localYield;
        });
    }

    public <TKey> LocalYield<YTreeMapNode> reduce(YsonMultiEntityReducerWithKey<TKey> reducer, ListF<YPath> tables) {
        List<List<YTreeMapNode>> severalTableIteratorsWithControlRecs = new ArrayList<>();
        LocalYield<YTreeMapNode> localYield = new LocalYield<>();

        return YtTransactionsUtils.withTransaction(yt, PING_INTERVAL, Optional.of(PING_INTERVAL), tr -> {
            // concat several table iterators and put table separators between them
            for (int tableIndex = 0; tableIndex < tables.size(); tableIndex++) {
                YPath table = tables.get(tableIndex);
                int finalTableIndex = tableIndex;
                if (!yt.cypress().exists(table)) {
                    // yt can't properly handle non-existing yt table read ICEBERG-727
                    throw new IllegalStateException(String.format("Table %s doesn't exist", table));
                }

                ListF<YTreeMapNode> it = yt.tables().read(
                        Optional.of(tr.getId()),
                        false,
                        table,
                        reducer.inputType(),
                        iter -> Cf.x(iter).map(r -> {
                            r.putAttribute(TABLE_INDEX_COLUMN, YTree.integerNode(finalTableIndex));
                            return r;
                        }).toList()  // this interface assumes that iterator is consumed inside a transaction
                );

                severalTableIteratorsWithControlRecs.add(it);
            }

            List<YTreeMapNode> crossTableIterator = severalTableIteratorsWithControlRecs
                    .stream()
                    .flatMap(Collection::stream)
                    .collect(Collectors.toList());

            reducer.reduce(crossTableIterator.iterator(), localYield, new StatisticsSlf4jLoggingImpl(),
                    new OperationContext());

            // and collect it output as list
            return localYield;
        });

    }


}
