package ru.yandex.stockpile.server.shard.merge;

import java.time.Instant;
import java.util.ArrayList;
import java.util.List;

import javax.annotation.Nullable;

import ru.yandex.solomon.model.protobuf.MetricType;
import ru.yandex.solomon.model.timeseries.AggrGraphDataListIterator;
import ru.yandex.solomon.model.timeseries.ConcatAggrGraphDataIterator;
import ru.yandex.solomon.model.timeseries.MergingAggrGraphDataIterator;

/**
 * @author Vladimir Gordiychuk
 */
public class MergeIterator implements Iterator {
    private final Cursor[] cursors;
    private long lastTsMillis;

    private MergeIterator(List<? extends Iterator> iterators) {
        this.cursors = iterators.stream()
            .map(Cursor::new)
            .filter(Cursor::next)
            .toArray(Cursor[]::new);
    }

    public static Iterator of(List<? extends Iterator> iterators) {
        if (iterators.isEmpty()) {
            return EmptyIterator.INSTANCE;
        } else if (iterators.size() == 1) {
            return iterators.get(0);
        }

        return new MergeIterator(iterators);
    }

    @Override
    public MetricType type() {
        for (var cursor : cursors) {
            if (cursor.it.type() != MetricType.METRIC_TYPE_UNSPECIFIED) {
                return cursor.it.type();
            }
        }
        return MetricType.METRIC_TYPE_UNSPECIFIED;
    }

    @Override
    public int columnSetMask() {
        int mask = 0;
        for (var cursor : cursors) {
            mask |= cursor.it.columnSetMask();
        }
        return mask;
    }

    @Override
    public int elapsedRecords() {
        int records = 0;
        for (var cursor : cursors) {
            records += cursor.it.elapsedRecords();
        }
        return records;
    }

    @Nullable
    @Override
    public Item next() {
        long fromTsMillis = prefetchFirst();
        if (fromTsMillis == Long.MAX_VALUE) {
            return null;
        }

        long toTsMillis = prefetchLast(fromTsMillis);
        var result = poll(fromTsMillis, toTsMillis);

        if (result != null) {
            if (lastTsMillis != 0 && result.getFirstTsMillis() <= lastTsMillis) {
                throw new IllegalStateException("First ts at frame  " + result + " <= then last ts from prev frame " + Instant.ofEpochMilli(lastTsMillis));
            }

            lastTsMillis = result.getLastTsMillis();
        }

        return result;
    }

    private long prefetchFirst() {
        long first = Long.MAX_VALUE;
        for (Cursor cursor : cursors) {
            first = cursor.prefetchFirst(first);
        }
        return first;
    }

    private long prefetchLast(long lastTsMillis) {
        long prevLastTsMillis;
        do {
            prevLastTsMillis = lastTsMillis;
            for (Cursor cursor : cursors) {
                lastTsMillis = cursor.prefetchLast(lastTsMillis);
            }
        } while (prevLastTsMillis != lastTsMillis);
        return lastTsMillis;
    }

    @Nullable
    public Item poll(long firstTsMillis, long lastTsMillis) {
        List<Item> items = new ArrayList<>(cursors.length);
        for (var cursor : cursors) {
            var item = cursor.poll(firstTsMillis, lastTsMillis);
            if (item != null) {
                items.add(item);
            }
        }

        if (items.isEmpty()) {
            return null;
        } else if (items.size() == 1) {
            // no merge at all
            return items.get(0);
        }

        List<AggrGraphDataListIterator> iterators = new ArrayList<>(items.size());
        int elapsedBytes = 0;
        for (var item : items) {
            iterators.add(item.iterator());
            elapsedBytes += item.getElapsedBytes();
        }
        var merge = MergingAggrGraphDataIterator.ofCombineAggregate(iterators);
        return new ItemIterator(merge, firstTsMillis, lastTsMillis, elapsedBytes);
    }

    private static class Cursor {
        private final Iterator it;
        private List<Item> prefetch;

        Cursor(Iterator it) {
            this.it = it;
            this.prefetch = new ArrayList<>(1);
        }

        @Nullable
        public Item poll(long firstTsMillis, long lastTsMillis) {
            if (prefetch.isEmpty()) {
                return null;
            }

            var items = prefetch.subList(0, indexOfLastItem(lastTsMillis));
            try {
                if (items.isEmpty()) {
                    return null;
                } else if (items.size() == 1) {
                    return items.get(0);
                } else {
                    List<AggrGraphDataListIterator> iterators = new ArrayList<>(items.size());
                    int elapsedBytes = 0;
                    for (var item : items) {
                        iterators.add(item.iterator());
                        elapsedBytes += item.getElapsedBytes();
                    }

                    var concat = ConcatAggrGraphDataIterator.of(iterators);
                    return new ItemIterator(concat, firstTsMillis, lastTsMillis, elapsedBytes);
                }
            } finally {
                items.clear();
            }
        }

        private int indexOfLastItem(long lastTsMillis) {
            for (int index = prefetch.size() - 1; index >= 0; index--) {
                var item = prefetch.get(index);
                if (item.getFirstTsMillis() <= lastTsMillis) {
                    return index + 1;
                }
            }

            return 0;
        }

        long prefetchFirst(long firstTimeMillis) {
            if (prefetch.isEmpty() && !next()) {
                return firstTimeMillis;
            }

            return Math.min(firstTimeMillis, prefetch.get(0).getFirstTsMillis());
        }

        long prefetchLast(long lastTimeMillis) {
            if (prefetch.isEmpty() && !next()) {
                return lastTimeMillis;
            }

            for (var item : prefetch) {
                if (item.getFirstTsMillis() <= lastTimeMillis) {
                    lastTimeMillis = Math.max(lastTimeMillis, item.getLastTsMillis());
                } else {
                    return lastTimeMillis;
                }
            }

            while (next()) {
                Item last = prefetch.get(prefetch.size() - 1);
                if (last.getFirstTsMillis() > lastTimeMillis) {
                    return lastTimeMillis;
                }

                lastTimeMillis = Math.max(lastTimeMillis, last.getLastTsMillis());
            }

            return lastTimeMillis;
        }

        private boolean next() {
            var item = it.next();
            if (item != null) {
                prefetch.add(item);
                return true;
            }
            return false;
        }
    }
}
