package ru.yandex.crypta.graph2.dao.yt.local.fastyt.recs;

import java.io.InputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import ru.yandex.crypta.graph2.dao.yt.local.fastyt.fs.LocalYtDataLayer;
import ru.yandex.crypta.graph2.dao.yt.local.fastyt.testdata.LoggingIterator;
import ru.yandex.crypta.graph2.utils.IteratorUtils;
import ru.yandex.inside.yt.kosher.cypress.YPath;
import ru.yandex.inside.yt.kosher.impl.ytree.builder.YTree;
import ru.yandex.inside.yt.kosher.tables.YTableEntryTypes;
import ru.yandex.inside.yt.kosher.ytree.YTreeMapNode;

import static ru.yandex.crypta.graph2.dao.yt.bendable.YsonMultiEntitySupport.TABLE_INDEX_COLUMN;

/**
 * Layer between bytes and yson recs represenation.
 * Emulates mapper and reducer input assuming recs are stored as yson in underlying dataLayer.
 */
public class YsonRecsLayer implements RecsLayer {

    private final LocalYtDataLayer dataLayer;
    private final boolean debugLogging = true;

    public YsonRecsLayer(LocalYtDataLayer dataLayer) {
        this.dataLayer = dataLayer;
    }

    private static void closeOrThrow(AutoCloseable closeable) {
        try {
            closeable.close();
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    private List<Stream<YTreeMapNode>> getRecsStreamsPerPath(List<YPath> paths) {
        return paths.stream().map(path -> {
                    InputStream inputStream = dataLayer.createInputStream(path);

                    Stream<YTreeMapNode> recsStream = IteratorUtils
                            .stream(YTableEntryTypes.YSON.iterator(inputStream))
                            .onClose(() -> closeOrThrow(inputStream));

                    if (debugLogging) {
                        return LoggingIterator.logStream(recsStream, path.name(), 100000, false);
                    } else {
                        return recsStream;
                    }
                }
        ).collect(Collectors.toList());
    }

    @Override
    public Stream<YTreeMapNode> readMap(List<YPath> paths) {
        List<Stream<YTreeMapNode>> pathsIters = getRecsStreamsPerPath(paths);
        return mapOrderAndTableIndex(pathsIters);
    }

    private Stream<YTreeMapNode> mapOrderAndTableIndex(List<Stream<YTreeMapNode>> tableIterators) {
        Stream<Stream<YTreeMapNode>> withIndex = IntStream
                .range(0, tableIterators.size())
                .mapToObj(tableIndex -> {
                            Stream<YTreeMapNode> tableIterator = tableIterators.get(tableIndex);
                            return tableIterator.peek(
                                    r -> r.putAttribute(TABLE_INDEX_COLUMN, YTree.integerNode(tableIndex))
                            );
                        }
                );

        // put to single iterator
        return withIndex.flatMap(tableIterator -> tableIterator);
    }

    @Override
    public Stream<YTreeMapNode> readReduce(List<YPath> paths, List<String> reduceBy) {
        List<Stream<YTreeMapNode>> pathsIters = getRecsStreamsPerPath(paths);

        return reduceOrderAndTableIndex(pathsIters, reduceBy);
    }

    private Stream<YTreeMapNode> reduceOrderAndTableIndex(List<Stream<YTreeMapNode>> tableStreams,
                                                          List<String> keys) {

        List<Iterator<RecWithKeyAndTableIndex<YTreeMapNode>>> tableIterators = new ArrayList<>();

        for (int tableIndex = 0; tableIndex < tableStreams.size(); tableIndex++) {
            int finalTableIndex = tableIndex;
            Stream<YTreeMapNode> tableStream = tableStreams.get(tableIndex);
            var tableIterator = tableStream.map(rec -> {
                rec.putAttribute(TABLE_INDEX_COLUMN, YTree.integerNode(finalTableIndex));
                return new RecWithKeyAndTableIndex<>(
                        rec,
                        extractKey(rec, keys),
                        finalTableIndex
                );
            }).iterator();

            tableIterators.add(tableIterator);

        }

        MergingIteratorByKeyAndTableIndex<YTreeMapNode> reduceIter = new MergingIteratorByKeyAndTableIndex<>(
                tableIterators
        );

        return IteratorUtils.stream(reduceIter)
                .map(RecWithKeyAndTableIndex::getRec)
                .onClose(() -> {
                    for (Stream<YTreeMapNode> tableStream : tableStreams) {
                        tableStream.close();
                    }
                });

    }

    private String extractKey(YTreeMapNode rec, List<String> sortedBy) {
        return sortedBy.stream().map(rec::getString).collect(Collectors.joining(""));
    }

    public static void main(String[] args) {
        Stream<Integer> stream = List.of(3, 2).stream();
        System.out.println(stream.findFirst());
        System.out.println(stream.findFirst());
    }
}

