package ru.yandex.crypta.graph2.dao.yt.bendable;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Stream;

import ru.yandex.bolts.collection.IteratorF;
import ru.yandex.bolts.collection.Tuple2;
import ru.yandex.crypta.graph2.utils.IteratorUtils;
import ru.yandex.inside.yt.kosher.operations.MapperOrReducer;
import ru.yandex.inside.yt.kosher.tables.YTableEntryType;
import ru.yandex.inside.yt.kosher.tables.types.YsonTableEntryType;
import ru.yandex.inside.yt.kosher.ytree.YTreeMapNode;

import static java.util.stream.Collectors.toList;

public class YsonMultiEntitySupport extends YsonCachedSerializerSupport implements MapperOrReducer<YTreeMapNode, YTreeMapNode> {

    public static final YTableEntryType<YTreeMapNode> YSON_WITH_TABLE_INDEX = new YsonTableEntryType(true, false);

    public static final String TABLE_INDEX_COLUMN = "table_index";

    public int getTableIndex(YTreeMapNode rec) {
        return rec.getAttributeOrThrow(TABLE_INDEX_COLUMN).intValue();
    }

    public void resetTableIndex(YTreeMapNode rec) {
        rec.removeAttribute(TABLE_INDEX_COLUMN);
    }

    protected <E1, E2> Tuple2<Collection<E1>, Collection<E2>> splitLeftRight(Iterator<YTreeMapNode> recs,
                                                                             Class<E1> clazz1, Class<E2> clazz2) {
        List<E1> entities1 = new ArrayList<>();
        List<E2> entities2 = new ArrayList<>();

        recs.forEachRemaining(rec -> {
            if (getTableIndex(rec) == 0) {
                entities1.add(parse(rec, clazz1));
            } else {
                entities2.add(parse(rec, clazz2));
            }
        });

        return Tuple2.tuple(entities1, entities2);
    }

    protected Tuple2<Collection<YTreeMapNode>, Collection<YTreeMapNode>> splitLeftRightRaw(IteratorF<YTreeMapNode> recs) {
        List<YTreeMapNode> entities1 = new ArrayList<>();
        List<YTreeMapNode> entities2 = new ArrayList<>();

        recs.forEachRemaining(rec -> {
            if (getTableIndex(rec) == 0) {
                entities1.add(rec);
            } else {
                entities2.add(rec);
            }
        });

        return Tuple2.tuple(entities1, entities2);
    }

    protected <E1, E2> Tuple2<Collection<E1>, Collection<E2>> splitLeftRightCheckOverlimit(Iterator<YTreeMapNode> recs,
                                                                                           Class<E1> clazz1,
                                                                                           Class<E2> clazz2,
                                                                                           int recsLimit) throws RecsOverlimitException {
        return splitLeftRight(checkOverlimitExc(recs, recsLimit).iterator(), clazz1, clazz2);

    }

    public List<YTreeMapNode> checkOverlimitExc(Iterator<YTreeMapNode> recs, int recsLimit) throws RecsOverlimitException {
        List<YTreeMapNode> firstN = IteratorUtils.stream(recs).limit(recsLimit).collect(toList());
        if (firstN.size() == recsLimit) {
            throw new RecsOverlimitException(recsLimit, firstN, recs);
        } else {
            return firstN;
        }
    }

    public Tuple2<Iterator<YTreeMapNode>, Boolean> checkOverlimit(IteratorF<YTreeMapNode> entries, int limit) {
        // TODO: rework without exception handling
        try {
            List<YTreeMapNode> noOverlimit = checkOverlimitExc(entries, limit);
            return Tuple2.tuple(noOverlimit.iterator(), false);
        } catch (RecsOverlimitException e) {
            return Tuple2.tuple(e.allRecsIterator(), true);
        }
    }

    @Override
    public YTableEntryType<YTreeMapNode> inputType() {
        return YSON_WITH_TABLE_INDEX;
    }

    @Override
    public YTableEntryType<YTreeMapNode> outputType() {
        return YSON_WITH_TABLE_INDEX;
    }

    public static class RecsOverlimitException extends Exception {
        private final List<YTreeMapNode> fetched;
        private final Iterator<YTreeMapNode> rest;
        private int limit;

        public RecsOverlimitException(int limit, List<YTreeMapNode> fetched, Iterator<YTreeMapNode> rest) {
            super(String.format("Reducer received more than %d recs", limit));
            this.limit = limit;
            this.fetched = fetched;
            this.rest = rest;
        }

        public long countTotalRecs() {
            return limit + IteratorUtils.stream(rest).count();
        }

        public Iterator<YTreeMapNode> allRecsIterator() {
            return Stream.concat(fetched.stream(), IteratorUtils.stream(rest)).iterator();
        }

        public void forEachRec(Consumer<YTreeMapNode> callback) {
            allRecsIterator().forEachRemaining(callback);
        }
    }
}
